Add Snell protocol. Refactor MASQUE HTTP/2, Fair Queue. Update XHTTP, OpenVPN, Sudoku, Fallback. Fixes

This commit is contained in:
Shtorm
2026-06-26 01:25:57 +03:00
parent d174962a04
commit edf38d33d6
107 changed files with 5346 additions and 708 deletions

View File

@@ -26,6 +26,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_manager - with_manager
- with_admin_panel - with_admin_panel
- with_profiler - with_profiler
@@ -64,6 +65,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_profiler - with_profiler
- badlinkname - badlinkname
- tfogo_checklinkname0 - tfogo_checklinkname0
@@ -123,6 +125,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_manager - with_manager
- with_admin_panel - with_admin_panel
- with_profiler - with_profiler
@@ -156,6 +159,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_manager - with_manager
- with_admin_panel - with_admin_panel
- with_profiler - with_profiler
@@ -189,6 +193,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_manager - with_manager
- with_admin_panel - with_admin_panel
- with_profiler - with_profiler
@@ -222,6 +227,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_manager - with_manager
- with_admin_panel - with_admin_panel
- with_profiler - with_profiler
@@ -255,6 +261,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_manager - with_manager
- with_admin_panel - with_admin_panel
- with_profiler - with_profiler
@@ -304,6 +311,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_manager - with_manager
- with_admin_panel - with_admin_panel
- with_profiler - with_profiler
@@ -361,6 +369,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_profiler - with_profiler
- badlinkname - badlinkname
- tfogo_checklinkname0 - tfogo_checklinkname0
@@ -433,6 +442,7 @@ builds:
- with_openvpn - with_openvpn
- with_trusttunnel - with_trusttunnel
- with_sudoku - with_sudoku
- with_snell
- with_profiler - with_profiler
- badlinkname - badlinkname
- tfogo_checklinkname0 - tfogo_checklinkname0

View File

@@ -13,6 +13,7 @@ type PlatformInterface interface {
UsePlatformAutoDetectInterfaceControl() bool UsePlatformAutoDetectInterfaceControl() bool
AutoDetectInterfaceControl(fd int) error AutoDetectInterfaceControl(fd int) error
BindInterfaceControl(fd int, interfaceName string) error
UsePlatformInterface() bool UsePlatformInterface() bool
OpenInterface(options *tun.Options, platformOptions option.TunPlatformOptions) (tun.Tun, error) OpenInterface(options *tun.Options, platformOptions option.TunPlatformOptions) (tun.Tun, error)

View File

@@ -63,7 +63,7 @@ func init() {
sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -X internal/godebug.defaultGODEBUG=multipathtcp=0 -s -w -buildid= -checklinkname=0") sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -X internal/godebug.defaultGODEBUG=multipathtcp=0 -s -w -buildid= -checklinkname=0")
debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -X internal/godebug.defaultGODEBUG=multipathtcp=0 -checklinkname=0") debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -X internal/godebug.defaultGODEBUG=multipathtcp=0 -checklinkname=0")
sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_masque", "with_mtproxy", "with_trusttunnel", "with_openvpn", "with_sudoku", "with_utls", "with_naive_outbound", "with_clash_api", "badlinkname", "tfogo_checklinkname0") sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_masque", "with_mtproxy", "with_trusttunnel", "with_openvpn", "with_sudoku", "with_snell", "with_utls", "with_naive_outbound", "with_clash_api", "badlinkname", "tfogo_checklinkname0")
darwinTags = append(darwinTags, "with_dhcp", "grpcnotrace") darwinTags = append(darwinTags, "with_dhcp", "grpcnotrace")
// memcTags = append(memcTags, "with_tailscale") // memcTags = append(memcTags, "with_tailscale")
sharedTags = append(sharedTags, "with_tailscale", "ts_omit_logtail", "ts_omit_ssh", "ts_omit_drive", "ts_omit_taildrop", "ts_omit_webclient", "ts_omit_doctor", "ts_omit_capture", "ts_omit_kube", "ts_omit_aws", "ts_omit_synology", "ts_omit_bird") sharedTags = append(sharedTags, "with_tailscale", "ts_omit_logtail", "ts_omit_ssh", "ts_omit_drive", "ts_omit_taildrop", "ts_omit_webclient", "ts_omit_doctor", "ts_omit_capture", "ts_omit_kube", "ts_omit_aws", "ts_omit_synology", "ts_omit_bird")

View File

@@ -1,4 +1,4 @@
package trusttunnel package congestion
import ( import (
"time" "time"
@@ -12,7 +12,7 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
func NewCongestionControl(name string, cwnd int, bbrProfile string, timeFunc func() time.Time) (func(conn *quic.Conn) congestion.CongestionControl, error) { func NewCongestionControl(name string, cwnd int, timeFunc func() time.Time) (func(conn *quic.Conn) congestion.CongestionControl, error) {
if timeFunc == nil { if timeFunc == nil {
timeFunc = time.Now timeFunc = time.Now
} }

View File

@@ -70,9 +70,20 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
if !(C.IsLinux || C.IsDarwin || C.IsWindows) { if !(C.IsLinux || C.IsDarwin || C.IsWindows) {
return nil, E.New("`bind_interface` is only supported on Linux, macOS and Windows") return nil, E.New("`bind_interface` is only supported on Linux, macOS and Windows")
} }
bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1) if platformInterface != nil && platformInterface.UsePlatformAutoDetectInterfaceControl() {
dialer.Control = control.Append(dialer.Control, bindFunc) interfaceName := options.BindInterface
listener.Control = control.Append(listener.Control, bindFunc) bindFunc := func(network, address string, conn syscall.RawConn) error {
return control.Raw(conn, func(fd uintptr) error {
return platformInterface.BindInterfaceControl(int(fd), interfaceName)
})
}
dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc)
} else {
bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1)
dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc)
}
} }
if options.RoutingMark > 0 { if options.RoutingMark > 0 {
if !C.IsLinux { if !C.IsLinux {

164
common/list/list.go Normal file
View File

@@ -0,0 +1,164 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package list
// Element is an element of a linked list.
type Element[T any] struct {
next, prev *Element[T]
list *List[T]
Value T
}
func (e *Element[T]) Next() *Element[T] {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
func (e *Element[T]) Prev() *Element[T] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
func (e *Element[T]) Remove() bool {
if e.list == nil {
return false
}
e.list.remove(e)
return true
}
type List[T any] struct {
root Element[T]
len int
}
func (l *List[T]) Init() *List[T] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
func New[T any]() *List[T] { return new(List[T]).Init() }
func (l *List[T]) Len() int { return l.len }
func (l *List[T]) Front() *Element[T] {
if l.len == 0 {
return nil
}
return l.root.next
}
func (l *List[T]) Back() *Element[T] {
if l.len == 0 {
return nil
}
return l.root.prev
}
func (l *List[T]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
func (l *List[T]) insert(e, at *Element[T]) *Element[T] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] {
return l.insert(&Element[T]{Value: v}, at)
}
func (l *List[T]) remove(e *Element[T]) {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil
e.prev = nil
e.list = nil
l.len--
}
func (l *List[T]) Remove(e *Element[T]) T {
if e.list == l {
l.remove(e)
}
return e.Value
}
func (l *List[T]) PushFront(v T) *Element[T] {
l.lazyInit()
return l.insertValue(v, &l.root)
}
func (l *List[T]) PushBack(v T) *Element[T] {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] {
if mark.list != l {
return nil
}
return l.insertValue(v, mark.prev)
}
func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] {
if mark.list != l {
return nil
}
return l.insertValue(v, mark)
}
func (l *List[T]) MoveToFront(e *Element[T]) {
if e.list != l || l.root.next == e {
return
}
l.move(e, &l.root)
}
func (l *List[T]) MoveToBack(e *Element[T]) {
if e.list != l || l.root.prev == e {
return
}
l.move(e, l.root.prev)
}
func (l *List[T]) MoveBefore(e, mark *Element[T]) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark.prev)
}
func (l *List[T]) MoveAfter(e, mark *Element[T]) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark)
}
func (l *List[T]) move(e, at *Element[T]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}

View File

@@ -9,7 +9,6 @@ import (
"strings" "strings"
"time" "time"
Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption"
"github.com/sagernet/sing/common/json/badoption" "github.com/sagernet/sing/common/json/badoption"
) )
@@ -69,8 +68,8 @@ func DecodeBase64URLSafe(content string) (string, error) {
return string(result), nil return string(result), nil
} }
func ParseXHTTPRange(value string) (Xbadoption.Range, error) { func ParseXHTTPRange(value string) (badoption.Range[int], error) {
result := Xbadoption.Range{} result := badoption.Range[int]{}
encoded, err := json.Marshal(value) encoded, err := json.Marshal(value)
if err != nil { if err != nil {
return result, err return result, err

View File

@@ -1,83 +0,0 @@
package badoption
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/sagernet/sing-box/common/xray/crypto"
E "github.com/sagernet/sing/common/exceptions"
)
type Range struct {
From int32 `json:"from"`
To int32 `json:"to"`
}
func (c *Range) Build() *Range {
return (*Range)(c)
}
func (c *Range) MarshalJSON() ([]byte, error) {
if c.From == c.To {
return json.Marshal(c.From)
}
return json.Marshal(fmt.Sprintf("%d-%d", c.From, c.To))
}
func (c *Range) UnmarshalJSON(content []byte) error {
var rangeValue struct {
From int32 `json:"from"`
To int32 `json:"to"`
}
var stringValue string
err := json.Unmarshal(content, &stringValue)
if err == nil {
parts := strings.Split(stringValue, "-")
if len(parts) != 2 {
from, err := strconv.ParseInt(parts[0], 10, 32)
if err != nil {
return err
}
rangeValue.From, rangeValue.To = int32(from), int32(from)
} else {
from, err := strconv.ParseInt(parts[0], 10, 32)
if err != nil {
return err
}
to, err := strconv.ParseInt(parts[1], 10, 32)
if err != nil {
return err
}
rangeValue.From, rangeValue.To = int32(from), int32(to)
}
} else {
var int32Value int32
err := json.Unmarshal(content, &int32Value)
if err == nil {
rangeValue.From, rangeValue.To = int32Value, int32Value
} else {
err := json.Unmarshal(content, &rangeValue)
if err != nil {
return err
}
}
}
if rangeValue.From > rangeValue.To {
return E.New("invalid range")
}
*c = Range{rangeValue.From, rangeValue.To}
return nil
}
func (c *Range) String() string {
if c.From == c.To {
return strconv.FormatInt(int64(c.From), 10)
}
return fmt.Sprintf("%d-%d", c.From, c.To)
}
func (c Range) Rand() int32 {
return int32(crypto.RandBetween(int64(c.From), int64(c.To)))
}

View File

@@ -28,6 +28,7 @@ const (
TypeMieru = "mieru" TypeMieru = "mieru"
TypeAnyTLS = "anytls" TypeAnyTLS = "anytls"
TypeSudoku = "sudoku" TypeSudoku = "sudoku"
TypeSnell = "snell"
TypeShadowsocksR = "shadowsocksr" TypeShadowsocksR = "shadowsocksr"
TypeVLESS = "vless" TypeVLESS = "vless"
TypeTUIC = "tuic" TypeTUIC = "tuic"
@@ -41,6 +42,7 @@ const (
TypeBandwidthLimiter = "bandwidth-limiter" TypeBandwidthLimiter = "bandwidth-limiter"
TypeTrafficLimiter = "traffic-limiter" TypeTrafficLimiter = "traffic-limiter"
TypeRateLimiter = "rate-limiter" TypeRateLimiter = "rate-limiter"
TypeFairQueue = "fair-queue"
TypeAdminPanel = "admin-panel" TypeAdminPanel = "admin-panel"
TypeManagerAPI = "manager-api" TypeManagerAPI = "manager-api"
TypeNodeManagerAPI = "node-manager-api" TypeNodeManagerAPI = "node-manager-api"
@@ -129,6 +131,8 @@ func ProxyDisplayName(proxyType string) string {
return "AnyTLS" return "AnyTLS"
case TypeSudoku: case TypeSudoku:
return "Sudoku" return "Sudoku"
case TypeSnell:
return "Snell"
case TypeFallback: case TypeFallback:
return "Fallback" return "Fallback"
case TypeTailscale: case TypeTailscale:
@@ -145,6 +149,8 @@ func ProxyDisplayName(proxyType string) string {
return "Traffic Limiter" return "Traffic Limiter"
case TypeRateLimiter: case TypeRateLimiter:
return "Rate Limiter" return "Rate Limiter"
case TypeFairQueue:
return "Fair Queue"
case TypeVPNClient: case TypeVPNClient:
return "VPN Client" return "VPN Client"
case TypeVPNServer: case TypeVPNServer:

View File

@@ -39,11 +39,14 @@
"udp_keepalive_period": "30s", "udp_keepalive_period": "30s",
"udp_initial_packet_size": 0, "udp_initial_packet_size": 0,
"reconnect_delay": "5s", "reconnect_delay": "5s",
"congestion_controller": "bbr",
"cwnd": 0,
"tls": { // TLS fields for HTTP2 "tls": { // TLS fields for HTTP2
"insecure": false, "insecure": false,
"cipher_suites": [], "cipher_suites": [],
"curve_preferences": [], "curve_preferences": [],
"fragment": false, "fragment": false,
"fragment_fallback_delay": "500ms",
"record_fragment": false, "record_fragment": false,
"kernel_tx": false, "kernel_tx": false,
"kernel_rx": false "kernel_rx": false

View File

@@ -0,0 +1,13 @@
{
"log": {
"level": "info"
},
"services": [
{
"type": "profiler",
"tag": "pprof",
"listen": "127.0.0.1",
"listen_port": 6060
}
]
}

View File

@@ -0,0 +1,46 @@
{
"log": {
"level": "error"
},
"dns": {
"servers": [
{
"type": "local",
"tag": "default"
}
]
},
"inbounds": [
{
"type": "mixed",
"tag": "mixed-in",
"listen_port": 7897
}
],
"outbounds": [
{
"type": "direct",
"tag": "direct"
},
{
"type": "snell",
"tag": "snell-out",
"server": "example.com",
"server_port": 8443,
"psk": "your-secret-psk",
"version": 4, // 1 | 2 | 3 | 4 | 5 (v5 falls back to v4)
"reuse": true, // v4 only, reuse pooled connections
"network": ["tcp", "udp"]
// "obfs": {
// "mode": "tls", // tls | http
// "host": "bing.com"
// }
// Dial Fields
}
],
"route": {
"final": "snell-out",
"default_domain_resolver": "default",
"auto_detect_interface": true
}
}

View File

@@ -0,0 +1,39 @@
{
"log": {
"level": "error"
},
"dns": {
"servers": [
{
"type": "local",
"tag": "default"
}
]
},
"inbounds": [
{
"type": "snell",
"tag": "snell-in",
"listen": "::",
"listen_port": 8443,
"psk": "your-secret-psk",
"version": 4, // 4 | 5 (server supports v4/v5 only)
"network": ["tcp", "udp"]
// "obfs": {
// "mode": "tls", // tls | http
// "host": "bing.com"
// }
}
],
"outbounds": [
{
"type": "direct",
"tag": "direct"
}
],
"route": {
"final": "direct",
"default_domain_resolver": "default",
"auto_detect_interface": true
}
}

View File

@@ -31,7 +31,8 @@
"multiplex": { "multiplex": {
"enabled": true, "enabled": true,
"max_connections": 8, "max_connections": 8,
"min_streams": 5 "min_streams": 5,
"max_streams": 0
}, },
"tls": { "tls": {
"enabled": true, "enabled": true,
@@ -50,12 +51,12 @@
"health_check": true, "health_check": true,
"quic": true, "quic": true,
"congestion_controller": "bbr", // bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno "congestion_controller": "bbr", // bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
"bbr_profile": "standard", // standard, conservative, aggressive
"cwnd": 32, "cwnd": 32,
"multiplex": { "multiplex": {
"enabled": true, "enabled": true,
"max_connections": 8, "max_connections": 8,
"min_streams": 5 "min_streams": 5,
"max_streams": 0
}, },
"tls": { "tls": {
"enabled": true, "enabled": true,

View File

@@ -13,7 +13,6 @@
} }
], ],
"congestion_controller": "bbr", // bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno "congestion_controller": "bbr", // bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
"bbr_profile": "standard", // standard, conservative, aggressive
"cwnd": 32, "cwnd": 32,
"tls": { "tls": {
"enabled": true, "enabled": true,

View File

@@ -65,6 +65,8 @@
"uplink_data_placement": "", "uplink_data_placement": "",
"uplink_data_key": "", "uplink_data_key": "",
"uplink_chunk_size": 0, "uplink_chunk_size": 0,
"congestion_controller": "", // h3 only: bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
"cwnd": 0, // h3 only: initial congestion window in packets, default 32
"server": "example.com", "server": "example.com",
"server_port": 443, "server_port": 443,
"download": { "download": {
@@ -97,6 +99,8 @@
"uplink_data_placement": "", "uplink_data_placement": "",
"uplink_data_key": "", "uplink_data_key": "",
"uplink_chunk_size": 0, "uplink_chunk_size": 0,
"congestion_controller": "", // h3 only: bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
"cwnd": 0, // h3 only: initial congestion window in packets, default 32
"server": "example.com", "server": "example.com",
"server_port": 443, "server_port": 443,
"tls": { // https://sing-box.sagernet.org/configuration/shared/tls/#outbound "tls": { // https://sing-box.sagernet.org/configuration/shared/tls/#outbound

View File

@@ -51,6 +51,8 @@
"seq_key": "", "seq_key": "",
"uplink_data_placement": "", "uplink_data_placement": "",
"uplink_data_key": "", "uplink_data_key": "",
"congestion_controller": "", // h3 only: bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
"cwnd": 0, // h3 only: initial congestion window in packets, default 32
} }
} }
], ],

View File

@@ -78,6 +78,10 @@ func (s *platformInterfaceStub) AutoDetectInterfaceControl(fd int) error {
return nil return nil
} }
func (s *platformInterfaceStub) BindInterfaceControl(fd int, interfaceName string) error {
return os.ErrInvalid
}
func (s *platformInterfaceStub) UsePlatformInterface() bool { func (s *platformInterfaceStub) UsePlatformInterface() bool {
return false return false
} }

View File

@@ -6,6 +6,7 @@ type PlatformInterface interface {
LocalDNSTransport() LocalDNSTransport LocalDNSTransport() LocalDNSTransport
UsePlatformAutoDetectInterfaceControl() bool UsePlatformAutoDetectInterfaceControl() bool
AutoDetectInterfaceControl(fd int32) error AutoDetectInterfaceControl(fd int32) error
BindInterfaceControl(fd int32, interfaceName string) error
OpenTun(options TunOptions) (int32, error) OpenTun(options TunOptions) (int32, error)
UseProcFS() bool UseProcFS() bool
FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (*ConnectionOwner, error) FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (*ConnectionOwner, error)

View File

@@ -49,6 +49,10 @@ func (w *platformInterfaceWrapper) AutoDetectInterfaceControl(fd int) error {
return w.iif.AutoDetectInterfaceControl(int32(fd)) return w.iif.AutoDetectInterfaceControl(int32(fd))
} }
func (w *platformInterfaceWrapper) BindInterfaceControl(fd int, interfaceName string) error {
return w.iif.BindInterfaceControl(int32(fd), interfaceName)
}
func (w *platformInterfaceWrapper) UsePlatformInterface() bool { func (w *platformInterfaceWrapper) UsePlatformInterface() bool {
return true return true
} }

7
go.mod
View File

@@ -32,6 +32,7 @@ require (
github.com/miekg/dns v1.1.72 github.com/miekg/dns v1.1.72
github.com/openai/openai-go/v3 v3.26.0 github.com/openai/openai-go/v3 v3.26.0
github.com/oschwald/maxminddb-golang v1.13.1 github.com/oschwald/maxminddb-golang v1.13.1
github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e
github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1 github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a
github.com/sagernet/cors v1.2.1 github.com/sagernet/cors v1.2.1
@@ -231,8 +232,8 @@ replace github.com/sagernet/sing-vmess => github.com/shtorm-7/sing-vmess v0.2.7-
replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0
replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.1.0
replace github.com/shtorm-7/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 replace github.com/shtorm-7/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.2.0
replace github.com/sagernet/sing => github.com/shtorm-7/sing v0.8.10-extended-1.1.0 replace github.com/sagernet/sing => github.com/shtorm-7/sing v0.8.10-extended-1.2.0

14
go.sum
View File

@@ -268,6 +268,8 @@ github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e h1:dCWirM5F3wMY+cmRda/B1BiPsFtmzXqV9b0hLWtVBMs=
github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e/go.mod h1:9leZcVcItj6m9/CfHY5Em/iBrCz7js8LcRQGTKEEv2M=
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
@@ -373,16 +375,16 @@ github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1h
github.com/sagernet/smux v1.5.50-sing-box-mod.1/go.mod h1:NjhsCEWedJm7eFLyhuBgIEzwfhRmytrUoiLluxs5Sk8= github.com/sagernet/smux v1.5.50-sing-box-mod.1/go.mod h1:NjhsCEWedJm7eFLyhuBgIEzwfhRmytrUoiLluxs5Sk8=
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 h1:ws7BIsYLd31Wjifq88BYCHRVlgO+07iwil39s6ERba8= github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.1.0 h1:3ZV98mKqKNPCPWHevJ6RPsb65DwPrRFEUOHUfDnG6vw=
github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0/go.mod h1:mRwx4w32qQxsWB2kThuHpbo7iNjJiq1jYWubgqEPjHA= github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.1.0/go.mod h1:mRwx4w32qQxsWB2kThuHpbo7iNjJiq1jYWubgqEPjHA=
github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 h1:e5s7RKBd2rIPR0StbvZ2vTVtJ5jDTsTk5wtIIapZTRg= github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 h1:e5s7RKBd2rIPR0StbvZ2vTVtJ5jDTsTk5wtIIapZTRg=
github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI= github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI=
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 h1:PLZ/YHqnApPx13wt6MX3ItqESp4ueBr1tGSi0bEGqYw= github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.2.0 h1:aOd9Vy2LGSwgMM+4805AgLBE/MQf8UymbXHxUZjSmoU=
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4= github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.2.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4=
github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 h1:iBLll4ZZG8ULQcHWs6gGslZWtBN72Zo1zjySzMVHF7g= github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 h1:iBLll4ZZG8ULQcHWs6gGslZWtBN72Zo1zjySzMVHF7g=
github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0= github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0=
github.com/shtorm-7/sing v0.8.10-extended-1.1.0 h1:P4JL2cugjvEvnYu8tMmpR30SE1qsS45RcnNEwzDz5as= github.com/shtorm-7/sing v0.8.10-extended-1.2.0 h1:5yw9j0+P2QkRWvxBvb71wvNdpAlHmmpBv4hj2gqvass=
github.com/shtorm-7/sing v0.8.10-extended-1.1.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA= github.com/shtorm-7/sing v0.8.10-extended-1.2.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA=
github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0 h1:a5OoXr3e2ACbM6vDIaaGL44IdHQ6wPjcSoU13vfC0Sw= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0 h1:a5OoXr3e2ACbM6vDIaaGL44IdHQ6wPjcSoU13vfC0Sw=
github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk=
github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0 h1:WVheKmQH5hSQbJU1ZTKthKSutkTLWSb2hp4JuQhJBow= github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0 h1:WVheKmQH5hSQbJU1ZTKthKSutkTLWSb2hp4JuQhJBow=

View File

@@ -91,6 +91,7 @@ func InboundRegistry() *inbound.Registry {
registerStubForRemovedInbounds(registry) registerStubForRemovedInbounds(registry)
registerMTProxyInbound(registry) registerMTProxyInbound(registry)
registerSudokuInbound(registry) registerSudokuInbound(registry)
registerSnellInbound(registry)
return registry return registry
} }
@@ -135,6 +136,7 @@ func OutboundRegistry() *outbound.Registry {
registerQUICOutbounds(registry) registerQUICOutbounds(registry)
registerStubForRemovedOutbounds(registry) registerStubForRemovedOutbounds(registry)
registerSudokuOutbound(registry) registerSudokuOutbound(registry)
registerSnellOutbound(registry)
return registry return registry
} }

17
include/snell.go Normal file
View File

@@ -0,0 +1,17 @@
//go:build with_snell
package include
import (
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/protocol/snell"
)
func registerSnellInbound(registry *inbound.Registry) {
snell.RegisterInbound(registry)
}
func registerSnellOutbound(registry *outbound.Registry) {
snell.RegisterOutbound(registry)
}

27
include/snell_stub.go Normal file
View File

@@ -0,0 +1,27 @@
//go:build !with_snell
package include
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func registerSnellInbound(registry *inbound.Registry) {
inbound.Register[option.SnellInboundOptions](registry, C.TypeSnell, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellInboundOptions) (adapter.Inbound, error) {
return nil, E.New(`Snell is not included in this build, rebuild with -tags with_snell`)
})
}
func registerSnellOutbound(registry *outbound.Registry) {
outbound.Register[option.SnellOutboundOptions](registry, C.TypeSnell, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellOutboundOptions) (adapter.Outbound, error) {
return nil, E.New(`Snell is not included in this build, rebuild with -tags with_snell`)
})
}

View File

@@ -18,7 +18,8 @@ type URLTestOutboundOptions struct {
} }
type FallbackOutboundOptions struct { type FallbackOutboundOptions struct {
Outbounds []string `json:"outbounds"` Outbounds []string `json:"outbounds"`
BlacklistTimeout badoption.Duration `json:"blacklist_timeout,omitempty"`
} }
type GroupCommonOption struct { type GroupCommonOption struct {

View File

@@ -69,3 +69,8 @@ type RateLimiterUser struct {
Count uint32 `json:"count"` Count uint32 `json:"count"`
Interval badoption.Duration `json:"interval"` Interval badoption.Duration `json:"interval"`
} }
type FairQueueOutboundOptions struct {
FlowKeys []string `json:"flow_keys,omitempty"`
Outbound string `json:"outbound"`
}

View File

@@ -18,6 +18,8 @@ type MASQUEOutboundOptions struct {
UDPKeepalivePeriod badoption.Duration `json:"udp_keepalive_period,omitempty"` UDPKeepalivePeriod badoption.Duration `json:"udp_keepalive_period,omitempty"`
UDPInitialPacketSize uint16 `json:"udp_initial_packet_size,omitempty"` UDPInitialPacketSize uint16 `json:"udp_initial_packet_size,omitempty"`
ReconnectDelay badoption.Duration `json:"reconnect_delay,omitempty"` ReconnectDelay badoption.Duration `json:"reconnect_delay,omitempty"`
CongestionController string `json:"congestion_controller,omitempty"`
CWND int `json:"cwnd,omitempty"`
MASQUEOutboundTLSOptionsContainer MASQUEOutboundTLSOptionsContainer
} }

View File

@@ -25,6 +25,7 @@ type OpenVPNOutboundOptions struct {
KeyDirection int `json:"key_direction,omitempty"` KeyDirection int `json:"key_direction,omitempty"`
ReconnectDelay badoption.Duration `json:"reconnect_delay,omitempty"` ReconnectDelay badoption.Duration `json:"reconnect_delay,omitempty"`
PingInterval badoption.Duration `json:"ping_interval,omitempty"` PingInterval badoption.Duration `json:"ping_interval,omitempty"`
PingRestart badoption.Duration `json:"ping_restart,omitempty"`
OpenVPNOutboundTLSOptionsContainer OpenVPNOutboundTLSOptionsContainer
} }

24
option/snell.go Normal file
View File

@@ -0,0 +1,24 @@
package option
type SnellOutboundOptions struct {
DialerOptions
ServerOptions
PSK string `json:"psk"`
Version int `json:"version,omitempty"`
Reuse bool `json:"reuse,omitempty"`
Network NetworkList `json:"network,omitempty"`
Obfs *SnellObfsOptions `json:"obfs,omitempty"`
}
type SnellInboundOptions struct {
ListenOptions
PSK string `json:"psk"`
Version int `json:"version,omitempty"`
Network NetworkList `json:"network,omitempty"`
Obfs *SnellObfsOptions `json:"obfs,omitempty"`
}
type SnellObfsOptions struct {
Mode string `json:"mode,omitempty"`
Host string `json:"host,omitempty"`
}

View File

@@ -6,7 +6,6 @@ type TrustTunnelInboundOptions struct {
Users []TrustTunnelUser `json:"users,omitempty"` Users []TrustTunnelUser `json:"users,omitempty"`
Network NetworkList `json:"network,omitempty"` Network NetworkList `json:"network,omitempty"`
CongestionController string `json:"congestion_controller,omitempty"` CongestionController string `json:"congestion_controller,omitempty"`
BBRProfile string `json:"bbr_profile,omitempty"`
CWND int `json:"cwnd,omitempty"` CWND int `json:"cwnd,omitempty"`
} }
@@ -32,7 +31,6 @@ type TrustTunnelOutboundOptions struct {
HealthCheck bool `json:"health_check,omitempty"` HealthCheck bool `json:"health_check,omitempty"`
QUIC bool `json:"quic,omitempty"` QUIC bool `json:"quic,omitempty"`
CongestionController string `json:"congestion_controller,omitempty"` CongestionController string `json:"congestion_controller,omitempty"`
BBRProfile string `json:"bbr_profile,omitempty"`
CWND int `json:"cwnd,omitempty"` CWND int `json:"cwnd,omitempty"`
Multiplex *TrustTunnelMultiplexOptions `json:"multiplex,omitempty"` Multiplex *TrustTunnelMultiplexOptions `json:"multiplex,omitempty"`
} }

View File

