mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-05-14 00:51:12 +03:00
Update sing-box core
This commit is contained in:
176
common/badtls/raw_conn.go
Normal file
176
common/badtls/raw_conn.go
Normal file
@@ -0,0 +1,176 @@
|
||||
//go:build go1.25 && badlinkname
|
||||
|
||||
package badtls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
type RawConn struct {
|
||||
pointer unsafe.Pointer
|
||||
methods *Methods
|
||||
|
||||
IsClient *bool
|
||||
IsHandshakeComplete *atomic.Bool
|
||||
Vers *uint16
|
||||
CipherSuite *uint16
|
||||
|
||||
RawInput *bytes.Buffer
|
||||
Input *bytes.Reader
|
||||
Hand *bytes.Buffer
|
||||
|
||||
CloseNotifySent *bool
|
||||
CloseNotifyErr *error
|
||||
|
||||
In *RawHalfConn
|
||||
Out *RawHalfConn
|
||||
|
||||
BytesSent *int64
|
||||
PacketsSent *int64
|
||||
|
||||
ActiveCall *atomic.Int32
|
||||
Tmp *[16]byte
|
||||
}
|
||||
|
||||
func NewRawConn(rawTLSConn tls.Conn) (*RawConn, error) {
|
||||
var (
|
||||
pointer unsafe.Pointer
|
||||
methods *Methods
|
||||
loaded bool
|
||||
)
|
||||
for _, tlsCreator := range methodRegistry {
|
||||
pointer, methods, loaded = tlsCreator(rawTLSConn)
|
||||
if loaded {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !loaded {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
conn := &RawConn{
|
||||
pointer: pointer,
|
||||
methods: methods,
|
||||
}
|
||||
|
||||
rawConn := reflect.Indirect(reflect.ValueOf(rawTLSConn))
|
||||
|
||||
rawIsClient := rawConn.FieldByName("isClient")
|
||||
if !rawIsClient.IsValid() || rawIsClient.Kind() != reflect.Bool {
|
||||
return nil, E.New("invalid Conn.isClient")
|
||||
}
|
||||
conn.IsClient = (*bool)(unsafe.Pointer(rawIsClient.UnsafeAddr()))
|
||||
|
||||
rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
|
||||
if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
|
||||
return nil, E.New("invalid Conn.isHandshakeComplete")
|
||||
}
|
||||
conn.IsHandshakeComplete = (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
|
||||
|
||||
rawVers := rawConn.FieldByName("vers")
|
||||
if !rawVers.IsValid() || rawVers.Kind() != reflect.Uint16 {
|
||||
return nil, E.New("invalid Conn.vers")
|
||||
}
|
||||
conn.Vers = (*uint16)(unsafe.Pointer(rawVers.UnsafeAddr()))
|
||||
|
||||
rawCipherSuite := rawConn.FieldByName("cipherSuite")
|
||||
if !rawCipherSuite.IsValid() || rawCipherSuite.Kind() != reflect.Uint16 {
|
||||
return nil, E.New("invalid Conn.cipherSuite")
|
||||
}
|
||||
conn.CipherSuite = (*uint16)(unsafe.Pointer(rawCipherSuite.UnsafeAddr()))
|
||||
|
||||
rawRawInput := rawConn.FieldByName("rawInput")
|
||||
if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct {
|
||||
return nil, E.New("invalid Conn.rawInput")
|
||||
}
|
||||
conn.RawInput = (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr()))
|
||||
|
||||
rawInput := rawConn.FieldByName("input")
|
||||
if !rawInput.IsValid() || rawInput.Kind() != reflect.Struct {
|
||||
return nil, E.New("invalid Conn.input")
|
||||
}
|
||||
conn.Input = (*bytes.Reader)(unsafe.Pointer(rawInput.UnsafeAddr()))
|
||||
|
||||
rawHand := rawConn.FieldByName("hand")
|
||||
if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct {
|
||||
return nil, E.New("invalid Conn.hand")
|
||||
}
|
||||
conn.Hand = (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
|
||||
|
||||
rawCloseNotifySent := rawConn.FieldByName("closeNotifySent")
|
||||
if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool {
|
||||
return nil, E.New("invalid Conn.closeNotifySent")
|
||||
}
|
||||
conn.CloseNotifySent = (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr()))
|
||||
|
||||
rawCloseNotifyErr := rawConn.FieldByName("closeNotifyErr")
|
||||
if !rawCloseNotifyErr.IsValid() || rawCloseNotifyErr.Kind() != reflect.Interface {
|
||||
return nil, E.New("invalid Conn.closeNotifyErr")
|
||||
}
|
||||
conn.CloseNotifyErr = (*error)(unsafe.Pointer(rawCloseNotifyErr.UnsafeAddr()))
|
||||
|
||||
rawIn := rawConn.FieldByName("in")
|
||||
if !rawIn.IsValid() || rawIn.Kind() != reflect.Struct {
|
||||
return nil, E.New("invalid Conn.in")
|
||||
}
|
||||
halfIn, err := NewRawHalfConn(rawIn, methods)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "invalid Conn.in")
|
||||
}
|
||||
conn.In = halfIn
|
||||
|
||||
rawOut := rawConn.FieldByName("out")
|
||||
if !rawOut.IsValid() || rawOut.Kind() != reflect.Struct {
|
||||
return nil, E.New("invalid Conn.out")
|
||||
}
|
||||
halfOut, err := NewRawHalfConn(rawOut, methods)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "invalid Conn.out")
|
||||
}
|
||||
conn.Out = halfOut
|
||||
|
||||
rawBytesSent := rawConn.FieldByName("bytesSent")
|
||||
if !rawBytesSent.IsValid() || rawBytesSent.Kind() != reflect.Int64 {
|
||||
return nil, E.New("invalid Conn.bytesSent")
|
||||
}
|
||||
conn.BytesSent = (*int64)(unsafe.Pointer(rawBytesSent.UnsafeAddr()))
|
||||
|
||||
rawPacketsSent := rawConn.FieldByName("packetsSent")
|
||||
if !rawPacketsSent.IsValid() || rawPacketsSent.Kind() != reflect.Int64 {
|
||||
return nil, E.New("invalid Conn.packetsSent")
|
||||
}
|
||||
conn.PacketsSent = (*int64)(unsafe.Pointer(rawPacketsSent.UnsafeAddr()))
|
||||
|
||||
rawActiveCall := rawConn.FieldByName("activeCall")
|
||||
if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
|
||||
return nil, E.New("invalid Conn.activeCall")
|
||||
}
|
||||
conn.ActiveCall = (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
|
||||
|
||||
rawTmp := rawConn.FieldByName("tmp")
|
||||
if !rawTmp.IsValid() || rawTmp.Kind() != reflect.Array || rawTmp.Len() != 16 || rawTmp.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, E.New("invalid Conn.tmp")
|
||||
}
|
||||
conn.Tmp = (*[16]byte)(unsafe.Pointer(rawTmp.UnsafeAddr()))
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *RawConn) ReadRecord() error {
|
||||
return c.methods.readRecord(c.pointer)
|
||||
}
|
||||
|
||||
func (c *RawConn) HandlePostHandshakeMessage() error {
|
||||
return c.methods.handlePostHandshakeMessage(c.pointer)
|
||||
}
|
||||
|
||||
func (c *RawConn) WriteRecordLocked(typ uint16, data []byte) (int, error) {
|
||||
return c.methods.writeRecordLocked(c.pointer, typ, data)
|
||||
}
|
||||
121
common/badtls/raw_half_conn.go
Normal file
121
common/badtls/raw_half_conn.go
Normal file
@@ -0,0 +1,121 @@
|
||||
//go:build go1.25 && badlinkname
|
||||
|
||||
package badtls
|
||||
|
||||
import (
|
||||
"hash"
|
||||
"reflect"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type RawHalfConn struct {
|
||||
pointer unsafe.Pointer
|
||||
methods *Methods
|
||||
*sync.Mutex
|
||||
Err *error
|
||||
Version *uint16
|
||||
Cipher *any
|
||||
Seq *[8]byte
|
||||
ScratchBuf *[13]byte
|
||||
TrafficSecret *[]byte
|
||||
Mac *hash.Hash
|
||||
RawKey *[]byte
|
||||
RawIV *[]byte
|
||||
RawMac *[]byte
|
||||
}
|
||||
|
||||
func NewRawHalfConn(rawHalfConn reflect.Value, methods *Methods) (*RawHalfConn, error) {
|
||||
halfConn := &RawHalfConn{
|
||||
pointer: (unsafe.Pointer)(rawHalfConn.UnsafeAddr()),
|
||||
methods: methods,
|
||||
}
|
||||
|
||||
rawMutex := rawHalfConn.FieldByName("Mutex")
|
||||
if !rawMutex.IsValid() || rawMutex.Kind() != reflect.Struct {
|
||||
return nil, E.New("badtls: invalid halfConn.Mutex")
|
||||
}
|
||||
halfConn.Mutex = (*sync.Mutex)(unsafe.Pointer(rawMutex.UnsafeAddr()))
|
||||
|
||||
rawErr := rawHalfConn.FieldByName("err")
|
||||
if !rawErr.IsValid() || rawErr.Kind() != reflect.Interface {
|
||||
return nil, E.New("badtls: invalid halfConn.err")
|
||||
}
|
||||
halfConn.Err = (*error)(unsafe.Pointer(rawErr.UnsafeAddr()))
|
||||
|
||||
rawVersion := rawHalfConn.FieldByName("version")
|
||||
if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 {
|
||||
return nil, E.New("badtls: invalid halfConn.version")
|
||||
}
|
||||
halfConn.Version = (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr()))
|
||||
|
||||
rawCipher := rawHalfConn.FieldByName("cipher")
|
||||
if !rawCipher.IsValid() || rawCipher.Kind() != reflect.Interface {
|
||||
return nil, E.New("badtls: invalid halfConn.cipher")
|
||||
}
|
||||
halfConn.Cipher = (*any)(unsafe.Pointer(rawCipher.UnsafeAddr()))
|
||||
|
||||
rawSeq := rawHalfConn.FieldByName("seq")
|
||||
if !rawSeq.IsValid() || rawSeq.Kind() != reflect.Array || rawSeq.Len() != 8 || rawSeq.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, E.New("badtls: invalid halfConn.seq")
|
||||
}
|
||||
halfConn.Seq = (*[8]byte)(unsafe.Pointer(rawSeq.UnsafeAddr()))
|
||||
|
||||
rawScratchBuf := rawHalfConn.FieldByName("scratchBuf")
|
||||
if !rawScratchBuf.IsValid() || rawScratchBuf.Kind() != reflect.Array || rawScratchBuf.Len() != 13 || rawScratchBuf.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, E.New("badtls: invalid halfConn.scratchBuf")
|
||||
}
|
||||
halfConn.ScratchBuf = (*[13]byte)(unsafe.Pointer(rawScratchBuf.UnsafeAddr()))
|
||||
|
||||
rawTrafficSecret := rawHalfConn.FieldByName("trafficSecret")
|
||||
if !rawTrafficSecret.IsValid() || rawTrafficSecret.Kind() != reflect.Slice || rawTrafficSecret.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, E.New("badtls: invalid halfConn.trafficSecret")
|
||||
}
|
||||
halfConn.TrafficSecret = (*[]byte)(unsafe.Pointer(rawTrafficSecret.UnsafeAddr()))
|
||||
|
||||
rawMac := rawHalfConn.FieldByName("mac")
|
||||
if !rawMac.IsValid() || rawMac.Kind() != reflect.Interface {
|
||||
return nil, E.New("badtls: invalid halfConn.mac")
|
||||
}
|
||||
halfConn.Mac = (*hash.Hash)(unsafe.Pointer(rawMac.UnsafeAddr()))
|
||||
|
||||
rawKey := rawHalfConn.FieldByName("rawKey")
|
||||
if rawKey.IsValid() {
|
||||
if /*!rawKey.IsValid() || */ rawKey.Kind() != reflect.Slice || rawKey.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, E.New("badtls: invalid halfConn.rawKey")
|
||||
}
|
||||
halfConn.RawKey = (*[]byte)(unsafe.Pointer(rawKey.UnsafeAddr()))
|
||||
|
||||
rawIV := rawHalfConn.FieldByName("rawIV")
|
||||
if !rawIV.IsValid() || rawIV.Kind() != reflect.Slice || rawIV.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, E.New("badtls: invalid halfConn.rawIV")
|
||||
}
|
||||
halfConn.RawIV = (*[]byte)(unsafe.Pointer(rawIV.UnsafeAddr()))
|
||||
|
||||
rawMAC := rawHalfConn.FieldByName("rawMac")
|
||||
if !rawMAC.IsValid() || rawMAC.Kind() != reflect.Slice || rawMAC.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, E.New("badtls: invalid halfConn.rawMac")
|
||||
}
|
||||
halfConn.RawMac = (*[]byte)(unsafe.Pointer(rawMAC.UnsafeAddr()))
|
||||
}
|
||||
|
||||
return halfConn, nil
|
||||
}
|
||||
|
||||
func (hc *RawHalfConn) Decrypt(record []byte) ([]byte, uint8, error) {
|
||||
return hc.methods.decrypt(hc.pointer, record)
|
||||
}
|
||||
|
||||
func (hc *RawHalfConn) SetErrorLocked(err error) error {
|
||||
return hc.methods.setErrorLocked(hc.pointer, err)
|
||||
}
|
||||
|
||||
func (hc *RawHalfConn) SetTrafficSecret(suite unsafe.Pointer, level int, secret []byte) {
|
||||
hc.methods.setTrafficSecret(hc.pointer, suite, level, secret)
|
||||
}
|
||||
|
||||
func (hc *RawHalfConn) ExplicitNonceLen() int {
|
||||
return hc.methods.explicitNonceLen(hc.pointer)
|
||||
}
|
||||
@@ -1,18 +1,9 @@
|
||||
//go:build go1.21 && !without_badtls
|
||||
//go:build go1.25 && badlinkname
|
||||
|
||||
package badtls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
@@ -21,63 +12,21 @@ var _ N.ReadWaiter = (*ReadWaitConn)(nil)
|
||||
|
||||
type ReadWaitConn struct {
|
||||
tls.Conn
|
||||
halfAccess *sync.Mutex
|
||||
rawInput *bytes.Buffer
|
||||
input *bytes.Reader
|
||||
hand *bytes.Buffer
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
tlsReadRecord func() error
|
||||
tlsHandlePostHandshakeMessage func() error
|
||||
rawConn *RawConn
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
}
|
||||
|
||||
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
||||
var (
|
||||
loaded bool
|
||||
tlsReadRecord func() error
|
||||
tlsHandlePostHandshakeMessage func() error
|
||||
)
|
||||
for _, tlsCreator := range tlsRegistry {
|
||||
loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
|
||||
if loaded {
|
||||
break
|
||||
}
|
||||
if _, isReadWaitConn := conn.(N.ReadWaiter); isReadWaitConn {
|
||||
return conn, nil
|
||||
}
|
||||
if !loaded {
|
||||
return nil, os.ErrInvalid
|
||||
rawConn, err := NewRawConn(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawConn := reflect.Indirect(reflect.ValueOf(conn))
|
||||
rawHalfConn := rawConn.FieldByName("in")
|
||||
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
||||
return nil, E.New("badtls: invalid half conn")
|
||||
}
|
||||
rawHalfMutex := rawHalfConn.FieldByName("Mutex")
|
||||
if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
|
||||
return nil, E.New("badtls: invalid half mutex")
|
||||
}
|
||||
halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
|
||||
rawRawInput := rawConn.FieldByName("rawInput")
|
||||
if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct {
|
||||
return nil, E.New("badtls: invalid raw input")
|
||||
}
|
||||
rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr()))
|
||||
rawInput0 := rawConn.FieldByName("input")
|
||||
if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct {
|
||||
return nil, E.New("badtls: invalid input")
|
||||
}
|
||||
input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr()))
|
||||
rawHand := rawConn.FieldByName("hand")
|
||||
if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct {
|
||||
return nil, E.New("badtls: invalid hand")
|
||||
}
|
||||
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
|
||||
return &ReadWaitConn{
|
||||
Conn: conn,
|
||||
halfAccess: halfAccess,
|
||||
rawInput: rawInput,
|
||||
input: input,
|
||||
hand: hand,
|
||||
tlsReadRecord: tlsReadRecord,
|
||||
tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
|
||||
Conn: conn,
|
||||
rawConn: rawConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -87,36 +36,36 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy
|
||||
}
|
||||
|
||||
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
err = c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.halfAccess.Lock()
|
||||
defer c.halfAccess.Unlock()
|
||||
for c.input.Len() == 0 {
|
||||
err = c.tlsReadRecord()
|
||||
//err = c.HandshakeContext(context.Background())
|
||||
//if err != nil {
|
||||
// return
|
||||
//}
|
||||
c.rawConn.In.Lock()
|
||||
defer c.rawConn.In.Unlock()
|
||||
for c.rawConn.Input.Len() == 0 {
|
||||
err = c.rawConn.ReadRecord()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for c.hand.Len() > 0 {
|
||||
err = c.tlsHandlePostHandshakeMessage()
|
||||
for c.rawConn.Hand.Len() > 0 {
|
||||
err = c.rawConn.HandlePostHandshakeMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
buffer = c.readWaitOptions.NewBuffer()
|
||||
n, err := c.input.Read(buffer.FreeBytes())
|
||||
n, err := c.rawConn.Input.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
|
||||
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
|
||||
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
|
||||
c.rawInput.Bytes()[0] == 21 {
|
||||
_ = c.tlsReadRecord()
|
||||
if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.Input.Len() > 0 &&
|
||||
// recordType(c.RawInput.Bytes()[0]) == recordTypeAlert {
|
||||
c.rawConn.RawInput.Bytes()[0] == 21 {
|
||||
_ = c.rawConn.ReadRecord()
|
||||
// return n, err // will be io.EOF on closeNotify
|
||||
}
|
||||
|
||||
@@ -131,25 +80,3 @@ func (c *ReadWaitConn) Upstream() any {
|
||||
func (c *ReadWaitConn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
|
||||
|
||||
func init() {
|
||||
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
|
||||
tlsConn, loaded := conn.(*tls.STDConn)
|
||||
if !loaded {
|
||||
return
|
||||
}
|
||||
return true, func() error {
|
||||
return stdTLSReadRecord(tlsConn)
|
||||
}, func() error {
|
||||
return stdTLSHandlePostHandshakeMessage(tlsConn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
|
||||
func stdTLSReadRecord(c *tls.STDConn) error
|
||||
|
||||
//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
|
||||
func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !go1.21 || without_badtls
|
||||
//go:build !go1.25 || !badlinkname
|
||||
|
||||
package badtls
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
//go:build go1.21 && !without_badtls && with_utls
|
||||
|
||||
package badtls
|
||||
|
||||
import (
|
||||
"net"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/metacubex/utls"
|
||||
)
|
||||
|
||||
func init() {
|
||||
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
|
||||
switch tlsConn := conn.(type) {
|
||||
case *tls.UConn:
|
||||
return true, func() error {
|
||||
return utlsReadRecord(tlsConn.Conn)
|
||||
}, func() error {
|
||||
return utlsHandlePostHandshakeMessage(tlsConn.Conn)
|
||||
}
|
||||
case *tls.Conn:
|
||||
return true, func() error {
|
||||
return utlsReadRecord(tlsConn)
|
||||
}, func() error {
|
||||
return utlsHandlePostHandshakeMessage(tlsConn)
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
//go:linkname utlsReadRecord github.com/metacubex/utls.(*Conn).readRecord
|
||||
func utlsReadRecord(c *tls.Conn) error
|
||||
|
||||
//go:linkname utlsHandlePostHandshakeMessage github.com/metacubex/utls.(*Conn).handlePostHandshakeMessage
|
||||
func utlsHandlePostHandshakeMessage(c *tls.Conn) error
|
||||
62
common/badtls/registry.go
Normal file
62
common/badtls/registry.go
Normal file
@@ -0,0 +1,62 @@
|
||||
//go:build go1.25 && badlinkname
|
||||
|
||||
package badtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Methods struct {
|
||||
readRecord func(c unsafe.Pointer) error
|
||||
handlePostHandshakeMessage func(c unsafe.Pointer) error
|
||||
writeRecordLocked func(c unsafe.Pointer, typ uint16, data []byte) (int, error)
|
||||
|
||||
setErrorLocked func(hc unsafe.Pointer, err error) error
|
||||
decrypt func(hc unsafe.Pointer, record []byte) ([]byte, uint8, error)
|
||||
setTrafficSecret func(hc unsafe.Pointer, suite unsafe.Pointer, level int, secret []byte)
|
||||
explicitNonceLen func(hc unsafe.Pointer) int
|
||||
}
|
||||
|
||||
var methodRegistry []func(conn net.Conn) (unsafe.Pointer, *Methods, bool)
|
||||
|
||||
func init() {
|
||||
methodRegistry = append(methodRegistry, func(conn net.Conn) (unsafe.Pointer, *Methods, bool) {
|
||||
tlsConn, loaded := conn.(*tls.Conn)
|
||||
if !loaded {
|
||||
return nil, nil, false
|
||||
}
|
||||
return unsafe.Pointer(tlsConn), &Methods{
|
||||
readRecord: stdTLSReadRecord,
|
||||
handlePostHandshakeMessage: stdTLSHandlePostHandshakeMessage,
|
||||
writeRecordLocked: stdWriteRecordLocked,
|
||||
|
||||
setErrorLocked: stdSetErrorLocked,
|
||||
decrypt: stdDecrypt,
|
||||
setTrafficSecret: stdSetTrafficSecret,
|
||||
explicitNonceLen: stdExplicitNonceLen,
|
||||
}, true
|
||||
})
|
||||
}
|
||||
|
||||
//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
|
||||
func stdTLSReadRecord(c unsafe.Pointer) error
|
||||
|
||||
//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
|
||||
func stdTLSHandlePostHandshakeMessage(c unsafe.Pointer) error
|
||||
|
||||
//go:linkname stdWriteRecordLocked crypto/tls.(*Conn).writeRecordLocked
|
||||
func stdWriteRecordLocked(c unsafe.Pointer, typ uint16, data []byte) (int, error)
|
||||
|
||||
//go:linkname stdSetErrorLocked crypto/tls.(*halfConn).setErrorLocked
|
||||
func stdSetErrorLocked(hc unsafe.Pointer, err error) error
|
||||
|
||||
//go:linkname stdDecrypt crypto/tls.(*halfConn).decrypt
|
||||
func stdDecrypt(hc unsafe.Pointer, record []byte) ([]byte, uint8, error)
|
||||
|
||||
//go:linkname stdSetTrafficSecret crypto/tls.(*halfConn).setTrafficSecret
|
||||
func stdSetTrafficSecret(hc unsafe.Pointer, suite unsafe.Pointer, level int, secret []byte)
|
||||
|
||||
//go:linkname stdExplicitNonceLen crypto/tls.(*halfConn).explicitNonceLen
|
||||
func stdExplicitNonceLen(hc unsafe.Pointer) int
|
||||
56
common/badtls/registry_utls.go
Normal file
56
common/badtls/registry_utls.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build go1.25 && badlinkname
|
||||
|
||||
package badtls
|
||||
|
||||
import (
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/metacubex/utls"
|
||||
)
|
||||
|
||||
func init() {
|
||||
methodRegistry = append(methodRegistry, func(conn net.Conn) (unsafe.Pointer, *Methods, bool) {
|
||||
var pointer unsafe.Pointer
|
||||
if uConn, loaded := N.CastReader[*tls.Conn](conn); loaded {
|
||||
pointer = unsafe.Pointer(uConn)
|
||||
} else if uConn, loaded := N.CastReader[*tls.UConn](conn); loaded {
|
||||
pointer = unsafe.Pointer(uConn.Conn)
|
||||
} else {
|
||||
return nil, nil, false
|
||||
}
|
||||
return pointer, &Methods{
|
||||
readRecord: utlsReadRecord,
|
||||
handlePostHandshakeMessage: utlsHandlePostHandshakeMessage,
|
||||
writeRecordLocked: utlsWriteRecordLocked,
|
||||
|
||||
setErrorLocked: utlsSetErrorLocked,
|
||||
decrypt: utlsDecrypt,
|
||||
setTrafficSecret: utlsSetTrafficSecret,
|
||||
explicitNonceLen: utlsExplicitNonceLen,
|
||||
}, true
|
||||
})
|
||||
}
|
||||
|
||||
//go:linkname utlsReadRecord github.com/metacubex/utls.(*Conn).readRecord
|
||||
func utlsReadRecord(c unsafe.Pointer) error
|
||||
|
||||
//go:linkname utlsHandlePostHandshakeMessage github.com/metacubex/utls.(*Conn).handlePostHandshakeMessage
|
||||
func utlsHandlePostHandshakeMessage(c unsafe.Pointer) error
|
||||
|
||||
//go:linkname utlsWriteRecordLocked github.com/metacubex/utls.(*Conn).writeRecordLocked
|
||||
func utlsWriteRecordLocked(hc unsafe.Pointer, typ uint16, data []byte) (int, error)
|
||||
|
||||
//go:linkname utlsSetErrorLocked github.com/metacubex/utls.(*halfConn).setErrorLocked
|
||||
func utlsSetErrorLocked(hc unsafe.Pointer, err error) error
|
||||
|
||||
//go:linkname utlsDecrypt github.com/metacubex/utls.(*halfConn).decrypt
|
||||
func utlsDecrypt(hc unsafe.Pointer, record []byte) ([]byte, uint8, error)
|
||||
|
||||
//go:linkname utlsSetTrafficSecret github.com/metacubex/utls.(*halfConn).setTrafficSecret
|
||||
func utlsSetTrafficSecret(hc unsafe.Pointer, suite unsafe.Pointer, level int, secret []byte)
|
||||
|
||||
//go:linkname utlsExplicitNonceLen github.com/metacubex/utls.(*halfConn).explicitNonceLen
|
||||
func utlsExplicitNonceLen(hc unsafe.Pointer) int
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"strings"
|
||||
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
type Version struct {
|
||||
@@ -16,7 +18,19 @@ type Version struct {
|
||||
PreReleaseVersion int
|
||||
}
|
||||
|
||||
func (v Version) After(anotherVersion Version) bool {
|
||||
func (v Version) LessThan(anotherVersion Version) bool {
|
||||
return !v.GreaterThanOrEqual(anotherVersion)
|
||||
}
|
||||
|
||||
func (v Version) LessThanOrEqual(anotherVersion Version) bool {
|
||||
return v == anotherVersion || anotherVersion.GreaterThan(v)
|
||||
}
|
||||
|
||||
func (v Version) GreaterThanOrEqual(anotherVersion Version) bool {
|
||||
return v == anotherVersion || v.GreaterThan(anotherVersion)
|
||||
}
|
||||
|
||||
func (v Version) GreaterThan(anotherVersion Version) bool {
|
||||
if v.Major > anotherVersion.Major {
|
||||
return true
|
||||
} else if v.Major < anotherVersion.Major {
|
||||
@@ -44,19 +58,29 @@ func (v Version) After(anotherVersion Version) bool {
|
||||
} else if v.PreReleaseVersion < anotherVersion.PreReleaseVersion {
|
||||
return false
|
||||
}
|
||||
} else if v.PreReleaseIdentifier == "rc" && anotherVersion.PreReleaseIdentifier == "beta" {
|
||||
}
|
||||
preReleaseIdentifier := parsePreReleaseIdentifier(v.PreReleaseIdentifier)
|
||||
anotherPreReleaseIdentifier := parsePreReleaseIdentifier(anotherVersion.PreReleaseIdentifier)
|
||||
if preReleaseIdentifier < anotherPreReleaseIdentifier {
|
||||
return true
|
||||
} else if v.PreReleaseIdentifier == "beta" && anotherVersion.PreReleaseIdentifier == "rc" {
|
||||
return false
|
||||
} else if v.PreReleaseIdentifier == "beta" && anotherVersion.PreReleaseIdentifier == "alpha" {
|
||||
return true
|
||||
} else if v.PreReleaseIdentifier == "alpha" && anotherVersion.PreReleaseIdentifier == "beta" {
|
||||
} else if preReleaseIdentifier > anotherPreReleaseIdentifier {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parsePreReleaseIdentifier(identifier string) int {
|
||||
if strings.HasPrefix(identifier, "rc") {
|
||||
return 1
|
||||
} else if strings.HasPrefix(identifier, "beta") {
|
||||
return 2
|
||||
} else if strings.HasPrefix(identifier, "alpha") {
|
||||
return 3
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (v Version) VersionString() string {
|
||||
return F.ToString(v.Major, ".", v.Minor, ".", v.Patch)
|
||||
}
|
||||
@@ -83,6 +107,10 @@ func (v Version) BadString() string {
|
||||
return version
|
||||
}
|
||||
|
||||
func IsValid(versionName string) bool {
|
||||
return semver.IsValid("v" + versionName)
|
||||
}
|
||||
|
||||
func Parse(versionName string) (version Version) {
|
||||
if strings.HasPrefix(versionName, "v") {
|
||||
versionName = versionName[1:]
|
||||
|
||||
@@ -10,9 +10,9 @@ func TestCompareVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, "1.3.0-beta.1", Parse("v1.3.0-beta1").String())
|
||||
require.Equal(t, "1.3-beta1", Parse("v1.3.0-beta.1").BadString())
|
||||
require.True(t, Parse("1.3.0").After(Parse("1.3-beta1")))
|
||||
require.True(t, Parse("1.3.0").After(Parse("1.3.0-beta1")))
|
||||
require.True(t, Parse("1.3.0-beta1").After(Parse("1.3.0-alpha1")))
|
||||
require.True(t, Parse("1.3.1").After(Parse("1.3.0")))
|
||||
require.True(t, Parse("1.4").After(Parse("1.3")))
|
||||
require.True(t, Parse("1.3.0").GreaterThan(Parse("1.3-beta1")))
|
||||
require.True(t, Parse("1.3.0").GreaterThan(Parse("1.3.0-beta1")))
|
||||
require.True(t, Parse("1.3.0-beta1").GreaterThan(Parse("1.3.0-alpha1")))
|
||||
require.True(t, Parse("1.3.1").GreaterThan(Parse("1.3.0")))
|
||||
require.True(t, Parse("1.4").GreaterThan(Parse("1.3")))
|
||||
}
|
||||
|
||||
2817
common/certificate/chrome.go
Normal file
2817
common/certificate/chrome.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/experimental/libbox/platform"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
@@ -36,7 +35,7 @@ func NewStore(ctx context.Context, logger logger.Logger, options option.Certific
|
||||
switch options.Store {
|
||||
case C.CertificateStoreSystem, "":
|
||||
systemPool = x509.NewCertPool()
|
||||
platformInterface := service.FromContext[platform.Interface](ctx)
|
||||
platformInterface := service.FromContext[adapter.PlatformInterface](ctx)
|
||||
var systemValid bool
|
||||
if platformInterface != nil {
|
||||
for _, cert := range platformInterface.SystemCertificates() {
|
||||
@@ -54,6 +53,8 @@ func NewStore(ctx context.Context, logger logger.Logger, options option.Certific
|
||||
}
|
||||
case C.CertificateStoreMozilla:
|
||||
systemPool = mozillaIncluded
|
||||
case C.CertificateStoreChrome:
|
||||
systemPool = chromeIncluded
|
||||
case C.CertificateStoreNone:
|
||||
systemPool = nil
|
||||
default:
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
element *list.Element[io.Closer]
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn) (net.Conn, error) {
|
||||
connAccess.Lock()
|
||||
element := openConnection.PushBack(conn)
|
||||
connAccess.Unlock()
|
||||
if KillerEnabled {
|
||||
err := KillerCheck()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
element: element,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
if c.element.Value != nil {
|
||||
connAccess.Lock()
|
||||
if c.element.Value != nil {
|
||||
openConnection.Remove(c.element)
|
||||
c.element.Value = nil
|
||||
}
|
||||
connAccess.Unlock()
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *Conn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
runtimeDebug "runtime/debug"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/memory"
|
||||
)
|
||||
|
||||
var (
|
||||
KillerEnabled bool
|
||||
MemoryLimit uint64
|
||||
killerLastCheck time.Time
|
||||
)
|
||||
|
||||
func KillerCheck() error {
|
||||
if !KillerEnabled {
|
||||
return nil
|
||||
}
|
||||
nowTime := time.Now()
|
||||
if nowTime.Sub(killerLastCheck) < 3*time.Second {
|
||||
return nil
|
||||
}
|
||||
killerLastCheck = nowTime
|
||||
if memory.Total() > MemoryLimit {
|
||||
Close()
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
runtimeDebug.FreeOSMemory()
|
||||
}()
|
||||
return E.New("out of memory")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
element *list.Element[io.Closer]
|
||||
}
|
||||
|
||||
func NewPacketConn(conn net.PacketConn) (net.PacketConn, error) {
|
||||
connAccess.Lock()
|
||||
element := openConnection.PushBack(conn)
|
||||
connAccess.Unlock()
|
||||
if KillerEnabled {
|
||||
err := KillerCheck()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &PacketConn{
|
||||
PacketConn: conn,
|
||||
element: element,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *PacketConn) Close() error {
|
||||
if c.element.Value != nil {
|
||||
connAccess.Lock()
|
||||
if c.element.Value != nil {
|
||||
openConnection.Remove(c.element)
|
||||
c.element.Value = nil
|
||||
}
|
||||
connAccess.Unlock()
|
||||
}
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
||||
func (c *PacketConn) Upstream() any {
|
||||
return bufio.NewPacketConn(c.PacketConn)
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *PacketConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
||||
var (
|
||||
connAccess sync.RWMutex
|
||||
openConnection list.List[io.Closer]
|
||||
)
|
||||
|
||||
func Count() int {
|
||||
if !Enabled {
|
||||
return 0
|
||||
}
|
||||
return openConnection.Len()
|
||||
}
|
||||
|
||||
func List() []io.Closer {
|
||||
if !Enabled {
|
||||
return nil
|
||||
}
|
||||
connAccess.RLock()
|
||||
defer connAccess.RUnlock()
|
||||
connList := make([]io.Closer, 0, openConnection.Len())
|
||||
for element := openConnection.Front(); element != nil; element = element.Next() {
|
||||
connList = append(connList, element.Value)
|
||||
}
|
||||
return connList
|
||||
}
|
||||
|
||||
func Close() {
|
||||
if !Enabled {
|
||||
return
|
||||
}
|
||||
connAccess.Lock()
|
||||
defer connAccess.Unlock()
|
||||
for element := openConnection.Front(); element != nil; element = element.Next() {
|
||||
common.Close(element.Value)
|
||||
element.Value = nil
|
||||
}
|
||||
openConnection.Init()
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
//go:build !with_conntrack
|
||||
|
||||
package conntrack
|
||||
|
||||
const Enabled = false
|
||||
@@ -1,5 +0,0 @@
|
||||
//go:build with_conntrack
|
||||
|
||||
package conntrack
|
||||
|
||||
const Enabled = true
|
||||
@@ -9,10 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/conntrack"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/experimental/libbox/platform"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
@@ -20,6 +18,8 @@ import (
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
"github.com/database64128/tfo-go/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -28,14 +28,15 @@ var (
|
||||
)
|
||||
|
||||
type DefaultDialer struct {
|
||||
dialer4 tcpDialer
|
||||
dialer6 tcpDialer
|
||||
dialer4 tfo.Dialer
|
||||
dialer6 tfo.Dialer
|
||||
udpDialer4 net.Dialer
|
||||
udpDialer6 net.Dialer
|
||||
udpListener net.ListenConfig
|
||||
udpAddr4 string
|
||||
udpAddr6 string
|
||||
netns string
|
||||
connectionManager adapter.ConnectionManager
|
||||
networkManager adapter.NetworkManager
|
||||
networkStrategy *C.NetworkStrategy
|
||||
defaultNetworkStrategy bool
|
||||
@@ -46,8 +47,9 @@ type DefaultDialer struct {
|
||||
}
|
||||
|
||||
func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDialer, error) {
|
||||
connectionManager := service.FromContext[adapter.ConnectionManager](ctx)
|
||||
networkManager := service.FromContext[adapter.NetworkManager](ctx)
|
||||
platformInterface := service.FromContext[platform.Interface](ctx)
|
||||
platformInterface := service.FromContext[adapter.PlatformInterface](ctx)
|
||||
|
||||
var (
|
||||
dialer net.Dialer
|
||||
@@ -88,7 +90,7 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
|
||||
|
||||
if networkManager != nil {
|
||||
defaultOptions := networkManager.DefaultOptions()
|
||||
if defaultOptions.BindInterface != "" {
|
||||
if defaultOptions.BindInterface != "" && !disableDefaultBind {
|
||||
bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), defaultOptions.BindInterface, -1)
|
||||
dialer.Control = control.Append(dialer.Control, bindFunc)
|
||||
listener.Control = control.Append(listener.Control, bindFunc)
|
||||
@@ -136,14 +138,32 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
|
||||
dialer.Control = control.Append(dialer.Control, control.ProtectPath(options.ProtectPath))
|
||||
listener.Control = control.Append(listener.Control, control.ProtectPath(options.ProtectPath))
|
||||
}
|
||||
if options.BindAddressNoPort {
|
||||
if !C.IsLinux {
|
||||
return nil, E.New("`bind_address_no_port` is only supported on Linux")
|
||||
}
|
||||
dialer.Control = control.Append(dialer.Control, control.BindAddressNoPort())
|
||||
}
|
||||
if options.ConnectTimeout != 0 {
|
||||
dialer.Timeout = time.Duration(options.ConnectTimeout)
|
||||
} else {
|
||||
dialer.Timeout = C.TCPConnectTimeout
|
||||
}
|
||||
// TODO: Add an option to customize the keep alive period
|
||||
dialer.KeepAlive = C.TCPKeepAliveInitial
|
||||
dialer.Control = control.Append(dialer.Control, control.SetKeepAlivePeriod(C.TCPKeepAliveInitial, C.TCPKeepAliveInterval))
|
||||
if !options.DisableTCPKeepAlive {
|
||||
keepIdle := time.Duration(options.TCPKeepAlive)
|
||||
if keepIdle == 0 {
|
||||
keepIdle = C.TCPKeepAliveInitial
|
||||
}
|
||||
keepInterval := time.Duration(options.TCPKeepAliveInterval)
|
||||
if keepInterval == 0 {
|
||||
keepInterval = C.TCPKeepAliveInterval
|
||||
}
|
||||
dialer.KeepAliveConfig = net.KeepAliveConfig{
|
||||
Enable: true,
|
||||
Idle: keepIdle,
|
||||
Interval: keepInterval,
|
||||
}
|
||||
}
|
||||
var udpFragment bool
|
||||
if options.UDPFragment != nil {
|
||||
udpFragment = *options.UDPFragment
|
||||
@@ -177,19 +197,10 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
|
||||
udpAddr6 = M.SocksaddrFrom(bindAddr, 0).String()
|
||||
}
|
||||
if options.TCPMultiPath {
|
||||
if !go121Available {
|
||||
return nil, E.New("MultiPath TCP requires go1.21, please recompile your binary.")
|
||||
}
|
||||
setMultiPathTCP(&dialer4)
|
||||
}
|
||||
tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcpDialer6, err := newTCPDialer(dialer6, options.TCPFastOpen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
dialer4.SetMultipathTCP(true)
|
||||
}
|
||||
tcpDialer4 := tfo.Dialer{Dialer: dialer4, DisableTFO: !options.TCPFastOpen}
|
||||
tcpDialer6 := tfo.Dialer{Dialer: dialer6, DisableTFO: !options.TCPFastOpen}
|
||||
return &DefaultDialer{
|
||||
dialer4: tcpDialer4,
|
||||
dialer6: tcpDialer6,
|
||||
@@ -199,6 +210,7 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
|
||||
udpAddr4: udpAddr4,
|
||||
udpAddr6: udpAddr6,
|
||||
netns: options.NetNs,
|
||||
connectionManager: connectionManager,
|
||||
networkManager: networkManager,
|
||||
networkStrategy: networkStrategy,
|
||||
defaultNetworkStrategy: defaultNetworkStrategy,
|
||||
@@ -231,7 +243,7 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address
|
||||
return nil, E.New("domain not resolved")
|
||||
}
|
||||
if d.networkStrategy == nil {
|
||||
return trackConn(listener.ListenNetworkNamespace[net.Conn](d.netns, func() (net.Conn, error) {
|
||||
return d.trackConn(listener.ListenNetworkNamespace[net.Conn](d.netns, func() (net.Conn, error) {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkUDP:
|
||||
if !address.IsIPv6() {
|
||||
@@ -269,7 +281,7 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin
|
||||
}
|
||||
var dialer net.Dialer
|
||||
if N.NetworkName(network) == N.NetworkTCP {
|
||||
dialer = dialerFromTCPDialer(d.dialer4)
|
||||
dialer = d.dialer4.Dialer
|
||||
} else {
|
||||
dialer = d.udpDialer4
|
||||
}
|
||||
@@ -296,12 +308,12 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin
|
||||
if !fastFallback && !isPrimary {
|
||||
d.networkLastFallback.Store(time.Now())
|
||||
}
|
||||
return trackConn(conn, nil)
|
||||
return d.trackConn(conn, nil)
|
||||
}
|
||||
|
||||
func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if d.networkStrategy == nil {
|
||||
return trackPacketConn(listener.ListenNetworkNamespace[net.PacketConn](d.netns, func() (net.PacketConn, error) {
|
||||
return d.trackPacketConn(listener.ListenNetworkNamespace[net.PacketConn](d.netns, func() (net.PacketConn, error) {
|
||||
if destination.IsIPv6() {
|
||||
return d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6)
|
||||
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
|
||||
@@ -315,6 +327,14 @@ func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksadd
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultDialer) DialerForICMPDestination(destination netip.Addr) net.Dialer {
|
||||
if !destination.Is6() {
|
||||
return d.dialer6.Dialer
|
||||
} else {
|
||||
return d.dialer4.Dialer
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
|
||||
if strategy == nil {
|
||||
strategy = d.networkStrategy
|
||||
@@ -345,33 +365,23 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return trackPacketConn(packetConn, nil)
|
||||
return d.trackPacketConn(packetConn, nil)
|
||||
}
|
||||
|
||||
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
|
||||
udpListener := d.udpListener
|
||||
udpListener.Control = control.Append(udpListener.Control, func(network, address string, conn syscall.RawConn) error {
|
||||
for _, wgControlFn := range WgControlFns {
|
||||
err := wgControlFn(network, address, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return udpListener.ListenPacket(context.Background(), network, address)
|
||||
func (d *DefaultDialer) WireGuardControl() control.Func {
|
||||
return d.udpListener.Control
|
||||
}
|
||||
|
||||
func trackConn(conn net.Conn, err error) (net.Conn, error) {
|
||||
if !conntrack.Enabled || err != nil {
|
||||
func (d *DefaultDialer) trackConn(conn net.Conn, err error) (net.Conn, error) {
|
||||
if d.connectionManager == nil || err != nil {
|
||||
return conn, err
|
||||
}
|
||||
return conntrack.NewConn(conn)
|
||||
return d.connectionManager.TrackConn(conn), nil
|
||||
}
|
||||
|
||||
func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
|
||||
if !conntrack.Enabled || err != nil {
|
||||
func (d *DefaultDialer) trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
|
||||
if d.connectionManager == nil || err != nil {
|
||||
return conn, err
|
||||
}
|
||||
return conntrack.NewPacketConn(conn)
|
||||
return d.connectionManager.TrackPacketConn(conn), nil
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:build go1.20
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/metacubex/tfo-go"
|
||||
)
|
||||
|
||||
type tcpDialer = tfo.Dialer
|
||||
|
||||
func newTCPDialer(dialer net.Dialer, tfoEnabled bool) (tcpDialer, error) {
|
||||
return tfo.Dialer{Dialer: dialer, DisableTFO: !tfoEnabled}, nil
|
||||
}
|
||||
|
||||
func dialerFromTCPDialer(dialer tcpDialer) net.Dialer {
|
||||
return dialer.Dialer
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build go1.21
|
||||
|
||||
package dialer
|
||||
|
||||
import "net"
|
||||
|
||||
const go121Available = true
|
||||
|
||||
func setMultiPathTCP(dialer *net.Dialer) {
|
||||
dialer.SetMultipathTCP(true)
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
//go:build !go1.20
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type tcpDialer = net.Dialer
|
||||
|
||||
func newTCPDialer(dialer net.Dialer, tfoEnabled bool) (tcpDialer, error) {
|
||||
if tfoEnabled {
|
||||
return dialer, E.New("TCP Fast Open requires go1.20, please recompile your binary.")
|
||||
}
|
||||
return dialer, nil
|
||||
}
|
||||
|
||||
func dialerFromTCPDialer(dialer tcpDialer) net.Dialer {
|
||||
return dialer
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !go1.21
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
const go121Available = false
|
||||
|
||||
func setMultiPathTCP(dialer *net.Dialer) {
|
||||
}
|
||||
@@ -145,3 +145,7 @@ type ParallelNetworkDialer interface {
|
||||
DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error)
|
||||
ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error)
|
||||
}
|
||||
|
||||
type PacketDialerWithDestination interface {
|
||||
ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build go1.20
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
@@ -16,7 +14,7 @@ import (
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/metacubex/tfo-go"
|
||||
"github.com/database64128/tfo-go/v2"
|
||||
)
|
||||
|
||||
type slowOpenConn struct {
|
||||
@@ -32,7 +30,7 @@ type slowOpenConn struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
func DialSlowContext(dialer *tfo.Dialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP, N.NetworkUDP:
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
//go:build !go1.20
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP, N.NetworkUDP:
|
||||
return dialer.DialContext(ctx, network, destination.String())
|
||||
default:
|
||||
return dialer.DialContext(ctx, network, destination.AddrString())
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,9 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/control"
|
||||
)
|
||||
|
||||
type WireGuardListener interface {
|
||||
ListenPacketCompat(network, address string) (net.PacketConn, error)
|
||||
WireGuardControl() control.Func
|
||||
}
|
||||
|
||||
var WgControlFns []control.Func
|
||||
|
||||
234
common/geosite/compat_test.go
Normal file
234
common/geosite/compat_test.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package geosite
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Old implementation using varbin reflection-based serialization
|
||||
|
||||
func oldWriteString(writer varbin.Writer, value string) error {
|
||||
//nolint:staticcheck
|
||||
return varbin.Write(writer, binary.BigEndian, value)
|
||||
}
|
||||
|
||||
func oldWriteItem(writer varbin.Writer, item Item) error {
|
||||
//nolint:staticcheck
|
||||
return varbin.Write(writer, binary.BigEndian, item)
|
||||
}
|
||||
|
||||
func oldReadString(reader varbin.Reader) (string, error) {
|
||||
//nolint:staticcheck
|
||||
return varbin.ReadValue[string](reader, binary.BigEndian)
|
||||
}
|
||||
|
||||
func oldReadItem(reader varbin.Reader) (Item, error) {
|
||||
//nolint:staticcheck
|
||||
return varbin.ReadValue[Item](reader, binary.BigEndian)
|
||||
}
|
||||
|
||||
func TestStringCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"single_char", "a"},
|
||||
{"ascii", "example.com"},
|
||||
{"utf8", "测试域名.中国"},
|
||||
{"special_chars", "\x00\xff\n\t"},
|
||||
{"127_bytes", strings.Repeat("x", 127)},
|
||||
{"128_bytes", strings.Repeat("x", 128)},
|
||||
{"16383_bytes", strings.Repeat("x", 16383)},
|
||||
{"16384_bytes", strings.Repeat("x", 16384)},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Old write
|
||||
var oldBuf bytes.Buffer
|
||||
err := oldWriteString(&oldBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// New write
|
||||
var newBuf bytes.Buffer
|
||||
err = writeString(&newBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bytes must match
|
||||
require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
|
||||
"mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
|
||||
|
||||
// New write -> old read
|
||||
readBack, err := oldReadString(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.input, readBack)
|
||||
|
||||
// Old write -> new read
|
||||
readBack2, err := readString(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.input, readBack2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestItemCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Note: varbin.Write has a bug where struct values (not pointers) don't write their fields
|
||||
// because field.CanSet() returns false for non-addressable values.
|
||||
// The old geosite code passed Item values to varbin.Write, which silently wrote nothing.
|
||||
// The new code correctly writes Type + Value using manual serialization.
|
||||
// This test verifies the new serialization format and round-trip correctness.
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input Item
|
||||
}{
|
||||
{"domain_empty", Item{Type: RuleTypeDomain, Value: ""}},
|
||||
{"domain_normal", Item{Type: RuleTypeDomain, Value: "example.com"}},
|
||||
{"domain_suffix", Item{Type: RuleTypeDomainSuffix, Value: ".example.com"}},
|
||||
{"domain_keyword", Item{Type: RuleTypeDomainKeyword, Value: "google"}},
|
||||
{"domain_regex", Item{Type: RuleTypeDomainRegex, Value: `^.*\.example\.com$`}},
|
||||
{"utf8_domain", Item{Type: RuleTypeDomain, Value: "测试.com"}},
|
||||
{"long_domain", Item{Type: RuleTypeDomainSuffix, Value: strings.Repeat("a", 200) + ".com"}},
|
||||
{"128_bytes_value", Item{Type: RuleTypeDomain, Value: strings.Repeat("x", 128)}},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// New write
|
||||
var newBuf bytes.Buffer
|
||||
err := newBuf.WriteByte(byte(tc.input.Type))
|
||||
require.NoError(t, err)
|
||||
err = writeString(&newBuf, tc.input.Value)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify format: Type (1 byte) + Value (uvarint len + bytes)
|
||||
require.True(t, len(newBuf.Bytes()) >= 1, "output too short")
|
||||
require.Equal(t, byte(tc.input.Type), newBuf.Bytes()[0], "type byte mismatch")
|
||||
|
||||
// New write -> old read (varbin can read correctly when given addressable target)
|
||||
readBack, err := oldReadItem(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.input, readBack)
|
||||
|
||||
// New write -> new read
|
||||
reader := bufio.NewReader(bytes.NewReader(newBuf.Bytes()))
|
||||
typeByte, err := reader.ReadByte()
|
||||
require.NoError(t, err)
|
||||
value, err := readString(reader)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.input, Item{Type: ItemType(typeByte), Value: value})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeWriteReadCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input map[string][]Item
|
||||
}{
|
||||
{
|
||||
"empty_map",
|
||||
map[string][]Item{},
|
||||
},
|
||||
{
|
||||
"single_code_empty_items",
|
||||
map[string][]Item{"test": {}},
|
||||
},
|
||||
{
|
||||
"single_code_single_item",
|
||||
map[string][]Item{"test": {{Type: RuleTypeDomain, Value: "a.com"}}},
|
||||
},
|
||||
{
|
||||
"single_code_multi_items",
|
||||
map[string][]Item{
|
||||
"test": {
|
||||
{Type: RuleTypeDomain, Value: "a.com"},
|
||||
{Type: RuleTypeDomainSuffix, Value: ".b.com"},
|
||||
{Type: RuleTypeDomainKeyword, Value: "keyword"},
|
||||
{Type: RuleTypeDomainRegex, Value: `^.*$`},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"multi_code",
|
||||
map[string][]Item{
|
||||
"cn": {{Type: RuleTypeDomain, Value: "baidu.com"}, {Type: RuleTypeDomainSuffix, Value: ".cn"}},
|
||||
"us": {{Type: RuleTypeDomain, Value: "google.com"}},
|
||||
"jp": {{Type: RuleTypeDomainSuffix, Value: ".jp"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"utf8_values",
|
||||
map[string][]Item{
|
||||
"test": {
|
||||
{Type: RuleTypeDomain, Value: "测试.中国"},
|
||||
{Type: RuleTypeDomainSuffix, Value: ".テスト"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"large_items",
|
||||
generateLargeItems(1000),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Write using new implementation
|
||||
var buf bytes.Buffer
|
||||
err := Write(&buf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read back and verify
|
||||
reader, codes, err := NewReader(bytes.NewReader(buf.Bytes()))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all codes exist
|
||||
codeSet := make(map[string]bool)
|
||||
for _, code := range codes {
|
||||
codeSet[code] = true
|
||||
}
|
||||
for code := range tc.input {
|
||||
require.True(t, codeSet[code], "missing code: %s", code)
|
||||
}
|
||||
|
||||
// Verify items match
|
||||
for code, expectedItems := range tc.input {
|
||||
items, err := reader.Read(code)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedItems, items, "items mismatch for code: %s", code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateLargeItems(count int) map[string][]Item {
|
||||
items := make([]Item, count)
|
||||
for i := 0; i < count; i++ {
|
||||
items[i] = Item{
|
||||
Type: ItemType(i % 4),
|
||||
Value: strings.Repeat("x", i%200) + ".com",
|
||||
}
|
||||
}
|
||||
return map[string][]Item{"large": items}
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"sync/atomic"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
)
|
||||
|
||||
type Reader struct {
|
||||
@@ -78,7 +77,7 @@ func (r *Reader) readMetadata() error {
|
||||
codeIndex uint64
|
||||
codeLength uint64
|
||||
)
|
||||
code, err = varbin.ReadValue[string](reader, binary.BigEndian)
|
||||
code, err = readString(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -112,9 +111,16 @@ func (r *Reader) Read(code string) ([]Item, error) {
|
||||
}
|
||||
r.bufferedReader.Reset(r.reader)
|
||||
itemList := make([]Item, r.domainLength[code])
|
||||
err = varbin.Read(r.bufferedReader, binary.BigEndian, &itemList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
for i := range itemList {
|
||||
typeByte, err := r.bufferedReader.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
itemList[i].Type = ItemType(typeByte)
|
||||
itemList[i].Value, err = readString(r.bufferedReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return itemList, nil
|
||||
}
|
||||
@@ -135,3 +141,18 @@ func (r *readCounter) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func readString(reader io.ByteReader) (string, error) {
|
||||
length, err := binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
bytes := make([]byte, length)
|
||||
for i := range bytes {
|
||||
bytes[i], err = reader.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package geosite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"sort"
|
||||
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
@@ -20,7 +19,11 @@ func Write(writer varbin.Writer, domains map[string][]Item) error {
|
||||
for _, code := range keys {
|
||||
index[code] = content.Len()
|
||||
for _, item := range domains[code] {
|
||||
err := varbin.Write(content, binary.BigEndian, item)
|
||||
err := content.WriteByte(byte(item.Type))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = writeString(content, item.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -38,7 +41,7 @@ func Write(writer varbin.Writer, domains map[string][]Item) error {
|
||||
}
|
||||
|
||||
for _, code := range keys {
|
||||
err = varbin.Write(writer, binary.BigEndian, code)
|
||||
err = writeString(writer, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,3 +62,12 @@ func Write(writer varbin.Writer, domains map[string][]Item) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeString(writer varbin.Writer, value string) error {
|
||||
_, err := varbin.WriteUvarint(writer, uint64(len(value)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write([]byte(value))
|
||||
return err
|
||||
}
|
||||
|
||||
133
common/ktls/ktls.go
Normal file
133
common/ktls/ktls.go
Normal file
@@ -0,0 +1,133 @@
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing-box/common/badtls"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
aTLS.Conn
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
conn net.Conn
|
||||
rawConn *badtls.RawConn
|
||||
syscallConn syscall.Conn
|
||||
rawSyscallConn syscall.RawConn
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
kernelTx bool
|
||||
kernelRx bool
|
||||
pendingRxSplice bool
|
||||
}
|
||||
|
||||
func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
|
||||
err := Load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
syscallConn, isSyscallConn := N.CastReader[interface {
|
||||
io.Reader
|
||||
syscall.Conn
|
||||
}](conn.NetConn())
|
||||
if !isSyscallConn {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
rawSyscallConn, err := syscallConn.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawConn, err := badtls.NewRawConn(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if *rawConn.Vers != tls.VersionTLS13 {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
for rawConn.RawInput.Len() > 0 {
|
||||
err = rawConn.ReadRecord()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rawConn.Hand.Len() > 0 {
|
||||
err = rawConn.HandlePostHandshakeMessage()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "handle post-handshake messages")
|
||||
}
|
||||
}
|
||||
}
|
||||
kConn := &Conn{
|
||||
Conn: conn,
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
conn: conn.NetConn(),
|
||||
rawConn: rawConn,
|
||||
syscallConn: syscallConn,
|
||||
rawSyscallConn: rawSyscallConn,
|
||||
}
|
||||
err = kConn.setupKernel(txOffload, rxOffload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return kConn, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *Conn) SyscallConnForRead() syscall.RawConn {
|
||||
if !c.kernelRx {
|
||||
return nil
|
||||
}
|
||||
if !*c.rawConn.IsClient {
|
||||
c.logger.WarnContext(c.ctx, "ktls: RX splice is unavailable on the server size, since it will cause an unknown failure")
|
||||
return nil
|
||||
}
|
||||
c.logger.DebugContext(c.ctx, "ktls: RX splice requested")
|
||||
return c.rawSyscallConn
|
||||
}
|
||||
|
||||
func (c *Conn) HandleSyscallReadError(inputErr error) ([]byte, error) {
|
||||
if errors.Is(inputErr, unix.EINVAL) {
|
||||
c.pendingRxSplice = true
|
||||
err := c.readRecord()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "ktls: handle non-application-data record")
|
||||
}
|
||||
var input bytes.Buffer
|
||||
if c.rawConn.Input.Len() > 0 {
|
||||
_, err = c.rawConn.Input.WriteTo(&input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return input.Bytes(), nil
|
||||
} else if errors.Is(inputErr, unix.EBADMSG) {
|
||||
return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertBadRecordMAC))
|
||||
} else {
|
||||
return nil, E.Cause(inputErr, "ktls: unexpected errno")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) SyscallConnForWrite() syscall.RawConn {
|
||||
if !c.kernelTx {
|
||||
return nil
|
||||
}
|
||||
c.logger.DebugContext(c.ctx, "ktls: TX splice requested")
|
||||
return c.rawSyscallConn
|
||||
}
|
||||
80
common/ktls/ktls_alert.go
Normal file
80
common/ktls/ktls_alert.go
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
// alert level
|
||||
alertLevelWarning = 1
|
||||
alertLevelError = 2
|
||||
)
|
||||
|
||||
const (
|
||||
alertCloseNotify = 0
|
||||
alertUnexpectedMessage = 10
|
||||
alertBadRecordMAC = 20
|
||||
alertDecryptionFailed = 21
|
||||
alertRecordOverflow = 22
|
||||
alertDecompressionFailure = 30
|
||||
alertHandshakeFailure = 40
|
||||
alertBadCertificate = 42
|
||||
alertUnsupportedCertificate = 43
|
||||
alertCertificateRevoked = 44
|
||||
alertCertificateExpired = 45
|
||||
alertCertificateUnknown = 46
|
||||
alertIllegalParameter = 47
|
||||
alertUnknownCA = 48
|
||||
alertAccessDenied = 49
|
||||
alertDecodeError = 50
|
||||
alertDecryptError = 51
|
||||
alertExportRestriction = 60
|
||||
alertProtocolVersion = 70
|
||||
alertInsufficientSecurity = 71
|
||||
alertInternalError = 80
|
||||
alertInappropriateFallback = 86
|
||||
alertUserCanceled = 90
|
||||
alertNoRenegotiation = 100
|
||||
alertMissingExtension = 109
|
||||
alertUnsupportedExtension = 110
|
||||
alertCertificateUnobtainable = 111
|
||||
alertUnrecognizedName = 112
|
||||
alertBadCertificateStatusResponse = 113
|
||||
alertBadCertificateHashValue = 114
|
||||
alertUnknownPSKIdentity = 115
|
||||
alertCertificateRequired = 116
|
||||
alertNoApplicationProtocol = 120
|
||||
alertECHRequired = 121
|
||||
)
|
||||
|
||||
func (c *Conn) sendAlertLocked(err uint8) error {
|
||||
switch err {
|
||||
case alertNoRenegotiation, alertCloseNotify:
|
||||
c.rawConn.Tmp[0] = alertLevelWarning
|
||||
default:
|
||||
c.rawConn.Tmp[0] = alertLevelError
|
||||
}
|
||||
c.rawConn.Tmp[1] = byte(err)
|
||||
|
||||
_, writeErr := c.writeRecordLocked(recordTypeAlert, c.rawConn.Tmp[0:2])
|
||||
if err == alertCloseNotify {
|
||||
// closeNotify is a special case in that it isn't an error.
|
||||
return writeErr
|
||||
}
|
||||
|
||||
return c.rawConn.Out.SetErrorLocked(&net.OpError{Op: "local error", Err: tls.AlertError(err)})
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
func (c *Conn) sendAlert(err uint8) error {
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
return c.sendAlertLocked(err)
|
||||
}
|
||||
326
common/ktls/ktls_cipher_suites_linux.go
Normal file
326
common/ktls/ktls_cipher_suites_linux.go
Normal file
@@ -0,0 +1,326 @@
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing-box/common/badtls"
|
||||
)
|
||||
|
||||
type kernelCryptoCipherType uint16
|
||||
|
||||
const (
|
||||
TLS_CIPHER_AES_GCM_128 kernelCryptoCipherType = 51
|
||||
TLS_CIPHER_AES_GCM_128_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_AES_GCM_128_KEY_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_GCM_128_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_AES_GCM_128_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_AES_GCM_256 kernelCryptoCipherType = 52
|
||||
TLS_CIPHER_AES_GCM_256_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_AES_GCM_256_KEY_SIZE kernelCryptoCipherType = 32
|
||||
TLS_CIPHER_AES_GCM_256_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_AES_GCM_256_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_AES_CCM_128 kernelCryptoCipherType = 53
|
||||
TLS_CIPHER_AES_CCM_128_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_AES_CCM_128_KEY_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_CCM_128_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_AES_CCM_128_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_CHACHA20_POLY1305 kernelCryptoCipherType = 54
|
||||
TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE kernelCryptoCipherType = 12
|
||||
TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE kernelCryptoCipherType = 32
|
||||
TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE kernelCryptoCipherType = 0
|
||||
TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
// TLS_CIPHER_SM4_GCM kernelCryptoCipherType = 55
|
||||
// TLS_CIPHER_SM4_GCM_IV_SIZE kernelCryptoCipherType = 8
|
||||
// TLS_CIPHER_SM4_GCM_KEY_SIZE kernelCryptoCipherType = 16
|
||||
// TLS_CIPHER_SM4_GCM_SALT_SIZE kernelCryptoCipherType = 4
|
||||
// TLS_CIPHER_SM4_GCM_TAG_SIZE kernelCryptoCipherType = 16
|
||||
// TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
// TLS_CIPHER_SM4_CCM kernelCryptoCipherType = 56
|
||||
// TLS_CIPHER_SM4_CCM_IV_SIZE kernelCryptoCipherType = 8
|
||||
// TLS_CIPHER_SM4_CCM_KEY_SIZE kernelCryptoCipherType = 16
|
||||
// TLS_CIPHER_SM4_CCM_SALT_SIZE kernelCryptoCipherType = 4
|
||||
// TLS_CIPHER_SM4_CCM_TAG_SIZE kernelCryptoCipherType = 16
|
||||
// TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_ARIA_GCM_128 kernelCryptoCipherType = 57
|
||||
TLS_CIPHER_ARIA_GCM_128_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_ARIA_GCM_128_KEY_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_ARIA_GCM_128_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_ARIA_GCM_128_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
|
||||
TLS_CIPHER_ARIA_GCM_256 kernelCryptoCipherType = 58
|
||||
TLS_CIPHER_ARIA_GCM_256_IV_SIZE kernelCryptoCipherType = 8
|
||||
TLS_CIPHER_ARIA_GCM_256_KEY_SIZE kernelCryptoCipherType = 32
|
||||
TLS_CIPHER_ARIA_GCM_256_SALT_SIZE kernelCryptoCipherType = 4
|
||||
TLS_CIPHER_ARIA_GCM_256_TAG_SIZE kernelCryptoCipherType = 16
|
||||
TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE kernelCryptoCipherType = 8
|
||||
)
|
||||
|
||||
type kernelCrypto interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
type kernelCryptoInfo struct {
|
||||
version uint16
|
||||
cipher_type kernelCryptoCipherType
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoAES128GCM{}
|
||||
|
||||
type kernelCryptoAES128GCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_AES_GCM_128_IV_SIZE]byte
|
||||
key [TLS_CIPHER_AES_GCM_128_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_AES_GCM_128_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoAES128GCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_AES_GCM_128
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoAES256GCM{}
|
||||
|
||||
type kernelCryptoAES256GCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_AES_GCM_256_IV_SIZE]byte
|
||||
key [TLS_CIPHER_AES_GCM_256_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_AES_GCM_256_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoAES256GCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_AES_GCM_256
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoAES128CCM{}
|
||||
|
||||
type kernelCryptoAES128CCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_AES_CCM_128_IV_SIZE]byte
|
||||
key [TLS_CIPHER_AES_CCM_128_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_AES_CCM_128_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoAES128CCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_AES_CCM_128
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoChacha20Poly1035{}
|
||||
|
||||
type kernelCryptoChacha20Poly1035 struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE]byte
|
||||
key [TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoChacha20Poly1035) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_CHACHA20_POLY1305
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
// var _ kernelCrypto = &kernelCryptoSM4GCM{}
|
||||
|
||||
// type kernelCryptoSM4GCM struct {
|
||||
// kernelCryptoInfo
|
||||
// iv [TLS_CIPHER_SM4_GCM_IV_SIZE]byte
|
||||
// key [TLS_CIPHER_SM4_GCM_KEY_SIZE]byte
|
||||
// salt [TLS_CIPHER_SM4_GCM_SALT_SIZE]byte
|
||||
// rec_seq [TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE]byte
|
||||
// }
|
||||
|
||||
// func (crypto *kernelCryptoSM4GCM) String() string {
|
||||
// crypto.cipher_type = TLS_CIPHER_SM4_GCM
|
||||
// return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
// }
|
||||
|
||||
// var _ kernelCrypto = &kernelCryptoSM4CCM{}
|
||||
|
||||
// type kernelCryptoSM4CCM struct {
|
||||
// kernelCryptoInfo
|
||||
// iv [TLS_CIPHER_SM4_CCM_IV_SIZE]byte
|
||||
// key [TLS_CIPHER_SM4_CCM_KEY_SIZE]byte
|
||||
// salt [TLS_CIPHER_SM4_CCM_SALT_SIZE]byte
|
||||
// rec_seq [TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE]byte
|
||||
// }
|
||||
|
||||
// func (crypto *kernelCryptoSM4CCM) String() string {
|
||||
// crypto.cipher_type = TLS_CIPHER_SM4_CCM
|
||||
// return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
// }
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoARIA128GCM{}
|
||||
|
||||
type kernelCryptoARIA128GCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_ARIA_GCM_128_IV_SIZE]byte
|
||||
key [TLS_CIPHER_ARIA_GCM_128_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_ARIA_GCM_128_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoARIA128GCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_ARIA_GCM_128
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
var _ kernelCrypto = &kernelCryptoARIA256GCM{}
|
||||
|
||||
type kernelCryptoARIA256GCM struct {
|
||||
kernelCryptoInfo
|
||||
iv [TLS_CIPHER_ARIA_GCM_256_IV_SIZE]byte
|
||||
key [TLS_CIPHER_ARIA_GCM_256_KEY_SIZE]byte
|
||||
salt [TLS_CIPHER_ARIA_GCM_256_SALT_SIZE]byte
|
||||
rec_seq [TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE]byte
|
||||
}
|
||||
|
||||
func (crypto *kernelCryptoARIA256GCM) String() string {
|
||||
crypto.cipher_type = TLS_CIPHER_ARIA_GCM_256
|
||||
return string((*[unsafe.Sizeof(*crypto)]byte)(unsafe.Pointer(crypto))[:])
|
||||
}
|
||||
|
||||
func kernelCipher(kernel *Support, hc *badtls.RawHalfConn, cipherSuite uint16, isRX bool) kernelCrypto {
|
||||
if !kernel.TLS {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch *hc.Version {
|
||||
case tls.VersionTLS12:
|
||||
if isRX && !kernel.TLS_Version13_RX {
|
||||
return nil
|
||||
}
|
||||
|
||||
case tls.VersionTLS13:
|
||||
if !kernel.TLS_Version13 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if isRX && !kernel.TLS_Version13_RX {
|
||||
return nil
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
var key, iv []byte
|
||||
if *hc.Version == tls.VersionTLS13 {
|
||||
key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), *hc.TrafficSecret)
|
||||
/*if isRX {
|
||||
key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), keyLog.RemoteTrafficSecret)
|
||||
} else {
|
||||
key, iv = trafficKey(cipherSuiteTLS13ByID(cipherSuite), keyLog.TrafficSecret)
|
||||
}*/
|
||||
} else {
|
||||
// csPtr := cipherSuiteByID(cipherSuite)
|
||||
// keysFromMasterSecret(*hc.Version, csPtr, keyLog.Secret, keyLog.Random)
|
||||
return nil
|
||||
}
|
||||
|
||||
switch cipherSuite {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
|
||||
crypto := new(kernelCryptoAES128GCM)
|
||||
|
||||
crypto.version = *hc.Version
|
||||
copy(crypto.key[:], key)
|
||||
copy(crypto.iv[:], iv[4:])
|
||||
copy(crypto.salt[:], iv[:4])
|
||||
crypto.rec_seq = *hc.Seq
|
||||
|
||||
return crypto
|
||||
case tls.TLS_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
|
||||
if !kernel.TLS_AES_256_GCM {
|
||||
return nil
|
||||
}
|
||||
|
||||
crypto := new(kernelCryptoAES256GCM)
|
||||
|
||||
crypto.version = *hc.Version
|
||||
copy(crypto.key[:], key)
|
||||
copy(crypto.iv[:], iv[4:])
|
||||
copy(crypto.salt[:], iv[:4])
|
||||
crypto.rec_seq = *hc.Seq
|
||||
|
||||
return crypto
|
||||
//case tls.TLS_AES_128_CCM_SHA256, tls.TLS_RSA_WITH_AES_128_CCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_SHA256:
|
||||
// if !kernel.TLS_AES_128_CCM {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// crypto := new(kernelCryptoAES128CCM)
|
||||
//
|
||||
// crypto.version = *hc.Version
|
||||
// copy(crypto.key[:], key)
|
||||
// copy(crypto.iv[:], iv[4:])
|
||||
// copy(crypto.salt[:], iv[:4])
|
||||
// crypto.rec_seq = *hc.Seq
|
||||
//
|
||||
// return crypto
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
|
||||
if !kernel.TLS_CHACHA20_POLY1305 {
|
||||
return nil
|
||||
}
|
||||
|
||||
crypto := new(kernelCryptoChacha20Poly1035)
|
||||
|
||||
crypto.version = *hc.Version
|
||||
copy(crypto.key[:], key)
|
||||
copy(crypto.iv[:], iv)
|
||||
crypto.rec_seq = *hc.Seq
|
||||
|
||||
return crypto
|
||||
//case tls.TLS_RSA_WITH_ARIA_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256:
|
||||
// if !kernel.TLS_ARIA_GCM {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// crypto := new(kernelCryptoARIA128GCM)
|
||||
//
|
||||
// crypto.version = *hc.Version
|
||||
// copy(crypto.key[:], key)
|
||||
// copy(crypto.iv[:], iv[4:])
|
||||
// copy(crypto.salt[:], iv[:4])
|
||||
// crypto.rec_seq = *hc.Seq
|
||||
//
|
||||
// return crypto
|
||||
//case tls.TLS_RSA_WITH_ARIA_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384:
|
||||
// if !kernel.TLS_ARIA_GCM {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// crypto := new(kernelCryptoARIA256GCM)
|
||||
//
|
||||
// crypto.version = *hc.Version
|
||||
// copy(crypto.key[:], key)
|
||||
// copy(crypto.iv[:], iv[4:])
|
||||
// copy(crypto.salt[:], iv[:4])
|
||||
// crypto.rec_seq = *hc.Seq
|
||||
//
|
||||
// return crypto
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
67
common/ktls/ktls_close.go
Normal file
67
common/ktls/ktls_close.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
if !c.kernelTx {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// Interlock with Conn.Write above.
|
||||
var x int32
|
||||
for {
|
||||
x = c.rawConn.ActiveCall.Load()
|
||||
if x&1 != 0 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if c.rawConn.ActiveCall.CompareAndSwap(x, x|1) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if x != 0 {
|
||||
// io.Writer and io.Closer should not be used concurrently.
|
||||
// If Close is called while a Write is currently in-flight,
|
||||
// interpret that as a sign that this Close is really just
|
||||
// being used to break the Write and/or clean up resources and
|
||||
// avoid sending the alertCloseNotify, which may block
|
||||
// waiting on handshakeMutex or the c.out mutex.
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
var alertErr error
|
||||
if c.rawConn.IsHandshakeComplete.Load() {
|
||||
if err := c.closeNotify(); err != nil {
|
||||
alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return alertErr
|
||||
}
|
||||
|
||||
func (c *Conn) closeNotify() error {
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
|
||||
if !*c.rawConn.CloseNotifySent {
|
||||
// Set a Write Deadline to prevent possibly blocking forever.
|
||||
c.SetWriteDeadline(time.Now().Add(time.Second * 5))
|
||||
*c.rawConn.CloseNotifyErr = c.sendAlertLocked(alertCloseNotify)
|
||||
*c.rawConn.CloseNotifySent = true
|
||||
// Any subsequent writes will fail.
|
||||
c.SetWriteDeadline(time.Now())
|
||||
}
|
||||
return *c.rawConn.CloseNotifyErr
|
||||
}
|
||||
24
common/ktls/ktls_const.go
Normal file
24
common/ktls/ktls_const.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
const (
|
||||
maxPlaintext = 16384 // maximum plaintext payload length
|
||||
maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
|
||||
maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3
|
||||
recordHeaderLen = 5 // record header length
|
||||
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
|
||||
maxHandshakeCertificateMsg = 262144 // maximum certificate message size (256 KiB)
|
||||
maxUselessRecords = 16 // maximum number of consecutive non-advancing records
|
||||
)
|
||||
|
||||
const (
|
||||
recordTypeChangeCipherSpec = 20
|
||||
recordTypeAlert = 21
|
||||
recordTypeHandshake = 22
|
||||
recordTypeApplicationData = 23
|
||||
)
|
||||
238
common/ktls/ktls_handshake_messages.go
Normal file
238
common/ktls/ktls_handshake_messages.go
Normal file
@@ -0,0 +1,238 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
)
|
||||
|
||||
// The marshalingFunction type is an adapter to allow the use of ordinary
|
||||
// functions as cryptobyte.MarshalingValue.
|
||||
type marshalingFunction func(b *cryptobyte.Builder) error
|
||||
|
||||
func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
|
||||
return f(b)
|
||||
}
|
||||
|
||||
// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
|
||||
// the length of the sequence is not the value specified, it produces an error.
|
||||
func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
|
||||
b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
|
||||
if len(v) != n {
|
||||
return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
|
||||
}
|
||||
b.AddBytes(v)
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
|
||||
func addUint64(b *cryptobyte.Builder, v uint64) {
|
||||
b.AddUint32(uint32(v >> 32))
|
||||
b.AddUint32(uint32(v))
|
||||
}
|
||||
|
||||
// readUint64 decodes a big-endian, 64-bit value into out and advances over it.
|
||||
// It reports whether the read was successful.
|
||||
func readUint64(s *cryptobyte.String, out *uint64) bool {
|
||||
var hi, lo uint32
|
||||
if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
|
||||
return false
|
||||
}
|
||||
*out = uint64(hi)<<32 | uint64(lo)
|
||||
return true
|
||||
}
|
||||
|
||||
// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
|
||||
// []byte instead of a cryptobyte.String.
|
||||
func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
|
||||
return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
|
||||
}
|
||||
|
||||
// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
|
||||
// []byte instead of a cryptobyte.String.
|
||||
func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
|
||||
return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
|
||||
}
|
||||
|
||||
// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
|
||||
// []byte instead of a cryptobyte.String.
|
||||
func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
|
||||
return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
|
||||
}
|
||||
|
||||
type keyUpdateMsg struct {
|
||||
updateRequested bool
|
||||
}
|
||||
|
||||
func (m *keyUpdateMsg) marshal() ([]byte, error) {
|
||||
var b cryptobyte.Builder
|
||||
b.AddUint8(typeKeyUpdate)
|
||||
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
if m.updateRequested {
|
||||
b.AddUint8(1)
|
||||
} else {
|
||||
b.AddUint8(0)
|
||||
}
|
||||
})
|
||||
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
|
||||
s := cryptobyte.String(data)
|
||||
|
||||
var updateRequested uint8
|
||||
if !s.Skip(4) || // message type and uint24 length field
|
||||
!s.ReadUint8(&updateRequested) || !s.Empty() {
|
||||
return false
|
||||
}
|
||||
switch updateRequested {
|
||||
case 0:
|
||||
m.updateRequested = false
|
||||
case 1:
|
||||
m.updateRequested = true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TLS handshake message types.
|
||||
const (
|
||||
typeHelloRequest uint8 = 0
|
||||
typeClientHello uint8 = 1
|
||||
typeServerHello uint8 = 2
|
||||
typeNewSessionTicket uint8 = 4
|
||||
typeEndOfEarlyData uint8 = 5
|
||||
typeEncryptedExtensions uint8 = 8
|
||||
typeCertificate uint8 = 11
|
||||
typeServerKeyExchange uint8 = 12
|
||||
typeCertificateRequest uint8 = 13
|
||||
typeServerHelloDone uint8 = 14
|
||||
typeCertificateVerify uint8 = 15
|
||||
typeClientKeyExchange uint8 = 16
|
||||
typeFinished uint8 = 20
|
||||
typeCertificateStatus uint8 = 22
|
||||
typeKeyUpdate uint8 = 24
|
||||
typeCompressedCertificate uint8 = 25
|
||||
typeMessageHash uint8 = 254 // synthetic message
|
||||
)
|
||||
|
||||
// TLS compression types.
|
||||
const (
|
||||
compressionNone uint8 = 0
|
||||
)
|
||||
|
||||
// TLS extension numbers
|
||||
const (
|
||||
extensionServerName uint16 = 0
|
||||
extensionStatusRequest uint16 = 5
|
||||
extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7
|
||||
extensionSupportedPoints uint16 = 11
|
||||
extensionSignatureAlgorithms uint16 = 13
|
||||
extensionALPN uint16 = 16
|
||||
extensionSCT uint16 = 18
|
||||
extensionPadding uint16 = 21
|
||||
extensionExtendedMasterSecret uint16 = 23
|
||||
extensionCompressCertificate uint16 = 27 // compress_certificate in TLS 1.3
|
||||
extensionSessionTicket uint16 = 35
|
||||
extensionPreSharedKey uint16 = 41
|
||||
extensionEarlyData uint16 = 42
|
||||
extensionSupportedVersions uint16 = 43
|
||||
extensionCookie uint16 = 44
|
||||
extensionPSKModes uint16 = 45
|
||||
extensionCertificateAuthorities uint16 = 47
|
||||
extensionSignatureAlgorithmsCert uint16 = 50
|
||||
extensionKeyShare uint16 = 51
|
||||
extensionQUICTransportParameters uint16 = 57
|
||||
extensionALPS uint16 = 17513
|
||||
extensionRenegotiationInfo uint16 = 0xff01
|
||||
extensionECHOuterExtensions uint16 = 0xfd00
|
||||
extensionEncryptedClientHello uint16 = 0xfe0d
|
||||
)
|
||||
|
||||
type handshakeMessage interface {
|
||||
marshal() ([]byte, error)
|
||||
unmarshal([]byte) bool
|
||||
}
|
||||
type newSessionTicketMsgTLS13 struct {
|
||||
lifetime uint32
|
||||
ageAdd uint32
|
||||
nonce []byte
|
||||
label []byte
|
||||
maxEarlyData uint32
|
||||
}
|
||||
|
||||
func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
|
||||
var b cryptobyte.Builder
|
||||
b.AddUint8(typeNewSessionTicket)
|
||||
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddUint32(m.lifetime)
|
||||
b.AddUint32(m.ageAdd)
|
||||
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(m.nonce)
|
||||
})
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(m.label)
|
||||
})
|
||||
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
if m.maxEarlyData > 0 {
|
||||
b.AddUint16(extensionEarlyData)
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddUint32(m.maxEarlyData)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
|
||||
*m = newSessionTicketMsgTLS13{}
|
||||
s := cryptobyte.String(data)
|
||||
|
||||
var extensions cryptobyte.String
|
||||
if !s.Skip(4) || // message type and uint24 length field
|
||||
!s.ReadUint32(&m.lifetime) ||
|
||||
!s.ReadUint32(&m.ageAdd) ||
|
||||
!readUint8LengthPrefixed(&s, &m.nonce) ||
|
||||
!readUint16LengthPrefixed(&s, &m.label) ||
|
||||
!s.ReadUint16LengthPrefixed(&extensions) ||
|
||||
!s.Empty() {
|
||||
return false
|
||||
}
|
||||
|
||||
for !extensions.Empty() {
|
||||
var extension uint16
|
||||
var extData cryptobyte.String
|
||||
if !extensions.ReadUint16(&extension) ||
|
||||
!extensions.ReadUint16LengthPrefixed(&extData) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch extension {
|
||||
case extensionEarlyData:
|
||||
if !extData.ReadUint32(&m.maxEarlyData) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
// Ignore unknown extensions.
|
||||
continue
|
||||
}
|
||||
|
||||
if !extData.Empty() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
173
common/ktls/ktls_key_update.go
Normal file
173
common/ktls/ktls_key_update.go
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// handlePostHandshakeMessage processes a handshake message arrived after the
|
||||
// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
|
||||
func (c *Conn) handlePostHandshakeMessage() error {
|
||||
if *c.rawConn.Vers != tls.VersionTLS13 {
|
||||
return errors.New("ktls: kernel does not support TLS 1.2 renegotiation")
|
||||
}
|
||||
|
||||
msg, err := c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//c.retryCount++
|
||||
//if c.retryCount > maxUselessRecords {
|
||||
// c.sendAlert(alertUnexpectedMessage)
|
||||
// return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
|
||||
//}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *newSessionTicketMsgTLS13:
|
||||
// return errors.New("ktls: received new session ticket")
|
||||
return nil
|
||||
case *keyUpdateMsg:
|
||||
return c.handleKeyUpdate(msg)
|
||||
}
|
||||
// The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
|
||||
// as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
|
||||
// unexpected_message alert here doesn't provide it with enough information to distinguish
|
||||
// this condition from other unexpected messages. This is probably fine.
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
|
||||
}
|
||||
|
||||
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
|
||||
//if c.quic != nil {
|
||||
// c.sendAlert(alertUnexpectedMessage)
|
||||
// return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
|
||||
//}
|
||||
|
||||
cipherSuite := cipherSuiteTLS13ByID(*c.rawConn.CipherSuite)
|
||||
if cipherSuite == nil {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertInternalError))
|
||||
}
|
||||
|
||||
newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.In.TrafficSecret)
|
||||
c.rawConn.In.SetTrafficSecret(cipherSuite, 0 /*tls.QUICEncryptionLevelInitial*/, newSecret)
|
||||
|
||||
err := c.resetupRX()
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return c.rawConn.In.SetErrorLocked(fmt.Errorf("ktls: resetupRX failed: %w", err))
|
||||
}
|
||||
|
||||
if keyUpdate.updateRequested {
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
|
||||
resetup, err := c.resetupTX()
|
||||
if err != nil {
|
||||
c.sendAlertLocked(alertInternalError)
|
||||
return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err))
|
||||
}
|
||||
|
||||
msg := &keyUpdateMsg{}
|
||||
msgBytes, err := msg.marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
|
||||
if err != nil {
|
||||
// Surface the error at the next write.
|
||||
c.rawConn.Out.SetErrorLocked(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
newSecret := nextTrafficSecret(cipherSuite, *c.rawConn.Out.TrafficSecret)
|
||||
c.rawConn.Out.SetTrafficSecret(cipherSuite, 0 /*QUICEncryptionLevelInitial*/, newSecret)
|
||||
|
||||
err = resetup()
|
||||
if err != nil {
|
||||
return c.rawConn.Out.SetErrorLocked(fmt.Errorf("ktls: resetupTX failed: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) readHandshakeBytes(n int) error {
|
||||
//if c.quic != nil {
|
||||
// return c.quicReadHandshakeBytes(n)
|
||||
//}
|
||||
for c.rawConn.Hand.Len() < n {
|
||||
if err := c.readRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) readHandshake(transcript io.Writer) (any, error) {
|
||||
if err := c.readHandshakeBytes(4); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := c.rawConn.Hand.Bytes()
|
||||
|
||||
maxHandshakeSize := maxHandshake
|
||||
// hasVers indicates we're past the first message, forcing someone trying to
|
||||
// make us just allocate a large buffer to at least do the initial part of
|
||||
// the handshake first.
|
||||
//if c.haveVers && data[0] == typeCertificate {
|
||||
// Since certificate messages are likely to be the only messages that
|
||||
// can be larger than maxHandshake, we use a special limit for just
|
||||
// those messages.
|
||||
//maxHandshakeSize = maxHandshakeCertificateMsg
|
||||
//}
|
||||
|
||||
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||
if n > maxHandshakeSize {
|
||||
c.sendAlertLocked(alertInternalError)
|
||||
return nil, c.rawConn.In.SetErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
|
||||
}
|
||||
if err := c.readHandshakeBytes(4 + n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data = c.rawConn.Hand.Next(4 + n)
|
||||
return c.unmarshalHandshakeMessage(data, transcript)
|
||||
}
|
||||
|
||||
func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript io.Writer) (any, error) {
|
||||
var m handshakeMessage
|
||||
switch data[0] {
|
||||
case typeNewSessionTicket:
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 {
|
||||
m = new(newSessionTicketMsgTLS13)
|
||||
} else {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
case typeKeyUpdate:
|
||||
m = new(keyUpdateMsg)
|
||||
default:
|
||||
return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
|
||||
// The handshake message unmarshalers
|
||||
// expect to be able to keep references to data,
|
||||
// so pass in a fresh copy that won't be overwritten.
|
||||
data = append([]byte(nil), data...)
|
||||
|
||||
if !m.unmarshal(data) {
|
||||
return nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecodeError))
|
||||
}
|
||||
|
||||
if transcript != nil {
|
||||
transcript.Write(data)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
329
common/ktls/ktls_linux.go
Normal file
329
common/ktls/ktls_linux.go
Normal file
@@ -0,0 +1,329 @@
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing-box/common/badversion"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/shell"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// mod from https://gitlab.com/go-extension/tls
|
||||
|
||||
const (
|
||||
TLS_TX = 1
|
||||
TLS_RX = 2
|
||||
TLS_TX_ZEROCOPY_RO = 3 // TX zerocopy (only sendfile now)
|
||||
TLS_RX_EXPECT_NO_PAD = 4 // Attempt opportunistic zero-copy, TLS 1.3 only
|
||||
|
||||
TLS_SET_RECORD_TYPE = 1
|
||||
TLS_GET_RECORD_TYPE = 2
|
||||
)
|
||||
|
||||
type Support struct {
|
||||
TLS, TLS_RX bool
|
||||
TLS_Version13, TLS_Version13_RX bool
|
||||
|
||||
TLS_TX_ZEROCOPY bool
|
||||
TLS_RX_NOPADDING bool
|
||||
|
||||
TLS_AES_256_GCM bool
|
||||
TLS_AES_128_CCM bool
|
||||
TLS_CHACHA20_POLY1305 bool
|
||||
TLS_SM4 bool
|
||||
TLS_ARIA_GCM bool
|
||||
|
||||
TLS_Version13_KeyUpdate bool
|
||||
}
|
||||
|
||||
var KernelSupport = sync.OnceValues(func() (*Support, error) {
|
||||
var uname unix.Utsname
|
||||
err := unix.Uname(&uname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kernelVersion := badversion.Parse(strings.Trim(string(uname.Release[:]), "\x00"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var support Support
|
||||
switch {
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 6, Minor: 14}):
|
||||
support.TLS_Version13_KeyUpdate = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 6, Minor: 1}):
|
||||
support.TLS_ARIA_GCM = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 6}):
|
||||
support.TLS_Version13_RX = true
|
||||
support.TLS_RX_NOPADDING = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 19}):
|
||||
support.TLS_TX_ZEROCOPY = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 16}):
|
||||
support.TLS_SM4 = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 11}):
|
||||
support.TLS_CHACHA20_POLY1305 = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 2}):
|
||||
support.TLS_AES_128_CCM = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 5, Minor: 1}):
|
||||
support.TLS_AES_256_GCM = true
|
||||
support.TLS_Version13 = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 4, Minor: 17}):
|
||||
support.TLS_RX = true
|
||||
fallthrough
|
||||
case kernelVersion.GreaterThanOrEqual(badversion.Version{Major: 4, Minor: 13}):
|
||||
support.TLS = true
|
||||
}
|
||||
|
||||
if support.TLS && support.TLS_Version13 {
|
||||
_, err := os.Stat("/sys/module/tls")
|
||||
if err != nil {
|
||||
if os.Getuid() == 0 {
|
||||
output, err := shell.Exec("modprobe", "tls").Read()
|
||||
if err != nil {
|
||||
return nil, E.Extend(E.Cause(err, "modprobe tls"), output)
|
||||
}
|
||||
} else {
|
||||
return nil, E.New("ktls: kernel TLS module not loaded")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &support, nil
|
||||
})
|
||||
|
||||
func Load() error {
|
||||
support, err := KernelSupport()
|
||||
if err != nil {
|
||||
return E.Cause(err, "ktls: check availability")
|
||||
}
|
||||
if !support.TLS || !support.TLS_Version13 {
|
||||
return E.New("ktls: kernel does not support TLS 1.3")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) setupKernel(txOffload, rxOffload bool) error {
|
||||
if !txOffload && !rxOffload {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
support, err := KernelSupport()
|
||||
if err != nil {
|
||||
return E.Cause(err, "check availability")
|
||||
}
|
||||
if !support.TLS || !support.TLS_Version13 {
|
||||
return E.New("kernel does not support TLS 1.3")
|
||||
}
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TCP, unix.TCP_ULP, "tls")
|
||||
})
|
||||
if err != nil {
|
||||
return os.NewSyscallError("setsockopt", err)
|
||||
}
|
||||
|
||||
if txOffload {
|
||||
txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false)
|
||||
if txCrypto == nil {
|
||||
return E.New("unsupported cipher suite")
|
||||
}
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String())
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if support.TLS_TX_ZEROCOPY {
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_TX_ZEROCOPY_RO, 1)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.kernelTx = true
|
||||
c.logger.DebugContext(c.ctx, "ktls: kernel TLS TX enabled")
|
||||
}
|
||||
|
||||
if rxOffload {
|
||||
rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true)
|
||||
if rxCrypto == nil {
|
||||
return E.New("unsupported cipher suite")
|
||||
}
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String())
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *c.rawConn.Vers >= tls.VersionTLS13 && support.TLS_RX_NOPADDING {
|
||||
err = control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_RX_EXPECT_NO_PAD, 1)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.kernelRx = true
|
||||
c.logger.DebugContext(c.ctx, "ktls: kernel TLS RX enabled")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) resetupTX() (func() error, error) {
|
||||
if !c.kernelTx {
|
||||
return nil, nil
|
||||
}
|
||||
support, err := KernelSupport()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !support.TLS_Version13_KeyUpdate {
|
||||
return nil, errors.New("ktls: kernel does not support rekey")
|
||||
}
|
||||
txCrypto := kernelCipher(support, c.rawConn.Out, *c.rawConn.CipherSuite, false)
|
||||
if txCrypto == nil {
|
||||
return nil, errors.New("ktls: set kernelCipher on unsupported tls session")
|
||||
}
|
||||
return func() error {
|
||||
return control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, txCrypto.String())
|
||||
})
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) resetupRX() error {
|
||||
if !c.kernelRx {
|
||||
return nil
|
||||
}
|
||||
support, err := KernelSupport()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !support.TLS_Version13_KeyUpdate {
|
||||
return errors.New("ktls: kernel does not support rekey")
|
||||
}
|
||||
rxCrypto := kernelCipher(support, c.rawConn.In, *c.rawConn.CipherSuite, true)
|
||||
if rxCrypto == nil {
|
||||
return errors.New("ktls: set kernelCipher on unsupported tls session")
|
||||
}
|
||||
return control.Raw(c.rawSyscallConn, func(fd uintptr) error {
|
||||
return syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, rxCrypto.String())
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) readKernelRecord() (uint8, []byte, error) {
|
||||
if c.rawConn.RawInput.Len() < maxPlaintext {
|
||||
c.rawConn.RawInput.Grow(maxPlaintext - c.rawConn.RawInput.Len())
|
||||
}
|
||||
|
||||
data := c.rawConn.RawInput.Bytes()[:maxPlaintext]
|
||||
|
||||
// cmsg for record type
|
||||
buffer := make([]byte, unix.CmsgSpace(1))
|
||||
cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0]))
|
||||
cmsg.SetLen(unix.CmsgLen(1))
|
||||
|
||||
var iov unix.Iovec
|
||||
iov.Base = &data[0]
|
||||
iov.SetLen(len(data))
|
||||
|
||||
var msg unix.Msghdr
|
||||
msg.Control = &buffer[0]
|
||||
msg.Controllen = cmsg.Len
|
||||
msg.Iov = &iov
|
||||
msg.Iovlen = 1
|
||||
|
||||
var n int
|
||||
var err error
|
||||
er := c.rawSyscallConn.Read(func(fd uintptr) bool {
|
||||
n, err = recvmsg(int(fd), &msg, 0)
|
||||
return err != unix.EAGAIN || c.pendingRxSplice
|
||||
})
|
||||
if er != nil {
|
||||
return 0, nil, er
|
||||
}
|
||||
switch err {
|
||||
case nil:
|
||||
case syscall.EINVAL, syscall.EAGAIN:
|
||||
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertProtocolVersion))
|
||||
case syscall.EMSGSIZE:
|
||||
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow))
|
||||
case syscall.EBADMSG:
|
||||
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecryptError))
|
||||
default:
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if n <= 0 {
|
||||
return 0, nil, c.rawConn.In.SetErrorLocked(io.EOF)
|
||||
}
|
||||
|
||||
if cmsg.Level == unix.SOL_TLS && cmsg.Type == TLS_GET_RECORD_TYPE {
|
||||
typ := buffer[unix.CmsgLen(0)]
|
||||
return typ, data[:n], nil
|
||||
}
|
||||
|
||||
return recordTypeApplicationData, data[:n], nil
|
||||
}
|
||||
|
||||
func (c *Conn) writeKernelRecord(typ uint16, data []byte) (int, error) {
|
||||
if typ == recordTypeApplicationData {
|
||||
return c.conn.Write(data)
|
||||
}
|
||||
|
||||
// cmsg for record type
|
||||
buffer := make([]byte, unix.CmsgSpace(1))
|
||||
cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0]))
|
||||
cmsg.SetLen(unix.CmsgLen(1))
|
||||
buffer[unix.CmsgLen(0)] = byte(typ)
|
||||
cmsg.Level = unix.SOL_TLS
|
||||
cmsg.Type = TLS_SET_RECORD_TYPE
|
||||
|
||||
var iov unix.Iovec
|
||||
iov.Base = &data[0]
|
||||
iov.SetLen(len(data))
|
||||
|
||||
var msg unix.Msghdr
|
||||
msg.Control = &buffer[0]
|
||||
msg.Controllen = cmsg.Len
|
||||
msg.Iov = &iov
|
||||
msg.Iovlen = 1
|
||||
|
||||
var n int
|
||||
var err error
|
||||
ew := c.rawSyscallConn.Write(func(fd uintptr) bool {
|
||||
n, err = sendmsg(int(fd), &msg, 0)
|
||||
return err != unix.EAGAIN
|
||||
})
|
||||
if ew != nil {
|
||||
return 0, ew
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
//go:linkname recvmsg golang.org/x/sys/unix.recvmsg
|
||||
func recvmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error)
|
||||
|
||||
//go:linkname sendmsg golang.org/x/sys/unix.sendmsg
|
||||
func sendmsg(fd int, msg *unix.Msghdr, flags int) (n int, err error)
|
||||
24
common/ktls/ktls_prf.go
Normal file
24
common/ktls/ktls_prf.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import "unsafe"
|
||||
|
||||
//go:linkname cipherSuiteByID github.com/metacubex/utls.cipherSuiteByID
|
||||
func cipherSuiteByID(id uint16) unsafe.Pointer
|
||||
|
||||
//go:linkname keysFromMasterSecret github.com/metacubex/utls.keysFromMasterSecret
|
||||
func keysFromMasterSecret(version uint16, suite unsafe.Pointer, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte)
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID github.com/metacubex/utls.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) unsafe.Pointer
|
||||
|
||||
//go:linkname nextTrafficSecret github.com/metacubex/utls.(*cipherSuiteTLS13).nextTrafficSecret
|
||||
func nextTrafficSecret(cs unsafe.Pointer, trafficSecret []byte) []byte
|
||||
|
||||
//go:linkname trafficKey github.com/metacubex/utls.(*cipherSuiteTLS13).trafficKey
|
||||
func trafficKey(cs unsafe.Pointer, trafficSecret []byte) (key, iv []byte)
|
||||
292
common/ktls/ktls_read.go
Normal file
292
common/ktls/ktls_read.go
Normal file
@@ -0,0 +1,292 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
if !c.kernelRx {
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
// Put this after Handshake, in case people were calling
|
||||
// Read(nil) for the side effect of the Handshake.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
c.rawConn.In.Lock()
|
||||
defer c.rawConn.In.Unlock()
|
||||
|
||||
for c.rawConn.Input.Len() == 0 {
|
||||
if err := c.readRecord(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for c.rawConn.Hand.Len() > 0 {
|
||||
if err := c.handlePostHandshakeMessage(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n, _ := c.rawConn.Input.Read(b)
|
||||
|
||||
// If a close-notify alert is waiting, read it so that we can return (n,
|
||||
// EOF) instead of (n, nil), to signal to the HTTP response reading
|
||||
// goroutine that the connection is now closed. This eliminates a race
|
||||
// where the HTTP response reading goroutine would otherwise not observe
|
||||
// the EOF until its next read, by which time a client goroutine might
|
||||
// have already tried to reuse the HTTP connection for a new request.
|
||||
// See https://golang.org/cl/76400046 and https://golang.org/issue/3514
|
||||
if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.RawInput.Len() > 0 &&
|
||||
c.rawConn.RawInput.Bytes()[0] == recordTypeAlert {
|
||||
if err := c.readRecord(); err != nil {
|
||||
return n, err // will be io.EOF on closeNotify
|
||||
}
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) readRecord() error {
|
||||
if *c.rawConn.In.Err != nil {
|
||||
return *c.rawConn.In.Err
|
||||
}
|
||||
|
||||
typ, data, err := c.readRawRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(data) > maxPlaintext {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow))
|
||||
}
|
||||
|
||||
// Application Data messages are always protected.
|
||||
if c.rawConn.In.Cipher == nil && typ == recordTypeApplicationData {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
|
||||
//if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
|
||||
// This is a state-advancing message: reset the retry count.
|
||||
// c.retryCount = 0
|
||||
//}
|
||||
|
||||
// Handshake messages MUST NOT be interleaved with other record types in TLS 1.3.
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 && typ != recordTypeHandshake && c.rawConn.Hand.Len() > 0 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
|
||||
switch typ {
|
||||
default:
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
case recordTypeAlert:
|
||||
//if c.quic != nil {
|
||||
// return c.rawConn.In.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
//}
|
||||
if len(data) != 2 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
if data[1] == alertCloseNotify {
|
||||
return c.rawConn.In.SetErrorLocked(io.EOF)
|
||||
}
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 {
|
||||
// TLS 1.3 removed warning-level alerts except for alertUserCanceled
|
||||
// (RFC 8446, § 6.1). Since at least one major implementation
|
||||
// (https://bugs.openjdk.org/browse/JDK-8323517) misuses this alert,
|
||||
// many TLS stacks now ignore it outright when seen in a TLS 1.3
|
||||
// handshake (e.g. BoringSSL, NSS, Rustls).
|
||||
if data[1] == alertUserCanceled {
|
||||
// Like TLS 1.2 alertLevelWarning alerts, we drop the record and retry.
|
||||
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
|
||||
}
|
||||
return c.rawConn.In.SetErrorLocked(&net.OpError{Op: "remote error", Err: tls.AlertError(data[1])})
|
||||
}
|
||||
switch data[0] {
|
||||
case alertLevelWarning:
|
||||
// Drop the record on the floor and retry.
|
||||
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
|
||||
case alertLevelError:
|
||||
return c.rawConn.In.SetErrorLocked(&net.OpError{Op: "remote error", Err: tls.AlertError(data[1])})
|
||||
default:
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
|
||||
case recordTypeChangeCipherSpec:
|
||||
if len(data) != 1 || data[0] != 1 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertDecodeError))
|
||||
}
|
||||
// Handshake messages are not allowed to fragment across the CCS.
|
||||
if c.rawConn.Hand.Len() > 0 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
// In TLS 1.3, change_cipher_spec records are ignored until the
|
||||
// Finished. See RFC 8446, Appendix D.4. Note that according to Section
|
||||
// 5, a server can send a ChangeCipherSpec before its ServerHello, when
|
||||
// c.vers is still unset. That's not useful though and suspicious if the
|
||||
// server then selects a lower protocol version, so don't allow that.
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 {
|
||||
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
|
||||
}
|
||||
// if !expectChangeCipherSpec {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
//}
|
||||
//if err := c.rawConn.In.changeCipherSpec(); err != nil {
|
||||
// return c.rawConn.In.setErrorLocked(c.sendAlert(err.(alert)))
|
||||
//}
|
||||
|
||||
case recordTypeApplicationData:
|
||||
// Some OpenSSL servers send empty records in order to randomize the
|
||||
// CBC RawIV. Ignore a limited number of empty records.
|
||||
if len(data) == 0 {
|
||||
return c.retryReadRecord( /*expectChangeCipherSpec*/ )
|
||||
}
|
||||
// Note that data is owned by c.rawInput, following the Next call above,
|
||||
// to avoid copying the plaintext. This is safe because c.rawInput is
|
||||
// not read from or written to until c.input is drained.
|
||||
c.rawConn.Input.Reset(data)
|
||||
case recordTypeHandshake:
|
||||
if len(data) == 0 {
|
||||
return c.rawConn.In.SetErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
c.rawConn.Hand.Write(data)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
func (c *Conn) readRawRecord() (typ uint8, data []byte, err error) {
|
||||
// Read from kernel.
|
||||
if c.kernelRx {
|
||||
return c.readKernelRecord()
|
||||
}
|
||||
|
||||
// Read header, payload.
|
||||
if err = c.readFromUntil(c.conn, recordHeaderLen); err != nil {
|
||||
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
|
||||
// is an error, but popular web sites seem to do this, so we accept it
|
||||
// if and only if at the record boundary.
|
||||
if err == io.ErrUnexpectedEOF && c.rawConn.RawInput.Len() == 0 {
|
||||
err = io.EOF
|
||||
}
|
||||
if e, ok := err.(net.Error); !ok || !e.Temporary() {
|
||||
c.rawConn.In.SetErrorLocked(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
hdr := c.rawConn.RawInput.Bytes()[:recordHeaderLen]
|
||||
typ = hdr[0]
|
||||
|
||||
vers := uint16(hdr[1])<<8 | uint16(hdr[2])
|
||||
expectedVers := *c.rawConn.Vers
|
||||
if expectedVers == tls.VersionTLS13 {
|
||||
// All TLS 1.3 records are expected to have 0x0303 (1.2) after
|
||||
// the initial hello (RFC 8446 Section 5.1).
|
||||
expectedVers = tls.VersionTLS12
|
||||
}
|
||||
n := int(hdr[3])<<8 | int(hdr[4])
|
||||
if /*c.haveVers && */ vers != expectedVers {
|
||||
c.sendAlert(alertProtocolVersion)
|
||||
msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
|
||||
err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(nil, msg))
|
||||
return
|
||||
}
|
||||
//if !c.haveVers {
|
||||
// // First message, be extra suspicious: this might not be a TLS
|
||||
// // client. Bail out before reading a full 'body', if possible.
|
||||
// // The current max version is 3.3 so if the version is >= 16.0,
|
||||
// // it's probably not real.
|
||||
// if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
|
||||
// err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
|
||||
// return
|
||||
// }
|
||||
//}
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
|
||||
c.sendAlert(alertRecordOverflow)
|
||||
msg := fmt.Sprintf("oversized record received with length %d", n)
|
||||
err = c.rawConn.In.SetErrorLocked(c.newRecordHeaderError(nil, msg))
|
||||
return
|
||||
}
|
||||
if err = c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
|
||||
if e, ok := err.(net.Error); !ok || !e.Temporary() {
|
||||
c.rawConn.In.SetErrorLocked(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Process message.
|
||||
record := c.rawConn.RawInput.Next(recordHeaderLen + n)
|
||||
data, typ, err = c.rawConn.In.Decrypt(record)
|
||||
if err != nil {
|
||||
err = c.rawConn.In.SetErrorLocked(c.sendAlert(uint8(err.(tls.AlertError))))
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like
|
||||
// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3.
|
||||
func (c *Conn) retryReadRecord( /*expectChangeCipherSpec bool*/ ) error {
|
||||
//c.retryCount++
|
||||
//if c.retryCount > maxUselessRecords {
|
||||
// c.sendAlert(alertUnexpectedMessage)
|
||||
// return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
|
||||
//}
|
||||
return c.readRecord( /*expectChangeCipherSpec*/ )
|
||||
}
|
||||
|
||||
// atLeastReader reads from R, stopping with EOF once at least N bytes have been
|
||||
// read. It is different from an io.LimitedReader in that it doesn't cut short
|
||||
// the last Read call, and in that it considers an early EOF an error.
|
||||
type atLeastReader struct {
|
||||
R io.Reader
|
||||
N int64
|
||||
}
|
||||
|
||||
func (r *atLeastReader) Read(p []byte) (int, error) {
|
||||
if r.N <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err := r.R.Read(p)
|
||||
r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
|
||||
if r.N > 0 && err == io.EOF {
|
||||
return n, io.ErrUnexpectedEOF
|
||||
}
|
||||
if r.N <= 0 && err == nil {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// readFromUntil reads from r into c.rawConn.RawInput until c.rawConn.RawInput contains
|
||||
// at least n bytes or else returns an error.
|
||||
func (c *Conn) readFromUntil(r io.Reader, n int) error {
|
||||
if c.rawConn.RawInput.Len() >= n {
|
||||
return nil
|
||||
}
|
||||
needs := n - c.rawConn.RawInput.Len()
|
||||
// There might be extra input waiting on the wire. Make a best effort
|
||||
// attempt to fetch it so that it can be used in (*Conn).Read to
|
||||
// "predict" closeNotify alerts.
|
||||
c.rawConn.RawInput.Grow(needs + bytes.MinRead)
|
||||
_, err := c.rawConn.RawInput.ReadFrom(&atLeastReader{r, int64(needs)})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err tls.RecordHeaderError) {
|
||||
err.Msg = msg
|
||||
err.Conn = conn
|
||||
copy(err.RecordHeader[:], c.rawConn.RawInput.Bytes())
|
||||
return err
|
||||
}
|
||||
41
common/ktls/ktls_read_wait.go
Normal file
41
common/ktls/ktls_read_wait.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||||
c.readWaitOptions = options
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Conn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
c.rawConn.In.Lock()
|
||||
defer c.rawConn.In.Unlock()
|
||||
for c.rawConn.Input.Len() == 0 {
|
||||
err = c.readRecord()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
buffer = c.readWaitOptions.NewBuffer()
|
||||
n, err := c.rawConn.Input.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
if n != 0 && c.rawConn.Input.Len() == 0 && c.rawConn.Input.Len() > 0 &&
|
||||
c.rawConn.RawInput.Bytes()[0] == recordTypeAlert {
|
||||
_ = c.rawConn.ReadRecord()
|
||||
}
|
||||
c.readWaitOptions.PostReturn(buffer)
|
||||
return
|
||||
}
|
||||
15
common/ktls/ktls_stub_nolinkname.go
Normal file
15
common/ktls/ktls_stub_nolinkname.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build linux && go1.25 && !badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
|
||||
return nil, E.New("kTLS requires build flags `badlinkname` and `-ldflags=-checklinkname=0`, please recompile your binary")
|
||||
}
|
||||
15
common/ktls/ktls_stub_nonlinux.go
Normal file
15
common/ktls/ktls_stub_nonlinux.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !linux
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
|
||||
return nil, E.New("kTLS is only supported on Linux")
|
||||
}
|
||||
15
common/ktls/ktls_stub_oldgo.go
Normal file
15
common/ktls/ktls_stub_oldgo.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build linux && !go1.25
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
|
||||
return nil, E.New("kTLS requires Go 1.25 or later, please recompile your binary")
|
||||
}
|
||||
154
common/ktls/ktls_write.go
Normal file
154
common/ktls/ktls_write.go
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux && go1.25 && badlinkname
|
||||
|
||||
package ktls
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
if !c.kernelTx {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
// interlock with Close below
|
||||
for {
|
||||
x := c.rawConn.ActiveCall.Load()
|
||||
if x&1 != 0 {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if c.rawConn.ActiveCall.CompareAndSwap(x, x+2) {
|
||||
break
|
||||
}
|
||||
}
|
||||
defer c.rawConn.ActiveCall.Add(-2)
|
||||
|
||||
//if err := c.Conn.HandshakeContext(context.Background()); err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
|
||||
c.rawConn.Out.Lock()
|
||||
defer c.rawConn.Out.Unlock()
|
||||
|
||||
if err := *c.rawConn.Out.Err; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if !c.rawConn.IsHandshakeComplete.Load() {
|
||||
return 0, tls.AlertError(alertInternalError)
|
||||
}
|
||||
|
||||
if *c.rawConn.CloseNotifySent {
|
||||
// return 0, errShutdown
|
||||
return 0, errors.New("tls: protocol is shutdown")
|
||||
}
|
||||
|
||||
// TLS 1.0 is susceptible to a chosen-plaintext
|
||||
// attack when using block mode ciphers due to predictable IVs.
|
||||
// This can be prevented by splitting each Application Data
|
||||
// record into two records, effectively randomizing the RawIV.
|
||||
//
|
||||
// https://www.openssl.org/~bodo/tls-cbc.txt
|
||||
// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
|
||||
// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
|
||||
|
||||
var m int
|
||||
if len(b) > 1 && *c.rawConn.Vers == tls.VersionTLS10 {
|
||||
if _, ok := (*c.rawConn.Out.Cipher).(cipher.BlockMode); ok {
|
||||
n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
|
||||
if err != nil {
|
||||
return n, c.rawConn.Out.SetErrorLocked(err)
|
||||
}
|
||||
m, b = 1, b[1:]
|
||||
}
|
||||
}
|
||||
|
||||
n, err := c.writeRecordLocked(recordTypeApplicationData, b)
|
||||
return n + m, c.rawConn.Out.SetErrorLocked(err)
|
||||
}
|
||||
|
||||
func (c *Conn) writeRecordLocked(typ uint16, data []byte) (n int, err error) {
|
||||
if !c.kernelTx {
|
||||
return c.rawConn.WriteRecordLocked(typ, data)
|
||||
}
|
||||
/*for len(data) > 0 {
|
||||
m := len(data)
|
||||
if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
|
||||
m = maxPayload
|
||||
}
|
||||
_, err = c.writeKernelRecord(typ, data[:m])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += m
|
||||
data = data[m:]
|
||||
}*/
|
||||
return c.writeKernelRecord(typ, data)
|
||||
}
|
||||
|
||||
const (
|
||||
// tcpMSSEstimate is a conservative estimate of the TCP maximum segment
|
||||
// size (MSS). A constant is used, rather than querying the kernel for
|
||||
// the actual MSS, to avoid complexity. The value here is the IPv6
|
||||
// minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40
|
||||
// bytes) and a TCP header with timestamps (32 bytes).
|
||||
tcpMSSEstimate = 1208
|
||||
|
||||
// recordSizeBoostThreshold is the number of bytes of application data
|
||||
// sent after which the TLS record size will be increased to the
|
||||
// maximum.
|
||||
recordSizeBoostThreshold = 128 * 1024
|
||||
)
|
||||
|
||||
func (c *Conn) maxPayloadSizeForWrite(typ uint16) int {
|
||||
if /*c.config.DynamicRecordSizingDisabled ||*/ typ != recordTypeApplicationData {
|
||||
return maxPlaintext
|
||||
}
|
||||
|
||||
if *c.rawConn.PacketsSent >= recordSizeBoostThreshold {
|
||||
return maxPlaintext
|
||||
}
|
||||
|
||||
// Subtract TLS overheads to get the maximum payload size.
|
||||
payloadBytes := tcpMSSEstimate - recordHeaderLen - c.rawConn.Out.ExplicitNonceLen()
|
||||
if rawCipher := *c.rawConn.Out.Cipher; rawCipher != nil {
|
||||
switch ciph := rawCipher.(type) {
|
||||
case cipher.Stream:
|
||||
payloadBytes -= (*c.rawConn.Out.Mac).Size()
|
||||
case cipher.AEAD:
|
||||
payloadBytes -= ciph.Overhead()
|
||||
/*case cbcMode:
|
||||
blockSize := ciph.BlockSize()
|
||||
// The payload must fit in a multiple of blockSize, with
|
||||
// room for at least one padding byte.
|
||||
payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
|
||||
// The RawMac is appended before padding so affects the
|
||||
// payload size directly.
|
||||
payloadBytes -= c.out.mac.Size()*/
|
||||
default:
|
||||
panic("unknown cipher type")
|
||||
}
|
||||
}
|
||||
if *c.rawConn.Vers == tls.VersionTLS13 {
|
||||
payloadBytes-- // encrypted ContentType
|
||||
}
|
||||
|
||||
// Allow packet growth in arithmetic progression up to max.
|
||||
pkt := *c.rawConn.PacketsSent
|
||||
*c.rawConn.PacketsSent++
|
||||
if pkt > 1000 {
|
||||
return maxPlaintext // avoid overflow in multiply below
|
||||
}
|
||||
|
||||
n := payloadBytes * int(pkt+1)
|
||||
if n > maxPlaintext {
|
||||
n = maxPlaintext
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build go1.21
|
||||
|
||||
package listener
|
||||
|
||||
import "net"
|
||||
|
||||
const go121Available = true
|
||||
|
||||
func setMultiPathTCP(listenConfig *net.ListenConfig) {
|
||||
listenConfig.SetMultipathTCP(true)
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build go1.23
|
||||
|
||||
package listener
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func setKeepAliveConfig(listener *net.ListenConfig, idle time.Duration, interval time.Duration) {
|
||||
listener.KeepAliveConfig = net.KeepAliveConfig{
|
||||
Enable: true,
|
||||
Idle: idle,
|
||||
Interval: interval,
|
||||
}
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
//go:build !go1.21
|
||||
|
||||
package listener
|
||||
|
||||
import "net"
|
||||
|
||||
const go121Available = false
|
||||
|
||||
func setMultiPathTCP(listenConfig *net.ListenConfig) {
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build !go1.23
|
||||
|
||||
package listener
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/control"
|
||||
)
|
||||
|
||||
func setKeepAliveConfig(listener *net.ListenConfig, idle time.Duration, interval time.Duration) {
|
||||
listener.KeepAlive = idle
|
||||
listener.Control = control.Append(listener.Control, control.SetKeepAlivePeriod(idle, interval))
|
||||
}
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
"github.com/metacubex/tfo-go"
|
||||
"github.com/database64128/tfo-go/v2"
|
||||
)
|
||||
|
||||
func (l *Listener) ListenTCP() (net.Listener, error) {
|
||||
@@ -37,7 +37,7 @@ func (l *Listener) ListenTCP() (net.Listener, error) {
|
||||
if l.listenOptions.ReuseAddr {
|
||||
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
|
||||
}
|
||||
if l.listenOptions.TCPKeepAlive >= 0 {
|
||||
if !l.listenOptions.DisableTCPKeepAlive {
|
||||
keepIdle := time.Duration(l.listenOptions.TCPKeepAlive)
|
||||
if keepIdle == 0 {
|
||||
keepIdle = C.TCPKeepAliveInitial
|
||||
@@ -46,13 +46,14 @@ func (l *Listener) ListenTCP() (net.Listener, error) {
|
||||
if keepInterval == 0 {
|
||||
keepInterval = C.TCPKeepAliveInterval
|
||||
}
|
||||
setKeepAliveConfig(&listenConfig, keepIdle, keepInterval)
|
||||
listenConfig.KeepAliveConfig = net.KeepAliveConfig{
|
||||
Enable: true,
|
||||
Idle: keepIdle,
|
||||
Interval: keepInterval,
|
||||
}
|
||||
}
|
||||
if l.listenOptions.TCPMultiPath {
|
||||
if !go121Available {
|
||||
return nil, E.New("MultiPath TCP requires go1.21, please recompile your binary.")
|
||||
}
|
||||
setMultiPathTCP(&listenConfig)
|
||||
listenConfig.SetMultipathTCP(true)
|
||||
}
|
||||
if l.tproxy {
|
||||
listenConfig.Control = control.Append(listenConfig.Control, func(network, address string, conn syscall.RawConn) error {
|
||||
@@ -98,8 +99,6 @@ func (l *Listener) loopTCPIn() {
|
||||
}
|
||||
//nolint:staticcheck
|
||||
metadata.InboundDetour = l.listenOptions.Detour
|
||||
//nolint:staticcheck
|
||||
metadata.InboundOptions = l.listenOptions.InboundOptions
|
||||
metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
|
||||
metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
|
||||
ctx := log.ContextWithNewID(l.ctx)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"os/user"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-tun"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type Searcher interface {
|
||||
FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error)
|
||||
FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error)
|
||||
}
|
||||
|
||||
var ErrNotFound = E.New("process not found")
|
||||
@@ -22,15 +23,7 @@ type Config struct {
|
||||
PackageManager tun.PackageManager
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
ProcessID uint32
|
||||
ProcessPath string
|
||||
PackageName string
|
||||
User string
|
||||
UserId int32
|
||||
}
|
||||
|
||||
func FindProcessInfo(searcher Searcher, ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
|
||||
func FindProcessInfo(searcher Searcher, ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||
info, err := searcher.FindProcessInfo(ctx, network, source, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -38,7 +31,7 @@ func FindProcessInfo(searcher Searcher, ctx context.Context, network string, sou
|
||||
if info.UserId != -1 {
|
||||
osUser, _ := user.LookupId(F.ToString(info.UserId))
|
||||
if osUser != nil {
|
||||
info.User = osUser.Username
|
||||
info.UserName = osUser.Username
|
||||
}
|
||||
}
|
||||
return info, nil
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-tun"
|
||||
)
|
||||
|
||||
@@ -17,22 +18,22 @@ func NewSearcher(config Config) (Searcher, error) {
|
||||
return &androidSearcher{config.PackageManager}, nil
|
||||
}
|
||||
|
||||
func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
|
||||
func (s *androidSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||
_, uid, err := resolveSocketByNetlink(network, source, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sharedPackage, loaded := s.packageManager.SharedPackageByID(uid % 100000); loaded {
|
||||
return &Info{
|
||||
UserId: int32(uid),
|
||||
PackageName: sharedPackage,
|
||||
return &adapter.ConnectionOwner{
|
||||
UserId: int32(uid),
|
||||
AndroidPackageName: sharedPackage,
|
||||
}, nil
|
||||
}
|
||||
if packageName, loaded := s.packageManager.PackageByID(uid % 100000); loaded {
|
||||
return &Info{
|
||||
UserId: int32(uid),
|
||||
PackageName: packageName,
|
||||
return &adapter.ConnectionOwner{
|
||||
UserId: int32(uid),
|
||||
AndroidPackageName: packageName,
|
||||
}, nil
|
||||
}
|
||||
return &Info{UserId: int32(uid)}, nil
|
||||
return &adapter.ConnectionOwner{UserId: int32(uid)}, nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -23,12 +24,12 @@ func NewSearcher(_ Config) (Searcher, error) {
|
||||
return &darwinSearcher{}, nil
|
||||
}
|
||||
|
||||
func (d *darwinSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
|
||||
func (d *darwinSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||
processName, err := findProcessName(network, source.Addr(), int(source.Port()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Info{ProcessPath: processName, UserId: -1}, nil
|
||||
return &adapter.ConnectionOwner{ProcessPath: processName, UserId: -1}, nil
|
||||
}
|
||||
|
||||
var structSize = func() int {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
)
|
||||
|
||||
@@ -19,7 +20,7 @@ func NewSearcher(config Config) (Searcher, error) {
|
||||
return &linuxSearcher{config.Logger}, nil
|
||||
}
|
||||
|
||||
func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
|
||||
func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||
inode, uid, err := resolveSocketByNetlink(network, source, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -28,7 +29,7 @@ func (s *linuxSearcher) FindProcessInfo(ctx context.Context, network string, sou
|
||||
if err != nil {
|
||||
s.logger.DebugContext(ctx, "find process path: ", err)
|
||||
}
|
||||
return &Info{
|
||||
return &adapter.ConnectionOwner{
|
||||
UserId: int32(uid),
|
||||
ProcessPath: processPath,
|
||||
}, nil
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/winiphlpapi"
|
||||
|
||||
@@ -27,16 +28,16 @@ func initWin32API() error {
|
||||
return winiphlpapi.LoadExtendedTable()
|
||||
}
|
||||
|
||||
func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
|
||||
func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*adapter.ConnectionOwner, error) {
|
||||
pid, err := winiphlpapi.FindPid(network, source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path, err := getProcessPath(pid)
|
||||
if err != nil {
|
||||
return &Info{ProcessID: pid, UserId: -1}, err
|
||||
return &adapter.ConnectionOwner{ProcessID: pid, UserId: -1}, err
|
||||
}
|
||||
return &Info{ProcessID: pid, ProcessPath: path, UserId: -1}, nil
|
||||
return &adapter.ConnectionOwner{ProcessID: pid, ProcessPath: path, UserId: -1}, nil
|
||||
}
|
||||
|
||||
func getProcessPath(pid uint32) (string, error) {
|
||||
|
||||
9
common/settings/wifi.go
Normal file
9
common/settings/wifi.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package settings
|
||||
|
||||
import "github.com/sagernet/sing-box/adapter"
|
||||
|
||||
type WIFIMonitor interface {
|
||||
ReadWIFIState() adapter.WIFIState
|
||||
Start() error
|
||||
Close() error
|
||||
}
|
||||
46
common/settings/wifi_linux.go
Normal file
46
common/settings/wifi_linux.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type LinuxWIFIMonitor struct {
|
||||
monitor WIFIMonitor
|
||||
}
|
||||
|
||||
func NewWIFIMonitor(callback func(adapter.WIFIState)) (WIFIMonitor, error) {
|
||||
monitors := []func(func(adapter.WIFIState)) (WIFIMonitor, error){
|
||||
newNetworkManagerMonitor,
|
||||
newIWDMonitor,
|
||||
newWpaSupplicantMonitor,
|
||||
newConnManMonitor,
|
||||
}
|
||||
var errors []error
|
||||
for _, factory := range monitors {
|
||||
monitor, err := factory(callback)
|
||||
if err == nil {
|
||||
return &LinuxWIFIMonitor{monitor: monitor}, nil
|
||||
}
|
||||
errors = append(errors, err)
|
||||
}
|
||||
return nil, E.Cause(E.Errors(errors...), "no supported WIFI manager found")
|
||||
}
|
||||
|
||||
func (m *LinuxWIFIMonitor) ReadWIFIState() adapter.WIFIState {
|
||||
return m.monitor.ReadWIFIState()
|
||||
}
|
||||
|
||||
func (m *LinuxWIFIMonitor) Start() error {
|
||||
if m.monitor != nil {
|
||||
return m.monitor.Start()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *LinuxWIFIMonitor) Close() error {
|
||||
if m.monitor != nil {
|
||||
return m.monitor.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
168
common/settings/wifi_linux_connman.go
Normal file
168
common/settings/wifi_linux_connman.go
Normal file
@@ -0,0 +1,168 @@
|
||||
//go:build linux
|
||||
|
||||
package settings
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
type connmanMonitor struct {
|
||||
conn *dbus.Conn
|
||||
callback func(adapter.WIFIState)
|
||||
cancel context.CancelFunc
|
||||
signalChan chan *dbus.Signal
|
||||
}
|
||||
|
||||
func newConnManMonitor(callback func(adapter.WIFIState)) (WIFIMonitor, error) {
|
||||
conn, err := dbus.ConnectSystemBus()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmObj := conn.Object("net.connman", "/")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
call := cmObj.CallWithContext(ctx, "net.connman.Manager.GetServices", 0)
|
||||
if call.Err != nil {
|
||||
conn.Close()
|
||||
return nil, call.Err
|
||||
}
|
||||
return &connmanMonitor{conn: conn, callback: callback}, nil
|
||||
}
|
||||
|
||||
func (m *connmanMonitor) ReadWIFIState() adapter.WIFIState {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmObj := m.conn.Object("net.connman", "/")
|
||||
var services []interface{}
|
||||
err := cmObj.CallWithContext(ctx, "net.connman.Manager.GetServices", 0).Store(&services)
|
||||
if err != nil {
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
servicePair, ok := service.([]interface{})
|
||||
if !ok || len(servicePair) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
serviceProps, ok := servicePair[1].(map[string]dbus.Variant)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
typeVariant, hasType := serviceProps["Type"]
|
||||
if !hasType {
|
||||
continue
|
||||
}
|
||||
serviceType, ok := typeVariant.Value().(string)
|
||||
if !ok || serviceType != "wifi" {
|
||||
continue
|
||||
}
|
||||
|
||||
stateVariant, hasState := serviceProps["State"]
|
||||
if !hasState {
|
||||
continue
|
||||
}
|
||||
state, ok := stateVariant.Value().(string)
|
||||
if !ok || (state != "online" && state != "ready") {
|
||||
continue
|
||||
}
|
||||
|
||||
nameVariant, hasName := serviceProps["Name"]
|
||||
if !hasName {
|
||||
continue
|
||||
}
|
||||
ssid, ok := nameVariant.Value().(string)
|
||||
if !ok || ssid == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
bssidVariant, hasBSSID := serviceProps["BSSID"]
|
||||
if !hasBSSID {
|
||||
return adapter.WIFIState{SSID: ssid}
|
||||
}
|
||||
bssid, ok := bssidVariant.Value().(string)
|
||||
if !ok {
|
||||
return adapter.WIFIState{SSID: ssid}
|
||||
}
|
||||
|
||||
return adapter.WIFIState{
|
||||
SSID: ssid,
|
||||
BSSID: strings.ToUpper(strings.ReplaceAll(bssid, ":", "")),
|
||||
}
|
||||
}
|
||||
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
func (m *connmanMonitor) Start() error {
|
||||
if m.callback == nil {
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
m.signalChan = make(chan *dbus.Signal, 10)
|
||||
m.conn.Signal(m.signalChan)
|
||||
|
||||
err := m.conn.AddMatchSignal(
|
||||
dbus.WithMatchInterface("net.connman.Service"),
|
||||
dbus.WithMatchSender("net.connman"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := m.ReadWIFIState()
|
||||
go m.monitorSignals(ctx, m.signalChan, state)
|
||||
m.callback(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *connmanMonitor) monitorSignals(ctx context.Context, signalChan chan *dbus.Signal, lastState adapter.WIFIState) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case signal, ok := <-signalChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// godbus Signal.Name uses "interface.member" format (e.g. "net.connman.Service.PropertyChanged"),
|
||||
// not just the member name. This differs from the D-Bus signal member in the match rule.
|
||||
if signal.Name == "net.connman.Service.PropertyChanged" {
|
||||
state := m.ReadWIFIState()
|
||||
if state != lastState {
|
||||
lastState = state
|
||||
m.callback(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connmanMonitor) Close() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
if m.signalChan != nil {
|
||||
m.conn.RemoveSignal(m.signalChan)
|
||||
close(m.signalChan)
|
||||
}
|
||||
if m.conn != nil {
|
||||
m.conn.RemoveMatchSignal(
|
||||
dbus.WithMatchInterface("net.connman.Service"),
|
||||
dbus.WithMatchSender("net.connman"),
|
||||
)
|
||||
return m.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
190
common/settings/wifi_linux_iwd.go
Normal file
190
common/settings/wifi_linux_iwd.go
Normal file
@@ -0,0 +1,190 @@
|
||||
//go:build linux
|
||||
|
||||
package settings
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
type iwdMonitor struct {
|
||||
conn *dbus.Conn
|
||||
callback func(adapter.WIFIState)
|
||||
cancel context.CancelFunc
|
||||
signalChan chan *dbus.Signal
|
||||
}
|
||||
|
||||
func newIWDMonitor(callback func(adapter.WIFIState)) (WIFIMonitor, error) {
|
||||
conn, err := dbus.ConnectSystemBus()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iwdObj := conn.Object("net.connman.iwd", "/")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
call := iwdObj.CallWithContext(ctx, "org.freedesktop.DBus.ObjectManager.GetManagedObjects", 0)
|
||||
if call.Err != nil {
|
||||
conn.Close()
|
||||
return nil, call.Err
|
||||
}
|
||||
return &iwdMonitor{conn: conn, callback: callback}, nil
|
||||
}
|
||||
|
||||
func (m *iwdMonitor) ReadWIFIState() adapter.WIFIState {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
iwdObj := m.conn.Object("net.connman.iwd", "/")
|
||||
var objects map[dbus.ObjectPath]map[string]map[string]dbus.Variant
|
||||
err := iwdObj.CallWithContext(ctx, "org.freedesktop.DBus.ObjectManager.GetManagedObjects", 0).Store(&objects)
|
||||
if err != nil {
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
for _, interfaces := range objects {
|
||||
stationProps, hasStation := interfaces["net.connman.iwd.Station"]
|
||||
if !hasStation {
|
||||
continue
|
||||
}
|
||||
|
||||
stateVariant, hasState := stationProps["State"]
|
||||
if !hasState {
|
||||
continue
|
||||
}
|
||||
state, ok := stateVariant.Value().(string)
|
||||
if !ok || state != "connected" {
|
||||
continue
|
||||
}
|
||||
|
||||
connectedNetworkVariant, hasNetwork := stationProps["ConnectedNetwork"]
|
||||
if !hasNetwork {
|
||||
continue
|
||||
}
|
||||
networkPath, ok := connectedNetworkVariant.Value().(dbus.ObjectPath)
|
||||
if !ok || networkPath == "/" {
|
||||
continue
|
||||
}
|
||||
|
||||
networkInterfaces, hasNetworkPath := objects[networkPath]
|
||||
if !hasNetworkPath {
|
||||
continue
|
||||
}
|
||||
|
||||
networkProps, hasNetworkInterface := networkInterfaces["net.connman.iwd.Network"]
|
||||
if !hasNetworkInterface {
|
||||
continue
|
||||
}
|
||||
|
||||
nameVariant, hasName := networkProps["Name"]
|
||||
if !hasName {
|
||||
continue
|
||||
}
|
||||
ssid, ok := nameVariant.Value().(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
connectedBSSVariant, hasBSS := stationProps["ConnectedAccessPoint"]
|
||||
if !hasBSS {
|
||||
return adapter.WIFIState{SSID: ssid}
|
||||
}
|
||||
bssPath, ok := connectedBSSVariant.Value().(dbus.ObjectPath)
|
||||
if !ok || bssPath == "/" {
|
||||
return adapter.WIFIState{SSID: ssid}
|
||||
}
|
||||
|
||||
bssInterfaces, hasBSSPath := objects[bssPath]
|
||||
if !hasBSSPath {
|
||||
return adapter.WIFIState{SSID: ssid}
|
||||
}
|
||||
|
||||
bssProps, hasBSSInterface := bssInterfaces["net.connman.iwd.BasicServiceSet"]
|
||||
if !hasBSSInterface {
|
||||
return adapter.WIFIState{SSID: ssid}
|
||||
}
|
||||
|
||||
addressVariant, hasAddress := bssProps["Address"]
|
||||
if !hasAddress {
|
||||
return adapter.WIFIState{SSID: ssid}
|
||||
}
|
||||
bssid, ok := addressVariant.Value().(string)
|
||||
if !ok {
|
||||
return adapter.WIFIState{SSID: ssid}
|
||||
}
|
||||
|
||||
return adapter.WIFIState{
|
||||
SSID: ssid,
|
||||
BSSID: strings.ToUpper(strings.ReplaceAll(bssid, ":", "")),
|
||||
}
|
||||
}
|
||||
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
func (m *iwdMonitor) Start() error {
|
||||
if m.callback == nil {
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
m.signalChan = make(chan *dbus.Signal, 10)
|
||||
m.conn.Signal(m.signalChan)
|
||||
|
||||
err := m.conn.AddMatchSignal(
|
||||
dbus.WithMatchInterface("org.freedesktop.DBus.Properties"),
|
||||
dbus.WithMatchSender("net.connman.iwd"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := m.ReadWIFIState()
|
||||
go m.monitorSignals(ctx, m.signalChan, state)
|
||||
m.callback(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *iwdMonitor) monitorSignals(ctx context.Context, signalChan chan *dbus.Signal, lastState adapter.WIFIState) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case signal, ok := <-signalChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if signal.Name == "org.freedesktop.DBus.Properties.PropertiesChanged" {
|
||||
state := m.ReadWIFIState()
|
||||
if state != lastState {
|
||||
lastState = state
|
||||
m.callback(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *iwdMonitor) Close() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
if m.signalChan != nil {
|
||||
m.conn.RemoveSignal(m.signalChan)
|
||||
close(m.signalChan)
|
||||
}
|
||||
if m.conn != nil {
|
||||
m.conn.RemoveMatchSignal(
|
||||
dbus.WithMatchInterface("org.freedesktop.DBus.Properties"),
|
||||
dbus.WithMatchSender("net.connman.iwd"),
|
||||
)
|
||||
return m.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
165
common/settings/wifi_linux_nm.go
Normal file
165
common/settings/wifi_linux_nm.go
Normal file
@@ -0,0 +1,165 @@
|
||||
//go:build linux
|
||||
|
||||
package settings
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
type networkManagerMonitor struct {
|
||||
conn *dbus.Conn
|
||||
callback func(adapter.WIFIState)
|
||||
cancel context.CancelFunc
|
||||
signalChan chan *dbus.Signal
|
||||
}
|
||||
|
||||
func newNetworkManagerMonitor(callback func(adapter.WIFIState)) (WIFIMonitor, error) {
|
||||
conn, err := dbus.ConnectSystemBus()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nmObj := conn.Object("org.freedesktop.NetworkManager", "/org/freedesktop/NetworkManager")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
var state uint32
|
||||
err = nmObj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, "org.freedesktop.NetworkManager", "State").Store(&state)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &networkManagerMonitor{conn: conn, callback: callback}, nil
|
||||
}
|
||||
|
||||
func (m *networkManagerMonitor) ReadWIFIState() adapter.WIFIState {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
nmObj := m.conn.Object("org.freedesktop.NetworkManager", "/org/freedesktop/NetworkManager")
|
||||
|
||||
var activeConnectionPaths []dbus.ObjectPath
|
||||
err := nmObj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, "org.freedesktop.NetworkManager", "ActiveConnections").Store(&activeConnectionPaths)
|
||||
if err != nil || len(activeConnectionPaths) == 0 {
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
for _, connectionPath := range activeConnectionPaths {
|
||||
connObj := m.conn.Object("org.freedesktop.NetworkManager", connectionPath)
|
||||
|
||||
var devicePaths []dbus.ObjectPath
|
||||
err = connObj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, "org.freedesktop.NetworkManager.Connection.Active", "Devices").Store(&devicePaths)
|
||||
if err != nil || len(devicePaths) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, devicePath := range devicePaths {
|
||||
deviceObj := m.conn.Object("org.freedesktop.NetworkManager", devicePath)
|
||||
|
||||
var deviceType uint32
|
||||
err = deviceObj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, "org.freedesktop.NetworkManager.Device", "DeviceType").Store(&deviceType)
|
||||
if err != nil || deviceType != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
var accessPointPath dbus.ObjectPath
|
||||
err = deviceObj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, "org.freedesktop.NetworkManager.Device.Wireless", "ActiveAccessPoint").Store(&accessPointPath)
|
||||
if err != nil || accessPointPath == "/" {
|
||||
continue
|
||||
}
|
||||
|
||||
apObj := m.conn.Object("org.freedesktop.NetworkManager", accessPointPath)
|
||||
|
||||
var ssidBytes []byte
|
||||
err = apObj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, "org.freedesktop.NetworkManager.AccessPoint", "Ssid").Store(&ssidBytes)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var hwAddress string
|
||||
err = apObj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, "org.freedesktop.NetworkManager.AccessPoint", "HwAddress").Store(&hwAddress)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ssid := strings.TrimSpace(string(ssidBytes))
|
||||
if ssid == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
return adapter.WIFIState{
|
||||
SSID: ssid,
|
||||
BSSID: strings.ToUpper(strings.ReplaceAll(hwAddress, ":", "")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
func (m *networkManagerMonitor) Start() error {
|
||||
if m.callback == nil {
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
m.signalChan = make(chan *dbus.Signal, 10)
|
||||
m.conn.Signal(m.signalChan)
|
||||
|
||||
err := m.conn.AddMatchSignal(
|
||||
dbus.WithMatchSender("org.freedesktop.NetworkManager"),
|
||||
dbus.WithMatchInterface("org.freedesktop.DBus.Properties"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := m.ReadWIFIState()
|
||||
go m.monitorSignals(ctx, m.signalChan, state)
|
||||
m.callback(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *networkManagerMonitor) monitorSignals(ctx context.Context, signalChan chan *dbus.Signal, lastState adapter.WIFIState) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case signal, ok := <-signalChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if signal.Name == "org.freedesktop.DBus.Properties.PropertiesChanged" {
|
||||
state := m.ReadWIFIState()
|
||||
if state != lastState {
|
||||
lastState = state
|
||||
m.callback(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *networkManagerMonitor) Close() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
if m.signalChan != nil {
|
||||
m.conn.RemoveSignal(m.signalChan)
|
||||
close(m.signalChan)
|
||||
}
|
||||
if m.conn != nil {
|
||||
m.conn.RemoveMatchSignal(
|
||||
dbus.WithMatchSender("org.freedesktop.NetworkManager"),
|
||||
dbus.WithMatchInterface("org.freedesktop.DBus.Properties"),
|
||||
)
|
||||
return m.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
225
common/settings/wifi_linux_wpa.go
Normal file
225
common/settings/wifi_linux_wpa.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
var wpaSocketCounter atomic.Uint64
|
||||
|
||||
type wpaSupplicantMonitor struct {
|
||||
socketPath string
|
||||
callback func(adapter.WIFIState)
|
||||
cancel context.CancelFunc
|
||||
monitorConn *net.UnixConn
|
||||
connMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newWpaSupplicantMonitor(callback func(adapter.WIFIState)) (WIFIMonitor, error) {
|
||||
socketDirs := []string{"/var/run/wpa_supplicant", "/run/wpa_supplicant"}
|
||||
for _, socketDir := range socketDirs {
|
||||
entries, err := os.ReadDir(socketDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || entry.Name() == "." || entry.Name() == ".." {
|
||||
continue
|
||||
}
|
||||
socketPath := filepath.Join(socketDir, entry.Name())
|
||||
id := wpaSocketCounter.Add(1)
|
||||
localAddr := &net.UnixAddr{Name: fmt.Sprintf("@sing-box-wpa-%d-%d", os.Getpid(), id), Net: "unixgram"}
|
||||
remoteAddr := &net.UnixAddr{Name: socketPath, Net: "unixgram"}
|
||||
conn, err := net.DialUnix("unixgram", localAddr, remoteAddr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Close()
|
||||
return &wpaSupplicantMonitor{socketPath: socketPath, callback: callback}, nil
|
||||
}
|
||||
}
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
func (m *wpaSupplicantMonitor) ReadWIFIState() adapter.WIFIState {
|
||||
id := wpaSocketCounter.Add(1)
|
||||
localAddr := &net.UnixAddr{Name: fmt.Sprintf("@sing-box-wpa-%d-%d", os.Getpid(), id), Net: "unixgram"}
|
||||
remoteAddr := &net.UnixAddr{Name: m.socketPath, Net: "unixgram"}
|
||||
conn, err := net.DialUnix("unixgram", localAddr, remoteAddr)
|
||||
if err != nil {
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(3 * time.Second))
|
||||
|
||||
status, err := m.sendCommand(conn, "STATUS")
|
||||
if err != nil {
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
var ssid, bssid string
|
||||
var connected bool
|
||||
scanner := bufio.NewScanner(strings.NewReader(status))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "wpa_state=") {
|
||||
state := strings.TrimPrefix(line, "wpa_state=")
|
||||
connected = state == "COMPLETED"
|
||||
} else if strings.HasPrefix(line, "ssid=") {
|
||||
ssid = strings.TrimPrefix(line, "ssid=")
|
||||
} else if strings.HasPrefix(line, "bssid=") {
|
||||
bssid = strings.TrimPrefix(line, "bssid=")
|
||||
}
|
||||
}
|
||||
|
||||
if !connected || ssid == "" {
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
return adapter.WIFIState{
|
||||
SSID: ssid,
|
||||
BSSID: strings.ToUpper(strings.ReplaceAll(bssid, ":", "")),
|
||||
}
|
||||
}
|
||||
|
||||
// sendCommand sends a command to wpa_supplicant and returns the response.
|
||||
// Commands are sent without trailing newlines per the wpa_supplicant control
|
||||
// interface protocol - the official wpa_ctrl.c sends raw command strings.
|
||||
func (m *wpaSupplicantMonitor) sendCommand(conn *net.UnixConn, command string) (string, error) {
|
||||
_, err := conn.Write([]byte(command))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
response := string(buf[:n])
|
||||
if strings.HasPrefix(response, "FAIL") {
|
||||
return "", os.ErrInvalid
|
||||
}
|
||||
|
||||
return strings.TrimSpace(response), nil
|
||||
}
|
||||
|
||||
func (m *wpaSupplicantMonitor) Start() error {
|
||||
if m.callback == nil {
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
state := m.ReadWIFIState()
|
||||
go m.monitorEvents(ctx, state)
|
||||
m.callback(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *wpaSupplicantMonitor) monitorEvents(ctx context.Context, lastState adapter.WIFIState) {
|
||||
var consecutiveErrors int
|
||||
var debounceTimer *time.Timer
|
||||
var debounceMutex sync.Mutex
|
||||
|
||||
localAddr := &net.UnixAddr{Name: fmt.Sprintf("@sing-box-wpa-mon-%d", os.Getpid()), Net: "unixgram"}
|
||||
remoteAddr := &net.UnixAddr{Name: m.socketPath, Net: "unixgram"}
|
||||
conn, err := net.DialUnix("unixgram", localAddr, remoteAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
m.connMutex.Lock()
|
||||
m.monitorConn = conn
|
||||
m.connMutex.Unlock()
|
||||
|
||||
// ATTACH/DETACH commands use os_strcmp() for exact matching in wpa_supplicant,
|
||||
// so they must be sent without trailing newlines.
|
||||
// See: https://w1.fi/cgit/hostap/tree/wpa_supplicant/ctrl_iface_unix.c
|
||||
_, err = conn.Write([]byte("ATTACH"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil || !strings.HasPrefix(string(buf[:n]), "OK") {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
debounceMutex.Lock()
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
debounceMutex.Unlock()
|
||||
conn.Write([]byte("DETACH"))
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
consecutiveErrors++
|
||||
if consecutiveErrors > 10 {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
consecutiveErrors = 0
|
||||
|
||||
msg := string(buf[:n])
|
||||
if strings.Contains(msg, "CTRL-EVENT-CONNECTED") || strings.Contains(msg, "CTRL-EVENT-DISCONNECTED") {
|
||||
debounceMutex.Lock()
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
debounceTimer = time.AfterFunc(500*time.Millisecond, func() {
|
||||
state := m.ReadWIFIState()
|
||||
if state != lastState {
|
||||
lastState = state
|
||||
m.callback(state)
|
||||
}
|
||||
})
|
||||
debounceMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *wpaSupplicantMonitor) Close() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
m.connMutex.Lock()
|
||||
if m.monitorConn != nil {
|
||||
m.monitorConn.Close()
|
||||
}
|
||||
m.connMutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
27
common/settings/wifi_stub.go
Normal file
27
common/settings/wifi_stub.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package settings
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
type stubWIFIMonitor struct{}
|
||||
|
||||
func NewWIFIMonitor(callback func(adapter.WIFIState)) (WIFIMonitor, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (m *stubWIFIMonitor) ReadWIFIState() adapter.WIFIState {
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
func (m *stubWIFIMonitor) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *stubWIFIMonitor) Close() error {
|
||||
return nil
|
||||
}
|
||||
144
common/settings/wifi_windows.go
Normal file
144
common/settings/wifi_windows.go
Normal file
@@ -0,0 +1,144 @@
|
||||
//go:build windows
|
||||
|
||||
package settings
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/winwlanapi"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type windowsWIFIMonitor struct {
|
||||
handle windows.Handle
|
||||
callback func(adapter.WIFIState)
|
||||
cancel context.CancelFunc
|
||||
lastState adapter.WIFIState
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewWIFIMonitor(callback func(adapter.WIFIState)) (WIFIMonitor, error) {
|
||||
handle, err := winwlanapi.OpenHandle()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
interfaces, err := winwlanapi.EnumInterfaces(handle)
|
||||
if err != nil {
|
||||
winwlanapi.CloseHandle(handle)
|
||||
return nil, err
|
||||
}
|
||||
if len(interfaces) == 0 {
|
||||
winwlanapi.CloseHandle(handle)
|
||||
return nil, fmt.Errorf("no wireless interfaces found")
|
||||
}
|
||||
|
||||
return &windowsWIFIMonitor{
|
||||
handle: handle,
|
||||
callback: callback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *windowsWIFIMonitor) ReadWIFIState() adapter.WIFIState {
|
||||
interfaces, err := winwlanapi.EnumInterfaces(m.handle)
|
||||
if err != nil || len(interfaces) == 0 {
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if iface.InterfaceState != winwlanapi.InterfaceStateConnected {
|
||||
continue
|
||||
}
|
||||
|
||||
guid := iface.InterfaceGUID
|
||||
attrs, err := winwlanapi.QueryCurrentConnection(m.handle, &guid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ssidLength := attrs.AssociationAttributes.SSID.Length
|
||||
if ssidLength == 0 || ssidLength > winwlanapi.Dot11SSIDMaxLength {
|
||||
continue
|
||||
}
|
||||
|
||||
ssid := string(attrs.AssociationAttributes.SSID.SSID[:ssidLength])
|
||||
bssid := formatBSSID(attrs.AssociationAttributes.BSSID)
|
||||
|
||||
return adapter.WIFIState{
|
||||
SSID: strings.TrimSpace(ssid),
|
||||
BSSID: bssid,
|
||||
}
|
||||
}
|
||||
|
||||
return adapter.WIFIState{}
|
||||
}
|
||||
|
||||
func formatBSSID(mac winwlanapi.Dot11MacAddress) string {
|
||||
return fmt.Sprintf("%02X%02X%02X%02X%02X%02X",
|
||||
mac[0], mac[1], mac[2], mac[3], mac[4], mac[5])
|
||||
}
|
||||
|
||||
func (m *windowsWIFIMonitor) Start() error {
|
||||
if m.callback == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
m.lastState = m.ReadWIFIState()
|
||||
|
||||
callbackFunc := func(data *winwlanapi.NotificationData, callbackContext uintptr) uintptr {
|
||||
if data.NotificationSource != winwlanapi.NotificationSourceACM {
|
||||
return 0
|
||||
}
|
||||
switch data.NotificationCode {
|
||||
case winwlanapi.NotificationACMConnectionComplete,
|
||||
winwlanapi.NotificationACMDisconnected:
|
||||
m.checkAndNotify()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
callbackPointer := syscall.NewCallback(callbackFunc)
|
||||
|
||||
err := winwlanapi.RegisterNotification(m.handle, winwlanapi.NotificationSourceACM, callbackPointer, 0)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
}()
|
||||
|
||||
m.callback(m.lastState)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *windowsWIFIMonitor) checkAndNotify() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
state := m.ReadWIFIState()
|
||||
if state != m.lastState {
|
||||
m.lastState = state
|
||||
if m.callback != nil {
|
||||
m.callback(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *windowsWIFIMonitor) Close() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
winwlanapi.UnregisterNotification(m.handle)
|
||||
return winwlanapi.CloseHandle(m.handle)
|
||||
}
|
||||
@@ -6,12 +6,15 @@ import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/domain"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json/badjson"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
|
||||
"go4.org/netipx"
|
||||
@@ -41,6 +44,8 @@ const (
|
||||
ruleItemNetworkType
|
||||
ruleItemNetworkIsExpensive
|
||||
ruleItemNetworkIsConstrained
|
||||
ruleItemNetworkInterfaceAddress
|
||||
ruleItemDefaultInterfaceAddress
|
||||
ruleItemFinal uint8 = 0xFF
|
||||
)
|
||||
|
||||
@@ -230,6 +235,51 @@ func readDefaultRule(reader varbin.Reader, recover bool) (rule option.DefaultHea
|
||||
rule.NetworkIsExpensive = true
|
||||
case ruleItemNetworkIsConstrained:
|
||||
rule.NetworkIsConstrained = true
|
||||
case ruleItemNetworkInterfaceAddress:
|
||||
rule.NetworkInterfaceAddress = new(badjson.TypedMap[option.InterfaceType, badoption.Listable[*badoption.Prefixable]])
|
||||
var size uint64
|
||||
size, err = binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for i := uint64(0); i < size; i++ {
|
||||
var key uint8
|
||||
err = binary.Read(reader, binary.BigEndian, &key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var value []*badoption.Prefixable
|
||||
var prefixCount uint64
|
||||
prefixCount, err = binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for j := uint64(0); j < prefixCount; j++ {
|
||||
var prefix netip.Prefix
|
||||
prefix, err = readPrefix(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
value = append(value, common.Ptr(badoption.Prefixable(prefix)))
|
||||
}
|
||||
rule.NetworkInterfaceAddress.Put(option.InterfaceType(key), value)
|
||||
}
|
||||
case ruleItemDefaultInterfaceAddress:
|
||||
var value []*badoption.Prefixable
|
||||
var prefixCount uint64
|
||||
prefixCount, err = binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for j := uint64(0); j < prefixCount; j++ {
|
||||
var prefix netip.Prefix
|
||||
prefix, err = readPrefix(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
value = append(value, common.Ptr(badoption.Prefixable(prefix)))
|
||||
}
|
||||
rule.DefaultInterfaceAddress = value
|
||||
case ruleItemFinal:
|
||||
err = binary.Read(reader, binary.BigEndian, &rule.Invert)
|
||||
return
|
||||
@@ -346,7 +396,7 @@ func writeDefaultRule(writer varbin.Writer, rule option.DefaultHeadlessRule, gen
|
||||
}
|
||||
if len(rule.NetworkType) > 0 {
|
||||
if generateVersion < C.RuleSetVersion3 {
|
||||
return E.New("network_type rule item is only supported in version 3 or later")
|
||||
return E.New("`network_type` rule item is only supported in version 3 or later")
|
||||
}
|
||||
err = writeRuleItemUint8(writer, ruleItemNetworkType, rule.NetworkType)
|
||||
if err != nil {
|
||||
@@ -354,17 +404,71 @@ func writeDefaultRule(writer varbin.Writer, rule option.DefaultHeadlessRule, gen
|
||||
}
|
||||
}
|
||||
if rule.NetworkIsExpensive {
|
||||
if generateVersion < C.RuleSetVersion3 {
|
||||
return E.New("`network_is_expensive` rule item is only supported in version 3 or later")
|
||||
}
|
||||
err = binary.Write(writer, binary.BigEndian, ruleItemNetworkIsExpensive)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if rule.NetworkIsConstrained {
|
||||
if generateVersion < C.RuleSetVersion3 {
|
||||
return E.New("`network_is_constrained` rule item is only supported in version 3 or later")
|
||||
}
|
||||
err = binary.Write(writer, binary.BigEndian, ruleItemNetworkIsConstrained)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if rule.NetworkInterfaceAddress != nil && rule.NetworkInterfaceAddress.Size() > 0 {
|
||||
if generateVersion < C.RuleSetVersion4 {
|
||||
return E.New("`network_interface_address` rule item is only supported in version 4 or later")
|
||||
}
|
||||
err = writer.WriteByte(ruleItemNetworkInterfaceAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = varbin.WriteUvarint(writer, uint64(rule.NetworkInterfaceAddress.Size()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, entry := range rule.NetworkInterfaceAddress.Entries() {
|
||||
err = binary.Write(writer, binary.BigEndian, uint8(entry.Key.Build()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = varbin.WriteUvarint(writer, uint64(len(entry.Value)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rawPrefix := range entry.Value {
|
||||
err = writePrefix(writer, rawPrefix.Build(netip.Prefix{}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(rule.DefaultInterfaceAddress) > 0 {
|
||||
if generateVersion < C.RuleSetVersion4 {
|
||||
return E.New("`default_interface_address` rule item is only supported in version 4 or later")
|
||||
}
|
||||
err = writer.WriteByte(ruleItemDefaultInterfaceAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = varbin.WriteUvarint(writer, uint64(len(rule.DefaultInterfaceAddress)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rawPrefix := range rule.DefaultInterfaceAddress {
|
||||
err = writePrefix(writer, rawPrefix.Build(netip.Prefix{}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(rule.WIFISSID) > 0 {
|
||||
err = writeRuleItemString(writer, ruleItemWIFISSID, rule.WIFISSID)
|
||||
if err != nil {
|
||||
@@ -402,7 +506,24 @@ func writeDefaultRule(writer varbin.Writer, rule option.DefaultHeadlessRule, gen
|
||||
}
|
||||
|
||||
func readRuleItemString(reader varbin.Reader) ([]string, error) {
|
||||
return varbin.ReadValue[[]string](reader, binary.BigEndian)
|
||||
length, err := binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]string, length)
|
||||
for i := range result {
|
||||
strLen, err := binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := make([]byte, strLen)
|
||||
_, err = io.ReadFull(reader, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[i] = string(buf)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func writeRuleItemString(writer varbin.Writer, itemType uint8, value []string) error {
|
||||
@@ -410,11 +531,34 @@ func writeRuleItemString(writer varbin.Writer, itemType uint8, value []string) e
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return varbin.Write(writer, binary.BigEndian, value)
|
||||
_, err = varbin.WriteUvarint(writer, uint64(len(value)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, s := range value {
|
||||
_, err = varbin.WriteUvarint(writer, uint64(len(s)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write([]byte(s))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readRuleItemUint8[E ~uint8](reader varbin.Reader) ([]E, error) {
|
||||
return varbin.ReadValue[[]E](reader, binary.BigEndian)
|
||||
length, err := binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]E, length)
|
||||
_, err = io.ReadFull(reader, *(*[]byte)(unsafe.Pointer(&result)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func writeRuleItemUint8[E ~uint8](writer varbin.Writer, itemType uint8, value []E) error {
|
||||
@@ -422,11 +566,25 @@ func writeRuleItemUint8[E ~uint8](writer varbin.Writer, itemType uint8, value []
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return varbin.Write(writer, binary.BigEndian, value)
|
||||
_, err = varbin.WriteUvarint(writer, uint64(len(value)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write(*(*[]byte)(unsafe.Pointer(&value)))
|
||||
return err
|
||||
}
|
||||
|
||||
func readRuleItemUint16(reader varbin.Reader) ([]uint16, error) {
|
||||
return varbin.ReadValue[[]uint16](reader, binary.BigEndian)
|
||||
length, err := binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]uint16, length)
|
||||
err = binary.Read(reader, binary.BigEndian, result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func writeRuleItemUint16(writer varbin.Writer, itemType uint8, value []uint16) error {
|
||||
@@ -434,7 +592,11 @@ func writeRuleItemUint16(writer varbin.Writer, itemType uint8, value []uint16) e
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return varbin.Write(writer, binary.BigEndian, value)
|
||||
_, err = varbin.WriteUvarint(writer, uint64(len(value)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(writer, binary.BigEndian, value)
|
||||
}
|
||||
|
||||
func writeRuleItemCIDR(writer varbin.Writer, itemType uint8, value []string) error {
|
||||
|
||||
494
common/srs/compat_test.go
Normal file
494
common/srs/compat_test.go
Normal file
@@ -0,0 +1,494 @@
|
||||
package srs
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
// Old implementations using varbin reflection-based serialization
|
||||
|
||||
func oldWriteStringSlice(writer varbin.Writer, value []string) error {
|
||||
//nolint:staticcheck
|
||||
return varbin.Write(writer, binary.BigEndian, value)
|
||||
}
|
||||
|
||||
func oldReadStringSlice(reader varbin.Reader) ([]string, error) {
|
||||
//nolint:staticcheck
|
||||
return varbin.ReadValue[[]string](reader, binary.BigEndian)
|
||||
}
|
||||
|
||||
func oldWriteUint8Slice[E ~uint8](writer varbin.Writer, value []E) error {
|
||||
//nolint:staticcheck
|
||||
return varbin.Write(writer, binary.BigEndian, value)
|
||||
}
|
||||
|
||||
func oldReadUint8Slice[E ~uint8](reader varbin.Reader) ([]E, error) {
|
||||
//nolint:staticcheck
|
||||
return varbin.ReadValue[[]E](reader, binary.BigEndian)
|
||||
}
|
||||
|
||||
func oldWriteUint16Slice(writer varbin.Writer, value []uint16) error {
|
||||
//nolint:staticcheck
|
||||
return varbin.Write(writer, binary.BigEndian, value)
|
||||
}
|
||||
|
||||
func oldReadUint16Slice(reader varbin.Reader) ([]uint16, error) {
|
||||
//nolint:staticcheck
|
||||
return varbin.ReadValue[[]uint16](reader, binary.BigEndian)
|
||||
}
|
||||
|
||||
func oldWritePrefix(writer varbin.Writer, prefix netip.Prefix) error {
|
||||
//nolint:staticcheck
|
||||
err := varbin.Write(writer, binary.BigEndian, prefix.Addr().AsSlice())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(writer, binary.BigEndian, uint8(prefix.Bits()))
|
||||
}
|
||||
|
||||
type oldIPRangeData struct {
|
||||
From []byte
|
||||
To []byte
|
||||
}
|
||||
|
||||
// Note: The old writeIPSet had a bug where varbin.Write(writer, binary.BigEndian, data)
|
||||
// with a struct VALUE (not pointer) silently wrote nothing because field.CanSet() returned false.
|
||||
// This caused IP range data to be missing from the output.
|
||||
// The new implementation correctly writes all range data.
|
||||
//
|
||||
// The old readIPSet used varbin.Read with a pre-allocated slice, which worked because
|
||||
// slice elements are addressable and CanSet() returns true for them.
|
||||
//
|
||||
// For compatibility testing, we verify:
|
||||
// 1. New write produces correct output with range data
|
||||
// 2. New read can parse the new format correctly
|
||||
// 3. Round-trip works correctly
|
||||
|
||||
func oldReadIPSet(reader varbin.Reader) (*netipx.IPSet, error) {
|
||||
version, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != 1 {
|
||||
return nil, err
|
||||
}
|
||||
var length uint64
|
||||
err = binary.Read(reader, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ranges := make([]oldIPRangeData, length)
|
||||
//nolint:staticcheck
|
||||
err = varbin.Read(reader, binary.BigEndian, &ranges)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mySet := &myIPSet{
|
||||
rr: make([]myIPRange, len(ranges)),
|
||||
}
|
||||
for i, rangeData := range ranges {
|
||||
mySet.rr[i].from = M.AddrFromIP(rangeData.From)
|
||||
mySet.rr[i].to = M.AddrFromIP(rangeData.To)
|
||||
}
|
||||
return (*netipx.IPSet)(unsafe.Pointer(mySet)), nil
|
||||
}
|
||||
|
||||
// New write functions (without itemType prefix for testing)
|
||||
|
||||
func newWriteStringSlice(writer varbin.Writer, value []string) error {
|
||||
_, err := varbin.WriteUvarint(writer, uint64(len(value)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, s := range value {
|
||||
_, err = varbin.WriteUvarint(writer, uint64(len(s)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write([]byte(s))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newWriteUint8Slice[E ~uint8](writer varbin.Writer, value []E) error {
|
||||
_, err := varbin.WriteUvarint(writer, uint64(len(value)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write(*(*[]byte)(unsafe.Pointer(&value)))
|
||||
return err
|
||||
}
|
||||
|
||||
func newWriteUint16Slice(writer varbin.Writer, value []uint16) error {
|
||||
_, err := varbin.WriteUvarint(writer, uint64(len(value)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Write(writer, binary.BigEndian, value)
|
||||
}
|
||||
|
||||
func newWritePrefix(writer varbin.Writer, prefix netip.Prefix) error {
|
||||
addrSlice := prefix.Addr().AsSlice()
|
||||
_, err := varbin.WriteUvarint(writer, uint64(len(addrSlice)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write(addrSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writer.WriteByte(uint8(prefix.Bits()))
|
||||
}
|
||||
|
||||
// Tests
|
||||
|
||||
func TestStringSliceCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input []string
|
||||
}{
|
||||
{"nil", nil},
|
||||
{"empty", []string{}},
|
||||
{"single_empty", []string{""}},
|
||||
{"single", []string{"test"}},
|
||||
{"multi", []string{"a", "b", "c"}},
|
||||
{"with_empty", []string{"a", "", "c"}},
|
||||
{"utf8", []string{"测试", "テスト", "тест"}},
|
||||
{"long_string", []string{strings.Repeat("x", 128)}},
|
||||
{"many_elements", generateStrings(128)},
|
||||
{"many_elements_256", generateStrings(256)},
|
||||
{"127_byte_string", []string{strings.Repeat("x", 127)}},
|
||||
{"128_byte_string", []string{strings.Repeat("x", 128)}},
|
||||
{"mixed_lengths", []string{"a", strings.Repeat("b", 100), "", strings.Repeat("c", 200)}},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Old write
|
||||
var oldBuf bytes.Buffer
|
||||
err := oldWriteStringSlice(&oldBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// New write
|
||||
var newBuf bytes.Buffer
|
||||
err = newWriteStringSlice(&newBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bytes must match
|
||||
require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
|
||||
"mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
|
||||
|
||||
// New write -> old read
|
||||
readBack, err := oldReadStringSlice(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
requireStringSliceEqual(t, tc.input, readBack)
|
||||
|
||||
// Old write -> new read
|
||||
readBack2, err := readRuleItemString(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
requireStringSliceEqual(t, tc.input, readBack2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUint8SliceCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input []uint8
|
||||
}{
|
||||
{"nil", nil},
|
||||
{"empty", []uint8{}},
|
||||
{"single_zero", []uint8{0}},
|
||||
{"single_max", []uint8{255}},
|
||||
{"multi", []uint8{0, 1, 127, 128, 255}},
|
||||
{"boundary", []uint8{0x00, 0x7f, 0x80, 0xff}},
|
||||
{"sequential", generateUint8Slice(256)},
|
||||
{"127_elements", generateUint8Slice(127)},
|
||||
{"128_elements", generateUint8Slice(128)},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Old write
|
||||
var oldBuf bytes.Buffer
|
||||
err := oldWriteUint8Slice(&oldBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// New write
|
||||
var newBuf bytes.Buffer
|
||||
err = newWriteUint8Slice(&newBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bytes must match
|
||||
require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
|
||||
"mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
|
||||
|
||||
// New write -> old read
|
||||
readBack, err := oldReadUint8Slice[uint8](bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
requireUint8SliceEqual(t, tc.input, readBack)
|
||||
|
||||
// Old write -> new read
|
||||
readBack2, err := readRuleItemUint8[uint8](bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
requireUint8SliceEqual(t, tc.input, readBack2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUint16SliceCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input []uint16
|
||||
}{
|
||||
{"nil", nil},
|
||||
{"empty", []uint16{}},
|
||||
{"single_zero", []uint16{0}},
|
||||
{"single_max", []uint16{65535}},
|
||||
{"multi", []uint16{0, 255, 256, 32767, 32768, 65535}},
|
||||
{"ports", []uint16{80, 443, 8080, 8443}},
|
||||
{"127_elements", generateUint16Slice(127)},
|
||||
{"128_elements", generateUint16Slice(128)},
|
||||
{"256_elements", generateUint16Slice(256)},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Old write
|
||||
var oldBuf bytes.Buffer
|
||||
err := oldWriteUint16Slice(&oldBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// New write
|
||||
var newBuf bytes.Buffer
|
||||
err = newWriteUint16Slice(&newBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bytes must match
|
||||
require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
|
||||
"mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
|
||||
|
||||
// New write -> old read
|
||||
readBack, err := oldReadUint16Slice(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
requireUint16SliceEqual(t, tc.input, readBack)
|
||||
|
||||
// Old write -> new read
|
||||
readBack2, err := readRuleItemUint16(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
requireUint16SliceEqual(t, tc.input, readBack2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefixCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input netip.Prefix
|
||||
}{
|
||||
{"ipv4_0", netip.MustParsePrefix("0.0.0.0/0")},
|
||||
{"ipv4_8", netip.MustParsePrefix("10.0.0.0/8")},
|
||||
{"ipv4_16", netip.MustParsePrefix("192.168.0.0/16")},
|
||||
{"ipv4_24", netip.MustParsePrefix("192.168.1.0/24")},
|
||||
{"ipv4_32", netip.MustParsePrefix("1.2.3.4/32")},
|
||||
{"ipv6_0", netip.MustParsePrefix("::/0")},
|
||||
{"ipv6_64", netip.MustParsePrefix("2001:db8::/64")},
|
||||
{"ipv6_128", netip.MustParsePrefix("::1/128")},
|
||||
{"ipv6_full", netip.MustParsePrefix("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128")},
|
||||
{"ipv4_private", netip.MustParsePrefix("172.16.0.0/12")},
|
||||
{"ipv6_link_local", netip.MustParsePrefix("fe80::/10")},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Old write
|
||||
var oldBuf bytes.Buffer
|
||||
err := oldWritePrefix(&oldBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// New write
|
||||
var newBuf bytes.Buffer
|
||||
err = newWritePrefix(&newBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bytes must match
|
||||
require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
|
||||
"mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
|
||||
|
||||
// New write -> new read (no old read for prefix)
|
||||
readBack, err := readPrefix(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.input, readBack)
|
||||
|
||||
// Old write -> new read
|
||||
readBack2, err := readPrefix(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.input, readBack2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPSetCompat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Note: The old writeIPSet was buggy (varbin.Write with struct values wrote nothing).
|
||||
// This test verifies the new implementation writes correct data and round-trips correctly.
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input *netipx.IPSet
|
||||
}{
|
||||
{"single_ipv4", buildIPSet("1.2.3.4")},
|
||||
{"ipv4_range", buildIPSet("192.168.0.0/16")},
|
||||
{"multi_ipv4", buildIPSet("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16")},
|
||||
{"single_ipv6", buildIPSet("::1")},
|
||||
{"ipv6_range", buildIPSet("2001:db8::/32")},
|
||||
{"mixed", buildIPSet("10.0.0.0/8", "::1", "2001:db8::/32")},
|
||||
{"large", buildLargeIPSet(100)},
|
||||
{"adjacent_ranges", buildIPSet("192.168.0.0/24", "192.168.1.0/24", "192.168.2.0/24")},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// New write
|
||||
var newBuf bytes.Buffer
|
||||
err := writeIPSet(&newBuf, tc.input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify format starts with version byte (1) + uint64 count
|
||||
require.True(t, len(newBuf.Bytes()) >= 9, "output too short")
|
||||
require.Equal(t, byte(1), newBuf.Bytes()[0], "version byte mismatch")
|
||||
|
||||
// New write -> old read (varbin.Read with pre-allocated slice works correctly)
|
||||
readBack, err := oldReadIPSet(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
requireIPSetEqual(t, tc.input, readBack)
|
||||
|
||||
// New write -> new read
|
||||
readBack2, err := readIPSet(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
|
||||
require.NoError(t, err)
|
||||
requireIPSetEqual(t, tc.input, readBack2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func generateStrings(count int) []string {
|
||||
result := make([]string, count)
|
||||
for i := range result {
|
||||
result[i] = strings.Repeat("x", i%50)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func generateUint8Slice(count int) []uint8 {
|
||||
result := make([]uint8, count)
|
||||
for i := range result {
|
||||
result[i] = uint8(i % 256)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func generateUint16Slice(count int) []uint16 {
|
||||
result := make([]uint16, count)
|
||||
for i := range result {
|
||||
result[i] = uint16(i * 257)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildIPSet(cidrs ...string) *netipx.IPSet {
|
||||
var builder netipx.IPSetBuilder
|
||||
for _, cidr := range cidrs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
addr, err := netip.ParseAddr(cidr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
builder.Add(addr)
|
||||
} else {
|
||||
builder.AddPrefix(prefix)
|
||||
}
|
||||
}
|
||||
set, _ := builder.IPSet()
|
||||
return set
|
||||
}
|
||||
|
||||
func buildLargeIPSet(count int) *netipx.IPSet {
|
||||
var builder netipx.IPSetBuilder
|
||||
for i := 0; i < count; i++ {
|
||||
prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{10, byte(i / 256), byte(i % 256), 0}), 24)
|
||||
builder.AddPrefix(prefix)
|
||||
}
|
||||
set, _ := builder.IPSet()
|
||||
return set
|
||||
}
|
||||
|
||||
func requireStringSliceEqual(t *testing.T, expected, actual []string) {
|
||||
t.Helper()
|
||||
if len(expected) == 0 && len(actual) == 0 {
|
||||
return
|
||||
}
|
||||
require.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
func requireUint8SliceEqual(t *testing.T, expected, actual []uint8) {
|
||||
t.Helper()
|
||||
if len(expected) == 0 && len(actual) == 0 {
|
||||
return
|
||||
}
|
||||
require.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
func requireUint16SliceEqual(t *testing.T, expected, actual []uint16) {
|
||||
t.Helper()
|
||||
if len(expected) == 0 && len(actual) == 0 {
|
||||
return
|
||||
}
|
||||
require.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
func requireIPSetEqual(t *testing.T, expected, actual *netipx.IPSet) {
|
||||
t.Helper()
|
||||
expectedRanges := expected.Ranges()
|
||||
actualRanges := actual.Ranges()
|
||||
require.Equal(t, len(expectedRanges), len(actualRanges), "range count mismatch")
|
||||
for i := range expectedRanges {
|
||||
require.Equal(t, expectedRanges[i].From(), actualRanges[i].From(), "range[%d].from mismatch", i)
|
||||
require.Equal(t, expectedRanges[i].To(), actualRanges[i].To(), "range[%d].to mismatch", i)
|
||||
}
|
||||
}
|
||||
44
common/srs/ip_cidr.go
Normal file
44
common/srs/ip_cidr.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package srs
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net/netip"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
)
|
||||
|
||||
func readPrefix(reader varbin.Reader) (netip.Prefix, error) {
|
||||
addrLen, err := binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, err
|
||||
}
|
||||
addrSlice := make([]byte, addrLen)
|
||||
_, err = io.ReadFull(reader, addrSlice)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, err
|
||||
}
|
||||
prefixBits, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return netip.Prefix{}, err
|
||||
}
|
||||
return netip.PrefixFrom(M.AddrFromIP(addrSlice), int(prefixBits)), nil
|
||||
}
|
||||
|
||||
func writePrefix(writer varbin.Writer, prefix netip.Prefix) error {
|
||||
addrSlice := prefix.Addr().AsSlice()
|
||||
_, err := varbin.WriteUvarint(writer, uint64(len(addrSlice)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write(addrSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = writer.WriteByte(uint8(prefix.Bits()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2,11 +2,11 @@ package srs
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/varbin"
|
||||
|
||||
@@ -22,11 +22,6 @@ type myIPRange struct {
|
||||
to netip.Addr
|
||||
}
|
||||
|
||||
type myIPRangeData struct {
|
||||
From []byte
|
||||
To []byte
|
||||
}
|
||||
|
||||
func readIPSet(reader varbin.Reader) (*netipx.IPSet, error) {
|
||||
version, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
@@ -41,17 +36,30 @@ func readIPSet(reader varbin.Reader) (*netipx.IPSet, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ranges := make([]myIPRangeData, length)
|
||||
err = varbin.Read(reader, binary.BigEndian, &ranges)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mySet := &myIPSet{
|
||||
rr: make([]myIPRange, len(ranges)),
|
||||
rr: make([]myIPRange, length),
|
||||
}
|
||||
for i, rangeData := range ranges {
|
||||
mySet.rr[i].from = M.AddrFromIP(rangeData.From)
|
||||
mySet.rr[i].to = M.AddrFromIP(rangeData.To)
|
||||
for i := range mySet.rr {
|
||||
fromLen, err := binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fromBytes := make([]byte, fromLen)
|
||||
_, err = io.ReadFull(reader, fromBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toLen, err := binary.ReadUvarint(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toBytes := make([]byte, toLen)
|
||||
_, err = io.ReadFull(reader, toBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mySet.rr[i].from = M.AddrFromIP(fromBytes)
|
||||
mySet.rr[i].to = M.AddrFromIP(toBytes)
|
||||
}
|
||||
return (*netipx.IPSet)(unsafe.Pointer(mySet)), nil
|
||||
}
|
||||
@@ -61,18 +69,27 @@ func writeIPSet(writer varbin.Writer, set *netipx.IPSet) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dataList := common.Map((*myIPSet)(unsafe.Pointer(set)).rr, func(rr myIPRange) myIPRangeData {
|
||||
return myIPRangeData{
|
||||
From: rr.from.AsSlice(),
|
||||
To: rr.to.AsSlice(),
|
||||
}
|
||||
})
|
||||
err = binary.Write(writer, binary.BigEndian, uint64(len(dataList)))
|
||||
mySet := (*myIPSet)(unsafe.Pointer(set))
|
||||
err = binary.Write(writer, binary.BigEndian, uint64(len(mySet.rr)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, data := range dataList {
|
||||
err = varbin.Write(writer, binary.BigEndian, data)
|
||||
for _, rr := range mySet.rr {
|
||||
fromBytes := rr.from.AsSlice()
|
||||
_, err = varbin.WriteUvarint(writer, uint64(len(fromBytes)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write(fromBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
toBytes := rr.to.AsSlice()
|
||||
_, err = varbin.WriteUvarint(writer, uint64(len(toBytes)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = writer.Write(toBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
"github.com/libdns/acmedns"
|
||||
"github.com/libdns/alidns"
|
||||
"github.com/libdns/cloudflare"
|
||||
"github.com/mholt/acmez/v3/acme"
|
||||
@@ -114,13 +115,24 @@ func startACME(ctx context.Context, logger logger.Logger, options option.Inbound
|
||||
switch dnsOptions.Provider {
|
||||
case C.DNSProviderAliDNS:
|
||||
solver.DNSProvider = &alidns.Provider{
|
||||
AccKeyID: dnsOptions.AliDNSOptions.AccessKeyID,
|
||||
AccKeySecret: dnsOptions.AliDNSOptions.AccessKeySecret,
|
||||
RegionID: dnsOptions.AliDNSOptions.RegionID,
|
||||
CredentialInfo: alidns.CredentialInfo{
|
||||
AccessKeyID: dnsOptions.AliDNSOptions.AccessKeyID,
|
||||
AccessKeySecret: dnsOptions.AliDNSOptions.AccessKeySecret,
|
||||
RegionID: dnsOptions.AliDNSOptions.RegionID,
|
||||
SecurityToken: dnsOptions.AliDNSOptions.SecurityToken,
|
||||
},
|
||||
}
|
||||
case C.DNSProviderCloudflare:
|
||||
solver.DNSProvider = &cloudflare.Provider{
|
||||
APIToken: dnsOptions.CloudflareOptions.APIToken,
|
||||
APIToken: dnsOptions.CloudflareOptions.APIToken,
|
||||
ZoneToken: dnsOptions.CloudflareOptions.ZoneToken,
|
||||
}
|
||||
case C.DNSProviderACMEDNS:
|
||||
solver.DNSProvider = &acmedns.Provider{
|
||||
Username: dnsOptions.ACMEDNSOptions.Username,
|
||||
Password: dnsOptions.ACMEDNSOptions.Password,
|
||||
Subdomain: dnsOptions.ACMEDNSOptions.Subdomain,
|
||||
ServerURL: dnsOptions.ACMEDNSOptions.ServerURL,
|
||||
}
|
||||
default:
|
||||
return nil, nil, E.New("unsupported ACME DNS01 provider type: " + dnsOptions.Provider)
|
||||
|
||||
@@ -2,39 +2,71 @@ package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/badtls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
func NewDialerFromOptions(ctx context.Context, router adapter.Router, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
|
||||
func NewDialerFromOptions(ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
|
||||
if !options.Enabled {
|
||||
return dialer, nil
|
||||
}
|
||||
config, err := NewClient(ctx, serverAddress, options)
|
||||
config, err := NewClientWithOptions(ClientOptions{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
ServerAddress: serverAddress,
|
||||
Options: options,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewDialer(dialer, config), nil
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
if !options.Enabled {
|
||||
func NewClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
return NewClientWithOptions(ClientOptions{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
ServerAddress: serverAddress,
|
||||
Options: options,
|
||||
})
|
||||
}
|
||||
|
||||
type ClientOptions struct {
|
||||
Context context.Context
|
||||
Logger logger.ContextLogger
|
||||
ServerAddress string
|
||||
Options option.OutboundTLSOptions
|
||||
KTLSCompatible bool
|
||||
}
|
||||
|
||||
func NewClientWithOptions(options ClientOptions) (Config, error) {
|
||||
if !options.Options.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
if options.Reality != nil && options.Reality.Enabled {
|
||||
return NewRealityClient(ctx, serverAddress, options)
|
||||
} else if options.UTLS != nil && options.UTLS.Enabled {
|
||||
return NewUTLSClient(ctx, serverAddress, options)
|
||||
if !options.KTLSCompatible {
|
||||
if options.Options.KernelTx {
|
||||
options.Logger.Warn("enabling kTLS TX in current scenarios will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_tx")
|
||||
}
|
||||
}
|
||||
return NewSTDClient(ctx, serverAddress, options)
|
||||
if options.Options.KernelRx {
|
||||
options.Logger.Warn("enabling kTLS RX will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_rx")
|
||||
}
|
||||
if options.Options.Reality != nil && options.Options.Reality.Enabled {
|
||||
return NewRealityClient(options.Context, options.Logger, options.ServerAddress, options.Options)
|
||||
} else if options.Options.UTLS != nil && options.Options.UTLS.Enabled {
|
||||
return NewUTLSClient(options.Context, options.Logger, options.ServerAddress, options.Options)
|
||||
}
|
||||
return NewSTDClient(options.Context, options.Logger, options.ServerAddress, options.Options)
|
||||
}
|
||||
|
||||
func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) {
|
||||
@@ -79,10 +111,10 @@ func (d *defaultDialer) ListenPacket(ctx context.Context, destination M.Socksadd
|
||||
}
|
||||
|
||||
func (d *defaultDialer) DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error) {
|
||||
return d.dialContext(ctx, destination)
|
||||
return d.dialContext(ctx, destination, true)
|
||||
}
|
||||
|
||||
func (d *defaultDialer) dialContext(ctx context.Context, destination M.Socksaddr) (Conn, error) {
|
||||
func (d *defaultDialer) dialContext(ctx context.Context, destination M.Socksaddr, echRetry bool) (Conn, error) {
|
||||
conn, err := d.dialer.DialContext(ctx, N.NetworkTCP, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -90,6 +122,13 @@ func (d *defaultDialer) dialContext(ctx context.Context, destination M.Socksaddr
|
||||
tlsConn, err := aTLS.ClientHandshake(ctx, conn, d.config)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
var echErr *tls.ECHRejectionError
|
||||
if echRetry && errors.As(err, &echErr) && len(echErr.RetryConfigList) > 0 {
|
||||
if echConfig, isECH := d.config.(ECHCapableConfig); isECH {
|
||||
echConfig.SetECHConfigList(echErr.RetryConfigList)
|
||||
return d.dialContext(ctx, destination, false)
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return tlsConn, nil
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/dns"
|
||||
"github.com/sagernet/sing-box/experimental/deprecated"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
@@ -38,7 +37,7 @@ func parseECHClientConfig(ctx context.Context, clientConfig ECHCapableConfig, op
|
||||
}
|
||||
//nolint:staticcheck
|
||||
if options.ECH.PQSignatureSchemesEnabled || options.ECH.DynamicRecordSizingDisabled {
|
||||
deprecated.Report(ctx, deprecated.OptionLegacyECHOptions)
|
||||
return nil, E.New("legacy ECH options are deprecated in sing-box 1.12.0 and removed in sing-box 1.13.0")
|
||||
}
|
||||
if len(echConfig) > 0 {
|
||||
block, rest := pem.Decode(echConfig)
|
||||
@@ -51,6 +50,7 @@ func parseECHClientConfig(ctx context.Context, clientConfig ECHCapableConfig, op
|
||||
return &ECHClientConfig{
|
||||
ECHCapableConfig: clientConfig,
|
||||
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
|
||||
queryServerName: options.ECH.QueryServerName,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions,
|
||||
tlsConfig.EncryptedClientHelloKeys = echKeys
|
||||
//nolint:staticcheck
|
||||
if options.ECH.PQSignatureSchemesEnabled || options.ECH.DynamicRecordSizingDisabled {
|
||||
deprecated.Report(ctx, deprecated.OptionLegacyECHOptions)
|
||||
return E.New("legacy ECH options are deprecated in sing-box 1.12.0 and removed in sing-box 1.13.0")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -108,10 +108,11 @@ func parseECHKeys(echKey []byte) ([]tls.EncryptedClientHelloKey, error) {
|
||||
|
||||
type ECHClientConfig struct {
|
||||
ECHCapableConfig
|
||||
access sync.Mutex
|
||||
dnsRouter adapter.DNSRouter
|
||||
lastTTL time.Duration
|
||||
lastUpdate time.Time
|
||||
access sync.Mutex
|
||||
dnsRouter adapter.DNSRouter
|
||||
queryServerName string
|
||||
lastTTL time.Duration
|
||||
lastUpdate time.Time
|
||||
}
|
||||
|
||||
func (s *ECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
|
||||
@@ -130,13 +131,17 @@ func (s *ECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Conn)
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
if len(s.ECHConfigList()) == 0 || s.lastTTL == 0 || time.Since(s.lastUpdate) > s.lastTTL {
|
||||
queryServerName := s.queryServerName
|
||||
if queryServerName == "" {
|
||||
queryServerName = s.ServerName()
|
||||
}
|
||||
message := &mDNS.Msg{
|
||||
MsgHdr: mDNS.MsgHdr{
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []mDNS.Question{
|
||||
{
|
||||
Name: mDNS.Fqdn(s.ServerName()),
|
||||
Name: mDNS.Fqdn(queryServerName),
|
||||
Qtype: mDNS.TypeHTTPS,
|
||||
Qclass: mDNS.ClassINET,
|
||||
},
|
||||
@@ -175,7 +180,12 @@ func (s *ECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Conn)
|
||||
}
|
||||
|
||||
func (s *ECHClientConfig) Clone() Config {
|
||||
return &ECHClientConfig{ECHCapableConfig: s.ECHCapableConfig.Clone().(ECHCapableConfig), dnsRouter: s.dnsRouter, lastUpdate: s.lastUpdate}
|
||||
return &ECHClientConfig{
|
||||
ECHCapableConfig: s.ECHCapableConfig.Clone().(ECHCapableConfig),
|
||||
dnsRouter: s.dnsRouter,
|
||||
queryServerName: s.queryServerName,
|
||||
lastUpdate: s.lastUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) {
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
//go:build !go1.24
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func parseECHClientConfig(ctx context.Context, clientConfig ECHCapableConfig, options option.OutboundTLSOptions) (Config, error) {
|
||||
return nil, E.New("ECH requires go1.24, please recompile your binary.")
|
||||
}
|
||||
|
||||
func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions, tlsConfig *tls.Config, echKeyPath *string) error {
|
||||
return E.New("ECH requires go1.24, please recompile your binary.")
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) setECHServerConfig(echKey []byte) error {
|
||||
panic("unreachable")
|
||||
}
|
||||
67
common/tls/ktls.go
Normal file
67
common/tls/ktls.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/common/ktls"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
type KTLSClientConfig struct {
|
||||
Config
|
||||
logger logger.ContextLogger
|
||||
kernelTx, kernelRx bool
|
||||
}
|
||||
|
||||
func (w *KTLSClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
|
||||
tlsConn, err := aTLS.ClientHandshake(ctx, conn, w.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kConn, err := ktls.NewConn(ctx, w.logger, tlsConn, w.kernelTx, w.kernelRx)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, E.Cause(err, "initialize kernel TLS")
|
||||
}
|
||||
return kConn, nil
|
||||
}
|
||||
|
||||
func (w *KTLSClientConfig) Clone() Config {
|
||||
return &KTLSClientConfig{
|
||||
w.Config.Clone(),
|
||||
w.logger,
|
||||
w.kernelTx,
|
||||
w.kernelRx,
|
||||
}
|
||||
}
|
||||
|
||||
type KTlSServerConfig struct {
|
||||
ServerConfig
|
||||
logger logger.ContextLogger
|
||||
kernelTx, kernelRx bool
|
||||
}
|
||||
|
||||
func (w *KTlSServerConfig) ServerHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
|
||||
tlsConn, err := aTLS.ServerHandshake(ctx, conn, w.ServerConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kConn, err := ktls.NewConn(ctx, w.logger, tlsConn, w.kernelTx, w.kernelRx)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, E.Cause(err, "initialize kernel TLS")
|
||||
}
|
||||
return kConn, nil
|
||||
}
|
||||
|
||||
func (w *KTlSServerConfig) Clone() Config {
|
||||
return &KTlSServerConfig{
|
||||
w.ServerConfig.Clone().(ServerConfig),
|
||||
w.logger,
|
||||
w.kernelTx,
|
||||
w.kernelRx,
|
||||
}
|
||||
}
|
||||
@@ -28,10 +28,12 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
@@ -49,12 +51,12 @@ type RealityClientConfig struct {
|
||||
shortID [8]byte
|
||||
}
|
||||
|
||||
func NewRealityClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (*RealityClientConfig, error) {
|
||||
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
if options.UTLS == nil || !options.UTLS.Enabled {
|
||||
return nil, E.New("uTLS is required by reality client")
|
||||
}
|
||||
|
||||
uClient, err := NewUTLSClient(ctx, serverAddress, options)
|
||||
uClient, err := NewUTLSClient(ctx, logger, serverAddress, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -74,7 +76,20 @@ func NewRealityClient(ctx context.Context, serverAddress string, options option.
|
||||
if decodedLen > 8 {
|
||||
return nil, E.New("invalid short_id")
|
||||
}
|
||||
return &RealityClientConfig{ctx, uClient.(*UTLSClientConfig), publicKey, shortID}, nil
|
||||
|
||||
var config Config = &RealityClientConfig{ctx, uClient.(*UTLSClientConfig), publicKey, shortID}
|
||||
if options.KernelRx || options.KernelTx {
|
||||
if !C.IsLinux {
|
||||
return nil, E.New("kTLS is only supported on Linux")
|
||||
}
|
||||
config = &KTLSClientConfig{
|
||||
Config: config,
|
||||
logger: logger,
|
||||
kernelTx: options.KernelTx,
|
||||
kernelRx: options.KernelRx,
|
||||
}
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (e *RealityClientConfig) ServerName() string {
|
||||
@@ -93,7 +108,7 @@ func (e *RealityClientConfig) SetNextProtos(nextProto []string) {
|
||||
e.uClient.SetNextProtos(nextProto)
|
||||
}
|
||||
|
||||
func (e *RealityClientConfig) Config() (*STDConfig, error) {
|
||||
func (e *RealityClientConfig) STDConfig() (*STDConfig, error) {
|
||||
return nil, E.New("unsupported usage for reality")
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
@@ -28,7 +29,7 @@ type RealityServerConfig struct {
|
||||
config *utls.RealityConfig
|
||||
}
|
||||
|
||||
func NewRealityServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (*RealityServerConfig, error) {
|
||||
func NewRealityServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||
var tlsConfig utls.RealityConfig
|
||||
|
||||
if options.ACME != nil && len(options.ACME.Domain) > 0 {
|
||||
@@ -67,7 +68,10 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb
|
||||
return nil, E.New("unknown cipher_suite: ", cipherSuite)
|
||||
}
|
||||
}
|
||||
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
||||
if len(options.CurvePreferences) > 0 {
|
||||
return nil, E.New("curve preferences is unavailable in reality")
|
||||
}
|
||||
if len(options.Certificate) > 0 || options.CertificatePath != "" || len(options.ClientCertificatePublicKeySHA256) > 0 {
|
||||
return nil, E.New("certificate is unavailable in reality")
|
||||
}
|
||||
if len(options.Key) > 0 || options.KeyPath != "" {
|
||||
@@ -119,7 +123,22 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb
|
||||
return handshakeDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
}
|
||||
|
||||
return &RealityServerConfig{&tlsConfig}, nil
|
||||
if options.ECH != nil && options.ECH.Enabled {
|
||||
return nil, E.New("Reality is conflict with ECH")
|
||||
}
|
||||
var config ServerConfig = &RealityServerConfig{&tlsConfig}
|
||||
if options.KernelTx || options.KernelRx {
|
||||
if !C.IsLinux {
|
||||
return nil, E.New("kTLS is only supported on Linux")
|
||||
}
|
||||
config = &KTlSServerConfig{
|
||||
ServerConfig: config,
|
||||
logger: logger,
|
||||
kernelTx: options.KernelTx,
|
||||
kernelRx: options.KernelRx,
|
||||
}
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (c *RealityServerConfig) ServerName() string {
|
||||
@@ -138,7 +157,7 @@ func (c *RealityServerConfig) SetNextProtos(nextProto []string) {
|
||||
c.config.NextProtos = nextProto
|
||||
}
|
||||
|
||||
func (c *RealityServerConfig) Config() (*tls.Config, error) {
|
||||
func (c *RealityServerConfig) STDConfig() (*tls.Config, error) {
|
||||
return nil, E.New("unsupported usage for reality")
|
||||
}
|
||||
|
||||
|
||||
@@ -12,14 +12,37 @@ import (
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
func NewServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||
if !options.Enabled {
|
||||
type ServerOptions struct {
|
||||
Context context.Context
|
||||
Logger log.ContextLogger
|
||||
Options option.InboundTLSOptions
|
||||
KTLSCompatible bool
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||
return NewServerWithOptions(ServerOptions{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Options: options,
|
||||
})
|
||||
}
|
||||
|
||||
func NewServerWithOptions(options ServerOptions) (ServerConfig, error) {
|
||||
if !options.Options.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
if options.Reality != nil && options.Reality.Enabled {
|
||||
return NewRealityServer(ctx, logger, options)
|
||||
if !options.KTLSCompatible {
|
||||
if options.Options.KernelTx {
|
||||
options.Logger.Warn("enabling kTLS TX in current scenarios will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_tx")
|
||||
}
|
||||
}
|
||||
return NewSTDServer(ctx, logger, options)
|
||||
if options.Options.KernelRx {
|
||||
options.Logger.Warn("enabling kTLS RX will definitely reduce performance, please checkout https://sing-box.sagernet.org/configuration/shared/tls/#kernel_rx")
|
||||
}
|
||||
if options.Options.Reality != nil && options.Options.Reality.Enabled {
|
||||
return NewRealityServer(options.Context, options.Logger, options.Options)
|
||||
}
|
||||
return NewSTDServer(options.Context, options.Logger, options.Options)
|
||||
}
|
||||
|
||||
func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) {
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -11,8 +14,10 @@ import (
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/tlsfragment"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
)
|
||||
|
||||
@@ -40,7 +45,7 @@ func (c *STDClientConfig) SetNextProtos(nextProto []string) {
|
||||
c.config.NextProtos = nextProto
|
||||
}
|
||||
|
||||
func (c *STDClientConfig) Config() (*STDConfig, error) {
|
||||
func (c *STDClientConfig) STDConfig() (*STDConfig, error) {
|
||||
return c.config, nil
|
||||
}
|
||||
|
||||
@@ -52,7 +57,13 @@ func (c *STDClientConfig) Client(conn net.Conn) (Conn, error) {
|
||||
}
|
||||
|
||||
func (c *STDClientConfig) Clone() Config {
|
||||
return &STDClientConfig{c.ctx, c.config.Clone(), c.fragment, c.fragmentFallbackDelay, c.recordFragment}
|
||||
return &STDClientConfig{
|
||||
ctx: c.ctx,
|
||||
config: c.config.Clone(),
|
||||
fragment: c.fragment,
|
||||
fragmentFallbackDelay: c.fragmentFallbackDelay,
|
||||
recordFragment: c.recordFragment,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *STDClientConfig) ECHConfigList() []byte {
|
||||
@@ -63,7 +74,7 @@ func (c *STDClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byte
|
||||
c.config.EncryptedClientHelloConfigList = EncryptedClientHelloConfigList
|
||||
}
|
||||
|
||||
func NewSTDClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
var serverName string
|
||||
if options.ServerName != "" {
|
||||
serverName = options.ServerName
|
||||
@@ -100,6 +111,15 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(options.CertificatePublicKeySHA256) > 0 {
|
||||
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
||||
return nil, E.New("certificate_public_key_sha256 is conflict with certificate or certificate_path")
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
|
||||
}
|
||||
}
|
||||
if len(options.ALPN) > 0 {
|
||||
tlsConfig.NextProtos = options.ALPN
|
||||
}
|
||||
@@ -129,6 +149,9 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
|
||||
return nil, E.New("unknown cipher_suite: ", cipherSuite)
|
||||
}
|
||||
}
|
||||
for _, curve := range options.CurvePreferences {
|
||||
tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curve))
|
||||
}
|
||||
var certificate []byte
|
||||
if len(options.Certificate) > 0 {
|
||||
certificate = []byte(strings.Join(options.Certificate, "\n"))
|
||||
@@ -146,10 +169,72 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
|
||||
}
|
||||
tlsConfig.RootCAs = certPool
|
||||
}
|
||||
stdConfig := &STDClientConfig{ctx, &tlsConfig, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
|
||||
if options.ECH != nil && options.ECH.Enabled {
|
||||
return parseECHClientConfig(ctx, stdConfig, options)
|
||||
} else {
|
||||
return stdConfig, nil
|
||||
var clientCertificate []byte
|
||||
if len(options.ClientCertificate) > 0 {
|
||||
clientCertificate = []byte(strings.Join(options.ClientCertificate, "\n"))
|
||||
} else if options.ClientCertificatePath != "" {
|
||||
content, err := os.ReadFile(options.ClientCertificatePath)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read client certificate")
|
||||
}
|
||||
clientCertificate = content
|
||||
}
|
||||
var clientKey []byte
|
||||
if len(options.ClientKey) > 0 {
|
||||
clientKey = []byte(strings.Join(options.ClientKey, "\n"))
|
||||
} else if options.ClientKeyPath != "" {
|
||||
content, err := os.ReadFile(options.ClientKeyPath)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read client key")
|
||||
}
|
||||
clientKey = content
|
||||
}
|
||||
if len(clientCertificate) > 0 && len(clientKey) > 0 {
|
||||
keyPair, err := tls.X509KeyPair(clientCertificate, clientKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse client x509 key pair")
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{keyPair}
|
||||
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
||||
return nil, E.New("client certificate and client key must be provided together")
|
||||
}
|
||||
var config Config = &STDClientConfig{ctx, &tlsConfig, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
|
||||
if options.ECH != nil && options.ECH.Enabled {
|
||||
var err error
|
||||
config, err = parseECHClientConfig(ctx, config.(ECHCapableConfig), options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if options.KernelRx || options.KernelTx {
|
||||
if !C.IsLinux {
|
||||
return nil, E.New("kTLS is only supported on Linux")
|
||||
}
|
||||
config = &KTLSClientConfig{
|
||||
Config: config,
|
||||
logger: logger,
|
||||
kernelTx: options.KernelTx,
|
||||
kernelRx: options.KernelRx,
|
||||
}
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func verifyPublicKeySHA256(knownHashValues [][]byte, rawCerts [][]byte, timeFunc func() time.Time) error {
|
||||
leafCertificate, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return E.Cause(err, "failed to parse leaf certificate")
|
||||
}
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(leafCertificate.PublicKey)
|
||||
if err != nil {
|
||||
return E.Cause(err, "failed to marshal public key")
|
||||
}
|
||||
hashValue := sha256.Sum256(pubKeyBytes)
|
||||
for _, value := range knownHashValues {
|
||||
if bytes.Equal(value, hashValue[:]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return E.New("unrecognized remote public key: ", base64.StdEncoding.EncodeToString(hashValue[:]))
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package tls
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
@@ -21,16 +23,17 @@ import (
|
||||
var errInsecureUnused = E.New("tls: insecure unused")
|
||||
|
||||
type STDServerConfig struct {
|
||||
access sync.RWMutex
|
||||
config *tls.Config
|
||||
logger log.Logger
|
||||
acmeService adapter.SimpleLifecycle
|
||||
certificate []byte
|
||||
key []byte
|
||||
certificatePath string
|
||||
keyPath string
|
||||
echKeyPath string
|
||||
watcher *fswatch.Watcher
|
||||
access sync.RWMutex
|
||||
config *tls.Config
|
||||
logger log.Logger
|
||||
acmeService adapter.SimpleLifecycle
|
||||
certificate []byte
|
||||
key []byte
|
||||
certificatePath string
|
||||
keyPath string
|
||||
clientCertificatePath []string
|
||||
echKeyPath string
|
||||
watcher *fswatch.Watcher
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) ServerName() string {
|
||||
@@ -69,7 +72,7 @@ func (c *STDServerConfig) SetNextProtos(nextProto []string) {
|
||||
c.config = config
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Config() (*STDConfig, error) {
|
||||
func (c *STDServerConfig) STDConfig() (*STDConfig, error) {
|
||||
return c.config, nil
|
||||
}
|
||||
|
||||
@@ -110,6 +113,9 @@ func (c *STDServerConfig) startWatcher() error {
|
||||
if c.echKeyPath != "" {
|
||||
watchPath = append(watchPath, c.echKeyPath)
|
||||
}
|
||||
if len(c.clientCertificatePath) > 0 {
|
||||
watchPath = append(watchPath, c.clientCertificatePath...)
|
||||
}
|
||||
if len(watchPath) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -158,6 +164,30 @@ func (c *STDServerConfig) certificateUpdated(path string) error {
|
||||
c.config = config
|
||||
c.access.Unlock()
|
||||
c.logger.Info("reloaded TLS certificate")
|
||||
} else if common.Contains(c.clientCertificatePath, path) {
|
||||
clientCertificateCA := x509.NewCertPool()
|
||||
var reloaded bool
|
||||
for _, certPath := range c.clientCertificatePath {
|
||||
content, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
c.logger.Error(E.Cause(err, "reload certificate from ", c.clientCertificatePath))
|
||||
continue
|
||||
}
|
||||
if !clientCertificateCA.AppendCertsFromPEM(content) {
|
||||
c.logger.Error(E.New("invalid client certificate file: ", certPath))
|
||||
continue
|
||||
}
|
||||
reloaded = true
|
||||
}
|
||||
if !reloaded {
|
||||
return E.New("client certificates is empty")
|
||||
}
|
||||
c.access.Lock()
|
||||
config := c.config.Clone()
|
||||
config.ClientCAs = clientCertificateCA
|
||||
c.config = config
|
||||
c.access.Unlock()
|
||||
c.logger.Info("reloaded client certificates")
|
||||
} else if path == c.echKeyPath {
|
||||
echKey, err := os.ReadFile(c.echKeyPath)
|
||||
if err != nil {
|
||||
@@ -182,7 +212,7 @@ func (c *STDServerConfig) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewSTDServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||
func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||
if !options.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -234,8 +264,14 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
|
||||
return nil, E.New("unknown cipher_suite: ", cipherSuite)
|
||||
}
|
||||
}
|
||||
var certificate []byte
|
||||
var key []byte
|
||||
for _, curveID := range options.CurvePreferences {
|
||||
tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curveID))
|
||||
}
|
||||
tlsConfig.ClientAuth = tls.ClientAuthType(options.ClientAuthentication)
|
||||
var (
|
||||
certificate []byte
|
||||
key []byte
|
||||
)
|
||||
if acmeService == nil {
|
||||
if len(options.Certificate) > 0 {
|
||||
certificate = []byte(strings.Join(options.Certificate, "\n"))
|
||||
@@ -277,6 +313,43 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
|
||||
tlsConfig.Certificates = []tls.Certificate{keyPair}
|
||||
}
|
||||
}
|
||||
if len(options.ClientCertificate) > 0 || len(options.ClientCertificatePath) > 0 {
|
||||
if tlsConfig.ClientAuth == tls.NoClientCert {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
}
|
||||
if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven || tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
|
||||
if len(options.ClientCertificate) > 0 {
|
||||
clientCertificateCA := x509.NewCertPool()
|
||||
if !clientCertificateCA.AppendCertsFromPEM([]byte(strings.Join(options.ClientCertificate, "\n"))) {
|
||||
return nil, E.New("invalid client certificate strings")
|
||||
}
|
||||
tlsConfig.ClientCAs = clientCertificateCA
|
||||
} else if len(options.ClientCertificatePath) > 0 {
|
||||
clientCertificateCA := x509.NewCertPool()
|
||||
for _, path := range options.ClientCertificatePath {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read client certificate from ", path)
|
||||
}
|
||||
if !clientCertificateCA.AppendCertsFromPEM(content) {
|
||||
return nil, E.New("invalid client certificate file: ", path)
|
||||
}
|
||||
}
|
||||
tlsConfig.ClientCAs = clientCertificateCA
|
||||
} else if len(options.ClientCertificatePublicKeySHA256) > 0 {
|
||||
if tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
|
||||
tlsConfig.ClientAuth = tls.RequireAnyClientCert
|
||||
} else if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven {
|
||||
tlsConfig.ClientAuth = tls.RequestClientCert
|
||||
}
|
||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
return verifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
|
||||
}
|
||||
} else {
|
||||
return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
|
||||
}
|
||||
}
|
||||
var echKeyPath string
|
||||
if options.ECH != nil && options.ECH.Enabled {
|
||||
err = parseECHServerConfig(ctx, options, tlsConfig, &echKeyPath)
|
||||
@@ -285,19 +358,32 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
|
||||
}
|
||||
}
|
||||
serverConfig := &STDServerConfig{
|
||||
config: tlsConfig,
|
||||
logger: logger,
|
||||
acmeService: acmeService,
|
||||
certificate: certificate,
|
||||
key: key,
|
||||
certificatePath: options.CertificatePath,
|
||||
keyPath: options.KeyPath,
|
||||
echKeyPath: echKeyPath,
|
||||
config: tlsConfig,
|
||||
logger: logger,
|
||||
acmeService: acmeService,
|
||||
certificate: certificate,
|
||||
key: key,
|
||||
certificatePath: options.CertificatePath,
|
||||
clientCertificatePath: options.ClientCertificatePath,
|
||||
keyPath: options.KeyPath,
|
||||
echKeyPath: echKeyPath,
|
||||
}
|
||||
serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
serverConfig.access.Lock()
|
||||
defer serverConfig.access.Unlock()
|
||||
return serverConfig.config, nil
|
||||
}
|
||||
return serverConfig, nil
|
||||
var config ServerConfig = serverConfig
|
||||
if options.KernelTx || options.KernelRx {
|
||||
if !C.IsLinux {
|
||||
return nil, E.New("kTLS is only supported on Linux")
|
||||
}
|
||||
config = &KTlSServerConfig{
|
||||
ServerConfig: config,
|
||||
logger: logger,
|
||||
kernelTx: options.KernelTx,
|
||||
kernelRx: options.KernelRx,
|
||||
}
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -14,8 +14,11 @@ import (
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/tlsfragment"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
|
||||
utls "github.com/metacubex/utls"
|
||||
@@ -50,7 +53,7 @@ func (c *UTLSClientConfig) SetNextProtos(nextProto []string) {
|
||||
c.config.NextProtos = nextProto
|
||||
}
|
||||
|
||||
func (c *UTLSClientConfig) Config() (*STDConfig, error) {
|
||||
func (c *UTLSClientConfig) STDConfig() (*STDConfig, error) {
|
||||
return nil, E.New("unsupported usage for uTLS")
|
||||
}
|
||||
|
||||
@@ -139,7 +142,7 @@ func (c *utlsALPNWrapper) HandshakeContext(ctx context.Context) error {
|
||||
return c.UConn.HandshakeContext(ctx)
|
||||
}
|
||||
|
||||
func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
var serverName string
|
||||
if options.ServerName != "" {
|
||||
serverName = options.ServerName
|
||||
@@ -164,6 +167,15 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out
|
||||
}
|
||||
tlsConfig.InsecureServerNameToVerify = serverName
|
||||
}
|
||||
if len(options.CertificatePublicKeySHA256) > 0 {
|
||||
if len(options.Certificate) > 0 || options.CertificatePath != "" {
|
||||
return nil, E.New("certificate_public_key_sha256 is conflict with certificate or certificate_path")
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
|
||||
}
|
||||
}
|
||||
if len(options.ALPN) > 0 {
|
||||
tlsConfig.NextProtos = options.ALPN
|
||||
}
|
||||
@@ -210,19 +222,61 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out
|
||||
}
|
||||
tlsConfig.RootCAs = certPool
|
||||
}
|
||||
var clientCertificate []byte
|
||||
if len(options.ClientCertificate) > 0 {
|
||||
clientCertificate = []byte(strings.Join(options.ClientCertificate, "\n"))
|
||||
} else if options.ClientCertificatePath != "" {
|
||||
content, err := os.ReadFile(options.ClientCertificatePath)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read client certificate")
|
||||
}
|
||||
clientCertificate = content
|
||||
}
|
||||
var clientKey []byte
|
||||
if len(options.ClientKey) > 0 {
|
||||
clientKey = []byte(strings.Join(options.ClientKey, "\n"))
|
||||
} else if options.ClientKeyPath != "" {
|
||||
content, err := os.ReadFile(options.ClientKeyPath)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read client key")
|
||||
}
|
||||
clientKey = content
|
||||
}
|
||||
if len(clientCertificate) > 0 && len(clientKey) > 0 {
|
||||
keyPair, err := utls.X509KeyPair(clientCertificate, clientKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse client x509 key pair")
|
||||
}
|
||||
tlsConfig.Certificates = []utls.Certificate{keyPair}
|
||||
} else if len(clientCertificate) > 0 || len(clientKey) > 0 {
|
||||
return nil, E.New("client certificate and client key must be provided together")
|
||||
}
|
||||
id, err := uTLSClientHelloID(options.UTLS.Fingerprint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
uConfig := &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
|
||||
var config Config = &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}
|
||||
if options.ECH != nil && options.ECH.Enabled {
|
||||
if options.Reality != nil && options.Reality.Enabled {
|
||||
return nil, E.New("Reality is conflict with ECH")
|
||||
}
|
||||
return parseECHClientConfig(ctx, uConfig, options)
|
||||
} else {
|
||||
return uConfig, nil
|
||||
config, err = parseECHClientConfig(ctx, config.(ECHCapableConfig), options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if (options.KernelRx || options.KernelTx) && !common.PtrValueOrDefault(options.Reality).Enabled {
|
||||
if !C.IsLinux {
|
||||
return nil, E.New("kTLS is only supported on Linux")
|
||||
}
|
||||
config = &KTLSClientConfig{
|
||||
Config: config,
|
||||
logger: logger,
|
||||
kernelTx: options.KernelTx,
|
||||
kernelRx: options.KernelRx,
|
||||
}
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -8,13 +8,14 @@ import (
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`)
|
||||
}
|
||||
|
||||
func NewRealityClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
func NewRealityClient(ctx context.Context, logger logger.ContextLogger, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
return nil, E.New(`uTLS, which is required by reality is not included in this build, rebuild with -tags with_utls`)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/observable"
|
||||
)
|
||||
|
||||
var _ adapter.URLTestHistoryStorage = (*HistoryStorage)(nil)
|
||||
@@ -22,7 +22,7 @@ var _ adapter.URLTestHistoryStorage = (*HistoryStorage)(nil)
|
||||
type HistoryStorage struct {
|
||||
access sync.RWMutex
|
||||
delayHistory map[string]*adapter.URLTestHistory
|
||||
updateHook chan<- struct{}
|
||||
updateHook *observable.Subscriber[struct{}]
|
||||
}
|
||||
|
||||
func NewHistoryStorage() *HistoryStorage {
|
||||
@@ -31,7 +31,7 @@ func NewHistoryStorage() *HistoryStorage {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HistoryStorage) SetHook(hook chan<- struct{}) {
|
||||
func (s *HistoryStorage) SetHook(hook *observable.Subscriber[struct{}]) {
|
||||
s.updateHook = hook
|
||||
}
|
||||
|
||||
@@ -61,10 +61,7 @@ func (s *HistoryStorage) StoreURLTestHistory(tag string, history *adapter.URLTes
|
||||
func (s *HistoryStorage) notifyUpdated() {
|
||||
updateHook := s.updateHook
|
||||
if updateHook != nil {
|
||||
select {
|
||||
case updateHook <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
updateHook.Emit(struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +97,7 @@ func URLTest(ctx context.Context, link string, detour N.Dialer) (t uint16, err e
|
||||
return
|
||||
}
|
||||
defer instance.Close()
|
||||
if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](instance); isEarlyConn && earlyConn.NeedHandshake() {
|
||||
if N.NeedHandshakeForWrite(instance) {
|
||||
start = time.Now()
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodHead, link, nil)
|
||||
|
||||
Reference in New Issue
Block a user