From 9ff7a84afe4c9242ac88d116cd0b5cd8b780080f Mon Sep 17 00:00:00 2001 From: Shtorm <108103062+shtorm-7@users.noreply.github.com> Date: Thu, 4 Jun 2026 10:03:20 +0300 Subject: [PATCH] Refactor TrustTunnel --- README.md | 2 +- option/trusttunnel.go | 16 ++--- protocol/trusttunnel/inbound.go | 118 ++++++++++++++++++++----------- protocol/trusttunnel/outbound.go | 28 ++------ transport/trusttunnel/client.go | 20 ++++-- transport/trusttunnel/icmp.go | 62 ---------------- transport/trusttunnel/quic.go | 68 +----------------- 7 files changed, 105 insertions(+), 209 deletions(-) delete mode 100644 transport/trusttunnel/icmp.go diff --git a/README.md b/README.md index f894cc50..c03633ca 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Sing-box with extended features. ## 🔥 Features -### Outbounds +### Protocols - **WARP** — Cloudflare WARP integration through WireGuard - **MASQUE** — Cloudflare MASQUE proxy over QUIC / HTTP-2 - **MTProxy** — Telegram MTProxy server with FakeTLS and domain fronting diff --git a/option/trusttunnel.go b/option/trusttunnel.go index 61342dd7..9a17d818 100644 --- a/option/trusttunnel.go +++ b/option/trusttunnel.go @@ -26,13 +26,13 @@ type TrustTunnelOutboundOptions struct { DialerOptions ServerOptions OutboundTLSOptionsContainer - Username string `json:"username,omitempty"` - Password string `json:"password,omitempty"` - Network NetworkList `json:"network,omitempty"` - HealthCheck bool `json:"health_check,omitempty"` - QUIC bool `json:"quic,omitempty"` - CongestionController string `json:"congestion_controller,omitempty"` - BBRProfile string `json:"bbr_profile,omitempty"` - CWND int `json:"cwnd,omitempty"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + Network NetworkList `json:"network,omitempty"` + HealthCheck bool `json:"health_check,omitempty"` + QUIC bool `json:"quic,omitempty"` + CongestionController string `json:"congestion_controller,omitempty"` + BBRProfile string `json:"bbr_profile,omitempty"` + CWND int `json:"cwnd,omitempty"` Multiplex *TrustTunnelMultiplexOptions `json:"multiplex,omitempty"` } diff --git a/protocol/trusttunnel/inbound.go b/protocol/trusttunnel/inbound.go index 3b955ed1..e6ad9b15 100644 --- a/protocol/trusttunnel/inbound.go +++ b/protocol/trusttunnel/inbound.go @@ -6,6 +6,8 @@ import ( "net" "net/http" + "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/http3" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/common/listener" @@ -14,11 +16,13 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/trusttunnel" + "github.com/sagernet/sing-quic" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" aTLS "github.com/sagernet/sing/common/tls" "golang.org/x/net/http2" @@ -31,36 +35,34 @@ func RegisterInbound(registry *inbound.Registry) { type Inbound struct { inbound.Adapter - ctx context.Context - router adapter.Router - logger logger.ContextLogger - listener *listener.Listener - tlsConfig tls.ServerConfig - service *trusttunnel.Service - httpServer *http.Server - quicService *trusttunnel.QUICService - network []string + ctx context.Context + router adapter.Router + logger logger.ContextLogger + options option.TrustTunnelInboundOptions + listener *listener.Listener + service *trusttunnel.Service + httpServer *http.Server + http3Server *http3.Server + httpTLSConfig tls.ServerConfig + http3TLSConfig tls.ServerConfig + network []string } func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrustTunnelInboundOptions) (adapter.Inbound, error) { if options.TLS == nil || !options.TLS.Enabled { return nil, C.ErrTLSRequired } - tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) - if err != nil { - return nil, err - } networkList := options.Network.Build() if len(networkList) == 0 { networkList = []string{N.NetworkTCP} } inbound := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeTrustTunnel, tag), - ctx: ctx, - router: router, - logger: logger, - tlsConfig: tlsConfig, - network: networkList, + Adapter: inbound.NewAdapter(C.TypeTrustTunnel, tag), + ctx: ctx, + router: router, + logger: logger, + options: options, + network: networkList, listener: listener.New(listener.Options{ Context: ctx, Logger: logger, @@ -78,9 +80,6 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo } service.UpdateUsers(userMap) inbound.service = service - if common.Contains(networkList, N.NetworkUDP) { - inbound.quicService = trusttunnel.NewQUICService(service, options.CongestionController, options.CWND, options.BBRProfile) - } return inbound, nil } @@ -88,14 +87,9 @@ func (h *Inbound) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - if h.tlsConfig != nil { - err := h.tlsConfig.Start() - if err != nil { - return err - } - } + var err error if common.Contains(h.network, N.NetworkTCP) { - tcpListener, err := h.listener.ListenTCP() + listener, err := h.listener.ListenTCP() if err != nil { return err } @@ -105,31 +99,70 @@ func (h *Inbound) Start(stage adapter.StartStage) error { return h.ctx }, } + h.httpTLSConfig, err = tls.NewServer(h.ctx, h.logger, common.PtrValueOrDefault(h.options.TLS)) + if err != nil { + return err + } + if len(h.httpTLSConfig.NextProtos()) == 0 { + h.httpTLSConfig.SetNextProtos([]string{http2.NextProtoTLS}) + } else if !common.Contains(h.httpTLSConfig.NextProtos(), http2.NextProtoTLS) { + h.httpTLSConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, h.httpTLSConfig.NextProtos()...)) + } + err = h.httpTLSConfig.Start() + if err != nil { + return err + } + listener = aTLS.NewListener(listener, h.httpTLSConfig) go func() { - var l net.Listener = tcpListener - if h.tlsConfig != nil { - if len(h.tlsConfig.NextProtos()) == 0 { - h.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS}) - } else if !common.Contains(h.tlsConfig.NextProtos(), http2.NextProtoTLS) { - h.tlsConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, h.tlsConfig.NextProtos()...)) - } - l = aTLS.NewListener(tcpListener, h.tlsConfig) - } - sErr := h.httpServer.Serve(l) + sErr := h.httpServer.Serve(listener) if sErr != nil && !errors.Is(sErr, http.ErrServerClosed) { h.logger.Error("HTTP server error: ", sErr) } }() } if common.Contains(h.network, N.NetworkUDP) { + h.http3TLSConfig, err = tls.NewServer(h.ctx, h.logger, common.PtrValueOrDefault(h.options.TLS)) + if err != nil { + return err + } + if err := qtls.ConfigureHTTP3(h.http3TLSConfig); err != nil { + return err + } + err = h.http3TLSConfig.Start() + if err != nil { + return err + } udpConn, err := h.listener.ListenUDP() if err != nil { return err } - err = h.quicService.Start(h.ctx, udpConn, h.tlsConfig) + congestionControlFactory, err := trusttunnel.NewCongestionControl( + h.options.CongestionController, + h.options.CWND, + h.options.BBRProfile, + ntp.TimeFuncFromContext(h.ctx), + ) if err != nil { return err } + h.http3Server = &http3.Server{ + Handler: h.service, + ConnContext: func(ctx context.Context, conn *quic.Conn) context.Context { + conn.SetCongestionControl(congestionControlFactory(conn)) + return ctx + }, + } + quicListener, err := qtls.ListenEarly(udpConn, h.http3TLSConfig, &quic.Config{ + MaxIdleTimeout: trusttunnel.DefaultSessionTimeout * 2, + MaxIncomingStreams: 1 << 60, + Allow0RTT: true, + }) + if err != nil { + return err + } + go func() { + _ = h.http3Server.ServeListener(quicListener) + }() } return nil } @@ -138,8 +171,9 @@ func (h *Inbound) Close() error { return common.Close( h.listener, common.PtrOrNil(h.httpServer), - common.PtrOrNil(h.quicService), - h.tlsConfig, + common.PtrOrNil(h.http3Server), + h.httpTLSConfig, + h.http3TLSConfig, ) } diff --git a/protocol/trusttunnel/outbound.go b/protocol/trusttunnel/outbound.go index b75f7068..5b95500c 100644 --- a/protocol/trusttunnel/outbound.go +++ b/protocol/trusttunnel/outbound.go @@ -18,8 +18,6 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - - "golang.org/x/net/http2" ) func RegisterOutbound(registry *outbound.Registry) { @@ -42,7 +40,13 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL } serverAddr := options.ServerOptions.Build() networkList := options.Network.Build() + tlsConfig, err := tls.NewClient(ctx, logger, options.Server, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } clientOpts := trusttunnel.ClientOptions{ + Dialer: outboundDialer, + TLSConfig: tlsConfig, Server: serverAddr, Username: options.Username, Password: options.Password, @@ -52,26 +56,6 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL BBRProfile: options.BBRProfile, HealthCheck: options.HealthCheck, } - if options.QUIC { - tlsConfig, err := tls.NewClient(ctx, logger, options.Server, common.PtrValueOrDefault(options.TLS)) - if err != nil { - return nil, err - } - if len(tlsConfig.NextProtos()) == 0 { - tlsConfig.SetNextProtos([]string{"h3"}) - } - clientOpts.QUICDialer = outboundDialer - clientOpts.QUICTLSConfig = tlsConfig - } else { - tlsConfig, err := tls.NewClient(ctx, logger, options.Server, common.PtrValueOrDefault(options.TLS)) - if err != nil { - return nil, err - } - if len(tlsConfig.NextProtos()) == 0 { - tlsConfig.SetNextProtos([]string{http2.NextProtoTLS}) - } - clientOpts.TLSDialer = tls.NewDialer(outboundDialer, tlsConfig) - } var client trusttunnel.Dialer if options.Multiplex != nil && options.Multiplex.Enabled { clientOpts.MaxConnections = options.Multiplex.MaxConnections diff --git a/transport/trusttunnel/client.go b/transport/trusttunnel/client.go index b1dcc6a4..66dedb02 100644 --- a/transport/trusttunnel/client.go +++ b/transport/trusttunnel/client.go @@ -16,9 +16,9 @@ import ( "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/ntp" "github.com/sagernet/quic-go" @@ -42,9 +42,8 @@ type Dialer interface { } type ClientOptions struct { - TLSDialer tls.Dialer - QUICDialer N.Dialer - QUICTLSConfig tls.Config + Dialer N.Dialer + TLSConfig tls.Config Server M.Socksaddr Username string Password string @@ -87,17 +86,20 @@ func NewClient(ctx context.Context, options ClientOptions) (*Client, error) { cancel() return nil, err } + if len(options.TLSConfig.NextProtos()) == 0 { + options.TLSConfig.SetNextProtos([]string{"h3"}) + } client.roundTripper = &http3.Transport{ QUICConfig: &quic.Config{ MaxIdleTimeout: DefaultSessionTimeout * 2, KeepAlivePeriod: DefaultHealthCheckTimeout, }, Dial: func(ctx context.Context, addr string, tlsCfg *stdtls.Config, cfg *quic.Config) (*quic.Conn, error) { - udpConn, err := options.QUICDialer.DialContext(ctx, N.NetworkUDP, client.server) + udpConn, err := options.Dialer.DialContext(ctx, N.NetworkUDP, client.server) if err != nil { return nil, err } - conn, err := qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), options.QUICTLSConfig, cfg) + conn, err := qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), options.TLSConfig, cfg) if err != nil { return nil, err } @@ -106,9 +108,13 @@ func NewClient(ctx context.Context, options ClientOptions) (*Client, error) { }, } } else { + if len(options.TLSConfig.NextProtos()) == 0 { + options.TLSConfig.SetNextProtos([]string{http2.NextProtoTLS}) + } + tlsDialer := tls.NewDialer(options.Dialer, options.TLSConfig) client.roundTripper = &http2.Transport{ DialTLSContext: func(ctx context.Context, network, addr string, _ *stdtls.Config) (net.Conn, error) { - return options.TLSDialer.DialContext(ctx, network, client.server) + return tlsDialer.DialContext(ctx, network, client.server) }, AllowHTTP: true, } diff --git a/transport/trusttunnel/icmp.go b/transport/trusttunnel/icmp.go deleted file mode 100644 index dd39a838..00000000 --- a/transport/trusttunnel/icmp.go +++ /dev/null @@ -1,62 +0,0 @@ -package trusttunnel - -import ( - "encoding/binary" - "net/netip" - - "github.com/sagernet/sing/common/buf" -) - -type IcmpConn struct { - httpConn -} - -func (i *IcmpConn) WritePing(id uint16, destination netip.Addr, sequenceNumber uint16, ttl uint8, size uint16) error { - request := buf.NewSize(2 + 16 + 2 + 1 + 2) - defer request.Release() - must(binary.Write(request, binary.BigEndian, id)) - destinationAddress := buildPaddingIP(destination) - must1(request.Write(destinationAddress[:])) - must(binary.Write(request, binary.BigEndian, sequenceNumber)) - must(binary.Write(request, binary.BigEndian, ttl)) - must(binary.Write(request, binary.BigEndian, size)) - _, err := i.writeFlush(request.Bytes()) - return err -} - -func (i *IcmpConn) ReadPing() (id uint16, sourceAddress netip.Addr, icmpType uint8, code uint8, sequenceNumber uint16, err error) { - err = i.waitCreated() - if err != nil { - return - } - response := buf.NewSize(2 + 16 + 1 + 1 + 2) - defer response.Release() - _, err = response.ReadFullFrom(i.body, response.FreeLen()) - if err != nil { - return - } - must(binary.Read(response, binary.BigEndian, &id)) - var sourceAddressBuffer [16]byte - must1(response.Read(sourceAddressBuffer[:])) - sourceAddress = parse16BytesIP(sourceAddressBuffer) - must(binary.Read(response, binary.BigEndian, &icmpType)) - must(binary.Read(response, binary.BigEndian, &code)) - must(binary.Read(response, binary.BigEndian, &sequenceNumber)) - return -} - -func (i *IcmpConn) Close() error { - return i.httpConn.Close() -} - -func must(err error) { - if err != nil { - panic(err) - } -} - -func must1[T any](_ T, err error) { - if err != nil { - panic(err) - } -} diff --git a/transport/trusttunnel/quic.go b/transport/trusttunnel/quic.go index 5d0f85d1..90c9d6d6 100644 --- a/transport/trusttunnel/quic.go +++ b/transport/trusttunnel/quic.go @@ -1,22 +1,15 @@ package trusttunnel import ( - "context" - "errors" - "net" "time" "github.com/sagernet/quic-go" "github.com/sagernet/quic-go/congestion" - "github.com/sagernet/quic-go/http3" - "github.com/sagernet/sing-box/common/tls" - E "github.com/sagernet/sing/common/exceptions" - qtls "github.com/sagernet/sing-quic" "github.com/sagernet/sing-quic/congestion_bbr1" "github.com/sagernet/sing-quic/congestion_bbr2" congestion_meta1 "github.com/sagernet/sing-quic/congestion_meta1" congestion_meta2 "github.com/sagernet/sing-quic/congestion_meta2" - "github.com/sagernet/sing/common/ntp" + E "github.com/sagernet/sing/common/exceptions" ) func NewCongestionControl(name string, cwnd int, bbrProfile string, timeFunc func() time.Time) (func(conn *quic.Conn) congestion.CongestionControl, error) { @@ -82,62 +75,3 @@ func NewCongestionControl(name string, cwnd int, bbrProfile string, timeFunc fun return nil, E.New("unknown congestion control: ", name) } } - -type QUICService struct { - service *Service - h3Server *http3.Server - udpConn net.PacketConn - congestionControl string - cwnd int - bbrProfile string -} - -func NewQUICService(service *Service, congestionControl string, cwnd int, bbrProfile string) *QUICService { - return &QUICService{ - service: service, - congestionControl: congestionControl, - cwnd: cwnd, - bbrProfile: bbrProfile, - } -} - -func (s *QUICService) Start(ctx context.Context, udpConn net.PacketConn, tlsConfig tls.ServerConfig) error { - s.udpConn = udpConn - congestionControlFactory, err := NewCongestionControl(s.congestionControl, s.cwnd, s.bbrProfile, ntp.TimeFuncFromContext(ctx)) - if err != nil { - return err - } - s.h3Server = &http3.Server{ - Handler: s.service, - ConnContext: func(ctx context.Context, conn *quic.Conn) context.Context { - conn.SetCongestionControl(congestionControlFactory(conn)) - return ctx - }, - } - if err := qtls.ConfigureHTTP3(tlsConfig); err != nil { - return err - } - quicListener, err := qtls.ListenEarly(udpConn, tlsConfig, &quic.Config{ - MaxIdleTimeout: DefaultSessionTimeout * 2, - MaxIncomingStreams: 1 << 60, - Allow0RTT: true, - }) - if err != nil { - return err - } - go func() { - _ = s.h3Server.ServeListener(quicListener) - }() - return nil -} - -func (s *QUICService) Close() error { - var errs []error - if s.h3Server != nil { - errs = append(errs, s.h3Server.Close()) - } - if s.udpConn != nil { - errs = append(errs, s.udpConn.Close()) - } - return errors.Join(errs...) -}