@@ -4,7 +4,6 @@ import (
"net/http" "net/http"
"strings" "strings"
Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption"
"github.com/sagernet/sing-box/common/xray/utils" "github.com/sagernet/sing-box/common/xray/utils"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@@ -119,13 +118,13 @@ type V2RayXHTTPBaseOptions struct {
Path string `json:"path,omitempty"` Path string `json:"path,omitempty"`
Headers map[string]string `json:"headers,omitempty"` Headers map[string]string `json:"headers,omitempty"`
DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"`
XPaddingBytes Xbadoption.Range `json:"x_padding_bytes"` XPaddingBytes badoption.Range[int] `json:"x_padding_bytes"`
NoGRPCHeader bool `json:"no_grpc_header,omitempty"` NoGRPCHeader bool `json:"no_grpc_header,omitempty"`
NoSSEHeader bool `json:"no_sse_header,omitempty"` NoSSEHeader bool `json:"no_sse_header,omitempty"`
ScMaxEachPostBytes *Xbadoption.Range `json:"sc_max_each_post_bytes"` ScMaxEachPostBytes *badoption.Range[int] `json:"sc_max_each_post_bytes"`
ScMinPostsIntervalMs *Xbadoption.Range `json:"sc_min_posts_interval_ms"` ScMinPostsIntervalMs *badoption.Range[int] `json:"sc_min_posts_interval_ms"`
ScMaxBufferedPosts int64 `json:"sc_max_buffered_posts,omitempty"` ScMaxBufferedPosts int64 `json:"sc_max_buffered_posts,omitempty"`
ScStreamUpServerSecs *Xbadoption.Range `json:"sc_stream_up_server_secs"` ScStreamUpServerSecs *badoption.Range[int] `json:"sc_stream_up_server_secs"`
ServerMaxHeaderBytes int `json:"server_max_header_bytes"` ServerMaxHeaderBytes int `json:"server_max_header_bytes"`
TrustedXForwardedFor badoption.Listable[string] `json:"trusted_x_forwarded_for,omitempty"` TrustedXForwardedFor badoption.Listable[string] `json:"trusted_x_forwarded_for,omitempty"`
Xmux *V2RayXHTTPXmuxOptions `json:"xmux"` Xmux *V2RayXHTTPXmuxOptions `json:"xmux"`
@@ -141,7 +140,11 @@ type V2RayXHTTPBaseOptions struct {
SeqKey string `json:"seq_key,omitempty"` SeqKey string `json:"seq_key,omitempty"`
UplinkDataPlacement string `json:"uplink_data_placement,omitempty"` UplinkDataPlacement string `json:"uplink_data_placement,omitempty"`
UplinkDataKey string `json:"uplink_data_key,omitempty"` UplinkDataKey string `json:"uplink_data_key,omitempty"`
UplinkChunkSize *Xbadoption.Range `json:"uplink_chunk_size,omitempty"` UplinkChunkSize *badoption.Range[int] `json:"uplink_chunk_size,omitempty"`
SessionIDTable string `json:"session_id_table,omitempty"`
SessionIDLength badoption.Range[int] `json:"session_id_length,omitempty"`
CongestionController string `json:"congestion_controller,omitempty"`
CWND int `json:"cwnd,omitempty"`
} }
type _V2RayXHTTPOptions struct { type _V2RayXHTTPOptions struct {
@@ -302,6 +305,10 @@ func checkV2RayXHTTPBaseOptions(mode string, options *V2RayXHTTPBaseOptions) err
return E.New("invalid negative value of maxHeaderBytes") return E.New("invalid negative value of maxHeaderBytes")
} }
if mode != "stream-one" && mode != "stream-up" && options.GetNormalizedScMaxEachPostBytes().From <= 0 {
return E.New("`scMaxEachPostBytes` should be bigger than 0")
}
if options.Xmux == nil { if options.Xmux == nil {
options.Xmux = &V2RayXHTTPXmuxOptions{} options.Xmux = &V2RayXHTTPXmuxOptions{}
options.Xmux.MaxConcurrency.From = 1 options.Xmux.MaxConcurrency.From = 1
@@ -346,9 +353,9 @@ func (c *V2RayXHTTPBaseOptions) GetRequestHeader() http.Header {
return header return header
} }
func (c *V2RayXHTTPBaseOptions) GetNormalizedXPaddingBytes() Xbadoption.Range { func (c *V2RayXHTTPBaseOptions) GetNormalizedXPaddingBytes() badoption.Range[int] {
if c.XPaddingBytes.To == 0 { if c.XPaddingBytes.To == 0 {
return Xbadoption.Range{ return badoption.Range[int]{
From: 100, From: 100,
To: 1000, To: 1000,
} }
@@ -363,9 +370,9 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkHTTPMethod() string {
return c.UplinkHTTPMethod return c.UplinkHTTPMethod
} }
func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxEachPostBytes() Xbadoption.Range { func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxEachPostBytes() badoption.Range[int] {
if c.ScMaxEachPostBytes == nil { if c.ScMaxEachPostBytes == nil {
return Xbadoption.Range{ return badoption.Range[int]{
From: 1000000, From: 1000000,
To: 1000000, To: 1000000,
} }
@@ -373,9 +380,9 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxEachPostBytes() Xbadoption.Ran
return *c.ScMaxEachPostBytes return *c.ScMaxEachPostBytes
} }
func (c *V2RayXHTTPBaseOptions) GetNormalizedScMinPostsIntervalMs() Xbadoption.Range { func (c *V2RayXHTTPBaseOptions) GetNormalizedScMinPostsIntervalMs() badoption.Range[int] {
if c.ScMinPostsIntervalMs == nil { if c.ScMinPostsIntervalMs == nil {
return Xbadoption.Range{ return badoption.Range[int]{
From: 30, From: 30,
To: 30, To: 30,
} }
@@ -391,9 +398,9 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxBufferedPosts() int {
return int(c.ScMaxBufferedPosts) return int(c.ScMaxBufferedPosts)
} }
func (c *V2RayXHTTPBaseOptions) GetNormalizedScStreamUpServerSecs() Xbadoption.Range { func (c *V2RayXHTTPBaseOptions) GetNormalizedScStreamUpServerSecs() badoption.Range[int] {
if c.ScStreamUpServerSecs == nil { if c.ScStreamUpServerSecs == nil {
return Xbadoption.Range{ return badoption.Range[int]{
From: 20, From: 20,
To: 80, To: 80,
} }
@@ -401,16 +408,16 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedScStreamUpServerSecs() Xbadoption.R
return *c.ScStreamUpServerSecs return *c.ScStreamUpServerSecs
} }
func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkChunkSize() Xbadoption.Range { func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkChunkSize() badoption.Range[int] {
if c.UplinkChunkSize == nil || c.UplinkChunkSize.To == 0 { if c.UplinkChunkSize == nil || c.UplinkChunkSize.To == 0 {
switch c.UplinkDataPlacement { switch c.UplinkDataPlacement {
case PlacementCookie: case PlacementCookie:
return Xbadoption.Range{ return badoption.Range[int]{
From: 2 * 1024, // 2 KiB From: 2 * 1024, // 2 KiB
To: 3 * 1024, // 3 KiB To: 3 * 1024, // 3 KiB
} }
case PlacementHeader: case PlacementHeader:
return Xbadoption.Range{ return badoption.Range[int]{
From: 3 * 1000, // 3 KB From: 3 * 1000, // 3 KB
To: 4 * 1000, // 4 KB To: 4 * 1000, // 4 KB
} }
@@ -418,7 +425,7 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkChunkSize() Xbadoption.Range
return c.GetNormalizedScMaxEachPostBytes() return c.GetNormalizedScMaxEachPostBytes()
} }
} else if c.UplinkChunkSize.From < 64 { } else if c.UplinkChunkSize.From < 64 {
return Xbadoption.Range{ return badoption.Range[int]{
From: 64, From: 64,
To: max(64, c.UplinkChunkSize.To), To: max(64, c.UplinkChunkSize.To),
} }
@@ -485,31 +492,31 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedSeqKey() string {
} }
type V2RayXHTTPXmuxOptions struct { type V2RayXHTTPXmuxOptions struct {
MaxConcurrency Xbadoption.Range `json:"max_concurrency"` MaxConcurrency badoption.Range[int] `json:"max_concurrency"`
MaxConnections Xbadoption.Range `json:"max_connections"` MaxConnections badoption.Range[int] `json:"max_connections"`
CMaxReuseTimes Xbadoption.Range `json:"c_max_reuse_times"` CMaxReuseTimes badoption.Range[int] `json:"c_max_reuse_times"`
HMaxRequestTimes Xbadoption.Range `json:"h_max_request_times"` HMaxRequestTimes badoption.Range[int] `json:"h_max_request_times"`
HMaxReusableSecs Xbadoption.Range `json:"h_max_reusable_secs"` HMaxReusableSecs badoption.Range[int] `json:"h_max_reusable_secs"`
HKeepAlivePeriod int64 `json:"h_keep_alive_period"` HKeepAlivePeriod int64 `json:"h_keep_alive_period"`
} }
func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConcurrency() Xbadoption.Range { func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConcurrency() badoption.Range[int] {
return m.MaxConcurrency return m.MaxConcurrency
} }
func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConnections() Xbadoption.Range { func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConnections() badoption.Range[int] {
return m.MaxConnections return m.MaxConnections
} }
func (m *V2RayXHTTPXmuxOptions) GetNormalizedCMaxReuseTimes() Xbadoption.Range { func (m *V2RayXHTTPXmuxOptions) GetNormalizedCMaxReuseTimes() badoption.Range[int] {
return m.CMaxReuseTimes return m.CMaxReuseTimes
} }
func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxRequestTimes() Xbadoption.Range { func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxRequestTimes() badoption.Range[int] {
return m.HMaxRequestTimes return m.HMaxRequestTimes
} }
func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxReusableSecs() Xbadoption.Range { func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxReusableSecs() badoption.Range[int] {
return m.HMaxReusableSecs return m.HMaxReusableSecs
} }

View File

@@ -3,7 +3,6 @@ package option
import ( import (
"net/netip" "net/netip"
Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption"
"github.com/sagernet/sing/common/json/badoption" "github.com/sagernet/sing/common/json/badoption"
) )
@@ -40,10 +39,10 @@ type WireGuardAmnezia struct {
S2 int `json:"s2,omitempty"` S2 int `json:"s2,omitempty"`
S3 int `json:"s3,omitempty"` S3 int `json:"s3,omitempty"`
S4 int `json:"s4,omitempty"` S4 int `json:"s4,omitempty"`
H1 *Xbadoption.Range `json:"h1,omitempty"` H1 *badoption.Range[uint32] `json:"h1,omitempty"`
H2 *Xbadoption.Range `json:"h2,omitempty"` H2 *badoption.Range[uint32] `json:"h2,omitempty"`
H3 *Xbadoption.Range `json:"h3,omitempty"` H3 *badoption.Range[uint32] `json:"h3,omitempty"`
H4 *Xbadoption.Range `json:"h4,omitempty"` H4 *badoption.Range[uint32] `json:"h4,omitempty"`
I1 string `json:"i1,omitempty"` I1 string `json:"i1,omitempty"`
I2 string `json:"i2,omitempty"` I2 string `json:"i2,omitempty"`
I3 string `json:"i3,omitempty"` I3 string `json:"i3,omitempty"`

View File

@@ -80,6 +80,7 @@ func (h *Inbound) Start(stage adapter.StartStage) error {
} }
func (h *Inbound) Close() error { func (h *Inbound) Close() error {
h.conns.Close()
errs := make([]error, 0) errs := make([]error, 0)
for _, inbound := range h.inbounds { for _, inbound := range h.inbounds {
err := inbound.Close() err := inbound.Close()

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"net" "net"
"sync" "sync"
"time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
@@ -31,14 +32,19 @@ type Fallback struct {
tags []string tags []string
outbounds map[string]adapter.Outbound outbounds map[string]adapter.Outbound
lastUsedOutbound string lastUsedOutbound string
blacklistTimeout time.Duration
mtx sync.Mutex blacklist map[string]time.Time
mtx sync.Mutex
} }
func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FallbackOutboundOptions) (adapter.Outbound, error) { func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FallbackOutboundOptions) (adapter.Outbound, error) {
if len(options.Outbounds) == 0 { if len(options.Outbounds) == 0 {
return nil, E.New("missing tags") return nil, E.New("missing tags")
} }
blacklistTimeout := time.Duration(options.BlacklistTimeout)
if blacklistTimeout == 0 {
blacklistTimeout = time.Minute
}
outbound := &Fallback{ outbound := &Fallback{
Adapter: outbound.NewAdapter(C.TypeFallback, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeFallback, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
ctx: ctx, ctx: ctx,
@@ -47,6 +53,8 @@ func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextL
tags: options.Outbounds, tags: options.Outbounds,
outbounds: make(map[string]adapter.Outbound, len(options.Outbounds)), outbounds: make(map[string]adapter.Outbound, len(options.Outbounds)),
lastUsedOutbound: options.Outbounds[0], lastUsedOutbound: options.Outbounds[0],
blacklistTimeout: blacklistTimeout,
blacklist: make(map[string]time.Time),
} }
return outbound, nil return outbound, nil
} }
@@ -73,35 +81,110 @@ func (s *Fallback) All() []string {
} }
func (s *Fallback) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (s *Fallback) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
var conn net.Conn s.mtx.Lock()
var active, blacklisted []string
for _, tag := range s.tags {
if s.isBlacklisted(tag) {
blacklisted = append(blacklisted, tag)
} else {
active = append(active, tag)
}
}
s.mtx.Unlock()
var err error var err error
for _, outbound := range s.outbounds { for _, tag := range active {
conn, err = outbound.DialContext(ctx, network, destination) var conn net.Conn
conn, err = s.outbounds[tag].DialContext(ctx, network, destination)
if err != nil {
s.logger.InfoContext(ctx, err)
s.mtx.Lock()
s.addToBlacklist(tag)
s.mtx.Unlock()
continue
}
s.mtx.Lock()
s.lastUsedOutbound = tag
s.mtx.Unlock()
return conn, nil
}
for _, tag := range blacklisted {
var conn net.Conn
conn, err = s.outbounds[tag].DialContext(ctx, network, destination)
if err != nil { if err != nil {
s.logger.InfoContext(ctx, err) s.logger.InfoContext(ctx, err)
continue continue
} }
s.mtx.Lock() s.mtx.Lock()
defer s.mtx.Unlock() delete(s.blacklist, tag)
s.lastUsedOutbound = outbound.Tag() s.lastUsedOutbound = tag
s.mtx.Unlock()
return conn, nil return conn, nil
} }
return nil, err return nil, err
} }
func (s *Fallback) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (s *Fallback) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
var conn net.PacketConn s.mtx.Lock()
var active, blacklisted []string
for _, tag := range s.tags {
if s.isBlacklisted(tag) {
blacklisted = append(blacklisted, tag)
} else {
active = append(active, tag)
}
}
s.mtx.Unlock()
var err error var err error
for _, outbound := range s.outbounds { for _, tag := range active {
conn, err = outbound.ListenPacket(ctx, destination) var conn net.PacketConn
conn, err = s.outbounds[tag].ListenPacket(ctx, destination)
if err != nil {
s.logger.InfoContext(ctx, err)
s.mtx.Lock()
s.addToBlacklist(tag)
s.mtx.Unlock()
continue
}
s.mtx.Lock()
s.lastUsedOutbound = tag
s.mtx.Unlock()
return conn, nil
}
for _, tag := range blacklisted {
var conn net.PacketConn
conn, err = s.outbounds[tag].ListenPacket(ctx, destination)
if err != nil { if err != nil {
s.logger.InfoContext(ctx, err) s.logger.InfoContext(ctx, err)
continue continue
} }
s.mtx.Lock() s.mtx.Lock()
defer s.mtx.Unlock() delete(s.blacklist, tag)
s.lastUsedOutbound = outbound.Tag() s.lastUsedOutbound = tag
s.mtx.Unlock()
return conn, nil return conn, nil
} }
return nil, err return nil, err
} }
func (s *Fallback) isBlacklisted(tag string) bool {
if s.blacklistTimeout == 0 {
return false
}
expiry, ok := s.blacklist[tag]
if !ok {
return false
}
if time.Now().After(expiry) {
delete(s.blacklist, tag)
return false
}
return true
}
func (s *Fallback) addToBlacklist(tag string) {
if s.blacklistTimeout > 0 {
s.blacklist[tag] = time.Now().Add(s.blacklistTimeout)
}
}

View File

@@ -2,11 +2,11 @@ package bandwidth
import ( import (
"context" "context"
"slices"
"sync" "sync"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/list"
) )
type BandwidthLimiter interface { type BandwidthLimiter interface {
@@ -14,123 +14,144 @@ type BandwidthLimiter interface {
SetSpeed(speed uint64) SetSpeed(speed uint64)
} }
type FlowKeysLimiter struct { type FairQueueLimiter struct {
limiter BandwidthLimiter limiter BandwidthLimiter
connIDGetter ConnIDGetter connIDGetter ConnIDGetter
waits map[string][]*wait flows *list.List[*flow]
conns map[string]int index map[string]*list.Element[*flow]
bytes map[string]uint64
pool sync.Pool
queue chan struct{} queue chan struct{}
reset time.Time reset time.Time
mtx sync.Mutex mtx sync.Mutex
} }
func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FlowKeysLimiter { func NewFairQueueLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FairQueueLimiter {
return &FlowKeysLimiter{ return &FairQueueLimiter{
limiter: limiter, limiter: limiter,
connIDGetter: connIDGetter, connIDGetter: connIDGetter,
waits: make(map[string][]*wait), flows: list.New[*flow](),
conns: make(map[string]int), index: make(map[string]*list.Element[*flow]),
bytes: make(map[string]uint64),
pool: sync.Pool{New: func() any { return list.New[*request]() }},
queue: make(chan struct{}, 1), queue: make(chan struct{}, 1),
reset: time.Now().Add(time.Second), reset: time.Now().Add(time.Second),
} }
} }
func (l *FlowKeysLimiter) SetSpeed(speed uint64) { func (l *FairQueueLimiter) SetSpeed(speed uint64) {
l.limiter.SetSpeed(speed) l.limiter.SetSpeed(speed)
} }
func (l *FlowKeysLimiter) WaitN(ctx context.Context, n int) error { func (l *FairQueueLimiter) WaitN(ctx context.Context, n int) error {
id, _ := l.connIDGetter(ctx, adapter.ContextFrom(ctx)) id, _ := l.connIDGetter(ctx, adapter.ContextFrom(ctx))
mainWait := &wait{ctx, make(chan struct{}), n} mainRequest := &request{ctx: ctx, done: make(chan struct{}), n: n}
l.mtx.Lock() l.mtx.Lock()
if waits, ok := l.waits[id]; ok { elem, ok := l.index[id]
l.waits[id] = append(waits, mainWait) if !ok {
} else { f := &flow{id: id, pending: l.pool.Get().(*list.List[*request])}
l.waits[id] = []*wait{mainWait} elem = l.flows.PushFront(f)
l.index[id] = elem
} }
mainRequestElem := elem.Value.pending.PushBack(mainRequest)
l.reorder(elem)
l.mtx.Unlock() l.mtx.Unlock()
select { select {
case l.queue <- struct{}{}: case l.queue <- struct{}{}:
case <-mainWait.finish: case <-mainRequest.done:
return nil return nil
case <-ctx.Done(): case <-ctx.Done():
l.mtx.Lock() l.mtx.Lock()
for i, wait := range l.waits[id] { l.removeRequest(id, mainRequestElem)
if wait == mainWait {
l.waits[id] = slices.Delete(l.waits[id], i, i+1)
close(wait.finish)
break
}
}
l.mtx.Unlock() l.mtx.Unlock()
return ctx.Err() return ctx.Err()
} }
select {
case <-mainRequest.done:
<-l.queue
return nil
default:
}
for { for {
if ctx.Err() != nil { if ctx.Err() != nil {
l.mtx.Lock() l.mtx.Lock()
for i, wait := range l.waits[id] { l.removeRequest(id, mainRequestElem)
if wait == mainWait {
l.waits[id] = slices.Delete(l.waits[id], i, i+1)
close(wait.finish)
break
}
}
l.mtx.Unlock() l.mtx.Unlock()
<-l.queue <-l.queue
return ctx.Err() return ctx.Err()
} }
l.mtx.Lock()
now := time.Now() now := time.Now()
if l.reset.Compare(now) == -1 { if l.reset.Compare(now) == -1 {
clear(l.conns) clear(l.bytes)
l.reset = now.Add(time.Second) l.reset = now.Add(time.Second)
} }
l.mtx.Lock() flowElem := l.flows.Front()
var minConnId string flow := flowElem.Value
var minN int firstRequestElem := flow.pending.Front()
for connID, waits := range l.waits { firstRequest := firstRequestElem.Value
if len(waits) == 0 { l.bytes[flow.id] += uint64(firstRequest.n)
continue firstRequestElem.Remove()
} if flow.pending.Len() == 0 {
if n, ok := l.conns[connID]; ok { l.flows.Remove(flowElem)
if minConnId == "" { delete(l.index, flow.id)
minConnId = connID l.pool.Put(flow.pending)
minN = n } else {
continue l.reorder(flowElem)
}
if n+waits[0].n < minN {
minConnId = connID
minN = n
}
} else {
l.conns[connID] = 0
minConnId = connID
break
}
}
minWait := l.waits[minConnId][0]
l.waits[minConnId][0] = nil
l.waits[minConnId] = l.waits[minConnId][1:]
if len(l.waits) == 0 {
delete(l.waits, minConnId)
} }
l.mtx.Unlock() l.mtx.Unlock()
err := l.limiter.WaitN(ctx, minWait.n) l.limiter.WaitN(firstRequest.ctx, firstRequest.n)
if err != nil { close(firstRequest.done)
continue if firstRequest == mainRequest {
}
l.conns[minConnId] = l.conns[minConnId] + minWait.n
close(minWait.finish)
if minWait == mainWait {
<-l.queue <-l.queue
return nil return nil
} }
} }
} }
type wait struct { func (l *FairQueueLimiter) reorder(elem *list.Element[*flow]) {
ctx context.Context f := elem.Value
finish chan struct{} front := f.pending.Front()
n int if front == nil {
return
}
cost := l.bytes[f.id] + uint64(front.Value.n)
for e := l.flows.Front(); e != nil; e = e.Next() {
if e == elem {
continue
}
eFront := e.Value.pending.Front()
if eFront == nil {
continue
}
if cost < l.bytes[e.Value.id]+uint64(eFront.Value.n) {
l.flows.MoveBefore(elem, e)
return
}
}
l.flows.MoveToBack(elem)
}
func (l *FairQueueLimiter) removeRequest(id string, elem *list.Element[*request]) {
if !elem.Remove() {
return
}
if flowElem, ok := l.index[id]; ok && flowElem.Value.pending.Len() == 0 {
l.flows.Remove(flowElem)
delete(l.index, id)
l.pool.Put(flowElem.Value.pending)
}
}
type flow struct {
id string
pending *list.List[*request]
}
type request struct {
ctx context.Context
done chan struct{}
n int
} }

View File

@@ -357,7 +357,7 @@ func createSpeedLimiter(speed uint64, flowKeys []string) (BandwidthLimiter, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
limiter = NewFlowKeysLimiter(getter, limiter) limiter = NewFairQueueLimiter(getter, limiter)
} }
return limiter, nil return limiter, nil
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/cloudflare" "github.com/sagernet/sing-box/common/cloudflare"
"github.com/sagernet/sing-box/common/congestion"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
@@ -23,6 +24,7 @@ import (
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -132,6 +134,15 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
logger.ErrorContext(ctx, err) logger.ErrorContext(ctx, err)
return return
} }
congestionControl, err := congestion.NewCongestionControl(
options.CongestionController,
options.CWND,
ntp.TimeFuncFromContext(ctx),
)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
tunnel, err := masque.NewTunnel( tunnel, err := masque.NewTunnel(
ctx, ctx,
logger, logger,
@@ -156,6 +167,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
UDPKeepalivePeriod: udpKeepalivePeriod, UDPKeepalivePeriod: udpKeepalivePeriod,
UDPInitialPacketSize: options.UDPInitialPacketSize, UDPInitialPacketSize: options.UDPInitialPacketSize,
ReconnectDelay: options.ReconnectDelay.Build(), ReconnectDelay: options.ReconnectDelay.Build(),
CongestionControl: congestionControl,
}, },
) )
if err != nil { if err != nil {

View File

@@ -104,6 +104,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
AllowedAddress: options.AllowedIPs, AllowedAddress: options.AllowedIPs,
ReconnectDelay: time.Duration(options.ReconnectDelay), ReconnectDelay: time.Duration(options.ReconnectDelay),
PingInterval: time.Duration(options.PingInterval), PingInterval: time.Duration(options.PingInterval),
PingRestart: time.Duration(options.PingRestart),
}) })
if err != nil { if err != nil {
return nil, err return nil, err

130
protocol/snell/inbound.go Normal file
View File

@@ -0,0 +1,130 @@
package snell
import (
"context"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/listener"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/snell"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.SnellInboundOptions](registry, C.TypeSnell, NewInbound)
}
type Inbound struct {
inbound.Adapter
router adapter.ConnectionRouterEx
logger logger.ContextLogger
listener *listener.Listener
service *snell.Service
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellInboundOptions) (adapter.Inbound, error) {
if options.PSK == "" {
return nil, E.New("snell requires psk")
}
udpEnabled := common.Contains(options.Network.Build(), N.NetworkUDP)
obfsMode := ""
if options.Obfs != nil {
obfsMode = options.Obfs.Mode
}
in := &Inbound{
Adapter: inbound.NewAdapter(C.TypeSnell, tag),
router: router,
logger: logger,
}
service, err := snell.NewService(snell.ServiceOptions{
PSK: []byte(options.PSK),
Version: options.Version,
ObfsMode: obfsMode,
UDP: udpEnabled,
Logger: logger,
Handler: (*inboundHandler)(in),
})
if err != nil {
return nil, err
}
in.service = service
in.listener = listener.New(listener.Options{
Context: ctx,
Logger: logger,
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
ConnectionHandler: in,
})
return in, nil
}
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start()
}
func (h *Inbound) Close() error {
return h.listener.Close()
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
err := h.service.NewConnection(ctx, conn, metadata.Source)
N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil {
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
}
}
var _ adapter.TCPInjectableInbound = (*Inbound)(nil)
type inboundHandler Inbound
func (h *inboundHandler) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, clientID string) {
var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.Source = source
metadata.Destination = destination
if clientID != "" {
metadata.User = clientID
h.logger.InfoContext(ctx, "[", clientID, "] inbound connection to ", destination)
} else {
h.logger.InfoContext(ctx, "inbound connection to ", destination)
}
done := make(chan struct{})
h.router.RouteConnectionEx(ctx, conn, metadata, N.OnceClose(func(error) {
close(done)
}))
<-done
}
func (h *inboundHandler) NewPacketConnection(ctx context.Context, conn net.PacketConn, source M.Socksaddr, clientID string) {
var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.InboundDetour = h.listener.ListenOptions().Detour
metadata.Source = source
if clientID != "" {
metadata.User = clientID
h.logger.InfoContext(ctx, "[", clientID, "] inbound packet connection")
} else {
h.logger.InfoContext(ctx, "inbound packet connection")
}
done := make(chan struct{})
h.router.RoutePacketConnectionEx(ctx, bufio.NewPacketConn(conn), metadata, N.OnceClose(func(error) {
close(done)
}))
<-done
}

114
protocol/snell/outbound.go Normal file
View File

@@ -0,0 +1,114 @@
package snell
import (
"context"
"fmt"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/snell"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.SnellOutboundOptions](registry, C.TypeSnell, NewOutbound)
}
type Outbound struct {
outbound.Adapter
logger logger.ContextLogger
client *snell.Client
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellOutboundOptions) (adapter.Outbound, error) {
if options.PSK == "" {
return nil, E.New("snell requires psk")
}
version := options.Version
if version == 0 {
version = snell.DefaultSnellVersion
}
if version == snell.Version5 {
version = snell.Version4
}
udpEnabled := common.Contains(options.Network.Build(), N.NetworkUDP)
switch version {
case snell.Version1, snell.Version2:
if udpEnabled {
return nil, fmt.Errorf("snell version %d does not support UDP", version)
}
case snell.Version3, snell.Version4:
default:
return nil, fmt.Errorf("snell version error: %d", version)
}
reuse := version == snell.Version2 || (version == snell.Version4 && options.Reuse)
obfsMode := ""
obfsHost := "bing.com"
if options.Obfs != nil {
switch options.Obfs.Mode {
case "", "tls", "http":
obfsMode = options.Obfs.Mode
default:
return nil, fmt.Errorf("snell obfs mode error: %s", options.Obfs.Mode)
}
if options.Obfs.Host != "" {
obfsHost = options.Obfs.Host
}
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
if err != nil {
return nil, err
}
client := snell.NewClient(snell.ClientOptions{
Dialer: outboundDialer,
Server: options.ServerOptions.Build(),
PSK: []byte(options.PSK),
Version: version,
Reuse: reuse,
ObfsMode: obfsMode,
ObfsHost: obfsHost,
})
return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSnell, tag, options.Network.Build(), options.DialerOptions),
logger: logger,
client: client,
}, nil
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
switch N.NetworkName(network) {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
return h.client.DialContext(ctx, destination)
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
conn, err := h.client.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return bufio.NewBindPacketConn(conn, destination), nil
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
return h.client.ListenPacket(ctx, destination)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/sagernet/quic-go/http3" "github.com/sagernet/quic-go/http3"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/congestion"
"github.com/sagernet/sing-box/common/listener" "github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
@@ -136,10 +137,9 @@ func (h *Inbound) Start(stage adapter.StartStage) error {
if err != nil { if err != nil {
return err return err
} }
congestionControlFactory, err := trusttunnel.NewCongestionControl( congestionControlFactory, err := congestion.NewCongestionControl(
h.options.CongestionController, h.options.CongestionController,
h.options.CWND, h.options.CWND,
h.options.BBRProfile,
ntp.TimeFuncFromContext(h.ctx), ntp.TimeFuncFromContext(h.ctx),
) )
if err != nil { if err != nil {

View File

@@ -53,7 +53,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
QUIC: options.QUIC, QUIC: options.QUIC,
CongestionControl: options.CongestionController, CongestionControl: options.CongestionController,
CWND: options.CWND, CWND: options.CWND,
BBRProfile: options.BBRProfile, Logger: logger,
HealthCheck: options.HealthCheck, HealthCheck: options.HealthCheck,
} }
var client trusttunnel.Dialer var client trusttunnel.Dialer

View File

@@ -1 +1 @@
with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_naive_outbound,badlinkname,tfogo_checklinkname0 with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_snell,with_naive_outbound,badlinkname,tfogo_checklinkname0

View File

@@ -1 +1 @@
with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_manager,with_admin_panel,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,badlinkname,tfogo_checklinkname0 with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_manager,with_admin_panel,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_snell,badlinkname,tfogo_checklinkname0

View File

@@ -1 +1 @@
with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0 with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_snell,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0

View File

@@ -1,4 +1,4 @@
#!/bin/sh #!/bin/sh
[ -s ${IPKG_INSTROOT}/lib/functions.sh ] || exit 0 [ -s ${IPKG_INSTROOT}/lib/functions.sh ] || exit 0
. ${IPKG_INSTROOT}/lib/functions.sh . ${IPKG_INSTROOT}/lib/functions.sh
default_prerm $0 $@ default_prerm $0 $@ || true

View File

@@ -58,6 +58,13 @@ func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata
} }
func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
select {
case <-r.started:
case <-ctx.Done():
return ctx.Err()
case <-r.ctx.Done():
return r.ctx.Err()
}
//nolint:staticcheck //nolint:staticcheck
if metadata.InboundDetour != "" { if metadata.InboundDetour != "" {
if metadata.LastInbound == metadata.InboundDetour { if metadata.LastInbound == metadata.InboundDetour {
@@ -192,6 +199,13 @@ func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn,
} }
func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
select {
case <-r.started:
case <-ctx.Done():
return ctx.Err()
case <-r.ctx.Done():
return r.ctx.Err()
}
//nolint:staticcheck //nolint:staticcheck
if metadata.InboundDetour != "" { if metadata.InboundDetour != "" {
if metadata.LastInbound == metadata.InboundDetour { if metadata.LastInbound == metadata.InboundDetour {

146
route/route_start_test.go Normal file
View File

@@ -0,0 +1,146 @@
package route
import (
"context"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
)
// newGateTestRouter builds the minimal Router needed to exercise the
// "wait until started" gate in routeConnection / routePacketConnection.
func newGateTestRouter(ctx context.Context) *Router {
return &Router{
ctx: ctx,
started: make(chan struct{}),
}
}
// gateMetadata returns metadata that hits the InboundDetour loop-detection
// branch (LastInbound == InboundDetour), so routeConnection /
// routePacketConnection return immediately once the gate is open, without
// needing any outbound/inbound managers.
func gateMetadata() adapter.InboundContext {
return adapter.InboundContext{InboundDetour: "self", LastInbound: "self"}
}
// TestRouteConnectionWaitsForStart verifies that a connection arriving before
// the router finishes starting (StartStatePostStart) blocks until the started
// channel is closed, then proceeds, instead of dereferencing a nil
// defaultOutbound.
func TestRouteConnectionWaitsForStart(t *testing.T) {
r := newGateTestRouter(context.Background())
done := make(chan error, 1)
go func() {
done <- r.routeConnection(context.Background(), nil, gateMetadata(), nil)
}()
// The gate must block while the router is not yet started.
select {
case <-done:
t.Fatal("routeConnection returned before router was started")
case <-time.After(50 * time.Millisecond):
}
// Simulate StartStatePostStart completing.
close(r.started)
select {
case err := <-done:
// We expect to have passed the gate and reached the loop-detection branch,
// which returns the "routing loop on detour" error.
if err == nil {
t.Fatal("expected routing-loop error after gate opened, got nil")
}
case <-time.After(time.Second):
t.Fatal("routeConnection did not proceed after router started")
}
}
// TestRoutePacketConnectionWaitsForStart is the UDP counterpart.
func TestRoutePacketConnectionWaitsForStart(t *testing.T) {
r := newGateTestRouter(context.Background())
done := make(chan error, 1)
go func() {
done <- r.routePacketConnection(context.Background(), nil, gateMetadata(), nil)
}()
select {
case <-done:
t.Fatal("routePacketConnection returned before router was started")
case <-time.After(50 * time.Millisecond):
}
close(r.started)
select {
case err := <-done:
if err == nil {
t.Fatal("expected routing-loop error after gate opened, got nil")
}
case <-time.After(time.Second):
t.Fatal("routePacketConnection did not proceed after router started")
}
}
// TestRouteConnectionAbortsOnConnContext verifies that a client disconnecting
// during startup unblocks the gate via the per-connection context, instead of
// hanging until the router starts.
func TestRouteConnectionAbortsOnConnContext(t *testing.T) {
r := newGateTestRouter(context.Background())
connCtx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- r.routeConnection(connCtx, nil, gateMetadata(), nil)
}()
select {
case <-done:
t.Fatal("routeConnection returned before context was cancelled")
case <-time.After(50 * time.Millisecond):
}
cancel()
select {
case err := <-done:
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("routeConnection did not abort on connection context cancellation")
}
}
// TestRouteConnectionAbortsOnRouterContext verifies that shutting down the box
// (router context cancellation) unblocks an in-flight early connection.
func TestRouteConnectionAbortsOnRouterContext(t *testing.T) {
routerCtx, cancel := context.WithCancel(context.Background())
r := newGateTestRouter(routerCtx)
done := make(chan error, 1)
go func() {
done <- r.routeConnection(context.Background(), nil, gateMetadata(), nil)
}()
select {
case <-done:
t.Fatal("routeConnection returned before router context was cancelled")
case <-time.After(50 * time.Millisecond):
}
cancel()
select {
case err := <-done:
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("routeConnection did not abort on router context cancellation")
}
}

View File

@@ -44,7 +44,7 @@ type Router struct {
pauseManager pause.Manager pauseManager pause.Manager
trackers []adapter.ConnectionTracker trackers []adapter.ConnectionTracker
platformInterface adapter.PlatformInterface platformInterface adapter.PlatformInterface
started bool started chan struct{}
} }
func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) *Router { func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) *Router {
@@ -63,6 +63,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess, needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess,
pauseManager: service.FromContext[pause.Manager](ctx), pauseManager: service.FromContext[pause.Manager](ctx),
platformInterface: service.FromContext[adapter.PlatformInterface](ctx), platformInterface: service.FromContext[adapter.PlatformInterface](ctx),
started: make(chan struct{}),
} }
} }
@@ -180,7 +181,7 @@ func (r *Router) Start(stage adapter.StartStage) error {
} else { } else {
r.defaultOutbound = r.outbound.Default() r.defaultOutbound = r.outbound.Default()
} }
r.started = true close(r.started)
return nil return nil
case adapter.StartStateStarted: case adapter.StartStateStarted:
for _, ruleSet := range r.ruleSets { for _, ruleSet := range r.ruleSets {

View File

@@ -29,13 +29,17 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
case "": case "":
return nil, nil return nil, nil
case C.RuleActionTypeRoute: case C.RuleActionTypeRoute:
overrideGateway := M.ParseAddr(action.RouteOptions.OverrideGateway) var overrideGateway *netip.Addr
if action.RouteOptions.OverrideGateway != "" {
parsed := M.ParseAddr(action.RouteOptions.OverrideGateway)
overrideGateway = &parsed
}
return &RuleActionRoute{ return &RuleActionRoute{
Outbound: action.RouteOptions.Outbound, Outbound: action.RouteOptions.Outbound,
RuleActionRouteOptions: RuleActionRouteOptions{ RuleActionRouteOptions: RuleActionRouteOptions{
OverrideAddress: M.ParseSocksaddrHostPort(action.RouteOptions.OverrideAddress, 0), OverrideAddress: M.ParseSocksaddrHostPort(action.RouteOptions.OverrideAddress, 0),
OverridePort: action.RouteOptions.OverridePort, OverridePort: action.RouteOptions.OverridePort,
OverrideGateway: &overrideGateway, OverrideGateway: overrideGateway,
NetworkStrategy: (*C.NetworkStrategy)(action.RouteOptions.NetworkStrategy), NetworkStrategy: (*C.NetworkStrategy)(action.RouteOptions.NetworkStrategy),
FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay), FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay),
UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping,

View File

@@ -1,6 +1,6 @@
module test module test
go 1.26.1 go 1.26.4
require github.com/sagernet/sing-box v0.0.0 require github.com/sagernet/sing-box v0.0.0
@@ -14,15 +14,17 @@ replace github.com/sagernet/sing-mux => github.com/shtorm-7/sing-mux v0.3.4-exte
replace github.com/ameshkov/dnscrypt/v2 => github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 replace github.com/ameshkov/dnscrypt/v2 => github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0
replace github.com/sagernet/sing-vmess => github.com/starifly/sing-vmess v0.2.7-mod.9 replace github.com/sagernet/sing-vmess => github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0
replace github.com/sagernet/sing => github.com/shtorm-7/sing v0.8.10-extended-1.0.0 replace github.com/sagernet/sing => /home/shtorm/Projects/shtorm-7/sing
replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.1 replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0
replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0
replace github.com/shtorm-7/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 replace github.com/shtorm-7/go-cache/v2 => /home/shtorm/Projects/shtorm-7/go-cache
replace github.com/sagernet/smux => /home/shtorm/Projects/shtorm-7/smux
require ( require (
github.com/docker/docker v28.5.2+incompatible github.com/docker/docker v28.5.2+incompatible
@@ -36,7 +38,6 @@ require (
github.com/spyzhov/ajson v0.9.4 github.com/spyzhov/ajson v0.9.4
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
go.uber.org/goleak v1.3.0 go.uber.org/goleak v1.3.0
golang.org/x/crypto v0.49.0
golang.org/x/net v0.52.0 golang.org/x/net v0.52.0
) )
@@ -221,6 +222,7 @@ require (
go.uber.org/zap/exp v0.3.0 // indirect go.uber.org/zap/exp v0.3.0 // indirect
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
golang.org/x/mod v0.34.0 // indirect golang.org/x/mod v0.34.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect

View File

@@ -362,8 +362,6 @@ github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA=
github.com/sagernet/sing-tun v0.8.9 h1:ixFKKUGdVcJl4wb0xbL36hobiw9l6DIH497EQf5ILpM= github.com/sagernet/sing-tun v0.8.9 h1:ixFKKUGdVcJl4wb0xbL36hobiw9l6DIH497EQf5ILpM=
github.com/sagernet/sing-tun v0.8.9/go.mod h1:QvarqUtHfj1ULaRR+6kZOS/OoCE+pYGq67A5tyIy+dQ= github.com/sagernet/sing-tun v0.8.9/go.mod h1:QvarqUtHfj1ULaRR+6kZOS/OoCE+pYGq67A5tyIy+dQ=
github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478=
github.com/sagernet/smux v1.5.50-sing-box-mod.1/go.mod h1:NjhsCEWedJm7eFLyhuBgIEzwfhRmytrUoiLluxs5Sk8=
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 h1:ws7BIsYLd31Wjifq88BYCHRVlgO+07iwil39s6ERba8= github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 h1:ws7BIsYLd31Wjifq88BYCHRVlgO+07iwil39s6ERba8=
@@ -372,12 +370,14 @@ github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 h1:e5s7RKBd2rIPR0StbvZ2vTV
github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI= github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI=
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 h1:PLZ/YHqnApPx13wt6MX3ItqESp4ueBr1tGSi0bEGqYw= github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 h1:PLZ/YHqnApPx13wt6MX3ItqESp4ueBr1tGSi0bEGqYw=
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4= github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4=
github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.1 h1:UeJkrCJJmIjTBywErVMx7fCSoBf4gh6QgT9bp9o1ajM= github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 h1:iBLll4ZZG8ULQcHWs6gGslZWtBN72Zo1zjySzMVHF7g=
github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.1/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0= github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0=
github.com/shtorm-7/sing v0.8.10-extended-1.0.0 h1:mAkyycCQOzCttPOR5fcHkJaZvXMQXeu3mbEfr8D+7A8= github.com/shtorm-7/sing v0.8.10-extended-1.1.0 h1:P4JL2cugjvEvnYu8tMmpR30SE1qsS45RcnNEwzDz5as=
github.com/shtorm-7/sing v0.8.10-extended-1.0.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA= github.com/shtorm-7/sing v0.8.10-extended-1.1.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA=
github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0 h1:a5OoXr3e2ACbM6vDIaaGL44IdHQ6wPjcSoU13vfC0Sw= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0 h1:a5OoXr3e2ACbM6vDIaaGL44IdHQ6wPjcSoU13vfC0Sw=
github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk=
github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0 h1:WVheKmQH5hSQbJU1ZTKthKSutkTLWSb2hp4JuQhJBow=
github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs=
github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2 h1:hSMjh97OszszOd8HrzpaYUQH9dWRRBluJCbwQyz8ZOk= github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2 h1:hSMjh97OszszOd8HrzpaYUQH9dWRRBluJCbwQyz8ZOk=
github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2/go.mod h1:TYIIqO5sZpWq873rLIeO2usszSMUpR3h6WdqVVs65ug= github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2/go.mod h1:TYIIqO5sZpWq873rLIeO2usszSMUpR3h6WdqVVs65ug=
github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.4.3 h1:jtOA73D4F5qRV70//ahOt20KBnWvQimAFjtIiOtt0ps= github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.4.3 h1:jtOA73D4F5qRV70//ahOt20KBnWvQimAFjtIiOtt0ps=
@@ -388,8 +388,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spyzhov/ajson v0.9.4 h1:MVibcTCgO7DY4IlskdqIlCmDOsUOZ9P7oKj8ifdcf84= github.com/spyzhov/ajson v0.9.4 h1:MVibcTCgO7DY4IlskdqIlCmDOsUOZ9P7oKj8ifdcf84=
github.com/spyzhov/ajson v0.9.4/go.mod h1:a6oSw0MMb7Z5aD2tPoPO+jq11ETKgXUr2XktHdT8Wt8= github.com/spyzhov/ajson v0.9.4/go.mod h1:a6oSw0MMb7Z5aD2tPoPO+jq11ETKgXUr2XktHdT8Wt8=
github.com/starifly/sing-vmess v0.2.7-mod.9 h1:xobAmejSbBQ0A3f/EtJ9cJd3m6gK7dDPccPdeGz7tXY=
github.com/starifly/sing-vmess v0.2.7-mod.9/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=

View File

@@ -217,3 +217,19 @@ func TestV2RayGRPCLite(t *testing.T) {
}) })
}) })
} }
func TestV2RayGRPCLiteServiceNameWithSlashes(t *testing.T) {
testV2RayTransportSelfWith(t, &option.V2RayTransportOptions{
Type: C.V2RayTransportTypeGRPC,
GRPCOptions: option.V2RayGRPCOptions{
ServiceName: "/47559/I53GwKHO",
ForceLite: true,
},
}, &option.V2RayTransportOptions{
Type: C.V2RayTransportTypeGRPC,
GRPCOptions: option.V2RayGRPCOptions{
ServiceName: "/47559/I53GwKHO",
ForceLite: true,
},
})
}

View File

@@ -0,0 +1,331 @@
package masque
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/sagernet/quic-go/quicvarint"
"github.com/yosida95/uritemplate/v3"
"golang.org/x/net/http2"
)
const h2DatagramCapsuleType uint64 = 0
const (
ipv4HeaderLen = 20
ipv6HeaderLen = 40
)
func ConnectTunnelH2(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, endpoint *net.TCPAddr, connectUri string) (io.Closer, IpConn, *http.Response, error) {
if endpoint == nil {
return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
}
tlsConfig.SetNextProtos([]string{"h2"})
conn, err := dialer.DialContext(ctx, N.NetworkTCP, M.SocksaddrFromNetIP(endpoint.AddrPort()))
if err != nil {
return nil, nil, nil, err
}
tlsConn, err := tlsConfig.Client(conn)
if err != nil {
_ = conn.Close()
return nil, nil, nil, err
}
if err = tlsConn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, nil, nil, err
}
tr := &http2.Transport{
ReadIdleTimeout: 30 * time.Second,
}
cc, err := tr.NewClientConn(tlsConn)
if err != nil {
_ = tlsConn.Close()
return nil, nil, nil, fmt.Errorf("connect-ip: failed to create client connection: %w", err)
}
additionalHeaders := http.Header{
"User-Agent": []string{""},
}
template := uritemplate.MustNew(connectUri)
h2Headers := additionalHeaders.Clone()
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
h2Headers.Set("pq-enabled", "false")
ipConn, rsp, err := dialH2(ctx, cc, template, h2Headers)
if err != nil {
_ = cc.Close()
if strings.Contains(err.Error(), "tls: access denied") {
return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
}
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
}
if rsp.StatusCode != http.StatusOK {
_ = ipConn.Close()
_ = cc.Close()
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %v", rsp.Status)
}
return cc, ipConn, rsp, nil
}
func dialH2(ctx context.Context, rt http.RoundTripper, template *uritemplate.Template, additionalHeaders http.Header) (*h2IpConn, *http.Response, error) {
if len(template.Varnames()) > 0 {
return nil, nil, errors.New("connect-ip: IP flow forwarding not supported")
}
u, err := url.Parse(template.Raw())
if err != nil {
return nil, nil, fmt.Errorf("connect-ip: failed to parse URI: %w", err)
}
reqCtx, cancel := context.WithCancel(context.Background())
pr, pw := io.Pipe()
req, err := http.NewRequestWithContext(reqCtx, http.MethodConnect, u.String(), pr)
if err != nil {
cancel()
_ = pr.Close()
_ = pw.Close()
return nil, nil, fmt.Errorf("connect-ip: failed to create request: %w", err)
}
req.Host = authorityFromURL(u)
req.ContentLength = -1
req.Header = make(http.Header)
for k, v := range additionalHeaders {
req.Header[k] = v
}
stop := context.AfterFunc(ctx, cancel)
rsp, err := rt.RoundTrip(req)
stop()
if err != nil {
cancel()
_ = pr.Close()
_ = pw.Close()
return nil, nil, fmt.Errorf("connect-ip: failed to send request: %w", err)
}
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
cancel()
_ = pr.Close()
_ = pw.Close()
_ = rsp.Body.Close()
return nil, rsp, fmt.Errorf("connect-ip: server responded with %d", rsp.StatusCode)
}
stream := &h2DatagramStream{
requestBody: pw,
responseBody: rsp.Body,
cancel: cancel,
}
return &h2IpConn{
str: stream,
closeChan: make(chan struct{}),
}, rsp, nil
}
func authorityFromURL(u *url.URL) string {
if u.Port() != "" {
return u.Host
}
host := u.Hostname()
if host == "" {
return u.Host
}
return host + ":443"
}
type h2IpConn struct {
str *h2DatagramStream
mu sync.Mutex
closeChan chan struct{}
closeErr error
}
func (c *h2IpConn) ReadPacket() (b []byte, err error) {
start:
data, err := c.str.ReceiveDatagram(context.Background())
if err != nil {
defer func() {
_ = c.Close()
}()
select {
case <-c.closeChan:
return nil, c.closeErr
default:
return nil, err
}
}
if err := c.handleIncomingProxiedPacket(data); err != nil {
goto start
}
return data, nil
}
func (c *h2IpConn) handleIncomingProxiedPacket(data []byte) error {
if len(data) == 0 {
return errors.New("connect-ip: empty packet")
}
switch v := ipVersion(data); v {
default:
return fmt.Errorf("connect-ip: unknown IP versions: %d", v)
case 4:
if len(data) < ipv4HeaderLen {
return fmt.Errorf("connect-ip: malformed datagram: too short")
}
case 6:
if len(data) < ipv6HeaderLen {
return fmt.Errorf("connect-ip: malformed datagram: too short")
}
}
return nil
}
func (c *h2IpConn) WritePacket(b []byte) (icmp []byte, err error) {
data, err := c.composeDatagram(b)
if err != nil {
return nil, nil
}
if err := c.str.SendDatagram(data); err != nil {
select {
case <-c.closeChan:
return nil, c.closeErr
default:
return nil, err
}
}
return nil, nil
}
func (c *h2IpConn) composeDatagram(b []byte) ([]byte, error) {
if len(b) == 0 {
return nil, nil
}
switch v := ipVersion(b); v {
default:
return nil, fmt.Errorf("connect-ip: unknown IP versions: %d", v)
case 4:
if len(b) < ipv4HeaderLen {
return nil, fmt.Errorf("connect-ip: IPv4 packet too short")
}
ttl := b[8]
if ttl <= 1 {
return nil, fmt.Errorf("connect-ip: datagram TTL too small: %d", ttl)
}
b[8]--
binary.BigEndian.PutUint16(b[10:12], calculateIPv4Checksum(([ipv4HeaderLen]byte)(b[:ipv4HeaderLen])))
case 6:
if len(b) < ipv6HeaderLen {
return nil, fmt.Errorf("connect-ip: IPv6 packet too short")
}
hopLimit := b[7]
if hopLimit <= 1 {
return nil, fmt.Errorf("connect-ip: datagram Hop Limit too small: %d", hopLimit)
}
b[7]--
}
return b, nil
}
func (c *h2IpConn) Close() error {
c.mu.Lock()
if c.closeErr == nil {
c.closeErr = net.ErrClosed
close(c.closeChan)
}
c.mu.Unlock()
err := c.str.Close()
return err
}
func ipVersion(b []byte) uint8 { return b[0] >> 4 }
func calculateIPv4Checksum(header [ipv4HeaderLen]byte) uint16 {
var sum uint32
for i := 0; i < len(header); i += 2 {
if i == 10 {
continue
}
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
for (sum >> 16) > 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
type h2DatagramStream struct {
requestBody *io.PipeWriter
responseBody io.ReadCloser
cancel context.CancelFunc
readMu sync.Mutex
writeMu sync.Mutex
}
func (s *h2DatagramStream) ReceiveDatagram(_ context.Context) ([]byte, error) {
s.readMu.Lock()
defer s.readMu.Unlock()
reader := quicvarint.NewReader(s.responseBody)
for {
capsuleType, err := quicvarint.Read(reader)
if err != nil {
return nil, err
}
payloadLen, err := quicvarint.Read(reader)
if err != nil {
return nil, err
}
payload := make([]byte, payloadLen)
_, err = io.ReadFull(reader, payload)
if err != nil {
return nil, err
}
if capsuleType != h2DatagramCapsuleType {
continue
}
return payload, nil
}
}
func (s *h2DatagramStream) SendDatagram(data []byte) error {
frame := make([]byte, 0, quicvarint.Len(h2DatagramCapsuleType)+quicvarint.Len(uint64(len(data)))+len(data))
frame = quicvarint.Append(frame, h2DatagramCapsuleType)
frame = quicvarint.Append(frame, uint64(len(data)))
frame = append(frame, data...)
s.writeMu.Lock()
defer s.writeMu.Unlock()
_, err := s.requestBody.Write(frame)
if err != nil {
return fmt.Errorf("connect-ip: failed to send datagram capsule: %w", err)
}
return nil
}
func (s *h2DatagramStream) Close() error {
_ = s.requestBody.Close()
err := s.responseBody.Close()
s.cancel()
return err
}

View File

@@ -2,9 +2,9 @@ package masque
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@@ -12,13 +12,13 @@ import (
connectip "github.com/Diniboy1123/connect-ip-go" connectip "github.com/Diniboy1123/connect-ip-go"
"github.com/sagernet/quic-go" "github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/congestion"
"github.com/sagernet/quic-go/http3" "github.com/sagernet/quic-go/http3"
qtls "github.com/sagernet/sing-quic" qtls "github.com/sagernet/sing-quic"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls" aTLS "github.com/sagernet/sing/common/tls"
"github.com/yosida95/uritemplate/v3" "github.com/yosida95/uritemplate/v3"
"golang.org/x/net/http2"
) )
type ( type (
@@ -26,39 +26,60 @@ type (
ListenPacket func(network string, address string) (net.PacketConn, error) ListenPacket func(network string, address string) (net.PacketConn, error)
) )
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool) (net.PacketConn, *http3.Transport, *connectip.Conn, *http.Response, error) { type IpConn interface {
template := uritemplate.MustNew(connectUri) ReadPacket() (b []byte, err error)
additionalHeaders := http.Header{ WritePacket(b []byte) (icmp []byte, err error)
"User-Agent": []string{""}, Close() error
}
type closerFunc func() error
func (f closerFunc) Close() error { return f() }
type quicIpConn struct {
conn *connectip.Conn
buf []byte
}
func newQuicIpConn(conn *connectip.Conn) *quicIpConn {
return &quicIpConn{
conn: conn,
buf: make([]byte, 0xFFFF),
} }
}
func (c *quicIpConn) ReadPacket() ([]byte, error) {
n, err := c.conn.ReadPacket(c.buf, true)
if err != nil {
return nil, err
}
return c.buf[:n], nil
}
func (c *quicIpConn) WritePacket(b []byte) (icmp []byte, err error) {
return c.conn.WritePacket(b)
}
func (c *quicIpConn) Close() error {
return c.conn.Close()
}
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool, congestionControl func(conn *quic.Conn) congestion.CongestionControl) (io.Closer, IpConn, *http.Response, error) {
if useHTTP2 { if useHTTP2 {
h2Endpoint, ok := endpoint.(*net.TCPAddr) h2Endpoint, ok := endpoint.(*net.TCPAddr)
if !ok || h2Endpoint == nil { if !ok || h2Endpoint == nil {
return nil, nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint") return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
} }
h2Headers := additionalHeaders.Clone() return ConnectTunnelH2(ctx, dialer, tlsConfig, h2Endpoint, connectUri)
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
h2Headers.Set("pq-enabled", "false")
h2Client, err := newHTTP2Client(dialer, tlsConfig, h2Endpoint, connectUri)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to create HTTP/2 client: %w", err)
}
ipConn, rsp, err := connectip.DialH2(ctx, h2Client, template, h2Headers)
if err != nil {
if strings.Contains(err.Error(), "tls: access denied") {
return nil, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
}
return nil, nil, nil, nil, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
}
return nil, nil, ipConn, rsp, nil
} }
quicEndpoint, ok := endpoint.(*net.UDPAddr) quicEndpoint, ok := endpoint.(*net.UDPAddr)
if !ok || quicEndpoint == nil { if !ok || quicEndpoint == nil {
return nil, nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint") return nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint")
} }
udpConn, err := dialer.ListenPacket(ctx, M.SocksaddrFromNetIP(quicEndpoint.AddrPort())) udpConn, err := dialer.ListenPacket(ctx, M.SocksaddrFromNetIP(quicEndpoint.AddrPort()))
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, err
} }
conn, err := qtls.Dial( conn, err := qtls.Dial(
ctx, ctx,
@@ -68,28 +89,34 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
quicConfig, quicConfig,
) )
if err != nil { if err != nil {
return nil, nil, nil, nil, err _ = udpConn.Close()
return nil, nil, nil, err
}
if congestionControl != nil {
conn.SetCongestionControl(congestionControl(conn))
} }
tr := &http3.Transport{ tr := &http3.Transport{
EnableDatagrams: true, EnableDatagrams: true,
AdditionalSettings: map[uint64]uint64{ AdditionalSettings: map[uint64]uint64{
// official client still sends this out as well, even though
// it's deprecated, see https://datatracker.ietf.org/doc/draft-ietf-masque-h3-datagram/00/
// SETTINGS_H3_DATAGRAM_00 = 0x0000000000000276
// https://github.com/cloudflare/quiche/blob/7c66757dbc55b8d0c3653d4b345c6785a181f0b7/quiche/src/h3/frame.rs#L46
0x276: 1, 0x276: 1,
}, },
DisableCompression: true, DisableCompression: true,
} }
hconn := tr.NewClientConn(conn) hconn := tr.NewClientConn(conn)
template := uritemplate.MustNew(connectUri)
additionalHeaders := http.Header{
"User-Agent": []string{""},
}
ipConn, rsp, err := connectip.Dial(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true) ipConn, rsp, err := connectip.Dial(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true)
if err != nil { if err != nil {
_ = tr.Close() _ = tr.Close()
_ = conn.CloseWithError(0, "connect-ip dial failed") _ = conn.CloseWithError(0, "connect-ip dial failed")
_ = udpConn.Close()
if strings.Contains(err.Error(), "tls: access denied") { if strings.Contains(err.Error(), "tls: access denied") {
return udpConn, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
} }
return udpConn, nil, nil, nil, fmt.Errorf("failed to dial connect-ip: %w", err) return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %w", err)
} }
err = ipConn.AdvertiseRoute(ctx, []connectip.IPRoute{ err = ipConn.AdvertiseRoute(ctx, []connectip.IPRoute{
{ {
@@ -109,34 +136,16 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
}, },
}) })
if err != nil { if err != nil {
return udpConn, nil, nil, nil, err _ = ipConn.Close()
_ = tr.Close()
_ = udpConn.Close()
return nil, nil, nil, err
} }
return udpConn, tr, ipConn, rsp, nil
}
func newHTTP2Client(dialer N.Dialer, baseTLSConfig aTLS.Config, endpoint *net.TCPAddr, connectURI string) (*http.Client, error) { closer := closerFunc(func() error {
if endpoint == nil { _ = tr.Close()
return nil, errors.New("missing HTTP/2 endpoint") _ = udpConn.Close()
} return nil
tlsConfig := baseTLSConfig.Clone() })
tlsConfig.SetNextProtos([]string{"h2"}) return closer, newQuicIpConn(ipConn), rsp, nil
return &http.Client{
Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, _ string, _ *tls.Config) (net.Conn, error) {
conn, err := dialer.DialContext(ctx, network, M.SocksaddrFromNetIP(endpoint.AddrPort()))
if err != nil {
return nil, err
}
tlsConn, err := tlsConfig.Client(conn)
if err != nil {
return nil, err
}
if err := tlsConn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, err
}
return tlsConn, nil
},
},
}, nil
} }

