mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-26 12:23:12 +03:00
Add Snell protocol. Refactor MASQUE HTTP/2, Fair Queue. Update XHTTP, OpenVPN, Sudoku, Fallback. Fixes
This commit is contained in:
@@ -26,6 +26,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_manager
|
||||
- with_admin_panel
|
||||
- with_profiler
|
||||
@@ -64,6 +65,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_profiler
|
||||
- badlinkname
|
||||
- tfogo_checklinkname0
|
||||
@@ -123,6 +125,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_manager
|
||||
- with_admin_panel
|
||||
- with_profiler
|
||||
@@ -156,6 +159,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_manager
|
||||
- with_admin_panel
|
||||
- with_profiler
|
||||
@@ -189,6 +193,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_manager
|
||||
- with_admin_panel
|
||||
- with_profiler
|
||||
@@ -222,6 +227,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_manager
|
||||
- with_admin_panel
|
||||
- with_profiler
|
||||
@@ -255,6 +261,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_manager
|
||||
- with_admin_panel
|
||||
- with_profiler
|
||||
@@ -304,6 +311,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_manager
|
||||
- with_admin_panel
|
||||
- with_profiler
|
||||
@@ -361,6 +369,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_profiler
|
||||
- badlinkname
|
||||
- tfogo_checklinkname0
|
||||
@@ -433,6 +442,7 @@ builds:
|
||||
- with_openvpn
|
||||
- with_trusttunnel
|
||||
- with_sudoku
|
||||
- with_snell
|
||||
- with_profiler
|
||||
- badlinkname
|
||||
- tfogo_checklinkname0
|
||||
|
||||
@@ -13,6 +13,7 @@ type PlatformInterface interface {
|
||||
|
||||
UsePlatformAutoDetectInterfaceControl() bool
|
||||
AutoDetectInterfaceControl(fd int) error
|
||||
BindInterfaceControl(fd int, interfaceName string) error
|
||||
|
||||
UsePlatformInterface() bool
|
||||
OpenInterface(options *tun.Options, platformOptions option.TunPlatformOptions) (tun.Tun, error)
|
||||
|
||||
@@ -63,7 +63,7 @@ func init() {
|
||||
sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -X internal/godebug.defaultGODEBUG=multipathtcp=0 -s -w -buildid= -checklinkname=0")
|
||||
debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -X internal/godebug.defaultGODEBUG=multipathtcp=0 -checklinkname=0")
|
||||
|
||||
sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_masque", "with_mtproxy", "with_trusttunnel", "with_openvpn", "with_sudoku", "with_utls", "with_naive_outbound", "with_clash_api", "badlinkname", "tfogo_checklinkname0")
|
||||
sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_masque", "with_mtproxy", "with_trusttunnel", "with_openvpn", "with_sudoku", "with_snell", "with_utls", "with_naive_outbound", "with_clash_api", "badlinkname", "tfogo_checklinkname0")
|
||||
darwinTags = append(darwinTags, "with_dhcp", "grpcnotrace")
|
||||
// memcTags = append(memcTags, "with_tailscale")
|
||||
sharedTags = append(sharedTags, "with_tailscale", "ts_omit_logtail", "ts_omit_ssh", "ts_omit_drive", "ts_omit_taildrop", "ts_omit_webclient", "ts_omit_doctor", "ts_omit_capture", "ts_omit_kube", "ts_omit_aws", "ts_omit_synology", "ts_omit_bird")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package trusttunnel
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func NewCongestionControl(name string, cwnd int, bbrProfile string, timeFunc func() time.Time) (func(conn *quic.Conn) congestion.CongestionControl, error) {
|
||||
func NewCongestionControl(name string, cwnd int, timeFunc func() time.Time) (func(conn *quic.Conn) congestion.CongestionControl, error) {
|
||||
if timeFunc == nil {
|
||||
timeFunc = time.Now
|
||||
}
|
||||
@@ -70,9 +70,20 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
|
||||
if !(C.IsLinux || C.IsDarwin || C.IsWindows) {
|
||||
return nil, E.New("`bind_interface` is only supported on Linux, macOS and Windows")
|
||||
}
|
||||
bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1)
|
||||
dialer.Control = control.Append(dialer.Control, bindFunc)
|
||||
listener.Control = control.Append(listener.Control, bindFunc)
|
||||
if platformInterface != nil && platformInterface.UsePlatformAutoDetectInterfaceControl() {
|
||||
interfaceName := options.BindInterface
|
||||
bindFunc := func(network, address string, conn syscall.RawConn) error {
|
||||
return control.Raw(conn, func(fd uintptr) error {
|
||||
return platformInterface.BindInterfaceControl(int(fd), interfaceName)
|
||||
})
|
||||
}
|
||||
dialer.Control = control.Append(dialer.Control, bindFunc)
|
||||
listener.Control = control.Append(listener.Control, bindFunc)
|
||||
} else {
|
||||
bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1)
|
||||
dialer.Control = control.Append(dialer.Control, bindFunc)
|
||||
listener.Control = control.Append(listener.Control, bindFunc)
|
||||
}
|
||||
}
|
||||
if options.RoutingMark > 0 {
|
||||
if !C.IsLinux {
|
||||
|
||||
164
common/list/list.go
Normal file
164
common/list/list.go
Normal file
@@ -0,0 +1,164 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package list
|
||||
|
||||
// Element is an element of a linked list.
|
||||
type Element[T any] struct {
|
||||
next, prev *Element[T]
|
||||
list *List[T]
|
||||
Value T
|
||||
}
|
||||
|
||||
func (e *Element[T]) Next() *Element[T] {
|
||||
if p := e.next; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Element[T]) Prev() *Element[T] {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Element[T]) Remove() bool {
|
||||
if e.list == nil {
|
||||
return false
|
||||
}
|
||||
e.list.remove(e)
|
||||
return true
|
||||
}
|
||||
|
||||
type List[T any] struct {
|
||||
root Element[T]
|
||||
len int
|
||||
}
|
||||
|
||||
func (l *List[T]) Init() *List[T] {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
func New[T any]() *List[T] { return new(List[T]).Init() }
|
||||
|
||||
func (l *List[T]) Len() int { return l.len }
|
||||
|
||||
func (l *List[T]) Front() *Element[T] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.next
|
||||
}
|
||||
|
||||
func (l *List[T]) Back() *Element[T] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
func (l *List[T]) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *List[T]) insert(e, at *Element[T]) *Element[T] {
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] {
|
||||
return l.insert(&Element[T]{Value: v}, at)
|
||||
}
|
||||
|
||||
func (l *List[T]) remove(e *Element[T]) {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil
|
||||
e.prev = nil
|
||||
e.list = nil
|
||||
l.len--
|
||||
}
|
||||
|
||||
func (l *List[T]) Remove(e *Element[T]) T {
|
||||
if e.list == l {
|
||||
l.remove(e)
|
||||
}
|
||||
return e.Value
|
||||
}
|
||||
|
||||
func (l *List[T]) PushFront(v T) *Element[T] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(v, &l.root)
|
||||
}
|
||||
|
||||
func (l *List[T]) PushBack(v T) *Element[T] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(v, l.root.prev)
|
||||
}
|
||||
|
||||
func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
return l.insertValue(v, mark.prev)
|
||||
}
|
||||
|
||||
func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
return l.insertValue(v, mark)
|
||||
}
|
||||
|
||||
func (l *List[T]) MoveToFront(e *Element[T]) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
l.move(e, &l.root)
|
||||
}
|
||||
|
||||
func (l *List[T]) MoveToBack(e *Element[T]) {
|
||||
if e.list != l || l.root.prev == e {
|
||||
return
|
||||
}
|
||||
l.move(e, l.root.prev)
|
||||
}
|
||||
|
||||
func (l *List[T]) MoveBefore(e, mark *Element[T]) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
}
|
||||
l.move(e, mark.prev)
|
||||
}
|
||||
|
||||
func (l *List[T]) MoveAfter(e, mark *Element[T]) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
}
|
||||
l.move(e, mark)
|
||||
}
|
||||
|
||||
func (l *List[T]) move(e, at *Element[T]) {
|
||||
if e == at {
|
||||
return
|
||||
}
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
)
|
||||
|
||||
@@ -69,8 +68,8 @@ func DecodeBase64URLSafe(content string) (string, error) {
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
func ParseXHTTPRange(value string) (Xbadoption.Range, error) {
|
||||
result := Xbadoption.Range{}
|
||||
func ParseXHTTPRange(value string) (badoption.Range[int], error) {
|
||||
result := badoption.Range[int]{}
|
||||
encoded, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return result, err
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
package badoption
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/common/xray/crypto"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type Range struct {
|
||||
From int32 `json:"from"`
|
||||
To int32 `json:"to"`
|
||||
}
|
||||
|
||||
func (c *Range) Build() *Range {
|
||||
return (*Range)(c)
|
||||
}
|
||||
|
||||
func (c *Range) MarshalJSON() ([]byte, error) {
|
||||
if c.From == c.To {
|
||||
return json.Marshal(c.From)
|
||||
}
|
||||
return json.Marshal(fmt.Sprintf("%d-%d", c.From, c.To))
|
||||
}
|
||||
|
||||
func (c *Range) UnmarshalJSON(content []byte) error {
|
||||
var rangeValue struct {
|
||||
From int32 `json:"from"`
|
||||
To int32 `json:"to"`
|
||||
}
|
||||
var stringValue string
|
||||
err := json.Unmarshal(content, &stringValue)
|
||||
if err == nil {
|
||||
parts := strings.Split(stringValue, "-")
|
||||
if len(parts) != 2 {
|
||||
from, err := strconv.ParseInt(parts[0], 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rangeValue.From, rangeValue.To = int32(from), int32(from)
|
||||
} else {
|
||||
from, err := strconv.ParseInt(parts[0], 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
to, err := strconv.ParseInt(parts[1], 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rangeValue.From, rangeValue.To = int32(from), int32(to)
|
||||
}
|
||||
} else {
|
||||
var int32Value int32
|
||||
err := json.Unmarshal(content, &int32Value)
|
||||
if err == nil {
|
||||
rangeValue.From, rangeValue.To = int32Value, int32Value
|
||||
} else {
|
||||
err := json.Unmarshal(content, &rangeValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if rangeValue.From > rangeValue.To {
|
||||
return E.New("invalid range")
|
||||
}
|
||||
*c = Range{rangeValue.From, rangeValue.To}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Range) String() string {
|
||||
if c.From == c.To {
|
||||
return strconv.FormatInt(int64(c.From), 10)
|
||||
}
|
||||
return fmt.Sprintf("%d-%d", c.From, c.To)
|
||||
}
|
||||
|
||||
func (c Range) Rand() int32 {
|
||||
return int32(crypto.RandBetween(int64(c.From), int64(c.To)))
|
||||
}
|
||||
@@ -28,6 +28,7 @@ const (
|
||||
TypeMieru = "mieru"
|
||||
TypeAnyTLS = "anytls"
|
||||
TypeSudoku = "sudoku"
|
||||
TypeSnell = "snell"
|
||||
TypeShadowsocksR = "shadowsocksr"
|
||||
TypeVLESS = "vless"
|
||||
TypeTUIC = "tuic"
|
||||
@@ -41,6 +42,7 @@ const (
|
||||
TypeBandwidthLimiter = "bandwidth-limiter"
|
||||
TypeTrafficLimiter = "traffic-limiter"
|
||||
TypeRateLimiter = "rate-limiter"
|
||||
TypeFairQueue = "fair-queue"
|
||||
TypeAdminPanel = "admin-panel"
|
||||
TypeManagerAPI = "manager-api"
|
||||
TypeNodeManagerAPI = "node-manager-api"
|
||||
@@ -129,6 +131,8 @@ func ProxyDisplayName(proxyType string) string {
|
||||
return "AnyTLS"
|
||||
case TypeSudoku:
|
||||
return "Sudoku"
|
||||
case TypeSnell:
|
||||
return "Snell"
|
||||
case TypeFallback:
|
||||
return "Fallback"
|
||||
case TypeTailscale:
|
||||
@@ -145,6 +149,8 @@ func ProxyDisplayName(proxyType string) string {
|
||||
return "Traffic Limiter"
|
||||
case TypeRateLimiter:
|
||||
return "Rate Limiter"
|
||||
case TypeFairQueue:
|
||||
return "Fair Queue"
|
||||
case TypeVPNClient:
|
||||
return "VPN Client"
|
||||
case TypeVPNServer:
|
||||
|
||||
@@ -39,11 +39,14 @@
|
||||
"udp_keepalive_period": "30s",
|
||||
"udp_initial_packet_size": 0,
|
||||
"reconnect_delay": "5s",
|
||||
"congestion_controller": "bbr",
|
||||
"cwnd": 0,
|
||||
"tls": { // TLS fields for HTTP2
|
||||
"insecure": false,
|
||||
"cipher_suites": [],
|
||||
"curve_preferences": [],
|
||||
"fragment": false,
|
||||
"fragment_fallback_delay": "500ms",
|
||||
"record_fragment": false,
|
||||
"kernel_tx": false,
|
||||
"kernel_rx": false
|
||||
|
||||
13
examples/profiler/config.json
Normal file
13
examples/profiler/config.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"log": {
|
||||
"level": "info"
|
||||
},
|
||||
"services": [
|
||||
{
|
||||
"type": "profiler",
|
||||
"tag": "pprof",
|
||||
"listen": "127.0.0.1",
|
||||
"listen_port": 6060
|
||||
}
|
||||
]
|
||||
}
|
||||
46
examples/snell/client.json
Normal file
46
examples/snell/client.json
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"log": {
|
||||
"level": "error"
|
||||
},
|
||||
"dns": {
|
||||
"servers": [
|
||||
{
|
||||
"type": "local",
|
||||
"tag": "default"
|
||||
}
|
||||
]
|
||||
},
|
||||
"inbounds": [
|
||||
{
|
||||
"type": "mixed",
|
||||
"tag": "mixed-in",
|
||||
"listen_port": 7897
|
||||
}
|
||||
],
|
||||
"outbounds": [
|
||||
{
|
||||
"type": "direct",
|
||||
"tag": "direct"
|
||||
},
|
||||
{
|
||||
"type": "snell",
|
||||
"tag": "snell-out",
|
||||
"server": "example.com",
|
||||
"server_port": 8443,
|
||||
"psk": "your-secret-psk",
|
||||
"version": 4, // 1 | 2 | 3 | 4 | 5 (v5 falls back to v4)
|
||||
"reuse": true, // v4 only, reuse pooled connections
|
||||
"network": ["tcp", "udp"]
|
||||
// "obfs": {
|
||||
// "mode": "tls", // tls | http
|
||||
// "host": "bing.com"
|
||||
// }
|
||||
// Dial Fields
|
||||
}
|
||||
],
|
||||
"route": {
|
||||
"final": "snell-out",
|
||||
"default_domain_resolver": "default",
|
||||
"auto_detect_interface": true
|
||||
}
|
||||
}
|
||||
39
examples/snell/server.json
Normal file
39
examples/snell/server.json
Normal file
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"log": {
|
||||
"level": "error"
|
||||
},
|
||||
"dns": {
|
||||
"servers": [
|
||||
{
|
||||
"type": "local",
|
||||
"tag": "default"
|
||||
}
|
||||
]
|
||||
},
|
||||
"inbounds": [
|
||||
{
|
||||
"type": "snell",
|
||||
"tag": "snell-in",
|
||||
"listen": "::",
|
||||
"listen_port": 8443,
|
||||
"psk": "your-secret-psk",
|
||||
"version": 4, // 4 | 5 (server supports v4/v5 only)
|
||||
"network": ["tcp", "udp"]
|
||||
// "obfs": {
|
||||
// "mode": "tls", // tls | http
|
||||
// "host": "bing.com"
|
||||
// }
|
||||
}
|
||||
],
|
||||
"outbounds": [
|
||||
{
|
||||
"type": "direct",
|
||||
"tag": "direct"
|
||||
}
|
||||
],
|
||||
"route": {
|
||||
"final": "direct",
|
||||
"default_domain_resolver": "default",
|
||||
"auto_detect_interface": true
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,8 @@
|
||||
"multiplex": {
|
||||
"enabled": true,
|
||||
"max_connections": 8,
|
||||
"min_streams": 5
|
||||
"min_streams": 5,
|
||||
"max_streams": 0
|
||||
},
|
||||
"tls": {
|
||||
"enabled": true,
|
||||
@@ -50,12 +51,12 @@
|
||||
"health_check": true,
|
||||
"quic": true,
|
||||
"congestion_controller": "bbr", // bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
|
||||
"bbr_profile": "standard", // standard, conservative, aggressive
|
||||
"cwnd": 32,
|
||||
"multiplex": {
|
||||
"enabled": true,
|
||||
"max_connections": 8,
|
||||
"min_streams": 5
|
||||
"min_streams": 5,
|
||||
"max_streams": 0
|
||||
},
|
||||
"tls": {
|
||||
"enabled": true,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
}
|
||||
],
|
||||
"congestion_controller": "bbr", // bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
|
||||
"bbr_profile": "standard", // standard, conservative, aggressive
|
||||
"cwnd": 32,
|
||||
"tls": {
|
||||
"enabled": true,
|
||||
|
||||
@@ -65,6 +65,8 @@
|
||||
"uplink_data_placement": "",
|
||||
"uplink_data_key": "",
|
||||
"uplink_chunk_size": 0,
|
||||
"congestion_controller": "", // h3 only: bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
|
||||
"cwnd": 0, // h3 only: initial congestion window in packets, default 32
|
||||
"server": "example.com",
|
||||
"server_port": 443,
|
||||
"download": {
|
||||
@@ -97,6 +99,8 @@
|
||||
"uplink_data_placement": "",
|
||||
"uplink_data_key": "",
|
||||
"uplink_chunk_size": 0,
|
||||
"congestion_controller": "", // h3 only: bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
|
||||
"cwnd": 0, // h3 only: initial congestion window in packets, default 32
|
||||
"server": "example.com",
|
||||
"server_port": 443,
|
||||
"tls": { // https://sing-box.sagernet.org/configuration/shared/tls/#outbound
|
||||
|
||||
@@ -51,6 +51,8 @@
|
||||
"seq_key": "",
|
||||
"uplink_data_placement": "",
|
||||
"uplink_data_key": "",
|
||||
"congestion_controller": "", // h3 only: bbr, bbr_standard, bbr2, bbr2_variant, cubic, reno
|
||||
"cwnd": 0, // h3 only: initial congestion window in packets, default 32
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -78,6 +78,10 @@ func (s *platformInterfaceStub) AutoDetectInterfaceControl(fd int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) BindInterfaceControl(fd int, interfaceName string) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (s *platformInterfaceStub) UsePlatformInterface() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ type PlatformInterface interface {
|
||||
LocalDNSTransport() LocalDNSTransport
|
||||
UsePlatformAutoDetectInterfaceControl() bool
|
||||
AutoDetectInterfaceControl(fd int32) error
|
||||
BindInterfaceControl(fd int32, interfaceName string) error
|
||||
OpenTun(options TunOptions) (int32, error)
|
||||
UseProcFS() bool
|
||||
FindConnectionOwner(ipProtocol int32, sourceAddress string, sourcePort int32, destinationAddress string, destinationPort int32) (*ConnectionOwner, error)
|
||||
|
||||
@@ -49,6 +49,10 @@ func (w *platformInterfaceWrapper) AutoDetectInterfaceControl(fd int) error {
|
||||
return w.iif.AutoDetectInterfaceControl(int32(fd))
|
||||
}
|
||||
|
||||
func (w *platformInterfaceWrapper) BindInterfaceControl(fd int, interfaceName string) error {
|
||||
return w.iif.BindInterfaceControl(int32(fd), interfaceName)
|
||||
}
|
||||
|
||||
func (w *platformInterfaceWrapper) UsePlatformInterface() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
7
go.mod
7
go.mod
@@ -32,6 +32,7 @@ require (
|
||||
github.com/miekg/dns v1.1.72
|
||||
github.com/openai/openai-go/v3 v3.26.0
|
||||
github.com/oschwald/maxminddb-golang v1.13.1
|
||||
github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e
|
||||
github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1
|
||||
github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a
|
||||
github.com/sagernet/cors v1.2.1
|
||||
@@ -231,8 +232,8 @@ replace github.com/sagernet/sing-vmess => github.com/shtorm-7/sing-vmess v0.2.7-
|
||||
|
||||
replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0
|
||||
|
||||
replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0
|
||||
replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.1.0
|
||||
|
||||
replace github.com/shtorm-7/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0
|
||||
replace github.com/shtorm-7/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.2.0
|
||||
|
||||
replace github.com/sagernet/sing => github.com/shtorm-7/sing v0.8.10-extended-1.1.0
|
||||
replace github.com/sagernet/sing => github.com/shtorm-7/sing v0.8.10-extended-1.2.0
|
||||
|
||||
14
go.sum
14
go.sum
@@ -268,6 +268,8 @@ github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||
github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e h1:dCWirM5F3wMY+cmRda/B1BiPsFtmzXqV9b0hLWtVBMs=
|
||||
github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e/go.mod h1:9leZcVcItj6m9/CfHY5Em/iBrCz7js8LcRQGTKEEv2M=
|
||||
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
|
||||
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
@@ -373,16 +375,16 @@ github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1h
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1/go.mod h1:NjhsCEWedJm7eFLyhuBgIEzwfhRmytrUoiLluxs5Sk8=
|
||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
|
||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
|
||||
github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 h1:ws7BIsYLd31Wjifq88BYCHRVlgO+07iwil39s6ERba8=
|
||||
github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0/go.mod h1:mRwx4w32qQxsWB2kThuHpbo7iNjJiq1jYWubgqEPjHA=
|
||||
github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.1.0 h1:3ZV98mKqKNPCPWHevJ6RPsb65DwPrRFEUOHUfDnG6vw=
|
||||
github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.1.0/go.mod h1:mRwx4w32qQxsWB2kThuHpbo7iNjJiq1jYWubgqEPjHA=
|
||||
github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 h1:e5s7RKBd2rIPR0StbvZ2vTVtJ5jDTsTk5wtIIapZTRg=
|
||||
github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI=
|
||||
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 h1:PLZ/YHqnApPx13wt6MX3ItqESp4ueBr1tGSi0bEGqYw=
|
||||
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4=
|
||||
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.2.0 h1:aOd9Vy2LGSwgMM+4805AgLBE/MQf8UymbXHxUZjSmoU=
|
||||
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.2.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4=
|
||||
github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 h1:iBLll4ZZG8ULQcHWs6gGslZWtBN72Zo1zjySzMVHF7g=
|
||||
github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0=
|
||||
github.com/shtorm-7/sing v0.8.10-extended-1.1.0 h1:P4JL2cugjvEvnYu8tMmpR30SE1qsS45RcnNEwzDz5as=
|
||||
github.com/shtorm-7/sing v0.8.10-extended-1.1.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA=
|
||||
github.com/shtorm-7/sing v0.8.10-extended-1.2.0 h1:5yw9j0+P2QkRWvxBvb71wvNdpAlHmmpBv4hj2gqvass=
|
||||
github.com/shtorm-7/sing v0.8.10-extended-1.2.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA=
|
||||
github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0 h1:a5OoXr3e2ACbM6vDIaaGL44IdHQ6wPjcSoU13vfC0Sw=
|
||||
github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk=
|
||||
github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0 h1:WVheKmQH5hSQbJU1ZTKthKSutkTLWSb2hp4JuQhJBow=
|
||||
|
||||
@@ -91,6 +91,7 @@ func InboundRegistry() *inbound.Registry {
|
||||
registerStubForRemovedInbounds(registry)
|
||||
registerMTProxyInbound(registry)
|
||||
registerSudokuInbound(registry)
|
||||
registerSnellInbound(registry)
|
||||
|
||||
return registry
|
||||
}
|
||||
@@ -135,6 +136,7 @@ func OutboundRegistry() *outbound.Registry {
|
||||
registerQUICOutbounds(registry)
|
||||
registerStubForRemovedOutbounds(registry)
|
||||
registerSudokuOutbound(registry)
|
||||
registerSnellOutbound(registry)
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
17
include/snell.go
Normal file
17
include/snell.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build with_snell
|
||||
|
||||
package include
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/protocol/snell"
|
||||
)
|
||||
|
||||
func registerSnellInbound(registry *inbound.Registry) {
|
||||
snell.RegisterInbound(registry)
|
||||
}
|
||||
|
||||
func registerSnellOutbound(registry *outbound.Registry) {
|
||||
snell.RegisterOutbound(registry)
|
||||
}
|
||||
27
include/snell_stub.go
Normal file
27
include/snell_stub.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build !with_snell
|
||||
|
||||
package include
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func registerSnellInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.SnellInboundOptions](registry, C.TypeSnell, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellInboundOptions) (adapter.Inbound, error) {
|
||||
return nil, E.New(`Snell is not included in this build, rebuild with -tags with_snell`)
|
||||
})
|
||||
}
|
||||
|
||||
func registerSnellOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.SnellOutboundOptions](registry, C.TypeSnell, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellOutboundOptions) (adapter.Outbound, error) {
|
||||
return nil, E.New(`Snell is not included in this build, rebuild with -tags with_snell`)
|
||||
})
|
||||
}
|
||||
@@ -18,7 +18,8 @@ type URLTestOutboundOptions struct {
|
||||
}
|
||||
|
||||
type FallbackOutboundOptions struct {
|
||||
Outbounds []string `json:"outbounds"`
|
||||
Outbounds []string `json:"outbounds"`
|
||||
BlacklistTimeout badoption.Duration `json:"blacklist_timeout,omitempty"`
|
||||
}
|
||||
|
||||
type GroupCommonOption struct {
|
||||
|
||||
@@ -69,3 +69,8 @@ type RateLimiterUser struct {
|
||||
Count uint32 `json:"count"`
|
||||
Interval badoption.Duration `json:"interval"`
|
||||
}
|
||||
|
||||
type FairQueueOutboundOptions struct {
|
||||
FlowKeys []string `json:"flow_keys,omitempty"`
|
||||
Outbound string `json:"outbound"`
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ type MASQUEOutboundOptions struct {
|
||||
UDPKeepalivePeriod badoption.Duration `json:"udp_keepalive_period,omitempty"`
|
||||
UDPInitialPacketSize uint16 `json:"udp_initial_packet_size,omitempty"`
|
||||
ReconnectDelay badoption.Duration `json:"reconnect_delay,omitempty"`
|
||||
CongestionController string `json:"congestion_controller,omitempty"`
|
||||
CWND int `json:"cwnd,omitempty"`
|
||||
MASQUEOutboundTLSOptionsContainer
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ type OpenVPNOutboundOptions struct {
|
||||
KeyDirection int `json:"key_direction,omitempty"`
|
||||
ReconnectDelay badoption.Duration `json:"reconnect_delay,omitempty"`
|
||||
PingInterval badoption.Duration `json:"ping_interval,omitempty"`
|
||||
PingRestart badoption.Duration `json:"ping_restart,omitempty"`
|
||||
OpenVPNOutboundTLSOptionsContainer
|
||||
}
|
||||
|
||||
|
||||
24
option/snell.go
Normal file
24
option/snell.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package option
|
||||
|
||||
type SnellOutboundOptions struct {
|
||||
DialerOptions
|
||||
ServerOptions
|
||||
PSK string `json:"psk"`
|
||||
Version int `json:"version,omitempty"`
|
||||
Reuse bool `json:"reuse,omitempty"`
|
||||
Network NetworkList `json:"network,omitempty"`
|
||||
Obfs *SnellObfsOptions `json:"obfs,omitempty"`
|
||||
}
|
||||
|
||||
type SnellInboundOptions struct {
|
||||
ListenOptions
|
||||
PSK string `json:"psk"`
|
||||
Version int `json:"version,omitempty"`
|
||||
Network NetworkList `json:"network,omitempty"`
|
||||
Obfs *SnellObfsOptions `json:"obfs,omitempty"`
|
||||
}
|
||||
|
||||
type SnellObfsOptions struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Host string `json:"host,omitempty"`
|
||||
}
|
||||
@@ -6,7 +6,6 @@ type TrustTunnelInboundOptions struct {
|
||||
Users []TrustTunnelUser `json:"users,omitempty"`
|
||||
Network NetworkList `json:"network,omitempty"`
|
||||
CongestionController string `json:"congestion_controller,omitempty"`
|
||||
BBRProfile string `json:"bbr_profile,omitempty"`
|
||||
CWND int `json:"cwnd,omitempty"`
|
||||
}
|
||||
|
||||
@@ -32,7 +31,6 @@ type TrustTunnelOutboundOptions struct {
|
||||
HealthCheck bool `json:"health_check,omitempty"`
|
||||
QUIC bool `json:"quic,omitempty"`
|
||||
CongestionController string `json:"congestion_controller,omitempty"`
|
||||
BBRProfile string `json:"bbr_profile,omitempty"`
|
||||
CWND int `json:"cwnd,omitempty"`
|
||||
Multiplex *TrustTunnelMultiplexOptions `json:"multiplex,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption"
|
||||
"github.com/sagernet/sing-box/common/xray/utils"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
@@ -119,13 +118,13 @@ type V2RayXHTTPBaseOptions struct {
|
||||
Path string `json:"path,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"`
|
||||
XPaddingBytes Xbadoption.Range `json:"x_padding_bytes"`
|
||||
XPaddingBytes badoption.Range[int] `json:"x_padding_bytes"`
|
||||
NoGRPCHeader bool `json:"no_grpc_header,omitempty"`
|
||||
NoSSEHeader bool `json:"no_sse_header,omitempty"`
|
||||
ScMaxEachPostBytes *Xbadoption.Range `json:"sc_max_each_post_bytes"`
|
||||
ScMinPostsIntervalMs *Xbadoption.Range `json:"sc_min_posts_interval_ms"`
|
||||
ScMaxEachPostBytes *badoption.Range[int] `json:"sc_max_each_post_bytes"`
|
||||
ScMinPostsIntervalMs *badoption.Range[int] `json:"sc_min_posts_interval_ms"`
|
||||
ScMaxBufferedPosts int64 `json:"sc_max_buffered_posts,omitempty"`
|
||||
ScStreamUpServerSecs *Xbadoption.Range `json:"sc_stream_up_server_secs"`
|
||||
ScStreamUpServerSecs *badoption.Range[int] `json:"sc_stream_up_server_secs"`
|
||||
ServerMaxHeaderBytes int `json:"server_max_header_bytes"`
|
||||
TrustedXForwardedFor badoption.Listable[string] `json:"trusted_x_forwarded_for,omitempty"`
|
||||
Xmux *V2RayXHTTPXmuxOptions `json:"xmux"`
|
||||
@@ -141,7 +140,11 @@ type V2RayXHTTPBaseOptions struct {
|
||||
SeqKey string `json:"seq_key,omitempty"`
|
||||
UplinkDataPlacement string `json:"uplink_data_placement,omitempty"`
|
||||
UplinkDataKey string `json:"uplink_data_key,omitempty"`
|
||||
UplinkChunkSize *Xbadoption.Range `json:"uplink_chunk_size,omitempty"`
|
||||
UplinkChunkSize *badoption.Range[int] `json:"uplink_chunk_size,omitempty"`
|
||||
SessionIDTable string `json:"session_id_table,omitempty"`
|
||||
SessionIDLength badoption.Range[int] `json:"session_id_length,omitempty"`
|
||||
CongestionController string `json:"congestion_controller,omitempty"`
|
||||
CWND int `json:"cwnd,omitempty"`
|
||||
}
|
||||
|
||||
type _V2RayXHTTPOptions struct {
|
||||
@@ -302,6 +305,10 @@ func checkV2RayXHTTPBaseOptions(mode string, options *V2RayXHTTPBaseOptions) err
|
||||
return E.New("invalid negative value of maxHeaderBytes")
|
||||
}
|
||||
|
||||
if mode != "stream-one" && mode != "stream-up" && options.GetNormalizedScMaxEachPostBytes().From <= 0 {
|
||||
return E.New("`scMaxEachPostBytes` should be bigger than 0")
|
||||
}
|
||||
|
||||
if options.Xmux == nil {
|
||||
options.Xmux = &V2RayXHTTPXmuxOptions{}
|
||||
options.Xmux.MaxConcurrency.From = 1
|
||||
@@ -346,9 +353,9 @@ func (c *V2RayXHTTPBaseOptions) GetRequestHeader() http.Header {
|
||||
return header
|
||||
}
|
||||
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedXPaddingBytes() Xbadoption.Range {
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedXPaddingBytes() badoption.Range[int] {
|
||||
if c.XPaddingBytes.To == 0 {
|
||||
return Xbadoption.Range{
|
||||
return badoption.Range[int]{
|
||||
From: 100,
|
||||
To: 1000,
|
||||
}
|
||||
@@ -363,9 +370,9 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkHTTPMethod() string {
|
||||
return c.UplinkHTTPMethod
|
||||
}
|
||||
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxEachPostBytes() Xbadoption.Range {
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxEachPostBytes() badoption.Range[int] {
|
||||
if c.ScMaxEachPostBytes == nil {
|
||||
return Xbadoption.Range{
|
||||
return badoption.Range[int]{
|
||||
From: 1000000,
|
||||
To: 1000000,
|
||||
}
|
||||
@@ -373,9 +380,9 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxEachPostBytes() Xbadoption.Ran
|
||||
return *c.ScMaxEachPostBytes
|
||||
}
|
||||
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedScMinPostsIntervalMs() Xbadoption.Range {
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedScMinPostsIntervalMs() badoption.Range[int] {
|
||||
if c.ScMinPostsIntervalMs == nil {
|
||||
return Xbadoption.Range{
|
||||
return badoption.Range[int]{
|
||||
From: 30,
|
||||
To: 30,
|
||||
}
|
||||
@@ -391,9 +398,9 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedScMaxBufferedPosts() int {
|
||||
return int(c.ScMaxBufferedPosts)
|
||||
}
|
||||
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedScStreamUpServerSecs() Xbadoption.Range {
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedScStreamUpServerSecs() badoption.Range[int] {
|
||||
if c.ScStreamUpServerSecs == nil {
|
||||
return Xbadoption.Range{
|
||||
return badoption.Range[int]{
|
||||
From: 20,
|
||||
To: 80,
|
||||
}
|
||||
@@ -401,16 +408,16 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedScStreamUpServerSecs() Xbadoption.R
|
||||
return *c.ScStreamUpServerSecs
|
||||
}
|
||||
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkChunkSize() Xbadoption.Range {
|
||||
func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkChunkSize() badoption.Range[int] {
|
||||
if c.UplinkChunkSize == nil || c.UplinkChunkSize.To == 0 {
|
||||
switch c.UplinkDataPlacement {
|
||||
case PlacementCookie:
|
||||
return Xbadoption.Range{
|
||||
return badoption.Range[int]{
|
||||
From: 2 * 1024, // 2 KiB
|
||||
To: 3 * 1024, // 3 KiB
|
||||
}
|
||||
case PlacementHeader:
|
||||
return Xbadoption.Range{
|
||||
return badoption.Range[int]{
|
||||
From: 3 * 1000, // 3 KB
|
||||
To: 4 * 1000, // 4 KB
|
||||
}
|
||||
@@ -418,7 +425,7 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedUplinkChunkSize() Xbadoption.Range
|
||||
return c.GetNormalizedScMaxEachPostBytes()
|
||||
}
|
||||
} else if c.UplinkChunkSize.From < 64 {
|
||||
return Xbadoption.Range{
|
||||
return badoption.Range[int]{
|
||||
From: 64,
|
||||
To: max(64, c.UplinkChunkSize.To),
|
||||
}
|
||||
@@ -485,31 +492,31 @@ func (c *V2RayXHTTPBaseOptions) GetNormalizedSeqKey() string {
|
||||
}
|
||||
|
||||
type V2RayXHTTPXmuxOptions struct {
|
||||
MaxConcurrency Xbadoption.Range `json:"max_concurrency"`
|
||||
MaxConnections Xbadoption.Range `json:"max_connections"`
|
||||
CMaxReuseTimes Xbadoption.Range `json:"c_max_reuse_times"`
|
||||
HMaxRequestTimes Xbadoption.Range `json:"h_max_request_times"`
|
||||
HMaxReusableSecs Xbadoption.Range `json:"h_max_reusable_secs"`
|
||||
HKeepAlivePeriod int64 `json:"h_keep_alive_period"`
|
||||
MaxConcurrency badoption.Range[int] `json:"max_concurrency"`
|
||||
MaxConnections badoption.Range[int] `json:"max_connections"`
|
||||
CMaxReuseTimes badoption.Range[int] `json:"c_max_reuse_times"`
|
||||
HMaxRequestTimes badoption.Range[int] `json:"h_max_request_times"`
|
||||
HMaxReusableSecs badoption.Range[int] `json:"h_max_reusable_secs"`
|
||||
HKeepAlivePeriod int64 `json:"h_keep_alive_period"`
|
||||
}
|
||||
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConcurrency() Xbadoption.Range {
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConcurrency() badoption.Range[int] {
|
||||
return m.MaxConcurrency
|
||||
}
|
||||
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConnections() Xbadoption.Range {
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedMaxConnections() badoption.Range[int] {
|
||||
return m.MaxConnections
|
||||
}
|
||||
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedCMaxReuseTimes() Xbadoption.Range {
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedCMaxReuseTimes() badoption.Range[int] {
|
||||
return m.CMaxReuseTimes
|
||||
}
|
||||
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxRequestTimes() Xbadoption.Range {
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxRequestTimes() badoption.Range[int] {
|
||||
return m.HMaxRequestTimes
|
||||
}
|
||||
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxReusableSecs() Xbadoption.Range {
|
||||
func (m *V2RayXHTTPXmuxOptions) GetNormalizedHMaxReusableSecs() badoption.Range[int] {
|
||||
return m.HMaxReusableSecs
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package option
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
)
|
||||
|
||||
@@ -40,10 +39,10 @@ type WireGuardAmnezia struct {
|
||||
S2 int `json:"s2,omitempty"`
|
||||
S3 int `json:"s3,omitempty"`
|
||||
S4 int `json:"s4,omitempty"`
|
||||
H1 *Xbadoption.Range `json:"h1,omitempty"`
|
||||
H2 *Xbadoption.Range `json:"h2,omitempty"`
|
||||
H3 *Xbadoption.Range `json:"h3,omitempty"`
|
||||
H4 *Xbadoption.Range `json:"h4,omitempty"`
|
||||
H1 *badoption.Range[uint32] `json:"h1,omitempty"`
|
||||
H2 *badoption.Range[uint32] `json:"h2,omitempty"`
|
||||
H3 *badoption.Range[uint32] `json:"h3,omitempty"`
|
||||
H4 *badoption.Range[uint32] `json:"h4,omitempty"`
|
||||
I1 string `json:"i1,omitempty"`
|
||||
I2 string `json:"i2,omitempty"`
|
||||
I3 string `json:"i3,omitempty"`
|
||||
|
||||
@@ -80,6 +80,7 @@ func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
}
|
||||
|
||||
func (h *Inbound) Close() error {
|
||||
h.conns.Close()
|
||||
errs := make([]error, 0)
|
||||
for _, inbound := range h.inbounds {
|
||||
err := inbound.Close()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
@@ -31,14 +32,19 @@ type Fallback struct {
|
||||
tags []string
|
||||
outbounds map[string]adapter.Outbound
|
||||
lastUsedOutbound string
|
||||
|
||||
mtx sync.Mutex
|
||||
blacklistTimeout time.Duration
|
||||
blacklist map[string]time.Time
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FallbackOutboundOptions) (adapter.Outbound, error) {
|
||||
if len(options.Outbounds) == 0 {
|
||||
return nil, E.New("missing tags")
|
||||
}
|
||||
blacklistTimeout := time.Duration(options.BlacklistTimeout)
|
||||
if blacklistTimeout == 0 {
|
||||
blacklistTimeout = time.Minute
|
||||
}
|
||||
outbound := &Fallback{
|
||||
Adapter: outbound.NewAdapter(C.TypeFallback, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
|
||||
ctx: ctx,
|
||||
@@ -47,6 +53,8 @@ func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||
tags: options.Outbounds,
|
||||
outbounds: make(map[string]adapter.Outbound, len(options.Outbounds)),
|
||||
lastUsedOutbound: options.Outbounds[0],
|
||||
blacklistTimeout: blacklistTimeout,
|
||||
blacklist: make(map[string]time.Time),
|
||||
}
|
||||
return outbound, nil
|
||||
}
|
||||
@@ -73,35 +81,110 @@ func (s *Fallback) All() []string {
|
||||
}
|
||||
|
||||
func (s *Fallback) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
var conn net.Conn
|
||||
s.mtx.Lock()
|
||||
var active, blacklisted []string
|
||||
for _, tag := range s.tags {
|
||||
if s.isBlacklisted(tag) {
|
||||
blacklisted = append(blacklisted, tag)
|
||||
} else {
|
||||
active = append(active, tag)
|
||||
}
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
|
||||
var err error
|
||||
for _, outbound := range s.outbounds {
|
||||
conn, err = outbound.DialContext(ctx, network, destination)
|
||||
for _, tag := range active {
|
||||
var conn net.Conn
|
||||
conn, err = s.outbounds[tag].DialContext(ctx, network, destination)
|
||||
if err != nil {
|
||||
s.logger.InfoContext(ctx, err)
|
||||
s.mtx.Lock()
|
||||
s.addToBlacklist(tag)
|
||||
s.mtx.Unlock()
|
||||
continue
|
||||
}
|
||||
s.mtx.Lock()
|
||||
s.lastUsedOutbound = tag
|
||||
s.mtx.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
for _, tag := range blacklisted {
|
||||
var conn net.Conn
|
||||
conn, err = s.outbounds[tag].DialContext(ctx, network, destination)
|
||||
if err != nil {
|
||||
s.logger.InfoContext(ctx, err)
|
||||
continue
|
||||
}
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
s.lastUsedOutbound = outbound.Tag()
|
||||
delete(s.blacklist, tag)
|
||||
s.lastUsedOutbound = tag
|
||||
s.mtx.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *Fallback) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
var conn net.PacketConn
|
||||
s.mtx.Lock()
|
||||
var active, blacklisted []string
|
||||
for _, tag := range s.tags {
|
||||
if s.isBlacklisted(tag) {
|
||||
blacklisted = append(blacklisted, tag)
|
||||
} else {
|
||||
active = append(active, tag)
|
||||
}
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
|
||||
var err error
|
||||
for _, outbound := range s.outbounds {
|
||||
conn, err = outbound.ListenPacket(ctx, destination)
|
||||
for _, tag := range active {
|
||||
var conn net.PacketConn
|
||||
conn, err = s.outbounds[tag].ListenPacket(ctx, destination)
|
||||
if err != nil {
|
||||
s.logger.InfoContext(ctx, err)
|
||||
s.mtx.Lock()
|
||||
s.addToBlacklist(tag)
|
||||
s.mtx.Unlock()
|
||||
continue
|
||||
}
|
||||
s.mtx.Lock()
|
||||
s.lastUsedOutbound = tag
|
||||
s.mtx.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
for _, tag := range blacklisted {
|
||||
var conn net.PacketConn
|
||||
conn, err = s.outbounds[tag].ListenPacket(ctx, destination)
|
||||
if err != nil {
|
||||
s.logger.InfoContext(ctx, err)
|
||||
continue
|
||||
}
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
s.lastUsedOutbound = outbound.Tag()
|
||||
delete(s.blacklist, tag)
|
||||
s.lastUsedOutbound = tag
|
||||
s.mtx.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *Fallback) isBlacklisted(tag string) bool {
|
||||
if s.blacklistTimeout == 0 {
|
||||
return false
|
||||
}
|
||||
expiry, ok := s.blacklist[tag]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if time.Now().After(expiry) {
|
||||
delete(s.blacklist, tag)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Fallback) addToBlacklist(tag string) {
|
||||
if s.blacklistTimeout > 0 {
|
||||
s.blacklist[tag] = time.Now().Add(s.blacklistTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@ package bandwidth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/list"
|
||||
)
|
||||
|
||||
type BandwidthLimiter interface {
|
||||
@@ -14,123 +14,144 @@ type BandwidthLimiter interface {
|
||||
SetSpeed(speed uint64)
|
||||
}
|
||||
|
||||
type FlowKeysLimiter struct {
|
||||
type FairQueueLimiter struct {
|
||||
limiter BandwidthLimiter
|
||||
connIDGetter ConnIDGetter
|
||||
|
||||
waits map[string][]*wait
|
||||
conns map[string]int
|
||||
flows *list.List[*flow]
|
||||
index map[string]*list.Element[*flow]
|
||||
bytes map[string]uint64
|
||||
pool sync.Pool
|
||||
queue chan struct{}
|
||||
reset time.Time
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FlowKeysLimiter {
|
||||
return &FlowKeysLimiter{
|
||||
func NewFairQueueLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FairQueueLimiter {
|
||||
return &FairQueueLimiter{
|
||||
limiter: limiter,
|
||||
connIDGetter: connIDGetter,
|
||||
waits: make(map[string][]*wait),
|
||||
conns: make(map[string]int),
|
||||
flows: list.New[*flow](),
|
||||
index: make(map[string]*list.Element[*flow]),
|
||||
bytes: make(map[string]uint64),
|
||||
pool: sync.Pool{New: func() any { return list.New[*request]() }},
|
||||
queue: make(chan struct{}, 1),
|
||||
reset: time.Now().Add(time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *FlowKeysLimiter) SetSpeed(speed uint64) {
|
||||
func (l *FairQueueLimiter) SetSpeed(speed uint64) {
|
||||
l.limiter.SetSpeed(speed)
|
||||
}
|
||||
|
||||
func (l *FlowKeysLimiter) WaitN(ctx context.Context, n int) error {
|
||||
func (l *FairQueueLimiter) WaitN(ctx context.Context, n int) error {
|
||||
id, _ := l.connIDGetter(ctx, adapter.ContextFrom(ctx))
|
||||
mainWait := &wait{ctx, make(chan struct{}), n}
|
||||
mainRequest := &request{ctx: ctx, done: make(chan struct{}), n: n}
|
||||
l.mtx.Lock()
|
||||
if waits, ok := l.waits[id]; ok {
|
||||
l.waits[id] = append(waits, mainWait)
|
||||
} else {
|
||||
l.waits[id] = []*wait{mainWait}
|
||||
elem, ok := l.index[id]
|
||||
if !ok {
|
||||
f := &flow{id: id, pending: l.pool.Get().(*list.List[*request])}
|
||||
elem = l.flows.PushFront(f)
|
||||
l.index[id] = elem
|
||||
}
|
||||
mainRequestElem := elem.Value.pending.PushBack(mainRequest)
|
||||
l.reorder(elem)
|
||||
l.mtx.Unlock()
|
||||
select {
|
||||
case l.queue <- struct{}{}:
|
||||
case <-mainWait.finish:
|
||||
case <-mainRequest.done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
l.mtx.Lock()
|
||||
for i, wait := range l.waits[id] {
|
||||
if wait == mainWait {
|
||||
l.waits[id] = slices.Delete(l.waits[id], i, i+1)
|
||||
close(wait.finish)
|
||||
break
|
||||
}
|
||||
}
|
||||
l.removeRequest(id, mainRequestElem)
|
||||
l.mtx.Unlock()
|
||||
return ctx.Err()
|
||||
}
|
||||
select {
|
||||
case <-mainRequest.done:
|
||||
<-l.queue
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
l.mtx.Lock()
|
||||
for i, wait := range l.waits[id] {
|
||||
if wait == mainWait {
|
||||
l.waits[id] = slices.Delete(l.waits[id], i, i+1)
|
||||
close(wait.finish)
|
||||
break
|
||||
}
|
||||
}
|
||||
l.removeRequest(id, mainRequestElem)
|
||||
l.mtx.Unlock()
|
||||
<-l.queue
|
||||
return ctx.Err()
|
||||
}
|
||||
l.mtx.Lock()
|
||||
now := time.Now()
|
||||
if l.reset.Compare(now) == -1 {
|
||||
clear(l.conns)
|
||||
clear(l.bytes)
|
||||
l.reset = now.Add(time.Second)
|
||||
}
|
||||
l.mtx.Lock()
|
||||
var minConnId string
|
||||
var minN int
|
||||
for connID, waits := range l.waits {
|
||||
if len(waits) == 0 {
|
||||
continue
|
||||
}
|
||||
if n, ok := l.conns[connID]; ok {
|
||||
if minConnId == "" {
|
||||
minConnId = connID
|
||||
minN = n
|
||||
continue
|
||||
}
|
||||
if n+waits[0].n < minN {
|
||||
minConnId = connID
|
||||
minN = n
|
||||
}
|
||||
} else {
|
||||
l.conns[connID] = 0
|
||||
minConnId = connID
|
||||
break
|
||||
}
|
||||
}
|
||||
minWait := l.waits[minConnId][0]
|
||||
l.waits[minConnId][0] = nil
|
||||
l.waits[minConnId] = l.waits[minConnId][1:]
|
||||
if len(l.waits) == 0 {
|
||||
delete(l.waits, minConnId)
|
||||
flowElem := l.flows.Front()
|
||||
flow := flowElem.Value
|
||||
firstRequestElem := flow.pending.Front()
|
||||
firstRequest := firstRequestElem.Value
|
||||
l.bytes[flow.id] += uint64(firstRequest.n)
|
||||
firstRequestElem.Remove()
|
||||
if flow.pending.Len() == 0 {
|
||||
l.flows.Remove(flowElem)
|
||||
delete(l.index, flow.id)
|
||||
l.pool.Put(flow.pending)
|
||||
} else {
|
||||
l.reorder(flowElem)
|
||||
}
|
||||
l.mtx.Unlock()
|
||||
err := l.limiter.WaitN(ctx, minWait.n)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
l.conns[minConnId] = l.conns[minConnId] + minWait.n
|
||||
close(minWait.finish)
|
||||
if minWait == mainWait {
|
||||
l.limiter.WaitN(firstRequest.ctx, firstRequest.n)
|
||||
close(firstRequest.done)
|
||||
if firstRequest == mainRequest {
|
||||
<-l.queue
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type wait struct {
|
||||
ctx context.Context
|
||||
finish chan struct{}
|
||||
n int
|
||||
func (l *FairQueueLimiter) reorder(elem *list.Element[*flow]) {
|
||||
f := elem.Value
|
||||
front := f.pending.Front()
|
||||
if front == nil {
|
||||
return
|
||||
}
|
||||
cost := l.bytes[f.id] + uint64(front.Value.n)
|
||||
for e := l.flows.Front(); e != nil; e = e.Next() {
|
||||
if e == elem {
|
||||
continue
|
||||
}
|
||||
eFront := e.Value.pending.Front()
|
||||
if eFront == nil {
|
||||
continue
|
||||
}
|
||||
if cost < l.bytes[e.Value.id]+uint64(eFront.Value.n) {
|
||||
l.flows.MoveBefore(elem, e)
|
||||
return
|
||||
}
|
||||
}
|
||||
l.flows.MoveToBack(elem)
|
||||
}
|
||||
|
||||
func (l *FairQueueLimiter) removeRequest(id string, elem *list.Element[*request]) {
|
||||
if !elem.Remove() {
|
||||
return
|
||||
}
|
||||
if flowElem, ok := l.index[id]; ok && flowElem.Value.pending.Len() == 0 {
|
||||
l.flows.Remove(flowElem)
|
||||
delete(l.index, id)
|
||||
l.pool.Put(flowElem.Value.pending)
|
||||
}
|
||||
}
|
||||
|
||||
type flow struct {
|
||||
id string
|
||||
pending *list.List[*request]
|
||||
}
|
||||
|
||||
type request struct {
|
||||
ctx context.Context
|
||||
done chan struct{}
|
||||
n int
|
||||
}
|
||||
|
||||
@@ -357,7 +357,7 @@ func createSpeedLimiter(speed uint64, flowKeys []string) (BandwidthLimiter, erro
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
limiter = NewFlowKeysLimiter(getter, limiter)
|
||||
limiter = NewFairQueueLimiter(getter, limiter)
|
||||
}
|
||||
return limiter, nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/cloudflare"
|
||||
"github.com/sagernet/sing-box/common/congestion"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/service"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -132,6 +134,15 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
congestionControl, err := congestion.NewCongestionControl(
|
||||
options.CongestionController,
|
||||
options.CWND,
|
||||
ntp.TimeFuncFromContext(ctx),
|
||||
)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
tunnel, err := masque.NewTunnel(
|
||||
ctx,
|
||||
logger,
|
||||
@@ -156,6 +167,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||
UDPKeepalivePeriod: udpKeepalivePeriod,
|
||||
UDPInitialPacketSize: options.UDPInitialPacketSize,
|
||||
ReconnectDelay: options.ReconnectDelay.Build(),
|
||||
CongestionControl: congestionControl,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
@@ -104,6 +104,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||
AllowedAddress: options.AllowedIPs,
|
||||
ReconnectDelay: time.Duration(options.ReconnectDelay),
|
||||
PingInterval: time.Duration(options.PingInterval),
|
||||
PingRestart: time.Duration(options.PingRestart),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
130
protocol/snell/inbound.go
Normal file
130
protocol/snell/inbound.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/snell"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func RegisterInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.SnellInboundOptions](registry, C.TypeSnell, NewInbound)
|
||||
}
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
router adapter.ConnectionRouterEx
|
||||
logger logger.ContextLogger
|
||||
listener *listener.Listener
|
||||
service *snell.Service
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellInboundOptions) (adapter.Inbound, error) {
|
||||
if options.PSK == "" {
|
||||
return nil, E.New("snell requires psk")
|
||||
}
|
||||
udpEnabled := common.Contains(options.Network.Build(), N.NetworkUDP)
|
||||
obfsMode := ""
|
||||
if options.Obfs != nil {
|
||||
obfsMode = options.Obfs.Mode
|
||||
}
|
||||
in := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeSnell, tag),
|
||||
router: router,
|
||||
logger: logger,
|
||||
}
|
||||
service, err := snell.NewService(snell.ServiceOptions{
|
||||
PSK: []byte(options.PSK),
|
||||
Version: options.Version,
|
||||
ObfsMode: obfsMode,
|
||||
UDP: udpEnabled,
|
||||
Logger: logger,
|
||||
Handler: (*inboundHandler)(in),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
in.service = service
|
||||
in.listener = listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
ConnectionHandler: in,
|
||||
})
|
||||
return in, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return h.listener.Start()
|
||||
}
|
||||
|
||||
func (h *Inbound) Close() error {
|
||||
return h.listener.Close()
|
||||
}
|
||||
|
||||
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
err := h.service.NewConnection(ctx, conn, metadata.Source)
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
if err != nil {
|
||||
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
|
||||
}
|
||||
}
|
||||
|
||||
var _ adapter.TCPInjectableInbound = (*Inbound)(nil)
|
||||
|
||||
type inboundHandler Inbound
|
||||
|
||||
func (h *inboundHandler) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, clientID string) {
|
||||
var metadata adapter.InboundContext
|
||||
metadata.Inbound = h.Tag()
|
||||
metadata.InboundType = h.Type()
|
||||
metadata.InboundDetour = h.listener.ListenOptions().Detour
|
||||
metadata.Source = source
|
||||
metadata.Destination = destination
|
||||
if clientID != "" {
|
||||
metadata.User = clientID
|
||||
h.logger.InfoContext(ctx, "[", clientID, "] inbound connection to ", destination)
|
||||
} else {
|
||||
h.logger.InfoContext(ctx, "inbound connection to ", destination)
|
||||
}
|
||||
done := make(chan struct{})
|
||||
h.router.RouteConnectionEx(ctx, conn, metadata, N.OnceClose(func(error) {
|
||||
close(done)
|
||||
}))
|
||||
<-done
|
||||
}
|
||||
|
||||
func (h *inboundHandler) NewPacketConnection(ctx context.Context, conn net.PacketConn, source M.Socksaddr, clientID string) {
|
||||
var metadata adapter.InboundContext
|
||||
metadata.Inbound = h.Tag()
|
||||
metadata.InboundType = h.Type()
|
||||
metadata.InboundDetour = h.listener.ListenOptions().Detour
|
||||
metadata.Source = source
|
||||
if clientID != "" {
|
||||
metadata.User = clientID
|
||||
h.logger.InfoContext(ctx, "[", clientID, "] inbound packet connection")
|
||||
} else {
|
||||
h.logger.InfoContext(ctx, "inbound packet connection")
|
||||
}
|
||||
done := make(chan struct{})
|
||||
h.router.RoutePacketConnectionEx(ctx, bufio.NewPacketConn(conn), metadata, N.OnceClose(func(error) {
|
||||
close(done)
|
||||
}))
|
||||
<-done
|
||||
}
|
||||
114
protocol/snell/outbound.go
Normal file
114
protocol/snell/outbound.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/snell"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func RegisterOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.SnellOutboundOptions](registry, C.TypeSnell, NewOutbound)
|
||||
}
|
||||
|
||||
type Outbound struct {
|
||||
outbound.Adapter
|
||||
logger logger.ContextLogger
|
||||
client *snell.Client
|
||||
}
|
||||
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SnellOutboundOptions) (adapter.Outbound, error) {
|
||||
if options.PSK == "" {
|
||||
return nil, E.New("snell requires psk")
|
||||
}
|
||||
version := options.Version
|
||||
if version == 0 {
|
||||
version = snell.DefaultSnellVersion
|
||||
}
|
||||
if version == snell.Version5 {
|
||||
version = snell.Version4
|
||||
}
|
||||
udpEnabled := common.Contains(options.Network.Build(), N.NetworkUDP)
|
||||
switch version {
|
||||
case snell.Version1, snell.Version2:
|
||||
if udpEnabled {
|
||||
return nil, fmt.Errorf("snell version %d does not support UDP", version)
|
||||
}
|
||||
case snell.Version3, snell.Version4:
|
||||
default:
|
||||
return nil, fmt.Errorf("snell version error: %d", version)
|
||||
}
|
||||
reuse := version == snell.Version2 || (version == snell.Version4 && options.Reuse)
|
||||
obfsMode := ""
|
||||
obfsHost := "bing.com"
|
||||
if options.Obfs != nil {
|
||||
switch options.Obfs.Mode {
|
||||
case "", "tls", "http":
|
||||
obfsMode = options.Obfs.Mode
|
||||
default:
|
||||
return nil, fmt.Errorf("snell obfs mode error: %s", options.Obfs.Mode)
|
||||
}
|
||||
if options.Obfs.Host != "" {
|
||||
obfsHost = options.Obfs.Host
|
||||
}
|
||||
}
|
||||
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := snell.NewClient(snell.ClientOptions{
|
||||
Dialer: outboundDialer,
|
||||
Server: options.ServerOptions.Build(),
|
||||
PSK: []byte(options.PSK),
|
||||
Version: version,
|
||||
Reuse: reuse,
|
||||
ObfsMode: obfsMode,
|
||||
ObfsHost: obfsHost,
|
||||
})
|
||||
return &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSnell, tag, options.Network.Build(), options.DialerOptions),
|
||||
logger: logger,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
ctx, metadata := adapter.ExtendContext(ctx)
|
||||
metadata.Outbound = h.Tag()
|
||||
metadata.Destination = destination
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
h.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
return h.client.DialContext(ctx, destination)
|
||||
case N.NetworkUDP:
|
||||
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
conn, err := h.client.ListenPacket(ctx, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bufio.NewBindPacketConn(conn, destination), nil
|
||||
default:
|
||||
return nil, E.Extend(N.ErrUnknownNetwork, network)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
ctx, metadata := adapter.ExtendContext(ctx)
|
||||
metadata.Outbound = h.Tag()
|
||||
metadata.Destination = destination
|
||||
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
return h.client.ListenPacket(ctx, destination)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/common/congestion"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
@@ -136,10 +137,9 @@ func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
congestionControlFactory, err := trusttunnel.NewCongestionControl(
|
||||
congestionControlFactory, err := congestion.NewCongestionControl(
|
||||
h.options.CongestionController,
|
||||
h.options.CWND,
|
||||
h.options.BBRProfile,
|
||||
ntp.TimeFuncFromContext(h.ctx),
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
@@ -53,7 +53,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||
QUIC: options.QUIC,
|
||||
CongestionControl: options.CongestionController,
|
||||
CWND: options.CWND,
|
||||
BBRProfile: options.BBRProfile,
|
||||
Logger: logger,
|
||||
HealthCheck: options.HealthCheck,
|
||||
}
|
||||
var client trusttunnel.Dialer
|
||||
|
||||
@@ -1 +1 @@
|
||||
with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_naive_outbound,badlinkname,tfogo_checklinkname0
|
||||
with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_snell,with_naive_outbound,badlinkname,tfogo_checklinkname0
|
||||
@@ -1 +1 @@
|
||||
with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_manager,with_admin_panel,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,badlinkname,tfogo_checklinkname0
|
||||
with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_manager,with_admin_panel,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_snell,badlinkname,tfogo_checklinkname0
|
||||
|
||||
@@ -1 +1 @@
|
||||
with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0
|
||||
with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_ccm,with_ocm,with_openvpn,with_trusttunnel,with_sudoku,with_snell,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/sh
|
||||
[ -s ${IPKG_INSTROOT}/lib/functions.sh ] || exit 0
|
||||
. ${IPKG_INSTROOT}/lib/functions.sh
|
||||
default_prerm $0 $@
|
||||
default_prerm $0 $@ || true
|
||||
|
||||
@@ -58,6 +58,13 @@ func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata
|
||||
}
|
||||
|
||||
func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
|
||||
select {
|
||||
case <-r.started:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-r.ctx.Done():
|
||||
return r.ctx.Err()
|
||||
}
|
||||
//nolint:staticcheck
|
||||
if metadata.InboundDetour != "" {
|
||||
if metadata.LastInbound == metadata.InboundDetour {
|
||||
@@ -192,6 +199,13 @@ func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn,
|
||||
}
|
||||
|
||||
func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
|
||||
select {
|
||||
case <-r.started:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-r.ctx.Done():
|
||||
return r.ctx.Err()
|
||||
}
|
||||
//nolint:staticcheck
|
||||
if metadata.InboundDetour != "" {
|
||||
if metadata.LastInbound == metadata.InboundDetour {
|
||||
|
||||
146
route/route_start_test.go
Normal file
146
route/route_start_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
// newGateTestRouter builds the minimal Router needed to exercise the
|
||||
// "wait until started" gate in routeConnection / routePacketConnection.
|
||||
func newGateTestRouter(ctx context.Context) *Router {
|
||||
return &Router{
|
||||
ctx: ctx,
|
||||
started: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// gateMetadata returns metadata that hits the InboundDetour loop-detection
|
||||
// branch (LastInbound == InboundDetour), so routeConnection /
|
||||
// routePacketConnection return immediately once the gate is open, without
|
||||
// needing any outbound/inbound managers.
|
||||
func gateMetadata() adapter.InboundContext {
|
||||
return adapter.InboundContext{InboundDetour: "self", LastInbound: "self"}
|
||||
}
|
||||
|
||||
// TestRouteConnectionWaitsForStart verifies that a connection arriving before
|
||||
// the router finishes starting (StartStatePostStart) blocks until the started
|
||||
// channel is closed, then proceeds, instead of dereferencing a nil
|
||||
// defaultOutbound.
|
||||
func TestRouteConnectionWaitsForStart(t *testing.T) {
|
||||
r := newGateTestRouter(context.Background())
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- r.routeConnection(context.Background(), nil, gateMetadata(), nil)
|
||||
}()
|
||||
|
||||
// The gate must block while the router is not yet started.
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("routeConnection returned before router was started")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Simulate StartStatePostStart completing.
|
||||
close(r.started)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
// We expect to have passed the gate and reached the loop-detection branch,
|
||||
// which returns the "routing loop on detour" error.
|
||||
if err == nil {
|
||||
t.Fatal("expected routing-loop error after gate opened, got nil")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("routeConnection did not proceed after router started")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoutePacketConnectionWaitsForStart is the UDP counterpart.
|
||||
func TestRoutePacketConnectionWaitsForStart(t *testing.T) {
|
||||
r := newGateTestRouter(context.Background())
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- r.routePacketConnection(context.Background(), nil, gateMetadata(), nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("routePacketConnection returned before router was started")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(r.started)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err == nil {
|
||||
t.Fatal("expected routing-loop error after gate opened, got nil")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("routePacketConnection did not proceed after router started")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouteConnectionAbortsOnConnContext verifies that a client disconnecting
|
||||
// during startup unblocks the gate via the per-connection context, instead of
|
||||
// hanging until the router starts.
|
||||
func TestRouteConnectionAbortsOnConnContext(t *testing.T) {
|
||||
r := newGateTestRouter(context.Background())
|
||||
|
||||
connCtx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- r.routeConnection(connCtx, nil, gateMetadata(), nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("routeConnection returned before context was cancelled")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("routeConnection did not abort on connection context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouteConnectionAbortsOnRouterContext verifies that shutting down the box
|
||||
// (router context cancellation) unblocks an in-flight early connection.
|
||||
func TestRouteConnectionAbortsOnRouterContext(t *testing.T) {
|
||||
routerCtx, cancel := context.WithCancel(context.Background())
|
||||
r := newGateTestRouter(routerCtx)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- r.routeConnection(context.Background(), nil, gateMetadata(), nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("routeConnection returned before router context was cancelled")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("routeConnection did not abort on router context cancellation")
|
||||
}
|
||||
}
|
||||
@@ -44,7 +44,7 @@ type Router struct {
|
||||
pauseManager pause.Manager
|
||||
trackers []adapter.ConnectionTracker
|
||||
platformInterface adapter.PlatformInterface
|
||||
started bool
|
||||
started chan struct{}
|
||||
}
|
||||
|
||||
func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) *Router {
|
||||
@@ -63,6 +63,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
|
||||
needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess,
|
||||
pauseManager: service.FromContext[pause.Manager](ctx),
|
||||
platformInterface: service.FromContext[adapter.PlatformInterface](ctx),
|
||||
started: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,7 +181,7 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
} else {
|
||||
r.defaultOutbound = r.outbound.Default()
|
||||
}
|
||||
r.started = true
|
||||
close(r.started)
|
||||
return nil
|
||||
case adapter.StartStateStarted:
|
||||
for _, ruleSet := range r.ruleSets {
|
||||
|
||||
@@ -29,13 +29,17 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
|
||||
case "":
|
||||
return nil, nil
|
||||
case C.RuleActionTypeRoute:
|
||||
overrideGateway := M.ParseAddr(action.RouteOptions.OverrideGateway)
|
||||
var overrideGateway *netip.Addr
|
||||
if action.RouteOptions.OverrideGateway != "" {
|
||||
parsed := M.ParseAddr(action.RouteOptions.OverrideGateway)
|
||||
overrideGateway = &parsed
|
||||
}
|
||||
return &RuleActionRoute{
|
||||
Outbound: action.RouteOptions.Outbound,
|
||||
RuleActionRouteOptions: RuleActionRouteOptions{
|
||||
OverrideAddress: M.ParseSocksaddrHostPort(action.RouteOptions.OverrideAddress, 0),
|
||||
OverridePort: action.RouteOptions.OverridePort,
|
||||
OverrideGateway: &overrideGateway,
|
||||
OverrideGateway: overrideGateway,
|
||||
NetworkStrategy: (*C.NetworkStrategy)(action.RouteOptions.NetworkStrategy),
|
||||
FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay),
|
||||
UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping,
|
||||
|
||||
14
test/go.mod
14
test/go.mod
@@ -1,6 +1,6 @@
|
||||
module test
|
||||
|
||||
go 1.26.1
|
||||
go 1.26.4
|
||||
|
||||
require github.com/sagernet/sing-box v0.0.0
|
||||
|
||||
@@ -14,15 +14,17 @@ replace github.com/sagernet/sing-mux => github.com/shtorm-7/sing-mux v0.3.4-exte
|
||||
|
||||
replace github.com/ameshkov/dnscrypt/v2 => github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0
|
||||
|
||||
replace github.com/sagernet/sing-vmess => github.com/starifly/sing-vmess v0.2.7-mod.9
|
||||
replace github.com/sagernet/sing-vmess => github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0
|
||||
|
||||
replace github.com/sagernet/sing => github.com/shtorm-7/sing v0.8.10-extended-1.0.0
|
||||
replace github.com/sagernet/sing => /home/shtorm/Projects/shtorm-7/sing
|
||||
|
||||
replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.1
|
||||
replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0
|
||||
|
||||
replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0
|
||||
|
||||
replace github.com/shtorm-7/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0
|
||||
replace github.com/shtorm-7/go-cache/v2 => /home/shtorm/Projects/shtorm-7/go-cache
|
||||
|
||||
replace github.com/sagernet/smux => /home/shtorm/Projects/shtorm-7/smux
|
||||
|
||||
require (
|
||||
github.com/docker/docker v28.5.2+incompatible
|
||||
@@ -36,7 +38,6 @@ require (
|
||||
github.com/spyzhov/ajson v0.9.4
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.uber.org/goleak v1.3.0
|
||||
golang.org/x/crypto v0.49.0
|
||||
golang.org/x/net v0.52.0
|
||||
)
|
||||
|
||||
@@ -221,6 +222,7 @@ require (
|
||||
go.uber.org/zap/exp v0.3.0 // indirect
|
||||
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
|
||||
golang.org/x/mod v0.34.0 // indirect
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
|
||||
14
test/go.sum
14
test/go.sum
@@ -362,8 +362,6 @@ github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75
|
||||
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA=
|
||||
github.com/sagernet/sing-tun v0.8.9 h1:ixFKKUGdVcJl4wb0xbL36hobiw9l6DIH497EQf5ILpM=
|
||||
github.com/sagernet/sing-tun v0.8.9/go.mod h1:QvarqUtHfj1ULaRR+6kZOS/OoCE+pYGq67A5tyIy+dQ=
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478=
|
||||
github.com/sagernet/smux v1.5.50-sing-box-mod.1/go.mod h1:NjhsCEWedJm7eFLyhuBgIEzwfhRmytrUoiLluxs5Sk8=
|
||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
|
||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
|
||||
github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 h1:ws7BIsYLd31Wjifq88BYCHRVlgO+07iwil39s6ERba8=
|
||||
@@ -372,12 +370,14 @@ github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 h1:e5s7RKBd2rIPR0StbvZ2vTV
|
||||
github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI=
|
||||
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0 h1:PLZ/YHqnApPx13wt6MX3ItqESp4ueBr1tGSi0bEGqYw=
|
||||
github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.1.0/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4=
|
||||
github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.1 h1:UeJkrCJJmIjTBywErVMx7fCSoBf4gh6QgT9bp9o1ajM=
|
||||
github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.1/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0=
|
||||
github.com/shtorm-7/sing v0.8.10-extended-1.0.0 h1:mAkyycCQOzCttPOR5fcHkJaZvXMQXeu3mbEfr8D+7A8=
|
||||
github.com/shtorm-7/sing v0.8.10-extended-1.0.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA=
|
||||
github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0 h1:iBLll4ZZG8ULQcHWs6gGslZWtBN72Zo1zjySzMVHF7g=
|
||||
github.com/shtorm-7/mtg-multi v1.11.0-extended-1.0.0/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0=
|
||||
github.com/shtorm-7/sing v0.8.10-extended-1.1.0 h1:P4JL2cugjvEvnYu8tMmpR30SE1qsS45RcnNEwzDz5as=
|
||||
github.com/shtorm-7/sing v0.8.10-extended-1.1.0/go.mod h1:olXxWQNqRW/l2Q6JI3b2Qmz8iQnIFlOeeH8bx6JhgUA=
|
||||
github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0 h1:a5OoXr3e2ACbM6vDIaaGL44IdHQ6wPjcSoU13vfC0Sw=
|
||||
github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk=
|
||||
github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0 h1:WVheKmQH5hSQbJU1ZTKthKSutkTLWSb2hp4JuQhJBow=
|
||||
github.com/shtorm-7/sing-vmess v0.2.7-extended-1.0.0/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs=
|
||||
github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2 h1:hSMjh97OszszOd8HrzpaYUQH9dWRRBluJCbwQyz8ZOk=
|
||||
github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2/go.mod h1:TYIIqO5sZpWq873rLIeO2usszSMUpR3h6WdqVVs65ug=
|
||||
github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.4.3 h1:jtOA73D4F5qRV70//ahOt20KBnWvQimAFjtIiOtt0ps=
|
||||
@@ -388,8 +388,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/spyzhov/ajson v0.9.4 h1:MVibcTCgO7DY4IlskdqIlCmDOsUOZ9P7oKj8ifdcf84=
|
||||
github.com/spyzhov/ajson v0.9.4/go.mod h1:a6oSw0MMb7Z5aD2tPoPO+jq11ETKgXUr2XktHdT8Wt8=
|
||||
github.com/starifly/sing-vmess v0.2.7-mod.9 h1:xobAmejSbBQ0A3f/EtJ9cJd3m6gK7dDPccPdeGz7tXY=
|
||||
github.com/starifly/sing-vmess v0.2.7-mod.9/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
|
||||
@@ -217,3 +217,19 @@ func TestV2RayGRPCLite(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestV2RayGRPCLiteServiceNameWithSlashes(t *testing.T) {
|
||||
testV2RayTransportSelfWith(t, &option.V2RayTransportOptions{
|
||||
Type: C.V2RayTransportTypeGRPC,
|
||||
GRPCOptions: option.V2RayGRPCOptions{
|
||||
ServiceName: "/47559/I53GwKHO",
|
||||
ForceLite: true,
|
||||
},
|
||||
}, &option.V2RayTransportOptions{
|
||||
Type: C.V2RayTransportTypeGRPC,
|
||||
GRPCOptions: option.V2RayGRPCOptions{
|
||||
ServiceName: "/47559/I53GwKHO",
|
||||
ForceLite: true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
331
transport/masque/client_h2.go
Normal file
331
transport/masque/client_h2.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package masque
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/sagernet/quic-go/quicvarint"
|
||||
"github.com/yosida95/uritemplate/v3"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const h2DatagramCapsuleType uint64 = 0
|
||||
|
||||
const (
|
||||
ipv4HeaderLen = 20
|
||||
ipv6HeaderLen = 40
|
||||
)
|
||||
|
||||
func ConnectTunnelH2(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, endpoint *net.TCPAddr, connectUri string) (io.Closer, IpConn, *http.Response, error) {
|
||||
if endpoint == nil {
|
||||
return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
|
||||
}
|
||||
|
||||
tlsConfig.SetNextProtos([]string{"h2"})
|
||||
|
||||
conn, err := dialer.DialContext(ctx, N.NetworkTCP, M.SocksaddrFromNetIP(endpoint.AddrPort()))
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
tlsConn, err := tlsConfig.Client(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if err = tlsConn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
tr := &http2.Transport{
|
||||
ReadIdleTimeout: 30 * time.Second,
|
||||
}
|
||||
cc, err := tr.NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
_ = tlsConn.Close()
|
||||
return nil, nil, nil, fmt.Errorf("connect-ip: failed to create client connection: %w", err)
|
||||
}
|
||||
|
||||
additionalHeaders := http.Header{
|
||||
"User-Agent": []string{""},
|
||||
}
|
||||
template := uritemplate.MustNew(connectUri)
|
||||
|
||||
h2Headers := additionalHeaders.Clone()
|
||||
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
|
||||
h2Headers.Set("pq-enabled", "false")
|
||||
|
||||
ipConn, rsp, err := dialH2(ctx, cc, template, h2Headers)
|
||||
if err != nil {
|
||||
_ = cc.Close()
|
||||
if strings.Contains(err.Error(), "tls: access denied") {
|
||||
return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
}
|
||||
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
|
||||
}
|
||||
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
_ = ipConn.Close()
|
||||
_ = cc.Close()
|
||||
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %v", rsp.Status)
|
||||
}
|
||||
|
||||
return cc, ipConn, rsp, nil
|
||||
}
|
||||
|
||||
func dialH2(ctx context.Context, rt http.RoundTripper, template *uritemplate.Template, additionalHeaders http.Header) (*h2IpConn, *http.Response, error) {
|
||||
if len(template.Varnames()) > 0 {
|
||||
return nil, nil, errors.New("connect-ip: IP flow forwarding not supported")
|
||||
}
|
||||
|
||||
u, err := url.Parse(template.Raw())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("connect-ip: failed to parse URI: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodConnect, u.String(), pr)
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
return nil, nil, fmt.Errorf("connect-ip: failed to create request: %w", err)
|
||||
}
|
||||
req.Host = authorityFromURL(u)
|
||||
req.ContentLength = -1
|
||||
req.Header = make(http.Header)
|
||||
for k, v := range additionalHeaders {
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
stop := context.AfterFunc(ctx, cancel)
|
||||
rsp, err := rt.RoundTrip(req)
|
||||
stop()
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
return nil, nil, fmt.Errorf("connect-ip: failed to send request: %w", err)
|
||||
}
|
||||
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
|
||||
cancel()
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
_ = rsp.Body.Close()
|
||||
return nil, rsp, fmt.Errorf("connect-ip: server responded with %d", rsp.StatusCode)
|
||||
}
|
||||
|
||||
stream := &h2DatagramStream{
|
||||
requestBody: pw,
|
||||
responseBody: rsp.Body,
|
||||
cancel: cancel,
|
||||
}
|
||||
return &h2IpConn{
|
||||
str: stream,
|
||||
closeChan: make(chan struct{}),
|
||||
}, rsp, nil
|
||||
}
|
||||
|
||||
func authorityFromURL(u *url.URL) string {
|
||||
if u.Port() != "" {
|
||||
return u.Host
|
||||
}
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return u.Host
|
||||
}
|
||||
return host + ":443"
|
||||
}
|
||||
|
||||
type h2IpConn struct {
|
||||
str *h2DatagramStream
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
closeChan chan struct{}
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (c *h2IpConn) ReadPacket() (b []byte, err error) {
|
||||
start:
|
||||
data, err := c.str.ReceiveDatagram(context.Background())
|
||||
if err != nil {
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
select {
|
||||
case <-c.closeChan:
|
||||
return nil, c.closeErr
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := c.handleIncomingProxiedPacket(data); err != nil {
|
||||
goto start
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) handleIncomingProxiedPacket(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return errors.New("connect-ip: empty packet")
|
||||
}
|
||||
switch v := ipVersion(data); v {
|
||||
default:
|
||||
return fmt.Errorf("connect-ip: unknown IP versions: %d", v)
|
||||
case 4:
|
||||
if len(data) < ipv4HeaderLen {
|
||||
return fmt.Errorf("connect-ip: malformed datagram: too short")
|
||||
}
|
||||
case 6:
|
||||
if len(data) < ipv6HeaderLen {
|
||||
return fmt.Errorf("connect-ip: malformed datagram: too short")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) WritePacket(b []byte) (icmp []byte, err error) {
|
||||
data, err := c.composeDatagram(b)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err := c.str.SendDatagram(data); err != nil {
|
||||
select {
|
||||
case <-c.closeChan:
|
||||
return nil, c.closeErr
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) composeDatagram(b []byte) ([]byte, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
switch v := ipVersion(b); v {
|
||||
default:
|
||||
return nil, fmt.Errorf("connect-ip: unknown IP versions: %d", v)
|
||||
case 4:
|
||||
if len(b) < ipv4HeaderLen {
|
||||
return nil, fmt.Errorf("connect-ip: IPv4 packet too short")
|
||||
}
|
||||
ttl := b[8]
|
||||
if ttl <= 1 {
|
||||
return nil, fmt.Errorf("connect-ip: datagram TTL too small: %d", ttl)
|
||||
}
|
||||
b[8]--
|
||||
binary.BigEndian.PutUint16(b[10:12], calculateIPv4Checksum(([ipv4HeaderLen]byte)(b[:ipv4HeaderLen])))
|
||||
case 6:
|
||||
if len(b) < ipv6HeaderLen {
|
||||
return nil, fmt.Errorf("connect-ip: IPv6 packet too short")
|
||||
}
|
||||
hopLimit := b[7]
|
||||
if hopLimit <= 1 {
|
||||
return nil, fmt.Errorf("connect-ip: datagram Hop Limit too small: %d", hopLimit)
|
||||
}
|
||||
b[7]--
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) Close() error {
|
||||
c.mu.Lock()
|
||||
if c.closeErr == nil {
|
||||
c.closeErr = net.ErrClosed
|
||||
close(c.closeChan)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
err := c.str.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func ipVersion(b []byte) uint8 { return b[0] >> 4 }
|
||||
|
||||
func calculateIPv4Checksum(header [ipv4HeaderLen]byte) uint16 {
|
||||
var sum uint32
|
||||
for i := 0; i < len(header); i += 2 {
|
||||
if i == 10 {
|
||||
continue
|
||||
}
|
||||
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||
}
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
type h2DatagramStream struct {
|
||||
requestBody *io.PipeWriter
|
||||
responseBody io.ReadCloser
|
||||
cancel context.CancelFunc
|
||||
|
||||
readMu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *h2DatagramStream) ReceiveDatagram(_ context.Context) ([]byte, error) {
|
||||
s.readMu.Lock()
|
||||
defer s.readMu.Unlock()
|
||||
|
||||
reader := quicvarint.NewReader(s.responseBody)
|
||||
for {
|
||||
capsuleType, err := quicvarint.Read(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen, err := quicvarint.Read(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := make([]byte, payloadLen)
|
||||
_, err = io.ReadFull(reader, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if capsuleType != h2DatagramCapsuleType {
|
||||
continue
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *h2DatagramStream) SendDatagram(data []byte) error {
|
||||
frame := make([]byte, 0, quicvarint.Len(h2DatagramCapsuleType)+quicvarint.Len(uint64(len(data)))+len(data))
|
||||
frame = quicvarint.Append(frame, h2DatagramCapsuleType)
|
||||
frame = quicvarint.Append(frame, uint64(len(data)))
|
||||
frame = append(frame, data...)
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
_, err := s.requestBody.Write(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect-ip: failed to send datagram capsule: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *h2DatagramStream) Close() error {
|
||||
_ = s.requestBody.Close()
|
||||
err := s.responseBody.Close()
|
||||
s.cancel()
|
||||
return err
|
||||
}
|
||||
@@ -2,9 +2,9 @@ package masque
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -12,13 +12,13 @@ import (
|
||||
|
||||
connectip "github.com/Diniboy1123/connect-ip-go"
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
qtls "github.com/sagernet/sing-quic"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
"github.com/yosida95/uritemplate/v3"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -26,39 +26,60 @@ type (
|
||||
ListenPacket func(network string, address string) (net.PacketConn, error)
|
||||
)
|
||||
|
||||
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool) (net.PacketConn, *http3.Transport, *connectip.Conn, *http.Response, error) {
|
||||
template := uritemplate.MustNew(connectUri)
|
||||
additionalHeaders := http.Header{
|
||||
"User-Agent": []string{""},
|
||||
type IpConn interface {
|
||||
ReadPacket() (b []byte, err error)
|
||||
WritePacket(b []byte) (icmp []byte, err error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type closerFunc func() error
|
||||
|
||||
func (f closerFunc) Close() error { return f() }
|
||||
|
||||
type quicIpConn struct {
|
||||
conn *connectip.Conn
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newQuicIpConn(conn *connectip.Conn) *quicIpConn {
|
||||
return &quicIpConn{
|
||||
conn: conn,
|
||||
buf: make([]byte, 0xFFFF),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *quicIpConn) ReadPacket() ([]byte, error) {
|
||||
n, err := c.conn.ReadPacket(c.buf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.buf[:n], nil
|
||||
}
|
||||
|
||||
func (c *quicIpConn) WritePacket(b []byte) (icmp []byte, err error) {
|
||||
return c.conn.WritePacket(b)
|
||||
}
|
||||
|
||||
func (c *quicIpConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool, congestionControl func(conn *quic.Conn) congestion.CongestionControl) (io.Closer, IpConn, *http.Response, error) {
|
||||
if useHTTP2 {
|
||||
h2Endpoint, ok := endpoint.(*net.TCPAddr)
|
||||
if !ok || h2Endpoint == nil {
|
||||
return nil, nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
|
||||
return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
|
||||
}
|
||||
h2Headers := additionalHeaders.Clone()
|
||||
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
|
||||
h2Headers.Set("pq-enabled", "false")
|
||||
h2Client, err := newHTTP2Client(dialer, tlsConfig, h2Endpoint, connectUri)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to create HTTP/2 client: %w", err)
|
||||
}
|
||||
ipConn, rsp, err := connectip.DialH2(ctx, h2Client, template, h2Headers)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "tls: access denied") {
|
||||
return nil, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
}
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
|
||||
}
|
||||
return nil, nil, ipConn, rsp, nil
|
||||
return ConnectTunnelH2(ctx, dialer, tlsConfig, h2Endpoint, connectUri)
|
||||
}
|
||||
|
||||
quicEndpoint, ok := endpoint.(*net.UDPAddr)
|
||||
if !ok || quicEndpoint == nil {
|
||||
return nil, nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint")
|
||||
return nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint")
|
||||
}
|
||||
udpConn, err := dialer.ListenPacket(ctx, M.SocksaddrFromNetIP(quicEndpoint.AddrPort()))
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
conn, err := qtls.Dial(
|
||||
ctx,
|
||||
@@ -68,28 +89,34 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
|
||||
quicConfig,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
_ = udpConn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if congestionControl != nil {
|
||||
conn.SetCongestionControl(congestionControl(conn))
|
||||
}
|
||||
tr := &http3.Transport{
|
||||
EnableDatagrams: true,
|
||||
AdditionalSettings: map[uint64]uint64{
|
||||
// official client still sends this out as well, even though
|
||||
// it's deprecated, see https://datatracker.ietf.org/doc/draft-ietf-masque-h3-datagram/00/
|
||||
// SETTINGS_H3_DATAGRAM_00 = 0x0000000000000276
|
||||
// https://github.com/cloudflare/quiche/blob/7c66757dbc55b8d0c3653d4b345c6785a181f0b7/quiche/src/h3/frame.rs#L46
|
||||
0x276: 1,
|
||||
},
|
||||
DisableCompression: true,
|
||||
}
|
||||
hconn := tr.NewClientConn(conn)
|
||||
|
||||
template := uritemplate.MustNew(connectUri)
|
||||
additionalHeaders := http.Header{
|
||||
"User-Agent": []string{""},
|
||||
}
|
||||
ipConn, rsp, err := connectip.Dial(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true)
|
||||
if err != nil {
|
||||
_ = tr.Close()
|
||||
_ = conn.CloseWithError(0, "connect-ip dial failed")
|
||||
_ = udpConn.Close()
|
||||
if strings.Contains(err.Error(), "tls: access denied") {
|
||||
return udpConn, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
}
|
||||
return udpConn, nil, nil, nil, fmt.Errorf("failed to dial connect-ip: %w", err)
|
||||
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %w", err)
|
||||
}
|
||||
err = ipConn.AdvertiseRoute(ctx, []connectip.IPRoute{
|
||||
{
|
||||
@@ -109,34 +136,16 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return udpConn, nil, nil, nil, err
|
||||
_ = ipConn.Close()
|
||||
_ = tr.Close()
|
||||
_ = udpConn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return udpConn, tr, ipConn, rsp, nil
|
||||
}
|
||||
|
||||
func newHTTP2Client(dialer N.Dialer, baseTLSConfig aTLS.Config, endpoint *net.TCPAddr, connectURI string) (*http.Client, error) {
|
||||
if endpoint == nil {
|
||||
return nil, errors.New("missing HTTP/2 endpoint")
|
||||
}
|
||||
tlsConfig := baseTLSConfig.Clone()
|
||||
tlsConfig.SetNextProtos([]string{"h2"})
|
||||
return &http.Client{
|
||||
Transport: &http2.Transport{
|
||||
DialTLSContext: func(ctx context.Context, network, _ string, _ *tls.Config) (net.Conn, error) {
|
||||
conn, err := dialer.DialContext(ctx, network, M.SocksaddrFromNetIP(endpoint.AddrPort()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConn, err := tlsConfig.Client(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return tlsConn, nil
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
closer := closerFunc(func() error {
|
||||
_ = tr.Close()
|
||||
_ = udpConn.Close()
|
||||
return nil
|
||||
})
|
||||
return closer, newQuicIpConn(ipConn), rsp, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
@@ -23,4 +25,5 @@ type TunnelOptions struct {
|
||||
UDPKeepalivePeriod time.Duration
|
||||
UDPInitialPacketSize uint16
|
||||
ReconnectDelay time.Duration
|
||||
CongestionControl func(conn *quic.Conn) congestion.CongestionControl
|
||||
}
|
||||
|
||||
@@ -4,13 +4,12 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
connectip "github.com/Diniboy1123/connect-ip-go"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
@@ -22,9 +21,8 @@ type Tunnel struct {
|
||||
options TunnelOptions
|
||||
device Device
|
||||
|
||||
udpConn net.PacketConn
|
||||
tr *http3.Transport
|
||||
ipConn *connectip.Conn
|
||||
closer io.Closer
|
||||
ipConn IpConn
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
@@ -83,13 +81,11 @@ func (e *Tunnel) Close() error {
|
||||
defer e.mtx.Unlock()
|
||||
if e.ipConn != nil {
|
||||
e.ipConn.Close()
|
||||
if e.udpConn != nil {
|
||||
e.udpConn.Close()
|
||||
}
|
||||
if e.tr != nil {
|
||||
e.tr.Close()
|
||||
if e.closer != nil {
|
||||
e.closer.Close()
|
||||
}
|
||||
e.ipConn = nil
|
||||
e.closer = nil
|
||||
}
|
||||
return e.device.Close()
|
||||
}
|
||||
@@ -124,7 +120,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
}
|
||||
icmp, err := ipConn.WritePacket(packet)
|
||||
if err != nil {
|
||||
if errors.As(err, new(*connectip.CloseError)) {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
if ok := e.closeIpConn(ipConn); ok {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing to IP connection: %w", err))
|
||||
}
|
||||
@@ -135,7 +131,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
}
|
||||
if len(icmp) > 0 {
|
||||
if _, err := e.device.Write([][]byte{icmp}, 0); err != nil {
|
||||
if errors.As(err, new(*connectip.CloseError)) {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing ICMP to TUN device: %v", err))
|
||||
continue
|
||||
}
|
||||
@@ -145,15 +141,14 @@ func (e *Tunnel) maintainTunnel() {
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
buf := make([]byte, 1280)
|
||||
for e.ctx.Err() == nil {
|
||||
ipConn, err := e.getIpConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n, err := ipConn.ReadPacket(buf, true)
|
||||
packet, err := ipConn.ReadPacket()
|
||||
if err != nil {
|
||||
if e.options.UseHTTP2 || errors.As(err, new(*connectip.CloseError)) {
|
||||
if e.options.UseHTTP2 || errors.Is(err, net.ErrClosed) {
|
||||
if ok := e.closeIpConn(ipConn); ok {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while reading from IP connection: %v", err))
|
||||
}
|
||||
@@ -162,7 +157,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("Error reading from IP connection: %v, continuine...", err))
|
||||
continue
|
||||
}
|
||||
if _, err := e.device.Write([][]byte{buf[:n]}, 0); err != nil {
|
||||
if _, err := e.device.Write([][]byte{packet}, 0); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -170,7 +165,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
<-e.ctx.Done()
|
||||
}
|
||||
|
||||
func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
func (e *Tunnel) getIpConn() (IpConn, error) {
|
||||
e.mtx.Lock()
|
||||
defer e.mtx.Unlock()
|
||||
if e.ctx.Err() != nil {
|
||||
@@ -184,7 +179,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
defer timer.Stop()
|
||||
for {
|
||||
e.logger.NoticeContext(e.ctx, fmt.Errorf("Establishing MASQUE connection to %s", e.options.Endpoint))
|
||||
udpConn, tr, ipConn, rsp, err := ConnectTunnel(
|
||||
closer, ipConn, rsp, err := ConnectTunnel(
|
||||
e.ctx,
|
||||
e.options.Dialer,
|
||||
e.options.TLSConfig,
|
||||
@@ -192,6 +187,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
"https://cloudflareaccess.com",
|
||||
e.options.Endpoint,
|
||||
e.options.UseHTTP2,
|
||||
e.options.CongestionControl,
|
||||
)
|
||||
if err != nil {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("Failed to connect tunnel: %v", err))
|
||||
@@ -206,11 +202,8 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
if rsp.StatusCode != 200 {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("Tunnel connection failed: %s", rsp.Status))
|
||||
ipConn.Close()
|
||||
if udpConn != nil {
|
||||
udpConn.Close()
|
||||
}
|
||||
if tr != nil {
|
||||
tr.Close()
|
||||
if closer != nil {
|
||||
closer.Close()
|
||||
}
|
||||
timer.Reset(e.options.ReconnectDelay)
|
||||
select {
|
||||
@@ -220,26 +213,23 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
e.udpConn = udpConn
|
||||
e.tr = tr
|
||||
e.closer = closer
|
||||
e.ipConn = ipConn
|
||||
e.logger.NoticeContext(e.ctx, "Connected to MASQUE server ", e.options.Endpoint)
|
||||
return ipConn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Tunnel) closeIpConn(ipConn *connectip.Conn) bool {
|
||||
func (e *Tunnel) closeIpConn(ipConn IpConn) bool {
|
||||
e.mtx.Lock()
|
||||
defer e.mtx.Unlock()
|
||||
if ipConn == e.ipConn {
|
||||
e.ipConn.Close()
|
||||
if e.udpConn != nil {
|
||||
e.udpConn.Close()
|
||||
}
|
||||
if e.tr != nil {
|
||||
e.tr.Close()
|
||||
if e.closer != nil {
|
||||
e.closer.Close()
|
||||
}
|
||||
e.ipConn = nil
|
||||
e.closer = nil
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
@@ -23,7 +24,7 @@ const (
|
||||
|
||||
type DataCipher interface {
|
||||
Encrypt(header []byte, packetID uint32, payload []byte) ([]byte, error)
|
||||
Decrypt(packet []byte, headerSize int) ([]byte, error)
|
||||
Decrypt(packet []byte, headerSize int) (plaintext []byte, packetID uint32, err error)
|
||||
}
|
||||
|
||||
type AEADDataCipher struct {
|
||||
@@ -86,9 +87,9 @@ func (g *AEADDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) {
|
||||
func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
|
||||
if len(packet) < headerSize+4+AESGCMTagSize+1 {
|
||||
return nil, errors.New("openvpn gcm data packet too short")
|
||||
return nil, 0, errors.New("openvpn gcm data packet too short")
|
||||
}
|
||||
header := packet[:headerSize]
|
||||
pidBytes := packet[headerSize : headerSize+4]
|
||||
@@ -96,8 +97,13 @@ func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error)
|
||||
ciphertext := packet[headerSize+4+AESGCMTagSize:]
|
||||
combined := append(ciphertext, tag...)
|
||||
ad := append(header, pidBytes...)
|
||||
nonce := g.nonce(binary.BigEndian.Uint32(pidBytes), g.recvImplicitIV)
|
||||
return g.recv.Open(nil, nonce[:], combined, ad)
|
||||
packetID := binary.BigEndian.Uint32(pidBytes)
|
||||
nonce := g.nonce(packetID, g.recvImplicitIV)
|
||||
plain, err := g.recv.Open(nil, nonce[:], combined, ad)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return plain, packetID, nil
|
||||
}
|
||||
|
||||
func (g *AEADDataCipher) nonce(packetID uint32, implicit [AESGCMIVSize]byte) [AESGCMIVSize]byte {
|
||||
@@ -127,6 +133,9 @@ func NewCBCCipher(keys *KeyMaterial, auth string) (*CBCDataCipher, error) {
|
||||
var newHash func() hash.Hash
|
||||
var hmacSize int
|
||||
switch auth {
|
||||
case AuthMD5:
|
||||
newHash = md5.New
|
||||
hmacSize = md5.Size
|
||||
case AuthSHA256:
|
||||
newHash = sha256.New
|
||||
hmacSize = sha256.Size
|
||||
@@ -176,34 +185,35 @@ func (c *CBCDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) {
|
||||
func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
|
||||
minSize := headerSize + c.hmacSize + CBCIVSize + aes.BlockSize
|
||||
if len(packet) < minSize {
|
||||
return nil, errors.New("openvpn cbc data packet too short")
|
||||
return nil, 0, errors.New("openvpn cbc data packet too short")
|
||||
}
|
||||
tag := packet[headerSize : headerSize+c.hmacSize]
|
||||
iv := packet[headerSize+c.hmacSize : headerSize+c.hmacSize+CBCIVSize]
|
||||
ct := packet[headerSize+c.hmacSize+CBCIVSize:]
|
||||
if len(ct)%aes.BlockSize != 0 {
|
||||
return nil, errors.New("openvpn cbc ciphertext not block-aligned")
|
||||
return nil, 0, errors.New("openvpn cbc ciphertext not block-aligned")
|
||||
}
|
||||
mac := hmac.New(c.newHash, c.recvHMAC)
|
||||
mac.Write(iv)
|
||||
mac.Write(ct)
|
||||
if !hmac.Equal(tag, mac.Sum(nil)) {
|
||||
return nil, errors.New("openvpn cbc hmac verification failed")
|
||||
return nil, 0, errors.New("openvpn cbc hmac verification failed")
|
||||
}
|
||||
plain := make([]byte, len(ct))
|
||||
cipher.NewCBCDecrypter(c.recvBlock, iv).CryptBlocks(plain, ct)
|
||||
padLen := int(plain[len(plain)-1])
|
||||
if padLen < 1 || padLen > aes.BlockSize {
|
||||
return nil, errors.New("openvpn cbc invalid padding")
|
||||
return nil, 0, errors.New("openvpn cbc invalid padding")
|
||||
}
|
||||
plain = plain[:len(plain)-padLen]
|
||||
if len(plain) < 4 {
|
||||
return nil, errors.New("openvpn cbc payload too short")
|
||||
return nil, 0, errors.New("openvpn cbc payload too short")
|
||||
}
|
||||
return plain[4:], nil
|
||||
packetID := binary.BigEndian.Uint32(plain[:4])
|
||||
return plain[4:], packetID, nil
|
||||
}
|
||||
|
||||
func CipherKeyLength(cipher string) int {
|
||||
|
||||
@@ -8,12 +8,16 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
const defaultHandshakeTimeout = 30 * time.Second
|
||||
const (
|
||||
defaultHandshakeTimeout = 30 * time.Second
|
||||
controlRetransmitDelay = time.Second
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
config *ClientConfig
|
||||
@@ -26,6 +30,8 @@ type Client struct {
|
||||
push *PushReply
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
lastReceiveNano atomic.Int64
|
||||
}
|
||||
|
||||
func NewClient(config *ClientConfig, io PacketIO, tlsConfig tls.Config) (*Client, error) {
|
||||
@@ -154,6 +160,7 @@ func (c *Client) Handshake(ctx context.Context) (*PushReply, error) {
|
||||
return nil, err
|
||||
}
|
||||
c.data = NewDataChannel(cipher, push.PeerID, push.CompLZO)
|
||||
c.markReceive()
|
||||
return push, nil
|
||||
}
|
||||
|
||||
@@ -181,10 +188,21 @@ func (c *Client) ReadIPPacket(ctx context.Context) ([]byte, error) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c.markReceive()
|
||||
return plain, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SinceReceive() time.Duration {
|
||||
return time.Duration(int64(time.Since(clientStart)) - c.lastReceiveNano.Load())
|
||||
}
|
||||
|
||||
func (c *Client) markReceive() {
|
||||
c.lastReceiveNano.Store(int64(time.Since(clientStart)))
|
||||
}
|
||||
|
||||
var clientStart = time.Now().Add(-time.Hour)
|
||||
|
||||
func (c *Client) Close() error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
@@ -199,10 +217,24 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) waitServerReset(ctx context.Context) error {
|
||||
retransmits := 0
|
||||
for {
|
||||
packet, err := c.control.Read(ctx)
|
||||
readCtx := ctx
|
||||
cancel := func() {}
|
||||
if c.config.Proto == ProtoUDP {
|
||||
readCtx, cancel = context.WithTimeout(ctx, controlRetransmitDelay)
|
||||
}
|
||||
packet, err := c.control.Read(readCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read hard reset response: %w", err)
|
||||
if c.config.Proto == ProtoUDP && errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil {
|
||||
if err := c.control.RetransmitPending(ctx); err != nil {
|
||||
return fmt.Errorf("retransmit hard reset: %w", err)
|
||||
}
|
||||
retransmits++
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read hard reset response after %d retransmits: %w", retransmits, err)
|
||||
}
|
||||
switch packet.Opcode {
|
||||
case PControlHardResetServerV2:
|
||||
|
||||
@@ -20,6 +20,7 @@ const (
|
||||
CipherAES256CBC = "AES-256-CBC"
|
||||
CipherCHACHA20POLY = "CHACHA20-POLY1305"
|
||||
|
||||
AuthMD5 = "MD5"
|
||||
AuthSHA1 = "SHA1"
|
||||
AuthSHA256 = "SHA256"
|
||||
AuthSHA384 = "SHA384"
|
||||
@@ -107,7 +108,7 @@ func isValidCipher(cipher string) bool {
|
||||
|
||||
func isValidAuth(auth string) bool {
|
||||
switch auth {
|
||||
case AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512:
|
||||
case AuthMD5, AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -30,8 +30,10 @@ type ControlChannel struct {
|
||||
mu sync.Mutex
|
||||
sendPacketID uint32
|
||||
sendMessage uint32
|
||||
recvMessage uint32
|
||||
ackPending []uint32
|
||||
pending map[uint32]*ControlPacket
|
||||
recvPending map[uint32]*ControlPacket
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
@@ -40,9 +42,10 @@ func NewControlChannel(io PacketIO, crypt ControlCrypt, local SessionID) *Contro
|
||||
ch := &ControlChannel{
|
||||
io: io,
|
||||
|
||||
clock: time.Now,
|
||||
local: local,
|
||||
pending: make(map[uint32]*ControlPacket),
|
||||
clock: time.Now,
|
||||
local: local,
|
||||
pending: make(map[uint32]*ControlPacket),
|
||||
recvPending: make(map[uint32]*ControlPacket),
|
||||
}
|
||||
if crypt != nil {
|
||||
ch.encode = func(p *ControlPacket, pid uint32, t uint32) ([]byte, error) {
|
||||
@@ -130,10 +133,23 @@ func (c *ControlChannel) SendAck(ctx context.Context) error {
|
||||
|
||||
func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
|
||||
for {
|
||||
c.mu.Lock()
|
||||
if packet, ok := c.recvPending[c.recvMessage]; ok {
|
||||
delete(c.recvPending, c.recvMessage)
|
||||
c.recvMessage++
|
||||
c.mu.Unlock()
|
||||
return packet, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
packet, err := c.readControlPacket(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var deliver *ControlPacket
|
||||
sendAck := false
|
||||
|
||||
c.mu.Lock()
|
||||
if c.remote == (SessionID{}) && packet.LocalSession != c.local {
|
||||
c.remote = packet.LocalSession
|
||||
@@ -144,11 +160,33 @@ func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
|
||||
if packet.Opcode.HasMessageID() {
|
||||
c.ackPending = appendAck(c.ackPending, packet.MessageID)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
if packet.Opcode == PAckV1 {
|
||||
continue
|
||||
|
||||
switch {
|
||||
case packet.Opcode == PAckV1:
|
||||
case !packet.Opcode.HasMessageID():
|
||||
deliver = packet
|
||||
case packet.MessageID < c.recvMessage:
|
||||
sendAck = true
|
||||
case packet.MessageID == c.recvMessage:
|
||||
deliver = packet
|
||||
c.recvMessage++
|
||||
default:
|
||||
if _, exists := c.recvPending[packet.MessageID]; !exists {
|
||||
c.recvPending[packet.MessageID] = packet
|
||||
}
|
||||
sendAck = true
|
||||
}
|
||||
|
||||
c.mu.Unlock()
|
||||
|
||||
if deliver != nil {
|
||||
return deliver, nil
|
||||
}
|
||||
if sendAck {
|
||||
if err := c.SendAck(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return packet, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -349,11 +387,17 @@ func (c *ControlConn) SetWriteDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
type streamPacketIO struct {
|
||||
conn net.Conn
|
||||
conn net.Conn
|
||||
deadlineMu sync.Mutex
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
|
||||
type datagramPacketIO struct {
|
||||
conn net.Conn
|
||||
conn net.Conn
|
||||
deadlineMu sync.Mutex
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
|
||||
func NewDatagramPacketIO(conn net.Conn) PacketIO {
|
||||
@@ -361,40 +405,23 @@ func NewDatagramPacketIO(conn net.Conn) PacketIO {
|
||||
}
|
||||
|
||||
func (d *datagramPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
|
||||
done := make(chan struct{})
|
||||
var (
|
||||
packet []byte
|
||||
err error
|
||||
)
|
||||
go func() {
|
||||
defer close(done)
|
||||
buf := make([]byte, 64*1024)
|
||||
var n int
|
||||
n, err = d.conn.Read(buf)
|
||||
if err == nil {
|
||||
packet = cloneBytes(buf[:n])
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return packet, err
|
||||
if err := setReadDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.readDeadline); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := make([]byte, 64*1024)
|
||||
n, err := d.conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, contextIOError(ctx, err)
|
||||
}
|
||||
return buf[:n], nil
|
||||
}
|
||||
|
||||
func (d *datagramPacketIO) WritePacket(ctx context.Context, packet []byte) error {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := d.conn.Write(packet)
|
||||
done <- err
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
if err := setWriteDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.writeDeadline); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := d.conn.Write(packet)
|
||||
return contextIOError(ctx, err)
|
||||
}
|
||||
|
||||
func (d *datagramPacketIO) Close() error {
|
||||
@@ -414,52 +441,37 @@ func NewTCPPacketIO(conn net.Conn) PacketIO {
|
||||
}
|
||||
|
||||
func (s *streamPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
|
||||
done := make(chan struct{})
|
||||
var (
|
||||
packet []byte
|
||||
err error
|
||||
)
|
||||
go func() {
|
||||
defer close(done)
|
||||
var lenBuf [2]byte
|
||||
if _, err = io.ReadFull(s.conn, lenBuf[:]); err != nil {
|
||||
return
|
||||
}
|
||||
size := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if size == 0 {
|
||||
err = errors.New("empty openvpn tcp packet")
|
||||
return
|
||||
}
|
||||
packet = make([]byte, size)
|
||||
_, err = io.ReadFull(s.conn, packet)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return packet, err
|
||||
if err := setReadDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.readDeadline); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var lenBuf [2]byte
|
||||
if _, err := io.ReadFull(s.conn, lenBuf[:]); err != nil {
|
||||
return nil, contextIOError(ctx, err)
|
||||
}
|
||||
size := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if size == 0 {
|
||||
return nil, errors.New("empty openvpn tcp packet")
|
||||
}
|
||||
packet := make([]byte, size)
|
||||
if _, err := io.ReadFull(s.conn, packet); err != nil {
|
||||
return nil, contextIOError(ctx, err)
|
||||
}
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func (s *streamPacketIO) WritePacket(ctx context.Context, packet []byte) error {
|
||||
if len(packet) > 0xffff {
|
||||
return fmt.Errorf("openvpn tcp packet too large: %d", len(packet))
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
frame := make([]byte, 2+len(packet))
|
||||
frame[0] = byte(len(packet) >> 8)
|
||||
frame[1] = byte(len(packet))
|
||||
copy(frame[2:], packet)
|
||||
_, err := s.conn.Write(frame)
|
||||
done <- err
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
if err := setWriteDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.writeDeadline); err != nil {
|
||||
return err
|
||||
}
|
||||
frame := make([]byte, 2+len(packet))
|
||||
frame[0] = byte(len(packet) >> 8)
|
||||
frame[1] = byte(len(packet))
|
||||
copy(frame[2:], packet)
|
||||
_, err := s.conn.Write(frame)
|
||||
return contextIOError(ctx, err)
|
||||
}
|
||||
|
||||
func (s *streamPacketIO) Close() error {
|
||||
@@ -473,3 +485,50 @@ func (s *streamPacketIO) LocalAddr() net.Addr {
|
||||
func (s *streamPacketIO) RemoteAddr() net.Addr {
|
||||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func setReadDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if current.Equal(deadline) {
|
||||
return nil
|
||||
}
|
||||
if hasDeadline {
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return err
|
||||
}
|
||||
*current = deadline
|
||||
return nil
|
||||
}
|
||||
|
||||
func setWriteDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if current.Equal(deadline) {
|
||||
return nil
|
||||
}
|
||||
if hasDeadline {
|
||||
if err := conn.SetWriteDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := conn.SetWriteDeadline(time.Time{}); err != nil {
|
||||
return err
|
||||
}
|
||||
*current = deadline
|
||||
return nil
|
||||
}
|
||||
|
||||
func contextIOError(ctx context.Context, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() && ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,15 +8,21 @@ import (
|
||||
|
||||
const (
|
||||
PeerIDUnset uint32 = 0xffffff
|
||||
|
||||
dataChannelReplayWindow = 64
|
||||
)
|
||||
|
||||
type DataChannel struct {
|
||||
cipher DataCipher
|
||||
keyID uint8
|
||||
peerID uint32
|
||||
compLZO bool
|
||||
cipher DataCipher
|
||||
keyID uint8
|
||||
peerID uint32
|
||||
compLZO bool
|
||||
|
||||
mu sync.Mutex
|
||||
sendPacketID uint32
|
||||
recvHighest uint32
|
||||
recvWindow uint64
|
||||
recvSeen bool
|
||||
}
|
||||
|
||||
func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel {
|
||||
@@ -29,10 +35,11 @@ func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel
|
||||
|
||||
func (d *DataChannel) Encrypt(packet []byte) ([]byte, error) {
|
||||
if d.compLZO {
|
||||
p := make([]byte, 1+len(packet))
|
||||
p[0] = 0xFA
|
||||
copy(p[1:], packet)
|
||||
packet = p
|
||||
compressed, err := lzo1xCompressSafe(packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packet = compressed
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.sendPacketID++
|
||||
@@ -50,18 +57,15 @@ func (d *DataChannel) Decrypt(packet []byte) ([]byte, error) {
|
||||
if opcode == PDataV2 {
|
||||
headerSize = 4
|
||||
}
|
||||
plain, err := d.cipher.Decrypt(packet, headerSize)
|
||||
plain, packetID, err := d.cipher.Decrypt(packet, headerSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.acceptPacketID(packetID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.compLZO {
|
||||
if len(plain) < 1 {
|
||||
return nil, errors.New("openvpn comp-lzo packet too short")
|
||||
}
|
||||
if plain[0] != 0xFA {
|
||||
return nil, fmt.Errorf("openvpn compressed packet not supported (byte: 0x%02x)", plain[0])
|
||||
}
|
||||
plain = plain[1:]
|
||||
return lzo1xDecompressSafe(plain)
|
||||
}
|
||||
return plain, nil
|
||||
}
|
||||
@@ -78,6 +82,40 @@ func (d *DataChannel) dataHeader() []byte {
|
||||
return []byte{opcodeKeyID(PDataV1, d.keyID)}
|
||||
}
|
||||
|
||||
func (d *DataChannel) acceptPacketID(packetID uint32) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if !d.recvSeen {
|
||||
d.recvHighest = packetID
|
||||
d.recvWindow = 1
|
||||
d.recvSeen = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if packetID > d.recvHighest {
|
||||
shift := packetID - d.recvHighest
|
||||
if shift >= dataChannelReplayWindow {
|
||||
d.recvWindow = 1
|
||||
} else {
|
||||
d.recvWindow = d.recvWindow<<shift | 1
|
||||
}
|
||||
d.recvHighest = packetID
|
||||
return nil
|
||||
}
|
||||
|
||||
diff := d.recvHighest - packetID
|
||||
if diff >= dataChannelReplayWindow {
|
||||
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
|
||||
}
|
||||
mask := uint64(1) << diff
|
||||
if d.recvWindow&mask != 0 {
|
||||
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
|
||||
}
|
||||
d.recvWindow |= mask
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParsePeerID(options string) uint32 {
|
||||
for _, field := range splitPushOptions(options) {
|
||||
if len(field) > len("peer-id ") && field[:len("peer-id ")] == "peer-id " {
|
||||
|
||||
444
transport/openvpn/e2e_test.go
Normal file
444
transport/openvpn/e2e_test.go
Normal file
@@ -0,0 +1,444 @@
|
||||
//go:build with_openvpn && with_gvisor
|
||||
|
||||
// OpenVPN E2E tests. Require a local OpenVPN server setup.
|
||||
//
|
||||
// Setup (run once before testing):
|
||||
//
|
||||
// # Generate PKI
|
||||
// mkdir -p /tmp/ovpn-e2e/pki/{issued,private}
|
||||
// cd /tmp/ovpn-e2e/pki
|
||||
// openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -days 1 -nodes -keyout ca.key -out ca.crt -subj "/CN=E2ETestCA"
|
||||
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/server.key -out server.csr -subj "/CN=server"
|
||||
// openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/server.crt -days 1
|
||||
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/client.key -out client.csr -subj "/CN=client"
|
||||
// openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/client.crt -days 1
|
||||
// openvpn --genkey secret ta.key
|
||||
// openvpn --genkey secret ta-auth.key
|
||||
//
|
||||
// # Start servers (4 instances: TCP/UDP × tls-crypt/tls-auth)
|
||||
// # TCP + tls-crypt on :11940, subnet 10.99.0.0/24
|
||||
// # UDP + tls-crypt on :11941, subnet 10.99.1.0/24
|
||||
// # TCP + tls-auth on :11942, subnet 10.99.2.0/24
|
||||
// # UDP + tls-auth on :11943, subnet 10.99.3.0/24
|
||||
// #
|
||||
// # Each server config needs: topology subnet, duplicate-cn, persist-tun,
|
||||
// # data-ciphers AES-256-GCM:AES-128-GCM:AES-192-GCM:CHACHA20-POLY1305:AES-256-CBC:AES-128-CBC:AES-192-CBC
|
||||
// # auth SHA256, keepalive 10 60, ca/cert/key from above PKI.
|
||||
// # tls-auth servers use: tls-auth ta-auth.key 0
|
||||
// # tls-crypt servers use: tls-crypt ta.key
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-crypt.conf --daemon
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-crypt.conf --daemon
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-auth.conf --daemon
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-auth.conf --daemon
|
||||
//
|
||||
// # Start HTTP servers on each VPN subnet
|
||||
// for ip in 10.99.0.1 10.99.1.1 10.99.2.1 10.99.3.1; do
|
||||
// mkdir -p /tmp/ovpn-e2e/$ip && echo "hello" > /tmp/ovpn-e2e/$ip/index.html
|
||||
// cd /tmp/ovpn-e2e/$ip && python3 -m http.server 8080 --bind $ip &
|
||||
// done
|
||||
//
|
||||
// Run tests:
|
||||
//
|
||||
// go test -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_manager,with_admin_panel,with_v2ray_api,with_ccm,with_ocm,with_profiler,with_openvpn,with_sudoku,with_trusttunnel" \
|
||||
// -run TestE2E -v -count=1 ./transport/openvpn/ -timeout 300s
|
||||
//
|
||||
// Tests all 28 combinations: 2 protos (tcp/udp) × 2 TLS modes (tls-crypt/tls-auth) × 7 ciphers.
|
||||
|
||||
package openvpn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box"
|
||||
"github.com/sagernet/sing-box/include"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
// Servers (started externally):
|
||||
// TCP+tls-crypt :11940 subnet 10.99.0.0/24
|
||||
// UDP+tls-crypt :11941 subnet 10.99.1.0/24
|
||||
// TCP+tls-auth :11942 subnet 10.99.2.0/24
|
||||
// UDP+tls-auth :11943 subnet 10.99.3.0/24
|
||||
// TCP+plain :11944 subnet 10.99.4.0/24
|
||||
// UDP+plain :11945 subnet 10.99.5.0/24
|
||||
// TCP+tls-crypt+SHA1 :11946 subnet 10.99.6.0/24 (CBC only)
|
||||
// TCP+tls-crypt+SHA512 :11947 subnet 10.99.7.0/24 (CBC only)
|
||||
// Each has HTTP on .1:8080 serving "hello"
|
||||
|
||||
const pkiDir = "/tmp/ovpn-e2e/pki"
|
||||
|
||||
type serverConfig struct {
|
||||
proto string
|
||||
port uint16
|
||||
tlsMode string // "tls-crypt" or "tls-auth"
|
||||
httpAddr string
|
||||
}
|
||||
|
||||
var servers = []serverConfig{
|
||||
{"tcp", 11940, "tls-crypt", "10.99.0.1:8080"},
|
||||
{"udp", 11941, "tls-crypt", "10.99.1.1:8080"},
|
||||
{"tcp", 11942, "tls-auth", "10.99.2.1:8080"},
|
||||
{"udp", 11943, "tls-auth", "10.99.3.1:8080"},
|
||||
}
|
||||
|
||||
var ciphers = []string{
|
||||
"AES-128-GCM",
|
||||
"AES-192-GCM",
|
||||
"AES-256-GCM",
|
||||
"CHACHA20-POLY1305",
|
||||
"AES-128-CBC",
|
||||
"AES-192-CBC",
|
||||
"AES-256-CBC",
|
||||
}
|
||||
|
||||
var portCounter atomic.Uint32
|
||||
|
||||
func init() { portCounter.Store(18100) }
|
||||
|
||||
func nextPort() uint16 { return uint16(portCounter.Add(1)) }
|
||||
|
||||
func readFile(t *testing.T, path string) string {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Skipf("PKI not found: %v", err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func testCombo(t *testing.T, srv serverConfig, cipher string) {
|
||||
t.Helper()
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
|
||||
ovpnOpts := &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: srv.port}},
|
||||
Proto: srv.proto,
|
||||
Cipher: cipher,
|
||||
Auth: "SHA256",
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
}
|
||||
|
||||
switch srv.tlsMode {
|
||||
case "tls-crypt":
|
||||
ovpnOpts.TLSCrypt = readFile(t, pkiDir+"/ta.key")
|
||||
case "tls-auth":
|
||||
ovpnOpts.TLSAuth = readFile(t, pkiDir+"/ta-auth.key")
|
||||
ovpnOpts.KeyDirection = 1
|
||||
}
|
||||
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer instance.Close()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(srv.httpAddr))
|
||||
if err != nil {
|
||||
t.Fatal("dial:", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
|
||||
if err != nil {
|
||||
t.Fatal("write:", err)
|
||||
}
|
||||
body, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Fatal("read:", err)
|
||||
}
|
||||
if !strings.Contains(string(body), "hello") {
|
||||
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
|
||||
}
|
||||
}
|
||||
|
||||
// 4 servers × 7 ciphers = 28 combinations
|
||||
func TestE2E(t *testing.T) {
|
||||
for _, srv := range servers {
|
||||
for _, cipher := range ciphers {
|
||||
name := fmt.Sprintf("%s/%s/%s", srv.proto, srv.tlsMode, cipher)
|
||||
srv, cipher := srv, cipher
|
||||
t.Run(name, func(t *testing.T) {
|
||||
testCombo(t, srv, cipher)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test CBC ciphers with different auth algorithms (SHA1, SHA512)
|
||||
func TestE2E_Auth(t *testing.T) {
|
||||
type authServer struct {
|
||||
port uint16
|
||||
auth string
|
||||
httpAddr string
|
||||
}
|
||||
authServers := []authServer{
|
||||
{11946, "SHA1", "10.99.6.1:8080"},
|
||||
{11947, "SHA512", "10.99.7.1:8080"},
|
||||
}
|
||||
cbcCiphers := []string{"AES-128-CBC", "AES-256-CBC"}
|
||||
|
||||
for _, as := range authServers {
|
||||
for _, cipher := range cbcCiphers {
|
||||
name := fmt.Sprintf("auth-%s/%s", as.auth, cipher)
|
||||
as, cipher := as, cipher
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{
|
||||
Type: "openvpn", Tag: "vpn",
|
||||
Options: &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: as.port}},
|
||||
Proto: "tcp",
|
||||
Cipher: cipher,
|
||||
Auth: as.auth,
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer instance.Close()
|
||||
time.Sleep(2 * time.Second)
|
||||
doHTTPCheck(t, port, as.httpAddr)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test tunnel stability with multiple sequential requests
|
||||
func TestE2E_BulkData(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{
|
||||
Type: "openvpn", Tag: "vpn",
|
||||
Options: &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11940}},
|
||||
Proto: "tcp",
|
||||
Cipher: "AES-256-GCM",
|
||||
Auth: "SHA256",
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer instance.Close()
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
|
||||
for i := 0; i < 10; i++ {
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr("10.99.0.1:8080"))
|
||||
if err != nil {
|
||||
t.Fatalf("request %d dial: %v", i, err)
|
||||
}
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
fmt.Fprintf(conn, "GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n")
|
||||
body, err := io.ReadAll(conn)
|
||||
conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("request %d read: %v", i, err)
|
||||
}
|
||||
if !strings.Contains(string(body), "hello") {
|
||||
t.Fatalf("request %d: no 'hello'", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doHTTPCheck(t *testing.T, socksPort uint16, httpAddr string) {
|
||||
t.Helper()
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", socksPort), socks.Version5, "", "")
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(httpAddr))
|
||||
if err != nil {
|
||||
t.Fatal("dial:", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
|
||||
if err != nil {
|
||||
t.Fatal("write:", err)
|
||||
}
|
||||
body, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Fatal("read:", err)
|
||||
}
|
||||
if !strings.Contains(string(body), "hello") {
|
||||
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
|
||||
}
|
||||
}
|
||||
|
||||
func startInstance(t *testing.T, ovpnOpts *option.OpenVPNOutboundOptions) uint16 {
|
||||
t.Helper()
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { instance.Close() })
|
||||
time.Sleep(2 * time.Second)
|
||||
return port
|
||||
}
|
||||
|
||||
func TestE2E_CompLZO(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
|
||||
for _, cipher := range ciphers {
|
||||
cipher := cipher
|
||||
t.Run(cipher, func(t *testing.T) {
|
||||
port := startInstance(t, &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11948}},
|
||||
Proto: "udp",
|
||||
Cipher: cipher,
|
||||
Auth: "SHA256",
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
})
|
||||
doHTTPCheck(t, port, "10.99.8.1:8080")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestE2E_AES192(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
|
||||
type combo struct {
|
||||
proto string
|
||||
port uint16
|
||||
httpAddr string
|
||||
}
|
||||
for _, c := range []combo{
|
||||
{"tcp", 11940, "10.99.0.1:8080"},
|
||||
{"udp", 11941, "10.99.1.1:8080"},
|
||||
} {
|
||||
for _, cipher := range []string{"AES-192-GCM", "AES-192-CBC"} {
|
||||
c, cipher := c, cipher
|
||||
t.Run(fmt.Sprintf("%s/%s", c.proto, cipher), func(t *testing.T) {
|
||||
port := startInstance(t, &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: c.port}},
|
||||
Proto: c.proto,
|
||||
Cipher: cipher,
|
||||
Auth: "SHA256",
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
})
|
||||
doHTTPCheck(t, port, c.httpAddr)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ net.Conn
|
||||
@@ -114,7 +114,7 @@ func ParseServerKeyMethod2Record(packet []byte) (*KeyMethod2Record, error) {
|
||||
}
|
||||
|
||||
func DeriveClientKeyMaterial(sources KeySource2, clientSession, serverSession SessionID, cipherKeyLen int) (*KeyMaterial, error) {
|
||||
if cipherKeyLen != 16 && cipherKeyLen != 32 {
|
||||
if cipherKeyLen != 16 && cipherKeyLen != 24 && cipherKeyLen != 32 {
|
||||
return nil, fmt.Errorf("unsupported data cipher key length %d", cipherKeyLen)
|
||||
}
|
||||
var master [48]byte
|
||||
|
||||
48
transport/openvpn/lzo.go
Normal file
48
transport/openvpn/lzo.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/rasky/go-lzo"
|
||||
)
|
||||
|
||||
const (
|
||||
lzoCompressNone = 0xFA
|
||||
lzoCompressLZO = 0x66
|
||||
)
|
||||
|
||||
var ErrLZODecompress = errors.New("lzo decompression failed")
|
||||
|
||||
func lzo1xDecompressSafe(src []byte) ([]byte, error) {
|
||||
if len(src) == 0 {
|
||||
return nil, ErrLZODecompress
|
||||
}
|
||||
|
||||
switch src[0] {
|
||||
case lzoCompressNone:
|
||||
if len(src) > 1 {
|
||||
return src[1:], nil
|
||||
}
|
||||
return nil, nil
|
||||
case lzoCompressLZO:
|
||||
if len(src) > 1 {
|
||||
r := bytes.NewReader(src[1:])
|
||||
out, err := lzo.Decompress1X(r, len(src)-1, 0)
|
||||
if err != nil {
|
||||
return nil, ErrLZODecompress
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, ErrLZODecompress
|
||||
}
|
||||
}
|
||||
|
||||
func lzo1xCompressSafe(src []byte) ([]byte, error) {
|
||||
lzoPacket := make([]byte, 1+len(src))
|
||||
lzoPacket[0] = lzoCompressNone
|
||||
copy(lzoPacket[1:], src)
|
||||
return lzoPacket, nil
|
||||
}
|
||||
@@ -10,16 +10,17 @@ import (
|
||||
const PushRequest = "PUSH_REQUEST"
|
||||
|
||||
type PushReply struct {
|
||||
Raw string
|
||||
Prefixes []netip.Prefix
|
||||
DNS []netip.Addr
|
||||
PeerID uint32
|
||||
Cipher string
|
||||
Ping uint32
|
||||
MTU uint32
|
||||
CompLZO bool
|
||||
Redirect bool
|
||||
BlockIPv6 bool
|
||||
Raw string
|
||||
Prefixes []netip.Prefix
|
||||
DNS []netip.Addr
|
||||
PeerID uint32
|
||||
Cipher string
|
||||
Ping uint32
|
||||
PingRestart uint32
|
||||
MTU uint32
|
||||
CompLZO bool
|
||||
Redirect bool
|
||||
BlockIPv6 bool
|
||||
}
|
||||
|
||||
func ParsePushReply(message string) (*PushReply, error) {
|
||||
@@ -81,6 +82,12 @@ func ParsePushReply(message string) (*PushReply, error) {
|
||||
reply.Ping = uint32(v)
|
||||
}
|
||||
}
|
||||
case "ping-restart":
|
||||
if len(fields) >= 2 {
|
||||
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
|
||||
reply.PingRestart = uint32(v)
|
||||
}
|
||||
}
|
||||
case "tun-mtu":
|
||||
if len(fields) >= 2 {
|
||||
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
|
||||
@@ -113,27 +120,44 @@ func splitPushOptions(message string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseIPv4Ifconfig(address, mask string) (netip.Prefix, error) {
|
||||
func parseIPv4Ifconfig(address, maskOrPeer string) (netip.Prefix, error) {
|
||||
addr, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 address %q: %w", address, err)
|
||||
}
|
||||
maskAddr, err := netip.ParseAddr(mask)
|
||||
maskAddr, err := netip.ParseAddr(maskOrPeer)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", mask, err)
|
||||
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", maskOrPeer, err)
|
||||
}
|
||||
if !addr.Is4() || !maskAddr.Is4() {
|
||||
return netip.Prefix{}, fmt.Errorf("openvpn ifconfig requires ipv4 address and mask")
|
||||
}
|
||||
maskBytes := maskAddr.As4()
|
||||
|
||||
if ones, ok := ipv4MaskSize(maskAddr); ok {
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
}
|
||||
|
||||
// Some servers, including SoftEther/VPNGate in net30/p2p mode, push
|
||||
// "ifconfig <local> <remote>" rather than "ifconfig <local> <netmask>".
|
||||
// Use a host prefix for that local tunnel address.
|
||||
return netip.PrefixFrom(addr, 32), nil
|
||||
}
|
||||
|
||||
func ipv4MaskSize(mask netip.Addr) (int, bool) {
|
||||
maskBytes := mask.As4()
|
||||
ones := 0
|
||||
seenZero := false
|
||||
for _, b := range maskBytes {
|
||||
for i := 7; i >= 0; i-- {
|
||||
if b&(1<<i) == 0 {
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
seenZero = true
|
||||
continue
|
||||
}
|
||||
if seenZero {
|
||||
return 0, false
|
||||
}
|
||||
ones++
|
||||
}
|
||||
}
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
return ones, true
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package openvpn
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
@@ -35,6 +36,9 @@ func NewTLSAuth(staticKey []byte, keyDirection int, auth string) (*TLSAuth, erro
|
||||
var newHash func() hash.Hash
|
||||
var hmacSize int
|
||||
switch auth {
|
||||
case AuthMD5:
|
||||
newHash = md5.New
|
||||
hmacSize = md5.Size
|
||||
case AuthSHA256:
|
||||
newHash = sha256.New
|
||||
hmacSize = sha256.Size
|
||||
|
||||
@@ -30,16 +30,19 @@ type TunnelOptions struct {
|
||||
UDPTimeout time.Duration
|
||||
ReconnectDelay time.Duration
|
||||
PingInterval time.Duration
|
||||
PingRestart time.Duration
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger logger.ContextLogger
|
||||
options TunnelOptions
|
||||
device Device
|
||||
client *Client
|
||||
mtu uint32
|
||||
serverIndex int
|
||||
wg sync.WaitGroup
|
||||
|
||||
await chan struct{}
|
||||
mu sync.Mutex
|
||||
@@ -49,8 +52,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
|
||||
if options.ReconnectDelay == 0 {
|
||||
options.ReconnectDelay = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Tunnel{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
options: options,
|
||||
await: make(chan struct{}),
|
||||
@@ -59,10 +64,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
|
||||
|
||||
func (t *Tunnel) Start() error {
|
||||
go func() {
|
||||
defer close(t.await)
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
t.logger.Error("OpenVPN connect: ", err)
|
||||
close(t.await)
|
||||
return
|
||||
}
|
||||
t.mtu = 1500
|
||||
@@ -84,20 +89,26 @@ func (t *Tunnel) Start() error {
|
||||
if err != nil {
|
||||
client.Close()
|
||||
t.logger.Error("create OpenVPN device: ", err)
|
||||
close(t.await)
|
||||
return
|
||||
}
|
||||
t.device = device
|
||||
if err := device.Start(); err != nil {
|
||||
client.Close()
|
||||
t.logger.Error("start OpenVPN device: ", err)
|
||||
close(t.await)
|
||||
return
|
||||
}
|
||||
close(t.await)
|
||||
t.maintainTunnel()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if err := t.isTunnelInitialized(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
@@ -105,6 +116,9 @@ func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.
|
||||
}
|
||||
|
||||
func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if err := t.isTunnelInitialized(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
@@ -112,15 +126,18 @@ func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
|
||||
}
|
||||
|
||||
func (t *Tunnel) Close() error {
|
||||
t.cancel()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.client != nil {
|
||||
t.client.Close()
|
||||
t.client = nil
|
||||
}
|
||||
if t.device != nil {
|
||||
return t.device.Close()
|
||||
t.device.Close()
|
||||
t.device = nil
|
||||
}
|
||||
t.mu.Unlock()
|
||||
t.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -137,7 +154,9 @@ func (t *Tunnel) isTunnelInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (t *Tunnel) maintainTunnel() {
|
||||
t.wg.Add(2)
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
bufs := make([][]byte, 1)
|
||||
bufs[0] = make([]byte, t.mtu)
|
||||
sizes := make([]int, 1)
|
||||
@@ -161,6 +180,7 @@ func (t *Tunnel) maintainTunnel() {
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
for t.ctx.Err() == nil {
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
@@ -179,10 +199,14 @@ func (t *Tunnel) maintainTunnel() {
|
||||
if bytes.Equal(packet, pingPayload) {
|
||||
continue
|
||||
}
|
||||
if t.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if t.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if _, err := t.device.Write([][]byte{packet}, 0); err != nil {
|
||||
if t.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -208,6 +232,34 @@ func (t *Tunnel) maintainTunnel() {
|
||||
}
|
||||
}()
|
||||
}
|
||||
pingRestart := t.options.PingRestart
|
||||
if pingRestart == 0 && t.client != nil && t.client.push.PingRestart > 0 {
|
||||
pingRestart = time.Duration(t.client.push.PingRestart) * time.Second
|
||||
}
|
||||
if pingRestart > 0 {
|
||||
t.wg.Add(1)
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
ticker := time.NewTicker(pingRestart)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if client.SinceReceive() >= pingRestart {
|
||||
if ok := t.closeClient(client); ok {
|
||||
t.logger.ErrorContext(t.ctx, fmt.Errorf("ping-restart timeout: no packet received for %s", pingRestart))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
<-t.ctx.Done()
|
||||
}
|
||||
|
||||
|
||||
100
transport/simple-obfs/http_server.go
Normal file
100
transport/simple-obfs/http_server.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package obfs
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTPObfsServer is the server side of the shadowsocks http simple-obfs implementation.
|
||||
type HTTPObfsServer struct {
|
||||
net.Conn
|
||||
buf []byte
|
||||
bio *bufio.Reader
|
||||
offset int
|
||||
firstRequest bool
|
||||
firstResponse bool
|
||||
}
|
||||
|
||||
func (hos *HTTPObfsServer) Read(b []byte) (int, error) {
|
||||
if hos.buf != nil {
|
||||
n := copy(b, hos.buf[hos.offset:])
|
||||
hos.offset += n
|
||||
if hos.offset == len(hos.buf) {
|
||||
hos.offset = 0
|
||||
hos.buf = nil
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
if hos.firstRequest {
|
||||
bio := bufio.NewReader(hos.Conn)
|
||||
req, err := http.ReadRequest(bio)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if req.Method != "GET" || req.Header.Get("Connection") != "Upgrade" {
|
||||
return 0, io.EOF
|
||||
}
|
||||
buf, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := copy(b, buf)
|
||||
if n < len(buf) {
|
||||
hos.buf = buf
|
||||
hos.offset = n
|
||||
}
|
||||
req.Body.Close()
|
||||
hos.bio = bio
|
||||
hos.firstRequest = false
|
||||
return n, nil
|
||||
}
|
||||
return hos.bio.Read(b)
|
||||
}
|
||||
|
||||
func (hos *HTTPObfsServer) Write(b []byte) (int, error) {
|
||||
if hos.firstResponse {
|
||||
randBytes := make([]byte, 16)
|
||||
cryptorand.Read(randBytes)
|
||||
date := time.Now().Format(time.RFC1123)
|
||||
resp := fmt.Sprintf(httpResponseTemplate, vMajor, vMinor, date, base64.URLEncoding.EncodeToString(randBytes))
|
||||
_, err := hos.Conn.Write([]byte(resp))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
hos.firstResponse = false
|
||||
}
|
||||
return hos.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (hos *HTTPObfsServer) Upstream() any {
|
||||
return hos.Conn
|
||||
}
|
||||
|
||||
// NewHTTPObfsServer returns a server-side HTTPObfs.
|
||||
func NewHTTPObfsServer(conn net.Conn) net.Conn {
|
||||
return &HTTPObfsServer{
|
||||
Conn: conn,
|
||||
firstRequest: true,
|
||||
firstResponse: true,
|
||||
}
|
||||
}
|
||||
|
||||
const httpResponseTemplate = "HTTP/1.1 101 Switching Protocols\r\n" +
|
||||
"Server: nginx/1.%d.%d\r\n" +
|
||||
"Date: %s\r\n" +
|
||||
"Upgrade: websocket\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Sec-WebSocket-Accept: %s\r\n" +
|
||||
"\r\n"
|
||||
|
||||
var (
|
||||
vMajor = rand.IntN(11)
|
||||
vMinor = rand.IntN(12)
|
||||
)
|
||||
154
transport/simple-obfs/tls_server.go
Normal file
154
transport/simple-obfs/tls_server.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package obfs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
B "github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
// TLSObfsServer is the server side of the shadowsocks tls simple-obfs implementation.
|
||||
type TLSObfsServer struct {
|
||||
net.Conn
|
||||
remain int
|
||||
firstRequest bool
|
||||
sessionTicketDone bool
|
||||
firstResponse bool
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) read(b []byte, discardN int) (int, error) {
|
||||
buf := B.Get(discardN)
|
||||
_, err := io.ReadFull(tos.Conn, buf)
|
||||
B.Put(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizeBuf := make([]byte, 2)
|
||||
_, err = io.ReadFull(tos.Conn, sizeBuf)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
length := int(binary.BigEndian.Uint16(sizeBuf))
|
||||
if length > len(b) {
|
||||
n, err := tos.Conn.Read(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
tos.remain = length - n
|
||||
return n, nil
|
||||
}
|
||||
return io.ReadFull(tos.Conn, b[:length])
|
||||
}
|
||||
|
||||
// skipOtherExts skips SNI & other TLS extensions.
|
||||
func (tos *TLSObfsServer) skipOtherExts() error {
|
||||
buf := make([]byte, 256)
|
||||
_, err := tos.read(buf, 7)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.ReadFull(tos.Conn, buf[:4*16+2])
|
||||
return err
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) Read(b []byte) (int, error) {
|
||||
if tos.remain > 0 {
|
||||
length := tos.remain
|
||||
if length > len(b) {
|
||||
length = len(b)
|
||||
}
|
||||
n, err := io.ReadFull(tos.Conn, b[:length])
|
||||
tos.remain -= n
|
||||
return n, err
|
||||
}
|
||||
if tos.firstRequest {
|
||||
tos.firstRequest = false
|
||||
return tos.read(b, 9*16-4)
|
||||
}
|
||||
if !tos.sessionTicketDone {
|
||||
tos.sessionTicketDone = true
|
||||
err := tos.skipOtherExts()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return tos.read(b, 3)
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) Write(b []byte) (int, error) {
|
||||
length := len(b)
|
||||
for i := 0; i < length; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > length {
|
||||
end = length
|
||||
}
|
||||
n, err := tos.write(b[i:end])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
return length, nil
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) write(b []byte) (int, error) {
|
||||
if tos.firstResponse {
|
||||
serverHello := makeServerHello(b)
|
||||
_, err := tos.Conn.Write(serverHello)
|
||||
tos.firstResponse = false
|
||||
return len(b), err
|
||||
}
|
||||
buf := B.NewSize(5 + len(b))
|
||||
defer buf.Release()
|
||||
buf.Write([]byte{0x17, 0x03, 0x03})
|
||||
binary.Write(buf, binary.BigEndian, uint16(len(b)))
|
||||
buf.Write(b)
|
||||
_, err := tos.Conn.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) Upstream() any {
|
||||
return tos.Conn
|
||||
}
|
||||
|
||||
// NewTLSObfsServer returns a server-side TLS SimpleObfs.
|
||||
func NewTLSObfsServer(conn net.Conn) net.Conn {
|
||||
return &TLSObfsServer{
|
||||
Conn: conn,
|
||||
firstRequest: true,
|
||||
firstResponse: true,
|
||||
}
|
||||
}
|
||||
|
||||
func makeServerHello(data []byte) []byte {
|
||||
randBytes := make([]byte, 28)
|
||||
sessionId := make([]byte, 32)
|
||||
rand.Read(randBytes)
|
||||
rand.Read(sessionId)
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteByte(0x16)
|
||||
binary.Write(buf, binary.BigEndian, uint16(0x0301))
|
||||
binary.Write(buf, binary.BigEndian, uint16(91))
|
||||
buf.Write([]byte{2, 0, 0, 87, 0x03, 0x03})
|
||||
binary.Write(buf, binary.BigEndian, uint32(time.Now().Unix()))
|
||||
buf.Write(randBytes)
|
||||
buf.WriteByte(32)
|
||||
buf.Write(sessionId)
|
||||
buf.Write([]byte{0xcc, 0xa8})
|
||||
buf.WriteByte(0)
|
||||
buf.Write([]byte{0x00, 0x00})
|
||||
buf.Write([]byte{0xff, 0x01, 0x00, 0x01, 0x00})
|
||||
buf.Write([]byte{0x00, 0x17, 0x00, 0x00})
|
||||
buf.Write([]byte{0x00, 0x0b, 0x00, 0x02, 0x01, 0x00})
|
||||
buf.Write([]byte{0x14, 0x03, 0x03, 0x00, 0x01, 0x01})
|
||||
buf.Write([]byte{0x16, 0x03, 0x03})
|
||||
binary.Write(buf, binary.BigEndian, uint16(len(data)))
|
||||
buf.Write(data)
|
||||
return buf.Bytes()
|
||||
}
|
||||
144
transport/snell/address.go
Normal file
144
transport/snell/address.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// SOCKS address types as defined in RFC 1928 section 5.
|
||||
const (
|
||||
atypIPv4 = 1
|
||||
atypDomainName = 3
|
||||
atypIPv6 = 4
|
||||
)
|
||||
|
||||
// socksAddr represents a SOCKS address as defined in RFC 1928 section 5.
|
||||
type socksAddr []byte
|
||||
|
||||
func (a socksAddr) String() string {
|
||||
var host, port string
|
||||
switch a[0] {
|
||||
case atypDomainName:
|
||||
hostLen := uint16(a[1])
|
||||
host = string(a[2 : 2+hostLen])
|
||||
port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
|
||||
case atypIPv4:
|
||||
host = net.IP(a[1 : 1+net.IPv4len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
|
||||
case atypIPv6:
|
||||
host = net.IP(a[1 : 1+net.IPv6len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// UDPAddr converts a socksAddr to *net.UDPAddr.
|
||||
func (a socksAddr) UDPAddr() *net.UDPAddr {
|
||||
if len(a) == 0 {
|
||||
return nil
|
||||
}
|
||||
switch a[0] {
|
||||
case atypIPv4:
|
||||
var ip [net.IPv4len]byte
|
||||
copy(ip[0:], a[1:1+net.IPv4len])
|
||||
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
|
||||
case atypIPv6:
|
||||
var ip [net.IPv6len]byte
|
||||
copy(ip[0:], a[1:1+net.IPv6len])
|
||||
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitSocksAddr slices a SOCKS address from beginning of b. Returns nil if failed.
|
||||
func splitSocksAddr(b []byte) socksAddr {
|
||||
addrLen := 1
|
||||
if len(b) < addrLen {
|
||||
return nil
|
||||
}
|
||||
switch b[0] {
|
||||
case atypDomainName:
|
||||
if len(b) < 2 {
|
||||
return nil
|
||||
}
|
||||
addrLen = 1 + 1 + int(b[1]) + 2
|
||||
case atypIPv4:
|
||||
addrLen = 1 + net.IPv4len + 2
|
||||
case atypIPv6:
|
||||
addrLen = 1 + net.IPv6len + 2
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
if len(b) < addrLen {
|
||||
return nil
|
||||
}
|
||||
return b[:addrLen]
|
||||
}
|
||||
|
||||
// parseAddr parses the address in string s. Returns nil if failed.
|
||||
func parseAddr(s string) socksAddr {
|
||||
var addr socksAddr
|
||||
host, port, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
addr = make([]byte, 1+net.IPv4len+2)
|
||||
addr[0] = atypIPv4
|
||||
copy(addr[1:], ip4)
|
||||
} else {
|
||||
addr = make([]byte, 1+net.IPv6len+2)
|
||||
addr[0] = atypIPv6
|
||||
copy(addr[1:], ip)
|
||||
}
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
return nil
|
||||
}
|
||||
addr = make([]byte, 1+1+len(host)+2)
|
||||
addr[0] = atypDomainName
|
||||
addr[1] = byte(len(host))
|
||||
copy(addr[2:], host)
|
||||
}
|
||||
portnum, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
|
||||
return addr
|
||||
}
|
||||
|
||||
// parseAddrToSocksAddr parses a socks addr from net.Addr.
|
||||
// This is a fast path of parseAddr(addr.String()).
|
||||
func parseAddrToSocksAddr(addr net.Addr) socksAddr {
|
||||
var hostip net.IP
|
||||
var port int
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
hostip = addr.IP
|
||||
port = addr.Port
|
||||
case *net.TCPAddr:
|
||||
hostip = addr.IP
|
||||
port = addr.Port
|
||||
case nil:
|
||||
return nil
|
||||
}
|
||||
if hostip == nil {
|
||||
return parseAddr(addr.String())
|
||||
}
|
||||
var parsed socksAddr
|
||||
if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
|
||||
parsed = make([]byte, 1+net.IPv4len+2)
|
||||
parsed[0] = atypIPv4
|
||||
copy(parsed[1:], ip4)
|
||||
binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
|
||||
} else {
|
||||
parsed = make([]byte, 1+net.IPv6len+2)
|
||||
parsed[0] = atypIPv6
|
||||
copy(parsed[1:], hostip)
|
||||
binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
56
transport/snell/cipher.go
Normal file
56
transport/snell/cipher.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// NewAES128GCM returns the AES-128-GCM cipher used by snell v2/v3.
|
||||
func NewAES128GCM(psk []byte) Cipher {
|
||||
return &snellCipher{
|
||||
psk: psk,
|
||||
keySize: 16,
|
||||
makeAEAD: aesGCM,
|
||||
}
|
||||
}
|
||||
|
||||
// NewChacha20Poly1305 returns the ChaCha20-Poly1305 cipher used by snell v1.
|
||||
func NewChacha20Poly1305(psk []byte) Cipher {
|
||||
return &snellCipher{
|
||||
psk: psk,
|
||||
keySize: 32,
|
||||
makeAEAD: chacha20poly1305.New,
|
||||
}
|
||||
}
|
||||
|
||||
type snellCipher struct {
|
||||
psk []byte
|
||||
keySize int
|
||||
makeAEAD func(key []byte) (cipher.AEAD, error)
|
||||
}
|
||||
|
||||
func (sc *snellCipher) KeySize() int { return sc.keySize }
|
||||
func (sc *snellCipher) SaltSize() int { return 16 }
|
||||
|
||||
func (sc *snellCipher) Encrypter(salt []byte) (cipher.AEAD, error) {
|
||||
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
|
||||
}
|
||||
|
||||
func (sc *snellCipher) Decrypter(salt []byte) (cipher.AEAD, error) {
|
||||
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
|
||||
}
|
||||
|
||||
func snellKDF(psk, salt []byte, keySize int) []byte {
|
||||
return argon2.IDKey(psk, salt, 3, 8, 1, 32)[:keySize]
|
||||
}
|
||||
|
||||
func aesGCM(key []byte) (cipher.AEAD, error) {
|
||||
blk, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cipher.NewGCM(blk)
|
||||
}
|
||||
120
transport/snell/client.go
Normal file
120
transport/snell/client.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type ClientOptions struct {
|
||||
Dialer N.Dialer
|
||||
Server M.Socksaddr
|
||||
PSK []byte
|
||||
Version int
|
||||
Reuse bool
|
||||
ObfsMode string
|
||||
ObfsHost string
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
dialer N.Dialer
|
||||
server M.Socksaddr
|
||||
psk []byte
|
||||
version int
|
||||
reuse bool
|
||||
obfsMode string
|
||||
obfsHost string
|
||||
pool *Pool
|
||||
}
|
||||
|
||||
func NewClient(options ClientOptions) *Client {
|
||||
c := &Client{
|
||||
dialer: options.Dialer,
|
||||
server: options.Server,
|
||||
psk: options.PSK,
|
||||
version: options.Version,
|
||||
reuse: options.Reuse,
|
||||
obfsMode: options.ObfsMode,
|
||||
obfsHost: options.ObfsHost,
|
||||
}
|
||||
if c.reuse {
|
||||
c.pool = NewPool(func(ctx context.Context) (*Snell, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.streamConn(conn), nil
|
||||
})
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) streamConn(conn net.Conn) *Snell {
|
||||
switch c.obfsMode {
|
||||
case "tls":
|
||||
conn = obfs.NewTLSObfs(conn, c.obfsHost)
|
||||
case "http":
|
||||
conn = obfs.NewHTTPObfs(conn, c.obfsHost, strconv.Itoa(int(c.server.Port)))
|
||||
}
|
||||
return StreamConn(conn, c.psk, c.version)
|
||||
}
|
||||
|
||||
func (c *Client) writeHeader(ctx context.Context, conn net.Conn, destination M.Socksaddr, udp bool) (err error) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = conn.SetWriteDeadline(deadline)
|
||||
defer conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
if udp {
|
||||
err = WriteUDPHeader(conn, c.version)
|
||||
if err == nil && c.version >= Version4 {
|
||||
if sc, ok := conn.(*Snell); ok {
|
||||
err = sc.ReadReply()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
err = WriteHeaderWithReuse(conn, destination.AddrString(), uint(destination.Port), c.version, c.reuse)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) DialContext(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
||||
if c.reuse {
|
||||
conn, err := c.pool.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = c.writeHeader(ctx, conn, destination, false); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream := c.streamConn(conn)
|
||||
if err = c.writeHeader(ctx, stream, destination, false); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream := c.streamConn(conn)
|
||||
if err = c.writeHeader(ctx, stream, destination, true); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return PacketConn(stream), nil
|
||||
}
|
||||
153
transport/snell/pool.go
Normal file
153
transport/snell/pool.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// poolEntry holds a pooled item with its insertion time.
|
||||
|
||||
// connPool is a small connection pool with age-based eviction.
|
||||
|
||||
// milliseconds
|
||||
|
||||
// Pool is a pool of reusable snell connections.
|
||||
type Pool struct {
|
||||
pool *connPool
|
||||
}
|
||||
|
||||
func (p *Pool) Get() (net.Conn, error) {
|
||||
return p.GetContext(context.Background())
|
||||
}
|
||||
|
||||
func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) {
|
||||
elm, err := p.pool.GetContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PoolConn{Snell: elm, pool: p}, nil
|
||||
}
|
||||
|
||||
func (p *Pool) Put(conn *Snell) {
|
||||
if err := HalfClose(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
p.pool.put(conn)
|
||||
}
|
||||
|
||||
// PoolConn wraps a pooled snell connection and returns it to the pool on Close.
|
||||
type PoolConn struct {
|
||||
*Snell
|
||||
pool *Pool
|
||||
closeWriteOnce sync.Once
|
||||
closeWriteErr error
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (pc *PoolConn) Read(b []byte) (int, error) {
|
||||
n, err := pc.Snell.Read(b)
|
||||
if err == ErrZeroChunk {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (pc *PoolConn) Write(b []byte) (int, error) {
|
||||
return pc.Snell.Write(b)
|
||||
}
|
||||
|
||||
func (pc *PoolConn) CloseWrite() error {
|
||||
pc.closeWriteOnce.Do(func() {
|
||||
pc.closeWriteErr = writeZeroChunk(pc.Snell)
|
||||
})
|
||||
return pc.closeWriteErr
|
||||
}
|
||||
|
||||
func (pc *PoolConn) Close() error {
|
||||
pc.closeOnce.Do(func() {
|
||||
if err := pc.CloseWrite(); err != nil {
|
||||
pc.closeErr = err
|
||||
_ = pc.Snell.Close()
|
||||
return
|
||||
}
|
||||
_ = pc.Snell.Conn.SetReadDeadline(time.Time{})
|
||||
pc.Snell.reply = false
|
||||
pc.pool.pool.put(pc.Snell)
|
||||
})
|
||||
return pc.closeErr
|
||||
}
|
||||
|
||||
// NewPool creates a new snell connection pool using the given factory.
|
||||
func NewPool(factory func(context.Context) (*Snell, error)) *Pool {
|
||||
cp := &connPool{
|
||||
ch: make(chan *poolEntry, 10),
|
||||
factory: factory,
|
||||
maxAge: 15000,
|
||||
evict: func(item *Snell) {
|
||||
_ = item.Close()
|
||||
},
|
||||
}
|
||||
p := &Pool{pool: cp}
|
||||
runtime.SetFinalizer(p, recycle)
|
||||
return p
|
||||
}
|
||||
|
||||
type poolEntry struct {
|
||||
elm *Snell
|
||||
time time.Time
|
||||
}
|
||||
|
||||
type connPool struct {
|
||||
ch chan *poolEntry
|
||||
factory func(context.Context) (*Snell, error)
|
||||
evict func(*Snell)
|
||||
maxAge int64
|
||||
}
|
||||
|
||||
func (p *connPool) GetContext(ctx context.Context) (*Snell, error) {
|
||||
now := time.Now()
|
||||
for {
|
||||
select {
|
||||
case item := <-p.ch:
|
||||
if p.maxAge != 0 && now.Sub(item.time).Milliseconds() > p.maxAge {
|
||||
if p.evict != nil {
|
||||
p.evict(item.elm)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return item.elm, nil
|
||||
default:
|
||||
return p.factory(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *connPool) put(item *Snell) {
|
||||
e := &poolEntry{
|
||||
elm: item,
|
||||
time: time.Now(),
|
||||
}
|
||||
select {
|
||||
case p.ch <- e:
|
||||
return
|
||||
default:
|
||||
if p.evict != nil {
|
||||
p.evict(item)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func recycle(p *Pool) {
|
||||
for item := range p.pool.ch {
|
||||
if p.pool.evict != nil {
|
||||
p.pool.evict(item.elm)
|
||||
}
|
||||
}
|
||||
}
|
||||
294
transport/snell/service.go
Normal file
294
transport/snell/service.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, clientID string)
|
||||
|
||||
NewPacketConnection(ctx context.Context, conn net.PacketConn, source M.Socksaddr, clientID string)
|
||||
}
|
||||
|
||||
type ServiceOptions struct {
|
||||
PSK []byte
|
||||
Version int
|
||||
ObfsMode string
|
||||
UDP bool
|
||||
Logger logger.ContextLogger
|
||||
Handler Handler
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
psk []byte
|
||||
version int
|
||||
obfsMode string
|
||||
udp bool
|
||||
logger logger.ContextLogger
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func NewService(options ServiceOptions) (*Service, error) {
|
||||
version := options.Version
|
||||
if version == 0 {
|
||||
version = Version4
|
||||
}
|
||||
if version != Version4 && version != Version5 {
|
||||
return nil, fmt.Errorf("snell inbound version %d is not supported", version)
|
||||
}
|
||||
if len(options.PSK) == 0 {
|
||||
return nil, errors.New("snell inbound requires psk")
|
||||
}
|
||||
switch options.ObfsMode {
|
||||
case "", "http", "tls":
|
||||
default:
|
||||
return nil, fmt.Errorf("snell inbound obfs mode error: %s", options.ObfsMode)
|
||||
}
|
||||
return &Service{
|
||||
psk: options.PSK,
|
||||
version: version,
|
||||
obfsMode: options.ObfsMode,
|
||||
udp: options.UDP,
|
||||
logger: options.Logger,
|
||||
handler: options.Handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(ctx context.Context, rawConn net.Conn, source M.Socksaddr) error {
|
||||
conn := rawConn
|
||||
switch s.obfsMode {
|
||||
case "http":
|
||||
conn = obfs.NewHTTPObfsServer(conn)
|
||||
case "tls":
|
||||
conn = obfs.NewTLSObfsServer(conn)
|
||||
}
|
||||
stream := ServerStreamConn(conn, s.psk, s.version)
|
||||
for {
|
||||
reuse, err := s.handleRequest(ctx, stream, source)
|
||||
if err != nil || !reuse {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleRequest(ctx context.Context, stream *Snell, source M.Socksaddr) (bool, error) {
|
||||
br := bufio.NewReader(stream)
|
||||
version, err := br.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if version != Version {
|
||||
return false, fmt.Errorf("snell invalid protocol version: %d", version)
|
||||
}
|
||||
command, err := br.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if command == CommandPing {
|
||||
_, _ = stream.Write([]byte{CommandPong})
|
||||
return false, nil
|
||||
}
|
||||
clientID, err := readClientID(br)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
switch command {
|
||||
case CommandConnect, CommandConnectV2:
|
||||
return s.handleTCP(ctx, stream, br, command == CommandConnectV2, clientID, source)
|
||||
case CommandUDP:
|
||||
if !s.udp {
|
||||
return false, errors.New("snell UDP is disabled")
|
||||
}
|
||||
return false, s.handleUDP(ctx, stream, clientID, source)
|
||||
default:
|
||||
return false, fmt.Errorf("snell unknown command: %d", command)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleTCP(ctx context.Context, stream *Snell, br *bufio.Reader, reuse bool, clientID string, source M.Socksaddr) (bool, error) {
|
||||
hostLen, err := br.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if hostLen == 0 {
|
||||
return false, errors.New("snell connect host is empty")
|
||||
}
|
||||
hostBytes := make([]byte, int(hostLen))
|
||||
if _, err := io.ReadFull(br, hostBytes); err != nil {
|
||||
return false, err
|
||||
}
|
||||
var portBytes [2]byte
|
||||
if _, err := io.ReadFull(br, portBytes[:]); err != nil {
|
||||
return false, err
|
||||
}
|
||||
destination := M.ParseSocksaddrHostPort(string(hostBytes), binary.BigEndian.Uint16(portBytes[:]))
|
||||
conn := &tcpRequestConn{
|
||||
Conn: stream,
|
||||
reader: br,
|
||||
reuse: reuse,
|
||||
}
|
||||
s.handler.NewConnection(ctx, conn, source, destination, clientID)
|
||||
if !reuse {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *Service) handleUDP(ctx context.Context, stream *Snell, clientID string, source M.Socksaddr) error {
|
||||
if _, err := stream.Write([]byte{CommandTunnel}); err != nil {
|
||||
return err
|
||||
}
|
||||
pc := &serverPacketConn{
|
||||
conn: stream,
|
||||
writeMu: &sync.Mutex{},
|
||||
}
|
||||
s.handler.NewPacketConnection(ctx, pc, source, clientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxPacketLength = 0x3fff
|
||||
|
||||
func readClientID(r *bufio.Reader) (string, error) {
|
||||
length, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
id := make([]byte, int(length))
|
||||
if _, err := io.ReadFull(r, id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(id), nil
|
||||
}
|
||||
|
||||
func writeCommandError(w io.Writer, code byte, message string) error {
|
||||
msg := []byte(message)
|
||||
if len(msg) > 255 {
|
||||
msg = msg[:255]
|
||||
}
|
||||
buf := make([]byte, 0, 3+len(msg))
|
||||
buf = append(buf, CommandError, code, byte(len(msg)))
|
||||
buf = append(buf, msg...)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
type tcpRequestConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
reuse bool
|
||||
writeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
replyWritten bool
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) Read(p []byte) (int, error) {
|
||||
n, err := c.reader.Read(p)
|
||||
if errors.Is(err, ErrZeroChunk) {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) Write(p []byte) (int, error) {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if !c.replyWritten {
|
||||
payload := make([]byte, 1+len(p))
|
||||
payload[0] = CommandTunnel
|
||||
copy(payload[1:], p)
|
||||
if _, err := c.Conn.Write(payload); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.replyWritten = true
|
||||
return len(p), nil
|
||||
}
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) CloseWrite() error {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) Close() error {
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if !c.replyWritten {
|
||||
err = writeCommandError(c.Conn, 0x65, "Remote EOF")
|
||||
if !c.reuse {
|
||||
err = errors.Join(err, c.Conn.Close())
|
||||
}
|
||||
return
|
||||
}
|
||||
if c.reuse {
|
||||
_, err = c.Conn.Write(nil)
|
||||
return
|
||||
}
|
||||
err = c.Conn.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
type serverPacketConn struct {
|
||||
conn *Snell
|
||||
writeMu *sync.Mutex
|
||||
readBuf []byte
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
if c.readBuf == nil {
|
||||
c.readBuf = make([]byte, maxPacketLength)
|
||||
}
|
||||
for {
|
||||
n, err := c.conn.Read(c.readBuf)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, ErrZeroChunk) {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
request, err := ParseUDPRequest(c.readBuf[:n])
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
if request.Ip.IsValid() {
|
||||
destination = M.SocksaddrFrom(request.Ip, request.Port)
|
||||
} else {
|
||||
destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
|
||||
}
|
||||
length := copy(p, request.Payload)
|
||||
if destination.IsFqdn() {
|
||||
return length, destination, nil
|
||||
}
|
||||
return length, destination.UDPAddr(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
return WritePacketResponse(c.conn, addr, p)
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Close() error { return c.conn.Close() }
|
||||
func (c *serverPacketConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||
func (c *serverPacketConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
||||
func (c *serverPacketConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
||||
func (c *serverPacketConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
||||
211
transport/snell/shadowaead.go
Normal file
211
transport/snell/shadowaead.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
// payloadSizeMask is the maximum size of payload in bytes.
|
||||
// 16*1024 - 1
|
||||
// >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead()
|
||||
|
||||
// ErrZeroChunk is returned when a zero-length chunk is read, which snell uses
|
||||
// as an end-of-stream signal.
|
||||
var ErrZeroChunk = errors.New("zero chunk")
|
||||
|
||||
// Cipher is the AEAD cipher abstraction used by the shadowaead stream.
|
||||
type Cipher interface {
|
||||
KeySize() int
|
||||
SaltSize() int
|
||||
Encrypter(salt []byte) (cipher.AEAD, error)
|
||||
Decrypter(salt []byte) (cipher.AEAD, error)
|
||||
}
|
||||
|
||||
const (
|
||||
payloadSizeMask = 0x3FFF
|
||||
bufSize = 17 * 1024
|
||||
)
|
||||
|
||||
type aeadWriter struct {
|
||||
io.Writer
|
||||
cipher.AEAD
|
||||
nonce [32]byte // should be sufficient for most nonce sizes
|
||||
}
|
||||
|
||||
// newAEADWriter wraps an io.Writer with authenticated encryption.
|
||||
func newAEADWriter(w io.Writer, aead cipher.AEAD) *aeadWriter {
|
||||
return &aeadWriter{Writer: w, AEAD: aead}
|
||||
}
|
||||
|
||||
// Write encrypts p and writes to the embedded io.Writer.
|
||||
func (w *aeadWriter) Write(p []byte) (n int, err error) {
|
||||
b := buf.Get(bufSize)
|
||||
defer buf.Put(b)
|
||||
nonce := w.nonce[:w.NonceSize()]
|
||||
tag := w.Overhead()
|
||||
off := 2 + tag
|
||||
if len(p) == 0 {
|
||||
b = b[:off]
|
||||
b[0], b[1] = byte(0), byte(0)
|
||||
w.Seal(b[:0], nonce, b[:2], nil)
|
||||
increment(nonce)
|
||||
_, err = w.Writer.Write(b)
|
||||
return
|
||||
}
|
||||
for nr := 0; n < len(p) && err == nil; n += nr {
|
||||
nr = payloadSizeMask
|
||||
if n+nr > len(p) {
|
||||
nr = len(p) - n
|
||||
}
|
||||
b = b[:off+nr+tag]
|
||||
b[0], b[1] = byte(nr>>8), byte(nr)
|
||||
w.Seal(b[:0], nonce, b[:2], nil)
|
||||
increment(nonce)
|
||||
w.Seal(b[:off], nonce, p[n:n+nr], nil)
|
||||
increment(nonce)
|
||||
_, err = w.Writer.Write(b)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type aeadReader struct {
|
||||
io.Reader
|
||||
cipher.AEAD
|
||||
nonce [32]byte // should be sufficient for most nonce sizes
|
||||
buf []byte // to be put back into bufPool
|
||||
off int // offset to unconsumed part of buf
|
||||
}
|
||||
|
||||
// newAEADReader wraps an io.Reader with authenticated decryption.
|
||||
func newAEADReader(r io.Reader, aead cipher.AEAD) *aeadReader {
|
||||
return &aeadReader{Reader: r, AEAD: aead}
|
||||
}
|
||||
|
||||
// read and decrypt a record into p. len(p) >= max payload size + AEAD overhead.
|
||||
func (r *aeadReader) read(p []byte) (int, error) {
|
||||
nonce := r.nonce[:r.NonceSize()]
|
||||
tag := r.Overhead()
|
||||
p = p[:2+tag]
|
||||
if _, err := io.ReadFull(r.Reader, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err := r.Open(p[:0], nonce, p, nil)
|
||||
increment(nonce)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask
|
||||
if size == 0 {
|
||||
return 0, ErrZeroChunk
|
||||
}
|
||||
p = p[:size+tag]
|
||||
if _, err := io.ReadFull(r.Reader, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = r.Open(p[:0], nonce, p, nil)
|
||||
increment(nonce)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// Read reads from the embedded io.Reader, decrypts and writes to p.
|
||||
func (r *aeadReader) Read(p []byte) (int, error) {
|
||||
if r.buf == nil {
|
||||
if len(p) >= payloadSizeMask+r.Overhead() {
|
||||
return r.read(p)
|
||||
}
|
||||
b := buf.Get(bufSize)
|
||||
n, err := r.read(b)
|
||||
if err != nil {
|
||||
buf.Put(b)
|
||||
return 0, err
|
||||
}
|
||||
r.buf = b[:n]
|
||||
r.off = 0
|
||||
}
|
||||
n := copy(p, r.buf[r.off:])
|
||||
r.off += n
|
||||
if r.off == len(r.buf) {
|
||||
buf.Put(r.buf[:cap(r.buf)])
|
||||
r.buf = nil
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// increment little-endian encoded unsigned integer b. Wrap around on overflow.
|
||||
func increment(b []byte) {
|
||||
for i := range b {
|
||||
b[i]++
|
||||
if b[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// aeadConn wraps a stream-oriented net.Conn with the shadowaead cipher.
|
||||
type aeadConn struct {
|
||||
net.Conn
|
||||
Cipher
|
||||
r *aeadReader
|
||||
w *aeadWriter
|
||||
}
|
||||
|
||||
// newAEADConn wraps a stream-oriented net.Conn with cipher.
|
||||
func newAEADConn(c net.Conn, ciph Cipher) *aeadConn {
|
||||
return &aeadConn{Conn: c, Cipher: ciph}
|
||||
}
|
||||
|
||||
func (c *aeadConn) initReader() error {
|
||||
salt := make([]byte, c.SaltSize())
|
||||
if _, err := io.ReadFull(c.Conn, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
aead, err := c.Decrypter(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.r = newAEADReader(c.Conn, aead)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) Read(b []byte) (int, error) {
|
||||
if c.r == nil {
|
||||
if err := c.initReader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *aeadConn) initWriter() error {
|
||||
salt := make([]byte, c.SaltSize())
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return err
|
||||
}
|
||||
aead, err := c.Encrypter(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.Conn.Write(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.w = newAEADWriter(c.Conn, aead)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) Write(b []byte) (int, error) {
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.w.Write(b)
|
||||
}
|
||||
408
transport/snell/snell.go
Normal file
408
transport/snell/snell.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
const (
|
||||
Version1 = 1
|
||||
Version2 = 2
|
||||
Version3 = 3
|
||||
Version4 = 4
|
||||
Version5 = 5
|
||||
DefaultSnellVersion = Version1
|
||||
|
||||
// max packet length
|
||||
maxLength = 0x3FFF
|
||||
)
|
||||
|
||||
const (
|
||||
CommandPing byte = 0
|
||||
CommandConnect byte = 1
|
||||
CommandConnectV2 byte = 5
|
||||
CommandUDP byte = 6
|
||||
CommandUDPForward byte = 1
|
||||
|
||||
CommandTunnel byte = 0
|
||||
CommandPong byte = 1
|
||||
CommandError byte = 2
|
||||
|
||||
Version byte = 1
|
||||
)
|
||||
|
||||
// Snell wraps an encrypted stream and handles the snell reply header.
|
||||
type Snell struct {
|
||||
net.Conn
|
||||
buffer [1]byte
|
||||
reply bool
|
||||
}
|
||||
|
||||
func (s *Snell) Read(b []byte) (int, error) {
|
||||
if err := s.ReadReply(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (s *Snell) ReadReply() error {
|
||||
if s.reply {
|
||||
return nil
|
||||
}
|
||||
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
s.reply = true
|
||||
if s.buffer[0] == CommandTunnel {
|
||||
return nil
|
||||
} else if s.buffer[0] != CommandError {
|
||||
return errors.New("command not support")
|
||||
}
|
||||
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
errcode := int(s.buffer[0])
|
||||
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
length := int(s.buffer[0])
|
||||
msg := make([]byte, length)
|
||||
if _, err := io.ReadFull(s.Conn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
|
||||
}
|
||||
|
||||
func WriteHeader(conn net.Conn, host string, port uint, version int) error {
|
||||
return WriteHeaderWithReuse(conn, host, port, version, false)
|
||||
}
|
||||
|
||||
func WriteHeaderWithReuse(conn net.Conn, host string, port uint, version int, reuse bool) error {
|
||||
buffer := &bytes.Buffer{}
|
||||
buffer.WriteByte(Version)
|
||||
if version == Version2 || reuse {
|
||||
buffer.WriteByte(CommandConnectV2)
|
||||
} else {
|
||||
buffer.WriteByte(CommandConnect)
|
||||
}
|
||||
buffer.WriteByte(0)
|
||||
buffer.WriteByte(uint8(len(host)))
|
||||
buffer.WriteString(host)
|
||||
binary.Write(buffer, binary.BigEndian, uint16(port))
|
||||
if _, err := conn.Write(buffer.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WriteUDPHeader(conn net.Conn, version int) error {
|
||||
if version < Version3 {
|
||||
return errors.New("unsupport UDP version")
|
||||
}
|
||||
_, err := conn.Write([]byte{Version, CommandUDP, 0x00})
|
||||
return err
|
||||
}
|
||||
|
||||
// HalfClose only works after the request negotiated the reuse command.
|
||||
func HalfClose(conn net.Conn) error {
|
||||
if err := writeZeroChunk(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if s, ok := conn.(*Snell); ok {
|
||||
s.reply = false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamConn wraps a raw connection with the snell stream cipher for the given version.
|
||||
func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
|
||||
if version >= Version4 {
|
||||
return &Snell{Conn: newV4Conn(conn, psk)}
|
||||
}
|
||||
var cipher Cipher
|
||||
if version != Version1 {
|
||||
cipher = NewAES128GCM(psk)
|
||||
} else {
|
||||
cipher = NewChacha20Poly1305(psk)
|
||||
}
|
||||
return &Snell{Conn: newAEADConn(conn, cipher)}
|
||||
}
|
||||
|
||||
// ServerStreamConn wraps a raw connection on the server side.
|
||||
func ServerStreamConn(conn net.Conn, psk []byte, version int) *Snell {
|
||||
stream := StreamConn(conn, psk, version)
|
||||
stream.reply = true
|
||||
return stream
|
||||
}
|
||||
|
||||
func PacketConn(conn net.Conn) net.PacketConn {
|
||||
return &packetConn{
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Snell) WritePacketFrame(b []byte) (int, error) {
|
||||
if fw, ok := s.Conn.(packetFrameWriter); ok {
|
||||
return fw.WritePacketFrame(b)
|
||||
}
|
||||
return s.Conn.Write(b)
|
||||
}
|
||||
|
||||
func WritePacket(w io.Writer, target, payload []byte) (int, error) {
|
||||
maxPayloadLength := maxLength - udpRequestHeaderLength(target)
|
||||
if maxPayloadLength <= 0 {
|
||||
return 0, errors.New("snell UDP address too large")
|
||||
}
|
||||
if len(payload) <= maxPayloadLength {
|
||||
return writePacket(w, target, payload)
|
||||
}
|
||||
return 0, errors.New("snell UDP payload too large")
|
||||
}
|
||||
|
||||
func WritePacketResponse(w io.Writer, addr net.Addr, payload []byte) (int, error) {
|
||||
buffer := &bytes.Buffer{}
|
||||
target := parseAddrToSocksAddr(addr)
|
||||
if len(target) == 0 {
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
switch target[0] {
|
||||
case atypIPv4:
|
||||
if len(target) < 1+net.IPv4len+2 {
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
buffer.WriteByte(0x04)
|
||||
buffer.Write(target[1 : 1+net.IPv4len+2])
|
||||
case atypIPv6:
|
||||
if len(target) < 1+net.IPv6len+2 {
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
buffer.WriteByte(0x06)
|
||||
buffer.Write(target[1 : 1+net.IPv6len+2])
|
||||
default:
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
buffer.Write(payload)
|
||||
var err error
|
||||
if fw, ok := w.(packetFrameWriter); ok {
|
||||
_, err = fw.WritePacketFrame(buffer.Bytes())
|
||||
} else {
|
||||
_, err = w.Write(buffer.Bytes())
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
// UDPRequest is a parsed snell UDP forward request.
|
||||
type UDPRequest struct {
|
||||
Host string
|
||||
Ip netip.Addr
|
||||
Port uint16
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func ParseUDPRequest(packet []byte) (UDPRequest, error) {
|
||||
if len(packet) < 2 || packet[0] != CommandUDPForward {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP request")
|
||||
}
|
||||
if hostLen := int(packet[1]); hostLen != 0 {
|
||||
if len(packet) <= 2+hostLen+2 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP domain request")
|
||||
}
|
||||
offset := 2 + hostLen
|
||||
return UDPRequest{
|
||||
Host: string(packet[2:offset]),
|
||||
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
|
||||
Payload: packet[offset+2:],
|
||||
}, nil
|
||||
}
|
||||
if len(packet) < 3 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP IP request")
|
||||
}
|
||||
switch packet[2] {
|
||||
case 0x04:
|
||||
if len(packet) < 3+net.IPv4len+2 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP IPv4 request")
|
||||
}
|
||||
offset := 3 + net.IPv4len
|
||||
ip, _ := netip.AddrFromSlice(packet[3:offset])
|
||||
return UDPRequest{
|
||||
Ip: ip.Unmap(),
|
||||
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
|
||||
Payload: packet[offset+2:],
|
||||
}, nil
|
||||
case 0x06:
|
||||
if len(packet) < 3+net.IPv6len+2 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP IPv6 request")
|
||||
}
|
||||
offset := 3 + net.IPv6len
|
||||
ip, _ := netip.AddrFromSlice(packet[3:offset])
|
||||
return UDPRequest{
|
||||
Ip: ip.Unmap(),
|
||||
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
|
||||
Payload: packet[offset+2:],
|
||||
}, nil
|
||||
default:
|
||||
return UDPRequest{}, errors.New("snell invalid UDP address type")
|
||||
}
|
||||
}
|
||||
|
||||
func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) {
|
||||
b := buf.Get(buf.UDPBufferSize)
|
||||
defer buf.Put(b)
|
||||
n, err := r.Read(b)
|
||||
headLen := 1
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if n < headLen {
|
||||
return nil, 0, errors.New("insufficient UDP length")
|
||||
}
|
||||
switch b[0] {
|
||||
case 0x04:
|
||||
headLen += net.IPv4len + 2
|
||||
if n < headLen {
|
||||
err = errors.New("insufficient UDP length")
|
||||
break
|
||||
}
|
||||
b[0] = atypIPv4
|
||||
case 0x06:
|
||||
headLen += net.IPv6len + 2
|
||||
if n < headLen {
|
||||
err = errors.New("insufficient UDP length")
|
||||
break
|
||||
}
|
||||
b[0] = atypIPv6
|
||||
default:
|
||||
err = errors.New("ip version invalid")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
addr := splitSocksAddr(b[0:])
|
||||
if addr == nil {
|
||||
return nil, 0, errors.New("remote address invalid")
|
||||
}
|
||||
uAddr := addr.UDPAddr()
|
||||
if uAddr == nil {
|
||||
return nil, 0, errors.New("parse addr error")
|
||||
}
|
||||
length := len(payload)
|
||||
if n-headLen < length {
|
||||
length = n - headLen
|
||||
}
|
||||
copy(payload[:], b[headLen:headLen+length])
|
||||
return uAddr, length, nil
|
||||
}
|
||||
|
||||
var endSignal = []byte{}
|
||||
|
||||
type packetFrameWriter interface {
|
||||
WritePacketFrame([]byte) (int, error)
|
||||
}
|
||||
|
||||
func writeZeroChunk(conn net.Conn) error {
|
||||
if _, err := conn.Write(endSignal); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePacket(w io.Writer, target, payload []byte) (int, error) {
|
||||
buffer := &bytes.Buffer{}
|
||||
buffer.WriteByte(CommandUDPForward)
|
||||
switch target[0] {
|
||||
case atypDomainName:
|
||||
hostLen := target[1]
|
||||
if len(target) < 1+1+int(hostLen)+2 {
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write(target[1 : 1+1+hostLen+2])
|
||||
case atypIPv4:
|
||||
if len(target) < 1+net.IPv4len+2 {
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write([]byte{0x00, 0x04})
|
||||
buffer.Write(target[1 : 1+net.IPv4len+2])
|
||||
case atypIPv6:
|
||||
if len(target) < 1+net.IPv6len+2 {
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write([]byte{0x00, 0x06})
|
||||
buffer.Write(target[1 : 1+net.IPv6len+2])
|
||||
default:
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write(payload)
|
||||
if fw, ok := w.(packetFrameWriter); ok {
|
||||
_, err := fw.WritePacketFrame(buffer.Bytes())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
_, err := w.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
func udpRequestHeaderLength(target []byte) int {
|
||||
if len(target) == 0 {
|
||||
return maxLength + 1
|
||||
}
|
||||
switch target[0] {
|
||||
case atypDomainName:
|
||||
if len(target) < 2 {
|
||||
return maxLength + 1
|
||||
}
|
||||
return 1 + 1 + int(target[1]) + 2
|
||||
case atypIPv4:
|
||||
return 1 + 2 + net.IPv4len + 2
|
||||
case atypIPv6:
|
||||
return 1 + 2 + net.IPv6len + 2
|
||||
default:
|
||||
return maxLength + 1
|
||||
}
|
||||
}
|
||||
|
||||
type packetConn struct {
|
||||
net.Conn
|
||||
rMux sync.Mutex
|
||||
wMux sync.Mutex
|
||||
}
|
||||
|
||||
func (pc *packetConn) WritePacketFrame(b []byte) (int, error) {
|
||||
if s, ok := pc.Conn.(*Snell); ok {
|
||||
if fw, ok := s.Conn.(packetFrameWriter); ok {
|
||||
return fw.WritePacketFrame(b)
|
||||
}
|
||||
}
|
||||
return pc.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
pc.wMux.Lock()
|
||||
defer pc.wMux.Unlock()
|
||||
return WritePacket(pc, parseAddrToSocksAddr(addr), b)
|
||||
}
|
||||
|
||||
func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
pc.rMux.Lock()
|
||||
defer pc.rMux.Unlock()
|
||||
addr, n, err := ReadPacket(pc.Conn, b)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
return n, addr, nil
|
||||
}
|
||||
463
transport/snell/v4.go
Normal file
463
transport/snell/v4.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"math/big"
|
||||
"math/bits"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
v4SaltSize = 16
|
||||
v4NonceSize = 12
|
||||
v4HeaderPlainSize = 7
|
||||
v4HeaderCipherSize = v4HeaderPlainSize + 16
|
||||
v4FrameSize = 1460
|
||||
v4InitialPaddingMin = 0x100
|
||||
v4InitialPaddingSpan = 0x100
|
||||
)
|
||||
|
||||
type v4Conn struct {
|
||||
net.Conn
|
||||
psk []byte
|
||||
r *v4Reader
|
||||
w *v4Writer
|
||||
}
|
||||
|
||||
func newV4Conn(conn net.Conn, psk []byte) *v4Conn {
|
||||
return &v4Conn{Conn: conn, psk: psk}
|
||||
}
|
||||
|
||||
func (c *v4Conn) initReader() error {
|
||||
salt := make([]byte, v4SaltSize)
|
||||
if _, err := io.ReadFull(c.Conn, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
aead, err := v4AEAD(c.psk, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.r = &v4Reader{Reader: c.Conn, aead: aead}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *v4Conn) initWriter() error {
|
||||
w, err := newV4Writer(c.Conn, c.psk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.w = w
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *v4Conn) Read(b []byte) (int, error) {
|
||||
if c.r == nil {
|
||||
if err := c.initReader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *v4Conn) Write(b []byte) (int, error) {
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.w.Write(b)
|
||||
}
|
||||
|
||||
func (c *v4Conn) WritePacketFrame(b []byte) (int, error) {
|
||||
if len(b) > maxLength {
|
||||
return 0, errors.New("snell v4 frame too large")
|
||||
}
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
c.w.mux.Lock()
|
||||
defer c.w.mux.Unlock()
|
||||
if err := c.w.writeFrame(b, c.w.nextFramePaddingLength(len(b))); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *v4Conn) WriteTo(w io.Writer) (int64, error) {
|
||||
if c.r == nil {
|
||||
if err := c.initReader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
var written int64
|
||||
buf := make([]byte, maxLength)
|
||||
for {
|
||||
n, err := c.r.Read(buf)
|
||||
if n > 0 {
|
||||
nw, ew := w.Write(buf[:n])
|
||||
written += int64(nw)
|
||||
if ew != nil {
|
||||
return written, ew
|
||||
}
|
||||
if nw != n {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *v4Conn) ReadFrom(r io.Reader) (int64, error) {
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
var read int64
|
||||
buf := make([]byte, maxLength)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
read += int64(n)
|
||||
if _, ew := c.w.Write(buf[:n]); ew != nil {
|
||||
return read, ew
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return read, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func v4AEAD(psk, salt []byte) (cipher.AEAD, error) {
|
||||
return aesGCM(snellKDF(psk, salt, 16))
|
||||
}
|
||||
|
||||
type v4Reader struct {
|
||||
io.Reader
|
||||
aead cipher.AEAD
|
||||
nonce [v4NonceSize]byte
|
||||
buf []byte
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func (r *v4Reader) Read(b []byte) (int, error) {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
if len(r.buf) == 0 {
|
||||
payload, err := r.readFrame()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r.buf = payload
|
||||
}
|
||||
n := copy(b, r.buf)
|
||||
r.buf = r.buf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *v4Reader) readFrame() ([]byte, error) {
|
||||
headerCipher := make([]byte, v4HeaderCipherSize)
|
||||
if _, err := io.ReadFull(r.Reader, headerCipher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
header, err := r.aead.Open(headerCipher[:0], r.nonce[:], headerCipher, nil)
|
||||
incrementV4Nonce(r.nonce[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(header) != v4HeaderPlainSize || header[0] != 4 {
|
||||
return nil, errors.New("snell v4 invalid frame header")
|
||||
}
|
||||
paddingLength := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
payloadLength := int(binary.BigEndian.Uint16(header[5:7]))
|
||||
if payloadLength == 0 {
|
||||
if paddingLength != 0 {
|
||||
return nil, errors.New("snell v4 zero chunk with padding")
|
||||
}
|
||||
return nil, ErrZeroChunk
|
||||
}
|
||||
if payloadLength > maxLength || paddingLength > maxLength {
|
||||
return nil, errors.New("snell v4 frame too large")
|
||||
}
|
||||
payloadCipherLength := payloadLength + r.aead.Overhead()
|
||||
frame := make([]byte, paddingLength+payloadCipherLength)
|
||||
if _, err := io.ReadFull(r.Reader, frame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if paddingLength > 0 {
|
||||
swapPadding(frame[:paddingLength], frame[paddingLength:])
|
||||
}
|
||||
payloadCipher := frame[paddingLength:]
|
||||
payload, err := r.aead.Open(payloadCipher[:0], r.nonce[:], payloadCipher, nil)
|
||||
incrementV4Nonce(r.nonce[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
type v4Writer struct {
|
||||
io.Writer
|
||||
aead cipher.AEAD
|
||||
nonce [v4NonceSize]byte
|
||||
salt [v4SaltSize]byte
|
||||
saltSent bool
|
||||
initialPaddingLength uint16
|
||||
payloadLimit uint16
|
||||
lastWrite time.Time
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newV4Writer(w io.Writer, psk []byte) (*v4Writer, error) {
|
||||
var salt [v4SaltSize]byte
|
||||
if _, err := io.ReadFull(cryptorand.Reader, salt[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := v4AEAD(psk, salt[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddingDelta, err := cryptoRandomInt(v4InitialPaddingSpan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v4Writer{
|
||||
Writer: w,
|
||||
aead: aead,
|
||||
salt: salt,
|
||||
initialPaddingLength: uint16(v4InitialPaddingMin + paddingDelta),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *v4Writer) Write(b []byte) (int, error) {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if len(b) == 0 {
|
||||
return 0, w.writeFrame(nil, 0)
|
||||
}
|
||||
written := 0
|
||||
for written < len(b) {
|
||||
payloadLimit := int(w.nextPayloadLimit())
|
||||
if payloadLimit <= 0 || payloadLimit > maxLength {
|
||||
payloadLimit = maxLength
|
||||
}
|
||||
end := written + payloadLimit
|
||||
if end > len(b) {
|
||||
end = len(b)
|
||||
}
|
||||
paddingLength := w.nextFramePaddingLength(end - written)
|
||||
if err := w.writeFrame(b[written:end], paddingLength); err != nil {
|
||||
return written, err
|
||||
}
|
||||
written = end
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (w *v4Writer) nextPayloadLimit() uint16 {
|
||||
now := time.Now()
|
||||
var payloadLimit uint16
|
||||
switch {
|
||||
case w.lastWrite.IsZero():
|
||||
payloadLimit = v4FrameSize - 55 - w.initialPaddingLength
|
||||
case now.Sub(w.lastWrite) > 30*time.Second:
|
||||
payloadLimit = v4FrameSize - 39
|
||||
default:
|
||||
payloadLimit = w.payloadLimit
|
||||
}
|
||||
w.lastWrite = now
|
||||
if payloadLimit <= maxLength-1 {
|
||||
next := int(payloadLimit) + v4FrameSize - 39
|
||||
if next > maxLength {
|
||||
next = maxLength
|
||||
}
|
||||
w.payloadLimit = uint16(next)
|
||||
} else {
|
||||
w.payloadLimit = maxLength
|
||||
}
|
||||
return payloadLimit
|
||||
}
|
||||
|
||||
func (w *v4Writer) nextFramePaddingLength(payloadLength int) int {
|
||||
if w.saltSent || payloadLength == 0 {
|
||||
return 0
|
||||
}
|
||||
return int(w.initialPaddingLength)
|
||||
}
|
||||
|
||||
func (w *v4Writer) writeFrame(payload []byte, paddingLength int) error {
|
||||
if len(payload) > maxLength || paddingLength > maxLength {
|
||||
return errors.New("snell v4 frame too large")
|
||||
}
|
||||
if len(payload) == 0 && paddingLength != 0 {
|
||||
return errors.New("snell v4 zero chunk with padding")
|
||||
}
|
||||
header := make([]byte, v4HeaderPlainSize)
|
||||
header[0] = 4
|
||||
binary.BigEndian.PutUint16(header[3:5], uint16(paddingLength))
|
||||
binary.BigEndian.PutUint16(header[5:7], uint16(len(payload)))
|
||||
headerCipher := w.aead.Seal(nil, w.nonce[:], header, nil)
|
||||
incrementV4Nonce(w.nonce[:])
|
||||
var payloadCipher []byte
|
||||
if len(payload) > 0 {
|
||||
payloadCipher = w.aead.Seal(nil, w.nonce[:], payload, nil)
|
||||
incrementV4Nonce(w.nonce[:])
|
||||
}
|
||||
frameLength := len(headerCipher) + paddingLength + len(payloadCipher)
|
||||
if !w.saltSent {
|
||||
frameLength += v4SaltSize
|
||||
}
|
||||
frame := make([]byte, 0, frameLength)
|
||||
if !w.saltSent {
|
||||
frame = append(frame, w.salt[:]...)
|
||||
w.saltSent = true
|
||||
}
|
||||
frame = append(frame, headerCipher...)
|
||||
if paddingLength > 0 {
|
||||
padding, err := makeV4Padding(payloadCipher, paddingLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
swapPadding(padding, payloadCipher)
|
||||
frame = append(frame, padding...)
|
||||
}
|
||||
frame = append(frame, payloadCipher...)
|
||||
return writeFull(w.Writer, frame)
|
||||
}
|
||||
|
||||
func swapPadding(padding, payloadCipher []byte) {
|
||||
limit := len(padding)
|
||||
if len(payloadCipher) < limit {
|
||||
limit = len(payloadCipher)
|
||||
}
|
||||
for i := 0; i < limit; i += 2 {
|
||||
padding[i], payloadCipher[i] = payloadCipher[i], padding[i]
|
||||
}
|
||||
}
|
||||
|
||||
func makeV4Padding(payloadCipher []byte, paddingLength int) ([]byte, error) {
|
||||
if paddingLength <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
payloadOnes := countV4PayloadOnes(payloadCipher)
|
||||
payloadZeros := 8*len(payloadCipher) - payloadOnes
|
||||
if payloadZeros <= 0 {
|
||||
return makeV4RandomPadding(paddingLength)
|
||||
}
|
||||
ratio := float64(payloadOnes) / float64(payloadZeros)
|
||||
if ratio <= 0.5 || ratio >= 1.6 {
|
||||
return makeV4RandomPadding(paddingLength)
|
||||
}
|
||||
targetRatioBase := 1.6
|
||||
if payloadZeros < payloadOnes {
|
||||
targetRatioBase = 0.4
|
||||
}
|
||||
jitter, err := randomUnitFloat64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetRatio := targetRatioBase + jitter/10
|
||||
totalBits := 8 * (paddingLength + len(payloadCipher))
|
||||
targetOnes := int(float64(totalBits)*(targetRatio/(targetRatio+1)) - float64(payloadOnes))
|
||||
if targetOnes < 0 || targetOnes > 8*paddingLength {
|
||||
return makeV4RandomPadding(paddingLength)
|
||||
}
|
||||
return makeV4BitCountPadding(paddingLength, targetOnes)
|
||||
}
|
||||
|
||||
func countV4PayloadOnes(payloadCipher []byte) int {
|
||||
limit := len(payloadCipher) &^ 3
|
||||
ones := 0
|
||||
for _, b := range payloadCipher[:limit] {
|
||||
ones += bits.OnesCount8(b)
|
||||
}
|
||||
return ones
|
||||
}
|
||||
|
||||
func makeV4RandomPadding(length int) ([]byte, error) {
|
||||
padding := make([]byte, length)
|
||||
_, err := io.ReadFull(cryptorand.Reader, padding)
|
||||
return padding, err
|
||||
}
|
||||
|
||||
func makeV4BitCountPadding(length, oneBits int) ([]byte, error) {
|
||||
totalBits := 8 * length
|
||||
if oneBits < 0 || oneBits > totalBits {
|
||||
return nil, errors.New("snell v4 invalid padding bit count")
|
||||
}
|
||||
bitset := make([]byte, totalBits)
|
||||
for i := 0; i < oneBits; i++ {
|
||||
bitset[i] = 1
|
||||
}
|
||||
for i := totalBits - 1; i > 0; i-- {
|
||||
j, err := cryptoRandomInt(i + 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bitset[i], bitset[j] = bitset[j], bitset[i]
|
||||
}
|
||||
padding := make([]byte, length)
|
||||
for i, bit := range bitset {
|
||||
if bit == 1 {
|
||||
padding[i/8] |= 1 << uint(i%8)
|
||||
}
|
||||
}
|
||||
return padding, nil
|
||||
}
|
||||
|
||||
func cryptoRandomInt(max int) (int, error) {
|
||||
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(int64(max)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(n.Int64()), nil
|
||||
}
|
||||
|
||||
func randomUnitFloat64() (float64, error) {
|
||||
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(1<<53))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return float64(n.Int64()) / math.Exp2(53), nil
|
||||
}
|
||||
|
||||
func writeFull(w io.Writer, p []byte) error {
|
||||
for len(p) > 0 {
|
||||
n, err := w.Write(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
p = p[n:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementV4Nonce(nonce []byte) {
|
||||
for i := range nonce {
|
||||
nonce[i]++
|
||||
if nonce[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,9 +18,12 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
headerSize = 1 + 4 + 4
|
||||
maxFrameSize = 256 * 1024
|
||||
maxDataPayload = 32 * 1024
|
||||
headerSize = 1 + 4 + 4
|
||||
// maxQueuedBytesPerStream bounds unread payload retained by a single logical stream.
|
||||
// Backpressure is applied to the demux loop instead of dropping data.
|
||||
maxQueuedBytesPerStream = 4 * 1024 * 1024
|
||||
maxFrameSize = 256 * 1024
|
||||
maxDataPayload = 128 * 1024
|
||||
)
|
||||
|
||||
type acceptEvent struct {
|
||||
@@ -344,6 +347,8 @@ type stream struct {
|
||||
closeErr error
|
||||
readBuf []byte
|
||||
queue [][]byte
|
||||
// queuedBytes includes unread bytes in readBuf and queue.
|
||||
queuedBytes int
|
||||
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
@@ -362,16 +367,20 @@ func newStream(session *Session, id uint32) *stream {
|
||||
|
||||
func (c *stream) enqueue(payload []byte) {
|
||||
c.mu.Lock()
|
||||
for !c.closed && c.queuedBytes+len(payload) > maxQueuedBytesPerStream {
|
||||
c.cond.Wait()
|
||||
}
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.queuedBytes += len(payload)
|
||||
if len(c.readBuf) == 0 && len(c.queue) == 0 {
|
||||
c.readBuf = payload
|
||||
} else {
|
||||
c.queue = append(c.queue, payload)
|
||||
}
|
||||
c.cond.Signal()
|
||||
c.cond.Broadcast()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -413,7 +422,11 @@ func (c *stream) Read(p []byte) (int, error) {
|
||||
}
|
||||
if len(c.readBuf) == 0 && len(c.queue) > 0 {
|
||||
c.readBuf = c.queue[0]
|
||||
c.queue[0] = nil
|
||||
c.queue = c.queue[1:]
|
||||
if len(c.queue) == 0 {
|
||||
c.queue = nil
|
||||
}
|
||||
}
|
||||
if len(c.readBuf) == 0 && c.closed {
|
||||
if c.closeErr == nil {
|
||||
@@ -424,6 +437,14 @@ func (c *stream) Read(p []byte) (int, error) {
|
||||
|
||||
n := copy(p, c.readBuf)
|
||||
c.readBuf = c.readBuf[n:]
|
||||
if len(c.readBuf) == 0 {
|
||||
c.readBuf = nil
|
||||
}
|
||||
c.queuedBytes -= n
|
||||
if c.queuedBytes < 0 {
|
||||
c.queuedBytes = 0
|
||||
}
|
||||
c.cond.Broadcast()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
||||
91
transport/sudoku/multiplex/session_backpressure_test.go
Normal file
91
transport/sudoku/multiplex/session_backpressure_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package multiplex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestSession_LargeTransferBackpressure verifies that a transfer larger than
|
||||
// maxQueuedBytesPerStream completes correctly: the demux loop applies
|
||||
// backpressure (cond.Wait) instead of dropping data, and the reader draining
|
||||
// the stream wakes the blocked loop without deadlock.
|
||||
func TestSession_LargeTransferBackpressure(t *testing.T) {
|
||||
c1, c2 := net.Pipe()
|
||||
|
||||
client, err := NewClientSession(c1)
|
||||
if err != nil {
|
||||
t.Fatalf("client session: %v", err)
|
||||
}
|
||||
server, err := NewServerSession(c2)
|
||||
if err != nil {
|
||||
t.Fatalf("server session: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
// Payload bigger than the per-stream backpressure window (4MB).
|
||||
const total = 12 * 1024 * 1024
|
||||
payload := make([]byte, total)
|
||||
if _, err := rand.Read(payload); err != nil {
|
||||
t.Fatalf("rand: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var writeErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stream, err := client.OpenStream([]byte("hello"))
|
||||
if err != nil {
|
||||
writeErr = err
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
if _, err := stream.Write(payload); err != nil {
|
||||
writeErr = err
|
||||
return
|
||||
}
|
||||
_ = stream.(interface{ CloseWrite() error }).CloseWrite()
|
||||
}()
|
||||
|
||||
var got []byte
|
||||
var readErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stream, openPayload, err := server.AcceptStream()
|
||||
if err != nil {
|
||||
readErr = err
|
||||
return
|
||||
}
|
||||
if string(openPayload) != "hello" {
|
||||
readErr = io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
got, readErr = io.ReadAll(stream)
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { wg.Wait(); close(done) }()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(30 * time.Second):
|
||||
t.Fatal("transfer deadlocked (backpressure did not release)")
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
t.Fatalf("write: %v", writeErr)
|
||||
}
|
||||
if readErr != nil {
|
||||
t.Fatalf("read: %v", readErr)
|
||||
}
|
||||
if !bytes.Equal(got, payload) {
|
||||
t.Fatalf("payload mismatch: got %d bytes, want %d", len(got), len(payload))
|
||||
}
|
||||
}
|
||||
56
transport/sudoku/obfs/sudoku/ascii_mode_test.go
Normal file
56
transport/sudoku/obfs/sudoku/ascii_mode_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package sudoku
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeASCIIMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", "prefer_entropy"},
|
||||
{"entropy", "prefer_entropy"},
|
||||
{"prefer_ascii", "prefer_ascii"},
|
||||
{"up_ascii_down_entropy", "up_ascii_down_entropy"},
|
||||
{"up_entropy_down_ascii", "up_entropy_down_ascii"},
|
||||
{"up_prefer_ascii_down_prefer_entropy", "up_ascii_down_entropy"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := NormalizeASCIIMode(tt.in)
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizeASCIIMode(%q): %v", tt.in, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("NormalizeASCIIMode(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := NormalizeASCIIMode("up_ascii_down_binary"); err == nil {
|
||||
t.Fatalf("expected invalid directional mode to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTableWithCustomDirectionalOpposite(t *testing.T) {
|
||||
table, err := NewTableWithCustom("seed", "up_ascii_down_entropy", "xpxvvpvv")
|
||||
if err != nil {
|
||||
t.Fatalf("NewTableWithCustom: %v", err)
|
||||
}
|
||||
if !table.IsASCII {
|
||||
t.Fatalf("uplink table should be ascii")
|
||||
}
|
||||
opposite := table.OppositeDirection()
|
||||
if opposite == nil || opposite == table {
|
||||
t.Fatalf("expected distinct opposite table")
|
||||
}
|
||||
if opposite.IsASCII {
|
||||
t.Fatalf("downlink table should be entropy/custom")
|
||||
}
|
||||
|
||||
symmetric, err := NewTableWithCustom("seed", "prefer_ascii", "xpxvvpvv")
|
||||
if err != nil {
|
||||
t.Fatalf("NewTableWithCustom symmetric: %v", err)
|
||||
}
|
||||
if symmetric.OppositeDirection() != symmetric {
|
||||
t.Fatalf("symmetric table should point to itself")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package sudoku
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -10,6 +11,8 @@ import (
|
||||
|
||||
const IOBufferSize = 32 * 1024
|
||||
|
||||
const minDecodeReadSize = 64
|
||||
|
||||
var perm4 = [24][4]byte{
|
||||
{0, 1, 2, 3},
|
||||
{0, 1, 3, 2},
|
||||
@@ -52,7 +55,7 @@ type Conn struct {
|
||||
writeMu sync.Mutex
|
||||
writeBuf []byte
|
||||
|
||||
rng randomSource
|
||||
rng *sudokuRand
|
||||
paddingThreshold uint64
|
||||
}
|
||||
|
||||
@@ -97,6 +100,9 @@ func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
|
||||
}
|
||||
|
||||
func (sc *Conn) StopRecording() {
|
||||
if sc == nil {
|
||||
return
|
||||
}
|
||||
sc.recordLock.Lock()
|
||||
sc.recording.Store(false)
|
||||
sc.recorder = nil
|
||||
@@ -115,6 +121,9 @@ func (sc *Conn) GetBufferedAndRecorded() []byte {
|
||||
if sc.recorder != nil {
|
||||
recorded = sc.recorder.Bytes()
|
||||
}
|
||||
if sc.reader == nil {
|
||||
return recorded
|
||||
}
|
||||
|
||||
buffered := sc.reader.Buffered()
|
||||
if buffered > 0 {
|
||||
@@ -131,6 +140,9 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if sc == nil || sc.Conn == nil || sc.table == nil || sc.table.layout == nil || sc.rng == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
sc.writeMu.Lock()
|
||||
defer sc.writeMu.Unlock()
|
||||
@@ -140,16 +152,19 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (sc *Conn) Read(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if sc == nil || sc.Conn == nil || sc.reader == nil || len(sc.rawBuf) == 0 || sc.table == nil || sc.table.layout == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if n, ok := drainPending(p, &sc.pendingData); ok {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
outN := 0
|
||||
for {
|
||||
if sc.pendingData.available() > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
nr, rErr := sc.reader.Read(sc.rawBuf)
|
||||
nr, rErr := readRawLimited(sc.Conn, sc.reader, sc.rawBuf[:sudokuReadSize(len(p)-outN, len(sc.rawBuf))])
|
||||
if nr > 0 {
|
||||
chunk := sc.rawBuf[:nr]
|
||||
if sc.recording.Load() {
|
||||
@@ -160,34 +175,80 @@ func (sc *Conn) Read(p []byte) (n int, err error) {
|
||||
sc.recordLock.Unlock()
|
||||
}
|
||||
|
||||
layout := sc.table.layout
|
||||
for _, b := range chunk {
|
||||
table := sc.table
|
||||
layout := table.layout
|
||||
for i := 0; i < len(chunk); {
|
||||
if sc.hintCount == 0 && outN < len(p) && i+3 < len(chunk) &&
|
||||
layout.hintTable[chunk[i]] &&
|
||||
layout.hintTable[chunk[i+1]] &&
|
||||
layout.hintTable[chunk[i+2]] &&
|
||||
layout.hintTable[chunk[i+3]] {
|
||||
val, ok := table.DecodeMap[packHintBytes(chunk[i], chunk[i+1], chunk[i+2], chunk[i+3])]
|
||||
if !ok {
|
||||
return 0, ErrInvalidSudokuMapMiss
|
||||
}
|
||||
p[outN] = val
|
||||
outN++
|
||||
i += 4
|
||||
continue
|
||||
}
|
||||
|
||||
b := chunk[i]
|
||||
i++
|
||||
if !layout.hintTable[b] {
|
||||
continue
|
||||
}
|
||||
|
||||
sc.hintBuf[sc.hintCount] = b
|
||||
sc.hintCount++
|
||||
if sc.hintCount == len(sc.hintBuf) {
|
||||
key := packHintsToKey(sc.hintBuf)
|
||||
val, ok := sc.table.DecodeMap[key]
|
||||
if !ok {
|
||||
return 0, ErrInvalidSudokuMapMiss
|
||||
}
|
||||
sc.pendingData.appendByte(val)
|
||||
sc.hintCount = 0
|
||||
if sc.hintCount != len(sc.hintBuf) {
|
||||
continue
|
||||
}
|
||||
|
||||
val, ok := table.DecodeMap[packHintBytes(sc.hintBuf[0], sc.hintBuf[1], sc.hintBuf[2], sc.hintBuf[3])]
|
||||
if !ok {
|
||||
return 0, ErrInvalidSudokuMapMiss
|
||||
}
|
||||
outN = appendDecodedByte(p, outN, &sc.pendingData, val)
|
||||
sc.hintCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
if rErr != nil {
|
||||
if outN > 0 {
|
||||
return outN, nil
|
||||
}
|
||||
if n, ok := drainPending(p, &sc.pendingData); ok {
|
||||
return n, nil
|
||||
}
|
||||
return 0, rErr
|
||||
}
|
||||
if sc.pendingData.available() > 0 {
|
||||
break
|
||||
if outN > 0 {
|
||||
return outN, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, _ = drainPending(p, &sc.pendingData)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func sudokuReadSize(decodedRemaining, maxRaw int) int {
|
||||
if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 {
|
||||
return maxRaw
|
||||
}
|
||||
if decodedRemaining > (maxRaw-minDecodeReadSize)/5 {
|
||||
return maxRaw
|
||||
}
|
||||
|
||||
return decodedRemaining*5 + minDecodeReadSize
|
||||
}
|
||||
|
||||
func readRawLimited(conn net.Conn, reader *bufio.Reader, dst []byte) (int, error) {
|
||||
if len(dst) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if reader != nil && reader.Buffered() > 0 {
|
||||
return reader.Read(dst)
|
||||
}
|
||||
if conn == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return conn.Read(dst)
|
||||
}
|
||||
|
||||
51
transport/sudoku/obfs/sudoku/conn_roundtrip_test.go
Normal file
51
transport/sudoku/obfs/sudoku/conn_roundtrip_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConn_Roundtrip exercises the optimized Conn encode/decode hot paths:
|
||||
// the no-padding fast path (pMin==pMax==0), the always-padding path
|
||||
// (pMin==pMax==100), a probabilistic range, and the adaptive read-size /
|
||||
// 4-byte fast hint decode path across a variety of payload sizes and modes.
|
||||
func TestConn_Roundtrip(t *testing.T) {
|
||||
modes := []string{"prefer_entropy", "prefer_ascii"}
|
||||
paddings := []struct{ min, max int }{
|
||||
{0, 0}, // no-padding specialized path
|
||||
{100, 100}, // always-padding specialized path
|
||||
{20, 60}, // probabilistic path
|
||||
}
|
||||
sizes := []int{1, 3, 4, 7, 16, 100, 1000, 64 * 1024}
|
||||
|
||||
for _, mode := range modes {
|
||||
for _, pad := range paddings {
|
||||
for _, size := range sizes {
|
||||
payload := make([]byte, size)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i*31 + 7)
|
||||
}
|
||||
|
||||
table := NewTable("conn-roundtrip-seed", mode)
|
||||
|
||||
// Encode via Conn.Write.
|
||||
w := &mockConn{}
|
||||
enc := NewConn(w, table, pad.min, pad.max, false)
|
||||
if _, err := enc.Write(payload); err != nil {
|
||||
t.Fatalf("mode=%s pad=%v size=%d write: %v", mode, pad, size, err)
|
||||
}
|
||||
|
||||
// Decode via Conn.Read using the same table.
|
||||
dec := NewConn(&mockConn{readBuf: w.writeBuf}, table, pad.min, pad.max, false)
|
||||
got := make([]byte, size)
|
||||
if _, err := io.ReadFull(dec, got); err != nil {
|
||||
t.Fatalf("mode=%s pad=%v size=%d read: %v", mode, pad, size, err)
|
||||
}
|
||||
if !bytes.Equal(got, payload) {
|
||||
t.Fatalf("mode=%s pad=%v size=%d roundtrip mismatch", mode, pad, size)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
package sudoku
|
||||
|
||||
func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThreshold uint64, p []byte) []byte {
|
||||
func encodeSudokuPayload(dst []byte, table *Table, rng *sudokuRand, paddingThreshold uint64, p []byte) []byte {
|
||||
if len(p) == 0 {
|
||||
return dst[:0]
|
||||
}
|
||||
if paddingThreshold == 0 {
|
||||
return encodeSudokuPayloadNoPadding(dst, table, rng, p)
|
||||
}
|
||||
|
||||
outCapacity := len(p)*6 + 1
|
||||
if cap(dst) < outCapacity {
|
||||
@@ -13,8 +16,25 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre
|
||||
pads := table.PaddingPool
|
||||
padLen := len(pads)
|
||||
|
||||
if paddingThreshold >= probOne {
|
||||
for _, b := range p {
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
|
||||
puzzles := table.EncodeTable[b]
|
||||
puzzle := puzzles[rng.Intn(len(puzzles))]
|
||||
|
||||
perm := perm4[rng.Intn(len(perm4))]
|
||||
for _, idx := range perm {
|
||||
out = append(out, pads[rng.Intn(padLen)], puzzle[idx])
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
return out
|
||||
}
|
||||
|
||||
for _, b := range p {
|
||||
if shouldPad(rng, paddingThreshold) {
|
||||
if uint64(rng.Uint32()) < paddingThreshold {
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
}
|
||||
|
||||
@@ -22,15 +42,31 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre
|
||||
puzzle := puzzles[rng.Intn(len(puzzles))]
|
||||
perm := perm4[rng.Intn(len(perm4))]
|
||||
for _, idx := range perm {
|
||||
if shouldPad(rng, paddingThreshold) {
|
||||
if uint64(rng.Uint32()) < paddingThreshold {
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
}
|
||||
out = append(out, puzzle[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if shouldPad(rng, paddingThreshold) {
|
||||
if uint64(rng.Uint32()) < paddingThreshold {
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeSudokuPayloadNoPadding(dst []byte, table *Table, rng *sudokuRand, p []byte) []byte {
|
||||
outCapacity := len(p) * 4
|
||||
if cap(dst) < outCapacity {
|
||||
dst = make([]byte, 0, outCapacity)
|
||||
}
|
||||
out := dst[:0]
|
||||
|
||||
for _, b := range p {
|
||||
puzzles := table.EncodeTable[b]
|
||||
puzzle := puzzles[rng.Intn(len(puzzles))]
|
||||
perm := perm4[rng.Intn(len(perm4))]
|
||||
out = append(out, puzzle[perm[0]], puzzle[perm[1]], puzzle[perm[2]], puzzle[perm[3]])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
RngBatchSize = 128
|
||||
|
||||
packedProtectedPrefixBytes = 14
|
||||
packedIOBufferSize = 64 * 1024
|
||||
packedDecodeBufferSize = 96 * 1024
|
||||
)
|
||||
|
||||
// PackedConn encodes traffic with the packed Sudoku layout while preserving
|
||||
@@ -35,7 +35,7 @@ type PackedConn struct {
|
||||
readBits int
|
||||
|
||||
// Padding selection matches Conn's threshold-based model.
|
||||
rng randomSource
|
||||
rng *sudokuRand
|
||||
paddingThreshold uint64
|
||||
padMarker byte
|
||||
padPool []byte
|
||||
@@ -67,18 +67,20 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
||||
pc := &PackedConn{
|
||||
Conn: c,
|
||||
table: table,
|
||||
reader: bufio.NewReaderSize(c, IOBufferSize),
|
||||
rawBuf: make([]byte, IOBufferSize),
|
||||
reader: bufio.NewReaderSize(c, packedIOBufferSize),
|
||||
rawBuf: make([]byte, packedDecodeBufferSize),
|
||||
pendingData: newPendingBuffer(4096),
|
||||
writeBuf: make([]byte, 0, 4096),
|
||||
rng: localRng,
|
||||
paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
|
||||
}
|
||||
|
||||
pc.padMarker = table.layout.padMarker
|
||||
for _, b := range table.PaddingPool {
|
||||
if b != pc.padMarker {
|
||||
pc.padPool = append(pc.padPool, b)
|
||||
if table != nil && table.layout != nil {
|
||||
pc.padMarker = table.layout.padMarker
|
||||
for _, b := range table.PaddingPool {
|
||||
if b != pc.padMarker {
|
||||
pc.padPool = append(pc.padPool, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(pc.padPool) == 0 {
|
||||
@@ -87,18 +89,6 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
||||
return pc
|
||||
}
|
||||
|
||||
func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
|
||||
if shouldPad(pc.rng, pc.paddingThreshold) {
|
||||
out = append(out, pc.getPaddingByte())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (pc *PackedConn) appendGroup(out []byte, group byte) []byte {
|
||||
out = pc.maybeAddPadding(out)
|
||||
return append(out, pc.table.layout.groupByte(group))
|
||||
}
|
||||
|
||||
func (pc *PackedConn) appendForcedPadding(out []byte) []byte {
|
||||
return append(out, pc.getPaddingByte())
|
||||
}
|
||||
@@ -134,7 +124,7 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
|
||||
} else {
|
||||
pc.bitBuf &= (1 << pc.bitCount) - 1
|
||||
}
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
out = appendPackedGroup(out, pc.table.layout, pc.rng, pc.paddingThreshold, pc.padPool, group)
|
||||
}
|
||||
|
||||
effective++
|
||||
@@ -148,19 +138,49 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
|
||||
return out, limit
|
||||
}
|
||||
|
||||
func appendPackedGroup(out []byte, layout *byteLayout, rng *sudokuRand, paddingThreshold uint64, padPool []byte, group byte) []byte {
|
||||
if paddingThreshold != 0 {
|
||||
u := rng.Uint32()
|
||||
if uint64(u) < paddingThreshold {
|
||||
out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))])
|
||||
}
|
||||
}
|
||||
return append(out, layout.encodeGroup[group&0x3F])
|
||||
}
|
||||
|
||||
func maybeAppendPackedPadding(out []byte, rng *sudokuRand, paddingThreshold uint64, padPool []byte) []byte {
|
||||
if paddingThreshold != 0 {
|
||||
u := rng.Uint32()
|
||||
if uint64(u) < paddingThreshold {
|
||||
out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
pc.writeMu.Lock()
|
||||
defer pc.writeMu.Unlock()
|
||||
|
||||
needed := len(p)*3/2 + 32
|
||||
if pc.paddingThreshold == 0 {
|
||||
needed = ((len(p)+2)/3)*4 + 32
|
||||
}
|
||||
if cap(pc.writeBuf) < needed {
|
||||
pc.writeBuf = make([]byte, 0, needed)
|
||||
}
|
||||
out := pc.writeBuf[:0]
|
||||
layout := pc.table.layout
|
||||
rng := pc.rng
|
||||
paddingThreshold := pc.paddingThreshold
|
||||
padPool := pc.padPool
|
||||
|
||||
var prefixN int
|
||||
out, prefixN = pc.writeProtectedPrefix(out, p)
|
||||
@@ -181,7 +201,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
} else {
|
||||
pc.bitBuf &= (1 << pc.bitCount) - 1
|
||||
}
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,10 +215,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
|
||||
g4 := b3 & 0x3F
|
||||
|
||||
out = pc.appendGroup(out, g1)
|
||||
out = pc.appendGroup(out, g2)
|
||||
out = pc.appendGroup(out, g3)
|
||||
out = pc.appendGroup(out, g4)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,10 +231,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
|
||||
g4 := b3 & 0x3F
|
||||
|
||||
out = pc.appendGroup(out, g1)
|
||||
out = pc.appendGroup(out, g2)
|
||||
out = pc.appendGroup(out, g3)
|
||||
out = pc.appendGroup(out, g4)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4)
|
||||
}
|
||||
|
||||
for ; i < n; i++ {
|
||||
@@ -229,7 +249,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
} else {
|
||||
pc.bitBuf &= (1 << pc.bitCount) - 1
|
||||
}
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,11 +257,11 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
group := byte(pc.bitBuf << (6 - pc.bitCount))
|
||||
pc.bitBuf = 0
|
||||
pc.bitCount = 0
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
|
||||
out = append(out, pc.padMarker)
|
||||
}
|
||||
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = maybeAppendPackedPadding(out, rng, paddingThreshold, padPool)
|
||||
|
||||
if len(out) > 0 {
|
||||
pc.writeBuf = out[:0]
|
||||
@@ -252,6 +272,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (pc *PackedConn) Flush() error {
|
||||
if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
|
||||
pc.writeMu.Lock()
|
||||
defer pc.writeMu.Unlock()
|
||||
|
||||
@@ -265,7 +289,7 @@ func (pc *PackedConn) Flush() error {
|
||||
out = append(out, pc.padMarker)
|
||||
}
|
||||
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = maybeAppendPackedPadding(out, pc.rng, pc.paddingThreshold, pc.padPool)
|
||||
|
||||
if len(out) > 0 {
|
||||
pc.writeBuf = out[:0]
|
||||
@@ -289,19 +313,44 @@ func writeFull(w io.Writer, b []byte) error {
|
||||
}
|
||||
|
||||
func (pc *PackedConn) Read(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if pc == nil || pc.Conn == nil || pc.reader == nil || len(pc.rawBuf) == 0 || pc.table == nil || pc.table.layout == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if n, ok := drainPending(p, &pc.pendingData); ok {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
outN := 0
|
||||
for {
|
||||
nr, rErr := pc.reader.Read(pc.rawBuf)
|
||||
nr, rErr := readRawLimited(pc.Conn, pc.reader, pc.rawBuf[:packedReadSize(len(p)-outN, len(pc.rawBuf))])
|
||||
if nr > 0 {
|
||||
rBuf := pc.readBitBuf
|
||||
rBits := pc.readBits
|
||||
padMarker := pc.padMarker
|
||||
layout := pc.table.layout
|
||||
|
||||
for _, b := range pc.rawBuf[:nr] {
|
||||
chunk := pc.rawBuf[:nr]
|
||||
for i := 0; i < len(chunk); {
|
||||
if rBits == 0 && outN+3 <= len(p) && i+3 < len(chunk) &&
|
||||
layout.hintTable[chunk[i]] && layout.hintTable[chunk[i+1]] &&
|
||||
layout.hintTable[chunk[i+2]] && layout.hintTable[chunk[i+3]] {
|
||||
g1 := layout.decodeGroup[chunk[i]]
|
||||
g2 := layout.decodeGroup[chunk[i+1]]
|
||||
g3 := layout.decodeGroup[chunk[i+2]]
|
||||
g4 := layout.decodeGroup[chunk[i+3]]
|
||||
p[outN] = (g1 << 2) | (g2 >> 4)
|
||||
p[outN+1] = (g2 << 4) | (g3 >> 2)
|
||||
p[outN+2] = (g3 << 6) | g4
|
||||
outN += 3
|
||||
i += 4
|
||||
continue
|
||||
}
|
||||
|
||||
b := chunk[i]
|
||||
i++
|
||||
if !layout.hintTable[b] {
|
||||
if b == padMarker {
|
||||
rBuf = 0
|
||||
@@ -321,7 +370,7 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
|
||||
if rBits >= 8 {
|
||||
rBits -= 8
|
||||
val := byte(rBuf >> rBits)
|
||||
pc.pendingData.appendByte(val)
|
||||
outN = appendDecodedByte(p, outN, &pc.pendingData, val)
|
||||
if rBits == 0 {
|
||||
rBuf = 0
|
||||
} else {
|
||||
@@ -339,21 +388,32 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
|
||||
pc.readBitBuf = 0
|
||||
pc.readBits = 0
|
||||
}
|
||||
if pc.pendingData.available() > 0 {
|
||||
break
|
||||
if outN > 0 {
|
||||
return outN, nil
|
||||
}
|
||||
if n, ok := drainPending(p, &pc.pendingData); ok {
|
||||
return n, nil
|
||||
}
|
||||
return 0, rErr
|
||||
}
|
||||
|
||||
if pc.pendingData.available() > 0 {
|
||||
break
|
||||
if outN > 0 {
|
||||
return outN, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, _ := drainPending(p, &pc.pendingData)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (pc *PackedConn) getPaddingByte() byte {
|
||||
return pc.padPool[pc.rng.Intn(len(pc.padPool))]
|
||||
}
|
||||
|
||||
func packedReadSize(decodedRemaining, maxRaw int) int {
|
||||
if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 {
|
||||
return maxRaw
|
||||
}
|
||||
if decodedRemaining > (maxRaw-minDecodeReadSize)/2 {
|
||||
return maxRaw
|
||||
}
|
||||
|
||||
return decodedRemaining*2 + minDecodeReadSize
|
||||
}
|
||||
|
||||
90
transport/sudoku/obfs/sudoku/packed_prefix_test.go
Normal file
90
transport/sudoku/obfs/sudoku/packed_prefix_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
readBuf []byte
|
||||
writeBuf []byte
|
||||
}
|
||||
|
||||
func (c *mockConn) Read(p []byte) (int, error) {
|
||||
if len(c.readBuf) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, c.readBuf)
|
||||
c.readBuf = c.readBuf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Write(p []byte) (int, error) {
|
||||
c.writeBuf = append(c.writeBuf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Close() error { return nil }
|
||||
func (c *mockConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *mockConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *mockConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *mockConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *mockConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
func TestPackedConn_ProtectedPrefixPadding(t *testing.T) {
|
||||
table := NewTable("packed-prefix-seed", "prefer_ascii")
|
||||
mock := &mockConn{}
|
||||
writer := NewPackedConn(mock, table, 0, 0)
|
||||
writer.rng = newSudokuRand(1)
|
||||
|
||||
payload := bytes.Repeat([]byte{0}, 32)
|
||||
if _, err := writer.Write(payload); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
wire := append([]byte(nil), mock.writeBuf...)
|
||||
if len(wire) < 20 {
|
||||
t.Fatalf("wire too short: %d", len(wire))
|
||||
}
|
||||
|
||||
firstHint := -1
|
||||
nonHintCount := 0
|
||||
maxHintRun := 0
|
||||
currentHintRun := 0
|
||||
for i, b := range wire[:20] {
|
||||
if table.layout.isHint(b) {
|
||||
if firstHint == -1 {
|
||||
firstHint = i
|
||||
}
|
||||
currentHintRun++
|
||||
if currentHintRun > maxHintRun {
|
||||
maxHintRun = currentHintRun
|
||||
}
|
||||
continue
|
||||
}
|
||||
nonHintCount++
|
||||
currentHintRun = 0
|
||||
}
|
||||
|
||||
if firstHint < 1 || firstHint > 2 {
|
||||
t.Fatalf("expected 1-2 leading padding bytes, first hint index=%d", firstHint)
|
||||
}
|
||||
if nonHintCount < 6 {
|
||||
t.Fatalf("expected dense prefix padding, got only %d non-hint bytes in first 20", nonHintCount)
|
||||
}
|
||||
if maxHintRun > 3 {
|
||||
t.Fatalf("prefix still exposes long hint run: %d", maxHintRun)
|
||||
}
|
||||
|
||||
reader := NewPackedConn(&mockConn{readBuf: wire}, table, 0, 0)
|
||||
decoded := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(reader, decoded); err != nil {
|
||||
t.Fatalf("read back: %v", err)
|
||||
}
|
||||
if !bytes.Equal(decoded, payload) {
|
||||
t.Fatalf("roundtrip mismatch")
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package sudoku
|
||||
|
||||
const probOne = uint64(1) << 32
|
||||
|
||||
func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 {
|
||||
func pickPaddingThreshold(r *sudokuRand, pMin, pMax int) uint64 {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
@@ -28,7 +28,7 @@ func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 {
|
||||
return min + (u * (max - min) >> 32)
|
||||
}
|
||||
|
||||
func shouldPad(r randomSource, threshold uint64) bool {
|
||||
func shouldPad(r *sudokuRand, threshold uint64) bool {
|
||||
if threshold == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -25,7 +25,10 @@ func (p *pendingBuffer) reset() {
|
||||
}
|
||||
|
||||
func (p *pendingBuffer) ensureAppendCapacity(extra int) {
|
||||
if p == nil || extra <= 0 || p.off == 0 {
|
||||
if p == nil || extra <= 0 {
|
||||
return
|
||||
}
|
||||
if p.off == 0 {
|
||||
return
|
||||
}
|
||||
if cap(p.data)-len(p.data) >= extra {
|
||||
@@ -43,6 +46,15 @@ func (p *pendingBuffer) appendByte(b byte) {
|
||||
p.data = append(p.data, b)
|
||||
}
|
||||
|
||||
func appendDecodedByte(dst []byte, n int, pending *pendingBuffer, b byte) int {
|
||||
if n < len(dst) {
|
||||
dst[n] = b
|
||||
return n + 1
|
||||
}
|
||||
pending.appendByte(b)
|
||||
return n
|
||||
}
|
||||
|
||||
func drainPending(dst []byte, pending *pendingBuffer) (int, bool) {
|
||||
if pending == nil || pending.available() == 0 {
|
||||
return 0, false
|
||||
|
||||
@@ -6,14 +6,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type randomSource interface {
|
||||
Uint32() uint32
|
||||
Uint64() uint64
|
||||
Intn(n int) int
|
||||
}
|
||||
|
||||
type sudokuRand struct {
|
||||
state uint64
|
||||
state uint64
|
||||
cached uint32
|
||||
haveCached bool
|
||||
}
|
||||
|
||||
func newSeededRand() *sudokuRand {
|
||||
@@ -37,20 +33,36 @@ func (r *sudokuRand) Uint64() uint64 {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
r.state += 0x9e3779b97f4a7c15
|
||||
z := r.state
|
||||
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9
|
||||
z = (z ^ (z >> 27)) * 0x94d049bb133111eb
|
||||
return z ^ (z >> 31)
|
||||
r.haveCached = false
|
||||
x := r.state
|
||||
x ^= x >> 12
|
||||
x ^= x << 25
|
||||
x ^= x >> 27
|
||||
r.state = x
|
||||
return x * 0x2545f4914f6cdd1d
|
||||
}
|
||||
|
||||
func (r *sudokuRand) Uint32() uint32 {
|
||||
return uint32(r.Uint64() >> 32)
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
if r.haveCached {
|
||||
r.haveCached = false
|
||||
return r.cached
|
||||
}
|
||||
v := r.Uint64()
|
||||
r.cached = uint32(v)
|
||||
r.haveCached = true
|
||||
return uint32(v >> 32)
|
||||
}
|
||||
|
||||
func (r *sudokuRand) Intn(n int) int {
|
||||
if n <= 1 {
|
||||
return 0
|
||||
}
|
||||
return int((uint64(r.Uint32()) * uint64(n)) >> 32)
|
||||
return fastIntnFromUint32(r.Uint32(), n)
|
||||
}
|
||||
|
||||
func fastIntnFromUint32(u uint32, n int) int {
|
||||
return int((uint64(u) * uint64(n)) >> 32)
|
||||
}
|
||||
|
||||
@@ -192,23 +192,27 @@ func tableHintFingerprint(key string, mode string, uplinkPattern string, downlin
|
||||
}
|
||||
|
||||
func packHintsToKey(hints [4]byte) uint32 {
|
||||
return packHintBytes(hints[0], hints[1], hints[2], hints[3])
|
||||
}
|
||||
|
||||
func packHintBytes(h0, h1, h2, h3 byte) uint32 {
|
||||
// Sorting network for 4 elements (Bubble sort unrolled)
|
||||
// Swap if a > b
|
||||
if hints[0] > hints[1] {
|
||||
hints[0], hints[1] = hints[1], hints[0]
|
||||
if h0 > h1 {
|
||||
h0, h1 = h1, h0
|
||||
}
|
||||
if hints[2] > hints[3] {
|
||||
hints[2], hints[3] = hints[3], hints[2]
|
||||
if h2 > h3 {
|
||||
h2, h3 = h3, h2
|
||||
}
|
||||
if hints[0] > hints[2] {
|
||||
hints[0], hints[2] = hints[2], hints[0]
|
||||
if h0 > h2 {
|
||||
h0, h2 = h2, h0
|
||||
}
|
||||
if hints[1] > hints[3] {
|
||||
hints[1], hints[3] = hints[3], hints[1]
|
||||
if h1 > h3 {
|
||||
h1, h3 = h3, h1
|
||||
}
|
||||
if hints[1] > hints[2] {
|
||||
hints[1], hints[2] = hints[2], hints[1]
|
||||
if h1 > h2 {
|
||||
h1, h2 = h2, h1
|
||||
}
|
||||
|
||||
return uint32(hints[0])<<24 | uint32(hints[1])<<16 | uint32(hints[2])<<8 | uint32(hints[3])
|
||||
return uint32(h0)<<24 | uint32(h1)<<16 | uint32(h2)<<8 | uint32(h3)
|
||||
}
|
||||
|
||||
@@ -14,12 +14,14 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/common/congestion"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
@@ -50,7 +52,7 @@ type ClientOptions struct {
|
||||
QUIC bool
|
||||
CongestionControl string
|
||||
CWND int
|
||||
BBRProfile string
|
||||
Logger logger.Logger
|
||||
HealthCheck bool
|
||||
MaxConnections int
|
||||
MinStreams int
|
||||
@@ -81,7 +83,7 @@ func NewClient(ctx context.Context, options ClientOptions) (*Client, error) {
|
||||
healthCheck: options.HealthCheck,
|
||||
}
|
||||
if options.QUIC {
|
||||
congestionControlFactory, err := NewCongestionControl(options.CongestionControl, options.CWND, options.BBRProfile, ntp.TimeFuncFromContext(ctx))
|
||||
congestionControlFactory, err := congestion.NewCongestionControl(options.CongestionControl, options.CWND, ntp.TimeFuncFromContext(ctx))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
|
||||
@@ -2,6 +2,7 @@ package v2raygrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
@@ -13,13 +14,21 @@ type GunService interface {
|
||||
}
|
||||
|
||||
func ServerDesc(name string) grpc.ServiceDesc {
|
||||
serviceName := name
|
||||
streamName := "Tun"
|
||||
if strings.Contains(name, "/") {
|
||||
name = strings.TrimPrefix(name, "/")
|
||||
lastSlash := strings.LastIndex(name, "/")
|
||||
serviceName = name[:lastSlash]
|
||||
streamName = name[lastSlash+1:]
|
||||
}
|
||||
return grpc.ServiceDesc{
|
||||
ServiceName: name,
|
||||
ServiceName: serviceName,
|
||||
HandlerType: (*GunServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "Tun",
|
||||
StreamName: streamName,
|
||||
Handler: _GunService_Tun_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
@@ -30,7 +39,11 @@ func ServerDesc(name string) grpc.ServiceDesc {
|
||||
}
|
||||
|
||||
func (c *gunServiceClient) TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GunService_TunClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], "/"+name+"/Tun", opts...)
|
||||
path := "/" + name + "/Tun"
|
||||
if strings.Contains(name, "/") {
|
||||
path = name
|
||||
}
|
||||
stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], path, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -53,10 +53,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
DisableCompression: true,
|
||||
},
|
||||
url: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: serverAddr.String(),
|
||||
Path: "/" + options.ServiceName + "/Tun",
|
||||
RawPath: "/" + url.PathEscape(options.ServiceName) + "/Tun",
|
||||
Scheme: "https",
|
||||
Host: serverAddr.String(),
|
||||
Path: grpcPath(options.ServiceName),
|
||||
},
|
||||
host: host,
|
||||
}
|
||||
|
||||
10
transport/v2raygrpclite/path.go
Normal file
10
transport/v2raygrpclite/path.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package v2raygrpclite
|
||||
|
||||
import "strings"
|
||||
|
||||
func grpcPath(serviceName string) string {
|
||||
if strings.Contains(serviceName, "/") {
|
||||
return serviceName
|
||||
}
|
||||
return "/" + serviceName + "/Tun"
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
|
||||
tlsConfig: tlsConfig,
|
||||
logger: logger,
|
||||
handler: handler,
|
||||
path: "/" + options.ServiceName + "/Tun",
|
||||
path: grpcPath(options.ServiceName),
|
||||
h2Server: &http2.Server{
|
||||
IdleTimeout: time.Duration(options.IdleTimeout),
|
||||
},
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package v2raykcp
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/common/list"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type SendingWindow struct {
|
||||
cache *list.List
|
||||
cache *list.List[*DataSegment]
|
||||
totalInFlightSize uint32
|
||||
writer SegmentWriter
|
||||
onPacketLoss func(uint32)
|
||||
@@ -16,7 +16,7 @@ type SendingWindow struct {
|
||||
|
||||
func NewSendingWindow(writer SegmentWriter, onPacketLoss func(uint32)) *SendingWindow {
|
||||
return &SendingWindow{
|
||||
cache: list.New(),
|
||||
cache: list.New[*DataSegment](),
|
||||
writer: writer,
|
||||
onPacketLoss: onPacketLoss,
|
||||
}
|
||||
@@ -27,9 +27,9 @@ func (sw *SendingWindow) Release() {
|
||||
return
|
||||
}
|
||||
for sw.cache.Len() > 0 {
|
||||
seg := sw.cache.Front().Value.(*DataSegment)
|
||||
seg := sw.cache.Front().Value
|
||||
seg.Release()
|
||||
sw.cache.Remove(sw.cache.Front())
|
||||
sw.cache.Front().Remove()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,17 +50,17 @@ func (sw *SendingWindow) Push(number uint32, b *buf.Buffer) {
|
||||
}
|
||||
|
||||
func (sw *SendingWindow) FirstNumber() uint32 {
|
||||
return sw.cache.Front().Value.(*DataSegment).Number
|
||||
return sw.cache.Front().Value.Number
|
||||
}
|
||||
|
||||
func (sw *SendingWindow) Clear(una uint32) {
|
||||
for !sw.IsEmpty() {
|
||||
seg := sw.cache.Front().Value.(*DataSegment)
|
||||
seg := sw.cache.Front().Value
|
||||
if seg.Number >= una {
|
||||
break
|
||||
}
|
||||
seg.Release()
|
||||
sw.cache.Remove(sw.cache.Front())
|
||||
sw.cache.Front().Remove()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,8 +87,7 @@ func (sw *SendingWindow) Visit(visitor func(seg *DataSegment) bool) {
|
||||
}
|
||||
|
||||
for e := sw.cache.Front(); e != nil; e = e.Next() {
|
||||
seg := e.Value.(*DataSegment)
|
||||
if !visitor(seg) {
|
||||
if !visitor(e.Value) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -132,7 +131,7 @@ func (sw *SendingWindow) Remove(number uint32) bool {
|
||||
}
|
||||
|
||||
for e := sw.cache.Front(); e != nil; e = e.Next() {
|
||||
seg := e.Value.(*DataSegment)
|
||||
seg := e.Value
|
||||
if seg.Number > number {
|
||||
return false
|
||||
} else if seg.Number == number {
|
||||
@@ -140,7 +139,7 @@ func (sw *SendingWindow) Remove(number uint32) bool {
|
||||
sw.totalInFlightSize--
|
||||
}
|
||||
seg.Release()
|
||||
sw.cache.Remove(e)
|
||||
e.Remove()
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/congestion"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
"github.com/sagernet/sing-box/common/xray/buf"
|
||||
"github.com/sagernet/sing-box/common/xray/net"
|
||||
"github.com/sagernet/sing-box/common/xray/pipe"
|
||||
"github.com/sagernet/sing-box/common/xray/signal/done"
|
||||
"github.com/sagernet/sing-box/common/xray/uuid"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
qtls "github.com/sagernet/sing-quic"
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
sHTTP "github.com/sagernet/sing/protocol/http"
|
||||
"github.com/sagernet/sing/service"
|
||||
"golang.org/x/net/http2"
|
||||
@@ -42,15 +43,22 @@ type Client struct {
|
||||
baseRequestURL2 url.URL
|
||||
getHTTPClient func() (DialerClient, *XmuxClient)
|
||||
getHTTPClient2 func() (DialerClient, *XmuxClient)
|
||||
xmuxManager *XmuxManager
|
||||
xmuxManager2 *XmuxManager
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayXHTTPOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
|
||||
if options.Mode == "" {
|
||||
return nil, E.New("mode is not set")
|
||||
}
|
||||
if tlsConfig != nil && len(tlsConfig.NextProtos()) == 0 {
|
||||
tlsConfig.SetNextProtos([]string{"h2"})
|
||||
}
|
||||
if _, err := congestion.NewCongestionControl(options.CongestionController, options.CWND, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if options.Download != nil {
|
||||
if _, err := congestion.NewCongestionControl(options.Download.CongestionController, options.Download.CWND, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
dest := serverAddr
|
||||
baseRequestURL, err := getBaseRequestURL(&options.V2RayXHTTPBaseOptions, dest, tlsConfig)
|
||||
if err != nil {
|
||||
@@ -61,7 +69,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
|
||||
xmuxOptions = *options.Xmux
|
||||
}
|
||||
xmuxManager := NewXmuxManager(xmuxOptions, func() XmuxConn {
|
||||
return createHTTPClient(dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig)
|
||||
return createHTTPClient(ctx, dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig)
|
||||
})
|
||||
getHTTPClient := func() (DialerClient, *XmuxClient) {
|
||||
xmuxClient := xmuxManager.GetXmuxClient(ctx)
|
||||
@@ -69,6 +77,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
|
||||
}
|
||||
baseRequestURL2 := baseRequestURL
|
||||
getHTTPClient2 := getHTTPClient
|
||||
var xmuxManager2 *XmuxManager
|
||||
if options.Download != nil {
|
||||
options2 := options.Download
|
||||
dialer2 := dialer
|
||||
@@ -98,8 +107,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
|
||||
if options2.Xmux != nil {
|
||||
xmuxOptions2 = *options2.Xmux
|
||||
}
|
||||
xmuxManager2 := NewXmuxManager(xmuxOptions2, func() XmuxConn {
|
||||
return createHTTPClient(dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2)
|
||||
xmuxManager2 = NewXmuxManager(xmuxOptions2, func() XmuxConn {
|
||||
return createHTTPClient(ctx, dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2)
|
||||
})
|
||||
getHTTPClient2 = func() (DialerClient, *XmuxClient) {
|
||||
xmuxClient2 := xmuxManager2.GetXmuxClient(ctx)
|
||||
@@ -113,6 +122,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
|
||||
getHTTPClient2: getHTTPClient2,
|
||||
baseRequestURL: baseRequestURL,
|
||||
baseRequestURL2: baseRequestURL2,
|
||||
xmuxManager: xmuxManager,
|
||||
xmuxManager2: xmuxManager2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -121,8 +132,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
mode := c.options.Mode
|
||||
sessionId := ""
|
||||
if c.options.Mode != "stream-one" {
|
||||
sessionIdUuid := uuid.New()
|
||||
sessionId = sessionIdUuid.String()
|
||||
sessionId = GenerateSessionID(&c.options.V2RayXHTTPBaseOptions)
|
||||
}
|
||||
requestURL := c.baseRequestURL
|
||||
requestURL2 := c.baseRequestURL2
|
||||
@@ -182,10 +192,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
}
|
||||
scMaxEachPostBytes := options.GetNormalizedScMaxEachPostBytes()
|
||||
scMinPostsIntervalMs := options.GetNormalizedScMinPostsIntervalMs()
|
||||
if scMaxEachPostBytes.From <= 0 {
|
||||
panic("`scMaxEachPostBytes` should be bigger than 0")
|
||||
}
|
||||
maxUploadSize := scMaxEachPostBytes.Rand()
|
||||
maxUploadSize := int32(scMaxEachPostBytes.Rand())
|
||||
// WithSizeLimit(0) will still allow single bytes to pass, and a lot of
|
||||
// code relies on this behavior. Subtract 1 so that together with
|
||||
// uploadWriter wrapper, exact size limits can be enforced
|
||||
@@ -255,6 +262,10 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.xmuxManager.Close()
|
||||
if c.xmuxManager2 != nil {
|
||||
c.xmuxManager2.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -294,7 +305,7 @@ func getBaseRequestURL(options *option.V2RayXHTTPBaseOptions, dest M.Socksaddr,
|
||||
return requestURL, nil
|
||||
}
|
||||
|
||||
func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient {
|
||||
func createHTTPClient(ctx context.Context, dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient {
|
||||
httpVersion := decideHTTPVersion(tlsConfig)
|
||||
dialContext := func(ctxInner context.Context) (net.Conn, error) {
|
||||
conn, err := dialer.DialContext(ctxInner, "tcp", dest)
|
||||
@@ -319,6 +330,7 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH
|
||||
if keepAlivePeriod < 0 {
|
||||
keepAlivePeriod = 0
|
||||
}
|
||||
congestionControlFactory, _ := congestion.NewCongestionControl(options.CongestionController, options.CWND, ntp.TimeFuncFromContext(ctx))
|
||||
quicConfig := &quic.Config{
|
||||
MaxIdleTimeout: net.ConnIdleTimeout,
|
||||
// these two are defaults of quic-go/http3. the default of quic-go (no
|
||||
@@ -334,7 +346,14 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH
|
||||
if dErr != nil {
|
||||
return nil, dErr
|
||||
}
|
||||
return qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg)
|
||||
conn, dErr := qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg)
|
||||
if dErr != nil {
|
||||
return nil, dErr
|
||||
}
|
||||
if congestionControlFactory != nil {
|
||||
conn.SetCongestionControl(congestionControlFactory(conn))
|
||||
}
|
||||
return conn, nil
|
||||
},
|
||||
}
|
||||
case "2":
|
||||
|
||||
@@ -39,7 +39,7 @@ func (c *splitConn) Close() error {
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
return err
|
||||
return err2
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -147,7 +147,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio
|
||||
if c.httpVersion != "1.1" {
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
c.closed = true
|
||||
c.Close()
|
||||
return err
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
@@ -225,10 +225,9 @@ func (w *WaitReadCloser) Set(rc io.ReadCloser) {
|
||||
}
|
||||
|
||||
func (w *WaitReadCloser) Read(b []byte) (int, error) {
|
||||
<-w.Wait
|
||||
if w.ReadCloser == nil {
|
||||
if <-w.Wait; w.ReadCloser == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return w.ReadCloser.Read(b)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user