From edf38d33d64e5bb548e50307f4c4ba0e646c6d0b Mon Sep 17 00:00:00 2001 From: Shtorm <108103062+shtorm-7@users.noreply.github.com> Date: Fri, 26 Jun 2026 01:25:57 +0300 Subject: [PATCH] Add Snell protocol. Refactor MASQUE HTTP/2, Fair Queue. Update XHTTP, OpenVPN, Sudoku, Fallback. Fixes --- .goreleaser.yaml | 10 + adapter/platform.go | 1 + cmd/internal/build_libbox/main.go | 2 +- .../congestion/congestion.go | 4 +- common/dialer/default.go | 17 +- common/list/list.go | 164 +++++++ common/utils.go | 5 +- common/xray/json/badoption/range.go | 83 ---- constant/proxy.go | 6 + examples/masque/client.json | 3 + examples/profiler/config.json | 13 + examples/snell/client.json | 46 ++ examples/snell/server.json | 39 ++ examples/trusttunnel/client.json | 7 +- examples/trusttunnel/server.json | 1 - examples/xhttp/client.json | 4 + examples/xhttp/server.json | 2 + experimental/libbox/config.go | 4 + experimental/libbox/platform.go | 1 + experimental/libbox/service.go | 4 + go.mod | 7 +- go.sum | 14 +- include/registry.go | 2 + include/snell.go | 17 + include/snell_stub.go | 27 + option/group.go | 3 +- option/limiter.go | 5 + option/masque.go | 2 + option/openvpn.go | 1 + option/snell.go | 24 + option/trusttunnel.go | 2 - option/v2ray_transport.go | 65 +-- option/wireguard.go | 9 +- protocol/bond/inbound.go | 1 + protocol/group/fallback.go | 107 +++- protocol/limiter/bandwidth/limiter.go | 161 +++--- protocol/limiter/bandwidth/strategy.go | 2 +- protocol/masque/outbound.go | 12 + protocol/openvpn/outbound.go | 1 + protocol/snell/inbound.go | 130 +++++ protocol/snell/outbound.go | 114 +++++ protocol/trusttunnel/inbound.go | 4 +- protocol/trusttunnel/outbound.go | 2 +- release/DEFAULT_BUILD_TAGS | 2 +- release/DEFAULT_BUILD_TAGS_OTHERS | 2 +- release/DEFAULT_BUILD_TAGS_WINDOWS | 2 +- release/config/openwrt.prerm | 2 +- route/route.go | 14 + route/route_start_test.go | 146 ++++++ route/router.go | 5 +- route/rule/rule_action.go | 8 +- test/go.mod | 14 +- test/go.sum | 14 +- test/v2ray_grpc_test.go | 16 + transport/masque/client_h2.go | 331 +++++++++++++ transport/masque/masque.go | 127 ++--- transport/masque/options.go | 3 + transport/masque/tunnel.go | 52 +- transport/openvpn/cipher.go | 34 +- transport/openvpn/client.go | 38 +- transport/openvpn/config.go | 3 +- transport/openvpn/control.go | 207 +++++--- transport/openvpn/data.go | 70 ++- transport/openvpn/e2e_test.go | 444 +++++++++++++++++ transport/openvpn/keymethod.go | 2 +- transport/openvpn/lzo.go | 48 ++ transport/openvpn/push.go | 56 ++- transport/openvpn/tlsauth.go | 4 + transport/openvpn/tunnel.go | 64 ++- transport/simple-obfs/http_server.go | 100 ++++ transport/simple-obfs/tls_server.go | 154 ++++++ transport/snell/address.go | 144 ++++++ transport/snell/cipher.go | 56 +++ transport/snell/client.go | 120 +++++ transport/snell/pool.go | 153 ++++++ transport/snell/service.go | 294 +++++++++++ transport/snell/shadowaead.go | 211 ++++++++ transport/snell/snell.go | 408 +++++++++++++++ transport/snell/v4.go | 463 ++++++++++++++++++ transport/sudoku/multiplex/session.go | 29 +- .../multiplex/session_backpressure_test.go | 91 ++++ .../sudoku/obfs/sudoku/ascii_mode_test.go | 56 +++ transport/sudoku/obfs/sudoku/conn.go | 103 +++- .../sudoku/obfs/sudoku/conn_roundtrip_test.go | 51 ++ transport/sudoku/obfs/sudoku/encode.go | 44 +- transport/sudoku/obfs/sudoku/packed.go | 150 ++++-- .../sudoku/obfs/sudoku/packed_prefix_test.go | 90 ++++ transport/sudoku/obfs/sudoku/padding_prob.go | 4 +- transport/sudoku/obfs/sudoku/pending.go | 14 +- transport/sudoku/obfs/sudoku/rand.go | 40 +- transport/sudoku/obfs/sudoku/table.go | 26 +- transport/trusttunnel/client.go | 6 +- transport/v2raygrpc/custom_name.go | 19 +- transport/v2raygrpclite/client.go | 7 +- transport/v2raygrpclite/path.go | 10 + transport/v2raygrpclite/server.go | 2 +- transport/v2raykcp/sending.go | 23 +- transport/v2rayxhttp/client.go | 49 +- transport/v2rayxhttp/conn.go | 2 +- transport/v2rayxhttp/dialer.go | 7 +- transport/v2rayxhttp/mux.go | 25 +- transport/v2rayxhttp/server.go | 87 +++- transport/v2rayxhttp/upload_queue.go | 116 +++-- transport/v2rayxhttp/utils.go | 41 +- transport/v2rayxhttp/writer.go | 3 +- transport/v2rayxhttp/xpadding.go | 10 +- transport/wireguard/endpoint_options.go | 10 +- 107 files changed, 5346 insertions(+), 708 deletions(-) rename transport/trusttunnel/quic.go => common/congestion/congestion.go (92%) create mode 100644 common/list/list.go delete mode 100644 common/xray/json/badoption/range.go create mode 100644 examples/profiler/config.json create mode 100644 examples/snell/client.json create mode 100644 examples/snell/server.json create mode 100644 include/snell.go create mode 100644 include/snell_stub.go create mode 100644 option/snell.go create mode 100644 protocol/snell/inbound.go create mode 100644 protocol/snell/outbound.go create mode 100644 route/route_start_test.go create mode 100644 transport/masque/client_h2.go create mode 100644 transport/openvpn/e2e_test.go create mode 100644 transport/openvpn/lzo.go create mode 100644 transport/simple-obfs/http_server.go create mode 100644 transport/simple-obfs/tls_server.go create mode 100644 transport/snell/address.go create mode 100644 transport/snell/cipher.go create mode 100644 transport/snell/client.go create mode 100644 transport/snell/pool.go create mode 100644 transport/snell/service.go create mode 100644 transport/snell/shadowaead.go create mode 100644 transport/snell/snell.go create mode 100644 transport/snell/v4.go create mode 100644 transport/sudoku/multiplex/session_backpressure_test.go create mode 100644 transport/sudoku/obfs/sudoku/ascii_mode_test.go create mode 100644 transport/sudoku/obfs/sudoku/conn_roundtrip_test.go create mode 100644 transport/sudoku/obfs/sudoku/packed_prefix_test.go create mode 100644 transport/v2raygrpclite/path.go diff --git a/.goreleaser.yaml b/.goreleaser.yaml index c3f17c15..99c0ff99 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -26,6 +26,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_manager - with_admin_panel - with_profiler @@ -64,6 +65,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_profiler - badlinkname - tfogo_checklinkname0 @@ -123,6 +125,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_manager - with_admin_panel - with_profiler @@ -156,6 +159,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_manager - with_admin_panel - with_profiler @@ -189,6 +193,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_manager - with_admin_panel - with_profiler @@ -222,6 +227,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_manager - with_admin_panel - with_profiler @@ -255,6 +261,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_manager - with_admin_panel - with_profiler @@ -304,6 +311,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_manager - with_admin_panel - with_profiler @@ -361,6 +369,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_profiler - badlinkname - tfogo_checklinkname0 @@ -433,6 +442,7 @@ builds: - with_openvpn - with_trusttunnel - with_sudoku + - with_snell - with_profiler - badlinkname - tfogo_checklinkname0 diff --git a/adapter/platform.go b/adapter/platform.go index df1f4471..52cc1578 100644 --- a/adapter/platform.go +++ b/adapter/platform.go @@ -13,6 +13,7 @@ type PlatformInterface interface { UsePlatformAutoDetectInterfaceControl() bool AutoDetectInterfaceControl(fd int) error + BindInterfaceControl(fd int, interfaceName string) error UsePlatformInterface() bool OpenInterface(options *tun.Options, platformOptions option.TunPlatformOptions) (tun.Tun, error) diff --git a/cmd/internal/build_libbox/main.go b/cmd/internal/build_libbox/main.go index e2659117..5dc354b8 100644 --- a/cmd/internal/build_libbox/main.go +++ b/cmd/internal/build_libbox/main.go @@ -63,7 +63,7 @@ func init() { sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -X internal/godebug.defaultGODEBUG=multipathtcp=0 -s -w -buildid= -checklinkname=0") debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -X internal/godebug.defaultGODEBUG=multipathtcp=0 -checklinkname=0") - sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_masque", "with_mtproxy", "with_trusttunnel", "with_openvpn", "with_sudoku", "with_utls", "with_naive_outbound", "with_clash_api", "badlinkname", "tfogo_checklinkname0") + sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_masque", "with_mtproxy", "with_trusttunnel", "with_openvpn", "with_sudoku", "with_snell", "with_utls", "with_naive_outbound", "with_clash_api", "badlinkname", "tfogo_checklinkname0") darwinTags = append(darwinTags, "with_dhcp", "grpcnotrace") // memcTags = append(memcTags, "with_tailscale") sharedTags = append(sharedTags, "with_tailscale", "ts_omit_logtail", "ts_omit_ssh", "ts_omit_drive", "ts_omit_taildrop", "ts_omit_webclient", "ts_omit_doctor", "ts_omit_capture", "ts_omit_kube", "ts_omit_aws", "ts_omit_synology", "ts_omit_bird") diff --git a/transport/trusttunnel/quic.go b/common/congestion/congestion.go similarity index 92% rename from transport/trusttunnel/quic.go rename to common/congestion/congestion.go index 90c9d6d6..672b458f 100644 --- a/transport/trusttunnel/quic.go +++ b/common/congestion/congestion.go @@ -1,4 +1,4 @@ -package trusttunnel +package congestion import ( "time" @@ -12,7 +12,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -func NewCongestionControl(name string, cwnd int, bbrProfile string, timeFunc func() time.Time) (func(conn *quic.Conn) congestion.CongestionControl, error) { +func NewCongestionControl(name string, cwnd int, timeFunc func() time.Time) (func(conn *quic.Conn) congestion.CongestionControl, error) { if timeFunc == nil { timeFunc = time.Now } diff --git a/common/dialer/default.go b/common/dialer/default.go index 4ffe00c1..e4df1907 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -70,9 +70,20 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial if !(C.IsLinux || C.IsDarwin || C.IsWindows) { return nil, E.New("`bind_interface` is only supported on Linux, macOS and Windows") } - bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1) - dialer.Control = control.Append(dialer.Control, bindFunc) - listener.Control = control.Append(listener.Control, bindFunc) + if platformInterface != nil && platformInterface.UsePlatformAutoDetectInterfaceControl() { + interfaceName := options.BindInterface + bindFunc := func(network, address string, conn syscall.RawConn) error { + return control.Raw(conn, func(fd uintptr) error { + return platformInterface.BindInterfaceControl(int(fd), interfaceName) + }) + } + dialer.Control = control.Append(dialer.Control, bindFunc) + listener.Control = control.Append(listener.Control, bindFunc) + } else { + bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1) + dialer.Control = control.Append(dialer.Control, bindFunc) + listener.Control = control.Append(listener.Control, bindFunc) + } } if options.RoutingMark > 0 { if !C.IsLinux { diff --git a/common/list/list.go b/common/list/list.go new file mode 100644 index 00000000..0c9bbd26 --- /dev/null +++ b/common/list/list.go @@ -0,0 +1,164 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package list + +// Element is an element of a linked list. +type Element[T any] struct { + next, prev *Element[T] + list *List[T] + Value T +} + +func (e *Element[T]) Next() *Element[T] { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +func (e *Element[T]) Prev() *Element[T] { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +func (e *Element[T]) Remove() bool { + if e.list == nil { + return false + } + e.list.remove(e) + return true +} + +type List[T any] struct { + root Element[T] + len int +} + +func (l *List[T]) Init() *List[T] { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +func New[T any]() *List[T] { return new(List[T]).Init() } + +func (l *List[T]) Len() int { return l.len } + +func (l *List[T]) Front() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.next +} + +func (l *List[T]) Back() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.prev +} + +func (l *List[T]) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +func (l *List[T]) insert(e, at *Element[T]) *Element[T] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len++ + return e +} + +func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { + return l.insert(&Element[T]{Value: v}, at) +} + +func (l *List[T]) remove(e *Element[T]) { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil + e.prev = nil + e.list = nil + l.len-- +} + +func (l *List[T]) Remove(e *Element[T]) T { + if e.list == l { + l.remove(e) + } + return e.Value +} + +func (l *List[T]) PushFront(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +func (l *List[T]) PushBack(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + return l.insertValue(v, mark.prev) +} + +func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + return l.insertValue(v, mark) +} + +func (l *List[T]) MoveToFront(e *Element[T]) { + if e.list != l || l.root.next == e { + return + } + l.move(e, &l.root) +} + +func (l *List[T]) MoveToBack(e *Element[T]) { + if e.list != l || l.root.prev == e { + return + } + l.move(e, l.root.prev) +} + +func (l *List[T]) MoveBefore(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark.prev) +} + +func (l *List[T]) MoveAfter(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark) +} + +func (l *List[T]) move(e, at *Element[T]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e +} diff --git a/common/utils.go b/common/utils.go index a821c003..d2c5d922 100644 --- a/common/utils.go +++ b/common/utils.go @@ -9,7 +9,6 @@ import ( "strings" "time" - Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption" "github.com/sagernet/sing/common/json/badoption" ) @@ -69,8 +68,8 @@ func DecodeBase64URLSafe(content string) (string, error) { return string(result), nil } -func ParseXHTTPRange(value string) (Xbadoption.Range, error) { - result := Xbadoption.Range{} +func ParseXHTTPRange(value string) (badoption.Range[int], error) { + result := badoption.Range[int]{} encoded, err := json.Marshal(value) if err != nil { return result, err diff --git a/common/xray/json/badoption/range.go b/common/xray/json/badoption/range.go deleted file mode 100644 index 28ed0896..00000000 --- a/common/xray/json/badoption/range.go +++ /dev/null @@ -1,83 +0,0 @@ -package badoption - -import ( - "encoding/json" - "fmt" - "strconv" - "strings" - - "github.com/sagernet/sing-box/common/xray/crypto" - E "github.com/sagernet/sing/common/exceptions" -) - -type Range struct { - From int32 `json:"from"` - To int32 `json:"to"` -} - -func (c *Range) Build() *Range { - return (*Range)(c) -} - -func (c *Range) MarshalJSON() ([]byte, error) { - if c.From == c.To { - return json.Marshal(c.From) - } - return json.Marshal(fmt.Sprintf("%d-%d", c.From, c.To)) -} - -func (c *Range) UnmarshalJSON(content []byte) error { - var rangeValue struct { - From int32 `json:"from"` - To int32 `json:"to"` - } - var stringValue string - err := json.Unmarshal(content, &stringValue) - if err == nil { - parts := strings.Split(stringValue, "-") - if len(parts) != 2 { - from, err := strconv.ParseInt(parts[0], 10, 32) - if err != nil { - return err - } - rangeValue.From, rangeValue.To = int32(from), int32(from) - } else { - from, err := strconv.ParseInt(parts[0], 10, 32) - if err != nil { - return err - } - to, err := strconv.ParseInt(parts[1], 10, 32) - if err != nil { - return err - } - rangeValue.From, rangeValue.To = int32(from), int32(to) - } - } else { - var int32Value int32 - err := json.Unmarshal(content, &int32Value) - if err == nil { - rangeValue.From, rangeValue.To = int32Value, int32Value - } else { - err := json.Unmarshal(content, &rangeValue) - if err != nil { - return err - } - } - } - if rangeValue.From > rangeValue.To { - return E.New("invalid range") - } - *c = Range{rangeValue.From, rangeValue.To} - return nil -} - -func (c *Range) String() string { - if c.From == c.To { - return strconv.FormatInt(int64(c.From), 10) - } - return fmt.Sprintf("%d-%d", c.From, c.To) -} - -func (c Range) Rand() int32 { - return int32(crypto.RandBetween(int64(c.From), int64(c.To))) -} diff --git a/constant/proxy.go b/constant/proxy.go index cbb65868..e39c9200 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -28,6 +28,7 @@ const ( TypeMieru = "mieru" TypeAnyTLS = "anytls" TypeSudoku = "sudoku" + TypeSnell = "snell" TypeShadowsocksR = "shadowsocksr" TypeVLESS = "vless" TypeTUIC = "tuic" @@ -41,6 +42,7 @@ const ( TypeBandwidthLimiter = "bandwidth-limiter" TypeTrafficLimiter = "traffic-limiter" TypeRateLimiter = "rate-limiter" + TypeFairQueue = "fair-queue" TypeAdminPanel = "admin-panel" TypeManagerAPI = "manager-api" TypeNodeManagerAPI = "node-manager-api" @@ -129,6 +131,8 @@ func ProxyDisplayName(proxyType string) string { return "AnyTLS" case TypeSudoku: return "Sudoku" + case TypeSnell: + return "Snell" case TypeFallback: return "Fallback" case TypeTailscale: @@ -145,6 +149,8 @@ func ProxyDisplayName(proxyType string) string { return "Traffic Limiter" case TypeRateLimiter: return "Rate Limiter" + case TypeFairQueue: + return "Fair Queue" case TypeVPNClient: return "VPN Client" case TypeVPNServer: diff --git a/examples/masque/client.json b/examples/masque/client.json index e858af6f..aa2bb853 100644 --- a/examples/masque/client.json +++ b/examples/masque/client.json @@ -39,11 +39,14 @@ "udp_keepalive_period": "30s", "udp_initial_packet_size": 0, "reconnect_delay": "5s", + "congestion_controller": "bbr", + "cwnd": 0, "tls": { // TLS fields for HTTP2 "insecure": false, "cipher_suites": [], "curve_preferences": [], "fragment": false, + "fragment_fallback_delay": "500ms", "record_fragment": false, "kernel_tx": false, "kernel_rx": false diff --git a/examples/profiler/config.json b/examples/profiler/config.json new file mode 100644 index 00000000..50709372 --- /dev/null +++ b/examples/profiler/config.json @@ -0,0 +1,13 @@ +{ + "log": { + "level": "info" + }, + "services": [ + { + "type": "profiler", + "tag": "pprof", + "listen": "127.0.0.1", + "listen_port": 6060 + } + ] +} diff --git a/examples/snell/client.json b/examples/snell/client.json new file mode 100644 index 00000000..64e4538a --- /dev/null +++ b/examples/snell/client.json @@ -0,0 +1,46 @@ +{ + "log": { + "level": "error" + }, + "dns": { + "servers": [ + { + "type": "local", + "tag": "default" + } + ] + }, + "inbounds": [ + { + "type": "mixed", + "tag": "mixed-in", + "listen_port": 7897 + } + ], + "outbounds": [ + { + "type": "direct", + "tag": "direct" + }, + { + "type": "snell", + "tag": "snell-out", + "server": "example.com", + "server_port": 8443, + "psk": "your-secret-psk", + "version": 4, // 1 | 2 | 3 | 4 | 5 (v5 falls back to v4) + "reuse": true, // v4 only, reuse pooled connections + "network": ["tcp", "udp"] + // "obfs": { + // "mode": "tls", // tls | http + // "host": "bing.com" + // } + // Dial Fields + } + ], + "route": { + "final": "snell-out", + "default_domain_resolver": "default", + "auto_detect_interface": true + } +} diff --git a/examples/snell/server.json b/examples/snell/server.json new file mode 100644 index 00000000..a1d82a87 --- /dev/null +++ b/examples/snell/server.json @@ -0,0 +1,39 @@ +{ + "log": { + "level": "error" + }, + "dns": { + "servers": [ + { + "type": "local", + "tag": "default" + } + ] + }, + "inbounds": [ + { + "type": "snell", + "tag": "snell-in", + "listen": "::", + "listen_port": 8443, + "psk": "your-secret-psk", + "version": 4, // 4 | 5 (server supports v4/v5 only) + "network": ["tcp", "udp"] + // "obfs": { + // "mode": "tls", // tls | http + // "host": "bing.com" + // } + } + ], + "outbounds": [ + { + "type": "direct", + "tag": "direct" + } + ], + "route": { + "final": "direct", + "default_domain_resolver": "default", + "auto_detect_interface": true + } +} diff --git a/examples/trusttunnel/client.json b/examples/trusttunnel/client.json index 505a1b77..4ecf2e6d 100644 --- a/examples/trusttunnel/client.json +++ b/examples/trusttunnel/client.json @@ -31,7 +31,8 @@ "multiplex": { "enabled": true, "max_connections": 8, - "min_streams": 5 + "min_streams": 5, + "max_streams": 0 }, "tls": { "enabled": true, @@ -50,12 +51,12 @@ "health_check": true, "quic": true, "congestion_controller": "bbr", // bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno - "bbr_profile": "standard", // standard, conservative, aggressive "cwnd": 32, "multiplex": { "enabled": true, "max_connections": 8, - "min_streams": 5 + "min_streams": 5, + "max_streams": 0 }, "tls": { "enabled": true, diff --git a/examples/trusttunnel/server.json b/examples/trusttunnel/server.json index db8e5508..6f7bdfef 100644 --- a/examples/trusttunnel/server.json +++ b/examples/trusttunnel/server.json @@ -13,7 +13,6 @@ } ], "congestion_controller": "bbr", // bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno - "bbr_profile": "standard", // standard, conservative, aggressive "cwnd": 32, "tls": { "enabled": true, diff --git a/examples/xhttp/client.json b/examples/xhttp/client.json index 9e53beb2..12efb76e 100644 --- a/examples/xhttp/client.json +++ b/examples/xhttp/client.json @@ -65,6 +65,8 @@ "uplink_data_placement": "", "uplink_data_key": "", "uplink_chunk_size": 0, + "congestion_controller": "", // h3 only: bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno + "cwnd": 0, // h3 only: initial congestion window in packets, default 32 "server": "example.com", "server_port": 443, "download": { @@ -97,6 +99,8 @@ "uplink_data_placement": "", "uplink_data_key": "", "uplink_chunk_size": 0, + "congestion_controller": "", // h3 only: bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno + "cwnd": 0, // h3 only: initial congestion window in packets, default 32 "server": "example.com", "server_port": 443, "tls": { // https://sing-box.sagernet.org/configuration/shared/tls/#outbound diff --git a/examples/xhttp/server.json b/examples/xhttp/server.json index 4e543909..d172b04e 100644 --- a/examples/xhttp/server.json +++ b/examples/xhttp/server.json @@ -51,6 +51,8 @@ "seq_key": "", "uplink_data_placement": "", "uplink_data_key": "", + "congestion_controller": "", // h3 only: bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno + "cwnd": 0, // h3 only: initial congestion window in packets, default 32 } } ], diff --git a/experimental/libbox/config.go b/experimental/libbox/config.go index 45156f77..517f9dda 100644 --- a/experimental/libbox/config.go +++ b/experimental/libbox/config.go @@ -78,6 +78,10 @@ func (s *platformInterfaceStub) AutoDetectInterfaceControl(fd int) error { return nil } +func (s *platformInterfaceStub) BindInterfaceControl(fd int, interfaceName string) error { + return os.ErrInvalid +} + func (s *platformInterfaceStub) UsePlatformInterface() bool { return false } diff --git a/experimental/libbox/platform.go b/experimental/libbox/platform.go index b82121b7..d25d7a4c 100644 --- a/experimental/libbox/platform.go +++ b/experimental/libbox/platform.go @@ -6,6 +6,7 @@ type PlatformInterface interface { LocalDNSTransport() LocalDNSTransport UsePlatformAutoDetectInterfaceControl() bool AutoDetectInterfaceControl(fd int32) error + BindInterfaceControl(fd int32, interfaceName string) error OpenTun(options TunOptions) (int32, error) UseProcFS() bool FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (*ConnectionOwner, error) diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index 61ec98b1..3b5c0987 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -49,6 +49,10 @@ func (w *platformInterfaceWrapper) AutoDetectInterfaceControl(fd int) error { return w.iif.AutoDetectInterfaceControl(int32(fd)) } +func (w *platformInterfaceWrapper) BindInterfaceControl(fd int, interfaceName string) error { + return w.iif.BindInterfaceControl(int32(fd), interfaceName) +} + func (w *platformInterfaceWrapper) UsePlatformInterface() bool { return true } diff --git a/go.mod b/go.mod index 6fde228f..eb1c6d8e 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/miekg/dns v1.1.72 github.com/openai/openai-go/v3 v3.26.0 github.com/oschwald/maxminddb-golang v1.13.1 + github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1 github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a github.com/sagernet/cors v1.2.1 @@ -231,8 +232,8 @@ replace github.com/sagernet/sing-vmess => github.com/shtorm-7/sing-vmess v0.2.7- replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 -replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 +replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.1.0 -replace github.com/shtorm-7/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 +replace github.com/shtorm-7/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.2.0 -replace github.com/sagernet/sing => github.com/shtorm-7/sing v0.8.10-extended-1.1.0 +replace github.com/sagernet/sing => github.com/shtorm-7/sing v0.8.10-extended-1.2.0 diff --git a/go.sum b/go.sum index 4ebaca11..d8a46388 100644 --- a/go.sum +++ b/go.sum @@ -268,6 +268,8 @@ github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= +github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e h1:dCWirM5F3wMY+cmRda/B1BiPsFtmzXqV9b0hLWtVBMs= +github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e/go.mod h1:9leZcVcItj6m9/CfHY5Em/iBrCz7js8LcRQGTKEEv2M= github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= @@ -373,16 +375,16 @@ github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1h github.com/sagernet/smux v1.5.50-sing-box-mod.1/go.mod h1:NjhsCEWedJm7eFLyhuBgIEzwfhRmytrUoiLluxs5Sk8= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA= -github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 h1:ws7BIsYLd31Wjifq88BYCHRVlgO+07iwil39s6ERba8= -github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0/go.mod h1:mRwx4w32qQxsWB2kThuHpbo7iNjJiq1jYWubgqEPjHA= +github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.1.0 h1:3ZV98mKqKNPCPWHevJ6RPsb65DwPrRFEUOHUfDnG6vw= +github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.1.0/go.mod h1:mRwx4w32qQxsWB2kThuHpbo7iNjJiq1jYWubgqEPjHA= github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 h1:e5s7RKBd2rIPR0StbvZ2vTVtJ5jDTsTk5wtIIapZTRg= github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI= -github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 h1:PLZ/YHqnApPx13wt6MX3ItqESp4ueBr1tGSi0bEGqYw= -github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4= +github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.2.0 h1:aOd9Vy2LGSwgMM+4805AgLBE/MQf8UymbXHxUZjSmoU= +github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.2.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4= github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 h1:iBLll4ZZG8ULQcHWs6gGslZWtBN72Zo1zjySzMVHF7g= github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0= -github.com/shtorm-7/sing v0.8.10-extended-1.1.0 h1:P4JL2cugjvEvnYu8tMmpR30SE1qsS45RcnNEwzDz5as= -github.com/shtorm-7/sing v0.8.10-extended-1.1.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA= +github.com/shtorm-7/sing v0.8.10-extended-1.2.0 h1:5yw9j0+P2QkRWvxBvb71wvNdpAlHmmpBv4hj2gqvass= +github.com/shtorm-7/sing v0.8.10-extended-1.2.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0 h1:a5OoXr3e2ACbM6vDIaaGL44IdHQ6wPjcSoU13vfC0Sw= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0 h1:WVheKmQH5hSQbJU1ZTKthKSutkTLWSb2hp4JuQhJBow= diff --git a/include/registry.go b/include/registry.go index ebb1cd6a..ddb9f9d5 100644 --- a/include/registry.go +++ b/include/registry.go @@ -91,6 +91,7 @@ func InboundRegistry() *inbound.Registry { registerStubForRemovedInbounds(registry) registerMTProxyInbound(registry) registerSudokuInbound(registry) + registerSnellInbound(registry) return registry } @@ -135,6 +136,7 @@ func OutboundRegistry() *outbound.Registry { registerQUICOutbounds(registry) registerStubForRemovedOutbounds(registry) registerSudokuOutbound(registry) + registerSnellOutbound(registry) return registry } diff --git a/include/snell.go b/include/snell.go new file mode 100644 index 00000000..b4b4e442 --- /dev/null +++ b/include/snell.go @@ -0,0 +1,17 @@ +//go:build with_snell + +package include + +import ( + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/protocol/snell" +) + +func registerSnellInbound(registry *inbound.Registry) { + snell.RegisterInbound(registry) +} + +func registerSnellOutbound(registry *outbound.Registry) { + snell.RegisterOutbound(registry) +} diff --git a/include/snell_stub.go b/include/snell_stub.go new file mode 100644 index 00000000..56041cb9 --- /dev/null +++ b/include/snell_stub.go @@ -0,0 +1,27 @@ +//go:build !with_snell + +package include + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/adapter/outbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func registerSnellInbound(registry *inbound.Registry) { + inbound.Register[option.SnellInboundOptions](registry, C.TypeSnell, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellInboundOptions) (adapter.Inbound, error) { + return nil, E.New(`Snell is not included in this build, rebuild with -tags with_snell`) + }) +} + +func registerSnellOutbound(registry *outbound.Registry) { + outbound.Register[option.SnellOutboundOptions](registry, C.TypeSnell, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellOutboundOptions) (adapter.Outbound, error) { + return nil, E.New(`Snell is not included in this build, rebuild with -tags with_snell`) + }) +} diff --git a/option/group.go b/option/group.go index 2fb8e65b..96b9a636 100644 --- a/option/group.go +++ b/option/group.go @@ -18,7 +18,8 @@ type URLTestOutboundOptions struct { } type FallbackOutboundOptions struct { - Outbounds []string `json:"outbounds"` + Outbounds []string `json:"outbounds"` + BlacklistTimeout badoption.Duration `json:"blacklist_timeout,omitempty"` } type GroupCommonOption struct { diff --git a/option/limiter.go b/option/limiter.go index 1b187c93..358f5beb 100644 --- a/option/limiter.go +++ b/option/limiter.go @@ -69,3 +69,8 @@ type RateLimiterUser struct { Count uint32 `json:"count"` Interval badoption.Duration `json:"interval"` } + +type FairQueueOutboundOptions struct { + FlowKeys []string `json:"flow_keys,omitempty"` + Outbound string `json:"outbound"` +} diff --git a/option/masque.go b/option/masque.go index 7ccef2a8..a2311ef1 100644 --- a/option/masque.go +++ b/option/masque.go @@ -18,6 +18,8 @@ type MASQUEOutboundOptions struct { UDPKeepalivePeriod badoption.Duration `json:"udp_keepalive_period,omitempty"` UDPInitialPacketSize uint16 `json:"udp_initial_packet_size,omitempty"` ReconnectDelay badoption.Duration `json:"reconnect_delay,omitempty"` + CongestionController string `json:"congestion_controller,omitempty"` + CWND int `json:"cwnd,omitempty"` MASQUEOutboundTLSOptionsContainer } diff --git a/option/openvpn.go b/option/openvpn.go index 3841c109..4e3d1c09 100644 --- a/option/openvpn.go +++ b/option/openvpn.go @@ -25,6 +25,7 @@ type OpenVPNOutboundOptions struct { KeyDirection int `json:"key_direction,omitempty"` ReconnectDelay badoption.Duration `json:"reconnect_delay,omitempty"` PingInterval badoption.Duration `json:"ping_interval,omitempty"` + PingRestart badoption.Duration `json:"ping_restart,omitempty"` OpenVPNOutboundTLSOptionsContainer } diff --git a/option/snell.go b/option/snell.go new file mode 100644 index 00000000..c9f04f71 --- /dev/null +++ b/option/snell.go @@ -0,0 +1,24 @@ +package option + +type SnellOutboundOptions struct { + DialerOptions + ServerOptions + PSK string `json:"psk"` + Version int `json:"version,omitempty"` + Reuse bool `json:"reuse,omitempty"` + Network NetworkList `json:"network,omitempty"` + Obfs *SnellObfsOptions `json:"obfs,omitempty"` +} + +type SnellInboundOptions struct { + ListenOptions + PSK string `json:"psk"` + Version int `json:"version,omitempty"` + Network NetworkList `json:"network,omitempty"` + Obfs *SnellObfsOptions `json:"obfs,omitempty"` +} + +type SnellObfsOptions struct { + Mode string `json:"mode,omitempty"` + Host string `json:"host,omitempty"` +} diff --git a/option/trusttunnel.go b/option/trusttunnel.go index 9a17d818..046f823c 100644 --- a/option/trusttunnel.go +++ b/option/trusttunnel.go @@ -6,7 +6,6 @@ type TrustTunnelInboundOptions struct { Users []TrustTunnelUser `json:"users,omitempty"` Network NetworkList `json:"network,omitempty"` CongestionController string `json:"congestion_controller,omitempty"` - BBRProfile string `json:"bbr_profile,omitempty"` CWND int `json:"cwnd,omitempty"` } @@ -32,7 +31,6 @@ type TrustTunnelOutboundOptions struct { HealthCheck bool `json:"health_check,omitempty"` QUIC bool `json:"quic,omitempty"` CongestionController string `json:"congestion_controller,omitempty"` - BBRProfile string `json:"bbr_profile,omitempty"` CWND int `json:"cwnd,omitempty"` Multiplex *TrustTunnelMultiplexOptions `json:"multiplex,omitempty"` } diff --git a/option/v2ray_transport.go b/option/v2ray_transport.go index 987f5207..3a9ee057 100644 --- a/option/v2ray_transport.go +++ b/option/v2ray_transport.go @@ -4,7 +4,6 @@ import ( "net/http" "strings" - Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption" "github.com/sagernet/sing-box/common/xray/utils" C "github.com/sagernet/sing-box/constant" E "github.com/sagernet/sing/common/exceptions" @@ -119,13 +118,13 @@ type V2RayXHTTPBaseOptions struct { Path string `json:"path,omitempty"` Headers map[string]string `json:"headers,omitempty"` DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` - XPaddingBytes Xbadoption.Range `json:"x_padding_bytes"` + XPaddingBytes badoption.Range[int] `json:"x_padding_bytes"` NoGRPCHeader bool `json:"no_grpc_header,omitempty"` NoSSEHeader bool `json:"no_sse_header,omitempty"` - ScMaxEachPostBytes *Xbadoption.Range `json:"sc_max_each_post_bytes"` - ScMinPostsIntervalMs *Xbadoption.Range `json:"sc_min_posts_interval_ms"` + ScMaxEachPostBytes *badoption.Range[int] `json:"sc_max_each_post_bytes"` + ScMinPostsIntervalMs *badoption.Range[int] `json:"sc_min_posts_interval_ms"` ScMaxBufferedPosts int64 `json:"sc_max_buffered_posts,omitempty"` - ScStreamUpServerSecs *Xbadoption.Range `json:"sc_stream_up_server_secs"` + ScStreamUpServerSecs *badoption.Range[int] `json:"sc_stream_up_server_secs"` ServerMaxHeaderBytes int `json:"server_max_header_bytes"` TrustedXForwardedFor badoption.Listable[string] `json:"trusted_x_forwarded_for,omitempty"` Xmux *V2RayXHTTPXmuxOptions `json:"xmux"` @@ -141,7 +140,11 @@ type V2RayXHTTPBaseOptions struct { SeqKey string `json:"seq_key,omitempty"` UplinkDataPlacement string `json:"uplink_data_placement,omitempty"` UplinkDataKey string `json:"uplink_data_key,omitempty"` - UplinkChunkSize *Xbadoption.Range `json:"uplink_chunk_size,omitempty"` + UplinkChunkSize *badoption.Range[int] `json:"uplink_chunk_size,omitempty"` + SessionIDTable string `json:"session_id_table,omitempty"` + SessionIDLength badoption.Range[int] `json:"session_id_length,omitempty"` + CongestionController string `json:"congestion_controller,omitempty"` + CWND int `json:"cwnd,omitempty"` } type _V2RayXHTTPOptions struct { @@ -302,6 +305,10 @@ func checkV2RayXHTTPBaseOptions(mode string, options *V2RayXHTTPBaseOptions) err return E.New("invalid negative value of maxHeaderBytes") } + if mode != "stream-one" && mode != "stream-up" && options.GetNormalizedScMaxEachPostBytes().From <= 0 { + return E.New("`scMaxEachPostBytes` should be bigger than 0") + } + if options.Xmux == nil { options.Xmux = &V2RayXHTTPXmuxOptions{} options.Xmux.MaxConcurrency.From = 1 @@ -346,9 +353,9 @@ func (c *V2RayXHTTPBaseOptions) GetRequestHeader() http.Header { return header } -func (c *V2RayXHTTPBaseOptions) GetNormalizedXPaddingBytes() Xbadoption.Range { +func (c *V2RayXHTTPBaseOptions) GetNormalizedXPaddingBytes() badoption.Range[int] { if c.XPaddingBytes.To == 0 { - return Xbadoption.Range{ + return badoption.Range[int]{ From: 100, To: 1000, } @@ -363,9 +370,9 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkHTTPMethod() string { return c.UplinkHTTPMethod } -func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxEachPostBytes() Xbadoption.Range { +func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxEachPostBytes() badoption.Range[int] { if c.ScMaxEachPostBytes == nil { - return Xbadoption.Range{ + return badoption.Range[int]{ From: 1000000, To: 1000000, } @@ -373,9 +380,9 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxEachPostBytes() Xbadoption.Ran return *c.ScMaxEachPostBytes } -func (c *V2RayXHTTPBaseOptions) GetNormalizedScMinPostsIntervalMs() Xbadoption.Range { +func (c *V2RayXHTTPBaseOptions) GetNormalizedScMinPostsIntervalMs() badoption.Range[int] { if c.ScMinPostsIntervalMs == nil { - return Xbadoption.Range{ + return badoption.Range[int]{ From: 30, To: 30, } @@ -391,9 +398,9 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxBufferedPosts() int { return int(c.ScMaxBufferedPosts) } -func (c *V2RayXHTTPBaseOptions) GetNormalizedScStreamUpServerSecs() Xbadoption.Range { +func (c *V2RayXHTTPBaseOptions) GetNormalizedScStreamUpServerSecs() badoption.Range[int] { if c.ScStreamUpServerSecs == nil { - return Xbadoption.Range{ + return badoption.Range[int]{ From: 20, To: 80, } @@ -401,16 +408,16 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedScStreamUpServerSecs() Xbadoption.R return *c.ScStreamUpServerSecs } -func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkChunkSize() Xbadoption.Range { +func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkChunkSize() badoption.Range[int] { if c.UplinkChunkSize == nil || c.UplinkChunkSize.To == 0 { switch c.UplinkDataPlacement { case PlacementCookie: - return Xbadoption.Range{ + return badoption.Range[int]{ From: 2 * 1024, // 2 KiB To: 3 * 1024, // 3 KiB } case PlacementHeader: - return Xbadoption.Range{ + return badoption.Range[int]{ From: 3 * 1000, // 3 KB To: 4 * 1000, // 4 KB } @@ -418,7 +425,7 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkChunkSize() Xbadoption.Range return c.GetNormalizedScMaxEachPostBytes() } } else if c.UplinkChunkSize.From < 64 { - return Xbadoption.Range{ + return badoption.Range[int]{ From: 64, To: max(64, c.UplinkChunkSize.To), } @@ -485,31 +492,31 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedSeqKey() string { } type V2RayXHTTPXmuxOptions struct { - MaxConcurrency Xbadoption.Range `json:"max_concurrency"` - MaxConnections Xbadoption.Range `json:"max_connections"` - CMaxReuseTimes Xbadoption.Range `json:"c_max_reuse_times"` - HMaxRequestTimes Xbadoption.Range `json:"h_max_request_times"` - HMaxReusableSecs Xbadoption.Range `json:"h_max_reusable_secs"` - HKeepAlivePeriod int64 `json:"h_keep_alive_period"` + MaxConcurrency badoption.Range[int] `json:"max_concurrency"` + MaxConnections badoption.Range[int] `json:"max_connections"` + CMaxReuseTimes badoption.Range[int] `json:"c_max_reuse_times"` + HMaxRequestTimes badoption.Range[int] `json:"h_max_request_times"` + HMaxReusableSecs badoption.Range[int] `json:"h_max_reusable_secs"` + HKeepAlivePeriod int64 `json:"h_keep_alive_period"` } -func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConcurrency() Xbadoption.Range { +func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConcurrency() badoption.Range[int] { return m.MaxConcurrency } -func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConnections() Xbadoption.Range { +func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConnections() badoption.Range[int] { return m.MaxConnections } -func (m *V2RayXHTTPXmuxOptions) GetNormalizedCMaxReuseTimes() Xbadoption.Range { +func (m *V2RayXHTTPXmuxOptions) GetNormalizedCMaxReuseTimes() badoption.Range[int] { return m.CMaxReuseTimes } -func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxRequestTimes() Xbadoption.Range { +func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxRequestTimes() badoption.Range[int] { return m.HMaxRequestTimes } -func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxReusableSecs() Xbadoption.Range { +func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxReusableSecs() badoption.Range[int] { return m.HMaxReusableSecs } diff --git a/option/wireguard.go b/option/wireguard.go index b5ae9ebd..d32e6eb2 100644 --- a/option/wireguard.go +++ b/option/wireguard.go @@ -3,7 +3,6 @@ package option import ( "net/netip" - Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption" "github.com/sagernet/sing/common/json/badoption" ) @@ -40,10 +39,10 @@ type WireGuardAmnezia struct { S2 int `json:"s2,omitempty"` S3 int `json:"s3,omitempty"` S4 int `json:"s4,omitempty"` - H1 *Xbadoption.Range `json:"h1,omitempty"` - H2 *Xbadoption.Range `json:"h2,omitempty"` - H3 *Xbadoption.Range `json:"h3,omitempty"` - H4 *Xbadoption.Range `json:"h4,omitempty"` + H1 *badoption.Range[uint32] `json:"h1,omitempty"` + H2 *badoption.Range[uint32] `json:"h2,omitempty"` + H3 *badoption.Range[uint32] `json:"h3,omitempty"` + H4 *badoption.Range[uint32] `json:"h4,omitempty"` I1 string `json:"i1,omitempty"` I2 string `json:"i2,omitempty"` I3 string `json:"i3,omitempty"` diff --git a/protocol/bond/inbound.go b/protocol/bond/inbound.go index b6ea00cd..93d281a0 100644 --- a/protocol/bond/inbound.go +++ b/protocol/bond/inbound.go @@ -80,6 +80,7 @@ func (h *Inbound) Start(stage adapter.StartStage) error { } func (h *Inbound) Close() error { + h.conns.Close() errs := make([]error, 0) for _, inbound := range h.inbounds { err := inbound.Close() diff --git a/protocol/group/fallback.go b/protocol/group/fallback.go index e7d37f72..772367af 100644 --- a/protocol/group/fallback.go +++ b/protocol/group/fallback.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync" + "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/outbound" @@ -31,14 +32,19 @@ type Fallback struct { tags []string outbounds map[string]adapter.Outbound lastUsedOutbound string - - mtx sync.Mutex + blacklistTimeout time.Duration + blacklist map[string]time.Time + mtx sync.Mutex } func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FallbackOutboundOptions) (adapter.Outbound, error) { if len(options.Outbounds) == 0 { return nil, E.New("missing tags") } + blacklistTimeout := time.Duration(options.BlacklistTimeout) + if blacklistTimeout == 0 { + blacklistTimeout = time.Minute + } outbound := &Fallback{ Adapter: outbound.NewAdapter(C.TypeFallback, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds), ctx: ctx, @@ -47,6 +53,8 @@ func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextL tags: options.Outbounds, outbounds: make(map[string]adapter.Outbound, len(options.Outbounds)), lastUsedOutbound: options.Outbounds[0], + blacklistTimeout: blacklistTimeout, + blacklist: make(map[string]time.Time), } return outbound, nil } @@ -73,35 +81,110 @@ func (s *Fallback) All() []string { } func (s *Fallback) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - var conn net.Conn + s.mtx.Lock() + var active, blacklisted []string + for _, tag := range s.tags { + if s.isBlacklisted(tag) { + blacklisted = append(blacklisted, tag) + } else { + active = append(active, tag) + } + } + s.mtx.Unlock() + var err error - for _, outbound := range s.outbounds { - conn, err = outbound.DialContext(ctx, network, destination) + for _, tag := range active { + var conn net.Conn + conn, err = s.outbounds[tag].DialContext(ctx, network, destination) + if err != nil { + s.logger.InfoContext(ctx, err) + s.mtx.Lock() + s.addToBlacklist(tag) + s.mtx.Unlock() + continue + } + s.mtx.Lock() + s.lastUsedOutbound = tag + s.mtx.Unlock() + return conn, nil + } + for _, tag := range blacklisted { + var conn net.Conn + conn, err = s.outbounds[tag].DialContext(ctx, network, destination) if err != nil { s.logger.InfoContext(ctx, err) continue } s.mtx.Lock() - defer s.mtx.Unlock() - s.lastUsedOutbound = outbound.Tag() + delete(s.blacklist, tag) + s.lastUsedOutbound = tag + s.mtx.Unlock() return conn, nil } return nil, err } func (s *Fallback) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - var conn net.PacketConn + s.mtx.Lock() + var active, blacklisted []string + for _, tag := range s.tags { + if s.isBlacklisted(tag) { + blacklisted = append(blacklisted, tag) + } else { + active = append(active, tag) + } + } + s.mtx.Unlock() + var err error - for _, outbound := range s.outbounds { - conn, err = outbound.ListenPacket(ctx, destination) + for _, tag := range active { + var conn net.PacketConn + conn, err = s.outbounds[tag].ListenPacket(ctx, destination) + if err != nil { + s.logger.InfoContext(ctx, err) + s.mtx.Lock() + s.addToBlacklist(tag) + s.mtx.Unlock() + continue + } + s.mtx.Lock() + s.lastUsedOutbound = tag + s.mtx.Unlock() + return conn, nil + } + for _, tag := range blacklisted { + var conn net.PacketConn + conn, err = s.outbounds[tag].ListenPacket(ctx, destination) if err != nil { s.logger.InfoContext(ctx, err) continue } s.mtx.Lock() - defer s.mtx.Unlock() - s.lastUsedOutbound = outbound.Tag() + delete(s.blacklist, tag) + s.lastUsedOutbound = tag + s.mtx.Unlock() return conn, nil } return nil, err } + +func (s *Fallback) isBlacklisted(tag string) bool { + if s.blacklistTimeout == 0 { + return false + } + expiry, ok := s.blacklist[tag] + if !ok { + return false + } + if time.Now().After(expiry) { + delete(s.blacklist, tag) + return false + } + return true +} + +func (s *Fallback) addToBlacklist(tag string) { + if s.blacklistTimeout > 0 { + s.blacklist[tag] = time.Now().Add(s.blacklistTimeout) + } +} diff --git a/protocol/limiter/bandwidth/limiter.go b/protocol/limiter/bandwidth/limiter.go index 07ff57e0..b8053e7b 100644 --- a/protocol/limiter/bandwidth/limiter.go +++ b/protocol/limiter/bandwidth/limiter.go @@ -2,11 +2,11 @@ package bandwidth import ( "context" - "slices" "sync" "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/list" ) type BandwidthLimiter interface { @@ -14,123 +14,144 @@ type BandwidthLimiter interface { SetSpeed(speed uint64) } -type FlowKeysLimiter struct { +type FairQueueLimiter struct { limiter BandwidthLimiter connIDGetter ConnIDGetter - waits map[string][]*wait - conns map[string]int + flows *list.List[*flow] + index map[string]*list.Element[*flow] + bytes map[string]uint64 + pool sync.Pool queue chan struct{} reset time.Time mtx sync.Mutex } -func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FlowKeysLimiter { - return &FlowKeysLimiter{ +func NewFairQueueLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FairQueueLimiter { + return &FairQueueLimiter{ limiter: limiter, connIDGetter: connIDGetter, - waits: make(map[string][]*wait), - conns: make(map[string]int), + flows: list.New[*flow](), + index: make(map[string]*list.Element[*flow]), + bytes: make(map[string]uint64), + pool: sync.Pool{New: func() any { return list.New[*request]() }}, queue: make(chan struct{}, 1), reset: time.Now().Add(time.Second), } } -func (l *FlowKeysLimiter) SetSpeed(speed uint64) { +func (l *FairQueueLimiter) SetSpeed(speed uint64) { l.limiter.SetSpeed(speed) } -func (l *FlowKeysLimiter) WaitN(ctx context.Context, n int) error { +func (l *FairQueueLimiter) WaitN(ctx context.Context, n int) error { id, _ := l.connIDGetter(ctx, adapter.ContextFrom(ctx)) - mainWait := &wait{ctx, make(chan struct{}), n} + mainRequest := &request{ctx: ctx, done: make(chan struct{}), n: n} l.mtx.Lock() - if waits, ok := l.waits[id]; ok { - l.waits[id] = append(waits, mainWait) - } else { - l.waits[id] = []*wait{mainWait} + elem, ok := l.index[id] + if !ok { + f := &flow{id: id, pending: l.pool.Get().(*list.List[*request])} + elem = l.flows.PushFront(f) + l.index[id] = elem } + mainRequestElem := elem.Value.pending.PushBack(mainRequest) + l.reorder(elem) l.mtx.Unlock() select { case l.queue <- struct{}{}: - case <-mainWait.finish: + case <-mainRequest.done: return nil case <-ctx.Done(): l.mtx.Lock() - for i, wait := range l.waits[id] { - if wait == mainWait { - l.waits[id] = slices.Delete(l.waits[id], i, i+1) - close(wait.finish) - break - } - } + l.removeRequest(id, mainRequestElem) l.mtx.Unlock() return ctx.Err() } + select { + case <-mainRequest.done: + <-l.queue + return nil + default: + } for { if ctx.Err() != nil { l.mtx.Lock() - for i, wait := range l.waits[id] { - if wait == mainWait { - l.waits[id] = slices.Delete(l.waits[id], i, i+1) - close(wait.finish) - break - } - } + l.removeRequest(id, mainRequestElem) l.mtx.Unlock() <-l.queue return ctx.Err() } + l.mtx.Lock() now := time.Now() if l.reset.Compare(now) == -1 { - clear(l.conns) + clear(l.bytes) l.reset = now.Add(time.Second) } - l.mtx.Lock() - var minConnId string - var minN int - for connID, waits := range l.waits { - if len(waits) == 0 { - continue - } - if n, ok := l.conns[connID]; ok { - if minConnId == "" { - minConnId = connID - minN = n - continue - } - if n+waits[0].n < minN { - minConnId = connID - minN = n - } - } else { - l.conns[connID] = 0 - minConnId = connID - break - } - } - minWait := l.waits[minConnId][0] - l.waits[minConnId][0] = nil - l.waits[minConnId] = l.waits[minConnId][1:] - if len(l.waits) == 0 { - delete(l.waits, minConnId) + flowElem := l.flows.Front() + flow := flowElem.Value + firstRequestElem := flow.pending.Front() + firstRequest := firstRequestElem.Value + l.bytes[flow.id] += uint64(firstRequest.n) + firstRequestElem.Remove() + if flow.pending.Len() == 0 { + l.flows.Remove(flowElem) + delete(l.index, flow.id) + l.pool.Put(flow.pending) + } else { + l.reorder(flowElem) } l.mtx.Unlock() - err := l.limiter.WaitN(ctx, minWait.n) - if err != nil { - continue - } - l.conns[minConnId] = l.conns[minConnId] + minWait.n - close(minWait.finish) - if minWait == mainWait { + l.limiter.WaitN(firstRequest.ctx, firstRequest.n) + close(firstRequest.done) + if firstRequest == mainRequest { <-l.queue return nil } } } -type wait struct { - ctx context.Context - finish chan struct{} - n int +func (l *FairQueueLimiter) reorder(elem *list.Element[*flow]) { + f := elem.Value + front := f.pending.Front() + if front == nil { + return + } + cost := l.bytes[f.id] + uint64(front.Value.n) + for e := l.flows.Front(); e != nil; e = e.Next() { + if e == elem { + continue + } + eFront := e.Value.pending.Front() + if eFront == nil { + continue + } + if cost < l.bytes[e.Value.id]+uint64(eFront.Value.n) { + l.flows.MoveBefore(elem, e) + return + } + } + l.flows.MoveToBack(elem) +} + +func (l *FairQueueLimiter) removeRequest(id string, elem *list.Element[*request]) { + if !elem.Remove() { + return + } + if flowElem, ok := l.index[id]; ok && flowElem.Value.pending.Len() == 0 { + l.flows.Remove(flowElem) + delete(l.index, id) + l.pool.Put(flowElem.Value.pending) + } +} + +type flow struct { + id string + pending *list.List[*request] +} + +type request struct { + ctx context.Context + done chan struct{} + n int } diff --git a/protocol/limiter/bandwidth/strategy.go b/protocol/limiter/bandwidth/strategy.go index 710b895f..87821910 100644 --- a/protocol/limiter/bandwidth/strategy.go +++ b/protocol/limiter/bandwidth/strategy.go @@ -357,7 +357,7 @@ func createSpeedLimiter(speed uint64, flowKeys []string) (BandwidthLimiter, erro if err != nil { return nil, err } - limiter = NewFlowKeysLimiter(getter, limiter) + limiter = NewFairQueueLimiter(getter, limiter) } return limiter, nil } diff --git a/protocol/masque/outbound.go b/protocol/masque/outbound.go index 80a60c5a..2a11d2f5 100644 --- a/protocol/masque/outbound.go +++ b/protocol/masque/outbound.go @@ -11,6 +11,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/common/cloudflare" + "github.com/sagernet/sing-box/common/congestion" "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" @@ -23,6 +24,7 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/service" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -132,6 +134,15 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL logger.ErrorContext(ctx, err) return } + congestionControl, err := congestion.NewCongestionControl( + options.CongestionController, + options.CWND, + ntp.TimeFuncFromContext(ctx), + ) + if err != nil { + logger.ErrorContext(ctx, err) + return + } tunnel, err := masque.NewTunnel( ctx, logger, @@ -156,6 +167,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL UDPKeepalivePeriod: udpKeepalivePeriod, UDPInitialPacketSize: options.UDPInitialPacketSize, ReconnectDelay: options.ReconnectDelay.Build(), + CongestionControl: congestionControl, }, ) if err != nil { diff --git a/protocol/openvpn/outbound.go b/protocol/openvpn/outbound.go index 5ccec1a5..bf1df8b5 100644 --- a/protocol/openvpn/outbound.go +++ b/protocol/openvpn/outbound.go @@ -104,6 +104,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL AllowedAddress: options.AllowedIPs, ReconnectDelay: time.Duration(options.ReconnectDelay), PingInterval: time.Duration(options.PingInterval), + PingRestart: time.Duration(options.PingRestart), }) if err != nil { return nil, err diff --git a/protocol/snell/inbound.go b/protocol/snell/inbound.go new file mode 100644 index 00000000..3c8c3cb7 --- /dev/null +++ b/protocol/snell/inbound.go @@ -0,0 +1,130 @@ +package snell + +import ( + "context" + "net" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/common/listener" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/snell" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func RegisterInbound(registry *inbound.Registry) { + inbound.Register[option.SnellInboundOptions](registry, C.TypeSnell, NewInbound) +} + +type Inbound struct { + inbound.Adapter + router adapter.ConnectionRouterEx + logger logger.ContextLogger + listener *listener.Listener + service *snell.Service +} + +func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellInboundOptions) (adapter.Inbound, error) { + if options.PSK == "" { + return nil, E.New("snell requires psk") + } + udpEnabled := common.Contains(options.Network.Build(), N.NetworkUDP) + obfsMode := "" + if options.Obfs != nil { + obfsMode = options.Obfs.Mode + } + in := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeSnell, tag), + router: router, + logger: logger, + } + service, err := snell.NewService(snell.ServiceOptions{ + PSK: []byte(options.PSK), + Version: options.Version, + ObfsMode: obfsMode, + UDP: udpEnabled, + Logger: logger, + Handler: (*inboundHandler)(in), + }) + if err != nil { + return nil, err + } + in.service = service + in.listener = listener.New(listener.Options{ + Context: ctx, + Logger: logger, + Network: []string{N.NetworkTCP}, + Listen: options.ListenOptions, + ConnectionHandler: in, + }) + return in, nil +} + +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + return h.listener.Start() +} + +func (h *Inbound) Close() error { + return h.listener.Close() +} + +func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + err := h.service.NewConnection(ctx, conn, metadata.Source) + N.CloseOnHandshakeFailure(conn, onClose, err) + if err != nil { + h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) + } +} + +var _ adapter.TCPInjectableInbound = (*Inbound)(nil) + +type inboundHandler Inbound + +func (h *inboundHandler) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, clientID string) { + var metadata adapter.InboundContext + metadata.Inbound = h.Tag() + metadata.InboundType = h.Type() + metadata.InboundDetour = h.listener.ListenOptions().Detour + metadata.Source = source + metadata.Destination = destination + if clientID != "" { + metadata.User = clientID + h.logger.InfoContext(ctx, "[", clientID, "] inbound connection to ", destination) + } else { + h.logger.InfoContext(ctx, "inbound connection to ", destination) + } + done := make(chan struct{}) + h.router.RouteConnectionEx(ctx, conn, metadata, N.OnceClose(func(error) { + close(done) + })) + <-done +} + +func (h *inboundHandler) NewPacketConnection(ctx context.Context, conn net.PacketConn, source M.Socksaddr, clientID string) { + var metadata adapter.InboundContext + metadata.Inbound = h.Tag() + metadata.InboundType = h.Type() + metadata.InboundDetour = h.listener.ListenOptions().Detour + metadata.Source = source + if clientID != "" { + metadata.User = clientID + h.logger.InfoContext(ctx, "[", clientID, "] inbound packet connection") + } else { + h.logger.InfoContext(ctx, "inbound packet connection") + } + done := make(chan struct{}) + h.router.RoutePacketConnectionEx(ctx, bufio.NewPacketConn(conn), metadata, N.OnceClose(func(error) { + close(done) + })) + <-done +} diff --git a/protocol/snell/outbound.go b/protocol/snell/outbound.go new file mode 100644 index 00000000..7a87174f --- /dev/null +++ b/protocol/snell/outbound.go @@ -0,0 +1,114 @@ +package snell + +import ( + "context" + "fmt" + "net" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/snell" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func RegisterOutbound(registry *outbound.Registry) { + outbound.Register[option.SnellOutboundOptions](registry, C.TypeSnell, NewOutbound) +} + +type Outbound struct { + outbound.Adapter + logger logger.ContextLogger + client *snell.Client +} + +func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellOutboundOptions) (adapter.Outbound, error) { + if options.PSK == "" { + return nil, E.New("snell requires psk") + } + version := options.Version + if version == 0 { + version = snell.DefaultSnellVersion + } + if version == snell.Version5 { + version = snell.Version4 + } + udpEnabled := common.Contains(options.Network.Build(), N.NetworkUDP) + switch version { + case snell.Version1, snell.Version2: + if udpEnabled { + return nil, fmt.Errorf("snell version %d does not support UDP", version) + } + case snell.Version3, snell.Version4: + default: + return nil, fmt.Errorf("snell version error: %d", version) + } + reuse := version == snell.Version2 || (version == snell.Version4 && options.Reuse) + obfsMode := "" + obfsHost := "bing.com" + if options.Obfs != nil { + switch options.Obfs.Mode { + case "", "tls", "http": + obfsMode = options.Obfs.Mode + default: + return nil, fmt.Errorf("snell obfs mode error: %s", options.Obfs.Mode) + } + if options.Obfs.Host != "" { + obfsHost = options.Obfs.Host + } + } + outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain()) + if err != nil { + return nil, err + } + client := snell.NewClient(snell.ClientOptions{ + Dialer: outboundDialer, + Server: options.ServerOptions.Build(), + PSK: []byte(options.PSK), + Version: version, + Reuse: reuse, + ObfsMode: obfsMode, + ObfsHost: obfsHost, + }) + return &Outbound{ + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSnell, tag, options.Network.Build(), options.DialerOptions), + logger: logger, + client: client, + }, nil +} + +func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Outbound = h.Tag() + metadata.Destination = destination + switch N.NetworkName(network) { + case N.NetworkTCP: + h.logger.InfoContext(ctx, "outbound connection to ", destination) + return h.client.DialContext(ctx, destination) + case N.NetworkUDP: + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + conn, err := h.client.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return bufio.NewBindPacketConn(conn, destination), nil + default: + return nil, E.Extend(N.ErrUnknownNetwork, network) + } +} + +func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Outbound = h.Tag() + metadata.Destination = destination + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + return h.client.ListenPacket(ctx, destination) +} diff --git a/protocol/trusttunnel/inbound.go b/protocol/trusttunnel/inbound.go index 7e02504e..f9065c63 100644 --- a/protocol/trusttunnel/inbound.go +++ b/protocol/trusttunnel/inbound.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/quic-go/http3" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/common/congestion" "github.com/sagernet/sing-box/common/listener" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" @@ -136,10 +137,9 @@ func (h *Inbound) Start(stage adapter.StartStage) error { if err != nil { return err } - congestionControlFactory, err := trusttunnel.NewCongestionControl( + congestionControlFactory, err := congestion.NewCongestionControl( h.options.CongestionController, h.options.CWND, - h.options.BBRProfile, ntp.TimeFuncFromContext(h.ctx), ) if err != nil { diff --git a/protocol/trusttunnel/outbound.go b/protocol/trusttunnel/outbound.go index 5b95500c..4371de76 100644 --- a/protocol/trusttunnel/outbound.go +++ b/protocol/trusttunnel/outbound.go @@ -53,7 +53,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL QUIC: options.QUIC, CongestionControl: options.CongestionController, CWND: options.CWND, - BBRProfile: options.BBRProfile, + Logger: logger, HealthCheck: options.HealthCheck, } var client trusttunnel.Dialer diff --git a/release/DEFAULT_BUILD_TAGS b/release/DEFAULT_BUILD_TAGS index 03380949..288202d1 100644 --- a/release/DEFAULT_BUILD_TAGS +++ b/release/DEFAULT_BUILD_TAGS @@ -1 +1 @@ -with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_naive_outbound,badlinkname,tfogo_checklinkname0 \ No newline at end of file +with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_snell,with_naive_outbound,badlinkname,tfogo_checklinkname0 \ No newline at end of file diff --git a/release/DEFAULT_BUILD_TAGS_OTHERS b/release/DEFAULT_BUILD_TAGS_OTHERS index 3c99d6ad..c3d0605a 100644 --- a/release/DEFAULT_BUILD_TAGS_OTHERS +++ b/release/DEFAULT_BUILD_TAGS_OTHERS @@ -1 +1 @@ -with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_manager,with_admin_panel,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,badlinkname,tfogo_checklinkname0 +with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_manager,with_admin_panel,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_snell,badlinkname,tfogo_checklinkname0 diff --git a/release/DEFAULT_BUILD_TAGS_WINDOWS b/release/DEFAULT_BUILD_TAGS_WINDOWS index bd722436..a89390bb 100644 --- a/release/DEFAULT_BUILD_TAGS_WINDOWS +++ b/release/DEFAULT_BUILD_TAGS_WINDOWS @@ -1 +1 @@ -with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0 \ No newline at end of file +with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_snell,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0 \ No newline at end of file diff --git a/release/config/openwrt.prerm b/release/config/openwrt.prerm index 12d06ec7..e1106da6 100755 --- a/release/config/openwrt.prerm +++ b/release/config/openwrt.prerm @@ -1,4 +1,4 @@ #!/bin/sh [ -s ${IPKG_INSTROOT}/lib/functions.sh ] || exit 0 . ${IPKG_INSTROOT}/lib/functions.sh -default_prerm $0 $@ +default_prerm $0 $@ || true diff --git a/route/route.go b/route/route.go index ec6ca399..2293fd86 100644 --- a/route/route.go +++ b/route/route.go @@ -58,6 +58,13 @@ func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata } func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + select { + case <-r.started: + case <-ctx.Done(): + return ctx.Err() + case <-r.ctx.Done(): + return r.ctx.Err() + } //nolint:staticcheck if metadata.InboundDetour != "" { if metadata.LastInbound == metadata.InboundDetour { @@ -192,6 +199,13 @@ func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, } func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + select { + case <-r.started: + case <-ctx.Done(): + return ctx.Err() + case <-r.ctx.Done(): + return r.ctx.Err() + } //nolint:staticcheck if metadata.InboundDetour != "" { if metadata.LastInbound == metadata.InboundDetour { diff --git a/route/route_start_test.go b/route/route_start_test.go new file mode 100644 index 00000000..f3d10019 --- /dev/null +++ b/route/route_start_test.go @@ -0,0 +1,146 @@ +package route + +import ( + "context" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" +) + +// newGateTestRouter builds the minimal Router needed to exercise the +// "wait until started" gate in routeConnection / routePacketConnection. +func newGateTestRouter(ctx context.Context) *Router { + return &Router{ + ctx: ctx, + started: make(chan struct{}), + } +} + +// gateMetadata returns metadata that hits the InboundDetour loop-detection +// branch (LastInbound == InboundDetour), so routeConnection / +// routePacketConnection return immediately once the gate is open, without +// needing any outbound/inbound managers. +func gateMetadata() adapter.InboundContext { + return adapter.InboundContext{InboundDetour: "self", LastInbound: "self"} +} + +// TestRouteConnectionWaitsForStart verifies that a connection arriving before +// the router finishes starting (StartStatePostStart) blocks until the started +// channel is closed, then proceeds, instead of dereferencing a nil +// defaultOutbound. +func TestRouteConnectionWaitsForStart(t *testing.T) { + r := newGateTestRouter(context.Background()) + + done := make(chan error, 1) + go func() { + done <- r.routeConnection(context.Background(), nil, gateMetadata(), nil) + }() + + // The gate must block while the router is not yet started. + select { + case <-done: + t.Fatal("routeConnection returned before router was started") + case <-time.After(50 * time.Millisecond): + } + + // Simulate StartStatePostStart completing. + close(r.started) + + select { + case err := <-done: + // We expect to have passed the gate and reached the loop-detection branch, + // which returns the "routing loop on detour" error. + if err == nil { + t.Fatal("expected routing-loop error after gate opened, got nil") + } + case <-time.After(time.Second): + t.Fatal("routeConnection did not proceed after router started") + } +} + +// TestRoutePacketConnectionWaitsForStart is the UDP counterpart. +func TestRoutePacketConnectionWaitsForStart(t *testing.T) { + r := newGateTestRouter(context.Background()) + + done := make(chan error, 1) + go func() { + done <- r.routePacketConnection(context.Background(), nil, gateMetadata(), nil) + }() + + select { + case <-done: + t.Fatal("routePacketConnection returned before router was started") + case <-time.After(50 * time.Millisecond): + } + + close(r.started) + + select { + case err := <-done: + if err == nil { + t.Fatal("expected routing-loop error after gate opened, got nil") + } + case <-time.After(time.Second): + t.Fatal("routePacketConnection did not proceed after router started") + } +} + +// TestRouteConnectionAbortsOnConnContext verifies that a client disconnecting +// during startup unblocks the gate via the per-connection context, instead of +// hanging until the router starts. +func TestRouteConnectionAbortsOnConnContext(t *testing.T) { + r := newGateTestRouter(context.Background()) + + connCtx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { + done <- r.routeConnection(connCtx, nil, gateMetadata(), nil) + }() + + select { + case <-done: + t.Fatal("routeConnection returned before context was cancelled") + case <-time.After(50 * time.Millisecond): + } + + cancel() + + select { + case err := <-done: + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } + case <-time.After(time.Second): + t.Fatal("routeConnection did not abort on connection context cancellation") + } +} + +// TestRouteConnectionAbortsOnRouterContext verifies that shutting down the box +// (router context cancellation) unblocks an in-flight early connection. +func TestRouteConnectionAbortsOnRouterContext(t *testing.T) { + routerCtx, cancel := context.WithCancel(context.Background()) + r := newGateTestRouter(routerCtx) + + done := make(chan error, 1) + go func() { + done <- r.routeConnection(context.Background(), nil, gateMetadata(), nil) + }() + + select { + case <-done: + t.Fatal("routeConnection returned before router context was cancelled") + case <-time.After(50 * time.Millisecond): + } + + cancel() + + select { + case err := <-done: + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } + case <-time.After(time.Second): + t.Fatal("routeConnection did not abort on router context cancellation") + } +} diff --git a/route/router.go b/route/router.go index 04d41322..94ea3cac 100644 --- a/route/router.go +++ b/route/router.go @@ -44,7 +44,7 @@ type Router struct { pauseManager pause.Manager trackers []adapter.ConnectionTracker platformInterface adapter.PlatformInterface - started bool + started chan struct{} } func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) *Router { @@ -63,6 +63,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess, pauseManager: service.FromContext[pause.Manager](ctx), platformInterface: service.FromContext[adapter.PlatformInterface](ctx), + started: make(chan struct{}), } } @@ -180,7 +181,7 @@ func (r *Router) Start(stage adapter.StartStage) error { } else { r.defaultOutbound = r.outbound.Default() } - r.started = true + close(r.started) return nil case adapter.StartStateStarted: for _, ruleSet := range r.ruleSets { diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index fb60a4d7..d24c8d73 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -29,13 +29,17 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti case "": return nil, nil case C.RuleActionTypeRoute: - overrideGateway := M.ParseAddr(action.RouteOptions.OverrideGateway) + var overrideGateway *netip.Addr + if action.RouteOptions.OverrideGateway != "" { + parsed := M.ParseAddr(action.RouteOptions.OverrideGateway) + overrideGateway = &parsed + } return &RuleActionRoute{ Outbound: action.RouteOptions.Outbound, RuleActionRouteOptions: RuleActionRouteOptions{ OverrideAddress: M.ParseSocksaddrHostPort(action.RouteOptions.OverrideAddress, 0), OverridePort: action.RouteOptions.OverridePort, - OverrideGateway: &overrideGateway, + OverrideGateway: overrideGateway, NetworkStrategy: (*C.NetworkStrategy)(action.RouteOptions.NetworkStrategy), FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, diff --git a/test/go.mod b/test/go.mod index 5f6e1d15..432bb001 100644 --- a/test/go.mod +++ b/test/go.mod @@ -1,6 +1,6 @@ module test -go 1.26.1 +go 1.26.4 require github.com/sagernet/sing-box v0.0.0 @@ -14,15 +14,17 @@ replace github.com/sagernet/sing-mux => github.com/shtorm-7/sing-mux v0.3.4-exte replace github.com/ameshkov/dnscrypt/v2 => github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 -replace github.com/sagernet/sing-vmess => github.com/starifly/sing-vmess v0.2.7-mod.9 +replace github.com/sagernet/sing-vmess => github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0 -replace github.com/sagernet/sing => github.com/shtorm-7/sing v0.8.10-extended-1.0.0 +replace github.com/sagernet/sing => /home/shtorm/Projects/shtorm-7/sing -replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.1 +replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 -replace github.com/shtorm-7/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 +replace github.com/shtorm-7/go-cache/v2 => /home/shtorm/Projects/shtorm-7/go-cache + +replace github.com/sagernet/smux => /home/shtorm/Projects/shtorm-7/smux require ( github.com/docker/docker v28.5.2+incompatible @@ -36,7 +38,6 @@ require ( github.com/spyzhov/ajson v0.9.4 github.com/stretchr/testify v1.11.1 go.uber.org/goleak v1.3.0 - golang.org/x/crypto v0.49.0 golang.org/x/net v0.52.0 ) @@ -221,6 +222,7 @@ require ( go.uber.org/zap/exp v0.3.0 // indirect go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect + golang.org/x/crypto v0.49.0 // indirect golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect golang.org/x/mod v0.34.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect diff --git a/test/go.sum b/test/go.sum index 0ec1be55..ae169fc0 100644 --- a/test/go.sum +++ b/test/go.sum @@ -362,8 +362,6 @@ github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75 github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA= github.com/sagernet/sing-tun v0.8.9 h1:ixFKKUGdVcJl4wb0xbL36hobiw9l6DIH497EQf5ILpM= github.com/sagernet/sing-tun v0.8.9/go.mod h1:QvarqUtHfj1ULaRR+6kZOS/OoCE+pYGq67A5tyIy+dQ= -github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478= -github.com/sagernet/smux v1.5.50-sing-box-mod.1/go.mod h1:NjhsCEWedJm7eFLyhuBgIEzwfhRmytrUoiLluxs5Sk8= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA= github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 h1:ws7BIsYLd31Wjifq88BYCHRVlgO+07iwil39s6ERba8= @@ -372,12 +370,14 @@ github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 h1:e5s7RKBd2rIPR0StbvZ2vTV github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI= github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 h1:PLZ/YHqnApPx13wt6MX3ItqESp4ueBr1tGSi0bEGqYw= github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4= -github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.1 h1:UeJkrCJJmIjTBywErVMx7fCSoBf4gh6QgT9bp9o1ajM= -github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.1/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0= -github.com/shtorm-7/sing v0.8.10-extended-1.0.0 h1:mAkyycCQOzCttPOR5fcHkJaZvXMQXeu3mbEfr8D+7A8= -github.com/shtorm-7/sing v0.8.10-extended-1.0.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA= +github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 h1:iBLll4ZZG8ULQcHWs6gGslZWtBN72Zo1zjySzMVHF7g= +github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0= +github.com/shtorm-7/sing v0.8.10-extended-1.1.0 h1:P4JL2cugjvEvnYu8tMmpR30SE1qsS45RcnNEwzDz5as= +github.com/shtorm-7/sing v0.8.10-extended-1.1.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0 h1:a5OoXr3e2ACbM6vDIaaGL44IdHQ6wPjcSoU13vfC0Sw= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= +github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0 h1:WVheKmQH5hSQbJU1ZTKthKSutkTLWSb2hp4JuQhJBow= +github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs= github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2 h1:hSMjh97OszszOd8HrzpaYUQH9dWRRBluJCbwQyz8ZOk= github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2/go.mod h1:TYIIqO5sZpWq873rLIeO2usszSMUpR3h6WdqVVs65ug= github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.4.3 h1:jtOA73D4F5qRV70//ahOt20KBnWvQimAFjtIiOtt0ps= @@ -388,8 +388,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spyzhov/ajson v0.9.4 h1:MVibcTCgO7DY4IlskdqIlCmDOsUOZ9P7oKj8ifdcf84= github.com/spyzhov/ajson v0.9.4/go.mod h1:a6oSw0MMb7Z5aD2tPoPO+jq11ETKgXUr2XktHdT8Wt8= -github.com/starifly/sing-vmess v0.2.7-mod.9 h1:xobAmejSbBQ0A3f/EtJ9cJd3m6gK7dDPccPdeGz7tXY= -github.com/starifly/sing-vmess v0.2.7-mod.9/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= diff --git a/test/v2ray_grpc_test.go b/test/v2ray_grpc_test.go index 884cc42e..69915833 100644 --- a/test/v2ray_grpc_test.go +++ b/test/v2ray_grpc_test.go @@ -217,3 +217,19 @@ func TestV2RayGRPCLite(t *testing.T) { }) }) } + +func TestV2RayGRPCLiteServiceNameWithSlashes(t *testing.T) { + testV2RayTransportSelfWith(t, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: "/47559/I53GwKHO", + ForceLite: true, + }, + }, &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: "/47559/I53GwKHO", + ForceLite: true, + }, + }) +} diff --git a/transport/masque/client_h2.go b/transport/masque/client_h2.go new file mode 100644 index 00000000..3b9ce86e --- /dev/null +++ b/transport/masque/client_h2.go @@ -0,0 +1,331 @@ +package masque + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" + + "github.com/sagernet/quic-go/quicvarint" + "github.com/yosida95/uritemplate/v3" + "golang.org/x/net/http2" +) + +const h2DatagramCapsuleType uint64 = 0 + +const ( + ipv4HeaderLen = 20 + ipv6HeaderLen = 40 +) + +func ConnectTunnelH2(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, endpoint *net.TCPAddr, connectUri string) (io.Closer, IpConn, *http.Response, error) { + if endpoint == nil { + return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint") + } + + tlsConfig.SetNextProtos([]string{"h2"}) + + conn, err := dialer.DialContext(ctx, N.NetworkTCP, M.SocksaddrFromNetIP(endpoint.AddrPort())) + if err != nil { + return nil, nil, nil, err + } + tlsConn, err := tlsConfig.Client(conn) + if err != nil { + _ = conn.Close() + return nil, nil, nil, err + } + if err = tlsConn.HandshakeContext(ctx); err != nil { + _ = conn.Close() + return nil, nil, nil, err + } + + tr := &http2.Transport{ + ReadIdleTimeout: 30 * time.Second, + } + cc, err := tr.NewClientConn(tlsConn) + if err != nil { + _ = tlsConn.Close() + return nil, nil, nil, fmt.Errorf("connect-ip: failed to create client connection: %w", err) + } + + additionalHeaders := http.Header{ + "User-Agent": []string{""}, + } + template := uritemplate.MustNew(connectUri) + + h2Headers := additionalHeaders.Clone() + h2Headers.Set("cf-connect-proto", "cf-connect-ip") + h2Headers.Set("pq-enabled", "false") + + ipConn, rsp, err := dialH2(ctx, cc, template, h2Headers) + if err != nil { + _ = cc.Close() + if strings.Contains(err.Error(), "tls: access denied") { + return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") + } + return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err) + } + + if rsp.StatusCode != http.StatusOK { + _ = ipConn.Close() + _ = cc.Close() + return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %v", rsp.Status) + } + + return cc, ipConn, rsp, nil +} + +func dialH2(ctx context.Context, rt http.RoundTripper, template *uritemplate.Template, additionalHeaders http.Header) (*h2IpConn, *http.Response, error) { + if len(template.Varnames()) > 0 { + return nil, nil, errors.New("connect-ip: IP flow forwarding not supported") + } + + u, err := url.Parse(template.Raw()) + if err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to parse URI: %w", err) + } + + reqCtx, cancel := context.WithCancel(context.Background()) + + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(reqCtx, http.MethodConnect, u.String(), pr) + if err != nil { + cancel() + _ = pr.Close() + _ = pw.Close() + return nil, nil, fmt.Errorf("connect-ip: failed to create request: %w", err) + } + req.Host = authorityFromURL(u) + req.ContentLength = -1 + req.Header = make(http.Header) + for k, v := range additionalHeaders { + req.Header[k] = v + } + + stop := context.AfterFunc(ctx, cancel) + rsp, err := rt.RoundTrip(req) + stop() + if err != nil { + cancel() + _ = pr.Close() + _ = pw.Close() + return nil, nil, fmt.Errorf("connect-ip: failed to send request: %w", err) + } + if rsp.StatusCode < 200 || rsp.StatusCode > 299 { + cancel() + _ = pr.Close() + _ = pw.Close() + _ = rsp.Body.Close() + return nil, rsp, fmt.Errorf("connect-ip: server responded with %d", rsp.StatusCode) + } + + stream := &h2DatagramStream{ + requestBody: pw, + responseBody: rsp.Body, + cancel: cancel, + } + return &h2IpConn{ + str: stream, + closeChan: make(chan struct{}), + }, rsp, nil +} + +func authorityFromURL(u *url.URL) string { + if u.Port() != "" { + return u.Host + } + host := u.Hostname() + if host == "" { + return u.Host + } + return host + ":443" +} + +type h2IpConn struct { + str *h2DatagramStream + + mu sync.Mutex + + closeChan chan struct{} + closeErr error +} + +func (c *h2IpConn) ReadPacket() (b []byte, err error) { +start: + data, err := c.str.ReceiveDatagram(context.Background()) + if err != nil { + defer func() { + _ = c.Close() + }() + select { + case <-c.closeChan: + return nil, c.closeErr + default: + return nil, err + } + } + if err := c.handleIncomingProxiedPacket(data); err != nil { + goto start + } + return data, nil +} + +func (c *h2IpConn) handleIncomingProxiedPacket(data []byte) error { + if len(data) == 0 { + return errors.New("connect-ip: empty packet") + } + switch v := ipVersion(data); v { + default: + return fmt.Errorf("connect-ip: unknown IP versions: %d", v) + case 4: + if len(data) < ipv4HeaderLen { + return fmt.Errorf("connect-ip: malformed datagram: too short") + } + case 6: + if len(data) < ipv6HeaderLen { + return fmt.Errorf("connect-ip: malformed datagram: too short") + } + } + return nil +} + +func (c *h2IpConn) WritePacket(b []byte) (icmp []byte, err error) { + data, err := c.composeDatagram(b) + if err != nil { + return nil, nil + } + if err := c.str.SendDatagram(data); err != nil { + select { + case <-c.closeChan: + return nil, c.closeErr + default: + return nil, err + } + } + return nil, nil +} + +func (c *h2IpConn) composeDatagram(b []byte) ([]byte, error) { + if len(b) == 0 { + return nil, nil + } + switch v := ipVersion(b); v { + default: + return nil, fmt.Errorf("connect-ip: unknown IP versions: %d", v) + case 4: + if len(b) < ipv4HeaderLen { + return nil, fmt.Errorf("connect-ip: IPv4 packet too short") + } + ttl := b[8] + if ttl <= 1 { + return nil, fmt.Errorf("connect-ip: datagram TTL too small: %d", ttl) + } + b[8]-- + binary.BigEndian.PutUint16(b[10:12], calculateIPv4Checksum(([ipv4HeaderLen]byte)(b[:ipv4HeaderLen]))) + case 6: + if len(b) < ipv6HeaderLen { + return nil, fmt.Errorf("connect-ip: IPv6 packet too short") + } + hopLimit := b[7] + if hopLimit <= 1 { + return nil, fmt.Errorf("connect-ip: datagram Hop Limit too small: %d", hopLimit) + } + b[7]-- + } + return b, nil +} + +func (c *h2IpConn) Close() error { + c.mu.Lock() + if c.closeErr == nil { + c.closeErr = net.ErrClosed + close(c.closeChan) + } + c.mu.Unlock() + err := c.str.Close() + return err +} + +func ipVersion(b []byte) uint8 { return b[0] >> 4 } + +func calculateIPv4Checksum(header [ipv4HeaderLen]byte) uint16 { + var sum uint32 + for i := 0; i < len(header); i += 2 { + if i == 10 { + continue + } + sum += uint32(binary.BigEndian.Uint16(header[i : i+2])) + } + for (sum >> 16) > 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return ^uint16(sum) +} + +type h2DatagramStream struct { + requestBody *io.PipeWriter + responseBody io.ReadCloser + cancel context.CancelFunc + + readMu sync.Mutex + writeMu sync.Mutex +} + +func (s *h2DatagramStream) ReceiveDatagram(_ context.Context) ([]byte, error) { + s.readMu.Lock() + defer s.readMu.Unlock() + + reader := quicvarint.NewReader(s.responseBody) + for { + capsuleType, err := quicvarint.Read(reader) + if err != nil { + return nil, err + } + payloadLen, err := quicvarint.Read(reader) + if err != nil { + return nil, err + } + payload := make([]byte, payloadLen) + _, err = io.ReadFull(reader, payload) + if err != nil { + return nil, err + } + if capsuleType != h2DatagramCapsuleType { + continue + } + return payload, nil + } +} + +func (s *h2DatagramStream) SendDatagram(data []byte) error { + frame := make([]byte, 0, quicvarint.Len(h2DatagramCapsuleType)+quicvarint.Len(uint64(len(data)))+len(data)) + frame = quicvarint.Append(frame, h2DatagramCapsuleType) + frame = quicvarint.Append(frame, uint64(len(data))) + frame = append(frame, data...) + + s.writeMu.Lock() + defer s.writeMu.Unlock() + _, err := s.requestBody.Write(frame) + if err != nil { + return fmt.Errorf("connect-ip: failed to send datagram capsule: %w", err) + } + return nil +} + +func (s *h2DatagramStream) Close() error { + _ = s.requestBody.Close() + err := s.responseBody.Close() + s.cancel() + return err +} diff --git a/transport/masque/masque.go b/transport/masque/masque.go index 477eb9d4..bbed11e0 100644 --- a/transport/masque/masque.go +++ b/transport/masque/masque.go @@ -2,9 +2,9 @@ package masque import ( "context" - "crypto/tls" "errors" "fmt" + "io" "net" "net/http" "net/netip" @@ -12,13 +12,13 @@ import ( connectip "github.com/Diniboy1123/connect-ip-go" "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/congestion" "github.com/sagernet/quic-go/http3" qtls "github.com/sagernet/sing-quic" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" aTLS "github.com/sagernet/sing/common/tls" "github.com/yosida95/uritemplate/v3" - "golang.org/x/net/http2" ) type ( @@ -26,39 +26,60 @@ type ( ListenPacket func(network string, address string) (net.PacketConn, error) ) -func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool) (net.PacketConn, *http3.Transport, *connectip.Conn, *http.Response, error) { - template := uritemplate.MustNew(connectUri) - additionalHeaders := http.Header{ - "User-Agent": []string{""}, +type IpConn interface { + ReadPacket() (b []byte, err error) + WritePacket(b []byte) (icmp []byte, err error) + Close() error +} + +type closerFunc func() error + +func (f closerFunc) Close() error { return f() } + +type quicIpConn struct { + conn *connectip.Conn + buf []byte +} + +func newQuicIpConn(conn *connectip.Conn) *quicIpConn { + return &quicIpConn{ + conn: conn, + buf: make([]byte, 0xFFFF), } +} + +func (c *quicIpConn) ReadPacket() ([]byte, error) { + n, err := c.conn.ReadPacket(c.buf, true) + if err != nil { + return nil, err + } + return c.buf[:n], nil +} + +func (c *quicIpConn) WritePacket(b []byte) (icmp []byte, err error) { + return c.conn.WritePacket(b) +} + +func (c *quicIpConn) Close() error { + return c.conn.Close() +} + +func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool, congestionControl func(conn *quic.Conn) congestion.CongestionControl) (io.Closer, IpConn, *http.Response, error) { if useHTTP2 { h2Endpoint, ok := endpoint.(*net.TCPAddr) if !ok || h2Endpoint == nil { - return nil, nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint") + return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint") } - h2Headers := additionalHeaders.Clone() - h2Headers.Set("cf-connect-proto", "cf-connect-ip") - h2Headers.Set("pq-enabled", "false") - h2Client, err := newHTTP2Client(dialer, tlsConfig, h2Endpoint, connectUri) - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("failed to create HTTP/2 client: %w", err) - } - ipConn, rsp, err := connectip.DialH2(ctx, h2Client, template, h2Headers) - if err != nil { - if strings.Contains(err.Error(), "tls: access denied") { - return nil, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") - } - return nil, nil, nil, nil, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err) - } - return nil, nil, ipConn, rsp, nil + return ConnectTunnelH2(ctx, dialer, tlsConfig, h2Endpoint, connectUri) } + quicEndpoint, ok := endpoint.(*net.UDPAddr) if !ok || quicEndpoint == nil { - return nil, nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint") + return nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint") } udpConn, err := dialer.ListenPacket(ctx, M.SocksaddrFromNetIP(quicEndpoint.AddrPort())) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } conn, err := qtls.Dial( ctx, @@ -68,28 +89,34 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig, ) if err != nil { - return nil, nil, nil, nil, err + _ = udpConn.Close() + return nil, nil, nil, err + } + if congestionControl != nil { + conn.SetCongestionControl(congestionControl(conn)) } tr := &http3.Transport{ EnableDatagrams: true, AdditionalSettings: map[uint64]uint64{ - // official client still sends this out as well, even though - // it's deprecated, see https://datatracker.ietf.org/doc/draft-ietf-masque-h3-datagram/00/ - // SETTINGS_H3_DATAGRAM_00 = 0x0000000000000276 - // https://github.com/cloudflare/quiche/blob/7c66757dbc55b8d0c3653d4b345c6785a181f0b7/quiche/src/h3/frame.rs#L46 0x276: 1, }, DisableCompression: true, } hconn := tr.NewClientConn(conn) + + template := uritemplate.MustNew(connectUri) + additionalHeaders := http.Header{ + "User-Agent": []string{""}, + } ipConn, rsp, err := connectip.Dial(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true) if err != nil { _ = tr.Close() _ = conn.CloseWithError(0, "connect-ip dial failed") + _ = udpConn.Close() if strings.Contains(err.Error(), "tls: access denied") { - return udpConn, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") + return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") } - return udpConn, nil, nil, nil, fmt.Errorf("failed to dial connect-ip: %w", err) + return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %w", err) } err = ipConn.AdvertiseRoute(ctx, []connectip.IPRoute{ { @@ -109,34 +136,16 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, }, }) if err != nil { - return udpConn, nil, nil, nil, err + _ = ipConn.Close() + _ = tr.Close() + _ = udpConn.Close() + return nil, nil, nil, err } - return udpConn, tr, ipConn, rsp, nil -} -func newHTTP2Client(dialer N.Dialer, baseTLSConfig aTLS.Config, endpoint *net.TCPAddr, connectURI string) (*http.Client, error) { - if endpoint == nil { - return nil, errors.New("missing HTTP/2 endpoint") - } - tlsConfig := baseTLSConfig.Clone() - tlsConfig.SetNextProtos([]string{"h2"}) - return &http.Client{ - Transport: &http2.Transport{ - DialTLSContext: func(ctx context.Context, network, _ string, _ *tls.Config) (net.Conn, error) { - conn, err := dialer.DialContext(ctx, network, M.SocksaddrFromNetIP(endpoint.AddrPort())) - if err != nil { - return nil, err - } - tlsConn, err := tlsConfig.Client(conn) - if err != nil { - return nil, err - } - if err := tlsConn.HandshakeContext(ctx); err != nil { - _ = conn.Close() - return nil, err - } - return tlsConn, nil - }, - }, - }, nil + closer := closerFunc(func() error { + _ = tr.Close() + _ = udpConn.Close() + return nil + }) + return closer, newQuicIpConn(ipConn), rsp, nil } diff --git a/transport/masque/options.go b/transport/masque/options.go index 3dfe1647..be65a1cd 100644 --- a/transport/masque/options.go +++ b/transport/masque/options.go @@ -5,6 +5,8 @@ import ( "net/netip" "time" + "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/congestion" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/tls" ) @@ -23,4 +25,5 @@ type TunnelOptions struct { UDPKeepalivePeriod time.Duration UDPInitialPacketSize uint16 ReconnectDelay time.Duration + CongestionControl func(conn *quic.Conn) congestion.CongestionControl } diff --git a/transport/masque/tunnel.go b/transport/masque/tunnel.go index e58a5251..a8bc8cef 100644 --- a/transport/masque/tunnel.go +++ b/transport/masque/tunnel.go @@ -4,13 +4,12 @@ import ( "context" "errors" "fmt" + "io" "net" "os" "sync" "time" - connectip "github.com/Diniboy1123/connect-ip-go" - "github.com/sagernet/quic-go/http3" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -22,9 +21,8 @@ type Tunnel struct { options TunnelOptions device Device - udpConn net.PacketConn - tr *http3.Transport - ipConn *connectip.Conn + closer io.Closer + ipConn IpConn mtx sync.Mutex } @@ -83,13 +81,11 @@ func (e *Tunnel) Close() error { defer e.mtx.Unlock() if e.ipConn != nil { e.ipConn.Close() - if e.udpConn != nil { - e.udpConn.Close() - } - if e.tr != nil { - e.tr.Close() + if e.closer != nil { + e.closer.Close() } e.ipConn = nil + e.closer = nil } return e.device.Close() } @@ -124,7 +120,7 @@ func (e *Tunnel) maintainTunnel() { } icmp, err := ipConn.WritePacket(packet) if err != nil { - if errors.As(err, new(*connectip.CloseError)) { + if errors.Is(err, net.ErrClosed) { if ok := e.closeIpConn(ipConn); ok { e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing to IP connection: %w", err)) } @@ -135,7 +131,7 @@ func (e *Tunnel) maintainTunnel() { } if len(icmp) > 0 { if _, err := e.device.Write([][]byte{icmp}, 0); err != nil { - if errors.As(err, new(*connectip.CloseError)) { + if errors.Is(err, net.ErrClosed) { e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing ICMP to TUN device: %v", err)) continue } @@ -145,15 +141,14 @@ func (e *Tunnel) maintainTunnel() { } }() go func() { - buf := make([]byte, 1280) for e.ctx.Err() == nil { ipConn, err := e.getIpConn() if err != nil { return } - n, err := ipConn.ReadPacket(buf, true) + packet, err := ipConn.ReadPacket() if err != nil { - if e.options.UseHTTP2 || errors.As(err, new(*connectip.CloseError)) { + if e.options.UseHTTP2 || errors.Is(err, net.ErrClosed) { if ok := e.closeIpConn(ipConn); ok { e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while reading from IP connection: %v", err)) } @@ -162,7 +157,7 @@ func (e *Tunnel) maintainTunnel() { e.logger.ErrorContext(e.ctx, fmt.Errorf("Error reading from IP connection: %v, continuine...", err)) continue } - if _, err := e.device.Write([][]byte{buf[:n]}, 0); err != nil { + if _, err := e.device.Write([][]byte{packet}, 0); err != nil { continue } } @@ -170,7 +165,7 @@ func (e *Tunnel) maintainTunnel() { <-e.ctx.Done() } -func (e *Tunnel) getIpConn() (*connectip.Conn, error) { +func (e *Tunnel) getIpConn() (IpConn, error) { e.mtx.Lock() defer e.mtx.Unlock() if e.ctx.Err() != nil { @@ -184,7 +179,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) { defer timer.Stop() for { e.logger.NoticeContext(e.ctx, fmt.Errorf("Establishing MASQUE connection to %s", e.options.Endpoint)) - udpConn, tr, ipConn, rsp, err := ConnectTunnel( + closer, ipConn, rsp, err := ConnectTunnel( e.ctx, e.options.Dialer, e.options.TLSConfig, @@ -192,6 +187,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) { "https://cloudflareaccess.com", e.options.Endpoint, e.options.UseHTTP2, + e.options.CongestionControl, ) if err != nil { e.logger.ErrorContext(e.ctx, fmt.Errorf("Failed to connect tunnel: %v", err)) @@ -206,11 +202,8 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) { if rsp.StatusCode != 200 { e.logger.ErrorContext(e.ctx, fmt.Errorf("Tunnel connection failed: %s", rsp.Status)) ipConn.Close() - if udpConn != nil { - udpConn.Close() - } - if tr != nil { - tr.Close() + if closer != nil { + closer.Close() } timer.Reset(e.options.ReconnectDelay) select { @@ -220,26 +213,23 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) { } continue } - e.udpConn = udpConn - e.tr = tr + e.closer = closer e.ipConn = ipConn e.logger.NoticeContext(e.ctx, "Connected to MASQUE server ", e.options.Endpoint) return ipConn, nil } } -func (e *Tunnel) closeIpConn(ipConn *connectip.Conn) bool { +func (e *Tunnel) closeIpConn(ipConn IpConn) bool { e.mtx.Lock() defer e.mtx.Unlock() if ipConn == e.ipConn { e.ipConn.Close() - if e.udpConn != nil { - e.udpConn.Close() - } - if e.tr != nil { - e.tr.Close() + if e.closer != nil { + e.closer.Close() } e.ipConn = nil + e.closer = nil return true } return false diff --git a/transport/openvpn/cipher.go b/transport/openvpn/cipher.go index 3ec1dc57..7a5519e0 100644 --- a/transport/openvpn/cipher.go +++ b/transport/openvpn/cipher.go @@ -4,6 +4,7 @@ import ( "crypto/aes" "crypto/cipher" "crypto/hmac" + "crypto/md5" "crypto/rand" "crypto/sha1" "crypto/sha256" @@ -23,7 +24,7 @@ const ( type DataCipher interface { Encrypt(header []byte, packetID uint32, payload []byte) ([]byte, error) - Decrypt(packet []byte, headerSize int) ([]byte, error) + Decrypt(packet []byte, headerSize int) (plaintext []byte, packetID uint32, err error) } type AEADDataCipher struct { @@ -86,9 +87,9 @@ func (g *AEADDataCipher) Encrypt(header []byte, packetID uint32, payload []byte) return out, nil } -func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) { +func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) { if len(packet) < headerSize+4+AESGCMTagSize+1 { - return nil, errors.New("openvpn gcm data packet too short") + return nil, 0, errors.New("openvpn gcm data packet too short") } header := packet[:headerSize] pidBytes := packet[headerSize : headerSize+4] @@ -96,8 +97,13 @@ func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) ciphertext := packet[headerSize+4+AESGCMTagSize:] combined := append(ciphertext, tag...) ad := append(header, pidBytes...) - nonce := g.nonce(binary.BigEndian.Uint32(pidBytes), g.recvImplicitIV) - return g.recv.Open(nil, nonce[:], combined, ad) + packetID := binary.BigEndian.Uint32(pidBytes) + nonce := g.nonce(packetID, g.recvImplicitIV) + plain, err := g.recv.Open(nil, nonce[:], combined, ad) + if err != nil { + return nil, 0, err + } + return plain, packetID, nil } func (g *AEADDataCipher) nonce(packetID uint32, implicit [AESGCMIVSize]byte) [AESGCMIVSize]byte { @@ -127,6 +133,9 @@ func NewCBCCipher(keys *KeyMaterial, auth string) (*CBCDataCipher, error) { var newHash func() hash.Hash var hmacSize int switch auth { + case AuthMD5: + newHash = md5.New + hmacSize = md5.Size case AuthSHA256: newHash = sha256.New hmacSize = sha256.Size @@ -176,34 +185,35 @@ func (c *CBCDataCipher) Encrypt(header []byte, packetID uint32, payload []byte) return out, nil } -func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) { +func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) { minSize := headerSize + c.hmacSize + CBCIVSize + aes.BlockSize if len(packet) < minSize { - return nil, errors.New("openvpn cbc data packet too short") + return nil, 0, errors.New("openvpn cbc data packet too short") } tag := packet[headerSize : headerSize+c.hmacSize] iv := packet[headerSize+c.hmacSize : headerSize+c.hmacSize+CBCIVSize] ct := packet[headerSize+c.hmacSize+CBCIVSize:] if len(ct)%aes.BlockSize != 0 { - return nil, errors.New("openvpn cbc ciphertext not block-aligned") + return nil, 0, errors.New("openvpn cbc ciphertext not block-aligned") } mac := hmac.New(c.newHash, c.recvHMAC) mac.Write(iv) mac.Write(ct) if !hmac.Equal(tag, mac.Sum(nil)) { - return nil, errors.New("openvpn cbc hmac verification failed") + return nil, 0, errors.New("openvpn cbc hmac verification failed") } plain := make([]byte, len(ct)) cipher.NewCBCDecrypter(c.recvBlock, iv).CryptBlocks(plain, ct) padLen := int(plain[len(plain)-1]) if padLen < 1 || padLen > aes.BlockSize { - return nil, errors.New("openvpn cbc invalid padding") + return nil, 0, errors.New("openvpn cbc invalid padding") } plain = plain[:len(plain)-padLen] if len(plain) < 4 { - return nil, errors.New("openvpn cbc payload too short") + return nil, 0, errors.New("openvpn cbc payload too short") } - return plain[4:], nil + packetID := binary.BigEndian.Uint32(plain[:4]) + return plain[4:], packetID, nil } func CipherKeyLength(cipher string) int { diff --git a/transport/openvpn/client.go b/transport/openvpn/client.go index e7adfa7e..1fe180b9 100644 --- a/transport/openvpn/client.go +++ b/transport/openvpn/client.go @@ -8,12 +8,16 @@ import ( "io" "net" "strings" + "sync/atomic" "time" "github.com/sagernet/sing/common/tls" ) -const defaultHandshakeTimeout = 30 * time.Second +const ( + defaultHandshakeTimeout = 30 * time.Second + controlRetransmitDelay = time.Second +) type Client struct { config *ClientConfig @@ -26,6 +30,8 @@ type Client struct { push *PushReply cancel context.CancelFunc + + lastReceiveNano atomic.Int64 } func NewClient(config *ClientConfig, io PacketIO, tlsConfig tls.Config) (*Client, error) { @@ -154,6 +160,7 @@ func (c *Client) Handshake(ctx context.Context) (*PushReply, error) { return nil, err } c.data = NewDataChannel(cipher, push.PeerID, push.CompLZO) + c.markReceive() return push, nil } @@ -181,10 +188,21 @@ func (c *Client) ReadIPPacket(ctx context.Context) ([]byte, error) { if err != nil { continue } + c.markReceive() return plain, nil } } +func (c *Client) SinceReceive() time.Duration { + return time.Duration(int64(time.Since(clientStart)) - c.lastReceiveNano.Load()) +} + +func (c *Client) markReceive() { + c.lastReceiveNano.Store(int64(time.Since(clientStart))) +} + +var clientStart = time.Now().Add(-time.Hour) + func (c *Client) Close() error { if c.cancel != nil { c.cancel() @@ -199,10 +217,24 @@ func (c *Client) Close() error { } func (c *Client) waitServerReset(ctx context.Context) error { + retransmits := 0 for { - packet, err := c.control.Read(ctx) + readCtx := ctx + cancel := func() {} + if c.config.Proto == ProtoUDP { + readCtx, cancel = context.WithTimeout(ctx, controlRetransmitDelay) + } + packet, err := c.control.Read(readCtx) + cancel() if err != nil { - return fmt.Errorf("read hard reset response: %w", err) + if c.config.Proto == ProtoUDP && errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil { + if err := c.control.RetransmitPending(ctx); err != nil { + return fmt.Errorf("retransmit hard reset: %w", err) + } + retransmits++ + continue + } + return fmt.Errorf("read hard reset response after %d retransmits: %w", retransmits, err) } switch packet.Opcode { case PControlHardResetServerV2: diff --git a/transport/openvpn/config.go b/transport/openvpn/config.go index 55f9c895..5c926b59 100644 --- a/transport/openvpn/config.go +++ b/transport/openvpn/config.go @@ -20,6 +20,7 @@ const ( CipherAES256CBC = "AES-256-CBC" CipherCHACHA20POLY = "CHACHA20-POLY1305" + AuthMD5 = "MD5" AuthSHA1 = "SHA1" AuthSHA256 = "SHA256" AuthSHA384 = "SHA384" @@ -107,7 +108,7 @@ func isValidCipher(cipher string) bool { func isValidAuth(auth string) bool { switch auth { - case AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512: + case AuthMD5, AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512: return true } return false diff --git a/transport/openvpn/control.go b/transport/openvpn/control.go index b777424a..b2d64f0f 100644 --- a/transport/openvpn/control.go +++ b/transport/openvpn/control.go @@ -30,8 +30,10 @@ type ControlChannel struct { mu sync.Mutex sendPacketID uint32 sendMessage uint32 + recvMessage uint32 ackPending []uint32 pending map[uint32]*ControlPacket + recvPending map[uint32]*ControlPacket readDeadline time.Time writeDeadline time.Time } @@ -40,9 +42,10 @@ func NewControlChannel(io PacketIO, crypt ControlCrypt, local SessionID) *Contro ch := &ControlChannel{ io: io, - clock: time.Now, - local: local, - pending: make(map[uint32]*ControlPacket), + clock: time.Now, + local: local, + pending: make(map[uint32]*ControlPacket), + recvPending: make(map[uint32]*ControlPacket), } if crypt != nil { ch.encode = func(p *ControlPacket, pid uint32, t uint32) ([]byte, error) { @@ -130,10 +133,23 @@ func (c *ControlChannel) SendAck(ctx context.Context) error { func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) { for { + c.mu.Lock() + if packet, ok := c.recvPending[c.recvMessage]; ok { + delete(c.recvPending, c.recvMessage) + c.recvMessage++ + c.mu.Unlock() + return packet, nil + } + c.mu.Unlock() + packet, err := c.readControlPacket(ctx) if err != nil { return nil, err } + + var deliver *ControlPacket + sendAck := false + c.mu.Lock() if c.remote == (SessionID{}) && packet.LocalSession != c.local { c.remote = packet.LocalSession @@ -144,11 +160,33 @@ func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) { if packet.Opcode.HasMessageID() { c.ackPending = appendAck(c.ackPending, packet.MessageID) } - c.mu.Unlock() - if packet.Opcode == PAckV1 { - continue + + switch { + case packet.Opcode == PAckV1: + case !packet.Opcode.HasMessageID(): + deliver = packet + case packet.MessageID < c.recvMessage: + sendAck = true + case packet.MessageID == c.recvMessage: + deliver = packet + c.recvMessage++ + default: + if _, exists := c.recvPending[packet.MessageID]; !exists { + c.recvPending[packet.MessageID] = packet + } + sendAck = true + } + + c.mu.Unlock() + + if deliver != nil { + return deliver, nil + } + if sendAck { + if err := c.SendAck(ctx); err != nil { + return nil, err + } } - return packet, nil } } @@ -349,11 +387,17 @@ func (c *ControlConn) SetWriteDeadline(t time.Time) error { } type streamPacketIO struct { - conn net.Conn + conn net.Conn + deadlineMu sync.Mutex + readDeadline time.Time + writeDeadline time.Time } type datagramPacketIO struct { - conn net.Conn + conn net.Conn + deadlineMu sync.Mutex + readDeadline time.Time + writeDeadline time.Time } func NewDatagramPacketIO(conn net.Conn) PacketIO { @@ -361,40 +405,23 @@ func NewDatagramPacketIO(conn net.Conn) PacketIO { } func (d *datagramPacketIO) ReadPacket(ctx context.Context) ([]byte, error) { - done := make(chan struct{}) - var ( - packet []byte - err error - ) - go func() { - defer close(done) - buf := make([]byte, 64*1024) - var n int - n, err = d.conn.Read(buf) - if err == nil { - packet = cloneBytes(buf[:n]) - } - }() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-done: - return packet, err + if err := setReadDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.readDeadline); err != nil { + return nil, err } + buf := make([]byte, 64*1024) + n, err := d.conn.Read(buf) + if err != nil { + return nil, contextIOError(ctx, err) + } + return buf[:n], nil } func (d *datagramPacketIO) WritePacket(ctx context.Context, packet []byte) error { - done := make(chan error, 1) - go func() { - _, err := d.conn.Write(packet) - done <- err - }() - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: + if err := setWriteDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.writeDeadline); err != nil { return err } + _, err := d.conn.Write(packet) + return contextIOError(ctx, err) } func (d *datagramPacketIO) Close() error { @@ -414,52 +441,37 @@ func NewTCPPacketIO(conn net.Conn) PacketIO { } func (s *streamPacketIO) ReadPacket(ctx context.Context) ([]byte, error) { - done := make(chan struct{}) - var ( - packet []byte - err error - ) - go func() { - defer close(done) - var lenBuf [2]byte - if _, err = io.ReadFull(s.conn, lenBuf[:]); err != nil { - return - } - size := int(lenBuf[0])<<8 | int(lenBuf[1]) - if size == 0 { - err = errors.New("empty openvpn tcp packet") - return - } - packet = make([]byte, size) - _, err = io.ReadFull(s.conn, packet) - }() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-done: - return packet, err + if err := setReadDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.readDeadline); err != nil { + return nil, err } + var lenBuf [2]byte + if _, err := io.ReadFull(s.conn, lenBuf[:]); err != nil { + return nil, contextIOError(ctx, err) + } + size := int(lenBuf[0])<<8 | int(lenBuf[1]) + if size == 0 { + return nil, errors.New("empty openvpn tcp packet") + } + packet := make([]byte, size) + if _, err := io.ReadFull(s.conn, packet); err != nil { + return nil, contextIOError(ctx, err) + } + return packet, nil } func (s *streamPacketIO) WritePacket(ctx context.Context, packet []byte) error { if len(packet) > 0xffff { return fmt.Errorf("openvpn tcp packet too large: %d", len(packet)) } - done := make(chan error, 1) - go func() { - frame := make([]byte, 2+len(packet)) - frame[0] = byte(len(packet) >> 8) - frame[1] = byte(len(packet)) - copy(frame[2:], packet) - _, err := s.conn.Write(frame) - done <- err - }() - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: + if err := setWriteDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.writeDeadline); err != nil { return err } + frame := make([]byte, 2+len(packet)) + frame[0] = byte(len(packet) >> 8) + frame[1] = byte(len(packet)) + copy(frame[2:], packet) + _, err := s.conn.Write(frame) + return contextIOError(ctx, err) } func (s *streamPacketIO) Close() error { @@ -473,3 +485,50 @@ func (s *streamPacketIO) LocalAddr() net.Addr { func (s *streamPacketIO) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } + +func setReadDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error { + deadline, hasDeadline := ctx.Deadline() + mu.Lock() + defer mu.Unlock() + if current.Equal(deadline) { + return nil + } + if hasDeadline { + if err := conn.SetReadDeadline(deadline); err != nil { + return err + } + } else if err := conn.SetReadDeadline(time.Time{}); err != nil { + return err + } + *current = deadline + return nil +} + +func setWriteDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error { + deadline, hasDeadline := ctx.Deadline() + mu.Lock() + defer mu.Unlock() + if current.Equal(deadline) { + return nil + } + if hasDeadline { + if err := conn.SetWriteDeadline(deadline); err != nil { + return err + } + } else if err := conn.SetWriteDeadline(time.Time{}); err != nil { + return err + } + *current = deadline + return nil +} + +func contextIOError(ctx context.Context, err error) error { + if err == nil { + return nil + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() && ctx.Err() != nil { + return ctx.Err() + } + return err +} diff --git a/transport/openvpn/data.go b/transport/openvpn/data.go index a67d15b3..90afabd6 100644 --- a/transport/openvpn/data.go +++ b/transport/openvpn/data.go @@ -8,15 +8,21 @@ import ( const ( PeerIDUnset uint32 = 0xffffff + + dataChannelReplayWindow = 64 ) type DataChannel struct { - cipher DataCipher - keyID uint8 - peerID uint32 - compLZO bool + cipher DataCipher + keyID uint8 + peerID uint32 + compLZO bool + mu sync.Mutex sendPacketID uint32 + recvHighest uint32 + recvWindow uint64 + recvSeen bool } func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel { @@ -29,10 +35,11 @@ func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel func (d *DataChannel) Encrypt(packet []byte) ([]byte, error) { if d.compLZO { - p := make([]byte, 1+len(packet)) - p[0] = 0xFA - copy(p[1:], packet) - packet = p + compressed, err := lzo1xCompressSafe(packet) + if err != nil { + return nil, err + } + packet = compressed } d.mu.Lock() d.sendPacketID++ @@ -50,18 +57,15 @@ func (d *DataChannel) Decrypt(packet []byte) ([]byte, error) { if opcode == PDataV2 { headerSize = 4 } - plain, err := d.cipher.Decrypt(packet, headerSize) + plain, packetID, err := d.cipher.Decrypt(packet, headerSize) if err != nil { return nil, err } + if err := d.acceptPacketID(packetID); err != nil { + return nil, err + } if d.compLZO { - if len(plain) < 1 { - return nil, errors.New("openvpn comp-lzo packet too short") - } - if plain[0] != 0xFA { - return nil, fmt.Errorf("openvpn compressed packet not supported (byte: 0x%02x)", plain[0]) - } - plain = plain[1:] + return lzo1xDecompressSafe(plain) } return plain, nil } @@ -78,6 +82,40 @@ func (d *DataChannel) dataHeader() []byte { return []byte{opcodeKeyID(PDataV1, d.keyID)} } +func (d *DataChannel) acceptPacketID(packetID uint32) error { + d.mu.Lock() + defer d.mu.Unlock() + + if !d.recvSeen { + d.recvHighest = packetID + d.recvWindow = 1 + d.recvSeen = true + return nil + } + + if packetID > d.recvHighest { + shift := packetID - d.recvHighest + if shift >= dataChannelReplayWindow { + d.recvWindow = 1 + } else { + d.recvWindow = d.recvWindow<= dataChannelReplayWindow { + return fmt.Errorf("openvpn replayed data packet id %d", packetID) + } + mask := uint64(1) << diff + if d.recvWindow&mask != 0 { + return fmt.Errorf("openvpn replayed data packet id %d", packetID) + } + d.recvWindow |= mask + return nil +} + func ParsePeerID(options string) uint32 { for _, field := range splitPushOptions(options) { if len(field) > len("peer-id ") && field[:len("peer-id ")] == "peer-id " { diff --git a/transport/openvpn/e2e_test.go b/transport/openvpn/e2e_test.go new file mode 100644 index 00000000..a591851b --- /dev/null +++ b/transport/openvpn/e2e_test.go @@ -0,0 +1,444 @@ +//go:build with_openvpn && with_gvisor + +// OpenVPN E2E tests. Require a local OpenVPN server setup. +// +// Setup (run once before testing): +// +// # Generate PKI +// mkdir -p /tmp/ovpn-e2e/pki/{issued,private} +// cd /tmp/ovpn-e2e/pki +// openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -days 1 -nodes -keyout ca.key -out ca.crt -subj "/CN=E2ETestCA" +// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/server.key -out server.csr -subj "/CN=server" +// openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/server.crt -days 1 +// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/client.key -out client.csr -subj "/CN=client" +// openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/client.crt -days 1 +// openvpn --genkey secret ta.key +// openvpn --genkey secret ta-auth.key +// +// # Start servers (4 instances: TCP/UDP × tls-crypt/tls-auth) +// # TCP + tls-crypt on :11940, subnet 10.99.0.0/24 +// # UDP + tls-crypt on :11941, subnet 10.99.1.0/24 +// # TCP + tls-auth on :11942, subnet 10.99.2.0/24 +// # UDP + tls-auth on :11943, subnet 10.99.3.0/24 +// # +// # Each server config needs: topology subnet, duplicate-cn, persist-tun, +// # data-ciphers AES-256-GCM:AES-128-GCM:AES-192-GCM:CHACHA20-POLY1305:AES-256-CBC:AES-128-CBC:AES-192-CBC +// # auth SHA256, keepalive 10 60, ca/cert/key from above PKI. +// # tls-auth servers use: tls-auth ta-auth.key 0 +// # tls-crypt servers use: tls-crypt ta.key +// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-crypt.conf --daemon +// sudo openvpn --config /tmp/ovpn-e2e/server-udp-crypt.conf --daemon +// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-auth.conf --daemon +// sudo openvpn --config /tmp/ovpn-e2e/server-udp-auth.conf --daemon +// +// # Start HTTP servers on each VPN subnet +// for ip in 10.99.0.1 10.99.1.1 10.99.2.1 10.99.3.1; do +// mkdir -p /tmp/ovpn-e2e/$ip && echo "hello" > /tmp/ovpn-e2e/$ip/index.html +// cd /tmp/ovpn-e2e/$ip && python3 -m http.server 8080 --bind $ip & +// done +// +// Run tests: +// +// go test -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_manager,with_admin_panel,with_v2ray_api,with_ccm,with_ocm,with_profiler,with_openvpn,with_sudoku,with_trusttunnel" \ +// -run TestE2E -v -count=1 ./transport/openvpn/ -timeout 300s +// +// Tests all 28 combinations: 2 protos (tcp/udp) × 2 TLS modes (tls-crypt/tls-auth) × 7 ciphers. + +package openvpn_test + +import ( + "context" + "fmt" + "io" + "net" + "os" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/sagernet/sing-box" + "github.com/sagernet/sing-box/include" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json/badoption" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/protocol/socks" +) + +// Servers (started externally): +// TCP+tls-crypt :11940 subnet 10.99.0.0/24 +// UDP+tls-crypt :11941 subnet 10.99.1.0/24 +// TCP+tls-auth :11942 subnet 10.99.2.0/24 +// UDP+tls-auth :11943 subnet 10.99.3.0/24 +// TCP+plain :11944 subnet 10.99.4.0/24 +// UDP+plain :11945 subnet 10.99.5.0/24 +// TCP+tls-crypt+SHA1 :11946 subnet 10.99.6.0/24 (CBC only) +// TCP+tls-crypt+SHA512 :11947 subnet 10.99.7.0/24 (CBC only) +// Each has HTTP on .1:8080 serving "hello" + +const pkiDir = "/tmp/ovpn-e2e/pki" + +type serverConfig struct { + proto string + port uint16 + tlsMode string // "tls-crypt" or "tls-auth" + httpAddr string +} + +var servers = []serverConfig{ + {"tcp", 11940, "tls-crypt", "10.99.0.1:8080"}, + {"udp", 11941, "tls-crypt", "10.99.1.1:8080"}, + {"tcp", 11942, "tls-auth", "10.99.2.1:8080"}, + {"udp", 11943, "tls-auth", "10.99.3.1:8080"}, +} + +var ciphers = []string{ + "AES-128-GCM", + "AES-192-GCM", + "AES-256-GCM", + "CHACHA20-POLY1305", + "AES-128-CBC", + "AES-192-CBC", + "AES-256-CBC", +} + +var portCounter atomic.Uint32 + +func init() { portCounter.Store(18100) } + +func nextPort() uint16 { return uint16(portCounter.Add(1)) } + +func readFile(t *testing.T, path string) string { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Skipf("PKI not found: %v", err) + } + return string(data) +} + +func testCombo(t *testing.T, srv serverConfig, cipher string) { + t.Helper() + ca := readFile(t, pkiDir+"/ca.crt") + cert := readFile(t, pkiDir+"/issued/client.crt") + key := readFile(t, pkiDir+"/private/client.key") + + ovpnOpts := &option.OpenVPNOutboundOptions{ + Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: srv.port}}, + Proto: srv.proto, + Cipher: cipher, + Auth: "SHA256", + OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{ + TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key}, + }, + } + + switch srv.tlsMode { + case "tls-crypt": + ovpnOpts.TLSCrypt = readFile(t, pkiDir+"/ta.key") + case "tls-auth": + ovpnOpts.TLSAuth = readFile(t, pkiDir+"/ta-auth.key") + ovpnOpts.KeyDirection = 1 + } + + port := nextPort() + opts := option.Options{ + Log: &option.LogOptions{Level: "error"}, + Inbounds: []option.Inbound{{ + Type: "socks", + Options: &option.SocksInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: (*badoption.Addr)(&badoption.Addr{}), + ListenPort: port, + }, + }, + }}, + Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}}, + Route: &option.RouteOptions{Final: "vpn"}, + } + + ctx := include.Context(context.Background()) + instance, err := box.New(box.Options{Context: ctx, Options: opts}) + if err != nil { + t.Fatal(err) + } + if err := instance.Start(); err != nil { + t.Fatal(err) + } + defer instance.Close() + + time.Sleep(2 * time.Second) + + dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "") + conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(srv.httpAddr)) + if err != nil { + t.Fatal("dial:", err) + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(5 * time.Second)) + _, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n")) + if err != nil { + t.Fatal("write:", err) + } + body, err := io.ReadAll(conn) + if err != nil { + t.Fatal("read:", err) + } + if !strings.Contains(string(body), "hello") { + t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)]) + } +} + +// 4 servers × 7 ciphers = 28 combinations +func TestE2E(t *testing.T) { + for _, srv := range servers { + for _, cipher := range ciphers { + name := fmt.Sprintf("%s/%s/%s", srv.proto, srv.tlsMode, cipher) + srv, cipher := srv, cipher + t.Run(name, func(t *testing.T) { + testCombo(t, srv, cipher) + }) + } + } +} + +// Test CBC ciphers with different auth algorithms (SHA1, SHA512) +func TestE2E_Auth(t *testing.T) { + type authServer struct { + port uint16 + auth string + httpAddr string + } + authServers := []authServer{ + {11946, "SHA1", "10.99.6.1:8080"}, + {11947, "SHA512", "10.99.7.1:8080"}, + } + cbcCiphers := []string{"AES-128-CBC", "AES-256-CBC"} + + for _, as := range authServers { + for _, cipher := range cbcCiphers { + name := fmt.Sprintf("auth-%s/%s", as.auth, cipher) + as, cipher := as, cipher + t.Run(name, func(t *testing.T) { + ca := readFile(t, pkiDir+"/ca.crt") + cert := readFile(t, pkiDir+"/issued/client.crt") + key := readFile(t, pkiDir+"/private/client.key") + tlsCrypt := readFile(t, pkiDir+"/ta.key") + port := nextPort() + opts := option.Options{ + Log: &option.LogOptions{Level: "error"}, + Inbounds: []option.Inbound{{ + Type: "socks", + Options: &option.SocksInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: (*badoption.Addr)(&badoption.Addr{}), + ListenPort: port, + }, + }, + }}, + Outbounds: []option.Outbound{{ + Type: "openvpn", Tag: "vpn", + Options: &option.OpenVPNOutboundOptions{ + Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: as.port}}, + Proto: "tcp", + Cipher: cipher, + Auth: as.auth, + TLSCrypt: tlsCrypt, + OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{ + TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key}, + }, + }, + }}, + Route: &option.RouteOptions{Final: "vpn"}, + } + ctx := include.Context(context.Background()) + instance, err := box.New(box.Options{Context: ctx, Options: opts}) + if err != nil { + t.Fatal(err) + } + if err := instance.Start(); err != nil { + t.Fatal(err) + } + defer instance.Close() + time.Sleep(2 * time.Second) + doHTTPCheck(t, port, as.httpAddr) + }) + } + } +} + +// Test tunnel stability with multiple sequential requests +func TestE2E_BulkData(t *testing.T) { + ca := readFile(t, pkiDir+"/ca.crt") + cert := readFile(t, pkiDir+"/issued/client.crt") + key := readFile(t, pkiDir+"/private/client.key") + tlsCrypt := readFile(t, pkiDir+"/ta.key") + port := nextPort() + opts := option.Options{ + Log: &option.LogOptions{Level: "error"}, + Inbounds: []option.Inbound{{ + Type: "socks", + Options: &option.SocksInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: (*badoption.Addr)(&badoption.Addr{}), + ListenPort: port, + }, + }, + }}, + Outbounds: []option.Outbound{{ + Type: "openvpn", Tag: "vpn", + Options: &option.OpenVPNOutboundOptions{ + Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11940}}, + Proto: "tcp", + Cipher: "AES-256-GCM", + Auth: "SHA256", + TLSCrypt: tlsCrypt, + OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{ + TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key}, + }, + }, + }}, + Route: &option.RouteOptions{Final: "vpn"}, + } + ctx := include.Context(context.Background()) + instance, err := box.New(box.Options{Context: ctx, Options: opts}) + if err != nil { + t.Fatal(err) + } + if err := instance.Start(); err != nil { + t.Fatal(err) + } + defer instance.Close() + time.Sleep(2 * time.Second) + + dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "") + for i := 0; i < 10; i++ { + conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr("10.99.0.1:8080")) + if err != nil { + t.Fatalf("request %d dial: %v", i, err) + } + conn.SetDeadline(time.Now().Add(5 * time.Second)) + fmt.Fprintf(conn, "GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n") + body, err := io.ReadAll(conn) + conn.Close() + if err != nil { + t.Fatalf("request %d read: %v", i, err) + } + if !strings.Contains(string(body), "hello") { + t.Fatalf("request %d: no 'hello'", i) + } + } +} + +func doHTTPCheck(t *testing.T, socksPort uint16, httpAddr string) { + t.Helper() + dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", socksPort), socks.Version5, "", "") + conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(httpAddr)) + if err != nil { + t.Fatal("dial:", err) + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(5 * time.Second)) + _, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n")) + if err != nil { + t.Fatal("write:", err) + } + body, err := io.ReadAll(conn) + if err != nil { + t.Fatal("read:", err) + } + if !strings.Contains(string(body), "hello") { + t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)]) + } +} + +func startInstance(t *testing.T, ovpnOpts *option.OpenVPNOutboundOptions) uint16 { + t.Helper() + port := nextPort() + opts := option.Options{ + Log: &option.LogOptions{Level: "error"}, + Inbounds: []option.Inbound{{ + Type: "socks", + Options: &option.SocksInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: (*badoption.Addr)(&badoption.Addr{}), + ListenPort: port, + }, + }, + }}, + Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}}, + Route: &option.RouteOptions{Final: "vpn"}, + } + ctx := include.Context(context.Background()) + instance, err := box.New(box.Options{Context: ctx, Options: opts}) + if err != nil { + t.Fatal(err) + } + if err := instance.Start(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { instance.Close() }) + time.Sleep(2 * time.Second) + return port +} + +func TestE2E_CompLZO(t *testing.T) { + ca := readFile(t, pkiDir+"/ca.crt") + cert := readFile(t, pkiDir+"/issued/client.crt") + key := readFile(t, pkiDir+"/private/client.key") + tlsCrypt := readFile(t, pkiDir+"/ta.key") + + for _, cipher := range ciphers { + cipher := cipher + t.Run(cipher, func(t *testing.T) { + port := startInstance(t, &option.OpenVPNOutboundOptions{ + Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11948}}, + Proto: "udp", + Cipher: cipher, + Auth: "SHA256", + TLSCrypt: tlsCrypt, + OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{ + TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key}, + }, + }) + doHTTPCheck(t, port, "10.99.8.1:8080") + }) + } +} + +func TestE2E_AES192(t *testing.T) { + ca := readFile(t, pkiDir+"/ca.crt") + cert := readFile(t, pkiDir+"/issued/client.crt") + key := readFile(t, pkiDir+"/private/client.key") + tlsCrypt := readFile(t, pkiDir+"/ta.key") + + type combo struct { + proto string + port uint16 + httpAddr string + } + for _, c := range []combo{ + {"tcp", 11940, "10.99.0.1:8080"}, + {"udp", 11941, "10.99.1.1:8080"}, + } { + for _, cipher := range []string{"AES-192-GCM", "AES-192-CBC"} { + c, cipher := c, cipher + t.Run(fmt.Sprintf("%s/%s", c.proto, cipher), func(t *testing.T) { + port := startInstance(t, &option.OpenVPNOutboundOptions{ + Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: c.port}}, + Proto: c.proto, + Cipher: cipher, + Auth: "SHA256", + TLSCrypt: tlsCrypt, + OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{ + TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key}, + }, + }) + doHTTPCheck(t, port, c.httpAddr) + }) + } + } +} + +var _ net.Conn diff --git a/transport/openvpn/keymethod.go b/transport/openvpn/keymethod.go index 4e74cdd3..702ccc99 100644 --- a/transport/openvpn/keymethod.go +++ b/transport/openvpn/keymethod.go @@ -114,7 +114,7 @@ func ParseServerKeyMethod2Record(packet []byte) (*KeyMethod2Record, error) { } func DeriveClientKeyMaterial(sources KeySource2, clientSession, serverSession SessionID, cipherKeyLen int) (*KeyMaterial, error) { - if cipherKeyLen != 16 && cipherKeyLen != 32 { + if cipherKeyLen != 16 && cipherKeyLen != 24 && cipherKeyLen != 32 { return nil, fmt.Errorf("unsupported data cipher key length %d", cipherKeyLen) } var master [48]byte diff --git a/transport/openvpn/lzo.go b/transport/openvpn/lzo.go new file mode 100644 index 00000000..7fdcd845 --- /dev/null +++ b/transport/openvpn/lzo.go @@ -0,0 +1,48 @@ +package openvpn + +import ( + "bytes" + "errors" + + "github.com/rasky/go-lzo" +) + +const ( + lzoCompressNone = 0xFA + lzoCompressLZO = 0x66 +) + +var ErrLZODecompress = errors.New("lzo decompression failed") + +func lzo1xDecompressSafe(src []byte) ([]byte, error) { + if len(src) == 0 { + return nil, ErrLZODecompress + } + + switch src[0] { + case lzoCompressNone: + if len(src) > 1 { + return src[1:], nil + } + return nil, nil + case lzoCompressLZO: + if len(src) > 1 { + r := bytes.NewReader(src[1:]) + out, err := lzo.Decompress1X(r, len(src)-1, 0) + if err != nil { + return nil, ErrLZODecompress + } + return out, nil + } + return nil, nil + default: + return nil, ErrLZODecompress + } +} + +func lzo1xCompressSafe(src []byte) ([]byte, error) { + lzoPacket := make([]byte, 1+len(src)) + lzoPacket[0] = lzoCompressNone + copy(lzoPacket[1:], src) + return lzoPacket, nil +} diff --git a/transport/openvpn/push.go b/transport/openvpn/push.go index 3fe9af1f..7c12a287 100644 --- a/transport/openvpn/push.go +++ b/transport/openvpn/push.go @@ -10,16 +10,17 @@ import ( const PushRequest = "PUSH_REQUEST" type PushReply struct { - Raw string - Prefixes []netip.Prefix - DNS []netip.Addr - PeerID uint32 - Cipher string - Ping uint32 - MTU uint32 - CompLZO bool - Redirect bool - BlockIPv6 bool + Raw string + Prefixes []netip.Prefix + DNS []netip.Addr + PeerID uint32 + Cipher string + Ping uint32 + PingRestart uint32 + MTU uint32 + CompLZO bool + Redirect bool + BlockIPv6 bool } func ParsePushReply(message string) (*PushReply, error) { @@ -81,6 +82,12 @@ func ParsePushReply(message string) (*PushReply, error) { reply.Ping = uint32(v) } } + case "ping-restart": + if len(fields) >= 2 { + if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil { + reply.PingRestart = uint32(v) + } + } case "tun-mtu": if len(fields) >= 2 { if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil { @@ -113,27 +120,44 @@ func splitPushOptions(message string) []string { return out } -func parseIPv4Ifconfig(address, mask string) (netip.Prefix, error) { +func parseIPv4Ifconfig(address, maskOrPeer string) (netip.Prefix, error) { addr, err := netip.ParseAddr(address) if err != nil { return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 address %q: %w", address, err) } - maskAddr, err := netip.ParseAddr(mask) + maskAddr, err := netip.ParseAddr(maskOrPeer) if err != nil { - return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", mask, err) + return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", maskOrPeer, err) } if !addr.Is4() || !maskAddr.Is4() { return netip.Prefix{}, fmt.Errorf("openvpn ifconfig requires ipv4 address and mask") } - maskBytes := maskAddr.As4() + + if ones, ok := ipv4MaskSize(maskAddr); ok { + return netip.PrefixFrom(addr, ones), nil + } + + // Some servers, including SoftEther/VPNGate in net30/p2p mode, push + // "ifconfig " rather than "ifconfig ". + // Use a host prefix for that local tunnel address. + return netip.PrefixFrom(addr, 32), nil +} + +func ipv4MaskSize(mask netip.Addr) (int, bool) { + maskBytes := mask.As4() ones := 0 + seenZero := false for _, b := range maskBytes { for i := 7; i >= 0; i-- { if b&(1< 0 { + pingRestart = time.Duration(t.client.push.PingRestart) * time.Second + } + if pingRestart > 0 { + t.wg.Add(1) + go func() { + defer t.wg.Done() + ticker := time.NewTicker(pingRestart) + defer ticker.Stop() + for { + select { + case <-t.ctx.Done(): + return + case <-ticker.C: + client, err := t.getClient() + if err != nil { + return + } + if client.SinceReceive() >= pingRestart { + if ok := t.closeClient(client); ok { + t.logger.ErrorContext(t.ctx, fmt.Errorf("ping-restart timeout: no packet received for %s", pingRestart)) + } + } + } + } + }() + } <-t.ctx.Done() } diff --git a/transport/simple-obfs/http_server.go b/transport/simple-obfs/http_server.go new file mode 100644 index 00000000..21be1ede --- /dev/null +++ b/transport/simple-obfs/http_server.go @@ -0,0 +1,100 @@ +package obfs + +import ( + "bufio" + cryptorand "crypto/rand" + "encoding/base64" + "fmt" + "io" + "math/rand/v2" + "net" + "net/http" + "time" +) + +// HTTPObfsServer is the server side of the shadowsocks http simple-obfs implementation. +type HTTPObfsServer struct { + net.Conn + buf []byte + bio *bufio.Reader + offset int + firstRequest bool + firstResponse bool +} + +func (hos *HTTPObfsServer) Read(b []byte) (int, error) { + if hos.buf != nil { + n := copy(b, hos.buf[hos.offset:]) + hos.offset += n + if hos.offset == len(hos.buf) { + hos.offset = 0 + hos.buf = nil + } + return n, nil + } + if hos.firstRequest { + bio := bufio.NewReader(hos.Conn) + req, err := http.ReadRequest(bio) + if err != nil { + return 0, err + } + if req.Method != "GET" || req.Header.Get("Connection") != "Upgrade" { + return 0, io.EOF + } + buf, err := io.ReadAll(req.Body) + if err != nil { + return 0, err + } + n := copy(b, buf) + if n < len(buf) { + hos.buf = buf + hos.offset = n + } + req.Body.Close() + hos.bio = bio + hos.firstRequest = false + return n, nil + } + return hos.bio.Read(b) +} + +func (hos *HTTPObfsServer) Write(b []byte) (int, error) { + if hos.firstResponse { + randBytes := make([]byte, 16) + cryptorand.Read(randBytes) + date := time.Now().Format(time.RFC1123) + resp := fmt.Sprintf(httpResponseTemplate, vMajor, vMinor, date, base64.URLEncoding.EncodeToString(randBytes)) + _, err := hos.Conn.Write([]byte(resp)) + if err != nil { + return 0, err + } + hos.firstResponse = false + } + return hos.Conn.Write(b) +} + +func (hos *HTTPObfsServer) Upstream() any { + return hos.Conn +} + +// NewHTTPObfsServer returns a server-side HTTPObfs. +func NewHTTPObfsServer(conn net.Conn) net.Conn { + return &HTTPObfsServer{ + Conn: conn, + firstRequest: true, + firstResponse: true, + } +} + +const httpResponseTemplate = "HTTP/1.1 101 Switching Protocols\r\n" + + "Server: nginx/1.%d.%d\r\n" + + "Date: %s\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Accept: %s\r\n" + + "\r\n" + +var ( + vMajor = rand.IntN(11) + vMinor = rand.IntN(12) +) diff --git a/transport/simple-obfs/tls_server.go b/transport/simple-obfs/tls_server.go new file mode 100644 index 00000000..9cce8438 --- /dev/null +++ b/transport/simple-obfs/tls_server.go @@ -0,0 +1,154 @@ +package obfs + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "io" + "net" + "time" + + B "github.com/sagernet/sing/common/buf" +) + +// TLSObfsServer is the server side of the shadowsocks tls simple-obfs implementation. +type TLSObfsServer struct { + net.Conn + remain int + firstRequest bool + sessionTicketDone bool + firstResponse bool +} + +func (tos *TLSObfsServer) read(b []byte, discardN int) (int, error) { + buf := B.Get(discardN) + _, err := io.ReadFull(tos.Conn, buf) + B.Put(buf) + if err != nil { + return 0, err + } + sizeBuf := make([]byte, 2) + _, err = io.ReadFull(tos.Conn, sizeBuf) + if err != nil { + return 0, nil + } + length := int(binary.BigEndian.Uint16(sizeBuf)) + if length > len(b) { + n, err := tos.Conn.Read(b) + if err != nil { + return n, err + } + tos.remain = length - n + return n, nil + } + return io.ReadFull(tos.Conn, b[:length]) +} + +// skipOtherExts skips SNI & other TLS extensions. +func (tos *TLSObfsServer) skipOtherExts() error { + buf := make([]byte, 256) + _, err := tos.read(buf, 7) + if err != nil { + return err + } + _, err = io.ReadFull(tos.Conn, buf[:4*16+2]) + return err +} + +func (tos *TLSObfsServer) Read(b []byte) (int, error) { + if tos.remain > 0 { + length := tos.remain + if length > len(b) { + length = len(b) + } + n, err := io.ReadFull(tos.Conn, b[:length]) + tos.remain -= n + return n, err + } + if tos.firstRequest { + tos.firstRequest = false + return tos.read(b, 9*16-4) + } + if !tos.sessionTicketDone { + tos.sessionTicketDone = true + err := tos.skipOtherExts() + if err != nil { + return 0, err + } + } + return tos.read(b, 3) +} + +func (tos *TLSObfsServer) Write(b []byte) (int, error) { + length := len(b) + for i := 0; i < length; i += chunkSize { + end := i + chunkSize + if end > length { + end = length + } + n, err := tos.write(b[i:end]) + if err != nil { + return n, err + } + } + return length, nil +} + +func (tos *TLSObfsServer) write(b []byte) (int, error) { + if tos.firstResponse { + serverHello := makeServerHello(b) + _, err := tos.Conn.Write(serverHello) + tos.firstResponse = false + return len(b), err + } + buf := B.NewSize(5 + len(b)) + defer buf.Release() + buf.Write([]byte{0x17, 0x03, 0x03}) + binary.Write(buf, binary.BigEndian, uint16(len(b))) + buf.Write(b) + _, err := tos.Conn.Write(buf.Bytes()) + if err != nil { + return 0, err + } + return len(b), nil +} + +func (tos *TLSObfsServer) Upstream() any { + return tos.Conn +} + +// NewTLSObfsServer returns a server-side TLS SimpleObfs. +func NewTLSObfsServer(conn net.Conn) net.Conn { + return &TLSObfsServer{ + Conn: conn, + firstRequest: true, + firstResponse: true, + } +} + +func makeServerHello(data []byte) []byte { + randBytes := make([]byte, 28) + sessionId := make([]byte, 32) + rand.Read(randBytes) + rand.Read(sessionId) + buf := &bytes.Buffer{} + buf.WriteByte(0x16) + binary.Write(buf, binary.BigEndian, uint16(0x0301)) + binary.Write(buf, binary.BigEndian, uint16(91)) + buf.Write([]byte{2, 0, 0, 87, 0x03, 0x03}) + binary.Write(buf, binary.BigEndian, uint32(time.Now().Unix())) + buf.Write(randBytes) + buf.WriteByte(32) + buf.Write(sessionId) + buf.Write([]byte{0xcc, 0xa8}) + buf.WriteByte(0) + buf.Write([]byte{0x00, 0x00}) + buf.Write([]byte{0xff, 0x01, 0x00, 0x01, 0x00}) + buf.Write([]byte{0x00, 0x17, 0x00, 0x00}) + buf.Write([]byte{0x00, 0x0b, 0x00, 0x02, 0x01, 0x00}) + buf.Write([]byte{0x14, 0x03, 0x03, 0x00, 0x01, 0x01}) + buf.Write([]byte{0x16, 0x03, 0x03}) + binary.Write(buf, binary.BigEndian, uint16(len(data))) + buf.Write(data) + return buf.Bytes() +} diff --git a/transport/snell/address.go b/transport/snell/address.go new file mode 100644 index 00000000..a6c135fd --- /dev/null +++ b/transport/snell/address.go @@ -0,0 +1,144 @@ +package snell + +import ( + "encoding/binary" + "net" + "strconv" +) + +// SOCKS address types as defined in RFC 1928 section 5. +const ( + atypIPv4 = 1 + atypDomainName = 3 + atypIPv6 = 4 +) + +// socksAddr represents a SOCKS address as defined in RFC 1928 section 5. +type socksAddr []byte + +func (a socksAddr) String() string { + var host, port string + switch a[0] { + case atypDomainName: + hostLen := uint16(a[1]) + host = string(a[2 : 2+hostLen]) + port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1])) + case atypIPv4: + host = net.IP(a[1 : 1+net.IPv4len]).String() + port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1])) + case atypIPv6: + host = net.IP(a[1 : 1+net.IPv6len]).String() + port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1])) + } + return net.JoinHostPort(host, port) +} + +// UDPAddr converts a socksAddr to *net.UDPAddr. +func (a socksAddr) UDPAddr() *net.UDPAddr { + if len(a) == 0 { + return nil + } + switch a[0] { + case atypIPv4: + var ip [net.IPv4len]byte + copy(ip[0:], a[1:1+net.IPv4len]) + return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))} + case atypIPv6: + var ip [net.IPv6len]byte + copy(ip[0:], a[1:1+net.IPv6len]) + return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))} + } + return nil +} + +// splitSocksAddr slices a SOCKS address from beginning of b. Returns nil if failed. +func splitSocksAddr(b []byte) socksAddr { + addrLen := 1 + if len(b) < addrLen { + return nil + } + switch b[0] { + case atypDomainName: + if len(b) < 2 { + return nil + } + addrLen = 1 + 1 + int(b[1]) + 2 + case atypIPv4: + addrLen = 1 + net.IPv4len + 2 + case atypIPv6: + addrLen = 1 + net.IPv6len + 2 + default: + return nil + } + if len(b) < addrLen { + return nil + } + return b[:addrLen] +} + +// parseAddr parses the address in string s. Returns nil if failed. +func parseAddr(s string) socksAddr { + var addr socksAddr + host, port, err := net.SplitHostPort(s) + if err != nil { + return nil + } + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + addr = make([]byte, 1+net.IPv4len+2) + addr[0] = atypIPv4 + copy(addr[1:], ip4) + } else { + addr = make([]byte, 1+net.IPv6len+2) + addr[0] = atypIPv6 + copy(addr[1:], ip) + } + } else { + if len(host) > 255 { + return nil + } + addr = make([]byte, 1+1+len(host)+2) + addr[0] = atypDomainName + addr[1] = byte(len(host)) + copy(addr[2:], host) + } + portnum, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil + } + addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum) + return addr +} + +// parseAddrToSocksAddr parses a socks addr from net.Addr. +// This is a fast path of parseAddr(addr.String()). +func parseAddrToSocksAddr(addr net.Addr) socksAddr { + var hostip net.IP + var port int + switch addr := addr.(type) { + case *net.UDPAddr: + hostip = addr.IP + port = addr.Port + case *net.TCPAddr: + hostip = addr.IP + port = addr.Port + case nil: + return nil + } + if hostip == nil { + return parseAddr(addr.String()) + } + var parsed socksAddr + if ip4 := hostip.To4(); ip4.DefaultMask() != nil { + parsed = make([]byte, 1+net.IPv4len+2) + parsed[0] = atypIPv4 + copy(parsed[1:], ip4) + binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port)) + } else { + parsed = make([]byte, 1+net.IPv6len+2) + parsed[0] = atypIPv6 + copy(parsed[1:], hostip) + binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port)) + } + return parsed +} diff --git a/transport/snell/cipher.go b/transport/snell/cipher.go new file mode 100644 index 00000000..ae8d59d2 --- /dev/null +++ b/transport/snell/cipher.go @@ -0,0 +1,56 @@ +package snell + +import ( + "crypto/aes" + "crypto/cipher" + + "golang.org/x/crypto/argon2" + "golang.org/x/crypto/chacha20poly1305" +) + +// NewAES128GCM returns the AES-128-GCM cipher used by snell v2/v3. +func NewAES128GCM(psk []byte) Cipher { + return &snellCipher{ + psk: psk, + keySize: 16, + makeAEAD: aesGCM, + } +} + +// NewChacha20Poly1305 returns the ChaCha20-Poly1305 cipher used by snell v1. +func NewChacha20Poly1305(psk []byte) Cipher { + return &snellCipher{ + psk: psk, + keySize: 32, + makeAEAD: chacha20poly1305.New, + } +} + +type snellCipher struct { + psk []byte + keySize int + makeAEAD func(key []byte) (cipher.AEAD, error) +} + +func (sc *snellCipher) KeySize() int { return sc.keySize } +func (sc *snellCipher) SaltSize() int { return 16 } + +func (sc *snellCipher) Encrypter(salt []byte) (cipher.AEAD, error) { + return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize())) +} + +func (sc *snellCipher) Decrypter(salt []byte) (cipher.AEAD, error) { + return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize())) +} + +func snellKDF(psk, salt []byte, keySize int) []byte { + return argon2.IDKey(psk, salt, 3, 8, 1, 32)[:keySize] +} + +func aesGCM(key []byte) (cipher.AEAD, error) { + blk, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewGCM(blk) +} diff --git a/transport/snell/client.go b/transport/snell/client.go new file mode 100644 index 00000000..65a570ec --- /dev/null +++ b/transport/snell/client.go @@ -0,0 +1,120 @@ +package snell + +import ( + "context" + "net" + "strconv" + "time" + + obfs "github.com/sagernet/sing-box/transport/simple-obfs" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type ClientOptions struct { + Dialer N.Dialer + Server M.Socksaddr + PSK []byte + Version int + Reuse bool + ObfsMode string + ObfsHost string +} + +type Client struct { + dialer N.Dialer + server M.Socksaddr + psk []byte + version int + reuse bool + obfsMode string + obfsHost string + pool *Pool +} + +func NewClient(options ClientOptions) *Client { + c := &Client{ + dialer: options.Dialer, + server: options.Server, + psk: options.PSK, + version: options.Version, + reuse: options.Reuse, + obfsMode: options.ObfsMode, + obfsHost: options.ObfsHost, + } + if c.reuse { + c.pool = NewPool(func(ctx context.Context) (*Snell, error) { + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server) + if err != nil { + return nil, err + } + return c.streamConn(conn), nil + }) + } + return c +} + +func (c *Client) streamConn(conn net.Conn) *Snell { + switch c.obfsMode { + case "tls": + conn = obfs.NewTLSObfs(conn, c.obfsHost) + case "http": + conn = obfs.NewHTTPObfs(conn, c.obfsHost, strconv.Itoa(int(c.server.Port))) + } + return StreamConn(conn, c.psk, c.version) +} + +func (c *Client) writeHeader(ctx context.Context, conn net.Conn, destination M.Socksaddr, udp bool) (err error) { + if deadline, ok := ctx.Deadline(); ok { + _ = conn.SetWriteDeadline(deadline) + defer conn.SetWriteDeadline(time.Time{}) + } + if udp { + err = WriteUDPHeader(conn, c.version) + if err == nil && c.version >= Version4 { + if sc, ok := conn.(*Snell); ok { + err = sc.ReadReply() + } + } + return + } + err = WriteHeaderWithReuse(conn, destination.AddrString(), uint(destination.Port), c.version, c.reuse) + return +} + +func (c *Client) DialContext(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { + if c.reuse { + conn, err := c.pool.Get() + if err != nil { + return nil, err + } + if err = c.writeHeader(ctx, conn, destination, false); err != nil { + _ = conn.Close() + return nil, err + } + return conn, nil + } + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server) + if err != nil { + return nil, err + } + stream := c.streamConn(conn) + if err = c.writeHeader(ctx, stream, destination, false); err != nil { + _ = conn.Close() + return nil, err + } + return stream, nil +} + +func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server) + if err != nil { + return nil, err + } + stream := c.streamConn(conn) + if err = c.writeHeader(ctx, stream, destination, true); err != nil { + _ = conn.Close() + return nil, err + } + return PacketConn(stream), nil +} diff --git a/transport/snell/pool.go b/transport/snell/pool.go new file mode 100644 index 00000000..703cca20 --- /dev/null +++ b/transport/snell/pool.go @@ -0,0 +1,153 @@ +package snell + +import ( + "context" + "io" + "net" + "runtime" + "sync" + "time" +) + +// poolEntry holds a pooled item with its insertion time. + +// connPool is a small connection pool with age-based eviction. + +// milliseconds + +// Pool is a pool of reusable snell connections. +type Pool struct { + pool *connPool +} + +func (p *Pool) Get() (net.Conn, error) { + return p.GetContext(context.Background()) +} + +func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) { + elm, err := p.pool.GetContext(ctx) + if err != nil { + return nil, err + } + return &PoolConn{Snell: elm, pool: p}, nil +} + +func (p *Pool) Put(conn *Snell) { + if err := HalfClose(conn); err != nil { + _ = conn.Close() + return + } + p.pool.put(conn) +} + +// PoolConn wraps a pooled snell connection and returns it to the pool on Close. +type PoolConn struct { + *Snell + pool *Pool + closeWriteOnce sync.Once + closeWriteErr error + closeOnce sync.Once + closeErr error +} + +func (pc *PoolConn) Read(b []byte) (int, error) { + n, err := pc.Snell.Read(b) + if err == ErrZeroChunk { + return n, io.EOF + } + return n, err +} + +func (pc *PoolConn) Write(b []byte) (int, error) { + return pc.Snell.Write(b) +} + +func (pc *PoolConn) CloseWrite() error { + pc.closeWriteOnce.Do(func() { + pc.closeWriteErr = writeZeroChunk(pc.Snell) + }) + return pc.closeWriteErr +} + +func (pc *PoolConn) Close() error { + pc.closeOnce.Do(func() { + if err := pc.CloseWrite(); err != nil { + pc.closeErr = err + _ = pc.Snell.Close() + return + } + _ = pc.Snell.Conn.SetReadDeadline(time.Time{}) + pc.Snell.reply = false + pc.pool.pool.put(pc.Snell) + }) + return pc.closeErr +} + +// NewPool creates a new snell connection pool using the given factory. +func NewPool(factory func(context.Context) (*Snell, error)) *Pool { + cp := &connPool{ + ch: make(chan *poolEntry, 10), + factory: factory, + maxAge: 15000, + evict: func(item *Snell) { + _ = item.Close() + }, + } + p := &Pool{pool: cp} + runtime.SetFinalizer(p, recycle) + return p +} + +type poolEntry struct { + elm *Snell + time time.Time +} + +type connPool struct { + ch chan *poolEntry + factory func(context.Context) (*Snell, error) + evict func(*Snell) + maxAge int64 +} + +func (p *connPool) GetContext(ctx context.Context) (*Snell, error) { + now := time.Now() + for { + select { + case item := <-p.ch: + if p.maxAge != 0 && now.Sub(item.time).Milliseconds() > p.maxAge { + if p.evict != nil { + p.evict(item.elm) + } + continue + } + return item.elm, nil + default: + return p.factory(ctx) + } + } +} + +func (p *connPool) put(item *Snell) { + e := &poolEntry{ + elm: item, + time: time.Now(), + } + select { + case p.ch <- e: + return + default: + if p.evict != nil { + p.evict(item) + } + return + } +} + +func recycle(p *Pool) { + for item := range p.pool.ch { + if p.pool.evict != nil { + p.pool.evict(item.elm) + } + } +} diff --git a/transport/snell/service.go b/transport/snell/service.go new file mode 100644 index 00000000..d6e0e77a --- /dev/null +++ b/transport/snell/service.go @@ -0,0 +1,294 @@ +package snell + +import ( + "bufio" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + obfs "github.com/sagernet/sing-box/transport/simple-obfs" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" +) + +type Handler interface { + NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, clientID string) + + NewPacketConnection(ctx context.Context, conn net.PacketConn, source M.Socksaddr, clientID string) +} + +type ServiceOptions struct { + PSK []byte + Version int + ObfsMode string + UDP bool + Logger logger.ContextLogger + Handler Handler +} + +type Service struct { + psk []byte + version int + obfsMode string + udp bool + logger logger.ContextLogger + handler Handler +} + +func NewService(options ServiceOptions) (*Service, error) { + version := options.Version + if version == 0 { + version = Version4 + } + if version != Version4 && version != Version5 { + return nil, fmt.Errorf("snell inbound version %d is not supported", version) + } + if len(options.PSK) == 0 { + return nil, errors.New("snell inbound requires psk") + } + switch options.ObfsMode { + case "", "http", "tls": + default: + return nil, fmt.Errorf("snell inbound obfs mode error: %s", options.ObfsMode) + } + return &Service{ + psk: options.PSK, + version: version, + obfsMode: options.ObfsMode, + udp: options.UDP, + logger: options.Logger, + handler: options.Handler, + }, nil +} + +func (s *Service) NewConnection(ctx context.Context, rawConn net.Conn, source M.Socksaddr) error { + conn := rawConn + switch s.obfsMode { + case "http": + conn = obfs.NewHTTPObfsServer(conn) + case "tls": + conn = obfs.NewTLSObfsServer(conn) + } + stream := ServerStreamConn(conn, s.psk, s.version) + for { + reuse, err := s.handleRequest(ctx, stream, source) + if err != nil || !reuse { + return err + } + } +} + +func (s *Service) handleRequest(ctx context.Context, stream *Snell, source M.Socksaddr) (bool, error) { + br := bufio.NewReader(stream) + version, err := br.ReadByte() + if err != nil { + return false, err + } + if version != Version { + return false, fmt.Errorf("snell invalid protocol version: %d", version) + } + command, err := br.ReadByte() + if err != nil { + return false, err + } + if command == CommandPing { + _, _ = stream.Write([]byte{CommandPong}) + return false, nil + } + clientID, err := readClientID(br) + if err != nil { + return false, err + } + switch command { + case CommandConnect, CommandConnectV2: + return s.handleTCP(ctx, stream, br, command == CommandConnectV2, clientID, source) + case CommandUDP: + if !s.udp { + return false, errors.New("snell UDP is disabled") + } + return false, s.handleUDP(ctx, stream, clientID, source) + default: + return false, fmt.Errorf("snell unknown command: %d", command) + } +} + +func (s *Service) handleTCP(ctx context.Context, stream *Snell, br *bufio.Reader, reuse bool, clientID string, source M.Socksaddr) (bool, error) { + hostLen, err := br.ReadByte() + if err != nil { + return false, err + } + if hostLen == 0 { + return false, errors.New("snell connect host is empty") + } + hostBytes := make([]byte, int(hostLen)) + if _, err := io.ReadFull(br, hostBytes); err != nil { + return false, err + } + var portBytes [2]byte + if _, err := io.ReadFull(br, portBytes[:]); err != nil { + return false, err + } + destination := M.ParseSocksaddrHostPort(string(hostBytes), binary.BigEndian.Uint16(portBytes[:])) + conn := &tcpRequestConn{ + Conn: stream, + reader: br, + reuse: reuse, + } + s.handler.NewConnection(ctx, conn, source, destination, clientID) + if !reuse { + return false, nil + } + return true, nil +} + +func (s *Service) handleUDP(ctx context.Context, stream *Snell, clientID string, source M.Socksaddr) error { + if _, err := stream.Write([]byte{CommandTunnel}); err != nil { + return err + } + pc := &serverPacketConn{ + conn: stream, + writeMu: &sync.Mutex{}, + } + s.handler.NewPacketConnection(ctx, pc, source, clientID) + return nil +} + +const maxPacketLength = 0x3fff + +func readClientID(r *bufio.Reader) (string, error) { + length, err := r.ReadByte() + if err != nil { + return "", err + } + if length == 0 { + return "", nil + } + id := make([]byte, int(length)) + if _, err := io.ReadFull(r, id); err != nil { + return "", err + } + return string(id), nil +} + +func writeCommandError(w io.Writer, code byte, message string) error { + msg := []byte(message) + if len(msg) > 255 { + msg = msg[:255] + } + buf := make([]byte, 0, 3+len(msg)) + buf = append(buf, CommandError, code, byte(len(msg))) + buf = append(buf, msg...) + _, err := w.Write(buf) + return err +} + +type tcpRequestConn struct { + net.Conn + reader *bufio.Reader + reuse bool + writeMu sync.Mutex + closeOnce sync.Once + replyWritten bool +} + +func (c *tcpRequestConn) Read(p []byte) (int, error) { + n, err := c.reader.Read(p) + if errors.Is(err, ErrZeroChunk) { + err = io.EOF + } + return n, err +} + +func (c *tcpRequestConn) Write(p []byte) (int, error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if !c.replyWritten { + payload := make([]byte, 1+len(p)) + payload[0] = CommandTunnel + copy(payload[1:], p) + if _, err := c.Conn.Write(payload); err != nil { + return 0, err + } + c.replyWritten = true + return len(p), nil + } + return c.Conn.Write(p) +} + +func (c *tcpRequestConn) CloseWrite() error { + return c.Close() +} + +func (c *tcpRequestConn) Close() error { + var err error + c.closeOnce.Do(func() { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if !c.replyWritten { + err = writeCommandError(c.Conn, 0x65, "Remote EOF") + if !c.reuse { + err = errors.Join(err, c.Conn.Close()) + } + return + } + if c.reuse { + _, err = c.Conn.Write(nil) + return + } + err = c.Conn.Close() + }) + return err +} + +type serverPacketConn struct { + conn *Snell + writeMu *sync.Mutex + readBuf []byte +} + +func (c *serverPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + if c.readBuf == nil { + c.readBuf = make([]byte, maxPacketLength) + } + for { + n, err := c.conn.Read(c.readBuf) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, ErrZeroChunk) { + return 0, nil, io.EOF + } + return 0, nil, err + } + request, err := ParseUDPRequest(c.readBuf[:n]) + if err != nil { + return 0, nil, err + } + var destination M.Socksaddr + if request.Ip.IsValid() { + destination = M.SocksaddrFrom(request.Ip, request.Port) + } else { + destination = M.ParseSocksaddrHostPort(request.Host, request.Port) + } + length := copy(p, request.Payload) + if destination.IsFqdn() { + return length, destination, nil + } + return length, destination.UDPAddr(), nil + } +} + +func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + return WritePacketResponse(c.conn, addr, p) +} + +func (c *serverPacketConn) Close() error { return c.conn.Close() } +func (c *serverPacketConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } +func (c *serverPacketConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } +func (c *serverPacketConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } +func (c *serverPacketConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } diff --git a/transport/snell/shadowaead.go b/transport/snell/shadowaead.go new file mode 100644 index 00000000..052a6008 --- /dev/null +++ b/transport/snell/shadowaead.go @@ -0,0 +1,211 @@ +package snell + +import ( + "crypto/cipher" + "crypto/rand" + "errors" + "io" + "net" + + "github.com/sagernet/sing/common/buf" +) + +// payloadSizeMask is the maximum size of payload in bytes. +// 16*1024 - 1 +// >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead() + +// ErrZeroChunk is returned when a zero-length chunk is read, which snell uses +// as an end-of-stream signal. +var ErrZeroChunk = errors.New("zero chunk") + +// Cipher is the AEAD cipher abstraction used by the shadowaead stream. +type Cipher interface { + KeySize() int + SaltSize() int + Encrypter(salt []byte) (cipher.AEAD, error) + Decrypter(salt []byte) (cipher.AEAD, error) +} + +const ( + payloadSizeMask = 0x3FFF + bufSize = 17 * 1024 +) + +type aeadWriter struct { + io.Writer + cipher.AEAD + nonce [32]byte // should be sufficient for most nonce sizes +} + +// newAEADWriter wraps an io.Writer with authenticated encryption. +func newAEADWriter(w io.Writer, aead cipher.AEAD) *aeadWriter { + return &aeadWriter{Writer: w, AEAD: aead} +} + +// Write encrypts p and writes to the embedded io.Writer. +func (w *aeadWriter) Write(p []byte) (n int, err error) { + b := buf.Get(bufSize) + defer buf.Put(b) + nonce := w.nonce[:w.NonceSize()] + tag := w.Overhead() + off := 2 + tag + if len(p) == 0 { + b = b[:off] + b[0], b[1] = byte(0), byte(0) + w.Seal(b[:0], nonce, b[:2], nil) + increment(nonce) + _, err = w.Writer.Write(b) + return + } + for nr := 0; n < len(p) && err == nil; n += nr { + nr = payloadSizeMask + if n+nr > len(p) { + nr = len(p) - n + } + b = b[:off+nr+tag] + b[0], b[1] = byte(nr>>8), byte(nr) + w.Seal(b[:0], nonce, b[:2], nil) + increment(nonce) + w.Seal(b[:off], nonce, p[n:n+nr], nil) + increment(nonce) + _, err = w.Writer.Write(b) + } + return +} + +type aeadReader struct { + io.Reader + cipher.AEAD + nonce [32]byte // should be sufficient for most nonce sizes + buf []byte // to be put back into bufPool + off int // offset to unconsumed part of buf +} + +// newAEADReader wraps an io.Reader with authenticated decryption. +func newAEADReader(r io.Reader, aead cipher.AEAD) *aeadReader { + return &aeadReader{Reader: r, AEAD: aead} +} + +// read and decrypt a record into p. len(p) >= max payload size + AEAD overhead. +func (r *aeadReader) read(p []byte) (int, error) { + nonce := r.nonce[:r.NonceSize()] + tag := r.Overhead() + p = p[:2+tag] + if _, err := io.ReadFull(r.Reader, p); err != nil { + return 0, err + } + _, err := r.Open(p[:0], nonce, p, nil) + increment(nonce) + if err != nil { + return 0, err + } + size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask + if size == 0 { + return 0, ErrZeroChunk + } + p = p[:size+tag] + if _, err := io.ReadFull(r.Reader, p); err != nil { + return 0, err + } + _, err = r.Open(p[:0], nonce, p, nil) + increment(nonce) + if err != nil { + return 0, err + } + return size, nil +} + +// Read reads from the embedded io.Reader, decrypts and writes to p. +func (r *aeadReader) Read(p []byte) (int, error) { + if r.buf == nil { + if len(p) >= payloadSizeMask+r.Overhead() { + return r.read(p) + } + b := buf.Get(bufSize) + n, err := r.read(b) + if err != nil { + buf.Put(b) + return 0, err + } + r.buf = b[:n] + r.off = 0 + } + n := copy(p, r.buf[r.off:]) + r.off += n + if r.off == len(r.buf) { + buf.Put(r.buf[:cap(r.buf)]) + r.buf = nil + } + return n, nil +} + +// increment little-endian encoded unsigned integer b. Wrap around on overflow. +func increment(b []byte) { + for i := range b { + b[i]++ + if b[i] != 0 { + return + } + } +} + +// aeadConn wraps a stream-oriented net.Conn with the shadowaead cipher. +type aeadConn struct { + net.Conn + Cipher + r *aeadReader + w *aeadWriter +} + +// newAEADConn wraps a stream-oriented net.Conn with cipher. +func newAEADConn(c net.Conn, ciph Cipher) *aeadConn { + return &aeadConn{Conn: c, Cipher: ciph} +} + +func (c *aeadConn) initReader() error { + salt := make([]byte, c.SaltSize()) + if _, err := io.ReadFull(c.Conn, salt); err != nil { + return err + } + aead, err := c.Decrypter(salt) + if err != nil { + return err + } + c.r = newAEADReader(c.Conn, aead) + return nil +} + +func (c *aeadConn) Read(b []byte) (int, error) { + if c.r == nil { + if err := c.initReader(); err != nil { + return 0, err + } + } + return c.r.Read(b) +} + +func (c *aeadConn) initWriter() error { + salt := make([]byte, c.SaltSize()) + if _, err := rand.Read(salt); err != nil { + return err + } + aead, err := c.Encrypter(salt) + if err != nil { + return err + } + _, err = c.Conn.Write(salt) + if err != nil { + return err + } + c.w = newAEADWriter(c.Conn, aead) + return nil +} + +func (c *aeadConn) Write(b []byte) (int, error) { + if c.w == nil { + if err := c.initWriter(); err != nil { + return 0, err + } + } + return c.w.Write(b) +} diff --git a/transport/snell/snell.go b/transport/snell/snell.go new file mode 100644 index 00000000..a3c70d83 --- /dev/null +++ b/transport/snell/snell.go @@ -0,0 +1,408 @@ +package snell + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "sync" + + "github.com/sagernet/sing/common/buf" +) + +const ( + Version1 = 1 + Version2 = 2 + Version3 = 3 + Version4 = 4 + Version5 = 5 + DefaultSnellVersion = Version1 + + // max packet length + maxLength = 0x3FFF +) + +const ( + CommandPing byte = 0 + CommandConnect byte = 1 + CommandConnectV2 byte = 5 + CommandUDP byte = 6 + CommandUDPForward byte = 1 + + CommandTunnel byte = 0 + CommandPong byte = 1 + CommandError byte = 2 + + Version byte = 1 +) + +// Snell wraps an encrypted stream and handles the snell reply header. +type Snell struct { + net.Conn + buffer [1]byte + reply bool +} + +func (s *Snell) Read(b []byte) (int, error) { + if err := s.ReadReply(); err != nil { + return 0, err + } + return s.Conn.Read(b) +} + +func (s *Snell) ReadReply() error { + if s.reply { + return nil + } + if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { + return err + } + s.reply = true + if s.buffer[0] == CommandTunnel { + return nil + } else if s.buffer[0] != CommandError { + return errors.New("command not support") + } + if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { + return err + } + errcode := int(s.buffer[0]) + if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { + return err + } + length := int(s.buffer[0]) + msg := make([]byte, length) + if _, err := io.ReadFull(s.Conn, msg); err != nil { + return err + } + return fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg)) +} + +func WriteHeader(conn net.Conn, host string, port uint, version int) error { + return WriteHeaderWithReuse(conn, host, port, version, false) +} + +func WriteHeaderWithReuse(conn net.Conn, host string, port uint, version int, reuse bool) error { + buffer := &bytes.Buffer{} + buffer.WriteByte(Version) + if version == Version2 || reuse { + buffer.WriteByte(CommandConnectV2) + } else { + buffer.WriteByte(CommandConnect) + } + buffer.WriteByte(0) + buffer.WriteByte(uint8(len(host))) + buffer.WriteString(host) + binary.Write(buffer, binary.BigEndian, uint16(port)) + if _, err := conn.Write(buffer.Bytes()); err != nil { + return err + } + return nil +} + +func WriteUDPHeader(conn net.Conn, version int) error { + if version < Version3 { + return errors.New("unsupport UDP version") + } + _, err := conn.Write([]byte{Version, CommandUDP, 0x00}) + return err +} + +// HalfClose only works after the request negotiated the reuse command. +func HalfClose(conn net.Conn) error { + if err := writeZeroChunk(conn); err != nil { + return err + } + if s, ok := conn.(*Snell); ok { + s.reply = false + } + return nil +} + +// StreamConn wraps a raw connection with the snell stream cipher for the given version. +func StreamConn(conn net.Conn, psk []byte, version int) *Snell { + if version >= Version4 { + return &Snell{Conn: newV4Conn(conn, psk)} + } + var cipher Cipher + if version != Version1 { + cipher = NewAES128GCM(psk) + } else { + cipher = NewChacha20Poly1305(psk) + } + return &Snell{Conn: newAEADConn(conn, cipher)} +} + +// ServerStreamConn wraps a raw connection on the server side. +func ServerStreamConn(conn net.Conn, psk []byte, version int) *Snell { + stream := StreamConn(conn, psk, version) + stream.reply = true + return stream +} + +func PacketConn(conn net.Conn) net.PacketConn { + return &packetConn{ + Conn: conn, + } +} + +func (s *Snell) WritePacketFrame(b []byte) (int, error) { + if fw, ok := s.Conn.(packetFrameWriter); ok { + return fw.WritePacketFrame(b) + } + return s.Conn.Write(b) +} + +func WritePacket(w io.Writer, target, payload []byte) (int, error) { + maxPayloadLength := maxLength - udpRequestHeaderLength(target) + if maxPayloadLength <= 0 { + return 0, errors.New("snell UDP address too large") + } + if len(payload) <= maxPayloadLength { + return writePacket(w, target, payload) + } + return 0, errors.New("snell UDP payload too large") +} + +func WritePacketResponse(w io.Writer, addr net.Addr, payload []byte) (int, error) { + buffer := &bytes.Buffer{} + target := parseAddrToSocksAddr(addr) + if len(target) == 0 { + return 0, errors.New("snell UDP response address invalid") + } + switch target[0] { + case atypIPv4: + if len(target) < 1+net.IPv4len+2 { + return 0, errors.New("snell UDP response address invalid") + } + buffer.WriteByte(0x04) + buffer.Write(target[1 : 1+net.IPv4len+2]) + case atypIPv6: + if len(target) < 1+net.IPv6len+2 { + return 0, errors.New("snell UDP response address invalid") + } + buffer.WriteByte(0x06) + buffer.Write(target[1 : 1+net.IPv6len+2]) + default: + return 0, errors.New("snell UDP response address invalid") + } + buffer.Write(payload) + var err error + if fw, ok := w.(packetFrameWriter); ok { + _, err = fw.WritePacketFrame(buffer.Bytes()) + } else { + _, err = w.Write(buffer.Bytes()) + } + if err != nil { + return 0, err + } + return len(payload), nil +} + +// UDPRequest is a parsed snell UDP forward request. +type UDPRequest struct { + Host string + Ip netip.Addr + Port uint16 + Payload []byte +} + +func ParseUDPRequest(packet []byte) (UDPRequest, error) { + if len(packet) < 2 || packet[0] != CommandUDPForward { + return UDPRequest{}, errors.New("snell invalid UDP request") + } + if hostLen := int(packet[1]); hostLen != 0 { + if len(packet) <= 2+hostLen+2 { + return UDPRequest{}, errors.New("snell invalid UDP domain request") + } + offset := 2 + hostLen + return UDPRequest{ + Host: string(packet[2:offset]), + Port: binary.BigEndian.Uint16(packet[offset : offset+2]), + Payload: packet[offset+2:], + }, nil + } + if len(packet) < 3 { + return UDPRequest{}, errors.New("snell invalid UDP IP request") + } + switch packet[2] { + case 0x04: + if len(packet) < 3+net.IPv4len+2 { + return UDPRequest{}, errors.New("snell invalid UDP IPv4 request") + } + offset := 3 + net.IPv4len + ip, _ := netip.AddrFromSlice(packet[3:offset]) + return UDPRequest{ + Ip: ip.Unmap(), + Port: binary.BigEndian.Uint16(packet[offset : offset+2]), + Payload: packet[offset+2:], + }, nil + case 0x06: + if len(packet) < 3+net.IPv6len+2 { + return UDPRequest{}, errors.New("snell invalid UDP IPv6 request") + } + offset := 3 + net.IPv6len + ip, _ := netip.AddrFromSlice(packet[3:offset]) + return UDPRequest{ + Ip: ip.Unmap(), + Port: binary.BigEndian.Uint16(packet[offset : offset+2]), + Payload: packet[offset+2:], + }, nil + default: + return UDPRequest{}, errors.New("snell invalid UDP address type") + } +} + +func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) { + b := buf.Get(buf.UDPBufferSize) + defer buf.Put(b) + n, err := r.Read(b) + headLen := 1 + if err != nil { + return nil, 0, err + } + if n < headLen { + return nil, 0, errors.New("insufficient UDP length") + } + switch b[0] { + case 0x04: + headLen += net.IPv4len + 2 + if n < headLen { + err = errors.New("insufficient UDP length") + break + } + b[0] = atypIPv4 + case 0x06: + headLen += net.IPv6len + 2 + if n < headLen { + err = errors.New("insufficient UDP length") + break + } + b[0] = atypIPv6 + default: + err = errors.New("ip version invalid") + } + if err != nil { + return nil, 0, err + } + addr := splitSocksAddr(b[0:]) + if addr == nil { + return nil, 0, errors.New("remote address invalid") + } + uAddr := addr.UDPAddr() + if uAddr == nil { + return nil, 0, errors.New("parse addr error") + } + length := len(payload) + if n-headLen < length { + length = n - headLen + } + copy(payload[:], b[headLen:headLen+length]) + return uAddr, length, nil +} + +var endSignal = []byte{} + +type packetFrameWriter interface { + WritePacketFrame([]byte) (int, error) +} + +func writeZeroChunk(conn net.Conn) error { + if _, err := conn.Write(endSignal); err != nil { + return err + } + return nil +} + +func writePacket(w io.Writer, target, payload []byte) (int, error) { + buffer := &bytes.Buffer{} + buffer.WriteByte(CommandUDPForward) + switch target[0] { + case atypDomainName: + hostLen := target[1] + if len(target) < 1+1+int(hostLen)+2 { + return 0, errors.New("snell UDP address invalid") + } + buffer.Write(target[1 : 1+1+hostLen+2]) + case atypIPv4: + if len(target) < 1+net.IPv4len+2 { + return 0, errors.New("snell UDP address invalid") + } + buffer.Write([]byte{0x00, 0x04}) + buffer.Write(target[1 : 1+net.IPv4len+2]) + case atypIPv6: + if len(target) < 1+net.IPv6len+2 { + return 0, errors.New("snell UDP address invalid") + } + buffer.Write([]byte{0x00, 0x06}) + buffer.Write(target[1 : 1+net.IPv6len+2]) + default: + return 0, errors.New("snell UDP address invalid") + } + buffer.Write(payload) + if fw, ok := w.(packetFrameWriter); ok { + _, err := fw.WritePacketFrame(buffer.Bytes()) + if err != nil { + return 0, err + } + return len(payload), nil + } + _, err := w.Write(buffer.Bytes()) + if err != nil { + return 0, err + } + return len(payload), nil +} + +func udpRequestHeaderLength(target []byte) int { + if len(target) == 0 { + return maxLength + 1 + } + switch target[0] { + case atypDomainName: + if len(target) < 2 { + return maxLength + 1 + } + return 1 + 1 + int(target[1]) + 2 + case atypIPv4: + return 1 + 2 + net.IPv4len + 2 + case atypIPv6: + return 1 + 2 + net.IPv6len + 2 + default: + return maxLength + 1 + } +} + +type packetConn struct { + net.Conn + rMux sync.Mutex + wMux sync.Mutex +} + +func (pc *packetConn) WritePacketFrame(b []byte) (int, error) { + if s, ok := pc.Conn.(*Snell); ok { + if fw, ok := s.Conn.(packetFrameWriter); ok { + return fw.WritePacketFrame(b) + } + } + return pc.Conn.Write(b) +} + +func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { + pc.wMux.Lock() + defer pc.wMux.Unlock() + return WritePacket(pc, parseAddrToSocksAddr(addr), b) +} + +func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { + pc.rMux.Lock() + defer pc.rMux.Unlock() + addr, n, err := ReadPacket(pc.Conn, b) + if err != nil { + return 0, nil, err + } + return n, addr, nil +} diff --git a/transport/snell/v4.go b/transport/snell/v4.go new file mode 100644 index 00000000..4e88ac56 --- /dev/null +++ b/transport/snell/v4.go @@ -0,0 +1,463 @@ +package snell + +import ( + "crypto/cipher" + cryptorand "crypto/rand" + "encoding/binary" + "errors" + "io" + "math" + "math/big" + "math/bits" + "net" + "sync" + "time" +) + +const ( + v4SaltSize = 16 + v4NonceSize = 12 + v4HeaderPlainSize = 7 + v4HeaderCipherSize = v4HeaderPlainSize + 16 + v4FrameSize = 1460 + v4InitialPaddingMin = 0x100 + v4InitialPaddingSpan = 0x100 +) + +type v4Conn struct { + net.Conn + psk []byte + r *v4Reader + w *v4Writer +} + +func newV4Conn(conn net.Conn, psk []byte) *v4Conn { + return &v4Conn{Conn: conn, psk: psk} +} + +func (c *v4Conn) initReader() error { + salt := make([]byte, v4SaltSize) + if _, err := io.ReadFull(c.Conn, salt); err != nil { + return err + } + aead, err := v4AEAD(c.psk, salt) + if err != nil { + return err + } + c.r = &v4Reader{Reader: c.Conn, aead: aead} + return nil +} + +func (c *v4Conn) initWriter() error { + w, err := newV4Writer(c.Conn, c.psk) + if err != nil { + return err + } + c.w = w + return nil +} + +func (c *v4Conn) Read(b []byte) (int, error) { + if c.r == nil { + if err := c.initReader(); err != nil { + return 0, err + } + } + return c.r.Read(b) +} + +func (c *v4Conn) Write(b []byte) (int, error) { + if c.w == nil { + if err := c.initWriter(); err != nil { + return 0, err + } + } + return c.w.Write(b) +} + +func (c *v4Conn) WritePacketFrame(b []byte) (int, error) { + if len(b) > maxLength { + return 0, errors.New("snell v4 frame too large") + } + if c.w == nil { + if err := c.initWriter(); err != nil { + return 0, err + } + } + c.w.mux.Lock() + defer c.w.mux.Unlock() + if err := c.w.writeFrame(b, c.w.nextFramePaddingLength(len(b))); err != nil { + return 0, err + } + return len(b), nil +} + +func (c *v4Conn) WriteTo(w io.Writer) (int64, error) { + if c.r == nil { + if err := c.initReader(); err != nil { + return 0, err + } + } + var written int64 + buf := make([]byte, maxLength) + for { + n, err := c.r.Read(buf) + if n > 0 { + nw, ew := w.Write(buf[:n]) + written += int64(nw) + if ew != nil { + return written, ew + } + if nw != n { + return written, io.ErrShortWrite + } + } + if err != nil { + if err == io.EOF { + err = nil + } + return written, err + } + } +} + +func (c *v4Conn) ReadFrom(r io.Reader) (int64, error) { + if c.w == nil { + if err := c.initWriter(); err != nil { + return 0, err + } + } + var read int64 + buf := make([]byte, maxLength) + for { + n, err := r.Read(buf) + if n > 0 { + read += int64(n) + if _, ew := c.w.Write(buf[:n]); ew != nil { + return read, ew + } + } + if err != nil { + if err == io.EOF { + err = nil + } + return read, err + } + } +} + +func v4AEAD(psk, salt []byte) (cipher.AEAD, error) { + return aesGCM(snellKDF(psk, salt, 16)) +} + +type v4Reader struct { + io.Reader + aead cipher.AEAD + nonce [v4NonceSize]byte + buf []byte + mux sync.Mutex +} + +func (r *v4Reader) Read(b []byte) (int, error) { + r.mux.Lock() + defer r.mux.Unlock() + if len(r.buf) == 0 { + payload, err := r.readFrame() + if err != nil { + return 0, err + } + r.buf = payload + } + n := copy(b, r.buf) + r.buf = r.buf[n:] + return n, nil +} + +func (r *v4Reader) readFrame() ([]byte, error) { + headerCipher := make([]byte, v4HeaderCipherSize) + if _, err := io.ReadFull(r.Reader, headerCipher); err != nil { + return nil, err + } + header, err := r.aead.Open(headerCipher[:0], r.nonce[:], headerCipher, nil) + incrementV4Nonce(r.nonce[:]) + if err != nil { + return nil, err + } + if len(header) != v4HeaderPlainSize || header[0] != 4 { + return nil, errors.New("snell v4 invalid frame header") + } + paddingLength := int(binary.BigEndian.Uint16(header[3:5])) + payloadLength := int(binary.BigEndian.Uint16(header[5:7])) + if payloadLength == 0 { + if paddingLength != 0 { + return nil, errors.New("snell v4 zero chunk with padding") + } + return nil, ErrZeroChunk + } + if payloadLength > maxLength || paddingLength > maxLength { + return nil, errors.New("snell v4 frame too large") + } + payloadCipherLength := payloadLength + r.aead.Overhead() + frame := make([]byte, paddingLength+payloadCipherLength) + if _, err := io.ReadFull(r.Reader, frame); err != nil { + return nil, err + } + if paddingLength > 0 { + swapPadding(frame[:paddingLength], frame[paddingLength:]) + } + payloadCipher := frame[paddingLength:] + payload, err := r.aead.Open(payloadCipher[:0], r.nonce[:], payloadCipher, nil) + incrementV4Nonce(r.nonce[:]) + if err != nil { + return nil, err + } + return payload, nil +} + +type v4Writer struct { + io.Writer + aead cipher.AEAD + nonce [v4NonceSize]byte + salt [v4SaltSize]byte + saltSent bool + initialPaddingLength uint16 + payloadLimit uint16 + lastWrite time.Time + mux sync.Mutex +} + +func newV4Writer(w io.Writer, psk []byte) (*v4Writer, error) { + var salt [v4SaltSize]byte + if _, err := io.ReadFull(cryptorand.Reader, salt[:]); err != nil { + return nil, err + } + aead, err := v4AEAD(psk, salt[:]) + if err != nil { + return nil, err + } + paddingDelta, err := cryptoRandomInt(v4InitialPaddingSpan) + if err != nil { + return nil, err + } + return &v4Writer{ + Writer: w, + aead: aead, + salt: salt, + initialPaddingLength: uint16(v4InitialPaddingMin + paddingDelta), + }, nil +} + +func (w *v4Writer) Write(b []byte) (int, error) { + w.mux.Lock() + defer w.mux.Unlock() + if len(b) == 0 { + return 0, w.writeFrame(nil, 0) + } + written := 0 + for written < len(b) { + payloadLimit := int(w.nextPayloadLimit()) + if payloadLimit <= 0 || payloadLimit > maxLength { + payloadLimit = maxLength + } + end := written + payloadLimit + if end > len(b) { + end = len(b) + } + paddingLength := w.nextFramePaddingLength(end - written) + if err := w.writeFrame(b[written:end], paddingLength); err != nil { + return written, err + } + written = end + } + return written, nil +} + +func (w *v4Writer) nextPayloadLimit() uint16 { + now := time.Now() + var payloadLimit uint16 + switch { + case w.lastWrite.IsZero(): + payloadLimit = v4FrameSize - 55 - w.initialPaddingLength + case now.Sub(w.lastWrite) > 30*time.Second: + payloadLimit = v4FrameSize - 39 + default: + payloadLimit = w.payloadLimit + } + w.lastWrite = now + if payloadLimit <= maxLength-1 { + next := int(payloadLimit) + v4FrameSize - 39 + if next > maxLength { + next = maxLength + } + w.payloadLimit = uint16(next) + } else { + w.payloadLimit = maxLength + } + return payloadLimit +} + +func (w *v4Writer) nextFramePaddingLength(payloadLength int) int { + if w.saltSent || payloadLength == 0 { + return 0 + } + return int(w.initialPaddingLength) +} + +func (w *v4Writer) writeFrame(payload []byte, paddingLength int) error { + if len(payload) > maxLength || paddingLength > maxLength { + return errors.New("snell v4 frame too large") + } + if len(payload) == 0 && paddingLength != 0 { + return errors.New("snell v4 zero chunk with padding") + } + header := make([]byte, v4HeaderPlainSize) + header[0] = 4 + binary.BigEndian.PutUint16(header[3:5], uint16(paddingLength)) + binary.BigEndian.PutUint16(header[5:7], uint16(len(payload))) + headerCipher := w.aead.Seal(nil, w.nonce[:], header, nil) + incrementV4Nonce(w.nonce[:]) + var payloadCipher []byte + if len(payload) > 0 { + payloadCipher = w.aead.Seal(nil, w.nonce[:], payload, nil) + incrementV4Nonce(w.nonce[:]) + } + frameLength := len(headerCipher) + paddingLength + len(payloadCipher) + if !w.saltSent { + frameLength += v4SaltSize + } + frame := make([]byte, 0, frameLength) + if !w.saltSent { + frame = append(frame, w.salt[:]...) + w.saltSent = true + } + frame = append(frame, headerCipher...) + if paddingLength > 0 { + padding, err := makeV4Padding(payloadCipher, paddingLength) + if err != nil { + return err + } + swapPadding(padding, payloadCipher) + frame = append(frame, padding...) + } + frame = append(frame, payloadCipher...) + return writeFull(w.Writer, frame) +} + +func swapPadding(padding, payloadCipher []byte) { + limit := len(padding) + if len(payloadCipher) < limit { + limit = len(payloadCipher) + } + for i := 0; i < limit; i += 2 { + padding[i], payloadCipher[i] = payloadCipher[i], padding[i] + } +} + +func makeV4Padding(payloadCipher []byte, paddingLength int) ([]byte, error) { + if paddingLength <= 0 { + return nil, nil + } + payloadOnes := countV4PayloadOnes(payloadCipher) + payloadZeros := 8*len(payloadCipher) - payloadOnes + if payloadZeros <= 0 { + return makeV4RandomPadding(paddingLength) + } + ratio := float64(payloadOnes) / float64(payloadZeros) + if ratio <= 0.5 || ratio >= 1.6 { + return makeV4RandomPadding(paddingLength) + } + targetRatioBase := 1.6 + if payloadZeros < payloadOnes { + targetRatioBase = 0.4 + } + jitter, err := randomUnitFloat64() + if err != nil { + return nil, err + } + targetRatio := targetRatioBase + jitter/10 + totalBits := 8 * (paddingLength + len(payloadCipher)) + targetOnes := int(float64(totalBits)*(targetRatio/(targetRatio+1)) - float64(payloadOnes)) + if targetOnes < 0 || targetOnes > 8*paddingLength { + return makeV4RandomPadding(paddingLength) + } + return makeV4BitCountPadding(paddingLength, targetOnes) +} + +func countV4PayloadOnes(payloadCipher []byte) int { + limit := len(payloadCipher) &^ 3 + ones := 0 + for _, b := range payloadCipher[:limit] { + ones += bits.OnesCount8(b) + } + return ones +} + +func makeV4RandomPadding(length int) ([]byte, error) { + padding := make([]byte, length) + _, err := io.ReadFull(cryptorand.Reader, padding) + return padding, err +} + +func makeV4BitCountPadding(length, oneBits int) ([]byte, error) { + totalBits := 8 * length + if oneBits < 0 || oneBits > totalBits { + return nil, errors.New("snell v4 invalid padding bit count") + } + bitset := make([]byte, totalBits) + for i := 0; i < oneBits; i++ { + bitset[i] = 1 + } + for i := totalBits - 1; i > 0; i-- { + j, err := cryptoRandomInt(i + 1) + if err != nil { + return nil, err + } + bitset[i], bitset[j] = bitset[j], bitset[i] + } + padding := make([]byte, length) + for i, bit := range bitset { + if bit == 1 { + padding[i/8] |= 1 << uint(i%8) + } + } + return padding, nil +} + +func cryptoRandomInt(max int) (int, error) { + n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(int64(max))) + if err != nil { + return 0, err + } + return int(n.Int64()), nil +} + +func randomUnitFloat64() (float64, error) { + n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(1<<53)) + if err != nil { + return 0, err + } + return float64(n.Int64()) / math.Exp2(53), nil +} + +func writeFull(w io.Writer, p []byte) error { + for len(p) > 0 { + n, err := w.Write(p) + if err != nil { + return err + } + if n == 0 { + return io.ErrShortWrite + } + p = p[n:] + } + return nil +} + +func incrementV4Nonce(nonce []byte) { + for i := range nonce { + nonce[i]++ + if nonce[i] != 0 { + return + } + } +} diff --git a/transport/sudoku/multiplex/session.go b/transport/sudoku/multiplex/session.go index 4344d8e7..7f5f662d 100644 --- a/transport/sudoku/multiplex/session.go +++ b/transport/sudoku/multiplex/session.go @@ -18,9 +18,12 @@ const ( ) const ( - headerSize = 1 + 4 + 4 - maxFrameSize = 256 * 1024 - maxDataPayload = 32 * 1024 + headerSize = 1 + 4 + 4 + // maxQueuedBytesPerStream bounds unread payload retained by a single logical stream. + // Backpressure is applied to the demux loop instead of dropping data. + maxQueuedBytesPerStream = 4 * 1024 * 1024 + maxFrameSize = 256 * 1024 + maxDataPayload = 128 * 1024 ) type acceptEvent struct { @@ -344,6 +347,8 @@ type stream struct { closeErr error readBuf []byte queue [][]byte + // queuedBytes includes unread bytes in readBuf and queue. + queuedBytes int localAddr net.Addr remoteAddr net.Addr @@ -362,16 +367,20 @@ func newStream(session *Session, id uint32) *stream { func (c *stream) enqueue(payload []byte) { c.mu.Lock() + for !c.closed && c.queuedBytes+len(payload) > maxQueuedBytesPerStream { + c.cond.Wait() + } if c.closed { c.mu.Unlock() return } + c.queuedBytes += len(payload) if len(c.readBuf) == 0 && len(c.queue) == 0 { c.readBuf = payload } else { c.queue = append(c.queue, payload) } - c.cond.Signal() + c.cond.Broadcast() c.mu.Unlock() } @@ -413,7 +422,11 @@ func (c *stream) Read(p []byte) (int, error) { } if len(c.readBuf) == 0 && len(c.queue) > 0 { c.readBuf = c.queue[0] + c.queue[0] = nil c.queue = c.queue[1:] + if len(c.queue) == 0 { + c.queue = nil + } } if len(c.readBuf) == 0 && c.closed { if c.closeErr == nil { @@ -424,6 +437,14 @@ func (c *stream) Read(p []byte) (int, error) { n := copy(p, c.readBuf) c.readBuf = c.readBuf[n:] + if len(c.readBuf) == 0 { + c.readBuf = nil + } + c.queuedBytes -= n + if c.queuedBytes < 0 { + c.queuedBytes = 0 + } + c.cond.Broadcast() return n, nil } diff --git a/transport/sudoku/multiplex/session_backpressure_test.go b/transport/sudoku/multiplex/session_backpressure_test.go new file mode 100644 index 00000000..3238262c --- /dev/null +++ b/transport/sudoku/multiplex/session_backpressure_test.go @@ -0,0 +1,91 @@ +package multiplex + +import ( + "bytes" + "crypto/rand" + "io" + "net" + "sync" + "testing" + "time" +) + +// TestSession_LargeTransferBackpressure verifies that a transfer larger than +// maxQueuedBytesPerStream completes correctly: the demux loop applies +// backpressure (cond.Wait) instead of dropping data, and the reader draining +// the stream wakes the blocked loop without deadlock. +func TestSession_LargeTransferBackpressure(t *testing.T) { + c1, c2 := net.Pipe() + + client, err := NewClientSession(c1) + if err != nil { + t.Fatalf("client session: %v", err) + } + server, err := NewServerSession(c2) + if err != nil { + t.Fatalf("server session: %v", err) + } + defer client.Close() + defer server.Close() + + // Payload bigger than the per-stream backpressure window (4MB). + const total = 12 * 1024 * 1024 + payload := make([]byte, total) + if _, err := rand.Read(payload); err != nil { + t.Fatalf("rand: %v", err) + } + + var wg sync.WaitGroup + wg.Add(2) + + var writeErr error + go func() { + defer wg.Done() + stream, err := client.OpenStream([]byte("hello")) + if err != nil { + writeErr = err + return + } + defer stream.Close() + if _, err := stream.Write(payload); err != nil { + writeErr = err + return + } + _ = stream.(interface{ CloseWrite() error }).CloseWrite() + }() + + var got []byte + var readErr error + go func() { + defer wg.Done() + stream, openPayload, err := server.AcceptStream() + if err != nil { + readErr = err + return + } + if string(openPayload) != "hello" { + readErr = io.ErrUnexpectedEOF + return + } + got, readErr = io.ReadAll(stream) + }() + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + + select { + case <-done: + case <-time.After(30 * time.Second): + t.Fatal("transfer deadlocked (backpressure did not release)") + } + + if writeErr != nil { + t.Fatalf("write: %v", writeErr) + } + if readErr != nil { + t.Fatalf("read: %v", readErr) + } + if !bytes.Equal(got, payload) { + t.Fatalf("payload mismatch: got %d bytes, want %d", len(got), len(payload)) + } +} diff --git a/transport/sudoku/obfs/sudoku/ascii_mode_test.go b/transport/sudoku/obfs/sudoku/ascii_mode_test.go new file mode 100644 index 00000000..7e0094d6 --- /dev/null +++ b/transport/sudoku/obfs/sudoku/ascii_mode_test.go @@ -0,0 +1,56 @@ +package sudoku + +import "testing" + +func TestNormalizeASCIIMode(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"", "prefer_entropy"}, + {"entropy", "prefer_entropy"}, + {"prefer_ascii", "prefer_ascii"}, + {"up_ascii_down_entropy", "up_ascii_down_entropy"}, + {"up_entropy_down_ascii", "up_entropy_down_ascii"}, + {"up_prefer_ascii_down_prefer_entropy", "up_ascii_down_entropy"}, + } + + for _, tt := range tests { + got, err := NormalizeASCIIMode(tt.in) + if err != nil { + t.Fatalf("NormalizeASCIIMode(%q): %v", tt.in, err) + } + if got != tt.want { + t.Fatalf("NormalizeASCIIMode(%q) = %q, want %q", tt.in, got, tt.want) + } + } + + if _, err := NormalizeASCIIMode("up_ascii_down_binary"); err == nil { + t.Fatalf("expected invalid directional mode to fail") + } +} + +func TestNewTableWithCustomDirectionalOpposite(t *testing.T) { + table, err := NewTableWithCustom("seed", "up_ascii_down_entropy", "xpxvvpvv") + if err != nil { + t.Fatalf("NewTableWithCustom: %v", err) + } + if !table.IsASCII { + t.Fatalf("uplink table should be ascii") + } + opposite := table.OppositeDirection() + if opposite == nil || opposite == table { + t.Fatalf("expected distinct opposite table") + } + if opposite.IsASCII { + t.Fatalf("downlink table should be entropy/custom") + } + + symmetric, err := NewTableWithCustom("seed", "prefer_ascii", "xpxvvpvv") + if err != nil { + t.Fatalf("NewTableWithCustom symmetric: %v", err) + } + if symmetric.OppositeDirection() != symmetric { + t.Fatalf("symmetric table should point to itself") + } +} diff --git a/transport/sudoku/obfs/sudoku/conn.go b/transport/sudoku/obfs/sudoku/conn.go index 12998d5e..e19cf221 100644 --- a/transport/sudoku/obfs/sudoku/conn.go +++ b/transport/sudoku/obfs/sudoku/conn.go @@ -3,6 +3,7 @@ package sudoku import ( "bufio" "bytes" + "io" "net" "sync" "sync/atomic" @@ -10,6 +11,8 @@ import ( const IOBufferSize = 32 * 1024 +const minDecodeReadSize = 64 + var perm4 = [24][4]byte{ {0, 1, 2, 3}, {0, 1, 3, 2}, @@ -52,7 +55,7 @@ type Conn struct { writeMu sync.Mutex writeBuf []byte - rng randomSource + rng *sudokuRand paddingThreshold uint64 } @@ -97,6 +100,9 @@ func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn { } func (sc *Conn) StopRecording() { + if sc == nil { + return + } sc.recordLock.Lock() sc.recording.Store(false) sc.recorder = nil @@ -115,6 +121,9 @@ func (sc *Conn) GetBufferedAndRecorded() []byte { if sc.recorder != nil { recorded = sc.recorder.Bytes() } + if sc.reader == nil { + return recorded + } buffered := sc.reader.Buffered() if buffered > 0 { @@ -131,6 +140,9 @@ func (sc *Conn) Write(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } + if sc == nil || sc.Conn == nil || sc.table == nil || sc.table.layout == nil || sc.rng == nil { + return 0, io.ErrClosedPipe + } sc.writeMu.Lock() defer sc.writeMu.Unlock() @@ -140,16 +152,19 @@ func (sc *Conn) Write(p []byte) (n int, err error) { } func (sc *Conn) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if sc == nil || sc.Conn == nil || sc.reader == nil || len(sc.rawBuf) == 0 || sc.table == nil || sc.table.layout == nil { + return 0, io.ErrClosedPipe + } if n, ok := drainPending(p, &sc.pendingData); ok { return n, nil } + outN := 0 for { - if sc.pendingData.available() > 0 { - break - } - - nr, rErr := sc.reader.Read(sc.rawBuf) + nr, rErr := readRawLimited(sc.Conn, sc.reader, sc.rawBuf[:sudokuReadSize(len(p)-outN, len(sc.rawBuf))]) if nr > 0 { chunk := sc.rawBuf[:nr] if sc.recording.Load() { @@ -160,34 +175,80 @@ func (sc *Conn) Read(p []byte) (n int, err error) { sc.recordLock.Unlock() } - layout := sc.table.layout - for _, b := range chunk { + table := sc.table + layout := table.layout + for i := 0; i < len(chunk); { + if sc.hintCount == 0 && outN < len(p) && i+3 < len(chunk) && + layout.hintTable[chunk[i]] && + layout.hintTable[chunk[i+1]] && + layout.hintTable[chunk[i+2]] && + layout.hintTable[chunk[i+3]] { + val, ok := table.DecodeMap[packHintBytes(chunk[i], chunk[i+1], chunk[i+2], chunk[i+3])] + if !ok { + return 0, ErrInvalidSudokuMapMiss + } + p[outN] = val + outN++ + i += 4 + continue + } + + b := chunk[i] + i++ if !layout.hintTable[b] { continue } sc.hintBuf[sc.hintCount] = b sc.hintCount++ - if sc.hintCount == len(sc.hintBuf) { - key := packHintsToKey(sc.hintBuf) - val, ok := sc.table.DecodeMap[key] - if !ok { - return 0, ErrInvalidSudokuMapMiss - } - sc.pendingData.appendByte(val) - sc.hintCount = 0 + if sc.hintCount != len(sc.hintBuf) { + continue } + + val, ok := table.DecodeMap[packHintBytes(sc.hintBuf[0], sc.hintBuf[1], sc.hintBuf[2], sc.hintBuf[3])] + if !ok { + return 0, ErrInvalidSudokuMapMiss + } + outN = appendDecodedByte(p, outN, &sc.pendingData, val) + sc.hintCount = 0 } } if rErr != nil { + if outN > 0 { + return outN, nil + } + if n, ok := drainPending(p, &sc.pendingData); ok { + return n, nil + } return 0, rErr } - if sc.pendingData.available() > 0 { - break + if outN > 0 { + return outN, nil } } - - n, _ = drainPending(p, &sc.pendingData) - return n, nil +} + +func sudokuReadSize(decodedRemaining, maxRaw int) int { + if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 { + return maxRaw + } + if decodedRemaining > (maxRaw-minDecodeReadSize)/5 { + return maxRaw + } + + return decodedRemaining*5 + minDecodeReadSize +} + +func readRawLimited(conn net.Conn, reader *bufio.Reader, dst []byte) (int, error) { + if len(dst) == 0 { + return 0, nil + } + if reader != nil && reader.Buffered() > 0 { + return reader.Read(dst) + } + if conn == nil { + return 0, io.ErrClosedPipe + } + return conn.Read(dst) } diff --git a/transport/sudoku/obfs/sudoku/conn_roundtrip_test.go b/transport/sudoku/obfs/sudoku/conn_roundtrip_test.go new file mode 100644 index 00000000..e2aeac7f --- /dev/null +++ b/transport/sudoku/obfs/sudoku/conn_roundtrip_test.go @@ -0,0 +1,51 @@ +package sudoku + +import ( + "bytes" + "io" + "testing" +) + +// TestConn_Roundtrip exercises the optimized Conn encode/decode hot paths: +// the no-padding fast path (pMin==pMax==0), the always-padding path +// (pMin==pMax==100), a probabilistic range, and the adaptive read-size / +// 4-byte fast hint decode path across a variety of payload sizes and modes. +func TestConn_Roundtrip(t *testing.T) { + modes := []string{"prefer_entropy", "prefer_ascii"} + paddings := []struct{ min, max int }{ + {0, 0}, // no-padding specialized path + {100, 100}, // always-padding specialized path + {20, 60}, // probabilistic path + } + sizes := []int{1, 3, 4, 7, 16, 100, 1000, 64 * 1024} + + for _, mode := range modes { + for _, pad := range paddings { + for _, size := range sizes { + payload := make([]byte, size) + for i := range payload { + payload[i] = byte(i*31 + 7) + } + + table := NewTable("conn-roundtrip-seed", mode) + + // Encode via Conn.Write. + w := &mockConn{} + enc := NewConn(w, table, pad.min, pad.max, false) + if _, err := enc.Write(payload); err != nil { + t.Fatalf("mode=%s pad=%v size=%d write: %v", mode, pad, size, err) + } + + // Decode via Conn.Read using the same table. + dec := NewConn(&mockConn{readBuf: w.writeBuf}, table, pad.min, pad.max, false) + got := make([]byte, size) + if _, err := io.ReadFull(dec, got); err != nil { + t.Fatalf("mode=%s pad=%v size=%d read: %v", mode, pad, size, err) + } + if !bytes.Equal(got, payload) { + t.Fatalf("mode=%s pad=%v size=%d roundtrip mismatch", mode, pad, size) + } + } + } + } +} diff --git a/transport/sudoku/obfs/sudoku/encode.go b/transport/sudoku/obfs/sudoku/encode.go index cfcf571e..f93c8d99 100644 --- a/transport/sudoku/obfs/sudoku/encode.go +++ b/transport/sudoku/obfs/sudoku/encode.go @@ -1,9 +1,12 @@ package sudoku -func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThreshold uint64, p []byte) []byte { +func encodeSudokuPayload(dst []byte, table *Table, rng *sudokuRand, paddingThreshold uint64, p []byte) []byte { if len(p) == 0 { return dst[:0] } + if paddingThreshold == 0 { + return encodeSudokuPayloadNoPadding(dst, table, rng, p) + } outCapacity := len(p)*6 + 1 if cap(dst) < outCapacity { @@ -13,8 +16,25 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre pads := table.PaddingPool padLen := len(pads) + if paddingThreshold >= probOne { + for _, b := range p { + out = append(out, pads[rng.Intn(padLen)]) + + puzzles := table.EncodeTable[b] + puzzle := puzzles[rng.Intn(len(puzzles))] + + perm := perm4[rng.Intn(len(perm4))] + for _, idx := range perm { + out = append(out, pads[rng.Intn(padLen)], puzzle[idx]) + } + } + + out = append(out, pads[rng.Intn(padLen)]) + return out + } + for _, b := range p { - if shouldPad(rng, paddingThreshold) { + if uint64(rng.Uint32()) < paddingThreshold { out = append(out, pads[rng.Intn(padLen)]) } @@ -22,15 +42,31 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre puzzle := puzzles[rng.Intn(len(puzzles))] perm := perm4[rng.Intn(len(perm4))] for _, idx := range perm { - if shouldPad(rng, paddingThreshold) { + if uint64(rng.Uint32()) < paddingThreshold { out = append(out, pads[rng.Intn(padLen)]) } out = append(out, puzzle[idx]) } } - if shouldPad(rng, paddingThreshold) { + if uint64(rng.Uint32()) < paddingThreshold { out = append(out, pads[rng.Intn(padLen)]) } return out } + +func encodeSudokuPayloadNoPadding(dst []byte, table *Table, rng *sudokuRand, p []byte) []byte { + outCapacity := len(p) * 4 + if cap(dst) < outCapacity { + dst = make([]byte, 0, outCapacity) + } + out := dst[:0] + + for _, b := range p { + puzzles := table.EncodeTable[b] + puzzle := puzzles[rng.Intn(len(puzzles))] + perm := perm4[rng.Intn(len(perm4))] + out = append(out, puzzle[perm[0]], puzzle[perm[1]], puzzle[perm[2]], puzzle[perm[3]]) + } + return out +} diff --git a/transport/sudoku/obfs/sudoku/packed.go b/transport/sudoku/obfs/sudoku/packed.go index 16e9f968..fc06436a 100644 --- a/transport/sudoku/obfs/sudoku/packed.go +++ b/transport/sudoku/obfs/sudoku/packed.go @@ -8,9 +8,9 @@ import ( ) const ( - RngBatchSize = 128 - packedProtectedPrefixBytes = 14 + packedIOBufferSize = 64 * 1024 + packedDecodeBufferSize = 96 * 1024 ) // PackedConn encodes traffic with the packed Sudoku layout while preserving @@ -35,7 +35,7 @@ type PackedConn struct { readBits int // Padding selection matches Conn's threshold-based model. - rng randomSource + rng *sudokuRand paddingThreshold uint64 padMarker byte padPool []byte @@ -67,18 +67,20 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn { pc := &PackedConn{ Conn: c, table: table, - reader: bufio.NewReaderSize(c, IOBufferSize), - rawBuf: make([]byte, IOBufferSize), + reader: bufio.NewReaderSize(c, packedIOBufferSize), + rawBuf: make([]byte, packedDecodeBufferSize), pendingData: newPendingBuffer(4096), writeBuf: make([]byte, 0, 4096), rng: localRng, paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax), } - pc.padMarker = table.layout.padMarker - for _, b := range table.PaddingPool { - if b != pc.padMarker { - pc.padPool = append(pc.padPool, b) + if table != nil && table.layout != nil { + pc.padMarker = table.layout.padMarker + for _, b := range table.PaddingPool { + if b != pc.padMarker { + pc.padPool = append(pc.padPool, b) + } } } if len(pc.padPool) == 0 { @@ -87,18 +89,6 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn { return pc } -func (pc *PackedConn) maybeAddPadding(out []byte) []byte { - if shouldPad(pc.rng, pc.paddingThreshold) { - out = append(out, pc.getPaddingByte()) - } - return out -} - -func (pc *PackedConn) appendGroup(out []byte, group byte) []byte { - out = pc.maybeAddPadding(out) - return append(out, pc.table.layout.groupByte(group)) -} - func (pc *PackedConn) appendForcedPadding(out []byte) []byte { return append(out, pc.getPaddingByte()) } @@ -134,7 +124,7 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) { } else { pc.bitBuf &= (1 << pc.bitCount) - 1 } - out = pc.appendGroup(out, group&0x3F) + out = appendPackedGroup(out, pc.table.layout, pc.rng, pc.paddingThreshold, pc.padPool, group) } effective++ @@ -148,19 +138,49 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) { return out, limit } +func appendPackedGroup(out []byte, layout *byteLayout, rng *sudokuRand, paddingThreshold uint64, padPool []byte, group byte) []byte { + if paddingThreshold != 0 { + u := rng.Uint32() + if uint64(u) < paddingThreshold { + out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))]) + } + } + return append(out, layout.encodeGroup[group&0x3F]) +} + +func maybeAppendPackedPadding(out []byte, rng *sudokuRand, paddingThreshold uint64, padPool []byte) []byte { + if paddingThreshold != 0 { + u := rng.Uint32() + if uint64(u) < paddingThreshold { + out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))]) + } + } + return out +} + func (pc *PackedConn) Write(p []byte) (int, error) { if len(p) == 0 { return 0, nil } + if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 { + return 0, io.ErrClosedPipe + } pc.writeMu.Lock() defer pc.writeMu.Unlock() needed := len(p)*3/2 + 32 + if pc.paddingThreshold == 0 { + needed = ((len(p)+2)/3)*4 + 32 + } if cap(pc.writeBuf) < needed { pc.writeBuf = make([]byte, 0, needed) } out := pc.writeBuf[:0] + layout := pc.table.layout + rng := pc.rng + paddingThreshold := pc.paddingThreshold + padPool := pc.padPool var prefixN int out, prefixN = pc.writeProtectedPrefix(out, p) @@ -181,7 +201,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) { } else { pc.bitBuf &= (1 << pc.bitCount) - 1 } - out = pc.appendGroup(out, group&0x3F) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group) } } @@ -195,10 +215,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) { g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) g4 := b3 & 0x3F - out = pc.appendGroup(out, g1) - out = pc.appendGroup(out, g2) - out = pc.appendGroup(out, g3) - out = pc.appendGroup(out, g4) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4) } } @@ -211,10 +231,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) { g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) g4 := b3 & 0x3F - out = pc.appendGroup(out, g1) - out = pc.appendGroup(out, g2) - out = pc.appendGroup(out, g3) - out = pc.appendGroup(out, g4) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4) } for ; i < n; i++ { @@ -229,7 +249,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) { } else { pc.bitBuf &= (1 << pc.bitCount) - 1 } - out = pc.appendGroup(out, group&0x3F) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group) } } @@ -237,11 +257,11 @@ func (pc *PackedConn) Write(p []byte) (int, error) { group := byte(pc.bitBuf << (6 - pc.bitCount)) pc.bitBuf = 0 pc.bitCount = 0 - out = pc.appendGroup(out, group&0x3F) + out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group) out = append(out, pc.padMarker) } - out = pc.maybeAddPadding(out) + out = maybeAppendPackedPadding(out, rng, paddingThreshold, padPool) if len(out) > 0 { pc.writeBuf = out[:0] @@ -252,6 +272,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) { } func (pc *PackedConn) Flush() error { + if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 { + return io.ErrClosedPipe + } + pc.writeMu.Lock() defer pc.writeMu.Unlock() @@ -265,7 +289,7 @@ func (pc *PackedConn) Flush() error { out = append(out, pc.padMarker) } - out = pc.maybeAddPadding(out) + out = maybeAppendPackedPadding(out, pc.rng, pc.paddingThreshold, pc.padPool) if len(out) > 0 { pc.writeBuf = out[:0] @@ -289,19 +313,44 @@ func writeFull(w io.Writer, b []byte) error { } func (pc *PackedConn) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if pc == nil || pc.Conn == nil || pc.reader == nil || len(pc.rawBuf) == 0 || pc.table == nil || pc.table.layout == nil { + return 0, io.ErrClosedPipe + } if n, ok := drainPending(p, &pc.pendingData); ok { return n, nil } + outN := 0 for { - nr, rErr := pc.reader.Read(pc.rawBuf) + nr, rErr := readRawLimited(pc.Conn, pc.reader, pc.rawBuf[:packedReadSize(len(p)-outN, len(pc.rawBuf))]) if nr > 0 { rBuf := pc.readBitBuf rBits := pc.readBits padMarker := pc.padMarker layout := pc.table.layout - for _, b := range pc.rawBuf[:nr] { + chunk := pc.rawBuf[:nr] + for i := 0; i < len(chunk); { + if rBits == 0 && outN+3 <= len(p) && i+3 < len(chunk) && + layout.hintTable[chunk[i]] && layout.hintTable[chunk[i+1]] && + layout.hintTable[chunk[i+2]] && layout.hintTable[chunk[i+3]] { + g1 := layout.decodeGroup[chunk[i]] + g2 := layout.decodeGroup[chunk[i+1]] + g3 := layout.decodeGroup[chunk[i+2]] + g4 := layout.decodeGroup[chunk[i+3]] + p[outN] = (g1 << 2) | (g2 >> 4) + p[outN+1] = (g2 << 4) | (g3 >> 2) + p[outN+2] = (g3 << 6) | g4 + outN += 3 + i += 4 + continue + } + + b := chunk[i] + i++ if !layout.hintTable[b] { if b == padMarker { rBuf = 0 @@ -321,7 +370,7 @@ func (pc *PackedConn) Read(p []byte) (int, error) { if rBits >= 8 { rBits -= 8 val := byte(rBuf >> rBits) - pc.pendingData.appendByte(val) + outN = appendDecodedByte(p, outN, &pc.pendingData, val) if rBits == 0 { rBuf = 0 } else { @@ -339,21 +388,32 @@ func (pc *PackedConn) Read(p []byte) (int, error) { pc.readBitBuf = 0 pc.readBits = 0 } - if pc.pendingData.available() > 0 { - break + if outN > 0 { + return outN, nil + } + if n, ok := drainPending(p, &pc.pendingData); ok { + return n, nil } return 0, rErr } - if pc.pendingData.available() > 0 { - break + if outN > 0 { + return outN, nil } } - - n, _ := drainPending(p, &pc.pendingData) - return n, nil } func (pc *PackedConn) getPaddingByte() byte { return pc.padPool[pc.rng.Intn(len(pc.padPool))] } + +func packedReadSize(decodedRemaining, maxRaw int) int { + if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 { + return maxRaw + } + if decodedRemaining > (maxRaw-minDecodeReadSize)/2 { + return maxRaw + } + + return decodedRemaining*2 + minDecodeReadSize +} diff --git a/transport/sudoku/obfs/sudoku/packed_prefix_test.go b/transport/sudoku/obfs/sudoku/packed_prefix_test.go new file mode 100644 index 00000000..f6f35fcc --- /dev/null +++ b/transport/sudoku/obfs/sudoku/packed_prefix_test.go @@ -0,0 +1,90 @@ +package sudoku + +import ( + "bytes" + "io" + "net" + "testing" + "time" +) + +type mockConn struct { + readBuf []byte + writeBuf []byte +} + +func (c *mockConn) Read(p []byte) (int, error) { + if len(c.readBuf) == 0 { + return 0, io.EOF + } + n := copy(p, c.readBuf) + c.readBuf = c.readBuf[n:] + return n, nil +} + +func (c *mockConn) Write(p []byte) (int, error) { + c.writeBuf = append(c.writeBuf, p...) + return len(p), nil +} + +func (c *mockConn) Close() error { return nil } +func (c *mockConn) LocalAddr() net.Addr { return nil } +func (c *mockConn) RemoteAddr() net.Addr { return nil } +func (c *mockConn) SetDeadline(time.Time) error { return nil } +func (c *mockConn) SetReadDeadline(time.Time) error { return nil } +func (c *mockConn) SetWriteDeadline(time.Time) error { return nil } + +func TestPackedConn_ProtectedPrefixPadding(t *testing.T) { + table := NewTable("packed-prefix-seed", "prefer_ascii") + mock := &mockConn{} + writer := NewPackedConn(mock, table, 0, 0) + writer.rng = newSudokuRand(1) + + payload := bytes.Repeat([]byte{0}, 32) + if _, err := writer.Write(payload); err != nil { + t.Fatalf("write: %v", err) + } + + wire := append([]byte(nil), mock.writeBuf...) + if len(wire) < 20 { + t.Fatalf("wire too short: %d", len(wire)) + } + + firstHint := -1 + nonHintCount := 0 + maxHintRun := 0 + currentHintRun := 0 + for i, b := range wire[:20] { + if table.layout.isHint(b) { + if firstHint == -1 { + firstHint = i + } + currentHintRun++ + if currentHintRun > maxHintRun { + maxHintRun = currentHintRun + } + continue + } + nonHintCount++ + currentHintRun = 0 + } + + if firstHint < 1 || firstHint > 2 { + t.Fatalf("expected 1-2 leading padding bytes, first hint index=%d", firstHint) + } + if nonHintCount < 6 { + t.Fatalf("expected dense prefix padding, got only %d non-hint bytes in first 20", nonHintCount) + } + if maxHintRun > 3 { + t.Fatalf("prefix still exposes long hint run: %d", maxHintRun) + } + + reader := NewPackedConn(&mockConn{readBuf: wire}, table, 0, 0) + decoded := make([]byte, len(payload)) + if _, err := io.ReadFull(reader, decoded); err != nil { + t.Fatalf("read back: %v", err) + } + if !bytes.Equal(decoded, payload) { + t.Fatalf("roundtrip mismatch") + } +} diff --git a/transport/sudoku/obfs/sudoku/padding_prob.go b/transport/sudoku/obfs/sudoku/padding_prob.go index 00ff68ff..bf32111c 100644 --- a/transport/sudoku/obfs/sudoku/padding_prob.go +++ b/transport/sudoku/obfs/sudoku/padding_prob.go @@ -2,7 +2,7 @@ package sudoku const probOne = uint64(1) << 32 -func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 { +func pickPaddingThreshold(r *sudokuRand, pMin, pMax int) uint64 { if r == nil { return 0 } @@ -28,7 +28,7 @@ func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 { return min + (u * (max - min) >> 32) } -func shouldPad(r randomSource, threshold uint64) bool { +func shouldPad(r *sudokuRand, threshold uint64) bool { if threshold == 0 { return false } diff --git a/transport/sudoku/obfs/sudoku/pending.go b/transport/sudoku/obfs/sudoku/pending.go index 60f09a12..6c624374 100644 --- a/transport/sudoku/obfs/sudoku/pending.go +++ b/transport/sudoku/obfs/sudoku/pending.go @@ -25,7 +25,10 @@ func (p *pendingBuffer) reset() { } func (p *pendingBuffer) ensureAppendCapacity(extra int) { - if p == nil || extra <= 0 || p.off == 0 { + if p == nil || extra <= 0 { + return + } + if p.off == 0 { return } if cap(p.data)-len(p.data) >= extra { @@ -43,6 +46,15 @@ func (p *pendingBuffer) appendByte(b byte) { p.data = append(p.data, b) } +func appendDecodedByte(dst []byte, n int, pending *pendingBuffer, b byte) int { + if n < len(dst) { + dst[n] = b + return n + 1 + } + pending.appendByte(b) + return n +} + func drainPending(dst []byte, pending *pendingBuffer) (int, bool) { if pending == nil || pending.available() == 0 { return 0, false diff --git a/transport/sudoku/obfs/sudoku/rand.go b/transport/sudoku/obfs/sudoku/rand.go index abe72d80..f40a0861 100644 --- a/transport/sudoku/obfs/sudoku/rand.go +++ b/transport/sudoku/obfs/sudoku/rand.go @@ -6,14 +6,10 @@ import ( "time" ) -type randomSource interface { - Uint32() uint32 - Uint64() uint64 - Intn(n int) int -} - type sudokuRand struct { - state uint64 + state uint64 + cached uint32 + haveCached bool } func newSeededRand() *sudokuRand { @@ -37,20 +33,36 @@ func (r *sudokuRand) Uint64() uint64 { if r == nil { return 0 } - r.state += 0x9e3779b97f4a7c15 - z := r.state - z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9 - z = (z ^ (z >> 27)) * 0x94d049bb133111eb - return z ^ (z >> 31) + r.haveCached = false + x := r.state + x ^= x >> 12 + x ^= x << 25 + x ^= x >> 27 + r.state = x + return x * 0x2545f4914f6cdd1d } func (r *sudokuRand) Uint32() uint32 { - return uint32(r.Uint64() >> 32) + if r == nil { + return 0 + } + if r.haveCached { + r.haveCached = false + return r.cached + } + v := r.Uint64() + r.cached = uint32(v) + r.haveCached = true + return uint32(v >> 32) } func (r *sudokuRand) Intn(n int) int { if n <= 1 { return 0 } - return int((uint64(r.Uint32()) * uint64(n)) >> 32) + return fastIntnFromUint32(r.Uint32(), n) +} + +func fastIntnFromUint32(u uint32, n int) int { + return int((uint64(u) * uint64(n)) >> 32) } diff --git a/transport/sudoku/obfs/sudoku/table.go b/transport/sudoku/obfs/sudoku/table.go index b32506ae..d355ae0e 100644 --- a/transport/sudoku/obfs/sudoku/table.go +++ b/transport/sudoku/obfs/sudoku/table.go @@ -192,23 +192,27 @@ func tableHintFingerprint(key string, mode string, uplinkPattern string, downlin } func packHintsToKey(hints [4]byte) uint32 { + return packHintBytes(hints[0], hints[1], hints[2], hints[3]) +} + +func packHintBytes(h0, h1, h2, h3 byte) uint32 { // Sorting network for 4 elements (Bubble sort unrolled) // Swap if a > b - if hints[0] > hints[1] { - hints[0], hints[1] = hints[1], hints[0] + if h0 > h1 { + h0, h1 = h1, h0 } - if hints[2] > hints[3] { - hints[2], hints[3] = hints[3], hints[2] + if h2 > h3 { + h2, h3 = h3, h2 } - if hints[0] > hints[2] { - hints[0], hints[2] = hints[2], hints[0] + if h0 > h2 { + h0, h2 = h2, h0 } - if hints[1] > hints[3] { - hints[1], hints[3] = hints[3], hints[1] + if h1 > h3 { + h1, h3 = h3, h1 } - if hints[1] > hints[2] { - hints[1], hints[2] = hints[2], hints[1] + if h1 > h2 { + h1, h2 = h2, h1 } - return uint32(hints[0])<<24 | uint32(hints[1])<<16 | uint32(hints[2])<<8 | uint32(hints[3]) + return uint32(h0)<<24 | uint32(h1)<<16 | uint32(h2)<<8 | uint32(h3) } diff --git a/transport/trusttunnel/client.go b/transport/trusttunnel/client.go index 66dedb02..eead3aae 100644 --- a/transport/trusttunnel/client.go +++ b/transport/trusttunnel/client.go @@ -14,12 +14,14 @@ import ( "sync/atomic" "time" + "github.com/sagernet/sing-box/common/congestion" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/common/logger" "github.com/sagernet/quic-go" "github.com/sagernet/quic-go/http3" @@ -50,7 +52,7 @@ type ClientOptions struct { QUIC bool CongestionControl string CWND int - BBRProfile string + Logger logger.Logger HealthCheck bool MaxConnections int MinStreams int @@ -81,7 +83,7 @@ func NewClient(ctx context.Context, options ClientOptions) (*Client, error) { healthCheck: options.HealthCheck, } if options.QUIC { - congestionControlFactory, err := NewCongestionControl(options.CongestionControl, options.CWND, options.BBRProfile, ntp.TimeFuncFromContext(ctx)) + congestionControlFactory, err := congestion.NewCongestionControl(options.CongestionControl, options.CWND, ntp.TimeFuncFromContext(ctx)) if err != nil { cancel() return nil, err diff --git a/transport/v2raygrpc/custom_name.go b/transport/v2raygrpc/custom_name.go index ce970dc6..b83249e4 100644 --- a/transport/v2raygrpc/custom_name.go +++ b/transport/v2raygrpc/custom_name.go @@ -2,6 +2,7 @@ package v2raygrpc import ( "context" + "strings" "google.golang.org/grpc" ) @@ -13,13 +14,21 @@ type GunService interface { } func ServerDesc(name string) grpc.ServiceDesc { + serviceName := name + streamName := "Tun" + if strings.Contains(name, "/") { + name = strings.TrimPrefix(name, "/") + lastSlash := strings.LastIndex(name, "/") + serviceName = name[:lastSlash] + streamName = name[lastSlash+1:] + } return grpc.ServiceDesc{ - ServiceName: name, + ServiceName: serviceName, HandlerType: (*GunServiceServer)(nil), Methods: []grpc.MethodDesc{}, Streams: []grpc.StreamDesc{ { - StreamName: "Tun", + StreamName: streamName, Handler: _GunService_Tun_Handler, ServerStreams: true, ClientStreams: true, @@ -30,7 +39,11 @@ func ServerDesc(name string) grpc.ServiceDesc { } func (c *gunServiceClient) TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GunService_TunClient, error) { - stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], "/"+name+"/Tun", opts...) + path := "/" + name + "/Tun" + if strings.Contains(name, "/") { + path = name + } + stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], path, opts...) if err != nil { return nil, err } diff --git a/transport/v2raygrpclite/client.go b/transport/v2raygrpclite/client.go index b2aab911..26cbf769 100644 --- a/transport/v2raygrpclite/client.go +++ b/transport/v2raygrpclite/client.go @@ -53,10 +53,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt DisableCompression: true, }, url: &url.URL{ - Scheme: "https", - Host: serverAddr.String(), - Path: "/" + options.ServiceName + "/Tun", - RawPath: "/" + url.PathEscape(options.ServiceName) + "/Tun", + Scheme: "https", + Host: serverAddr.String(), + Path: grpcPath(options.ServiceName), }, host: host, } diff --git a/transport/v2raygrpclite/path.go b/transport/v2raygrpclite/path.go new file mode 100644 index 00000000..5c44188f --- /dev/null +++ b/transport/v2raygrpclite/path.go @@ -0,0 +1,10 @@ +package v2raygrpclite + +import "strings" + +func grpcPath(serviceName string) string { + if strings.Contains(serviceName, "/") { + return serviceName + } + return "/" + serviceName + "/Tun" +} diff --git a/transport/v2raygrpclite/server.go b/transport/v2raygrpclite/server.go index 622d785a..4872d409 100644 --- a/transport/v2raygrpclite/server.go +++ b/transport/v2raygrpclite/server.go @@ -42,7 +42,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option. tlsConfig: tlsConfig, logger: logger, handler: handler, - path: "/" + options.ServiceName + "/Tun", + path: grpcPath(options.ServiceName), h2Server: &http2.Server{ IdleTimeout: time.Duration(options.IdleTimeout), }, diff --git a/transport/v2raykcp/sending.go b/transport/v2raykcp/sending.go index c0e59953..b27605d6 100644 --- a/transport/v2raykcp/sending.go +++ b/transport/v2raykcp/sending.go @@ -1,14 +1,14 @@ package v2raykcp import ( - "container/list" "sync" + "github.com/sagernet/sing-box/common/list" "github.com/sagernet/sing/common/buf" ) type SendingWindow struct { - cache *list.List + cache *list.List[*DataSegment] totalInFlightSize uint32 writer SegmentWriter onPacketLoss func(uint32) @@ -16,7 +16,7 @@ type SendingWindow struct { func NewSendingWindow(writer SegmentWriter, onPacketLoss func(uint32)) *SendingWindow { return &SendingWindow{ - cache: list.New(), + cache: list.New[*DataSegment](), writer: writer, onPacketLoss: onPacketLoss, } @@ -27,9 +27,9 @@ func (sw *SendingWindow) Release() { return } for sw.cache.Len() > 0 { - seg := sw.cache.Front().Value.(*DataSegment) + seg := sw.cache.Front().Value seg.Release() - sw.cache.Remove(sw.cache.Front()) + sw.cache.Front().Remove() } } @@ -50,17 +50,17 @@ func (sw *SendingWindow) Push(number uint32, b *buf.Buffer) { } func (sw *SendingWindow) FirstNumber() uint32 { - return sw.cache.Front().Value.(*DataSegment).Number + return sw.cache.Front().Value.Number } func (sw *SendingWindow) Clear(una uint32) { for !sw.IsEmpty() { - seg := sw.cache.Front().Value.(*DataSegment) + seg := sw.cache.Front().Value if seg.Number >= una { break } seg.Release() - sw.cache.Remove(sw.cache.Front()) + sw.cache.Front().Remove() } } @@ -87,8 +87,7 @@ func (sw *SendingWindow) Visit(visitor func(seg *DataSegment) bool) { } for e := sw.cache.Front(); e != nil; e = e.Next() { - seg := e.Value.(*DataSegment) - if !visitor(seg) { + if !visitor(e.Value) { break } } @@ -132,7 +131,7 @@ func (sw *SendingWindow) Remove(number uint32) bool { } for e := sw.cache.Front(); e != nil; e = e.Next() { - seg := e.Value.(*DataSegment) + seg := e.Value if seg.Number > number { return false } else if seg.Number == number { @@ -140,7 +139,7 @@ func (sw *SendingWindow) Remove(number uint32) bool { sw.totalInFlightSize-- } seg.Release() - sw.cache.Remove(e) + e.Remove() return true } } diff --git a/transport/v2rayxhttp/client.go b/transport/v2rayxhttp/client.go index e6024887..609e2473 100644 --- a/transport/v2rayxhttp/client.go +++ b/transport/v2rayxhttp/client.go @@ -16,12 +16,12 @@ import ( "github.com/sagernet/quic-go" "github.com/sagernet/quic-go/http3" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/congestion" "github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/common/xray/buf" "github.com/sagernet/sing-box/common/xray/net" "github.com/sagernet/sing-box/common/xray/pipe" "github.com/sagernet/sing-box/common/xray/signal/done" - "github.com/sagernet/sing-box/common/xray/uuid" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" qtls "github.com/sagernet/sing-quic" @@ -30,6 +30,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" sHTTP "github.com/sagernet/sing/protocol/http" "github.com/sagernet/sing/service" "golang.org/x/net/http2" @@ -42,15 +43,22 @@ type Client struct { baseRequestURL2 url.URL getHTTPClient func() (DialerClient, *XmuxClient) getHTTPClient2 func() (DialerClient, *XmuxClient) + xmuxManager *XmuxManager + xmuxManager2 *XmuxManager } func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayXHTTPOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) { - if options.Mode == "" { - return nil, E.New("mode is not set") - } if tlsConfig != nil && len(tlsConfig.NextProtos()) == 0 { tlsConfig.SetNextProtos([]string{"h2"}) } + if _, err := congestion.NewCongestionControl(options.CongestionController, options.CWND, nil); err != nil { + return nil, err + } + if options.Download != nil { + if _, err := congestion.NewCongestionControl(options.Download.CongestionController, options.Download.CWND, nil); err != nil { + return nil, err + } + } dest := serverAddr baseRequestURL, err := getBaseRequestURL(&options.V2RayXHTTPBaseOptions, dest, tlsConfig) if err != nil { @@ -61,7 +69,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s xmuxOptions = *options.Xmux } xmuxManager := NewXmuxManager(xmuxOptions, func() XmuxConn { - return createHTTPClient(dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig) + return createHTTPClient(ctx, dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig) }) getHTTPClient := func() (DialerClient, *XmuxClient) { xmuxClient := xmuxManager.GetXmuxClient(ctx) @@ -69,6 +77,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s } baseRequestURL2 := baseRequestURL getHTTPClient2 := getHTTPClient + var xmuxManager2 *XmuxManager if options.Download != nil { options2 := options.Download dialer2 := dialer @@ -98,8 +107,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s if options2.Xmux != nil { xmuxOptions2 = *options2.Xmux } - xmuxManager2 := NewXmuxManager(xmuxOptions2, func() XmuxConn { - return createHTTPClient(dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2) + xmuxManager2 = NewXmuxManager(xmuxOptions2, func() XmuxConn { + return createHTTPClient(ctx, dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2) }) getHTTPClient2 = func() (DialerClient, *XmuxClient) { xmuxClient2 := xmuxManager2.GetXmuxClient(ctx) @@ -113,6 +122,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s getHTTPClient2: getHTTPClient2, baseRequestURL: baseRequestURL, baseRequestURL2: baseRequestURL2, + xmuxManager: xmuxManager, + xmuxManager2: xmuxManager2, }, nil } @@ -121,8 +132,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { mode := c.options.Mode sessionId := "" if c.options.Mode != "stream-one" { - sessionIdUuid := uuid.New() - sessionId = sessionIdUuid.String() + sessionId = GenerateSessionID(&c.options.V2RayXHTTPBaseOptions) } requestURL := c.baseRequestURL requestURL2 := c.baseRequestURL2 @@ -182,10 +192,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { } scMaxEachPostBytes := options.GetNormalizedScMaxEachPostBytes() scMinPostsIntervalMs := options.GetNormalizedScMinPostsIntervalMs() - if scMaxEachPostBytes.From <= 0 { - panic("`scMaxEachPostBytes` should be bigger than 0") - } - maxUploadSize := scMaxEachPostBytes.Rand() + maxUploadSize := int32(scMaxEachPostBytes.Rand()) // WithSizeLimit(0) will still allow single bytes to pass, and a lot of // code relies on this behavior. Subtract 1 so that together with // uploadWriter wrapper, exact size limits can be enforced @@ -255,6 +262,10 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { } func (c *Client) Close() error { + c.xmuxManager.Close() + if c.xmuxManager2 != nil { + c.xmuxManager2.Close() + } return nil } @@ -294,7 +305,7 @@ func getBaseRequestURL(options *option.V2RayXHTTPBaseOptions, dest M.Socksaddr, return requestURL, nil } -func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient { +func createHTTPClient(ctx context.Context, dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient { httpVersion := decideHTTPVersion(tlsConfig) dialContext := func(ctxInner context.Context) (net.Conn, error) { conn, err := dialer.DialContext(ctxInner, "tcp", dest) @@ -319,6 +330,7 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH if keepAlivePeriod < 0 { keepAlivePeriod = 0 } + congestionControlFactory, _ := congestion.NewCongestionControl(options.CongestionController, options.CWND, ntp.TimeFuncFromContext(ctx)) quicConfig := &quic.Config{ MaxIdleTimeout: net.ConnIdleTimeout, // these two are defaults of quic-go/http3. the default of quic-go (no @@ -334,7 +346,14 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH if dErr != nil { return nil, dErr } - return qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg) + conn, dErr := qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg) + if dErr != nil { + return nil, dErr + } + if congestionControlFactory != nil { + conn.SetCongestionControl(congestionControlFactory(conn)) + } + return conn, nil }, } case "2": diff --git a/transport/v2rayxhttp/conn.go b/transport/v2rayxhttp/conn.go index 7da8e916..f8ac0b5d 100644 --- a/transport/v2rayxhttp/conn.go +++ b/transport/v2rayxhttp/conn.go @@ -39,7 +39,7 @@ func (c *splitConn) Close() error { } if err2 != nil { - return err + return err2 } return nil diff --git a/transport/v2rayxhttp/dialer.go b/transport/v2rayxhttp/dialer.go index 255042b1..e2e197a1 100644 --- a/transport/v2rayxhttp/dialer.go +++ b/transport/v2rayxhttp/dialer.go @@ -147,7 +147,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio if c.httpVersion != "1.1" { resp, err := c.client.Do(req) if err != nil { - c.closed = true + c.Close() return err } io.Copy(io.Discard, resp.Body) @@ -225,10 +225,9 @@ func (w *WaitReadCloser) Set(rc io.ReadCloser) { } func (w *WaitReadCloser) Read(b []byte) (int, error) { + <-w.Wait if w.ReadCloser == nil { - if <-w.Wait; w.ReadCloser == nil { - return 0, io.ErrClosedPipe - } + return 0, io.ErrClosedPipe } return w.ReadCloser.Read(b) } diff --git a/transport/v2rayxhttp/mux.go b/transport/v2rayxhttp/mux.go index e134fdeb..f83a3ebf 100644 --- a/transport/v2rayxhttp/mux.go +++ b/transport/v2rayxhttp/mux.go @@ -19,8 +19,8 @@ type XmuxConn interface { type XmuxClient struct { XmuxConn XmuxConn - openUsage int32 - leftUsage int32 + openUsage int + leftUsage int LeftRequests atomic.Int32 UnreusableAt time.Time @@ -37,7 +37,7 @@ func (c *XmuxClient) Close() { } } -func (c *XmuxClient) AddOpenUsage(delta int32) { +func (c *XmuxClient) AddOpenUsage(delta int) { c.mtx.Lock() defer c.mtx.Unlock() c.openUsage += delta @@ -46,7 +46,7 @@ func (c *XmuxClient) AddOpenUsage(delta int32) { } } -func (c *XmuxClient) GetOpenUsage() int32 { +func (c *XmuxClient) GetOpenUsage() int { c.mtx.Lock() defer c.mtx.Unlock() return c.openUsage @@ -54,8 +54,8 @@ func (c *XmuxClient) GetOpenUsage() int32 { type XmuxManager struct { options option.V2RayXHTTPXmuxOptions - concurrency int32 - connections int32 + concurrency int + connections int newConnFunc func() XmuxConn xmuxClients []*XmuxClient mtx sync.Mutex @@ -71,6 +71,15 @@ func NewXmuxManager(options option.V2RayXHTTPXmuxOptions, newConnFunc func() Xmu } } +func (m *XmuxManager) Close() { + m.mtx.Lock() + defer m.mtx.Unlock() + for _, xmuxClient := range m.xmuxClients { + xmuxClient.Close() + } + m.xmuxClients = m.xmuxClients[:0] +} + func (m *XmuxManager) newXmuxClient() *XmuxClient { xmuxClient := &XmuxClient{ XmuxConn: m.newConnFunc(), @@ -81,7 +90,7 @@ func (m *XmuxManager) newXmuxClient() *XmuxClient { } xmuxClient.LeftRequests.Store(math.MaxInt32) if x := m.options.GetNormalizedHMaxRequestTimes().Rand(); x > 0 { - xmuxClient.LeftRequests.Store(x) + xmuxClient.LeftRequests.Store(int32(x)) } if x := m.options.GetNormalizedHMaxReusableSecs().Rand(); x > 0 { xmuxClient.UnreusableAt = time.Now().Add(time.Duration(x) * time.Second) @@ -112,7 +121,7 @@ func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient { if len(m.xmuxClients) == 0 { return m.newXmuxClient() } - if m.connections > 0 && len(m.xmuxClients) < int(m.connections) { + if m.connections > 0 && len(m.xmuxClients) < m.connections { return m.newXmuxClient() } xmuxClients := make([]*XmuxClient, 0) diff --git a/transport/v2rayxhttp/server.go b/transport/v2rayxhttp/server.go index f67d5e89..52b1308c 100644 --- a/transport/v2rayxhttp/server.go +++ b/transport/v2rayxhttp/server.go @@ -18,6 +18,8 @@ import ( "github.com/sagernet/quic-go" "github.com/sagernet/quic-go/http3" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/congestion" + "github.com/sagernet/sing-box/common/kmutex" "github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/common/xray/buf" xnet "github.com/sagernet/sing-box/common/xray/net" @@ -31,6 +33,7 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" aTLS "github.com/sagernet/sing/common/tls" sHttp "github.com/sagernet/sing/protocol/http" ) @@ -49,7 +52,7 @@ type Server struct { options *option.V2RayXHTTPOptions host string path string - sessionMu sync.Mutex + sessionMu *kmutex.Kmutex[string] sessions sync.Map } @@ -62,6 +65,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option. options: &options, host: options.Host, path: options.GetNormalizedPath(), + sessionMu: kmutex.New[string](), } if server.network() == N.NetworkTCP { protocols := new(http.Protocols) @@ -80,11 +84,21 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option. }, } } else { + congestionControlFactory, err := congestion.NewCongestionControl(options.CongestionController, options.CWND, ntp.TimeFuncFromContext(ctx)) + if err != nil { + return nil, err + } server.quicConfig = &quic.Config{ DisablePathMTUDiscovery: !C.IsLinux && !C.IsWindows, } server.http3Server = &http3.Server{ Handler: server, + ConnContext: func(ctx context.Context, conn *quic.Conn) context.Context { + if congestionControlFactory != nil { + conn.SetCongestionControl(congestionControlFactory(conn)) + } + return log.ContextWithNewID(ctx) + }, } } return server, nil @@ -102,7 +116,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { return } WriteResponseHeader(writer, request.Method, request.Header, s.options) - length := int(s.options.GetNormalizedXPaddingBytes().Rand()) + length := s.options.GetNormalizedXPaddingBytes().Rand() config := XPaddingConfig{Length: length} if s.options.XPaddingObfsMode { config.Placement = XPaddingPlacement{ @@ -125,15 +139,25 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { validRange := s.options.GetNormalizedXPaddingBytes() paddingValue, paddingPlacement := ExtractXPaddingFromRequest(&s.options.V2RayXHTTPBaseOptions, request, s.options.XPaddingObfsMode) if !IsPaddingValid(&s.options.V2RayXHTTPBaseOptions, paddingValue, validRange.From, validRange.To, PaddingMethod(s.options.XPaddingMethod)) { - s.logger.ErrorContext(request.Context(), "invalid padding ("+paddingPlacement+") length:", int32(len(paddingValue))) + s.logger.ErrorContext(request.Context(), "invalid padding ("+paddingPlacement+") length:", len(paddingValue)) writer.WriteHeader(http.StatusBadRequest) return } sessionId, seqStr := ExtractMetaFromRequest(s.options, request, s.path) - if sessionId == "" && s.options.Mode != "" && s.options.Mode != "auto" && s.options.Mode != "stream-one" && s.options.Mode != "stream-up" { - s.logger.ErrorContext(request.Context(), "stream-one mode is not allowed") - writer.WriteHeader(http.StatusBadRequest) - return + if s.options.Mode != "" && s.options.Mode != "auto" { + if sessionId == "" { + if s.options.Mode != "stream-one" && s.options.Mode != "stream-up" { + s.logger.ErrorContext(request.Context(), "stream-one mode is not allowed") + writer.WriteHeader(http.StatusBadRequest) + return + } + } else { + if s.options.Mode == "stream-one" { + s.logger.ErrorContext(request.Context(), "session is not allowed in stream-one mode") + writer.WriteHeader(http.StatusBadRequest) + return + } + } } var forwardedAddrs []xnet.Address if len(s.options.TrustedXForwardedFor) > 0 { @@ -171,7 +195,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if sessionId != "" { currentSession = s.upsertSession(sessionId) } - scMaxEachPostBytes := int(s.options.GetNormalizedScMaxEachPostBytes().To) + scMaxEachPostBytes := s.options.GetNormalizedScMaxEachPostBytes().To uplinkDataPlacement := s.options.GetNormalizedUplinkDataPlacement() uplinkDataKey := s.options.UplinkDataKey isUplinkRequest := false @@ -207,12 +231,22 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { referrer := request.Header.Get("Referer") if referrer != "" && scStreamUpServerSecs.To > 0 { go func() { + timer := time.NewTimer(0) + if !timer.Stop() { + <-timer.C + } + defer timer.Stop() for { - _, err := httpSC.Write(bytes.Repeat([]byte{'X'}, int(s.options.GetNormalizedXPaddingBytes().Rand()))) + _, err := httpSC.Write(bytes.Repeat([]byte{'X'}, s.options.GetNormalizedXPaddingBytes().Rand())) if err != nil { - break + return + } + timer.Reset(time.Duration(scStreamUpServerSecs.Rand()) * time.Second) + select { + case <-timer.C: + case <-httpSC.Wait(): + return } - time.Sleep(time.Duration(scStreamUpServerSecs.Rand()) * time.Second) } }() } @@ -327,7 +361,11 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { // after GET is done, the connection is finished. disable automatic // session reaping, and handle it in defer currentSession.isFullyConnected.Close() - defer s.sessions.Delete(sessionId) + defer func() { + s.sessionMu.Lock(sessionId) + defer s.sessionMu.Unlock(sessionId) + s.sessions.Delete(sessionId) + }() } // magic header instructs nginx + apache to not buffer response body writer.Header().Set("X-Accel-Buffering", "no") @@ -410,32 +448,27 @@ func (s *Server) network() string { } func (s *Server) upsertSession(sessionId string) *httpSession { - // fast path + s.sessionMu.Lock(sessionId) + defer s.sessionMu.Unlock(sessionId) currentSessionAny, ok := s.sessions.Load(sessionId) if ok { return currentSessionAny.(*httpSession) } - // slow path - s.sessionMu.Lock() - defer s.sessionMu.Unlock() - currentSessionAny, ok = s.sessions.Load(sessionId) - if ok { - return currentSessionAny.(*httpSession) - } session := &httpSession{ uploadQueue: NewUploadQueue(s.options.GetNormalizedScMaxBufferedPosts()), isFullyConnected: done.New(), } s.sessions.Store(sessionId, session) - shouldReap := done.New() - go func() { - time.Sleep(30 * time.Second) - shouldReap.Close() - }() go func() { + reapTimer := time.NewTimer(30 * time.Second) + defer reapTimer.Stop() select { - case <-shouldReap.Wait(): - s.sessions.Delete(sessionId) + case <-reapTimer.C: + s.sessionMu.Lock(sessionId) + if current, ok := s.sessions.Load(sessionId); ok && current.(*httpSession) == session { + s.sessions.Delete(sessionId) + } + s.sessionMu.Unlock(sessionId) session.uploadQueue.Close() case <-session.isFullyConnected.Wait(): } diff --git a/transport/v2rayxhttp/upload_queue.go b/transport/v2rayxhttp/upload_queue.go index 32259a73..da0b2f3c 100644 --- a/transport/v2rayxhttp/upload_queue.go +++ b/transport/v2rayxhttp/upload_queue.go @@ -6,7 +6,6 @@ package xhttp import ( "container/heap" "io" - "runtime" "sync" E "github.com/sagernet/sing/common/exceptions" @@ -19,19 +18,22 @@ type Packet struct { } type uploadQueue struct { - reader io.ReadCloser - nomore bool - pushedPackets chan Packet - writeCloseMutex sync.Mutex - heap uploadHeap - nextSeq uint64 - closed bool - maxPackets int + reader io.ReadCloser + nomore bool + pushedPackets chan Packet + done chan struct{} + heap uploadHeap + nextSeq uint64 + closed bool + maxPackets int + + mtx sync.Mutex } func NewUploadQueue(maxPackets int) *uploadQueue { return &uploadQueue{ pushedPackets: make(chan Packet, maxPackets), + done: make(chan struct{}), heap: uploadHeap{}, nextSeq: 0, closed: false, @@ -40,63 +42,83 @@ func NewUploadQueue(maxPackets int) *uploadQueue { } func (h *uploadQueue) Push(p Packet) error { - h.writeCloseMutex.Lock() - defer h.writeCloseMutex.Unlock() + h.mtx.Lock() if h.closed { + h.mtx.Unlock() return E.New("packet queue closed") } if h.nomore { + h.mtx.Unlock() return E.New("h.reader already exists") } if p.Reader != nil { h.nomore = true } - h.pushedPackets <- p - return nil + h.mtx.Unlock() + select { + case h.pushedPackets <- p: + return nil + case <-h.done: + return E.New("packet queue closed") + } } func (h *uploadQueue) Close() error { - h.writeCloseMutex.Lock() - defer h.writeCloseMutex.Unlock() - if !h.closed { - h.closed = true - runtime.Gosched() // hope Read() gets the packet - f: - for { - select { - case p := <-h.pushedPackets: - if p.Reader != nil { - h.reader = p.Reader - } - default: - break f + h.mtx.Lock() + if h.closed { + h.mtx.Unlock() + return nil + } + h.closed = true + close(h.done) + h.mtx.Unlock() + + for { + select { + case p := <-h.pushedPackets: + if p.Reader != nil { + p.Reader.Close() } + default: + if h.reader != nil { + return h.reader.Close() + } + return nil } - close(h.pushedPackets) } - if h.reader != nil { - return h.reader.Close() - } - return nil } func (h *uploadQueue) Read(b []byte) (int, error) { + h.mtx.Lock() + if h.closed { + h.mtx.Unlock() + return 0, io.EOF + } + h.mtx.Unlock() if h.reader != nil { return h.reader.Read(b) } - if h.closed { - return 0, io.EOF - } if len(h.heap) == 0 { - packet, more := <-h.pushedPackets - if !more { + select { + case packet, more := <-h.pushedPackets: + if !more { + return 0, io.EOF + } + if packet.Reader != nil { + h.mtx.Lock() + if h.closed { + packet.Reader.Close() + h.mtx.Unlock() + return 0, io.EOF + } + h.reader = packet.Reader + h.mtx.Unlock() + return h.reader.Read(b) + } + heap.Push(&h.heap, packet) + case <-h.done: return 0, io.EOF } - if packet.Reader != nil { - h.reader = packet.Reader - return h.reader.Read(b) - } - heap.Push(&h.heap, packet) } for len(h.heap) > 0 { packet := heap.Pop(&h.heap).(Packet) @@ -125,11 +147,15 @@ func (h *uploadQueue) Read(b []byte) (int, error) { return 0, E.New("packet queue is too large") } heap.Push(&h.heap, packet) - packet2, more := <-h.pushedPackets - if !more { + select { + case packet2, more := <-h.pushedPackets: + if !more { + return 0, io.EOF + } + heap.Push(&h.heap, packet2) + case <-h.done: return 0, io.EOF } - heap.Push(&h.heap, packet2) } } return 0, nil diff --git a/transport/v2rayxhttp/utils.go b/transport/v2rayxhttp/utils.go index 102a64db..408343f3 100644 --- a/transport/v2rayxhttp/utils.go +++ b/transport/v2rayxhttp/utils.go @@ -4,16 +4,49 @@ import ( "encoding/base64" "fmt" "io" + "math/rand/v2" "net/http" "github.com/sagernet/sing-box/common/xray/buf" "github.com/sagernet/sing-box/common/xray/utils" + "github.com/sagernet/sing-box/common/xray/uuid" "github.com/sagernet/sing-box/option" ) +// PredefinedTable maps named charsets to their alphabets for session ID generation. + +var PredefinedTable = map[string]string{ + "ALPHABET": "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + "Alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + "BASE36": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ", + "Base62": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + "HEX": "0123456789ABCDEF", + "alphabet": "abcdefghijklmnopqrstuvwxyz", + "base36": "0123456789abcdefghijklmnopqrstuvwxyz", + "hex": "0123456789abcdef", + "number": "0123456789", +} + +func GenerateSessionID(options *option.V2RayXHTTPBaseOptions) string { + length := options.SessionIDLength.Rand() + table := options.SessionIDTable + if predefined, ok := PredefinedTable[table]; ok { + table = predefined + } + if table != "" && length > 0 { + id := make([]byte, length) + for i := range id { + id[i] = table[rand.N(len(table))] + } + return string(id) + } + newUUID := uuid.New() + return newUUID.String() +} + func FillStreamRequest(request *http.Request, sessionId string, seqStr string, options *option.V2RayXHTTPBaseOptions) { request.Header = options.GetRequestHeader() - length := int(options.GetNormalizedXPaddingBytes().Rand()) + length := options.GetNormalizedXPaddingBytes().Rand() config := XPaddingConfig{Length: length} if options.XPaddingObfsMode { config.Placement = XPaddingPlacement{ @@ -58,7 +91,7 @@ func FillPacketRequest(request *http.Request, sessionId string, seqStr string, p } } } - length := int(options.GetNormalizedXPaddingBytes().Rand()) + length := options.GetNormalizedXPaddingBytes().Rand() config := XPaddingConfig{Length: length} if options.XPaddingObfsMode { config.Placement = XPaddingPlacement{ @@ -125,7 +158,7 @@ func GetRequestHeaderWithPayload(payload []byte, options *option.V2RayXHTTPBaseO key := options.UplinkDataKey encodedData := base64.RawURLEncoding.EncodeToString(payload) for i := 0; len(encodedData) > 0; i++ { - chunkSize := min(int(options.GetNormalizedUplinkChunkSize().Rand()), len(encodedData)) + chunkSize := min(options.GetNormalizedUplinkChunkSize().Rand(), len(encodedData)) chunk := encodedData[:chunkSize] encodedData = encodedData[chunkSize:] headerKey := fmt.Sprintf("%s-%d", key, i) @@ -140,7 +173,7 @@ func GetRequestCookiesWithPayload(payload []byte, options *option.V2RayXHTTPBase key := options.UplinkDataKey encodedData := base64.RawURLEncoding.EncodeToString(payload) for i := 0; len(encodedData) > 0; i++ { - chunkSize := min(int(options.GetNormalizedUplinkChunkSize().Rand()), len(encodedData)) + chunkSize := min(options.GetNormalizedUplinkChunkSize().Rand(), len(encodedData)) chunk := encodedData[:chunkSize] encodedData = encodedData[chunkSize:] cookieName := fmt.Sprintf("%s_%d", key, i) diff --git a/transport/v2rayxhttp/writer.go b/transport/v2rayxhttp/writer.go index 3c11e3b3..3e4dcd6b 100644 --- a/transport/v2rayxhttp/writer.go +++ b/transport/v2rayxhttp/writer.go @@ -31,11 +31,12 @@ func (w uploadWriter) Write(b []byte) (int, error) { var writed int for _, buff := range buffer.MultiBuffer { + n := int(buff.Len()) err := w.WriteMultiBuffer(buf.MultiBuffer{buff}) if err != nil { return writed, err } - writed += int(buff.Len()) + writed += n } return writed, nil } diff --git a/transport/v2rayxhttp/xpadding.go b/transport/v2rayxhttp/xpadding.go index fd36c49d..1d6cf26b 100644 --- a/transport/v2rayxhttp/xpadding.go +++ b/transport/v2rayxhttp/xpadding.go @@ -264,7 +264,7 @@ func ExtractXPaddingFromRequest(options *option.V2RayXHTTPBaseOptions, req *http return "", "" } -func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string, from, to int32, method PaddingMethod) bool { +func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string, from, to int, method PaddingMethod) bool { if paddingValue == "" { return false } @@ -274,11 +274,11 @@ func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string, } switch method { case PaddingMethodRepeatX: - n := int32(len(paddingValue)) + n := len(paddingValue) return n >= from && n <= to case PaddingMethodTokenish: - const tolerance = int32(validationTolerance) - n := int32(hpack.HuffmanEncodeLength(paddingValue)) + const tolerance = validationTolerance + n := int(hpack.HuffmanEncodeLength(paddingValue)) f := from - tolerance t := to + tolerance if f < 0 { @@ -286,7 +286,7 @@ func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string, } return n >= f && n <= t default: - n := int32(len(paddingValue)) + n := len(paddingValue) return n >= from && n <= to } } diff --git a/transport/wireguard/endpoint_options.go b/transport/wireguard/endpoint_options.go index a9771555..4c9341d9 100644 --- a/transport/wireguard/endpoint_options.go +++ b/transport/wireguard/endpoint_options.go @@ -5,7 +5,7 @@ import ( "net/netip" "time" - Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption" + "github.com/sagernet/sing/common/json/badoption" tun "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -49,10 +49,10 @@ type AmneziaOptions struct { S2 int S3 int S4 int - H1 *Xbadoption.Range - H2 *Xbadoption.Range - H3 *Xbadoption.Range - H4 *Xbadoption.Range + H1 *badoption.Range[uint32] + H2 *badoption.Range[uint32] + H3 *badoption.Range[uint32] + H4 *badoption.Range[uint32] I1 string I2 string I3 string