View File

@@ -5,6 +5,8 @@ import (
"net/netip" "net/netip"
"time" "time"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/congestion"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/tls" "github.com/sagernet/sing/common/tls"
) )
@@ -23,4 +25,5 @@ type TunnelOptions struct {
UDPKeepalivePeriod time.Duration UDPKeepalivePeriod time.Duration
UDPInitialPacketSize uint16 UDPInitialPacketSize uint16
ReconnectDelay time.Duration ReconnectDelay time.Duration
CongestionControl func(conn *quic.Conn) congestion.CongestionControl
} }

View File

@@ -4,13 +4,12 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync" "sync"
"time" "time"
connectip "github.com/Diniboy1123/connect-ip-go"
"github.com/sagernet/quic-go/http3"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@@ -22,9 +21,8 @@ type Tunnel struct {
options TunnelOptions options TunnelOptions
device Device device Device
udpConn net.PacketConn closer io.Closer
tr *http3.Transport ipConn IpConn
ipConn *connectip.Conn
mtx sync.Mutex mtx sync.Mutex
} }
@@ -83,13 +81,11 @@ func (e *Tunnel) Close() error {
defer e.mtx.Unlock() defer e.mtx.Unlock()
if e.ipConn != nil { if e.ipConn != nil {
e.ipConn.Close() e.ipConn.Close()
if e.udpConn != nil { if e.closer != nil {
e.udpConn.Close() e.closer.Close()
}
if e.tr != nil {
e.tr.Close()
} }
e.ipConn = nil e.ipConn = nil
e.closer = nil
} }
return e.device.Close() return e.device.Close()
} }
@@ -124,7 +120,7 @@ func (e *Tunnel) maintainTunnel() {
} }
icmp, err := ipConn.WritePacket(packet) icmp, err := ipConn.WritePacket(packet)
if err != nil { if err != nil {
if errors.As(err, new(*connectip.CloseError)) { if errors.Is(err, net.ErrClosed) {
if ok := e.closeIpConn(ipConn); ok { if ok := e.closeIpConn(ipConn); ok {
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing to IP connection: %w", err)) e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing to IP connection: %w", err))
} }
@@ -135,7 +131,7 @@ func (e *Tunnel) maintainTunnel() {
} }
if len(icmp) > 0 { if len(icmp) > 0 {
if _, err := e.device.Write([][]byte{icmp}, 0); err != nil { if _, err := e.device.Write([][]byte{icmp}, 0); err != nil {
if errors.As(err, new(*connectip.CloseError)) { if errors.Is(err, net.ErrClosed) {
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing ICMP to TUN device: %v", err)) e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing ICMP to TUN device: %v", err))
continue continue
} }
@@ -145,15 +141,14 @@ func (e *Tunnel) maintainTunnel() {
} }
}() }()
go func() { go func() {
buf := make([]byte, 1280)
for e.ctx.Err() == nil { for e.ctx.Err() == nil {
ipConn, err := e.getIpConn() ipConn, err := e.getIpConn()
if err != nil { if err != nil {
return return
} }
n, err := ipConn.ReadPacket(buf, true) packet, err := ipConn.ReadPacket()
if err != nil { if err != nil {
if e.options.UseHTTP2 || errors.As(err, new(*connectip.CloseError)) { if e.options.UseHTTP2 || errors.Is(err, net.ErrClosed) {
if ok := e.closeIpConn(ipConn); ok { if ok := e.closeIpConn(ipConn); ok {
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while reading from IP connection: %v", err)) e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while reading from IP connection: %v", err))
} }
@@ -162,7 +157,7 @@ func (e *Tunnel) maintainTunnel() {
e.logger.ErrorContext(e.ctx, fmt.Errorf("Error reading from IP connection: %v, continuine...", err)) e.logger.ErrorContext(e.ctx, fmt.Errorf("Error reading from IP connection: %v, continuine...", err))
continue continue
} }
if _, err := e.device.Write([][]byte{buf[:n]}, 0); err != nil { if _, err := e.device.Write([][]byte{packet}, 0); err != nil {
continue continue
} }
} }
@@ -170,7 +165,7 @@ func (e *Tunnel) maintainTunnel() {
<-e.ctx.Done() <-e.ctx.Done()
} }
func (e *Tunnel) getIpConn() (*connectip.Conn, error) { func (e *Tunnel) getIpConn() (IpConn, error) {
e.mtx.Lock() e.mtx.Lock()
defer e.mtx.Unlock() defer e.mtx.Unlock()
if e.ctx.Err() != nil { if e.ctx.Err() != nil {
@@ -184,7 +179,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
defer timer.Stop() defer timer.Stop()
for { for {
e.logger.NoticeContext(e.ctx, fmt.Errorf("Establishing MASQUE connection to %s", e.options.Endpoint)) e.logger.NoticeContext(e.ctx, fmt.Errorf("Establishing MASQUE connection to %s", e.options.Endpoint))
udpConn, tr, ipConn, rsp, err := ConnectTunnel( closer, ipConn, rsp, err := ConnectTunnel(
e.ctx, e.ctx,
e.options.Dialer, e.options.Dialer,
e.options.TLSConfig, e.options.TLSConfig,
@@ -192,6 +187,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
"https://cloudflareaccess.com", "https://cloudflareaccess.com",
e.options.Endpoint, e.options.Endpoint,
e.options.UseHTTP2, e.options.UseHTTP2,
e.options.CongestionControl,
) )
if err != nil { if err != nil {
e.logger.ErrorContext(e.ctx, fmt.Errorf("Failed to connect tunnel: %v", err)) e.logger.ErrorContext(e.ctx, fmt.Errorf("Failed to connect tunnel: %v", err))
@@ -206,11 +202,8 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
if rsp.StatusCode != 200 { if rsp.StatusCode != 200 {
e.logger.ErrorContext(e.ctx, fmt.Errorf("Tunnel connection failed: %s", rsp.Status)) e.logger.ErrorContext(e.ctx, fmt.Errorf("Tunnel connection failed: %s", rsp.Status))
ipConn.Close() ipConn.Close()
if udpConn != nil { if closer != nil {
udpConn.Close() closer.Close()
}
if tr != nil {
tr.Close()
} }
timer.Reset(e.options.ReconnectDelay) timer.Reset(e.options.ReconnectDelay)
select { select {
@@ -220,26 +213,23 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
} }
continue continue
} }
e.udpConn = udpConn e.closer = closer
e.tr = tr
e.ipConn = ipConn e.ipConn = ipConn
e.logger.NoticeContext(e.ctx, "Connected to MASQUE server ", e.options.Endpoint) e.logger.NoticeContext(e.ctx, "Connected to MASQUE server ", e.options.Endpoint)
return ipConn, nil return ipConn, nil
} }
} }
func (e *Tunnel) closeIpConn(ipConn *connectip.Conn) bool { func (e *Tunnel) closeIpConn(ipConn IpConn) bool {
e.mtx.Lock() e.mtx.Lock()
defer e.mtx.Unlock() defer e.mtx.Unlock()
if ipConn == e.ipConn { if ipConn == e.ipConn {
e.ipConn.Close() e.ipConn.Close()
if e.udpConn != nil { if e.closer != nil {
e.udpConn.Close() e.closer.Close()
}
if e.tr != nil {
e.tr.Close()
} }
e.ipConn = nil e.ipConn = nil
e.closer = nil
return true return true
} }
return false return false

View File

@@ -4,6 +4,7 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/hmac" "crypto/hmac"
"crypto/md5"
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
@@ -23,7 +24,7 @@ const (
type DataCipher interface { type DataCipher interface {
Encrypt(header []byte, packetID uint32, payload []byte) ([]byte, error) Encrypt(header []byte, packetID uint32, payload []byte) ([]byte, error)
Decrypt(packet []byte, headerSize int) ([]byte, error) Decrypt(packet []byte, headerSize int) (plaintext []byte, packetID uint32, err error)
} }
type AEADDataCipher struct { type AEADDataCipher struct {
@@ -86,9 +87,9 @@ func (g *AEADDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
return out, nil return out, nil
} }
func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) { func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
if len(packet) < headerSize+4+AESGCMTagSize+1 { if len(packet) < headerSize+4+AESGCMTagSize+1 {
return nil, errors.New("openvpn gcm data packet too short") return nil, 0, errors.New("openvpn gcm data packet too short")
} }
header := packet[:headerSize] header := packet[:headerSize]
pidBytes := packet[headerSize : headerSize+4] pidBytes := packet[headerSize : headerSize+4]
@@ -96,8 +97,13 @@ func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error)
ciphertext := packet[headerSize+4+AESGCMTagSize:] ciphertext := packet[headerSize+4+AESGCMTagSize:]
combined := append(ciphertext, tag...) combined := append(ciphertext, tag...)
ad := append(header, pidBytes...) ad := append(header, pidBytes...)
nonce := g.nonce(binary.BigEndian.Uint32(pidBytes), g.recvImplicitIV) packetID := binary.BigEndian.Uint32(pidBytes)
return g.recv.Open(nil, nonce[:], combined, ad) nonce := g.nonce(packetID, g.recvImplicitIV)
plain, err := g.recv.Open(nil, nonce[:], combined, ad)
if err != nil {
return nil, 0, err
}
return plain, packetID, nil
} }
func (g *AEADDataCipher) nonce(packetID uint32, implicit [AESGCMIVSize]byte) [AESGCMIVSize]byte { func (g *AEADDataCipher) nonce(packetID uint32, implicit [AESGCMIVSize]byte) [AESGCMIVSize]byte {
@@ -127,6 +133,9 @@ func NewCBCCipher(keys *KeyMaterial, auth string) (*CBCDataCipher, error) {
var newHash func() hash.Hash var newHash func() hash.Hash
var hmacSize int var hmacSize int
switch auth { switch auth {
case AuthMD5:
newHash = md5.New
hmacSize = md5.Size
case AuthSHA256: case AuthSHA256:
newHash = sha256.New newHash = sha256.New
hmacSize = sha256.Size hmacSize = sha256.Size
@@ -176,34 +185,35 @@ func (c *CBCDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
return out, nil return out, nil
} }
func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) { func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
minSize := headerSize + c.hmacSize + CBCIVSize + aes.BlockSize minSize := headerSize + c.hmacSize + CBCIVSize + aes.BlockSize
if len(packet) < minSize { if len(packet) < minSize {
return nil, errors.New("openvpn cbc data packet too short") return nil, 0, errors.New("openvpn cbc data packet too short")
} }
tag := packet[headerSize : headerSize+c.hmacSize] tag := packet[headerSize : headerSize+c.hmacSize]
iv := packet[headerSize+c.hmacSize : headerSize+c.hmacSize+CBCIVSize] iv := packet[headerSize+c.hmacSize : headerSize+c.hmacSize+CBCIVSize]
ct := packet[headerSize+c.hmacSize+CBCIVSize:] ct := packet[headerSize+c.hmacSize+CBCIVSize:]
if len(ct)%aes.BlockSize != 0 { if len(ct)%aes.BlockSize != 0 {
return nil, errors.New("openvpn cbc ciphertext not block-aligned") return nil, 0, errors.New("openvpn cbc ciphertext not block-aligned")
} }
mac := hmac.New(c.newHash, c.recvHMAC) mac := hmac.New(c.newHash, c.recvHMAC)
mac.Write(iv) mac.Write(iv)
mac.Write(ct) mac.Write(ct)
if !hmac.Equal(tag, mac.Sum(nil)) { if !hmac.Equal(tag, mac.Sum(nil)) {
return nil, errors.New("openvpn cbc hmac verification failed") return nil, 0, errors.New("openvpn cbc hmac verification failed")
} }
plain := make([]byte, len(ct)) plain := make([]byte, len(ct))
cipher.NewCBCDecrypter(c.recvBlock, iv).CryptBlocks(plain, ct) cipher.NewCBCDecrypter(c.recvBlock, iv).CryptBlocks(plain, ct)
padLen := int(plain[len(plain)-1]) padLen := int(plain[len(plain)-1])
if padLen < 1 || padLen > aes.BlockSize { if padLen < 1 || padLen > aes.BlockSize {
return nil, errors.New("openvpn cbc invalid padding") return nil, 0, errors.New("openvpn cbc invalid padding")
} }
plain = plain[:len(plain)-padLen] plain = plain[:len(plain)-padLen]
if len(plain) < 4 { if len(plain) < 4 {
return nil, errors.New("openvpn cbc payload too short") return nil, 0, errors.New("openvpn cbc payload too short")
} }
return plain[4:], nil packetID := binary.BigEndian.Uint32(plain[:4])
return plain[4:], packetID, nil
} }
func CipherKeyLength(cipher string) int { func CipherKeyLength(cipher string) int {

View File

@@ -8,12 +8,16 @@ import (
"io" "io"
"net" "net"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/sagernet/sing/common/tls" "github.com/sagernet/sing/common/tls"
) )
const defaultHandshakeTimeout = 30 * time.Second const (
defaultHandshakeTimeout = 30 * time.Second
controlRetransmitDelay = time.Second
)
type Client struct { type Client struct {
config *ClientConfig config *ClientConfig
@@ -26,6 +30,8 @@ type Client struct {
push *PushReply push *PushReply
cancel context.CancelFunc cancel context.CancelFunc
lastReceiveNano atomic.Int64
} }
func NewClient(config *ClientConfig, io PacketIO, tlsConfig tls.Config) (*Client, error) { func NewClient(config *ClientConfig, io PacketIO, tlsConfig tls.Config) (*Client, error) {
@@ -154,6 +160,7 @@ func (c *Client) Handshake(ctx context.Context) (*PushReply, error) {
return nil, err return nil, err
} }
c.data = NewDataChannel(cipher, push.PeerID, push.CompLZO) c.data = NewDataChannel(cipher, push.PeerID, push.CompLZO)
c.markReceive()
return push, nil return push, nil
} }
@@ -181,10 +188,21 @@ func (c *Client) ReadIPPacket(ctx context.Context) ([]byte, error) {
if err != nil { if err != nil {
continue continue
} }
c.markReceive()
return plain, nil return plain, nil
} }
} }
func (c *Client) SinceReceive() time.Duration {
return time.Duration(int64(time.Since(clientStart)) - c.lastReceiveNano.Load())
}
func (c *Client) markReceive() {
c.lastReceiveNano.Store(int64(time.Since(clientStart)))
}
var clientStart = time.Now().Add(-time.Hour)
func (c *Client) Close() error { func (c *Client) Close() error {
if c.cancel != nil { if c.cancel != nil {
c.cancel() c.cancel()
@@ -199,10 +217,24 @@ func (c *Client) Close() error {
} }
func (c *Client) waitServerReset(ctx context.Context) error { func (c *Client) waitServerReset(ctx context.Context) error {
retransmits := 0
for { for {
packet, err := c.control.Read(ctx) readCtx := ctx
cancel := func() {}
if c.config.Proto == ProtoUDP {
readCtx, cancel = context.WithTimeout(ctx, controlRetransmitDelay)
}
packet, err := c.control.Read(readCtx)
cancel()
if err != nil { if err != nil {
return fmt.Errorf("read hard reset response: %w", err) if c.config.Proto == ProtoUDP && errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil {
if err := c.control.RetransmitPending(ctx); err != nil {
return fmt.Errorf("retransmit hard reset: %w", err)
}
retransmits++
continue
}
return fmt.Errorf("read hard reset response after %d retransmits: %w", retransmits, err)
} }
switch packet.Opcode { switch packet.Opcode {
case PControlHardResetServerV2: case PControlHardResetServerV2:

View File

@@ -20,6 +20,7 @@ const (
CipherAES256CBC = "AES-256-CBC" CipherAES256CBC = "AES-256-CBC"
CipherCHACHA20POLY = "CHACHA20-POLY1305" CipherCHACHA20POLY = "CHACHA20-POLY1305"
AuthMD5 = "MD5"
AuthSHA1 = "SHA1" AuthSHA1 = "SHA1"
AuthSHA256 = "SHA256" AuthSHA256 = "SHA256"
AuthSHA384 = "SHA384" AuthSHA384 = "SHA384"
@@ -107,7 +108,7 @@ func isValidCipher(cipher string) bool {
func isValidAuth(auth string) bool { func isValidAuth(auth string) bool {
switch auth { switch auth {
case AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512: case AuthMD5, AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512:
return true return true
} }
return false return false

View File

@@ -30,8 +30,10 @@ type ControlChannel struct {
mu sync.Mutex mu sync.Mutex
sendPacketID uint32 sendPacketID uint32
sendMessage uint32 sendMessage uint32
recvMessage uint32
ackPending []uint32 ackPending []uint32
pending map[uint32]*ControlPacket pending map[uint32]*ControlPacket
recvPending map[uint32]*ControlPacket
readDeadline time.Time readDeadline time.Time
writeDeadline time.Time writeDeadline time.Time
} }
@@ -40,9 +42,10 @@ func NewControlChannel(io PacketIO, crypt ControlCrypt, local SessionID) *Contro
ch := &ControlChannel{ ch := &ControlChannel{
io: io, io: io,
clock: time.Now, clock: time.Now,
local: local, local: local,
pending: make(map[uint32]*ControlPacket), pending: make(map[uint32]*ControlPacket),
recvPending: make(map[uint32]*ControlPacket),
} }
if crypt != nil { if crypt != nil {
ch.encode = func(p *ControlPacket, pid uint32, t uint32) ([]byte, error) { ch.encode = func(p *ControlPacket, pid uint32, t uint32) ([]byte, error) {
@@ -130,10 +133,23 @@ func (c *ControlChannel) SendAck(ctx context.Context) error {
func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) { func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
for { for {
c.mu.Lock()
if packet, ok := c.recvPending[c.recvMessage]; ok {
delete(c.recvPending, c.recvMessage)
c.recvMessage++
c.mu.Unlock()
return packet, nil
}
c.mu.Unlock()
packet, err := c.readControlPacket(ctx) packet, err := c.readControlPacket(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var deliver *ControlPacket
sendAck := false
c.mu.Lock() c.mu.Lock()
if c.remote == (SessionID{}) && packet.LocalSession != c.local { if c.remote == (SessionID{}) && packet.LocalSession != c.local {
c.remote = packet.LocalSession c.remote = packet.LocalSession
@@ -144,11 +160,33 @@ func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
if packet.Opcode.HasMessageID() { if packet.Opcode.HasMessageID() {
c.ackPending = appendAck(c.ackPending, packet.MessageID) c.ackPending = appendAck(c.ackPending, packet.MessageID)
} }
c.mu.Unlock()
if packet.Opcode == PAckV1 { switch {
continue case packet.Opcode == PAckV1:
case !packet.Opcode.HasMessageID():
deliver = packet
case packet.MessageID < c.recvMessage:
sendAck = true
case packet.MessageID == c.recvMessage:
deliver = packet
c.recvMessage++
default:
if _, exists := c.recvPending[packet.MessageID]; !exists {
c.recvPending[packet.MessageID] = packet
}
sendAck = true
}
c.mu.Unlock()
if deliver != nil {
return deliver, nil
}
if sendAck {
if err := c.SendAck(ctx); err != nil {
return nil, err
}
} }
return packet, nil
} }
} }
@@ -349,11 +387,17 @@ func (c *ControlConn) SetWriteDeadline(t time.Time) error {
} }
type streamPacketIO struct { type streamPacketIO struct {
conn net.Conn conn net.Conn
deadlineMu sync.Mutex
readDeadline time.Time
writeDeadline time.Time
} }
type datagramPacketIO struct { type datagramPacketIO struct {
conn net.Conn conn net.Conn
deadlineMu sync.Mutex
readDeadline time.Time
writeDeadline time.Time
} }
func NewDatagramPacketIO(conn net.Conn) PacketIO { func NewDatagramPacketIO(conn net.Conn) PacketIO {
@@ -361,40 +405,23 @@ func NewDatagramPacketIO(conn net.Conn) PacketIO {
} }
func (d *datagramPacketIO) ReadPacket(ctx context.Context) ([]byte, error) { func (d *datagramPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
done := make(chan struct{}) if err := setReadDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.readDeadline); err != nil {
var ( return nil, err
packet []byte
err error
)
go func() {
defer close(done)
buf := make([]byte, 64*1024)
var n int
n, err = d.conn.Read(buf)
if err == nil {
packet = cloneBytes(buf[:n])
}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return packet, err
} }
buf := make([]byte, 64*1024)
n, err := d.conn.Read(buf)
if err != nil {
return nil, contextIOError(ctx, err)
}
return buf[:n], nil
} }
func (d *datagramPacketIO) WritePacket(ctx context.Context, packet []byte) error { func (d *datagramPacketIO) WritePacket(ctx context.Context, packet []byte) error {
done := make(chan error, 1) if err := setWriteDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.writeDeadline); err != nil {
go func() {
_, err := d.conn.Write(packet)
done <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err return err
} }
_, err := d.conn.Write(packet)
return contextIOError(ctx, err)
} }
func (d *datagramPacketIO) Close() error { func (d *datagramPacketIO) Close() error {
@@ -414,52 +441,37 @@ func NewTCPPacketIO(conn net.Conn) PacketIO {
} }
func (s *streamPacketIO) ReadPacket(ctx context.Context) ([]byte, error) { func (s *streamPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
done := make(chan struct{}) if err := setReadDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.readDeadline); err != nil {
var ( return nil, err
packet []byte
err error
)
go func() {
defer close(done)
var lenBuf [2]byte
if _, err = io.ReadFull(s.conn, lenBuf[:]); err != nil {
return
}
size := int(lenBuf[0])<<8 | int(lenBuf[1])
if size == 0 {
err = errors.New("empty openvpn tcp packet")
return
}
packet = make([]byte, size)
_, err = io.ReadFull(s.conn, packet)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return packet, err
} }
var lenBuf [2]byte
if _, err := io.ReadFull(s.conn, lenBuf[:]); err != nil {
return nil, contextIOError(ctx, err)
}
size := int(lenBuf[0])<<8 | int(lenBuf[1])
if size == 0 {
return nil, errors.New("empty openvpn tcp packet")
}
packet := make([]byte, size)
if _, err := io.ReadFull(s.conn, packet); err != nil {
return nil, contextIOError(ctx, err)
}
return packet, nil
} }
func (s *streamPacketIO) WritePacket(ctx context.Context, packet []byte) error { func (s *streamPacketIO) WritePacket(ctx context.Context, packet []byte) error {
if len(packet) > 0xffff { if len(packet) > 0xffff {
return fmt.Errorf("openvpn tcp packet too large: %d", len(packet)) return fmt.Errorf("openvpn tcp packet too large: %d", len(packet))
} }
done := make(chan error, 1) if err := setWriteDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.writeDeadline); err != nil {
go func() {
frame := make([]byte, 2+len(packet))
frame[0] = byte(len(packet) >> 8)
frame[1] = byte(len(packet))
copy(frame[2:], packet)
_, err := s.conn.Write(frame)
done <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err return err
} }
frame := make([]byte, 2+len(packet))
frame[0] = byte(len(packet) >> 8)
frame[1] = byte(len(packet))
copy(frame[2:], packet)
_, err := s.conn.Write(frame)
return contextIOError(ctx, err)
} }
func (s *streamPacketIO) Close() error { func (s *streamPacketIO) Close() error {
@@ -473,3 +485,50 @@ func (s *streamPacketIO) LocalAddr() net.Addr {
func (s *streamPacketIO) RemoteAddr() net.Addr { func (s *streamPacketIO) RemoteAddr() net.Addr {
return s.conn.RemoteAddr() return s.conn.RemoteAddr()
} }
func setReadDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
deadline, hasDeadline := ctx.Deadline()
mu.Lock()
defer mu.Unlock()
if current.Equal(deadline) {
return nil
}
if hasDeadline {
if err := conn.SetReadDeadline(deadline); err != nil {
return err
}
} else if err := conn.SetReadDeadline(time.Time{}); err != nil {
return err
}
*current = deadline
return nil
}
func setWriteDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
deadline, hasDeadline := ctx.Deadline()
mu.Lock()
defer mu.Unlock()
if current.Equal(deadline) {
return nil
}
if hasDeadline {
if err := conn.SetWriteDeadline(deadline); err != nil {
return err
}
} else if err := conn.SetWriteDeadline(time.Time{}); err != nil {
return err
}
*current = deadline
return nil
}
func contextIOError(ctx context.Context, err error) error {
if err == nil {
return nil
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() && ctx.Err() != nil {
return ctx.Err()
}
return err
}

View File

@@ -8,15 +8,21 @@ import (
const ( const (
PeerIDUnset uint32 = 0xffffff PeerIDUnset uint32 = 0xffffff
dataChannelReplayWindow = 64
) )
type DataChannel struct { type DataChannel struct {
cipher DataCipher cipher DataCipher
keyID uint8 keyID uint8
peerID uint32 peerID uint32
compLZO bool compLZO bool
mu sync.Mutex mu sync.Mutex
sendPacketID uint32 sendPacketID uint32
recvHighest uint32
recvWindow uint64
recvSeen bool
} }
func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel { func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel {
@@ -29,10 +35,11 @@ func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel
func (d *DataChannel) Encrypt(packet []byte) ([]byte, error) { func (d *DataChannel) Encrypt(packet []byte) ([]byte, error) {
if d.compLZO { if d.compLZO {
p := make([]byte, 1+len(packet)) compressed, err := lzo1xCompressSafe(packet)
p[0] = 0xFA if err != nil {
copy(p[1:], packet) return nil, err
packet = p }
packet = compressed
} }
d.mu.Lock() d.mu.Lock()
d.sendPacketID++ d.sendPacketID++
@@ -50,18 +57,15 @@ func (d *DataChannel) Decrypt(packet []byte) ([]byte, error) {
if opcode == PDataV2 { if opcode == PDataV2 {
headerSize = 4 headerSize = 4
} }
plain, err := d.cipher.Decrypt(packet, headerSize) plain, packetID, err := d.cipher.Decrypt(packet, headerSize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := d.acceptPacketID(packetID); err != nil {
return nil, err
}
if d.compLZO { if d.compLZO {
if len(plain) < 1 { return lzo1xDecompressSafe(plain)
return nil, errors.New("openvpn comp-lzo packet too short")
}
if plain[0] != 0xFA {
return nil, fmt.Errorf("openvpn compressed packet not supported (byte: 0x%02x)", plain[0])
}
plain = plain[1:]
} }
return plain, nil return plain, nil
} }
@@ -78,6 +82,40 @@ func (d *DataChannel) dataHeader() []byte {
return []byte{opcodeKeyID(PDataV1, d.keyID)} return []byte{opcodeKeyID(PDataV1, d.keyID)}
} }
func (d *DataChannel) acceptPacketID(packetID uint32) error {
d.mu.Lock()
defer d.mu.Unlock()
if !d.recvSeen {
d.recvHighest = packetID
d.recvWindow = 1
d.recvSeen = true
return nil
}
if packetID > d.recvHighest {
shift := packetID - d.recvHighest
if shift >= dataChannelReplayWindow {
d.recvWindow = 1
} else {
d.recvWindow = d.recvWindow<<shift | 1
}
d.recvHighest = packetID
return nil
}
diff := d.recvHighest - packetID
if diff >= dataChannelReplayWindow {
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
}
mask := uint64(1) << diff
if d.recvWindow&mask != 0 {
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
}
d.recvWindow |= mask
return nil
}
func ParsePeerID(options string) uint32 { func ParsePeerID(options string) uint32 {
for _, field := range splitPushOptions(options) { for _, field := range splitPushOptions(options) {
if len(field) > len("peer-id ") && field[:len("peer-id ")] == "peer-id " { if len(field) > len("peer-id ") && field[:len("peer-id ")] == "peer-id " {

View File

@@ -0,0 +1,444 @@
//go:build with_openvpn && with_gvisor
// OpenVPN E2E tests. Require a local OpenVPN server setup.
//
// Setup (run once before testing):
//
// # Generate PKI
// mkdir -p /tmp/ovpn-e2e/pki/{issued,private}
// cd /tmp/ovpn-e2e/pki
// openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -days 1 -nodes -keyout ca.key -out ca.crt -subj "/CN=E2ETestCA"
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/server.key -out server.csr -subj "/CN=server"
// openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/server.crt -days 1
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/client.key -out client.csr -subj "/CN=client"
// openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/client.crt -days 1
// openvpn --genkey secret ta.key
// openvpn --genkey secret ta-auth.key
//
// # Start servers (4 instances: TCP/UDP × tls-crypt/tls-auth)
// # TCP + tls-crypt on :11940, subnet 10.99.0.0/24
// # UDP + tls-crypt on :11941, subnet 10.99.1.0/24
// # TCP + tls-auth on :11942, subnet 10.99.2.0/24
// # UDP + tls-auth on :11943, subnet 10.99.3.0/24
// #
// # Each server config needs: topology subnet, duplicate-cn, persist-tun,
// # data-ciphers AES-256-GCM:AES-128-GCM:AES-192-GCM:CHACHA20-POLY1305:AES-256-CBC:AES-128-CBC:AES-192-CBC
// # auth SHA256, keepalive 10 60, ca/cert/key from above PKI.
// # tls-auth servers use: tls-auth ta-auth.key 0
// # tls-crypt servers use: tls-crypt ta.key
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-crypt.conf --daemon
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-crypt.conf --daemon
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-auth.conf --daemon
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-auth.conf --daemon
//
// # Start HTTP servers on each VPN subnet
// for ip in 10.99.0.1 10.99.1.1 10.99.2.1 10.99.3.1; do
// mkdir -p /tmp/ovpn-e2e/$ip && echo "hello" > /tmp/ovpn-e2e/$ip/index.html
// cd /tmp/ovpn-e2e/$ip && python3 -m http.server 8080 --bind $ip &
// done
//
// Run tests:
//
// go test -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_manager,with_admin_panel,with_v2ray_api,with_ccm,with_ocm,with_profiler,with_openvpn,with_sudoku,with_trusttunnel" \
// -run TestE2E -v -count=1 ./transport/openvpn/ -timeout 300s
//
// Tests all 28 combinations: 2 protos (tcp/udp) × 2 TLS modes (tls-crypt/tls-auth) × 7 ciphers.
package openvpn_test
import (
"context"
"fmt"
"io"
"net"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/sagernet/sing-box"
"github.com/sagernet/sing-box/include"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/json/badoption"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/protocol/socks"
)
// Servers (started externally):
// TCP+tls-crypt :11940 subnet 10.99.0.0/24
// UDP+tls-crypt :11941 subnet 10.99.1.0/24
// TCP+tls-auth :11942 subnet 10.99.2.0/24
// UDP+tls-auth :11943 subnet 10.99.3.0/24
// TCP+plain :11944 subnet 10.99.4.0/24
// UDP+plain :11945 subnet 10.99.5.0/24
// TCP+tls-crypt+SHA1 :11946 subnet 10.99.6.0/24 (CBC only)
// TCP+tls-crypt+SHA512 :11947 subnet 10.99.7.0/24 (CBC only)
// Each has HTTP on .1:8080 serving "hello"
const pkiDir = "/tmp/ovpn-e2e/pki"
type serverConfig struct {
proto string
port uint16
tlsMode string // "tls-crypt" or "tls-auth"
httpAddr string
}
var servers = []serverConfig{
{"tcp", 11940, "tls-crypt", "10.99.0.1:8080"},
{"udp", 11941, "tls-crypt", "10.99.1.1:8080"},
{"tcp", 11942, "tls-auth", "10.99.2.1:8080"},
{"udp", 11943, "tls-auth", "10.99.3.1:8080"},
}
var ciphers = []string{
"AES-128-GCM",
"AES-192-GCM",
"AES-256-GCM",
"CHACHA20-POLY1305",
"AES-128-CBC",
"AES-192-CBC",
"AES-256-CBC",
}
var portCounter atomic.Uint32
func init() { portCounter.Store(18100) }
func nextPort() uint16 { return uint16(portCounter.Add(1)) }
func readFile(t *testing.T, path string) string {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Skipf("PKI not found: %v", err)
}
return string(data)
}
func testCombo(t *testing.T, srv serverConfig, cipher string) {
t.Helper()
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
ovpnOpts := &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: srv.port}},
Proto: srv.proto,
Cipher: cipher,
Auth: "SHA256",
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
}
switch srv.tlsMode {
case "tls-crypt":
ovpnOpts.TLSCrypt = readFile(t, pkiDir+"/ta.key")
case "tls-auth":
ovpnOpts.TLSAuth = readFile(t, pkiDir+"/ta-auth.key")
ovpnOpts.KeyDirection = 1
}
port := nextPort()
opts := option.Options{
Log: &option.LogOptions{Level: "error"},
Inbounds: []option.Inbound{{
Type: "socks",
Options: &option.SocksInboundOptions{
ListenOptions: option.ListenOptions{
Listen: (*badoption.Addr)(&badoption.Addr{}),
ListenPort: port,
},
},
}},
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
Route: &option.RouteOptions{Final: "vpn"},
}
ctx := include.Context(context.Background())
instance, err := box.New(box.Options{Context: ctx, Options: opts})
if err != nil {
t.Fatal(err)
}
if err := instance.Start(); err != nil {
t.Fatal(err)
}
defer instance.Close()
time.Sleep(2 * time.Second)
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(srv.httpAddr))
if err != nil {
t.Fatal("dial:", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
if err != nil {
t.Fatal("write:", err)
}
body, err := io.ReadAll(conn)
if err != nil {
t.Fatal("read:", err)
}
if !strings.Contains(string(body), "hello") {
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
}
}
// 4 servers × 7 ciphers = 28 combinations
func TestE2E(t *testing.T) {
for _, srv := range servers {
for _, cipher := range ciphers {
name := fmt.Sprintf("%s/%s/%s", srv.proto, srv.tlsMode, cipher)
srv, cipher := srv, cipher
t.Run(name, func(t *testing.T) {
testCombo(t, srv, cipher)
})
}
}
}
// Test CBC ciphers with different auth algorithms (SHA1, SHA512)
func TestE2E_Auth(t *testing.T) {
type authServer struct {
port uint16
auth string
httpAddr string
}
authServers := []authServer{
{11946, "SHA1", "10.99.6.1:8080"},
{11947, "SHA512", "10.99.7.1:8080"},
}
cbcCiphers := []string{"AES-128-CBC", "AES-256-CBC"}
for _, as := range authServers {
for _, cipher := range cbcCiphers {
name := fmt.Sprintf("auth-%s/%s", as.auth, cipher)
as, cipher := as, cipher
t.Run(name, func(t *testing.T) {
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
tlsCrypt := readFile(t, pkiDir+"/ta.key")
port := nextPort()
opts := option.Options{
Log: &option.LogOptions{Level: "error"},
Inbounds: []option.Inbound{{
Type: "socks",
Options: &option.SocksInboundOptions{
ListenOptions: option.ListenOptions{
Listen: (*badoption.Addr)(&badoption.Addr{}),
ListenPort: port,
},
},
}},
Outbounds: []option.Outbound{{
Type: "openvpn", Tag: "vpn",
Options: &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: as.port}},
Proto: "tcp",
Cipher: cipher,
Auth: as.auth,
TLSCrypt: tlsCrypt,
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
},
}},
Route: &option.RouteOptions{Final: "vpn"},
}
ctx := include.Context(context.Background())
instance, err := box.New(box.Options{Context: ctx, Options: opts})
if err != nil {
t.Fatal(err)
}
if err := instance.Start(); err != nil {
t.Fatal(err)
}
defer instance.Close()
time.Sleep(2 * time.Second)
doHTTPCheck(t, port, as.httpAddr)
})
}
}
}
// Test tunnel stability with multiple sequential requests
func TestE2E_BulkData(t *testing.T) {
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
tlsCrypt := readFile(t, pkiDir+"/ta.key")
port := nextPort()
opts := option.Options{
Log: &option.LogOptions{Level: "error"},
Inbounds: []option.Inbound{{
Type: "socks",
Options: &option.SocksInboundOptions{
ListenOptions: option.ListenOptions{
Listen: (*badoption.Addr)(&badoption.Addr{}),
ListenPort: port,
},
},
}},
Outbounds: []option.Outbound{{
Type: "openvpn", Tag: "vpn",
Options: &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11940}},
Proto: "tcp",
Cipher: "AES-256-GCM",
Auth: "SHA256",
TLSCrypt: tlsCrypt,
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
},
}},
Route: &option.RouteOptions{Final: "vpn"},
}
ctx := include.Context(context.Background())
instance, err := box.New(box.Options{Context: ctx, Options: opts})
if err != nil {
t.Fatal(err)
}
if err := instance.Start(); err != nil {
t.Fatal(err)
}
defer instance.Close()
time.Sleep(2 * time.Second)
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
for i := 0; i < 10; i++ {
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr("10.99.0.1:8080"))
if err != nil {
t.Fatalf("request %d dial: %v", i, err)
}
conn.SetDeadline(time.Now().Add(5 * time.Second))
fmt.Fprintf(conn, "GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n")
body, err := io.ReadAll(conn)
conn.Close()
if err != nil {
t.Fatalf("request %d read: %v", i, err)
}
if !strings.Contains(string(body), "hello") {
t.Fatalf("request %d: no 'hello'", i)
}
}
}
func doHTTPCheck(t *testing.T, socksPort uint16, httpAddr string) {
t.Helper()
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", socksPort), socks.Version5, "", "")
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(httpAddr))
if err != nil {
t.Fatal("dial:", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
if err != nil {
t.Fatal("write:", err)
}
body, err := io.ReadAll(conn)
if err != nil {
t.Fatal("read:", err)
}
if !strings.Contains(string(body), "hello") {
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
}
}
func startInstance(t *testing.T, ovpnOpts *option.OpenVPNOutboundOptions) uint16 {
t.Helper()
port := nextPort()
opts := option.Options{
Log: &option.LogOptions{Level: "error"},
Inbounds: []option.Inbound{{
Type: "socks",
Options: &option.SocksInboundOptions{
ListenOptions: option.ListenOptions{
Listen: (*badoption.Addr)(&badoption.Addr{}),
ListenPort: port,
},
},
}},
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
Route: &option.RouteOptions{Final: "vpn"},
}
ctx := include.Context(context.Background())
instance, err := box.New(box.Options{Context: ctx, Options: opts})
if err != nil {
t.Fatal(err)
}
if err := instance.Start(); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { instance.Close() })
time.Sleep(2 * time.Second)
return port
}
func TestE2E_CompLZO(t *testing.T) {
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
tlsCrypt := readFile(t, pkiDir+"/ta.key")
for _, cipher := range ciphers {
cipher := cipher
t.Run(cipher, func(t *testing.T) {
port := startInstance(t, &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11948}},
Proto: "udp",
Cipher: cipher,
Auth: "SHA256",
TLSCrypt: tlsCrypt,
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
})
doHTTPCheck(t, port, "10.99.8.1:8080")
})
}
}
func TestE2E_AES192(t *testing.T) {
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
tlsCrypt := readFile(t, pkiDir+"/ta.key")
type combo struct {
proto string
port uint16
httpAddr string
}
for _, c := range []combo{
{"tcp", 11940, "10.99.0.1:8080"},
{"udp", 11941, "10.99.1.1:8080"},
} {
for _, cipher := range []string{"AES-192-GCM", "AES-192-CBC"} {
c, cipher := c, cipher
t.Run(fmt.Sprintf("%s/%s", c.proto, cipher), func(t *testing.T) {
port := startInstance(t, &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: c.port}},
Proto: c.proto,
Cipher: cipher,
Auth: "SHA256",
TLSCrypt: tlsCrypt,
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
})
doHTTPCheck(t, port, c.httpAddr)
})
}
}
}
var _ net.Conn

View File

@@ -114,7 +114,7 @@ func ParseServerKeyMethod2Record(packet []byte) (*KeyMethod2Record, error) {
} }
func DeriveClientKeyMaterial(sources KeySource2, clientSession, serverSession SessionID, cipherKeyLen int) (*KeyMaterial, error) { func DeriveClientKeyMaterial(sources KeySource2, clientSession, serverSession SessionID, cipherKeyLen int) (*KeyMaterial, error) {
if cipherKeyLen != 16 && cipherKeyLen != 32 { if cipherKeyLen != 16 && cipherKeyLen != 24 && cipherKeyLen != 32 {
return nil, fmt.Errorf("unsupported data cipher key length %d", cipherKeyLen) return nil, fmt.Errorf("unsupported data cipher key length %d", cipherKeyLen)
} }
var master [48]byte var master [48]byte

48
transport/openvpn/lzo.go Normal file
View File

@@ -0,0 +1,48 @@
package openvpn
import (
"bytes"
"errors"
"github.com/rasky/go-lzo"
)
const (
lzoCompressNone = 0xFA
lzoCompressLZO = 0x66
)
var ErrLZODecompress = errors.New("lzo decompression failed")
func lzo1xDecompressSafe(src []byte) ([]byte, error) {
if len(src) == 0 {
return nil, ErrLZODecompress
}
switch src[0] {
case lzoCompressNone:
if len(src) > 1 {
return src[1:], nil
}
return nil, nil
case lzoCompressLZO:
if len(src) > 1 {
r := bytes.NewReader(src[1:])
out, err := lzo.Decompress1X(r, len(src)-1, 0)
if err != nil {
return nil, ErrLZODecompress
}
return out, nil
}
return nil, nil
default:
return nil, ErrLZODecompress
}
}
func lzo1xCompressSafe(src []byte) ([]byte, error) {
lzoPacket := make([]byte, 1+len(src))
lzoPacket[0] = lzoCompressNone
copy(lzoPacket[1:], src)
return lzoPacket, nil
}

View File

@@ -10,16 +10,17 @@ import (
const PushRequest = "PUSH_REQUEST" const PushRequest = "PUSH_REQUEST"
type PushReply struct { type PushReply struct {
Raw string Raw string
Prefixes []netip.Prefix Prefixes []netip.Prefix
DNS []netip.Addr DNS []netip.Addr
PeerID uint32 PeerID uint32
Cipher string Cipher string
Ping uint32 Ping uint32
MTU uint32 PingRestart uint32
CompLZO bool MTU uint32
Redirect bool CompLZO bool
BlockIPv6 bool Redirect bool
BlockIPv6 bool
} }
func ParsePushReply(message string) (*PushReply, error) { func ParsePushReply(message string) (*PushReply, error) {
@@ -81,6 +82,12 @@ func ParsePushReply(message string) (*PushReply, error) {
reply.Ping = uint32(v) reply.Ping = uint32(v)
} }
} }
case "ping-restart":
if len(fields) >= 2 {
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
reply.PingRestart = uint32(v)
}
}
case "tun-mtu": case "tun-mtu":
if len(fields) >= 2 { if len(fields) >= 2 {
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil { if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
@@ -113,27 +120,44 @@ func splitPushOptions(message string) []string {
return out return out
} }
func parseIPv4Ifconfig(address, mask string) (netip.Prefix, error) { func parseIPv4Ifconfig(address, maskOrPeer string) (netip.Prefix, error) {
addr, err := netip.ParseAddr(address) addr, err := netip.ParseAddr(address)
if err != nil { if err != nil {
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 address %q: %w", address, err) return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 address %q: %w", address, err)
} }
maskAddr, err := netip.ParseAddr(mask) maskAddr, err := netip.ParseAddr(maskOrPeer)
if err != nil { if err != nil {
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", mask, err) return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", maskOrPeer, err)
} }
if !addr.Is4() || !maskAddr.Is4() { if !addr.Is4() || !maskAddr.Is4() {
return netip.Prefix{}, fmt.Errorf("openvpn ifconfig requires ipv4 address and mask") return netip.Prefix{}, fmt.Errorf("openvpn ifconfig requires ipv4 address and mask")
} }
maskBytes := maskAddr.As4()
if ones, ok := ipv4MaskSize(maskAddr); ok {
return netip.PrefixFrom(addr, ones), nil
}
// Some servers, including SoftEther/VPNGate in net30/p2p mode, push
// "ifconfig <local> <remote>" rather than "ifconfig <local> <netmask>".
// Use a host prefix for that local tunnel address.
return netip.PrefixFrom(addr, 32), nil
}
func ipv4MaskSize(mask netip.Addr) (int, bool) {
maskBytes := mask.As4()
ones := 0 ones := 0
seenZero := false
for _, b := range maskBytes { for _, b := range maskBytes {
for i := 7; i >= 0; i-- { for i := 7; i >= 0; i-- {
if b&(1<<i) == 0 { if b&(1<<i) == 0 {
return netip.PrefixFrom(addr, ones), nil seenZero = true
continue
}
if seenZero {
return 0, false
} }
ones++ ones++
} }
} }
return netip.PrefixFrom(addr, ones), nil return ones, true
} }

View File

@@ -2,6 +2,7 @@ package openvpn
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/md5"
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
@@ -35,6 +36,9 @@ func NewTLSAuth(staticKey []byte, keyDirection int, auth string) (*TLSAuth, erro
var newHash func() hash.Hash var newHash func() hash.Hash
var hmacSize int var hmacSize int
switch auth { switch auth {
case AuthMD5:
newHash = md5.New
hmacSize = md5.Size
case AuthSHA256: case AuthSHA256:
newHash = sha256.New newHash = sha256.New
hmacSize = sha256.Size hmacSize = sha256.Size

View File

@@ -30,16 +30,19 @@ type TunnelOptions struct {
UDPTimeout time.Duration UDPTimeout time.Duration
ReconnectDelay time.Duration ReconnectDelay time.Duration
PingInterval time.Duration PingInterval time.Duration
PingRestart time.Duration
} }
type Tunnel struct { type Tunnel struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc
logger logger.ContextLogger logger logger.ContextLogger
options TunnelOptions options TunnelOptions
device Device device Device
client *Client client *Client
mtu uint32 mtu uint32
serverIndex int serverIndex int
wg sync.WaitGroup
await chan struct{} await chan struct{}
mu sync.Mutex mu sync.Mutex
@@ -49,8 +52,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
if options.ReconnectDelay == 0 { if options.ReconnectDelay == 0 {
options.ReconnectDelay = 5 * time.Second options.ReconnectDelay = 5 * time.Second
} }
ctx, cancel := context.WithCancel(ctx)
return &Tunnel{ return &Tunnel{
ctx: ctx, ctx: ctx,
cancel: cancel,
logger: logger, logger: logger,
options: options, options: options,
await: make(chan struct{}), await: make(chan struct{}),
@@ -59,10 +64,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
func (t *Tunnel) Start() error { func (t *Tunnel) Start() error {
go func() { go func() {
defer close(t.await)
client, err := t.getClient() client, err := t.getClient()
if err != nil { if err != nil {
t.logger.Error("OpenVPN connect: ", err) t.logger.Error("OpenVPN connect: ", err)
close(t.await)
return return
} }
t.mtu = 1500 t.mtu = 1500
@@ -84,20 +89,26 @@ func (t *Tunnel) Start() error {
if err != nil { if err != nil {
client.Close() client.Close()
t.logger.Error("create OpenVPN device: ", err) t.logger.Error("create OpenVPN device: ", err)
close(t.await)
return return
} }
t.device = device t.device = device
if err := device.Start(); err != nil { if err := device.Start(); err != nil {
client.Close() client.Close()
t.logger.Error("start OpenVPN device: ", err) t.logger.Error("start OpenVPN device: ", err)
close(t.await)
return return
} }
close(t.await)
t.maintainTunnel() t.maintainTunnel()
}() }()
return nil return nil
} }
func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if err := t.isTunnelInitialized(ctx); err != nil {
return nil, err
}
if !destination.Addr.IsValid() { if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination") return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
} }
@@ -105,6 +116,9 @@ func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.
} }
func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if err := t.isTunnelInitialized(ctx); err != nil {
return nil, err
}
if !destination.Addr.IsValid() { if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination") return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
} }
@@ -112,15 +126,18 @@ func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
} }
func (t *Tunnel) Close() error { func (t *Tunnel) Close() error {
t.cancel()
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock()
if t.client != nil { if t.client != nil {
t.client.Close() t.client.Close()
t.client = nil t.client = nil
} }
if t.device != nil { if t.device != nil {
return t.device.Close() t.device.Close()
t.device = nil
} }
t.mu.Unlock()
t.wg.Wait()
return nil return nil
} }
@@ -137,7 +154,9 @@ func (t *Tunnel) isTunnelInitialized(ctx context.Context) error {
} }
func (t *Tunnel) maintainTunnel() { func (t *Tunnel) maintainTunnel() {
t.wg.Add(2)
go func() { go func() {
defer t.wg.Done()
bufs := make([][]byte, 1) bufs := make([][]byte, 1)
bufs[0] = make([]byte, t.mtu) bufs[0] = make([]byte, t.mtu)
sizes := make([]int, 1) sizes := make([]int, 1)
@@ -161,6 +180,7 @@ func (t *Tunnel) maintainTunnel() {
} }
}() }()
go func() { go func() {
defer t.wg.Done()
for t.ctx.Err() == nil { for t.ctx.Err() == nil {
client, err := t.getClient() client, err := t.getClient()
if err != nil { if err != nil {
@@ -179,10 +199,14 @@ func (t *Tunnel) maintainTunnel() {
if bytes.Equal(packet, pingPayload) { if bytes.Equal(packet, pingPayload) {
continue continue
} }
if t.ctx.Err() != nil {
return
}
if t.ctx.Err() != nil {
return
}
if _, err := t.device.Write([][]byte{packet}, 0); err != nil { if _, err := t.device.Write([][]byte{packet}, 0); err != nil {
if t.ctx.Err() != nil { return
return
}
} }
} }
}() }()
@@ -208,6 +232,34 @@ func (t *Tunnel) maintainTunnel() {
} }
}() }()
} }
pingRestart := t.options.PingRestart
if pingRestart == 0 && t.client != nil && t.client.push.PingRestart > 0 {
pingRestart = time.Duration(t.client.push.PingRestart) * time.Second
}
if pingRestart > 0 {
t.wg.Add(1)
go func() {
defer t.wg.Done()
ticker := time.NewTicker(pingRestart)
defer ticker.Stop()
for {
select {
case <-t.ctx.Done():
return
case <-ticker.C:
client, err := t.getClient()
if err != nil {
return
}
if client.SinceReceive() >= pingRestart {
if ok := t.closeClient(client); ok {
t.logger.ErrorContext(t.ctx, fmt.Errorf("ping-restart timeout: no packet received for %s", pingRestart))
}
}
}
}
}()
}
<-t.ctx.Done() <-t.ctx.Done()
} }

View File

@@ -0,0 +1,100 @@
package obfs
import (
"bufio"
cryptorand "crypto/rand"
"encoding/base64"
"fmt"
"io"
"math/rand/v2"
"net"
"net/http"
"time"
)
// HTTPObfsServer is the server side of the shadowsocks http simple-obfs implementation.
type HTTPObfsServer struct {
net.Conn
buf []byte
bio *bufio.Reader
offset int
firstRequest bool
firstResponse bool
}
func (hos *HTTPObfsServer) Read(b []byte) (int, error) {
if hos.buf != nil {
n := copy(b, hos.buf[hos.offset:])
hos.offset += n
if hos.offset == len(hos.buf) {
hos.offset = 0
hos.buf = nil
}
return n, nil
}
if hos.firstRequest {
bio := bufio.NewReader(hos.Conn)
req, err := http.ReadRequest(bio)
if err != nil {
return 0, err
}
if req.Method != "GET" || req.Header.Get("Connection") != "Upgrade" {
return 0, io.EOF
}
buf, err := io.ReadAll(req.Body)
if err != nil {
return 0, err
}
n := copy(b, buf)
if n < len(buf) {
hos.buf = buf
hos.offset = n
}
req.Body.Close()
hos.bio = bio
hos.firstRequest = false
return n, nil
}
return hos.bio.Read(b)
}
func (hos *HTTPObfsServer) Write(b []byte) (int, error) {
if hos.firstResponse {
randBytes := make([]byte, 16)
cryptorand.Read(randBytes)
date := time.Now().Format(time.RFC1123)
resp := fmt.Sprintf(httpResponseTemplate, vMajor, vMinor, date, base64.URLEncoding.EncodeToString(randBytes))
_, err := hos.Conn.Write([]byte(resp))
if err != nil {
return 0, err
}
hos.firstResponse = false
}
return hos.Conn.Write(b)
}
func (hos *HTTPObfsServer) Upstream() any {
return hos.Conn
}
// NewHTTPObfsServer returns a server-side HTTPObfs.
func NewHTTPObfsServer(conn net.Conn) net.Conn {
return &HTTPObfsServer{
Conn: conn,
firstRequest: true,
firstResponse: true,
}
}
const httpResponseTemplate = "HTTP/1.1 101 Switching Protocols\r\n" +
"Server: nginx/1.%d.%d\r\n" +
"Date: %s\r\n" +
"Upgrade: websocket\r\n" +
"Connection: Upgrade\r\n" +
"Sec-WebSocket-Accept: %s\r\n" +
"\r\n"
var (
vMajor = rand.IntN(11)
vMinor = rand.IntN(12)
)

View File

@@ -0,0 +1,154 @@
package obfs
import (
"bytes"
"crypto/rand"
"encoding/binary"
"io"
"net"
"time"
B "github.com/sagernet/sing/common/buf"
)
// TLSObfsServer is the server side of the shadowsocks tls simple-obfs implementation.
type TLSObfsServer struct {
net.Conn
remain int
firstRequest bool
sessionTicketDone bool
firstResponse bool
}
func (tos *TLSObfsServer) read(b []byte, discardN int) (int, error) {
buf := B.Get(discardN)
_, err := io.ReadFull(tos.Conn, buf)
B.Put(buf)
if err != nil {
return 0, err
}
sizeBuf := make([]byte, 2)
_, err = io.ReadFull(tos.Conn, sizeBuf)
if err != nil {
return 0, nil
}
length := int(binary.BigEndian.Uint16(sizeBuf))
if length > len(b) {
n, err := tos.Conn.Read(b)
if err != nil {
return n, err
}
tos.remain = length - n
return n, nil
}
return io.ReadFull(tos.Conn, b[:length])
}
// skipOtherExts skips SNI & other TLS extensions.
func (tos *TLSObfsServer) skipOtherExts() error {
buf := make([]byte, 256)
_, err := tos.read(buf, 7)
if err != nil {
return err
}
_, err = io.ReadFull(tos.Conn, buf[:4*16+2])
return err
}
func (tos *TLSObfsServer) Read(b []byte) (int, error) {
if tos.remain > 0 {
length := tos.remain
if length > len(b) {
length = len(b)
}
n, err := io.ReadFull(tos.Conn, b[:length])
tos.remain -= n
return n, err
}
if tos.firstRequest {
tos.firstRequest = false
return tos.read(b, 9*16-4)
}
if !tos.sessionTicketDone {
tos.sessionTicketDone = true
err := tos.skipOtherExts()
if err != nil {
return 0, err
}
}
return tos.read(b, 3)
}
func (tos *TLSObfsServer) Write(b []byte) (int, error) {
length := len(b)
for i := 0; i < length; i += chunkSize {
end := i + chunkSize
if end > length {
end = length
}
n, err := tos.write(b[i:end])
if err != nil {
return n, err
}
}
return length, nil
}
func (tos *TLSObfsServer) write(b []byte) (int, error) {
if tos.firstResponse {
serverHello := makeServerHello(b)
_, err := tos.Conn.Write(serverHello)
tos.firstResponse = false
return len(b), err
}
buf := B.NewSize(5 + len(b))
defer buf.Release()
buf.Write([]byte{0x17, 0x03, 0x03})
binary.Write(buf, binary.BigEndian, uint16(len(b)))
buf.Write(b)
_, err := tos.Conn.Write(buf.Bytes())
if err != nil {
return 0, err
}
return len(b), nil
}
func (tos *TLSObfsServer) Upstream() any {
return tos.Conn
}
// NewTLSObfsServer returns a server-side TLS SimpleObfs.
func NewTLSObfsServer(conn net.Conn) net.Conn {
return &TLSObfsServer{
Conn: conn,
firstRequest: true,
firstResponse: true,
}
}
func makeServerHello(data []byte) []byte {
randBytes := make([]byte, 28)
sessionId := make([]byte, 32)
rand.Read(randBytes)
rand.Read(sessionId)
buf := &bytes.Buffer{}
buf.WriteByte(0x16)
binary.Write(buf, binary.BigEndian, uint16(0x0301))
binary.Write(buf, binary.BigEndian, uint16(91))
buf.Write([]byte{2, 0, 0, 87, 0x03, 0x03})
binary.Write(buf, binary.BigEndian, uint32(time.Now().Unix()))
buf.Write(randBytes)
buf.WriteByte(32)
buf.Write(sessionId)
buf.Write([]byte{0xcc, 0xa8})
buf.WriteByte(0)
buf.Write([]byte{0x00, 0x00})
buf.Write([]byte{0xff, 0x01, 0x00, 0x01, 0x00})
buf.Write([]byte{0x00, 0x17, 0x00, 0x00})
buf.Write([]byte{0x00, 0x0b, 0x00, 0x02, 0x01, 0x00})
buf.Write([]byte{0x14, 0x03, 0x03, 0x00, 0x01, 0x01})
buf.Write([]byte{0x16, 0x03, 0x03})
binary.Write(buf, binary.BigEndian, uint16(len(data)))
buf.Write(data)
return buf.Bytes()
}

144
transport/snell/address.go Normal file
View File

@@ -0,0 +1,144 @@
package snell
import (
"encoding/binary"
"net"
"strconv"
)
// SOCKS address types as defined in RFC 1928 section 5.
const (
atypIPv4 = 1
atypDomainName = 3
atypIPv6 = 4
)
// socksAddr represents a SOCKS address as defined in RFC 1928 section 5.
type socksAddr []byte
func (a socksAddr) String() string {
var host, port string
switch a[0] {
case atypDomainName:
hostLen := uint16(a[1])
host = string(a[2 : 2+hostLen])
port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
case atypIPv4:
host = net.IP(a[1 : 1+net.IPv4len]).String()
port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
case atypIPv6:
host = net.IP(a[1 : 1+net.IPv6len]).String()
port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
}
return net.JoinHostPort(host, port)
}
// UDPAddr converts a socksAddr to *net.UDPAddr.
func (a socksAddr) UDPAddr() *net.UDPAddr {
if len(a) == 0 {
return nil
}
switch a[0] {
case atypIPv4:
var ip [net.IPv4len]byte
copy(ip[0:], a[1:1+net.IPv4len])
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
case atypIPv6:
var ip [net.IPv6len]byte
copy(ip[0:], a[1:1+net.IPv6len])
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
}
return nil
}
// splitSocksAddr slices a SOCKS address from beginning of b. Returns nil if failed.
func splitSocksAddr(b []byte) socksAddr {
addrLen := 1
if len(b) < addrLen {
return nil
}
switch b[0] {
case atypDomainName:
if len(b) < 2 {
return nil
}
addrLen = 1 + 1 + int(b[1]) + 2
case atypIPv4:
addrLen = 1 + net.IPv4len + 2
case atypIPv6:
addrLen = 1 + net.IPv6len + 2
default:
return nil
}
if len(b) < addrLen {
return nil
}
return b[:addrLen]
}
// parseAddr parses the address in string s. Returns nil if failed.
func parseAddr(s string) socksAddr {
var addr socksAddr
host, port, err := net.SplitHostPort(s)
if err != nil {
return nil
}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
addr = make([]byte, 1+net.IPv4len+2)
addr[0] = atypIPv4
copy(addr[1:], ip4)
} else {
addr = make([]byte, 1+net.IPv6len+2)
addr[0] = atypIPv6
copy(addr[1:], ip)
}
} else {
if len(host) > 255 {
return nil
}
addr = make([]byte, 1+1+len(host)+2)
addr[0] = atypDomainName
addr[1] = byte(len(host))
copy(addr[2:], host)
}
portnum, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil
}
addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
return addr
}
// parseAddrToSocksAddr parses a socks addr from net.Addr.
// This is a fast path of parseAddr(addr.String()).
func parseAddrToSocksAddr(addr net.Addr) socksAddr {
var hostip net.IP
var port int
switch addr := addr.(type) {
case *net.UDPAddr:
hostip = addr.IP
port = addr.Port
case *net.TCPAddr:
hostip = addr.IP
port = addr.Port
case nil:
return nil
}
if hostip == nil {
return parseAddr(addr.String())
}
var parsed socksAddr
if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
parsed = make([]byte, 1+net.IPv4len+2)
parsed[0] = atypIPv4
copy(parsed[1:], ip4)
binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
} else {
parsed = make([]byte, 1+net.IPv6len+2)
parsed[0] = atypIPv6
copy(parsed[1:], hostip)
binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
}
return parsed
}

56
transport/snell/cipher.go Normal file
View File

@@ -0,0 +1,56 @@
package snell
import (
"crypto/aes"
"crypto/cipher"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/chacha20poly1305"
)
// NewAES128GCM returns the AES-128-GCM cipher used by snell v2/v3.
func NewAES128GCM(psk []byte) Cipher {
return &snellCipher{
psk: psk,
keySize: 16,
makeAEAD: aesGCM,
}
}
// NewChacha20Poly1305 returns the ChaCha20-Poly1305 cipher used by snell v1.
func NewChacha20Poly1305(psk []byte) Cipher {
return &snellCipher{
psk: psk,
keySize: 32,
makeAEAD: chacha20poly1305.New,
}
}
type snellCipher struct {
psk []byte
keySize int
makeAEAD func(key []byte) (cipher.AEAD, error)
}
func (sc *snellCipher) KeySize() int { return sc.keySize }
func (sc *snellCipher) SaltSize() int { return 16 }
func (sc *snellCipher) Encrypter(salt []byte) (cipher.AEAD, error) {
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
}
func (sc *snellCipher) Decrypter(salt []byte) (cipher.AEAD, error) {
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
}
func snellKDF(psk, salt []byte, keySize int) []byte {
return argon2.IDKey(psk, salt, 3, 8, 1, 32)[:keySize]
}
func aesGCM(key []byte) (cipher.AEAD, error) {
blk, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewGCM(blk)
}

120
transport/snell/client.go Normal file
View File

@@ -0,0 +1,120 @@
package snell
import (
"context"
"net"
"strconv"
"time"
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type ClientOptions struct {
Dialer N.Dialer
Server M.Socksaddr
PSK []byte
Version int
Reuse bool
ObfsMode string
ObfsHost string
}
type Client struct {
dialer N.Dialer
server M.Socksaddr
psk []byte
version int
reuse bool
obfsMode string
obfsHost string
pool *Pool
}
func NewClient(options ClientOptions) *Client {
c := &Client{
dialer: options.Dialer,
server: options.Server,
psk: options.PSK,
version: options.Version,
reuse: options.Reuse,
obfsMode: options.ObfsMode,
obfsHost: options.ObfsHost,
}
if c.reuse {
c.pool = NewPool(func(ctx context.Context) (*Snell, error) {
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
if err != nil {
return nil, err
}
return c.streamConn(conn), nil
})
}
return c
}
func (c *Client) streamConn(conn net.Conn) *Snell {
switch c.obfsMode {
case "tls":
conn = obfs.NewTLSObfs(conn, c.obfsHost)
case "http":
conn = obfs.NewHTTPObfs(conn, c.obfsHost, strconv.Itoa(int(c.server.Port)))
}
return StreamConn(conn, c.psk, c.version)
}
func (c *Client) writeHeader(ctx context.Context, conn net.Conn, destination M.Socksaddr, udp bool) (err error) {
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetWriteDeadline(deadline)
defer conn.SetWriteDeadline(time.Time{})
}
if udp {
err = WriteUDPHeader(conn, c.version)
if err == nil && c.version >= Version4 {
if sc, ok := conn.(*Snell); ok {
err = sc.ReadReply()
}
}
return
}
err = WriteHeaderWithReuse(conn, destination.AddrString(), uint(destination.Port), c.version, c.reuse)
return
}
func (c *Client) DialContext(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
if c.reuse {
conn, err := c.pool.Get()
if err != nil {
return nil, err
}
if err = c.writeHeader(ctx, conn, destination, false); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
if err != nil {
return nil, err
}
stream := c.streamConn(conn)
if err = c.writeHeader(ctx, stream, destination, false); err != nil {
_ = conn.Close()
return nil, err
}
return stream, nil
}
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
if err != nil {
return nil, err
}
stream := c.streamConn(conn)
if err = c.writeHeader(ctx, stream, destination, true); err != nil {
_ = conn.Close()
return nil, err
}
return PacketConn(stream), nil
}

153
transport/snell/pool.go Normal file
View File

@@ -0,0 +1,153 @@
package snell
import (
"context"
"io"
"net"
"runtime"
"sync"
"time"
)
// poolEntry holds a pooled item with its insertion time.
// connPool is a small connection pool with age-based eviction.
// milliseconds
// Pool is a pool of reusable snell connections.
type Pool struct {
pool *connPool
}
func (p *Pool) Get() (net.Conn, error) {
return p.GetContext(context.Background())
}
func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) {
elm, err := p.pool.GetContext(ctx)
if err != nil {
return nil, err
}
return &PoolConn{Snell: elm, pool: p}, nil
}
func (p *Pool) Put(conn *Snell) {
if err := HalfClose(conn); err != nil {
_ = conn.Close()
return
}
p.pool.put(conn)
}
// PoolConn wraps a pooled snell connection and returns it to the pool on Close.
type PoolConn struct {
*Snell
pool *Pool
closeWriteOnce sync.Once
closeWriteErr error
closeOnce sync.Once
closeErr error
}
func (pc *PoolConn) Read(b []byte) (int, error) {
n, err := pc.Snell.Read(b)
if err == ErrZeroChunk {
return n, io.EOF
}
return n, err
}
func (pc *PoolConn) Write(b []byte) (int, error) {
return pc.Snell.Write(b)
}
func (pc *PoolConn) CloseWrite() error {
pc.closeWriteOnce.Do(func() {
pc.closeWriteErr = writeZeroChunk(pc.Snell)
})
return pc.closeWriteErr
}
func (pc *PoolConn) Close() error {
pc.closeOnce.Do(func() {
if err := pc.CloseWrite(); err != nil {
pc.closeErr = err
_ = pc.Snell.Close()
return
}
_ = pc.Snell.Conn.SetReadDeadline(time.Time{})
pc.Snell.reply = false
pc.pool.pool.put(pc.Snell)
})
return pc.closeErr
}
// NewPool creates a new snell connection pool using the given factory.
func NewPool(factory func(context.Context) (*Snell, error)) *Pool {
cp := &connPool{
ch: make(chan *poolEntry, 10),
factory: factory,
maxAge: 15000,
evict: func(item *Snell) {
_ = item.Close()
},
}
p := &Pool{pool: cp}
runtime.SetFinalizer(p, recycle)
return p
}
type poolEntry struct {
elm *Snell
time time.Time
}
type connPool struct {
ch chan *poolEntry
factory func(context.Context) (*Snell, error)
evict func(*Snell)
maxAge int64
}
func (p *connPool) GetContext(ctx context.Context) (*Snell, error) {
now := time.Now()
for {
select {
case item := <-p.ch:
if p.maxAge != 0 && now.Sub(item.time).Milliseconds() > p.maxAge {
if p.evict != nil {
p.evict(item.elm)
}
continue
}
return item.elm, nil
default:
return p.factory(ctx)
}
}
}
func (p *connPool) put(item *Snell) {
e := &poolEntry{
elm: item,
time: time.Now(),
}
select {
case p.ch <- e:
return
default:
if p.evict != nil {
p.evict(item)
}
return
}
}
func recycle(p *Pool) {
for item := range p.pool.ch {
if p.pool.evict != nil {
p.pool.evict(item.elm)
}
}
}

294
transport/snell/service.go Normal file
View File

@@ -0,0 +1,294 @@
package snell
import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
)
type Handler interface {
NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, clientID string)
NewPacketConnection(ctx context.Context, conn net.PacketConn, source M.Socksaddr, clientID string)
}
type ServiceOptions struct {
PSK []byte
Version int
ObfsMode string
UDP bool
Logger logger.ContextLogger
Handler Handler
}
type Service struct {
psk []byte
version int
obfsMode string
udp bool
logger logger.ContextLogger
handler Handler
}
func NewService(options ServiceOptions) (*Service, error) {
version := options.Version
if version == 0 {
version = Version4
}
if version != Version4 && version != Version5 {
return nil, fmt.Errorf("snell inbound version %d is not supported", version)
}
if len(options.PSK) == 0 {
return nil, errors.New("snell inbound requires psk")
}
switch options.ObfsMode {
case "", "http", "tls":
default:
return nil, fmt.Errorf("snell inbound obfs mode error: %s", options.ObfsMode)
}
return &Service{
psk: options.PSK,
version: version,
obfsMode: options.ObfsMode,
udp: options.UDP,
logger: options.Logger,
handler: options.Handler,
}, nil
}
func (s *Service) NewConnection(ctx context.Context, rawConn net.Conn, source M.Socksaddr) error {
conn := rawConn
switch s.obfsMode {
case "http":
conn = obfs.NewHTTPObfsServer(conn)
case "tls":
conn = obfs.NewTLSObfsServer(conn)
}
stream := ServerStreamConn(conn, s.psk, s.version)
for {
reuse, err := s.handleRequest(ctx, stream, source)
if err != nil || !reuse {
return err
}
}
}
func (s *Service) handleRequest(ctx context.Context, stream *Snell, source M.Socksaddr) (bool, error) {
br := bufio.NewReader(stream)
version, err := br.ReadByte()
if err != nil {
return false, err
}
if version != Version {
return false, fmt.Errorf("snell invalid protocol version: %d", version)
}
command, err := br.ReadByte()
if err != nil {
return false, err
}
if command == CommandPing {
_, _ = stream.Write([]byte{CommandPong})
return false, nil
}
clientID, err := readClientID(br)
if err != nil {
return false, err
}
switch command {
case CommandConnect, CommandConnectV2:
return s.handleTCP(ctx, stream, br, command == CommandConnectV2, clientID, source)
case CommandUDP:
if !s.udp {
return false, errors.New("snell UDP is disabled")
}
return false, s.handleUDP(ctx, stream, clientID, source)
default:
return false, fmt.Errorf("snell unknown command: %d", command)
}
}
func (s *Service) handleTCP(ctx context.Context, stream *Snell, br *bufio.Reader, reuse bool, clientID string, source M.Socksaddr) (bool, error) {
hostLen, err := br.ReadByte()
if err != nil {
return false, err
}
if hostLen == 0 {
return false, errors.New("snell connect host is empty")
}
hostBytes := make([]byte, int(hostLen))
if _, err := io.ReadFull(br, hostBytes); err != nil {
return false, err
}
var portBytes [2]byte
if _, err := io.ReadFull(br, portBytes[:]); err != nil {
return false, err
}
destination := M.ParseSocksaddrHostPort(string(hostBytes), binary.BigEndian.Uint16(portBytes[:]))
conn := &tcpRequestConn{
Conn: stream,
reader: br,
reuse: reuse,
}
s.handler.NewConnection(ctx, conn, source, destination, clientID)
if !reuse {
return false, nil
}
return true, nil
}
func (s *Service) handleUDP(ctx context.Context, stream *Snell, clientID string, source M.Socksaddr) error {
if _, err := stream.Write([]byte{CommandTunnel}); err != nil {
return err
}
pc := &serverPacketConn{
conn: stream,
writeMu: &sync.Mutex{},
}
s.handler.NewPacketConnection(ctx, pc, source, clientID)
return nil
}
const maxPacketLength = 0x3fff
func readClientID(r *bufio.Reader) (string, error) {
length, err := r.ReadByte()
if err != nil {
return "", err
}
if length == 0 {
return "", nil
}
id := make([]byte, int(length))
if _, err := io.ReadFull(r, id); err != nil {
return "", err
}
return string(id), nil
}
func writeCommandError(w io.Writer, code byte, message string) error {
msg := []byte(message)
if len(msg) > 255 {
msg = msg[:255]
}
buf := make([]byte, 0, 3+len(msg))
buf = append(buf, CommandError, code, byte(len(msg)))
buf = append(buf, msg...)
_, err := w.Write(buf)
return err
}
type tcpRequestConn struct {
net.Conn
reader *bufio.Reader
reuse bool
writeMu sync.Mutex
closeOnce sync.Once
replyWritten bool
}
func (c *tcpRequestConn) Read(p []byte) (int, error) {
n, err := c.reader.Read(p)
if errors.Is(err, ErrZeroChunk) {
err = io.EOF
}
return n, err
}
func (c *tcpRequestConn) Write(p []byte) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if !c.replyWritten {
payload := make([]byte, 1+len(p))
payload[0] = CommandTunnel
copy(payload[1:], p)
if _, err := c.Conn.Write(payload); err != nil {
return 0, err
}
c.replyWritten = true
return len(p), nil
}
return c.Conn.Write(p)
}
func (c *tcpRequestConn) CloseWrite() error {
return c.Close()
}
func (c *tcpRequestConn) Close() error {
var err error
c.closeOnce.Do(func() {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if !c.replyWritten {
err = writeCommandError(c.Conn, 0x65, "Remote EOF")
if !c.reuse {
err = errors.Join(err, c.Conn.Close())
}
return
}
if c.reuse {
_, err = c.Conn.Write(nil)
return
}
err = c.Conn.Close()
})
return err
}
type serverPacketConn struct {
conn *Snell
writeMu *sync.Mutex
readBuf []byte
}
func (c *serverPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
if c.readBuf == nil {
c.readBuf = make([]byte, maxPacketLength)
}
for {
n, err := c.conn.Read(c.readBuf)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, ErrZeroChunk) {
return 0, nil, io.EOF
}
return 0, nil, err
}
request, err := ParseUDPRequest(c.readBuf[:n])
if err != nil {
return 0, nil, err
}
var destination M.Socksaddr
if request.Ip.IsValid() {
destination = M.SocksaddrFrom(request.Ip, request.Port)
} else {
destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
}
length := copy(p, request.Payload)
if destination.IsFqdn() {
return length, destination, nil
}
return length, destination.UDPAddr(), nil
}
}
func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
return WritePacketResponse(c.conn, addr, p)
}
func (c *serverPacketConn) Close() error { return c.conn.Close() }
func (c *serverPacketConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *serverPacketConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *serverPacketConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *serverPacketConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }

View File

@@ -0,0 +1,211 @@
package snell
import (
"crypto/cipher"
"crypto/rand"
"errors"
"io"
"net"
"github.com/sagernet/sing/common/buf"
)
// payloadSizeMask is the maximum size of payload in bytes.
// 16*1024 - 1
// >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead()
// ErrZeroChunk is returned when a zero-length chunk is read, which snell uses
// as an end-of-stream signal.
var ErrZeroChunk = errors.New("zero chunk")
// Cipher is the AEAD cipher abstraction used by the shadowaead stream.
type Cipher interface {
KeySize() int
SaltSize() int
Encrypter(salt []byte) (cipher.AEAD, error)
Decrypter(salt []byte) (cipher.AEAD, error)
}
const (
payloadSizeMask = 0x3FFF
bufSize = 17 * 1024
)
type aeadWriter struct {
io.Writer
cipher.AEAD
nonce [32]byte // should be sufficient for most nonce sizes
}
// newAEADWriter wraps an io.Writer with authenticated encryption.
func newAEADWriter(w io.Writer, aead cipher.AEAD) *aeadWriter {
return &aeadWriter{Writer: w, AEAD: aead}
}
// Write encrypts p and writes to the embedded io.Writer.
func (w *aeadWriter) Write(p []byte) (n int, err error) {
b := buf.Get(bufSize)
defer buf.Put(b)
nonce := w.nonce[:w.NonceSize()]
tag := w.Overhead()
off := 2 + tag
if len(p) == 0 {
b = b[:off]
b[0], b[1] = byte(0), byte(0)
w.Seal(b[:0], nonce, b[:2], nil)
increment(nonce)
_, err = w.Writer.Write(b)
return
}
for nr := 0; n < len(p) && err == nil; n += nr {
nr = payloadSizeMask
if n+nr > len(p) {
nr = len(p) - n
}
b = b[:off+nr+tag]
b[0], b[1] = byte(nr>>8), byte(nr)
w.Seal(b[:0], nonce, b[:2], nil)
increment(nonce)
w.Seal(b[:off], nonce, p[n:n+nr], nil)
increment(nonce)
_, err = w.Writer.Write(b)
}
return
}
type aeadReader struct {
io.Reader
cipher.AEAD
nonce [32]byte // should be sufficient for most nonce sizes
buf []byte // to be put back into bufPool
off int // offset to unconsumed part of buf
}
// newAEADReader wraps an io.Reader with authenticated decryption.
func newAEADReader(r io.Reader, aead cipher.AEAD) *aeadReader {
return &aeadReader{Reader: r, AEAD: aead}
}
// read and decrypt a record into p. len(p) >= max payload size + AEAD overhead.
func (r *aeadReader) read(p []byte) (int, error) {
nonce := r.nonce[:r.NonceSize()]
tag := r.Overhead()
p = p[:2+tag]
if _, err := io.ReadFull(r.Reader, p); err != nil {
return 0, err
}
_, err := r.Open(p[:0], nonce, p, nil)
increment(nonce)
if err != nil {
return 0, err
}
size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask
if size == 0 {
return 0, ErrZeroChunk
}
p = p[:size+tag]
if _, err := io.ReadFull(r.Reader, p); err != nil {
return 0, err
}
_, err = r.Open(p[:0], nonce, p, nil)
increment(nonce)
if err != nil {
return 0, err
}
return size, nil
}
// Read reads from the embedded io.Reader, decrypts and writes to p.
func (r *aeadReader) Read(p []byte) (int, error) {
if r.buf == nil {
if len(p) >= payloadSizeMask+r.Overhead() {
return r.read(p)
}
b := buf.Get(bufSize)
n, err := r.read(b)
if err != nil {
buf.Put(b)
return 0, err
}
r.buf = b[:n]
r.off = 0
}
n := copy(p, r.buf[r.off:])
r.off += n
if r.off == len(r.buf) {
buf.Put(r.buf[:cap(r.buf)])
r.buf = nil
}
return n, nil
}
// increment little-endian encoded unsigned integer b. Wrap around on overflow.
func increment(b []byte) {
for i := range b {
b[i]++
if b[i] != 0 {
return
}
}
}
// aeadConn wraps a stream-oriented net.Conn with the shadowaead cipher.
type aeadConn struct {
net.Conn
Cipher
r *aeadReader
w *aeadWriter
}
// newAEADConn wraps a stream-oriented net.Conn with cipher.
func newAEADConn(c net.Conn, ciph Cipher) *aeadConn {
return &aeadConn{Conn: c, Cipher: ciph}
}
func (c *aeadConn) initReader() error {
salt := make([]byte, c.SaltSize())
if _, err := io.ReadFull(c.Conn, salt); err != nil {
return err
}
aead, err := c.Decrypter(salt)
if err != nil {
return err
}
c.r = newAEADReader(c.Conn, aead)
return nil
}
func (c *aeadConn) Read(b []byte) (int, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
return c.r.Read(b)
}
func (c *aeadConn) initWriter() error {
salt := make([]byte, c.SaltSize())
if _, err := rand.Read(salt); err != nil {
return err
}
aead, err := c.Encrypter(salt)
if err != nil {
return err
}
_, err = c.Conn.Write(salt)
if err != nil {
return err
}
c.w = newAEADWriter(c.Conn, aead)
return nil
}
func (c *aeadConn) Write(b []byte) (int, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
return c.w.Write(b)
}

408
transport/snell/snell.go Normal file
View File

@@ -0,0 +1,408 @@
package snell
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"github.com/sagernet/sing/common/buf"
)
const (
Version1 = 1
Version2 = 2
Version3 = 3
Version4 = 4
Version5 = 5
DefaultSnellVersion = Version1
// max packet length
maxLength = 0x3FFF
)
const (
CommandPing byte = 0
CommandConnect byte = 1
CommandConnectV2 byte = 5
CommandUDP byte = 6
CommandUDPForward byte = 1
CommandTunnel byte = 0
CommandPong byte = 1
CommandError byte = 2
Version byte = 1
)
// Snell wraps an encrypted stream and handles the snell reply header.
type Snell struct {
net.Conn
buffer [1]byte
reply bool
}
func (s *Snell) Read(b []byte) (int, error) {
if err := s.ReadReply(); err != nil {
return 0, err
}
return s.Conn.Read(b)
}
func (s *Snell) ReadReply() error {
if s.reply {
return nil
}
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
return err
}
s.reply = true
if s.buffer[0] == CommandTunnel {
return nil
} else if s.buffer[0] != CommandError {
return errors.New("command not support")
}
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
return err
}
errcode := int(s.buffer[0])
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
return err
}
length := int(s.buffer[0])
msg := make([]byte, length)
if _, err := io.ReadFull(s.Conn, msg); err != nil {
return err
}
return fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
}
func WriteHeader(conn net.Conn, host string, port uint, version int) error {
return WriteHeaderWithReuse(conn, host, port, version, false)
}
func WriteHeaderWithReuse(conn net.Conn, host string, port uint, version int, reuse bool) error {
buffer := &bytes.Buffer{}
buffer.WriteByte(Version)
if version == Version2 || reuse {
buffer.WriteByte(CommandConnectV2)
} else {
buffer.WriteByte(CommandConnect)
}
buffer.WriteByte(0)
buffer.WriteByte(uint8(len(host)))
buffer.WriteString(host)
binary.Write(buffer, binary.BigEndian, uint16(port))
if _, err := conn.Write(buffer.Bytes()); err != nil {
return err
}
return nil
}
func WriteUDPHeader(conn net.Conn, version int) error {
if version < Version3 {
return errors.New("unsupport UDP version")
}
_, err := conn.Write([]byte{Version, CommandUDP, 0x00})
return err
}
// HalfClose only works after the request negotiated the reuse command.
func HalfClose(conn net.Conn) error {
if err := writeZeroChunk(conn); err != nil {
return err
}
if s, ok := conn.(*Snell); ok {
s.reply = false
}
return nil
}
// StreamConn wraps a raw connection with the snell stream cipher for the given version.
func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
if version >= Version4 {
return &Snell{Conn: newV4Conn(conn, psk)}
}
var cipher Cipher
if version != Version1 {
cipher = NewAES128GCM(psk)
} else {
cipher = NewChacha20Poly1305(psk)
}
return &Snell{Conn: newAEADConn(conn, cipher)}
}
// ServerStreamConn wraps a raw connection on the server side.
func ServerStreamConn(conn net.Conn, psk []byte, version int) *Snell {
stream := StreamConn(conn, psk, version)
stream.reply = true
return stream
}
func PacketConn(conn net.Conn) net.PacketConn {
return &packetConn{
Conn: conn,
}
}
func (s *Snell) WritePacketFrame(b []byte) (int, error) {
if fw, ok := s.Conn.(packetFrameWriter); ok {
return fw.WritePacketFrame(b)
}
return s.Conn.Write(b)
}
func WritePacket(w io.Writer, target, payload []byte) (int, error) {
maxPayloadLength := maxLength - udpRequestHeaderLength(target)
if maxPayloadLength <= 0 {
return 0, errors.New("snell UDP address too large")
}
if len(payload) <= maxPayloadLength {
return writePacket(w, target, payload)
}
return 0, errors.New("snell UDP payload too large")
}
func WritePacketResponse(w io.Writer, addr net.Addr, payload []byte) (int, error) {
buffer := &bytes.Buffer{}
target := parseAddrToSocksAddr(addr)
if len(target) == 0 {
return 0, errors.New("snell UDP response address invalid")
}
switch target[0] {
case atypIPv4:
if len(target) < 1+net.IPv4len+2 {
return 0, errors.New("snell UDP response address invalid")
}
buffer.WriteByte(0x04)
buffer.Write(target[1 : 1+net.IPv4len+2])
case atypIPv6:
if len(target) < 1+net.IPv6len+2 {
return 0, errors.New("snell UDP response address invalid")
}
buffer.WriteByte(0x06)
buffer.Write(target[1 : 1+net.IPv6len+2])
default:
return 0, errors.New("snell UDP response address invalid")
}
buffer.Write(payload)
var err error
if fw, ok := w.(packetFrameWriter); ok {
_, err = fw.WritePacketFrame(buffer.Bytes())
} else {
_, err = w.Write(buffer.Bytes())
}
if err != nil {
return 0, err
}
return len(payload), nil
}
// UDPRequest is a parsed snell UDP forward request.
type UDPRequest struct {
Host string
Ip netip.Addr
Port uint16
Payload []byte
}
func ParseUDPRequest(packet []byte) (UDPRequest, error) {
if len(packet) < 2 || packet[0] != CommandUDPForward {
return UDPRequest{}, errors.New("snell invalid UDP request")
}
if hostLen := int(packet[1]); hostLen != 0 {
if len(packet) <= 2+hostLen+2 {
return UDPRequest{}, errors.New("snell invalid UDP domain request")
}
offset := 2 + hostLen
return UDPRequest{
Host: string(packet[2:offset]),
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
Payload: packet[offset+2:],
}, nil
}
if len(packet) < 3 {
return UDPRequest{}, errors.New("snell invalid UDP IP request")
}
switch packet[2] {
case 0x04:
if len(packet) < 3+net.IPv4len+2 {
return UDPRequest{}, errors.New("snell invalid UDP IPv4 request")
}
offset := 3 + net.IPv4len
ip, _ := netip.AddrFromSlice(packet[3:offset])
return UDPRequest{
Ip: ip.Unmap(),
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
Payload: packet[offset+2:],
}, nil
case 0x06:
if len(packet) < 3+net.IPv6len+2 {
return UDPRequest{}, errors.New("snell invalid UDP IPv6 request")
}
offset := 3 + net.IPv6len
ip, _ := netip.AddrFromSlice(packet[3:offset])
return UDPRequest{
Ip: ip.Unmap(),
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
Payload: packet[offset+2:],
}, nil
default:
return UDPRequest{}, errors.New("snell invalid UDP address type")
}
}
func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) {
b := buf.Get(buf.UDPBufferSize)
defer buf.Put(b)
n, err := r.Read(b)
headLen := 1
if err != nil {
return nil, 0, err
}
if n < headLen {
return nil, 0, errors.New("insufficient UDP length")
}
switch b[0] {
case 0x04:
headLen += net.IPv4len + 2
if n < headLen {
err = errors.New("insufficient UDP length")
break
}
b[0] = atypIPv4
case 0x06:
headLen += net.IPv6len + 2
if n < headLen {
err = errors.New("insufficient UDP length")
break
}
b[0] = atypIPv6
default:
err = errors.New("ip version invalid")
}
if err != nil {
return nil, 0, err
}
addr := splitSocksAddr(b[0:])
if addr == nil {
return nil, 0, errors.New("remote address invalid")
}
uAddr := addr.UDPAddr()
if uAddr == nil {
return nil, 0, errors.New("parse addr error")
}
length := len(payload)
if n-headLen < length {
length = n - headLen
}
copy(payload[:], b[headLen:headLen+length])
return uAddr, length, nil
}
var endSignal = []byte{}
type packetFrameWriter interface {
WritePacketFrame([]byte) (int, error)
}
func writeZeroChunk(conn net.Conn) error {
if _, err := conn.Write(endSignal); err != nil {
return err
}
return nil
}
func writePacket(w io.Writer, target, payload []byte) (int, error) {
buffer := &bytes.Buffer{}
buffer.WriteByte(CommandUDPForward)
switch target[0] {
case atypDomainName:
hostLen := target[1]
if len(target) < 1+1+int(hostLen)+2 {
return 0, errors.New("snell UDP address invalid")
}
buffer.Write(target[1 : 1+1+hostLen+2])
case atypIPv4:
if len(target) < 1+net.IPv4len+2 {
return 0, errors.New("snell UDP address invalid")
}
buffer.Write([]byte{0x00, 0x04})
buffer.Write(target[1 : 1+net.IPv4len+2])
case atypIPv6:
if len(target) < 1+net.IPv6len+2 {
return 0, errors.New("snell UDP address invalid")
}
buffer.Write([]byte{0x00, 0x06})
buffer.Write(target[1 : 1+net.IPv6len+2])
default:
return 0, errors.New("snell UDP address invalid")
}
buffer.Write(payload)
if fw, ok := w.(packetFrameWriter); ok {
_, err := fw.WritePacketFrame(buffer.Bytes())
if err != nil {
return 0, err
}
return len(payload), nil
}
_, err := w.Write(buffer.Bytes())
if err != nil {
return 0, err
}
return len(payload), nil
}
func udpRequestHeaderLength(target []byte) int {
if len(target) == 0 {
return maxLength + 1
}
switch target[0] {
case atypDomainName:
if len(target) < 2 {
return maxLength + 1
}
return 1 + 1 + int(target[1]) + 2
case atypIPv4:
return 1 + 2 + net.IPv4len + 2
case atypIPv6:
return 1 + 2 + net.IPv6len + 2
default:
return maxLength + 1
}
}
type packetConn struct {
net.Conn
rMux sync.Mutex
wMux sync.Mutex
}
func (pc *packetConn) WritePacketFrame(b []byte) (int, error) {
if s, ok := pc.Conn.(*Snell); ok {
if fw, ok := s.Conn.(packetFrameWriter); ok {
return fw.WritePacketFrame(b)
}
}
return pc.Conn.Write(b)
}
func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
pc.wMux.Lock()
defer pc.wMux.Unlock()
return WritePacket(pc, parseAddrToSocksAddr(addr), b)
}
func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
pc.rMux.Lock()
defer pc.rMux.Unlock()
addr, n, err := ReadPacket(pc.Conn, b)
if err != nil {
return 0, nil, err
}
return n, addr, nil
}

