Add custom tls client support for v2ray h2/grpclite transports

This commit is contained in:
世界
2022-09-11 10:22:52 +08:00
parent 7e09beb0c3
commit a2d1f89922
14 changed files with 211 additions and 129 deletions

View File

@@ -16,6 +16,8 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/http2"
)
var _ adapter.V2RayClientTransport = (*Client)(nil)
@@ -24,7 +26,7 @@ type Client struct {
ctx context.Context
dialer N.Dialer
serverAddr M.Socksaddr
client *http.Client
transport http.RoundTripper
http2 bool
url *url.URL
host []string
@@ -33,6 +35,25 @@ type Client struct {
}
func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayHTTPOptions, tlsConfig tls.Config) adapter.V2RayClientTransport {
var transport http.RoundTripper
if tlsConfig == nil {
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
} else {
tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
transport = &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
if err != nil {
return nil, err
}
return tls.ClientHandshake(ctx, conn, tlsConfig)
},
}
}
client := &Client{
ctx: ctx,
dialer: dialer,
@@ -40,27 +61,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
host: options.Host,
method: options.Method,
headers: make(http.Header),
client: &http.Client{},
transport: transport,
http2: tlsConfig != nil,
}
if client.http2 {
client.client.Transport = &http.Transport{
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
if err != nil {
return nil, err
}
return tls.ClientHandshake(ctx, conn, tlsConfig)
},
ForceAttemptHTTP2: true,
}
} else {
client.client.Transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
}
if client.method == "" {
client.method = "PUT"
}
@@ -145,17 +148,16 @@ func (c *Client) dialHTTP2(ctx context.Context) (net.Conn, error) {
}
// Disable any compression method from server.
request.Header.Set("Accept-Encoding", "identity")
response, err := c.client.Do(request) // nolint: bodyclose
if err != nil {
pipeInWriter.Close()
return nil, err
}
if response.StatusCode != 200 {
pipeInWriter.Close()
return nil, E.New("unexpected status: ", response.StatusCode, " ", response.Status)
}
return &HTTPConn{
response.Body,
pipeInWriter,
}, nil
conn := newLateHTTPConn(pipeInWriter)
go func() {
response, err := c.transport.RoundTrip(request)
if err != nil {
conn.setup(nil, err)
} else if response.StatusCode != 200 {
conn.setup(nil, E.New("unexpected status: ", response.StatusCode, " ", response.Status))
} else {
conn.setup(response.Body, nil)
}
}()
return conn, nil
}

View File

@@ -14,9 +14,37 @@ import (
type HTTPConn struct {
reader io.Reader
writer io.Writer
create chan struct{}
err error
}
func newHTTPConn(reader io.Reader, writer io.Writer) HTTPConn {
return HTTPConn{
reader: reader,
writer: writer,
}
}
func newLateHTTPConn(writer io.Writer) *HTTPConn {
return &HTTPConn{
create: make(chan struct{}),
writer: writer,
}
}
func (c *HTTPConn) setup(reader io.Reader, err error) {
c.reader = reader
c.err = err
close(c.create)
}
func (c *HTTPConn) Read(b []byte) (n int, err error) {
if c.reader == nil {
<-c.create
if c.err != nil {
return 0, c.err
}
}
n, err = c.reader.Read(b)
return n, baderror.WrapH2(err)
}

View File

@@ -16,6 +16,8 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
sHttp "github.com/sagernet/sing/protocol/http"
"golang.org/x/net/http2"
)
var _ adapter.V2RayServerTransport = (*Server)(nil)
@@ -110,10 +112,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
s.handler.NewConnection(request.Context(), conn, metadata)
} else {
conn := &ServerHTTPConn{
HTTPConn{
request.Body,
writer,
},
newHTTPConn(request.Body, writer),
writer.(http.Flusher),
}
s.handler.NewConnection(request.Context(), conn, metadata)
@@ -128,6 +127,10 @@ func (s *Server) Serve(listener net.Listener) error {
if s.httpServer.TLSConfig == nil {
return s.httpServer.Serve(listener)
} else {
err := http2.ConfigureServer(s.httpServer, &http2.Server{})
if err != nil {
return err
}
return s.httpServer.ServeTLS(listener, "", "")
}
}