Files
sing-box-extended/transport/trusttunnel/client.go

324 lines
8.2 KiB
Go

package trusttunnel
import (
"context"
stdtls "crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3"
qtls "github.com/sagernet/sing-quic"
"golang.org/x/net/http2"
)
var (
appName = "sing-box"
appVersion = C.Version
tcpUserAgent = runtime.GOOS + " " + appName + "/" + appVersion
udpUserAgent = runtime.GOOS + " " + UDPMagicAddress
icmpUserAgent = runtime.GOOS + " " + ICMPMagicAddress
)
type Dialer interface {
Dial(ctx context.Context, host string) (net.Conn, error)
ListenPacket(ctx context.Context) (net.PacketConn, error)
Close() error
}
type ClientOptions struct {
TLSDialer tls.Dialer
QUICDialer N.Dialer
QUICTLSConfig tls.Config
Server M.Socksaddr
Username string
Password string
QUIC bool
CongestionControl string
CWND int
BBRProfile string
HealthCheck bool
MaxConnections int
MinStreams int
MaxStreams int
}
type Client struct {
ctx context.Context
cancel context.CancelFunc
server M.Socksaddr
serverString string
auth string
roundTripper http.RoundTripper
startOnce sync.Once
healthCheck bool
healthCheckTimer *time.Timer
count atomic.Int64
}
func NewClient(ctx context.Context, options ClientOptions) (*Client, error) {
ctx, cancel := context.WithCancel(ctx)
client := &Client{
ctx: ctx,
cancel: cancel,
server: options.Server,
serverString: options.Server.String(),
auth: buildAuth(options.Username, options.Password),
healthCheck: options.HealthCheck,
}
if options.QUIC {
congestionControlFactory, err := NewCongestionControl(options.CongestionControl, options.CWND, options.BBRProfile, ntp.TimeFuncFromContext(ctx))
if err != nil {
cancel()
return nil, err
}
client.roundTripper = &http3.Transport{
QUICConfig: &quic.Config{
MaxIdleTimeout: DefaultSessionTimeout * 2,
KeepAlivePeriod: DefaultHealthCheckTimeout,
},
Dial: func(ctx context.Context, addr string, tlsCfg *stdtls.Config, cfg *quic.Config) (*quic.Conn, error) {
udpConn, err := options.QUICDialer.DialContext(ctx, N.NetworkUDP, client.server)
if err != nil {
return nil, err
}
conn, err := qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), options.QUICTLSConfig, cfg)
if err != nil {
return nil, err
}
conn.SetCongestionControl(congestionControlFactory(conn))
return conn, nil
},
}
} else {
client.roundTripper = &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, _ *stdtls.Config) (net.Conn, error) {
return options.TLSDialer.DialContext(ctx, network, client.server)
},
AllowHTTP: true,
}
}
return client, nil
}
func (c *Client) start() {
if c.healthCheck {
c.healthCheckTimer = time.NewTimer(DefaultHealthCheckTimeout)
go c.loopHealthCheck()
}
}
func (c *Client) loopHealthCheck() {
for {
select {
case <-c.healthCheckTimer.C:
case <-c.ctx.Done():
c.healthCheckTimer.Stop()
return
}
ctx, cancel := context.WithTimeout(c.ctx, DefaultHealthCheckTimeout)
_ = c.HealthCheck(ctx)
cancel()
}
}
func (c *Client) resetHealthCheckTimer() {
if c.healthCheckTimer == nil {
return
}
c.healthCheckTimer.Reset(DefaultHealthCheckTimeout)
}
func (c *Client) roundTrip(request *http.Request, conn *httpConn) {
c.startOnce.Do(c.start)
pipeReader, pipeWriter := io.Pipe()
request.Body = pipeReader
*conn = httpConn{writer: pipeWriter, created: make(chan struct{})}
c.count.Add(1)
conn.closeFn = sync.OnceFunc(func() { c.count.Add(-1) })
ctx, cancel := context.WithCancel(c.ctx)
conn.cancelFn = cancel
go func() {
timeout := time.AfterFunc(C.TCPTimeout, cancel)
defer timeout.Stop()
request = request.WithContext(ctx)
response, err := c.roundTripper.RoundTrip(request)
if err != nil {
_ = pipeWriter.CloseWithError(err)
_ = pipeReader.CloseWithError(err)
conn.setup(nil, err)
} else if response.StatusCode != http.StatusOK {
_ = response.Body.Close()
err = fmt.Errorf("unexpected status code: %d", response.StatusCode)
_ = pipeWriter.CloseWithError(err)
_ = pipeReader.CloseWithError(err)
conn.setup(nil, err)
} else {
c.resetHealthCheckTimer()
conn.setup(response.Body, nil)
}
}()
}
func (c *Client) newConnectRequest(host, userAgent string) *http.Request {
return &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Scheme: "https", Host: c.serverString},
Header: http.Header{
"User-Agent": {userAgent},
"Proxy-Authorization": {c.auth},
},
Host: host,
}
}
func (c *Client) Dial(ctx context.Context, host string) (net.Conn, error) {
conn := &tcpConn{}
c.roundTrip(c.newConnectRequest(host, tcpUserAgent), &conn.httpConn)
return conn, nil
}
func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
conn := &clientPacketConn{}
c.roundTrip(c.newConnectRequest(UDPMagicAddress, udpUserAgent), &conn.httpConn)
return conn, nil
}
func (c *Client) Close() error {
c.cancel()
if closer, ok := c.roundTripper.(io.Closer); ok {
_ = closer.Close()
}
if t, ok := c.roundTripper.(*http2.Transport); ok {
t.CloseIdleConnections()
}
if c.healthCheckTimer != nil {
c.healthCheckTimer.Stop()
}
return nil
}
func (c *Client) HealthCheck(ctx context.Context) error {
defer c.resetHealthCheckTimer()
response, err := c.roundTripper.RoundTrip(c.newConnectRequest(HealthCheckMagicAddress, runtime.GOOS).WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", response.StatusCode)
}
return nil
}
type MultiplexClient struct {
mutex sync.Mutex
maxConnections int
minStreams int
maxStreams int
ctx context.Context
options ClientOptions
clients []*Client
}
func NewMultiplexClient(ctx context.Context, options ClientOptions) (*MultiplexClient, error) {
maxConnections := options.MaxConnections
minStreams := options.MinStreams
maxStreams := options.MaxStreams
if maxConnections == 0 && minStreams == 0 && maxStreams == 0 {
maxConnections = 8
minStreams = 5
}
client, err := NewClient(ctx, options)
if err != nil {
return nil, err
}
return &MultiplexClient{
maxConnections: maxConnections,
minStreams: minStreams,
maxStreams: maxStreams,
ctx: ctx,
options: options,
clients: []*Client{client},
}, nil
}
func (c *MultiplexClient) Dial(ctx context.Context, host string) (net.Conn, error) {
t, err := c.getClient()
if err != nil {
return nil, err
}
return t.Dial(ctx, host)
}
func (c *MultiplexClient) ListenPacket(ctx context.Context) (net.PacketConn, error) {
t, err := c.getClient()
if err != nil {
return nil, err
}
return t.ListenPacket(ctx)
}
func (c *MultiplexClient) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
var errs []error
for _, t := range c.clients {
if err := t.Close(); err != nil {
errs = append(errs, err)
}
}
c.clients = nil
return errors.Join(errs...)
}
func (c *MultiplexClient) getClient() (*Client, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
var transport *Client
for _, t := range c.clients {
if transport == nil || t.count.Load() < transport.count.Load() {
transport = t
}
}
if transport == nil {
return c.newClientLocked()
}
numStreams := int(transport.count.Load())
if numStreams == 0 {
return transport, nil
}
if c.maxConnections > 0 {
if len(c.clients) >= c.maxConnections || numStreams < c.minStreams {
return transport, nil
}
} else if c.maxStreams > 0 && numStreams < c.maxStreams {
return transport, nil
}
return c.newClientLocked()
}
func (c *MultiplexClient) newClientLocked() (*Client, error) {
t, err := NewClient(c.ctx, c.options)
if err != nil {
return nil, err
}
c.clients = append(c.clients, t)
return t, nil
}