463
transport/snell/v4.go Normal file
View File

@@ -0,0 +1,463 @@
package snell
import (
"crypto/cipher"
cryptorand "crypto/rand"
"encoding/binary"
"errors"
"io"
"math"
"math/big"
"math/bits"
"net"
"sync"
"time"
)
const (
v4SaltSize = 16
v4NonceSize = 12
v4HeaderPlainSize = 7
v4HeaderCipherSize = v4HeaderPlainSize + 16
v4FrameSize = 1460
v4InitialPaddingMin = 0x100
v4InitialPaddingSpan = 0x100
)
type v4Conn struct {
net.Conn
psk []byte
r *v4Reader
w *v4Writer
}
func newV4Conn(conn net.Conn, psk []byte) *v4Conn {
return &v4Conn{Conn: conn, psk: psk}
}
func (c *v4Conn) initReader() error {
salt := make([]byte, v4SaltSize)
if _, err := io.ReadFull(c.Conn, salt); err != nil {
return err
}
aead, err := v4AEAD(c.psk, salt)
if err != nil {
return err
}
c.r = &v4Reader{Reader: c.Conn, aead: aead}
return nil
}
func (c *v4Conn) initWriter() error {
w, err := newV4Writer(c.Conn, c.psk)
if err != nil {
return err
}
c.w = w
return nil
}
func (c *v4Conn) Read(b []byte) (int, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
return c.r.Read(b)
}
func (c *v4Conn) Write(b []byte) (int, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
return c.w.Write(b)
}
func (c *v4Conn) WritePacketFrame(b []byte) (int, error) {
if len(b) > maxLength {
return 0, errors.New("snell v4 frame too large")
}
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
c.w.mux.Lock()
defer c.w.mux.Unlock()
if err := c.w.writeFrame(b, c.w.nextFramePaddingLength(len(b))); err != nil {
return 0, err
}
return len(b), nil
}
func (c *v4Conn) WriteTo(w io.Writer) (int64, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
var written int64
buf := make([]byte, maxLength)
for {
n, err := c.r.Read(buf)
if n > 0 {
nw, ew := w.Write(buf[:n])
written += int64(nw)
if ew != nil {
return written, ew
}
if nw != n {
return written, io.ErrShortWrite
}
}
if err != nil {
if err == io.EOF {
err = nil
}
return written, err
}
}
}
func (c *v4Conn) ReadFrom(r io.Reader) (int64, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
var read int64
buf := make([]byte, maxLength)
for {
n, err := r.Read(buf)
if n > 0 {
read += int64(n)
if _, ew := c.w.Write(buf[:n]); ew != nil {
return read, ew
}
}
if err != nil {
if err == io.EOF {
err = nil
}
return read, err
}
}
}
func v4AEAD(psk, salt []byte) (cipher.AEAD, error) {
return aesGCM(snellKDF(psk, salt, 16))
}
type v4Reader struct {
io.Reader
aead cipher.AEAD
nonce [v4NonceSize]byte
buf []byte
mux sync.Mutex
}
func (r *v4Reader) Read(b []byte) (int, error) {
r.mux.Lock()
defer r.mux.Unlock()
if len(r.buf) == 0 {
payload, err := r.readFrame()
if err != nil {
return 0, err
}
r.buf = payload
}
n := copy(b, r.buf)
r.buf = r.buf[n:]
return n, nil
}
func (r *v4Reader) readFrame() ([]byte, error) {
headerCipher := make([]byte, v4HeaderCipherSize)
if _, err := io.ReadFull(r.Reader, headerCipher); err != nil {
return nil, err
}
header, err := r.aead.Open(headerCipher[:0], r.nonce[:], headerCipher, nil)
incrementV4Nonce(r.nonce[:])
if err != nil {
return nil, err
}
if len(header) != v4HeaderPlainSize || header[0] != 4 {
return nil, errors.New("snell v4 invalid frame header")
}
paddingLength := int(binary.BigEndian.Uint16(header[3:5]))
payloadLength := int(binary.BigEndian.Uint16(header[5:7]))
if payloadLength == 0 {
if paddingLength != 0 {
return nil, errors.New("snell v4 zero chunk with padding")
}
return nil, ErrZeroChunk
}
if payloadLength > maxLength || paddingLength > maxLength {
return nil, errors.New("snell v4 frame too large")
}
payloadCipherLength := payloadLength + r.aead.Overhead()
frame := make([]byte, paddingLength+payloadCipherLength)
if _, err := io.ReadFull(r.Reader, frame); err != nil {
return nil, err
}
if paddingLength > 0 {
swapPadding(frame[:paddingLength], frame[paddingLength:])
}
payloadCipher := frame[paddingLength:]
payload, err := r.aead.Open(payloadCipher[:0], r.nonce[:], payloadCipher, nil)
incrementV4Nonce(r.nonce[:])
if err != nil {
return nil, err
}
return payload, nil
}
type v4Writer struct {
io.Writer
aead cipher.AEAD
nonce [v4NonceSize]byte
salt [v4SaltSize]byte
saltSent bool
initialPaddingLength uint16
payloadLimit uint16
lastWrite time.Time
mux sync.Mutex
}
func newV4Writer(w io.Writer, psk []byte) (*v4Writer, error) {
var salt [v4SaltSize]byte
if _, err := io.ReadFull(cryptorand.Reader, salt[:]); err != nil {
return nil, err
}
aead, err := v4AEAD(psk, salt[:])
if err != nil {
return nil, err
}
paddingDelta, err := cryptoRandomInt(v4InitialPaddingSpan)
if err != nil {
return nil, err
}
return &v4Writer{
Writer: w,
aead: aead,
salt: salt,
initialPaddingLength: uint16(v4InitialPaddingMin + paddingDelta),
}, nil
}
func (w *v4Writer) Write(b []byte) (int, error) {
w.mux.Lock()
defer w.mux.Unlock()
if len(b) == 0 {
return 0, w.writeFrame(nil, 0)
}
written := 0
for written < len(b) {
payloadLimit := int(w.nextPayloadLimit())
if payloadLimit <= 0 || payloadLimit > maxLength {
payloadLimit = maxLength
}
end := written + payloadLimit
if end > len(b) {
end = len(b)
}
paddingLength := w.nextFramePaddingLength(end - written)
if err := w.writeFrame(b[written:end], paddingLength); err != nil {
return written, err
}
written = end
}
return written, nil
}
func (w *v4Writer) nextPayloadLimit() uint16 {
now := time.Now()
var payloadLimit uint16
switch {
case w.lastWrite.IsZero():
payloadLimit = v4FrameSize - 55 - w.initialPaddingLength
case now.Sub(w.lastWrite) > 30*time.Second:
payloadLimit = v4FrameSize - 39
default:
payloadLimit = w.payloadLimit
}
w.lastWrite = now
if payloadLimit <= maxLength-1 {
next := int(payloadLimit) + v4FrameSize - 39
if next > maxLength {
next = maxLength
}
w.payloadLimit = uint16(next)
} else {
w.payloadLimit = maxLength
}
return payloadLimit
}
func (w *v4Writer) nextFramePaddingLength(payloadLength int) int {
if w.saltSent || payloadLength == 0 {
return 0
}
return int(w.initialPaddingLength)
}
func (w *v4Writer) writeFrame(payload []byte, paddingLength int) error {
if len(payload) > maxLength || paddingLength > maxLength {
return errors.New("snell v4 frame too large")
}
if len(payload) == 0 && paddingLength != 0 {
return errors.New("snell v4 zero chunk with padding")
}
header := make([]byte, v4HeaderPlainSize)
header[0] = 4
binary.BigEndian.PutUint16(header[3:5], uint16(paddingLength))
binary.BigEndian.PutUint16(header[5:7], uint16(len(payload)))
headerCipher := w.aead.Seal(nil, w.nonce[:], header, nil)
incrementV4Nonce(w.nonce[:])
var payloadCipher []byte
if len(payload) > 0 {
payloadCipher = w.aead.Seal(nil, w.nonce[:], payload, nil)
incrementV4Nonce(w.nonce[:])
}
frameLength := len(headerCipher) + paddingLength + len(payloadCipher)
if !w.saltSent {
frameLength += v4SaltSize
}
frame := make([]byte, 0, frameLength)
if !w.saltSent {
frame = append(frame, w.salt[:]...)
w.saltSent = true
}
frame = append(frame, headerCipher...)
if paddingLength > 0 {
padding, err := makeV4Padding(payloadCipher, paddingLength)
if err != nil {
return err
}
swapPadding(padding, payloadCipher)
frame = append(frame, padding...)
}
frame = append(frame, payloadCipher...)
return writeFull(w.Writer, frame)
}
func swapPadding(padding, payloadCipher []byte) {
limit := len(padding)
if len(payloadCipher) < limit {
limit = len(payloadCipher)
}
for i := 0; i < limit; i += 2 {
padding[i], payloadCipher[i] = payloadCipher[i], padding[i]
}
}
func makeV4Padding(payloadCipher []byte, paddingLength int) ([]byte, error) {
if paddingLength <= 0 {
return nil, nil
}
payloadOnes := countV4PayloadOnes(payloadCipher)
payloadZeros := 8*len(payloadCipher) - payloadOnes
if payloadZeros <= 0 {
return makeV4RandomPadding(paddingLength)
}
ratio := float64(payloadOnes) / float64(payloadZeros)
if ratio <= 0.5 || ratio >= 1.6 {
return makeV4RandomPadding(paddingLength)
}
targetRatioBase := 1.6
if payloadZeros < payloadOnes {
targetRatioBase = 0.4
}
jitter, err := randomUnitFloat64()
if err != nil {
return nil, err
}
targetRatio := targetRatioBase + jitter/10
totalBits := 8 * (paddingLength + len(payloadCipher))
targetOnes := int(float64(totalBits)*(targetRatio/(targetRatio+1)) - float64(payloadOnes))
if targetOnes < 0 || targetOnes > 8*paddingLength {
return makeV4RandomPadding(paddingLength)
}
return makeV4BitCountPadding(paddingLength, targetOnes)
}
func countV4PayloadOnes(payloadCipher []byte) int {
limit := len(payloadCipher) &^ 3
ones := 0
for _, b := range payloadCipher[:limit] {
ones += bits.OnesCount8(b)
}
return ones
}
func makeV4RandomPadding(length int) ([]byte, error) {
padding := make([]byte, length)
_, err := io.ReadFull(cryptorand.Reader, padding)
return padding, err
}
func makeV4BitCountPadding(length, oneBits int) ([]byte, error) {
totalBits := 8 * length
if oneBits < 0 || oneBits > totalBits {
return nil, errors.New("snell v4 invalid padding bit count")
}
bitset := make([]byte, totalBits)
for i := 0; i < oneBits; i++ {
bitset[i] = 1
}
for i := totalBits - 1; i > 0; i-- {
j, err := cryptoRandomInt(i + 1)
if err != nil {
return nil, err
}
bitset[i], bitset[j] = bitset[j], bitset[i]
}
padding := make([]byte, length)
for i, bit := range bitset {
if bit == 1 {
padding[i/8] |= 1 << uint(i%8)
}
}
return padding, nil
}
func cryptoRandomInt(max int) (int, error) {
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(int64(max)))
if err != nil {
return 0, err
}
return int(n.Int64()), nil
}
func randomUnitFloat64() (float64, error) {
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(1<<53))
if err != nil {
return 0, err
}
return float64(n.Int64()) / math.Exp2(53), nil
}
func writeFull(w io.Writer, p []byte) error {
for len(p) > 0 {
n, err := w.Write(p)
if err != nil {
return err
}
if n == 0 {
return io.ErrShortWrite
}
p = p[n:]
}
return nil
}
func incrementV4Nonce(nonce []byte) {
for i := range nonce {
nonce[i]++
if nonce[i] != 0 {
return
}
}
}

View File

@@ -18,9 +18,12 @@ const (
) )
const ( const (
headerSize = 1 + 4 + 4 headerSize = 1 + 4 + 4
maxFrameSize = 256 * 1024 // maxQueuedBytesPerStream bounds unread payload retained by a single logical stream.
maxDataPayload = 32 * 1024 // Backpressure is applied to the demux loop instead of dropping data.
maxQueuedBytesPerStream = 4 * 1024 * 1024
maxFrameSize = 256 * 1024
maxDataPayload = 128 * 1024
) )
type acceptEvent struct { type acceptEvent struct {
@@ -344,6 +347,8 @@ type stream struct {
closeErr error closeErr error
readBuf []byte readBuf []byte
queue [][]byte queue [][]byte
// queuedBytes includes unread bytes in readBuf and queue.
queuedBytes int
localAddr net.Addr localAddr net.Addr
remoteAddr net.Addr remoteAddr net.Addr
@@ -362,16 +367,20 @@ func newStream(session *Session, id uint32) *stream {
func (c *stream) enqueue(payload []byte) { func (c *stream) enqueue(payload []byte) {
c.mu.Lock() c.mu.Lock()
for !c.closed && c.queuedBytes+len(payload) > maxQueuedBytesPerStream {
c.cond.Wait()
}
if c.closed { if c.closed {
c.mu.Unlock() c.mu.Unlock()
return return
} }
c.queuedBytes += len(payload)
if len(c.readBuf) == 0 && len(c.queue) == 0 { if len(c.readBuf) == 0 && len(c.queue) == 0 {
c.readBuf = payload c.readBuf = payload
} else { } else {
c.queue = append(c.queue, payload) c.queue = append(c.queue, payload)
} }
c.cond.Signal() c.cond.Broadcast()
c.mu.Unlock() c.mu.Unlock()
} }
@@ -413,7 +422,11 @@ func (c *stream) Read(p []byte) (int, error) {
} }
if len(c.readBuf) == 0 && len(c.queue) > 0 { if len(c.readBuf) == 0 && len(c.queue) > 0 {
c.readBuf = c.queue[0] c.readBuf = c.queue[0]
c.queue[0] = nil
c.queue = c.queue[1:] c.queue = c.queue[1:]
if len(c.queue) == 0 {
c.queue = nil
}
} }
if len(c.readBuf) == 0 && c.closed { if len(c.readBuf) == 0 && c.closed {
if c.closeErr == nil { if c.closeErr == nil {
@@ -424,6 +437,14 @@ func (c *stream) Read(p []byte) (int, error) {
n := copy(p, c.readBuf) n := copy(p, c.readBuf)
c.readBuf = c.readBuf[n:] c.readBuf = c.readBuf[n:]
if len(c.readBuf) == 0 {
c.readBuf = nil
}
c.queuedBytes -= n
if c.queuedBytes < 0 {
c.queuedBytes = 0
}
c.cond.Broadcast()
return n, nil return n, nil
} }

View File

@@ -0,0 +1,91 @@
package multiplex
import (
"bytes"
"crypto/rand"
"io"
"net"
"sync"
"testing"
"time"
)
// TestSession_LargeTransferBackpressure verifies that a transfer larger than
// maxQueuedBytesPerStream completes correctly: the demux loop applies
// backpressure (cond.Wait) instead of dropping data, and the reader draining
// the stream wakes the blocked loop without deadlock.
func TestSession_LargeTransferBackpressure(t *testing.T) {
c1, c2 := net.Pipe()
client, err := NewClientSession(c1)
if err != nil {
t.Fatalf("client session: %v", err)
}
server, err := NewServerSession(c2)
if err != nil {
t.Fatalf("server session: %v", err)
}
defer client.Close()
defer server.Close()
// Payload bigger than the per-stream backpressure window (4MB).
const total = 12 * 1024 * 1024
payload := make([]byte, total)
if _, err := rand.Read(payload); err != nil {
t.Fatalf("rand: %v", err)
}
var wg sync.WaitGroup
wg.Add(2)
var writeErr error
go func() {
defer wg.Done()
stream, err := client.OpenStream([]byte("hello"))
if err != nil {
writeErr = err
return
}
defer stream.Close()
if _, err := stream.Write(payload); err != nil {
writeErr = err
return
}
_ = stream.(interface{ CloseWrite() error }).CloseWrite()
}()
var got []byte
var readErr error
go func() {
defer wg.Done()
stream, openPayload, err := server.AcceptStream()
if err != nil {
readErr = err
return
}
if string(openPayload) != "hello" {
readErr = io.ErrUnexpectedEOF
return
}
got, readErr = io.ReadAll(stream)
}()
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
case <-time.After(30 * time.Second):
t.Fatal("transfer deadlocked (backpressure did not release)")
}
if writeErr != nil {
t.Fatalf("write: %v", writeErr)
}
if readErr != nil {
t.Fatalf("read: %v", readErr)
}
if !bytes.Equal(got, payload) {
t.Fatalf("payload mismatch: got %d bytes, want %d", len(got), len(payload))
}
}

View File

@@ -0,0 +1,56 @@
package sudoku
import "testing"
func TestNormalizeASCIIMode(t *testing.T) {
tests := []struct {
in string
want string
}{
{"", "prefer_entropy"},
{"entropy", "prefer_entropy"},
{"prefer_ascii", "prefer_ascii"},
{"up_ascii_down_entropy", "up_ascii_down_entropy"},
{"up_entropy_down_ascii", "up_entropy_down_ascii"},
{"up_prefer_ascii_down_prefer_entropy", "up_ascii_down_entropy"},
}
for _, tt := range tests {
got, err := NormalizeASCIIMode(tt.in)
if err != nil {
t.Fatalf("NormalizeASCIIMode(%q): %v", tt.in, err)
}
if got != tt.want {
t.Fatalf("NormalizeASCIIMode(%q) = %q, want %q", tt.in, got, tt.want)
}
}
if _, err := NormalizeASCIIMode("up_ascii_down_binary"); err == nil {
t.Fatalf("expected invalid directional mode to fail")
}
}
func TestNewTableWithCustomDirectionalOpposite(t *testing.T) {
table, err := NewTableWithCustom("seed", "up_ascii_down_entropy", "xpxvvpvv")
if err != nil {
t.Fatalf("NewTableWithCustom: %v", err)
}
if !table.IsASCII {
t.Fatalf("uplink table should be ascii")
}
opposite := table.OppositeDirection()
if opposite == nil || opposite == table {
t.Fatalf("expected distinct opposite table")
}
if opposite.IsASCII {
t.Fatalf("downlink table should be entropy/custom")
}
symmetric, err := NewTableWithCustom("seed", "prefer_ascii", "xpxvvpvv")
if err != nil {
t.Fatalf("NewTableWithCustom symmetric: %v", err)
}
if symmetric.OppositeDirection() != symmetric {
t.Fatalf("symmetric table should point to itself")
}
}

View File

@@ -3,6 +3,7 @@ package sudoku
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"io"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -10,6 +11,8 @@ import (
const IOBufferSize = 32 * 1024 const IOBufferSize = 32 * 1024
const minDecodeReadSize = 64
var perm4 = [24][4]byte{ var perm4 = [24][4]byte{
{0, 1, 2, 3}, {0, 1, 2, 3},
{0, 1, 3, 2}, {0, 1, 3, 2},
@@ -52,7 +55,7 @@ type Conn struct {
writeMu sync.Mutex writeMu sync.Mutex
writeBuf []byte writeBuf []byte
rng randomSource rng *sudokuRand
paddingThreshold uint64 paddingThreshold uint64
} }
@@ -97,6 +100,9 @@ func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
} }
func (sc *Conn) StopRecording() { func (sc *Conn) StopRecording() {
if sc == nil {
return
}
sc.recordLock.Lock() sc.recordLock.Lock()
sc.recording.Store(false) sc.recording.Store(false)
sc.recorder = nil sc.recorder = nil
@@ -115,6 +121,9 @@ func (sc *Conn) GetBufferedAndRecorded() []byte {
if sc.recorder != nil { if sc.recorder != nil {
recorded = sc.recorder.Bytes() recorded = sc.recorder.Bytes()
} }
if sc.reader == nil {
return recorded
}
buffered := sc.reader.Buffered() buffered := sc.reader.Buffered()
if buffered > 0 { if buffered > 0 {
@@ -131,6 +140,9 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
if len(p) == 0 { if len(p) == 0 {
return 0, nil return 0, nil
} }
if sc == nil || sc.Conn == nil || sc.table == nil || sc.table.layout == nil || sc.rng == nil {
return 0, io.ErrClosedPipe
}
sc.writeMu.Lock() sc.writeMu.Lock()
defer sc.writeMu.Unlock() defer sc.writeMu.Unlock()
@@ -140,16 +152,19 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
} }
func (sc *Conn) Read(p []byte) (n int, err error) { func (sc *Conn) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
if sc == nil || sc.Conn == nil || sc.reader == nil || len(sc.rawBuf) == 0 || sc.table == nil || sc.table.layout == nil {
return 0, io.ErrClosedPipe
}
if n, ok := drainPending(p, &sc.pendingData); ok { if n, ok := drainPending(p, &sc.pendingData); ok {
return n, nil return n, nil
} }
outN := 0
for { for {
if sc.pendingData.available() > 0 { nr, rErr := readRawLimited(sc.Conn, sc.reader, sc.rawBuf[:sudokuReadSize(len(p)-outN, len(sc.rawBuf))])
break
}
nr, rErr := sc.reader.Read(sc.rawBuf)
if nr > 0 { if nr > 0 {
chunk := sc.rawBuf[:nr] chunk := sc.rawBuf[:nr]
if sc.recording.Load() { if sc.recording.Load() {
@@ -160,34 +175,80 @@ func (sc *Conn) Read(p []byte) (n int, err error) {
sc.recordLock.Unlock() sc.recordLock.Unlock()
} }
layout := sc.table.layout table := sc.table
for _, b := range chunk { layout := table.layout
for i := 0; i < len(chunk); {
if sc.hintCount == 0 && outN < len(p) && i+3 < len(chunk) &&
layout.hintTable[chunk[i]] &&
layout.hintTable[chunk[i+1]] &&
layout.hintTable[chunk[i+2]] &&
layout.hintTable[chunk[i+3]] {
val, ok := table.DecodeMap[packHintBytes(chunk[i], chunk[i+1], chunk[i+2], chunk[i+3])]
if !ok {
return 0, ErrInvalidSudokuMapMiss
}
p[outN] = val
outN++
i += 4
continue
}
b := chunk[i]
i++
if !layout.hintTable[b] { if !layout.hintTable[b] {
continue continue
} }
sc.hintBuf[sc.hintCount] = b sc.hintBuf[sc.hintCount] = b
sc.hintCount++ sc.hintCount++
if sc.hintCount == len(sc.hintBuf) { if sc.hintCount != len(sc.hintBuf) {
key := packHintsToKey(sc.hintBuf) continue
val, ok := sc.table.DecodeMap[key]
if !ok {
return 0, ErrInvalidSudokuMapMiss
}
sc.pendingData.appendByte(val)
sc.hintCount = 0
} }
val, ok := table.DecodeMap[packHintBytes(sc.hintBuf[0], sc.hintBuf[1], sc.hintBuf[2], sc.hintBuf[3])]
if !ok {
return 0, ErrInvalidSudokuMapMiss
}
outN = appendDecodedByte(p, outN, &sc.pendingData, val)
sc.hintCount = 0
} }
} }
if rErr != nil { if rErr != nil {
if outN > 0 {
return outN, nil
}
if n, ok := drainPending(p, &sc.pendingData); ok {
return n, nil
}
return 0, rErr return 0, rErr
} }
if sc.pendingData.available() > 0 { if outN > 0 {
break return outN, nil
} }
} }
}
n, _ = drainPending(p, &sc.pendingData)
return n, nil func sudokuReadSize(decodedRemaining, maxRaw int) int {
if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 {
return maxRaw
}
if decodedRemaining > (maxRaw-minDecodeReadSize)/5 {
return maxRaw
}
return decodedRemaining*5 + minDecodeReadSize
}
func readRawLimited(conn net.Conn, reader *bufio.Reader, dst []byte) (int, error) {
if len(dst) == 0 {
return 0, nil
}
if reader != nil && reader.Buffered() > 0 {
return reader.Read(dst)
}
if conn == nil {
return 0, io.ErrClosedPipe
}
return conn.Read(dst)
} }

View File

@@ -0,0 +1,51 @@
package sudoku
import (
"bytes"
"io"
"testing"
)
// TestConn_Roundtrip exercises the optimized Conn encode/decode hot paths:
// the no-padding fast path (pMin==pMax==0), the always-padding path
// (pMin==pMax==100), a probabilistic range, and the adaptive read-size /
// 4-byte fast hint decode path across a variety of payload sizes and modes.
func TestConn_Roundtrip(t *testing.T) {
modes := []string{"prefer_entropy", "prefer_ascii"}
paddings := []struct{ min, max int }{
{0, 0}, // no-padding specialized path
{100, 100}, // always-padding specialized path
{20, 60}, // probabilistic path
}
sizes := []int{1, 3, 4, 7, 16, 100, 1000, 64 * 1024}
for _, mode := range modes {
for _, pad := range paddings {
for _, size := range sizes {
payload := make([]byte, size)
for i := range payload {
payload[i] = byte(i*31 + 7)
}
table := NewTable("conn-roundtrip-seed", mode)
// Encode via Conn.Write.
w := &mockConn{}
enc := NewConn(w, table, pad.min, pad.max, false)
if _, err := enc.Write(payload); err != nil {
t.Fatalf("mode=%s pad=%v size=%d write: %v", mode, pad, size, err)
}
// Decode via Conn.Read using the same table.
dec := NewConn(&mockConn{readBuf: w.writeBuf}, table, pad.min, pad.max, false)
got := make([]byte, size)
if _, err := io.ReadFull(dec, got); err != nil {
t.Fatalf("mode=%s pad=%v size=%d read: %v", mode, pad, size, err)
}
if !bytes.Equal(got, payload) {
t.Fatalf("mode=%s pad=%v size=%d roundtrip mismatch", mode, pad, size)
}
}
}
}
}

View File

@@ -1,9 +1,12 @@
package sudoku package sudoku
func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThreshold uint64, p []byte) []byte { func encodeSudokuPayload(dst []byte, table *Table, rng *sudokuRand, paddingThreshold uint64, p []byte) []byte {
if len(p) == 0 { if len(p) == 0 {
return dst[:0] return dst[:0]
} }
if paddingThreshold == 0 {
return encodeSudokuPayloadNoPadding(dst, table, rng, p)
}
outCapacity := len(p)*6 + 1 outCapacity := len(p)*6 + 1
if cap(dst) < outCapacity { if cap(dst) < outCapacity {
@@ -13,8 +16,25 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre
pads := table.PaddingPool pads := table.PaddingPool
padLen := len(pads) padLen := len(pads)
if paddingThreshold >= probOne {
for _, b := range p {
out = append(out, pads[rng.Intn(padLen)])
puzzles := table.EncodeTable[b]
puzzle := puzzles[rng.Intn(len(puzzles))]
perm := perm4[rng.Intn(len(perm4))]
for _, idx := range perm {
out = append(out, pads[rng.Intn(padLen)], puzzle[idx])
}
}
out = append(out, pads[rng.Intn(padLen)])
return out
}
for _, b := range p { for _, b := range p {
if shouldPad(rng, paddingThreshold) { if uint64(rng.Uint32()) < paddingThreshold {
out = append(out, pads[rng.Intn(padLen)]) out = append(out, pads[rng.Intn(padLen)])
} }
@@ -22,15 +42,31 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre
puzzle := puzzles[rng.Intn(len(puzzles))] puzzle := puzzles[rng.Intn(len(puzzles))]
perm := perm4[rng.Intn(len(perm4))] perm := perm4[rng.Intn(len(perm4))]
for _, idx := range perm { for _, idx := range perm {
if shouldPad(rng, paddingThreshold) { if uint64(rng.Uint32()) < paddingThreshold {
out = append(out, pads[rng.Intn(padLen)]) out = append(out, pads[rng.Intn(padLen)])
} }
out = append(out, puzzle[idx]) out = append(out, puzzle[idx])
} }
} }
if shouldPad(rng, paddingThreshold) { if uint64(rng.Uint32()) < paddingThreshold {
out = append(out, pads[rng.Intn(padLen)]) out = append(out, pads[rng.Intn(padLen)])
} }
return out return out
} }
func encodeSudokuPayloadNoPadding(dst []byte, table *Table, rng *sudokuRand, p []byte) []byte {
outCapacity := len(p) * 4
if cap(dst) < outCapacity {
dst = make([]byte, 0, outCapacity)
}
out := dst[:0]
for _, b := range p {
puzzles := table.EncodeTable[b]
puzzle := puzzles[rng.Intn(len(puzzles))]
perm := perm4[rng.Intn(len(perm4))]
out = append(out, puzzle[perm[0]], puzzle[perm[1]], puzzle[perm[2]], puzzle[perm[3]])
}
return out
}

View File

@@ -8,9 +8,9 @@ import (
) )
const ( const (
RngBatchSize = 128
packedProtectedPrefixBytes = 14 packedProtectedPrefixBytes = 14
packedIOBufferSize = 64 * 1024
packedDecodeBufferSize = 96 * 1024
) )
// PackedConn encodes traffic with the packed Sudoku layout while preserving // PackedConn encodes traffic with the packed Sudoku layout while preserving
@@ -35,7 +35,7 @@ type PackedConn struct {
readBits int readBits int
// Padding selection matches Conn's threshold-based model. // Padding selection matches Conn's threshold-based model.
rng randomSource rng *sudokuRand
paddingThreshold uint64 paddingThreshold uint64
padMarker byte padMarker byte
padPool []byte padPool []byte
@@ -67,18 +67,20 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
pc := &PackedConn{ pc := &PackedConn{
Conn: c, Conn: c,
table: table, table: table,
reader: bufio.NewReaderSize(c, IOBufferSize), reader: bufio.NewReaderSize(c, packedIOBufferSize),
rawBuf: make([]byte, IOBufferSize), rawBuf: make([]byte, packedDecodeBufferSize),
pendingData: newPendingBuffer(4096), pendingData: newPendingBuffer(4096),
writeBuf: make([]byte, 0, 4096), writeBuf: make([]byte, 0, 4096),
rng: localRng, rng: localRng,
paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax), paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
} }
pc.padMarker = table.layout.padMarker if table != nil && table.layout != nil {
for _, b := range table.PaddingPool { pc.padMarker = table.layout.padMarker
if b != pc.padMarker { for _, b := range table.PaddingPool {
pc.padPool = append(pc.padPool, b) if b != pc.padMarker {
pc.padPool = append(pc.padPool, b)
}
} }
} }
if len(pc.padPool) == 0 { if len(pc.padPool) == 0 {
@@ -87,18 +89,6 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
return pc return pc
} }
func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
if shouldPad(pc.rng, pc.paddingThreshold) {
out = append(out, pc.getPaddingByte())
}
return out
}
func (pc *PackedConn) appendGroup(out []byte, group byte) []byte {
out = pc.maybeAddPadding(out)
return append(out, pc.table.layout.groupByte(group))
}
func (pc *PackedConn) appendForcedPadding(out []byte) []byte { func (pc *PackedConn) appendForcedPadding(out []byte) []byte {
return append(out, pc.getPaddingByte()) return append(out, pc.getPaddingByte())
} }
@@ -134,7 +124,7 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
} else { } else {
pc.bitBuf &= (1 << pc.bitCount) - 1 pc.bitBuf &= (1 << pc.bitCount) - 1
} }
out = pc.appendGroup(out, group&0x3F) out = appendPackedGroup(out, pc.table.layout, pc.rng, pc.paddingThreshold, pc.padPool, group)
} }
effective++ effective++
@@ -148,19 +138,49 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
return out, limit return out, limit
} }
func appendPackedGroup(out []byte, layout *byteLayout, rng *sudokuRand, paddingThreshold uint64, padPool []byte, group byte) []byte {
if paddingThreshold != 0 {
u := rng.Uint32()
if uint64(u) < paddingThreshold {
out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))])
}
}
return append(out, layout.encodeGroup[group&0x3F])
}
func maybeAppendPackedPadding(out []byte, rng *sudokuRand, paddingThreshold uint64, padPool []byte) []byte {
if paddingThreshold != 0 {
u := rng.Uint32()
if uint64(u) < paddingThreshold {
out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))])
}
}
return out
}
func (pc *PackedConn) Write(p []byte) (int, error) { func (pc *PackedConn) Write(p []byte) (int, error) {
if len(p) == 0 { if len(p) == 0 {
return 0, nil return 0, nil
} }
if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 {
return 0, io.ErrClosedPipe
}
pc.writeMu.Lock() pc.writeMu.Lock()
defer pc.writeMu.Unlock() defer pc.writeMu.Unlock()
needed := len(p)*3/2 + 32 needed := len(p)*3/2 + 32
if pc.paddingThreshold == 0 {
needed = ((len(p)+2)/3)*4 + 32
}
if cap(pc.writeBuf) < needed { if cap(pc.writeBuf) < needed {
pc.writeBuf = make([]byte, 0, needed) pc.writeBuf = make([]byte, 0, needed)
} }
out := pc.writeBuf[:0] out := pc.writeBuf[:0]
layout := pc.table.layout
rng := pc.rng
paddingThreshold := pc.paddingThreshold
padPool := pc.padPool
var prefixN int var prefixN int
out, prefixN = pc.writeProtectedPrefix(out, p) out, prefixN = pc.writeProtectedPrefix(out, p)
@@ -181,7 +201,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
} else { } else {
pc.bitBuf &= (1 << pc.bitCount) - 1 pc.bitBuf &= (1 << pc.bitCount) - 1
} }
out = pc.appendGroup(out, group&0x3F) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
} }
} }
@@ -195,10 +215,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
g4 := b3 & 0x3F g4 := b3 & 0x3F
out = pc.appendGroup(out, g1) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1)
out = pc.appendGroup(out, g2) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2)
out = pc.appendGroup(out, g3) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3)
out = pc.appendGroup(out, g4) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4)
} }
} }
@@ -211,10 +231,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
g4 := b3 & 0x3F g4 := b3 & 0x3F
out = pc.appendGroup(out, g1) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1)
out = pc.appendGroup(out, g2) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2)
out = pc.appendGroup(out, g3) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3)
out = pc.appendGroup(out, g4) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4)
} }
for ; i < n; i++ { for ; i < n; i++ {
@@ -229,7 +249,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
} else { } else {
pc.bitBuf &= (1 << pc.bitCount) - 1 pc.bitBuf &= (1 << pc.bitCount) - 1
} }
out = pc.appendGroup(out, group&0x3F) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
} }
} }
@@ -237,11 +257,11 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
group := byte(pc.bitBuf << (6 - pc.bitCount)) group := byte(pc.bitBuf << (6 - pc.bitCount))
pc.bitBuf = 0 pc.bitBuf = 0
pc.bitCount = 0 pc.bitCount = 0
out = pc.appendGroup(out, group&0x3F) out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
out = append(out, pc.padMarker) out = append(out, pc.padMarker)
} }
out = pc.maybeAddPadding(out) out = maybeAppendPackedPadding(out, rng, paddingThreshold, padPool)
if len(out) > 0 { if len(out) > 0 {
pc.writeBuf = out[:0] pc.writeBuf = out[:0]
@@ -252,6 +272,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
} }
func (pc *PackedConn) Flush() error { func (pc *PackedConn) Flush() error {
if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 {
return io.ErrClosedPipe
}
pc.writeMu.Lock() pc.writeMu.Lock()
defer pc.writeMu.Unlock() defer pc.writeMu.Unlock()
@@ -265,7 +289,7 @@ func (pc *PackedConn) Flush() error {
out = append(out, pc.padMarker) out = append(out, pc.padMarker)
} }
out = pc.maybeAddPadding(out) out = maybeAppendPackedPadding(out, pc.rng, pc.paddingThreshold, pc.padPool)
if len(out) > 0 { if len(out) > 0 {
pc.writeBuf = out[:0] pc.writeBuf = out[:0]
@@ -289,19 +313,44 @@ func writeFull(w io.Writer, b []byte) error {
} }
func (pc *PackedConn) Read(p []byte) (int, error) { func (pc *PackedConn) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
if pc == nil || pc.Conn == nil || pc.reader == nil || len(pc.rawBuf) == 0 || pc.table == nil || pc.table.layout == nil {
return 0, io.ErrClosedPipe
}
if n, ok := drainPending(p, &pc.pendingData); ok { if n, ok := drainPending(p, &pc.pendingData); ok {
return n, nil return n, nil
} }
outN := 0
for { for {
nr, rErr := pc.reader.Read(pc.rawBuf) nr, rErr := readRawLimited(pc.Conn, pc.reader, pc.rawBuf[:packedReadSize(len(p)-outN, len(pc.rawBuf))])
if nr > 0 { if nr > 0 {
rBuf := pc.readBitBuf rBuf := pc.readBitBuf
rBits := pc.readBits rBits := pc.readBits
padMarker := pc.padMarker padMarker := pc.padMarker
layout := pc.table.layout layout := pc.table.layout
for _, b := range pc.rawBuf[:nr] { chunk := pc.rawBuf[:nr]
for i := 0; i < len(chunk); {
if rBits == 0 && outN+3 <= len(p) && i+3 < len(chunk) &&
layout.hintTable[chunk[i]] && layout.hintTable[chunk[i+1]] &&
layout.hintTable[chunk[i+2]] && layout.hintTable[chunk[i+3]] {
g1 := layout.decodeGroup[chunk[i]]
g2 := layout.decodeGroup[chunk[i+1]]
g3 := layout.decodeGroup[chunk[i+2]]
g4 := layout.decodeGroup[chunk[i+3]]
p[outN] = (g1 << 2) | (g2 >> 4)
p[outN+1] = (g2 << 4) | (g3 >> 2)
p[outN+2] = (g3 << 6) | g4
outN += 3
i += 4
continue
}
b := chunk[i]
i++
if !layout.hintTable[b] { if !layout.hintTable[b] {
if b == padMarker { if b == padMarker {
rBuf = 0 rBuf = 0
@@ -321,7 +370,7 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
if rBits >= 8 { if rBits >= 8 {
rBits -= 8 rBits -= 8
val := byte(rBuf >> rBits) val := byte(rBuf >> rBits)
pc.pendingData.appendByte(val) outN = appendDecodedByte(p, outN, &pc.pendingData, val)
if rBits == 0 { if rBits == 0 {
rBuf = 0 rBuf = 0
} else { } else {
@@ -339,21 +388,32 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
pc.readBitBuf = 0 pc.readBitBuf = 0
pc.readBits = 0 pc.readBits = 0
} }
if pc.pendingData.available() > 0 { if outN > 0 {
break return outN, nil
}
if n, ok := drainPending(p, &pc.pendingData); ok {
return n, nil
} }
return 0, rErr return 0, rErr
} }
if pc.pendingData.available() > 0 { if outN > 0 {
break return outN, nil
} }
} }
n, _ := drainPending(p, &pc.pendingData)
return n, nil
} }
func (pc *PackedConn) getPaddingByte() byte { func (pc *PackedConn) getPaddingByte() byte {
return pc.padPool[pc.rng.Intn(len(pc.padPool))] return pc.padPool[pc.rng.Intn(len(pc.padPool))]
} }
func packedReadSize(decodedRemaining, maxRaw int) int {
if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 {
return maxRaw
}
if decodedRemaining > (maxRaw-minDecodeReadSize)/2 {
return maxRaw
}
return decodedRemaining*2 + minDecodeReadSize
}

View File

@@ -0,0 +1,90 @@
package sudoku
import (
"bytes"
"io"
"net"
"testing"
"time"
)
type mockConn struct {
readBuf []byte
writeBuf []byte
}
func (c *mockConn) Read(p []byte) (int, error) {
if len(c.readBuf) == 0 {
return 0, io.EOF
}
n := copy(p, c.readBuf)
c.readBuf = c.readBuf[n:]
return n, nil
}
func (c *mockConn) Write(p []byte) (int, error) {
c.writeBuf = append(c.writeBuf, p...)
return len(p), nil
}
func (c *mockConn) Close() error { return nil }
func (c *mockConn) LocalAddr() net.Addr { return nil }
func (c *mockConn) RemoteAddr() net.Addr { return nil }
func (c *mockConn) SetDeadline(time.Time) error { return nil }
func (c *mockConn) SetReadDeadline(time.Time) error { return nil }
func (c *mockConn) SetWriteDeadline(time.Time) error { return nil }
func TestPackedConn_ProtectedPrefixPadding(t *testing.T) {
table := NewTable("packed-prefix-seed", "prefer_ascii")
mock := &mockConn{}
writer := NewPackedConn(mock, table, 0, 0)
writer.rng = newSudokuRand(1)
payload := bytes.Repeat([]byte{0}, 32)
if _, err := writer.Write(payload); err != nil {
t.Fatalf("write: %v", err)
}
wire := append([]byte(nil), mock.writeBuf...)
if len(wire) < 20 {
t.Fatalf("wire too short: %d", len(wire))
}
firstHint := -1
nonHintCount := 0
maxHintRun := 0
currentHintRun := 0
for i, b := range wire[:20] {
if table.layout.isHint(b) {
if firstHint == -1 {
firstHint = i
}
currentHintRun++
if currentHintRun > maxHintRun {
maxHintRun = currentHintRun
}
continue
}
nonHintCount++
currentHintRun = 0
}
if firstHint < 1 || firstHint > 2 {
t.Fatalf("expected 1-2 leading padding bytes, first hint index=%d", firstHint)
}
if nonHintCount < 6 {
t.Fatalf("expected dense prefix padding, got only %d non-hint bytes in first 20", nonHintCount)
}
if maxHintRun > 3 {
t.Fatalf("prefix still exposes long hint run: %d", maxHintRun)
}
reader := NewPackedConn(&mockConn{readBuf: wire}, table, 0, 0)
decoded := make([]byte, len(payload))
if _, err := io.ReadFull(reader, decoded); err != nil {
t.Fatalf("read back: %v", err)
}
if !bytes.Equal(decoded, payload) {
t.Fatalf("roundtrip mismatch")
}
}

View File

@@ -2,7 +2,7 @@ package sudoku
const probOne = uint64(1) << 32 const probOne = uint64(1) << 32
func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 { func pickPaddingThreshold(r *sudokuRand, pMin, pMax int) uint64 {
if r == nil { if r == nil {
return 0 return 0
} }
@@ -28,7 +28,7 @@ func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 {
return min + (u * (max - min) >> 32) return min + (u * (max - min) >> 32)
} }
func shouldPad(r randomSource, threshold uint64) bool { func shouldPad(r *sudokuRand, threshold uint64) bool {
if threshold == 0 { if threshold == 0 {
return false return false
} }

View File

@@ -25,7 +25,10 @@ func (p *pendingBuffer) reset() {
} }
func (p *pendingBuffer) ensureAppendCapacity(extra int) { func (p *pendingBuffer) ensureAppendCapacity(extra int) {
if p == nil || extra <= 0 || p.off == 0 { if p == nil || extra <= 0 {
return
}
if p.off == 0 {
return return
} }
if cap(p.data)-len(p.data) >= extra { if cap(p.data)-len(p.data) >= extra {
@@ -43,6 +46,15 @@ func (p *pendingBuffer) appendByte(b byte) {
p.data = append(p.data, b) p.data = append(p.data, b)
} }
func appendDecodedByte(dst []byte, n int, pending *pendingBuffer, b byte) int {
if n < len(dst) {
dst[n] = b
return n + 1
}
pending.appendByte(b)
return n
}
func drainPending(dst []byte, pending *pendingBuffer) (int, bool) { func drainPending(dst []byte, pending *pendingBuffer) (int, bool) {
if pending == nil || pending.available() == 0 { if pending == nil || pending.available() == 0 {
return 0, false return 0, false

View File

@@ -6,14 +6,10 @@ import (
"time" "time"
) )
type randomSource interface {
Uint32() uint32
Uint64() uint64
Intn(n int) int
}
type sudokuRand struct { type sudokuRand struct {
state uint64 state uint64
cached uint32
haveCached bool
} }
func newSeededRand() *sudokuRand { func newSeededRand() *sudokuRand {
@@ -37,20 +33,36 @@ func (r *sudokuRand) Uint64() uint64 {
if r == nil { if r == nil {
return 0 return 0
} }
r.state += 0x9e3779b97f4a7c15 r.haveCached = false
z := r.state x := r.state
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9 x ^= x >> 12
z = (z ^ (z >> 27)) * 0x94d049bb133111eb x ^= x << 25
return z ^ (z >> 31) x ^= x >> 27
r.state = x
return x * 0x2545f4914f6cdd1d
} }
func (r *sudokuRand) Uint32() uint32 { func (r *sudokuRand) Uint32() uint32 {
return uint32(r.Uint64() >> 32) if r == nil {
return 0
}
if r.haveCached {
r.haveCached = false
return r.cached
}
v := r.Uint64()
r.cached = uint32(v)
r.haveCached = true
return uint32(v >> 32)
} }
func (r *sudokuRand) Intn(n int) int { func (r *sudokuRand) Intn(n int) int {
if n <= 1 { if n <= 1 {
return 0 return 0
} }
return int((uint64(r.Uint32()) * uint64(n)) >> 32) return fastIntnFromUint32(r.Uint32(), n)
}
func fastIntnFromUint32(u uint32, n int) int {
return int((uint64(u) * uint64(n)) >> 32)
} }

View File

@@ -192,23 +192,27 @@ func tableHintFingerprint(key string, mode string, uplinkPattern string, downlin
} }
func packHintsToKey(hints [4]byte) uint32 { func packHintsToKey(hints [4]byte) uint32 {
return packHintBytes(hints[0], hints[1], hints[2], hints[3])
}
func packHintBytes(h0, h1, h2, h3 byte) uint32 {
// Sorting network for 4 elements (Bubble sort unrolled) // Sorting network for 4 elements (Bubble sort unrolled)
// Swap if a > b // Swap if a > b
if hints[0] > hints[1] { if h0 > h1 {
hints[0], hints[1] = hints[1], hints[0] h0, h1 = h1, h0
} }
if hints[2] > hints[3] { if h2 > h3 {
hints[2], hints[3] = hints[3], hints[2] h2, h3 = h3, h2
} }
if hints[0] > hints[2] { if h0 > h2 {
hints[0], hints[2] = hints[2], hints[0] h0, h2 = h2, h0
} }
if hints[1] > hints[3] { if h1 > h3 {
hints[1], hints[3] = hints[3], hints[1] h1, h3 = h3, h1
} }
if hints[1] > hints[2] { if h1 > h2 {
hints[1], hints[2] = hints[2], hints[1] h1, h2 = h2, h1
} }
return uint32(hints[0])<<24 | uint32(hints[1])<<16 | uint32(hints[2])<<8 | uint32(hints[3]) return uint32(h0)<<24 | uint32(h1)<<16 | uint32(h2)<<8 | uint32(h3)
} }

View File

@@ -14,12 +14,14 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/sagernet/sing-box/common/congestion"
"github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/quic-go" "github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3" "github.com/sagernet/quic-go/http3"
@@ -50,7 +52,7 @@ type ClientOptions struct {
QUIC bool QUIC bool
CongestionControl string CongestionControl string
CWND int CWND int
BBRProfile string Logger logger.Logger
HealthCheck bool HealthCheck bool
MaxConnections int MaxConnections int
MinStreams int MinStreams int
@@ -81,7 +83,7 @@ func NewClient(ctx context.Context, options ClientOptions) (*Client, error) {
healthCheck: options.HealthCheck, healthCheck: options.HealthCheck,
} }
if options.QUIC { if options.QUIC {
congestionControlFactory, err := NewCongestionControl(options.CongestionControl, options.CWND, options.BBRProfile, ntp.TimeFuncFromContext(ctx)) congestionControlFactory, err := congestion.NewCongestionControl(options.CongestionControl, options.CWND, ntp.TimeFuncFromContext(ctx))
if err != nil { if err != nil {
cancel() cancel()
return nil, err return nil, err

View File

@@ -2,6 +2,7 @@ package v2raygrpc
import ( import (
"context" "context"
"strings"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@@ -13,13 +14,21 @@ type GunService interface {
} }
func ServerDesc(name string) grpc.ServiceDesc { func ServerDesc(name string) grpc.ServiceDesc {
serviceName := name
streamName := "Tun"
if strings.Contains(name, "/") {
name = strings.TrimPrefix(name, "/")
lastSlash := strings.LastIndex(name, "/")
serviceName = name[:lastSlash]
streamName = name[lastSlash+1:]
}
return grpc.ServiceDesc{ return grpc.ServiceDesc{
ServiceName: name, ServiceName: serviceName,
HandlerType: (*GunServiceServer)(nil), HandlerType: (*GunServiceServer)(nil),
Methods: []grpc.MethodDesc{}, Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{ Streams: []grpc.StreamDesc{
{ {
StreamName: "Tun", StreamName: streamName,
Handler: _GunService_Tun_Handler, Handler: _GunService_Tun_Handler,
ServerStreams: true, ServerStreams: true,
ClientStreams: true, ClientStreams: true,
@@ -30,7 +39,11 @@ func ServerDesc(name string) grpc.ServiceDesc {
} }
func (c *gunServiceClient) TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GunService_TunClient, error) { func (c *gunServiceClient) TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GunService_TunClient, error) {
stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], "/"+name+"/Tun", opts...) path := "/" + name + "/Tun"
if strings.Contains(name, "/") {
path = name
}
stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], path, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -53,10 +53,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
DisableCompression: true, DisableCompression: true,
}, },
url: &url.URL{ url: &url.URL{
Scheme: "https", Scheme: "https",
Host: serverAddr.String(), Host: serverAddr.String(),
Path: "/" + options.ServiceName + "/Tun", Path: grpcPath(options.ServiceName),
RawPath: "/" + url.PathEscape(options.ServiceName) + "/Tun",
}, },
host: host, host: host,
} }

View File

@@ -0,0 +1,10 @@
package v2raygrpclite
import "strings"
func grpcPath(serviceName string) string {
if strings.Contains(serviceName, "/") {
return serviceName
}
return "/" + serviceName + "/Tun"
}

View File

@@ -42,7 +42,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
logger: logger, logger: logger,
handler: handler, handler: handler,
path: "/" + options.ServiceName + "/Tun", path: grpcPath(options.ServiceName),
h2Server: &http2.Server{ h2Server: &http2.Server{
IdleTimeout: time.Duration(options.IdleTimeout), IdleTimeout: time.Duration(options.IdleTimeout),
}, },

View File

@@ -1,14 +1,14 @@
package v2raykcp package v2raykcp
import ( import (
"container/list"
"sync" "sync"
"github.com/sagernet/sing-box/common/list"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
) )
type SendingWindow struct { type SendingWindow struct {
cache *list.List cache *list.List[*DataSegment]
totalInFlightSize uint32 totalInFlightSize uint32
writer SegmentWriter writer SegmentWriter
onPacketLoss func(uint32) onPacketLoss func(uint32)
@@ -16,7 +16,7 @@ type SendingWindow struct {
func NewSendingWindow(writer SegmentWriter, onPacketLoss func(uint32)) *SendingWindow { func NewSendingWindow(writer SegmentWriter, onPacketLoss func(uint32)) *SendingWindow {
return &SendingWindow{ return &SendingWindow{
cache: list.New(), cache: list.New[*DataSegment](),
writer: writer, writer: writer,
onPacketLoss: onPacketLoss, onPacketLoss: onPacketLoss,
} }
@@ -27,9 +27,9 @@ func (sw *SendingWindow) Release() {
return return
} }
for sw.cache.Len() > 0 { for sw.cache.Len() > 0 {
seg := sw.cache.Front().Value.(*DataSegment) seg := sw.cache.Front().Value
seg.Release() seg.Release()
sw.cache.Remove(sw.cache.Front()) sw.cache.Front().Remove()
} }
} }
@@ -50,17 +50,17 @@ func (sw *SendingWindow) Push(number uint32, b *buf.Buffer) {
} }
func (sw *SendingWindow) FirstNumber() uint32 { func (sw *SendingWindow) FirstNumber() uint32 {
return sw.cache.Front().Value.(*DataSegment).Number return sw.cache.Front().Value.Number
} }
func (sw *SendingWindow) Clear(una uint32) { func (sw *SendingWindow) Clear(una uint32) {
for !sw.IsEmpty() { for !sw.IsEmpty() {
seg := sw.cache.Front().Value.(*DataSegment) seg := sw.cache.Front().Value
if seg.Number >= una { if seg.Number >= una {
break break
} }
seg.Release() seg.Release()
sw.cache.Remove(sw.cache.Front()) sw.cache.Front().Remove()
} }
} }
@@ -87,8 +87,7 @@ func (sw *SendingWindow) Visit(visitor func(seg *DataSegment) bool) {
} }
for e := sw.cache.Front(); e != nil; e = e.Next() { for e := sw.cache.Front(); e != nil; e = e.Next() {
seg := e.Value.(*DataSegment) if !visitor(e.Value) {
if !visitor(seg) {
break break
} }
} }
@@ -132,7 +131,7 @@ func (sw *SendingWindow) Remove(number uint32) bool {
} }
for e := sw.cache.Front(); e != nil; e = e.Next() { for e := sw.cache.Front(); e != nil; e = e.Next() {
seg := e.Value.(*DataSegment) seg := e.Value
if seg.Number > number { if seg.Number > number {
return false return false
} else if seg.Number == number { } else if seg.Number == number {
@@ -140,7 +139,7 @@ func (sw *SendingWindow) Remove(number uint32) bool {
sw.totalInFlightSize-- sw.totalInFlightSize--
} }
seg.Release() seg.Release()
sw.cache.Remove(e) e.Remove()
return true return true
} }
} }

View File

@@ -16,12 +16,12 @@ import (
"github.com/sagernet/quic-go" "github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3" "github.com/sagernet/quic-go/http3"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/congestion"
"github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/common/xray/buf" "github.com/sagernet/sing-box/common/xray/buf"
"github.com/sagernet/sing-box/common/xray/net" "github.com/sagernet/sing-box/common/xray/net"
"github.com/sagernet/sing-box/common/xray/pipe" "github.com/sagernet/sing-box/common/xray/pipe"
"github.com/sagernet/sing-box/common/xray/signal/done" "github.com/sagernet/sing-box/common/xray/signal/done"
"github.com/sagernet/sing-box/common/xray/uuid"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
qtls "github.com/sagernet/sing-quic" qtls "github.com/sagernet/sing-quic"
@@ -30,6 +30,7 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
sHTTP "github.com/sagernet/sing/protocol/http" sHTTP "github.com/sagernet/sing/protocol/http"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
"golang.org/x/net/http2" "golang.org/x/net/http2"
@@ -42,15 +43,22 @@ type Client struct {
baseRequestURL2 url.URL baseRequestURL2 url.URL
getHTTPClient func() (DialerClient, *XmuxClient) getHTTPClient func() (DialerClient, *XmuxClient)
getHTTPClient2 func() (DialerClient, *XmuxClient) getHTTPClient2 func() (DialerClient, *XmuxClient)
xmuxManager *XmuxManager
xmuxManager2 *XmuxManager
} }
func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayXHTTPOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) { func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayXHTTPOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
if options.Mode == "" {
return nil, E.New("mode is not set")
}
if tlsConfig != nil && len(tlsConfig.NextProtos()) == 0 { if tlsConfig != nil && len(tlsConfig.NextProtos()) == 0 {
tlsConfig.SetNextProtos([]string{"h2"}) tlsConfig.SetNextProtos([]string{"h2"})
} }
if _, err := congestion.NewCongestionControl(options.CongestionController, options.CWND, nil); err != nil {
return nil, err
}
if options.Download != nil {
if _, err := congestion.NewCongestionControl(options.Download.CongestionController, options.Download.CWND, nil); err != nil {
return nil, err
}
}
dest := serverAddr dest := serverAddr
baseRequestURL, err := getBaseRequestURL(&options.V2RayXHTTPBaseOptions, dest, tlsConfig) baseRequestURL, err := getBaseRequestURL(&options.V2RayXHTTPBaseOptions, dest, tlsConfig)
if err != nil { if err != nil {
@@ -61,7 +69,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
xmuxOptions = *options.Xmux xmuxOptions = *options.Xmux
} }
xmuxManager := NewXmuxManager(xmuxOptions, func() XmuxConn { xmuxManager := NewXmuxManager(xmuxOptions, func() XmuxConn {
return createHTTPClient(dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig) return createHTTPClient(ctx, dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig)
}) })
getHTTPClient := func() (DialerClient, *XmuxClient) { getHTTPClient := func() (DialerClient, *XmuxClient) {
xmuxClient := xmuxManager.GetXmuxClient(ctx) xmuxClient := xmuxManager.GetXmuxClient(ctx)
@@ -69,6 +77,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
} }
baseRequestURL2 := baseRequestURL baseRequestURL2 := baseRequestURL
getHTTPClient2 := getHTTPClient getHTTPClient2 := getHTTPClient
var xmuxManager2 *XmuxManager
if options.Download != nil { if options.Download != nil {
options2 := options.Download options2 := options.Download
dialer2 := dialer dialer2 := dialer
@@ -98,8 +107,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
if options2.Xmux != nil { if options2.Xmux != nil {
xmuxOptions2 = *options2.Xmux xmuxOptions2 = *options2.Xmux
} }
xmuxManager2 := NewXmuxManager(xmuxOptions2, func() XmuxConn { xmuxManager2 = NewXmuxManager(xmuxOptions2, func() XmuxConn {
return createHTTPClient(dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2) return createHTTPClient(ctx, dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2)
}) })
getHTTPClient2 = func() (DialerClient, *XmuxClient) { getHTTPClient2 = func() (DialerClient, *XmuxClient) {
xmuxClient2 := xmuxManager2.GetXmuxClient(ctx) xmuxClient2 := xmuxManager2.GetXmuxClient(ctx)
@@ -113,6 +122,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
getHTTPClient2: getHTTPClient2, getHTTPClient2: getHTTPClient2,
baseRequestURL: baseRequestURL, baseRequestURL: baseRequestURL,
baseRequestURL2: baseRequestURL2, baseRequestURL2: baseRequestURL2,
xmuxManager: xmuxManager,
xmuxManager2: xmuxManager2,
}, nil }, nil
} }
@@ -121,8 +132,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
mode := c.options.Mode mode := c.options.Mode
sessionId := "" sessionId := ""
if c.options.Mode != "stream-one" { if c.options.Mode != "stream-one" {
sessionIdUuid := uuid.New() sessionId = GenerateSessionID(&c.options.V2RayXHTTPBaseOptions)
sessionId = sessionIdUuid.String()
} }
requestURL := c.baseRequestURL requestURL := c.baseRequestURL
requestURL2 := c.baseRequestURL2 requestURL2 := c.baseRequestURL2
@@ -182,10 +192,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
} }
scMaxEachPostBytes := options.GetNormalizedScMaxEachPostBytes() scMaxEachPostBytes := options.GetNormalizedScMaxEachPostBytes()
scMinPostsIntervalMs := options.GetNormalizedScMinPostsIntervalMs() scMinPostsIntervalMs := options.GetNormalizedScMinPostsIntervalMs()
if scMaxEachPostBytes.From <= 0 { maxUploadSize := int32(scMaxEachPostBytes.Rand())
panic("`scMaxEachPostBytes` should be bigger than 0")
}
maxUploadSize := scMaxEachPostBytes.Rand()
// WithSizeLimit(0) will still allow single bytes to pass, and a lot of // WithSizeLimit(0) will still allow single bytes to pass, and a lot of
// code relies on this behavior. Subtract 1 so that together with // code relies on this behavior. Subtract 1 so that together with
// uploadWriter wrapper, exact size limits can be enforced // uploadWriter wrapper, exact size limits can be enforced
@@ -255,6 +262,10 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
} }
func (c *Client) Close() error { func (c *Client) Close() error {
c.xmuxManager.Close()
if c.xmuxManager2 != nil {
c.xmuxManager2.Close()
}
return nil return nil
} }
@@ -294,7 +305,7 @@ func getBaseRequestURL(options *option.V2RayXHTTPBaseOptions, dest M.Socksaddr,
return requestURL, nil return requestURL, nil
} }
func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient { func createHTTPClient(ctx context.Context, dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient {
httpVersion := decideHTTPVersion(tlsConfig) httpVersion := decideHTTPVersion(tlsConfig)
dialContext := func(ctxInner context.Context) (net.Conn, error) { dialContext := func(ctxInner context.Context) (net.Conn, error) {
conn, err := dialer.DialContext(ctxInner, "tcp", dest) conn, err := dialer.DialContext(ctxInner, "tcp", dest)
@@ -319,6 +330,7 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH
if keepAlivePeriod < 0 { if keepAlivePeriod < 0 {
keepAlivePeriod = 0 keepAlivePeriod = 0
} }
congestionControlFactory, _ := congestion.NewCongestionControl(options.CongestionController, options.CWND, ntp.TimeFuncFromContext(ctx))
quicConfig := &quic.Config{ quicConfig := &quic.Config{
MaxIdleTimeout: net.ConnIdleTimeout, MaxIdleTimeout: net.ConnIdleTimeout,
// these two are defaults of quic-go/http3. the default of quic-go (no // these two are defaults of quic-go/http3. the default of quic-go (no
@@ -334,7 +346,14 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH
if dErr != nil { if dErr != nil {
return nil, dErr return nil, dErr
} }
return qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg) conn, dErr := qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg)
if dErr != nil {
return nil, dErr
}
if congestionControlFactory != nil {
conn.SetCongestionControl(congestionControlFactory(conn))
}
return conn, nil
}, },
} }
case "2": case "2":

View File

@@ -39,7 +39,7 @@ func (c *splitConn) Close() error {
} }
if err2 != nil { if err2 != nil {
return err return err2
} }
return nil return nil

View File

@@ -147,7 +147,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio
if c.httpVersion != "1.1" { if c.httpVersion != "1.1" {
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
if err != nil { if err != nil {
c.closed = true c.Close()
return err return err
} }
io.Copy(io.Discard, resp.Body) io.Copy(io.Discard, resp.Body)
@@ -225,10 +225,9 @@ func (w *WaitReadCloser) Set(rc io.ReadCloser) {
} }
func (w *WaitReadCloser) Read(b []byte) (int, error) { func (w *WaitReadCloser) Read(b []byte) (int, error) {
<-w.Wait
if w.ReadCloser == nil { if w.ReadCloser == nil {
if <-w.Wait; w.ReadCloser == nil { return 0, io.ErrClosedPipe
return 0, io.ErrClosedPipe
}
} }
return w.ReadCloser.Read(b) return w.ReadCloser.Read(b)
} }

Some files were not shown because too many files have changed in this diff Show More