mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-26 20:29:03 +03:00
Add Snell protocol. Refactor MASQUE HTTP/2, Fair Queue. Update XHTTP, OpenVPN, Sudoku, Fallback. Fixes
This commit is contained in:
144
transport/snell/address.go
Normal file
144
transport/snell/address.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// SOCKS address types as defined in RFC 1928 section 5.
|
||||
const (
|
||||
atypIPv4 = 1
|
||||
atypDomainName = 3
|
||||
atypIPv6 = 4
|
||||
)
|
||||
|
||||
// socksAddr represents a SOCKS address as defined in RFC 1928 section 5.
|
||||
type socksAddr []byte
|
||||
|
||||
func (a socksAddr) String() string {
|
||||
var host, port string
|
||||
switch a[0] {
|
||||
case atypDomainName:
|
||||
hostLen := uint16(a[1])
|
||||
host = string(a[2 : 2+hostLen])
|
||||
port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
|
||||
case atypIPv4:
|
||||
host = net.IP(a[1 : 1+net.IPv4len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
|
||||
case atypIPv6:
|
||||
host = net.IP(a[1 : 1+net.IPv6len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// UDPAddr converts a socksAddr to *net.UDPAddr.
|
||||
func (a socksAddr) UDPAddr() *net.UDPAddr {
|
||||
if len(a) == 0 {
|
||||
return nil
|
||||
}
|
||||
switch a[0] {
|
||||
case atypIPv4:
|
||||
var ip [net.IPv4len]byte
|
||||
copy(ip[0:], a[1:1+net.IPv4len])
|
||||
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
|
||||
case atypIPv6:
|
||||
var ip [net.IPv6len]byte
|
||||
copy(ip[0:], a[1:1+net.IPv6len])
|
||||
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitSocksAddr slices a SOCKS address from beginning of b. Returns nil if failed.
|
||||
func splitSocksAddr(b []byte) socksAddr {
|
||||
addrLen := 1
|
||||
if len(b) < addrLen {
|
||||
return nil
|
||||
}
|
||||
switch b[0] {
|
||||
case atypDomainName:
|
||||
if len(b) < 2 {
|
||||
return nil
|
||||
}
|
||||
addrLen = 1 + 1 + int(b[1]) + 2
|
||||
case atypIPv4:
|
||||
addrLen = 1 + net.IPv4len + 2
|
||||
case atypIPv6:
|
||||
addrLen = 1 + net.IPv6len + 2
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
if len(b) < addrLen {
|
||||
return nil
|
||||
}
|
||||
return b[:addrLen]
|
||||
}
|
||||
|
||||
// parseAddr parses the address in string s. Returns nil if failed.
|
||||
func parseAddr(s string) socksAddr {
|
||||
var addr socksAddr
|
||||
host, port, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
addr = make([]byte, 1+net.IPv4len+2)
|
||||
addr[0] = atypIPv4
|
||||
copy(addr[1:], ip4)
|
||||
} else {
|
||||
addr = make([]byte, 1+net.IPv6len+2)
|
||||
addr[0] = atypIPv6
|
||||
copy(addr[1:], ip)
|
||||
}
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
return nil
|
||||
}
|
||||
addr = make([]byte, 1+1+len(host)+2)
|
||||
addr[0] = atypDomainName
|
||||
addr[1] = byte(len(host))
|
||||
copy(addr[2:], host)
|
||||
}
|
||||
portnum, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
|
||||
return addr
|
||||
}
|
||||
|
||||
// parseAddrToSocksAddr parses a socks addr from net.Addr.
|
||||
// This is a fast path of parseAddr(addr.String()).
|
||||
func parseAddrToSocksAddr(addr net.Addr) socksAddr {
|
||||
var hostip net.IP
|
||||
var port int
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
hostip = addr.IP
|
||||
port = addr.Port
|
||||
case *net.TCPAddr:
|
||||
hostip = addr.IP
|
||||
port = addr.Port
|
||||
case nil:
|
||||
return nil
|
||||
}
|
||||
if hostip == nil {
|
||||
return parseAddr(addr.String())
|
||||
}
|
||||
var parsed socksAddr
|
||||
if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
|
||||
parsed = make([]byte, 1+net.IPv4len+2)
|
||||
parsed[0] = atypIPv4
|
||||
copy(parsed[1:], ip4)
|
||||
binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
|
||||
} else {
|
||||
parsed = make([]byte, 1+net.IPv6len+2)
|
||||
parsed[0] = atypIPv6
|
||||
copy(parsed[1:], hostip)
|
||||
binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
56
transport/snell/cipher.go
Normal file
56
transport/snell/cipher.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// NewAES128GCM returns the AES-128-GCM cipher used by snell v2/v3.
|
||||
func NewAES128GCM(psk []byte) Cipher {
|
||||
return &snellCipher{
|
||||
psk: psk,
|
||||
keySize: 16,
|
||||
makeAEAD: aesGCM,
|
||||
}
|
||||
}
|
||||
|
||||
// NewChacha20Poly1305 returns the ChaCha20-Poly1305 cipher used by snell v1.
|
||||
func NewChacha20Poly1305(psk []byte) Cipher {
|
||||
return &snellCipher{
|
||||
psk: psk,
|
||||
keySize: 32,
|
||||
makeAEAD: chacha20poly1305.New,
|
||||
}
|
||||
}
|
||||
|
||||
type snellCipher struct {
|
||||
psk []byte
|
||||
keySize int
|
||||
makeAEAD func(key []byte) (cipher.AEAD, error)
|
||||
}
|
||||
|
||||
func (sc *snellCipher) KeySize() int { return sc.keySize }
|
||||
func (sc *snellCipher) SaltSize() int { return 16 }
|
||||
|
||||
func (sc *snellCipher) Encrypter(salt []byte) (cipher.AEAD, error) {
|
||||
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
|
||||
}
|
||||
|
||||
func (sc *snellCipher) Decrypter(salt []byte) (cipher.AEAD, error) {
|
||||
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
|
||||
}
|
||||
|
||||
func snellKDF(psk, salt []byte, keySize int) []byte {
|
||||
return argon2.IDKey(psk, salt, 3, 8, 1, 32)[:keySize]
|
||||
}
|
||||
|
||||
func aesGCM(key []byte) (cipher.AEAD, error) {
|
||||
blk, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cipher.NewGCM(blk)
|
||||
}
|
||||
120
transport/snell/client.go
Normal file
120
transport/snell/client.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type ClientOptions struct {
|
||||
Dialer N.Dialer
|
||||
Server M.Socksaddr
|
||||
PSK []byte
|
||||
Version int
|
||||
Reuse bool
|
||||
ObfsMode string
|
||||
ObfsHost string
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
dialer N.Dialer
|
||||
server M.Socksaddr
|
||||
psk []byte
|
||||
version int
|
||||
reuse bool
|
||||
obfsMode string
|
||||
obfsHost string
|
||||
pool *Pool
|
||||
}
|
||||
|
||||
func NewClient(options ClientOptions) *Client {
|
||||
c := &Client{
|
||||
dialer: options.Dialer,
|
||||
server: options.Server,
|
||||
psk: options.PSK,
|
||||
version: options.Version,
|
||||
reuse: options.Reuse,
|
||||
obfsMode: options.ObfsMode,
|
||||
obfsHost: options.ObfsHost,
|
||||
}
|
||||
if c.reuse {
|
||||
c.pool = NewPool(func(ctx context.Context) (*Snell, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.streamConn(conn), nil
|
||||
})
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) streamConn(conn net.Conn) *Snell {
|
||||
switch c.obfsMode {
|
||||
case "tls":
|
||||
conn = obfs.NewTLSObfs(conn, c.obfsHost)
|
||||
case "http":
|
||||
conn = obfs.NewHTTPObfs(conn, c.obfsHost, strconv.Itoa(int(c.server.Port)))
|
||||
}
|
||||
return StreamConn(conn, c.psk, c.version)
|
||||
}
|
||||
|
||||
func (c *Client) writeHeader(ctx context.Context, conn net.Conn, destination M.Socksaddr, udp bool) (err error) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = conn.SetWriteDeadline(deadline)
|
||||
defer conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
if udp {
|
||||
err = WriteUDPHeader(conn, c.version)
|
||||
if err == nil && c.version >= Version4 {
|
||||
if sc, ok := conn.(*Snell); ok {
|
||||
err = sc.ReadReply()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
err = WriteHeaderWithReuse(conn, destination.AddrString(), uint(destination.Port), c.version, c.reuse)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) DialContext(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
||||
if c.reuse {
|
||||
conn, err := c.pool.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = c.writeHeader(ctx, conn, destination, false); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream := c.streamConn(conn)
|
||||
if err = c.writeHeader(ctx, stream, destination, false); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream := c.streamConn(conn)
|
||||
if err = c.writeHeader(ctx, stream, destination, true); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return PacketConn(stream), nil
|
||||
}
|
||||
153
transport/snell/pool.go
Normal file
153
transport/snell/pool.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// poolEntry holds a pooled item with its insertion time.
|
||||
|
||||
// connPool is a small connection pool with age-based eviction.
|
||||
|
||||
// milliseconds
|
||||
|
||||
// Pool is a pool of reusable snell connections.
|
||||
type Pool struct {
|
||||
pool *connPool
|
||||
}
|
||||
|
||||
func (p *Pool) Get() (net.Conn, error) {
|
||||
return p.GetContext(context.Background())
|
||||
}
|
||||
|
||||
func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) {
|
||||
elm, err := p.pool.GetContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PoolConn{Snell: elm, pool: p}, nil
|
||||
}
|
||||
|
||||
func (p *Pool) Put(conn *Snell) {
|
||||
if err := HalfClose(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
p.pool.put(conn)
|
||||
}
|
||||
|
||||
// PoolConn wraps a pooled snell connection and returns it to the pool on Close.
|
||||
type PoolConn struct {
|
||||
*Snell
|
||||
pool *Pool
|
||||
closeWriteOnce sync.Once
|
||||
closeWriteErr error
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (pc *PoolConn) Read(b []byte) (int, error) {
|
||||
n, err := pc.Snell.Read(b)
|
||||
if err == ErrZeroChunk {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (pc *PoolConn) Write(b []byte) (int, error) {
|
||||
return pc.Snell.Write(b)
|
||||
}
|
||||
|
||||
func (pc *PoolConn) CloseWrite() error {
|
||||
pc.closeWriteOnce.Do(func() {
|
||||
pc.closeWriteErr = writeZeroChunk(pc.Snell)
|
||||
})
|
||||
return pc.closeWriteErr
|
||||
}
|
||||
|
||||
func (pc *PoolConn) Close() error {
|
||||
pc.closeOnce.Do(func() {
|
||||
if err := pc.CloseWrite(); err != nil {
|
||||
pc.closeErr = err
|
||||
_ = pc.Snell.Close()
|
||||
return
|
||||
}
|
||||
_ = pc.Snell.Conn.SetReadDeadline(time.Time{})
|
||||
pc.Snell.reply = false
|
||||
pc.pool.pool.put(pc.Snell)
|
||||
})
|
||||
return pc.closeErr
|
||||
}
|
||||
|
||||
// NewPool creates a new snell connection pool using the given factory.
|
||||
func NewPool(factory func(context.Context) (*Snell, error)) *Pool {
|
||||
cp := &connPool{
|
||||
ch: make(chan *poolEntry, 10),
|
||||
factory: factory,
|
||||
maxAge: 15000,
|
||||
evict: func(item *Snell) {
|
||||
_ = item.Close()
|
||||
},
|
||||
}
|
||||
p := &Pool{pool: cp}
|
||||
runtime.SetFinalizer(p, recycle)
|
||||
return p
|
||||
}
|
||||
|
||||
type poolEntry struct {
|
||||
elm *Snell
|
||||
time time.Time
|
||||
}
|
||||
|
||||
type connPool struct {
|
||||
ch chan *poolEntry
|
||||
factory func(context.Context) (*Snell, error)
|
||||
evict func(*Snell)
|
||||
maxAge int64
|
||||
}
|
||||
|
||||
func (p *connPool) GetContext(ctx context.Context) (*Snell, error) {
|
||||
now := time.Now()
|
||||
for {
|
||||
select {
|
||||
case item := <-p.ch:
|
||||
if p.maxAge != 0 && now.Sub(item.time).Milliseconds() > p.maxAge {
|
||||
if p.evict != nil {
|
||||
p.evict(item.elm)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return item.elm, nil
|
||||
default:
|
||||
return p.factory(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *connPool) put(item *Snell) {
|
||||
e := &poolEntry{
|
||||
elm: item,
|
||||
time: time.Now(),
|
||||
}
|
||||
select {
|
||||
case p.ch <- e:
|
||||
return
|
||||
default:
|
||||
if p.evict != nil {
|
||||
p.evict(item)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func recycle(p *Pool) {
|
||||
for item := range p.pool.ch {
|
||||
if p.pool.evict != nil {
|
||||
p.pool.evict(item.elm)
|
||||
}
|
||||
}
|
||||
}
|
||||
294
transport/snell/service.go
Normal file
294
transport/snell/service.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, clientID string)
|
||||
|
||||
NewPacketConnection(ctx context.Context, conn net.PacketConn, source M.Socksaddr, clientID string)
|
||||
}
|
||||
|
||||
type ServiceOptions struct {
|
||||
PSK []byte
|
||||
Version int
|
||||
ObfsMode string
|
||||
UDP bool
|
||||
Logger logger.ContextLogger
|
||||
Handler Handler
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
psk []byte
|
||||
version int
|
||||
obfsMode string
|
||||
udp bool
|
||||
logger logger.ContextLogger
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func NewService(options ServiceOptions) (*Service, error) {
|
||||
version := options.Version
|
||||
if version == 0 {
|
||||
version = Version4
|
||||
}
|
||||
if version != Version4 && version != Version5 {
|
||||
return nil, fmt.Errorf("snell inbound version %d is not supported", version)
|
||||
}
|
||||
if len(options.PSK) == 0 {
|
||||
return nil, errors.New("snell inbound requires psk")
|
||||
}
|
||||
switch options.ObfsMode {
|
||||
case "", "http", "tls":
|
||||
default:
|
||||
return nil, fmt.Errorf("snell inbound obfs mode error: %s", options.ObfsMode)
|
||||
}
|
||||
return &Service{
|
||||
psk: options.PSK,
|
||||
version: version,
|
||||
obfsMode: options.ObfsMode,
|
||||
udp: options.UDP,
|
||||
logger: options.Logger,
|
||||
handler: options.Handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(ctx context.Context, rawConn net.Conn, source M.Socksaddr) error {
|
||||
conn := rawConn
|
||||
switch s.obfsMode {
|
||||
case "http":
|
||||
conn = obfs.NewHTTPObfsServer(conn)
|
||||
case "tls":
|
||||
conn = obfs.NewTLSObfsServer(conn)
|
||||
}
|
||||
stream := ServerStreamConn(conn, s.psk, s.version)
|
||||
for {
|
||||
reuse, err := s.handleRequest(ctx, stream, source)
|
||||
if err != nil || !reuse {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleRequest(ctx context.Context, stream *Snell, source M.Socksaddr) (bool, error) {
|
||||
br := bufio.NewReader(stream)
|
||||
version, err := br.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if version != Version {
|
||||
return false, fmt.Errorf("snell invalid protocol version: %d", version)
|
||||
}
|
||||
command, err := br.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if command == CommandPing {
|
||||
_, _ = stream.Write([]byte{CommandPong})
|
||||
return false, nil
|
||||
}
|
||||
clientID, err := readClientID(br)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
switch command {
|
||||
case CommandConnect, CommandConnectV2:
|
||||
return s.handleTCP(ctx, stream, br, command == CommandConnectV2, clientID, source)
|
||||
case CommandUDP:
|
||||
if !s.udp {
|
||||
return false, errors.New("snell UDP is disabled")
|
||||
}
|
||||
return false, s.handleUDP(ctx, stream, clientID, source)
|
||||
default:
|
||||
return false, fmt.Errorf("snell unknown command: %d", command)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleTCP(ctx context.Context, stream *Snell, br *bufio.Reader, reuse bool, clientID string, source M.Socksaddr) (bool, error) {
|
||||
hostLen, err := br.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if hostLen == 0 {
|
||||
return false, errors.New("snell connect host is empty")
|
||||
}
|
||||
hostBytes := make([]byte, int(hostLen))
|
||||
if _, err := io.ReadFull(br, hostBytes); err != nil {
|
||||
return false, err
|
||||
}
|
||||
var portBytes [2]byte
|
||||
if _, err := io.ReadFull(br, portBytes[:]); err != nil {
|
||||
return false, err
|
||||
}
|
||||
destination := M.ParseSocksaddrHostPort(string(hostBytes), binary.BigEndian.Uint16(portBytes[:]))
|
||||
conn := &tcpRequestConn{
|
||||
Conn: stream,
|
||||
reader: br,
|
||||
reuse: reuse,
|
||||
}
|
||||
s.handler.NewConnection(ctx, conn, source, destination, clientID)
|
||||
if !reuse {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *Service) handleUDP(ctx context.Context, stream *Snell, clientID string, source M.Socksaddr) error {
|
||||
if _, err := stream.Write([]byte{CommandTunnel}); err != nil {
|
||||
return err
|
||||
}
|
||||
pc := &serverPacketConn{
|
||||
conn: stream,
|
||||
writeMu: &sync.Mutex{},
|
||||
}
|
||||
s.handler.NewPacketConnection(ctx, pc, source, clientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxPacketLength = 0x3fff
|
||||
|
||||
func readClientID(r *bufio.Reader) (string, error) {
|
||||
length, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
id := make([]byte, int(length))
|
||||
if _, err := io.ReadFull(r, id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(id), nil
|
||||
}
|
||||
|
||||
func writeCommandError(w io.Writer, code byte, message string) error {
|
||||
msg := []byte(message)
|
||||
if len(msg) > 255 {
|
||||
msg = msg[:255]
|
||||
}
|
||||
buf := make([]byte, 0, 3+len(msg))
|
||||
buf = append(buf, CommandError, code, byte(len(msg)))
|
||||
buf = append(buf, msg...)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
type tcpRequestConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
reuse bool
|
||||
writeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
replyWritten bool
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) Read(p []byte) (int, error) {
|
||||
n, err := c.reader.Read(p)
|
||||
if errors.Is(err, ErrZeroChunk) {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) Write(p []byte) (int, error) {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if !c.replyWritten {
|
||||
payload := make([]byte, 1+len(p))
|
||||
payload[0] = CommandTunnel
|
||||
copy(payload[1:], p)
|
||||
if _, err := c.Conn.Write(payload); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.replyWritten = true
|
||||
return len(p), nil
|
||||
}
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) CloseWrite() error {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) Close() error {
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if !c.replyWritten {
|
||||
err = writeCommandError(c.Conn, 0x65, "Remote EOF")
|
||||
if !c.reuse {
|
||||
err = errors.Join(err, c.Conn.Close())
|
||||
}
|
||||
return
|
||||
}
|
||||
if c.reuse {
|
||||
_, err = c.Conn.Write(nil)
|
||||
return
|
||||
}
|
||||
err = c.Conn.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
type serverPacketConn struct {
|
||||
conn *Snell
|
||||
writeMu *sync.Mutex
|
||||
readBuf []byte
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
if c.readBuf == nil {
|
||||
c.readBuf = make([]byte, maxPacketLength)
|
||||
}
|
||||
for {
|
||||
n, err := c.conn.Read(c.readBuf)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, ErrZeroChunk) {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
request, err := ParseUDPRequest(c.readBuf[:n])
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
if request.Ip.IsValid() {
|
||||
destination = M.SocksaddrFrom(request.Ip, request.Port)
|
||||
} else {
|
||||
destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
|
||||
}
|
||||
length := copy(p, request.Payload)
|
||||
if destination.IsFqdn() {
|
||||
return length, destination, nil
|
||||
}
|
||||
return length, destination.UDPAddr(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
return WritePacketResponse(c.conn, addr, p)
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Close() error { return c.conn.Close() }
|
||||
func (c *serverPacketConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||
func (c *serverPacketConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
||||
func (c *serverPacketConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
||||
func (c *serverPacketConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
||||
211
transport/snell/shadowaead.go
Normal file
211
transport/snell/shadowaead.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
// payloadSizeMask is the maximum size of payload in bytes.
|
||||
// 16*1024 - 1
|
||||
// >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead()
|
||||
|
||||
// ErrZeroChunk is returned when a zero-length chunk is read, which snell uses
|
||||
// as an end-of-stream signal.
|
||||
var ErrZeroChunk = errors.New("zero chunk")
|
||||
|
||||
// Cipher is the AEAD cipher abstraction used by the shadowaead stream.
|
||||
type Cipher interface {
|
||||
KeySize() int
|
||||
SaltSize() int
|
||||
Encrypter(salt []byte) (cipher.AEAD, error)
|
||||
Decrypter(salt []byte) (cipher.AEAD, error)
|
||||
}
|
||||
|
||||
const (
|
||||
payloadSizeMask = 0x3FFF
|
||||
bufSize = 17 * 1024
|
||||
)
|
||||
|
||||
type aeadWriter struct {
|
||||
io.Writer
|
||||
cipher.AEAD
|
||||
nonce [32]byte // should be sufficient for most nonce sizes
|
||||
}
|
||||
|
||||
// newAEADWriter wraps an io.Writer with authenticated encryption.
|
||||
func newAEADWriter(w io.Writer, aead cipher.AEAD) *aeadWriter {
|
||||
return &aeadWriter{Writer: w, AEAD: aead}
|
||||
}
|
||||
|
||||
// Write encrypts p and writes to the embedded io.Writer.
|
||||
func (w *aeadWriter) Write(p []byte) (n int, err error) {
|
||||
b := buf.Get(bufSize)
|
||||
defer buf.Put(b)
|
||||
nonce := w.nonce[:w.NonceSize()]
|
||||
tag := w.Overhead()
|
||||
off := 2 + tag
|
||||
if len(p) == 0 {
|
||||
b = b[:off]
|
||||
b[0], b[1] = byte(0), byte(0)
|
||||
w.Seal(b[:0], nonce, b[:2], nil)
|
||||
increment(nonce)
|
||||
_, err = w.Writer.Write(b)
|
||||
return
|
||||
}
|
||||
for nr := 0; n < len(p) && err == nil; n += nr {
|
||||
nr = payloadSizeMask
|
||||
if n+nr > len(p) {
|
||||
nr = len(p) - n
|
||||
}
|
||||
b = b[:off+nr+tag]
|
||||
b[0], b[1] = byte(nr>>8), byte(nr)
|
||||
w.Seal(b[:0], nonce, b[:2], nil)
|
||||
increment(nonce)
|
||||
w.Seal(b[:off], nonce, p[n:n+nr], nil)
|
||||
increment(nonce)
|
||||
_, err = w.Writer.Write(b)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type aeadReader struct {
|
||||
io.Reader
|
||||
cipher.AEAD
|
||||
nonce [32]byte // should be sufficient for most nonce sizes
|
||||
buf []byte // to be put back into bufPool
|
||||
off int // offset to unconsumed part of buf
|
||||
}
|
||||
|
||||
// newAEADReader wraps an io.Reader with authenticated decryption.
|
||||
func newAEADReader(r io.Reader, aead cipher.AEAD) *aeadReader {
|
||||
return &aeadReader{Reader: r, AEAD: aead}
|
||||
}
|
||||
|
||||
// read and decrypt a record into p. len(p) >= max payload size + AEAD overhead.
|
||||
func (r *aeadReader) read(p []byte) (int, error) {
|
||||
nonce := r.nonce[:r.NonceSize()]
|
||||
tag := r.Overhead()
|
||||
p = p[:2+tag]
|
||||
if _, err := io.ReadFull(r.Reader, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err := r.Open(p[:0], nonce, p, nil)
|
||||
increment(nonce)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask
|
||||
if size == 0 {
|
||||
return 0, ErrZeroChunk
|
||||
}
|
||||
p = p[:size+tag]
|
||||
if _, err := io.ReadFull(r.Reader, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = r.Open(p[:0], nonce, p, nil)
|
||||
increment(nonce)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// Read reads from the embedded io.Reader, decrypts and writes to p.
|
||||
func (r *aeadReader) Read(p []byte) (int, error) {
|
||||
if r.buf == nil {
|
||||
if len(p) >= payloadSizeMask+r.Overhead() {
|
||||
return r.read(p)
|
||||
}
|
||||
b := buf.Get(bufSize)
|
||||
n, err := r.read(b)
|
||||
if err != nil {
|
||||
buf.Put(b)
|
||||
return 0, err
|
||||
}
|
||||
r.buf = b[:n]
|
||||
r.off = 0
|
||||
}
|
||||
n := copy(p, r.buf[r.off:])
|
||||
r.off += n
|
||||
if r.off == len(r.buf) {
|
||||
buf.Put(r.buf[:cap(r.buf)])
|
||||
r.buf = nil
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// increment little-endian encoded unsigned integer b. Wrap around on overflow.
|
||||
func increment(b []byte) {
|
||||
for i := range b {
|
||||
b[i]++
|
||||
if b[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// aeadConn wraps a stream-oriented net.Conn with the shadowaead cipher.
|
||||
type aeadConn struct {
|
||||
net.Conn
|
||||
Cipher
|
||||
r *aeadReader
|
||||
w *aeadWriter
|
||||
}
|
||||
|
||||
// newAEADConn wraps a stream-oriented net.Conn with cipher.
|
||||
func newAEADConn(c net.Conn, ciph Cipher) *aeadConn {
|
||||
return &aeadConn{Conn: c, Cipher: ciph}
|
||||
}
|
||||
|
||||
func (c *aeadConn) initReader() error {
|
||||
salt := make([]byte, c.SaltSize())
|
||||
if _, err := io.ReadFull(c.Conn, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
aead, err := c.Decrypter(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.r = newAEADReader(c.Conn, aead)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) Read(b []byte) (int, error) {
|
||||
if c.r == nil {
|
||||
if err := c.initReader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *aeadConn) initWriter() error {
|
||||
salt := make([]byte, c.SaltSize())
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return err
|
||||
}
|
||||
aead, err := c.Encrypter(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.Conn.Write(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.w = newAEADWriter(c.Conn, aead)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) Write(b []byte) (int, error) {
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.w.Write(b)
|
||||
}
|
||||
408
transport/snell/snell.go
Normal file
408
transport/snell/snell.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
const (
|
||||
Version1 = 1
|
||||
Version2 = 2
|
||||
Version3 = 3
|
||||
Version4 = 4
|
||||
Version5 = 5
|
||||
DefaultSnellVersion = Version1
|
||||
|
||||
// max packet length
|
||||
maxLength = 0x3FFF
|
||||
)
|
||||
|
||||
const (
|
||||
CommandPing byte = 0
|
||||
CommandConnect byte = 1
|
||||
CommandConnectV2 byte = 5
|
||||
CommandUDP byte = 6
|
||||
CommandUDPForward byte = 1
|
||||
|
||||
CommandTunnel byte = 0
|
||||
CommandPong byte = 1
|
||||
CommandError byte = 2
|
||||
|
||||
Version byte = 1
|
||||
)
|
||||
|
||||
// Snell wraps an encrypted stream and handles the snell reply header.
|
||||
type Snell struct {
|
||||
net.Conn
|
||||
buffer [1]byte
|
||||
reply bool
|
||||
}
|
||||
|
||||
func (s *Snell) Read(b []byte) (int, error) {
|
||||
if err := s.ReadReply(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (s *Snell) ReadReply() error {
|
||||
if s.reply {
|
||||
return nil
|
||||
}
|
||||
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
s.reply = true
|
||||
if s.buffer[0] == CommandTunnel {
|
||||
return nil
|
||||
} else if s.buffer[0] != CommandError {
|
||||
return errors.New("command not support")
|
||||
}
|
||||
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
errcode := int(s.buffer[0])
|
||||
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
length := int(s.buffer[0])
|
||||
msg := make([]byte, length)
|
||||
if _, err := io.ReadFull(s.Conn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
|
||||
}
|
||||
|
||||
func WriteHeader(conn net.Conn, host string, port uint, version int) error {
|
||||
return WriteHeaderWithReuse(conn, host, port, version, false)
|
||||
}
|
||||
|
||||
func WriteHeaderWithReuse(conn net.Conn, host string, port uint, version int, reuse bool) error {
|
||||
buffer := &bytes.Buffer{}
|
||||
buffer.WriteByte(Version)
|
||||
if version == Version2 || reuse {
|
||||
buffer.WriteByte(CommandConnectV2)
|
||||
} else {
|
||||
buffer.WriteByte(CommandConnect)
|
||||
}
|
||||
buffer.WriteByte(0)
|
||||
buffer.WriteByte(uint8(len(host)))
|
||||
buffer.WriteString(host)
|
||||
binary.Write(buffer, binary.BigEndian, uint16(port))
|
||||
if _, err := conn.Write(buffer.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WriteUDPHeader(conn net.Conn, version int) error {
|
||||
if version < Version3 {
|
||||
return errors.New("unsupport UDP version")
|
||||
}
|
||||
_, err := conn.Write([]byte{Version, CommandUDP, 0x00})
|
||||
return err
|
||||
}
|
||||
|
||||
// HalfClose only works after the request negotiated the reuse command.
|
||||
func HalfClose(conn net.Conn) error {
|
||||
if err := writeZeroChunk(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if s, ok := conn.(*Snell); ok {
|
||||
s.reply = false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamConn wraps a raw connection with the snell stream cipher for the given version.
|
||||
func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
|
||||
if version >= Version4 {
|
||||
return &Snell{Conn: newV4Conn(conn, psk)}
|
||||
}
|
||||
var cipher Cipher
|
||||
if version != Version1 {
|
||||
cipher = NewAES128GCM(psk)
|
||||
} else {
|
||||
cipher = NewChacha20Poly1305(psk)
|
||||
}
|
||||
return &Snell{Conn: newAEADConn(conn, cipher)}
|
||||
}
|
||||
|
||||
// ServerStreamConn wraps a raw connection on the server side.
|
||||
func ServerStreamConn(conn net.Conn, psk []byte, version int) *Snell {
|
||||
stream := StreamConn(conn, psk, version)
|
||||
stream.reply = true
|
||||
return stream
|
||||
}
|
||||
|
||||
func PacketConn(conn net.Conn) net.PacketConn {
|
||||
return &packetConn{
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Snell) WritePacketFrame(b []byte) (int, error) {
|
||||
if fw, ok := s.Conn.(packetFrameWriter); ok {
|
||||
return fw.WritePacketFrame(b)
|
||||
}
|
||||
return s.Conn.Write(b)
|
||||
}
|
||||
|
||||
func WritePacket(w io.Writer, target, payload []byte) (int, error) {
|
||||
maxPayloadLength := maxLength - udpRequestHeaderLength(target)
|
||||
if maxPayloadLength <= 0 {
|
||||
return 0, errors.New("snell UDP address too large")
|
||||
}
|
||||
if len(payload) <= maxPayloadLength {
|
||||
return writePacket(w, target, payload)
|
||||
}
|
||||
return 0, errors.New("snell UDP payload too large")
|
||||
}
|
||||
|
||||
func WritePacketResponse(w io.Writer, addr net.Addr, payload []byte) (int, error) {
|
||||
buffer := &bytes.Buffer{}
|
||||
target := parseAddrToSocksAddr(addr)
|
||||
if len(target) == 0 {
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
switch target[0] {
|
||||
case atypIPv4:
|
||||
if len(target) < 1+net.IPv4len+2 {
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
buffer.WriteByte(0x04)
|
||||
buffer.Write(target[1 : 1+net.IPv4len+2])
|
||||
case atypIPv6:
|
||||
if len(target) < 1+net.IPv6len+2 {
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
buffer.WriteByte(0x06)
|
||||
buffer.Write(target[1 : 1+net.IPv6len+2])
|
||||
default:
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
buffer.Write(payload)
|
||||
var err error
|
||||
if fw, ok := w.(packetFrameWriter); ok {
|
||||
_, err = fw.WritePacketFrame(buffer.Bytes())
|
||||
} else {
|
||||
_, err = w.Write(buffer.Bytes())
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
// UDPRequest is a parsed snell UDP forward request.
|
||||
type UDPRequest struct {
|
||||
Host string
|
||||
Ip netip.Addr
|
||||
Port uint16
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func ParseUDPRequest(packet []byte) (UDPRequest, error) {
|
||||
if len(packet) < 2 || packet[0] != CommandUDPForward {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP request")
|
||||
}
|
||||
if hostLen := int(packet[1]); hostLen != 0 {
|
||||
if len(packet) <= 2+hostLen+2 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP domain request")
|
||||
}
|
||||
offset := 2 + hostLen
|
||||
return UDPRequest{
|
||||
Host: string(packet[2:offset]),
|
||||
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
|
||||
Payload: packet[offset+2:],
|
||||
}, nil
|
||||
}
|
||||
if len(packet) < 3 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP IP request")
|
||||
}
|
||||
switch packet[2] {
|
||||
case 0x04:
|
||||
if len(packet) < 3+net.IPv4len+2 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP IPv4 request")
|
||||
}
|
||||
offset := 3 + net.IPv4len
|
||||
ip, _ := netip.AddrFromSlice(packet[3:offset])
|
||||
return UDPRequest{
|
||||
Ip: ip.Unmap(),
|
||||
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
|
||||
Payload: packet[offset+2:],
|
||||
}, nil
|
||||
case 0x06:
|
||||
if len(packet) < 3+net.IPv6len+2 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP IPv6 request")
|
||||
}
|
||||
offset := 3 + net.IPv6len
|
||||
ip, _ := netip.AddrFromSlice(packet[3:offset])
|
||||
return UDPRequest{
|
||||
Ip: ip.Unmap(),
|
||||
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
|
||||
Payload: packet[offset+2:],
|
||||
}, nil
|
||||
default:
|
||||
return UDPRequest{}, errors.New("snell invalid UDP address type")
|
||||
}
|
||||
}
|
||||
|
||||
func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) {
|
||||
b := buf.Get(buf.UDPBufferSize)
|
||||
defer buf.Put(b)
|
||||
n, err := r.Read(b)
|
||||
headLen := 1
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if n < headLen {
|
||||
return nil, 0, errors.New("insufficient UDP length")
|
||||
}
|
||||
switch b[0] {
|
||||
case 0x04:
|
||||
headLen += net.IPv4len + 2
|
||||
if n < headLen {
|
||||
err = errors.New("insufficient UDP length")
|
||||
break
|
||||
}
|
||||
b[0] = atypIPv4
|
||||
case 0x06:
|
||||
headLen += net.IPv6len + 2
|
||||
if n < headLen {
|
||||
err = errors.New("insufficient UDP length")
|
||||
break
|
||||
}
|
||||
b[0] = atypIPv6
|
||||
default:
|
||||
err = errors.New("ip version invalid")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
addr := splitSocksAddr(b[0:])
|
||||
if addr == nil {
|
||||
return nil, 0, errors.New("remote address invalid")
|
||||
}
|
||||
uAddr := addr.UDPAddr()
|
||||
if uAddr == nil {
|
||||
return nil, 0, errors.New("parse addr error")
|
||||
}
|
||||
length := len(payload)
|
||||
if n-headLen < length {
|
||||
length = n - headLen
|
||||
}
|
||||
copy(payload[:], b[headLen:headLen+length])
|
||||
return uAddr, length, nil
|
||||
}
|
||||
|
||||
var endSignal = []byte{}
|
||||
|
||||
type packetFrameWriter interface {
|
||||
WritePacketFrame([]byte) (int, error)
|
||||
}
|
||||
|
||||
func writeZeroChunk(conn net.Conn) error {
|
||||
if _, err := conn.Write(endSignal); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePacket(w io.Writer, target, payload []byte) (int, error) {
|
||||
buffer := &bytes.Buffer{}
|
||||
buffer.WriteByte(CommandUDPForward)
|
||||
switch target[0] {
|
||||
case atypDomainName:
|
||||
hostLen := target[1]
|
||||
if len(target) < 1+1+int(hostLen)+2 {
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write(target[1 : 1+1+hostLen+2])
|
||||
case atypIPv4:
|
||||
if len(target) < 1+net.IPv4len+2 {
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write([]byte{0x00, 0x04})
|
||||
buffer.Write(target[1 : 1+net.IPv4len+2])
|
||||
case atypIPv6:
|
||||
if len(target) < 1+net.IPv6len+2 {
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write([]byte{0x00, 0x06})
|
||||
buffer.Write(target[1 : 1+net.IPv6len+2])
|
||||
default:
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write(payload)
|
||||
if fw, ok := w.(packetFrameWriter); ok {
|
||||
_, err := fw.WritePacketFrame(buffer.Bytes())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
_, err := w.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
func udpRequestHeaderLength(target []byte) int {
|
||||
if len(target) == 0 {
|
||||
return maxLength + 1
|
||||
}
|
||||
switch target[0] {
|
||||
case atypDomainName:
|
||||
if len(target) < 2 {
|
||||
return maxLength + 1
|
||||
}
|
||||
return 1 + 1 + int(target[1]) + 2
|
||||
case atypIPv4:
|
||||
return 1 + 2 + net.IPv4len + 2
|
||||
case atypIPv6:
|
||||
return 1 + 2 + net.IPv6len + 2
|
||||
default:
|
||||
return maxLength + 1
|
||||
}
|
||||
}
|
||||
|
||||
type packetConn struct {
|
||||
net.Conn
|
||||
rMux sync.Mutex
|
||||
wMux sync.Mutex
|
||||
}
|
||||
|
||||
func (pc *packetConn) WritePacketFrame(b []byte) (int, error) {
|
||||
if s, ok := pc.Conn.(*Snell); ok {
|
||||
if fw, ok := s.Conn.(packetFrameWriter); ok {
|
||||
return fw.WritePacketFrame(b)
|
||||
}
|
||||
}
|
||||
return pc.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
pc.wMux.Lock()
|
||||
defer pc.wMux.Unlock()
|
||||
return WritePacket(pc, parseAddrToSocksAddr(addr), b)
|
||||
}
|
||||
|
||||
func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
pc.rMux.Lock()
|
||||
defer pc.rMux.Unlock()
|
||||
addr, n, err := ReadPacket(pc.Conn, b)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
return n, addr, nil
|
||||
}
|
||||
463
transport/snell/v4.go
Normal file
463
transport/snell/v4.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"math/big"
|
||||
"math/bits"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
v4SaltSize = 16
|
||||
v4NonceSize = 12
|
||||
v4HeaderPlainSize = 7
|
||||
v4HeaderCipherSize = v4HeaderPlainSize + 16
|
||||
v4FrameSize = 1460
|
||||
v4InitialPaddingMin = 0x100
|
||||
v4InitialPaddingSpan = 0x100
|
||||
)
|
||||
|
||||
type v4Conn struct {
|
||||
net.Conn
|
||||
psk []byte
|
||||
r *v4Reader
|
||||
w *v4Writer
|
||||
}
|
||||
|
||||
func newV4Conn(conn net.Conn, psk []byte) *v4Conn {
|
||||
return &v4Conn{Conn: conn, psk: psk}
|
||||
}
|
||||
|
||||
func (c *v4Conn) initReader() error {
|
||||
salt := make([]byte, v4SaltSize)
|
||||
if _, err := io.ReadFull(c.Conn, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
aead, err := v4AEAD(c.psk, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.r = &v4Reader{Reader: c.Conn, aead: aead}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *v4Conn) initWriter() error {
|
||||
w, err := newV4Writer(c.Conn, c.psk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.w = w
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *v4Conn) Read(b []byte) (int, error) {
|
||||
if c.r == nil {
|
||||
if err := c.initReader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *v4Conn) Write(b []byte) (int, error) {
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.w.Write(b)
|
||||
}
|
||||
|
||||
func (c *v4Conn) WritePacketFrame(b []byte) (int, error) {
|
||||
if len(b) > maxLength {
|
||||
return 0, errors.New("snell v4 frame too large")
|
||||
}
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
c.w.mux.Lock()
|
||||
defer c.w.mux.Unlock()
|
||||
if err := c.w.writeFrame(b, c.w.nextFramePaddingLength(len(b))); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *v4Conn) WriteTo(w io.Writer) (int64, error) {
|
||||
if c.r == nil {
|
||||
if err := c.initReader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
var written int64
|
||||
buf := make([]byte, maxLength)
|
||||
for {
|
||||
n, err := c.r.Read(buf)
|
||||
if n > 0 {
|
||||
nw, ew := w.Write(buf[:n])
|
||||
written += int64(nw)
|
||||
if ew != nil {
|
||||
return written, ew
|
||||
}
|
||||
if nw != n {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *v4Conn) ReadFrom(r io.Reader) (int64, error) {
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
var read int64
|
||||
buf := make([]byte, maxLength)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
read += int64(n)
|
||||
if _, ew := c.w.Write(buf[:n]); ew != nil {
|
||||
return read, ew
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return read, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func v4AEAD(psk, salt []byte) (cipher.AEAD, error) {
|
||||
return aesGCM(snellKDF(psk, salt, 16))
|
||||
}
|
||||
|
||||
type v4Reader struct {
|
||||
io.Reader
|
||||
aead cipher.AEAD
|
||||
nonce [v4NonceSize]byte
|
||||
buf []byte
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func (r *v4Reader) Read(b []byte) (int, error) {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
if len(r.buf) == 0 {
|
||||
payload, err := r.readFrame()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r.buf = payload
|
||||
}
|
||||
n := copy(b, r.buf)
|
||||
r.buf = r.buf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *v4Reader) readFrame() ([]byte, error) {
|
||||
headerCipher := make([]byte, v4HeaderCipherSize)
|
||||
if _, err := io.ReadFull(r.Reader, headerCipher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
header, err := r.aead.Open(headerCipher[:0], r.nonce[:], headerCipher, nil)
|
||||
incrementV4Nonce(r.nonce[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(header) != v4HeaderPlainSize || header[0] != 4 {
|
||||
return nil, errors.New("snell v4 invalid frame header")
|
||||
}
|
||||
paddingLength := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
payloadLength := int(binary.BigEndian.Uint16(header[5:7]))
|
||||
if payloadLength == 0 {
|
||||
if paddingLength != 0 {
|
||||
return nil, errors.New("snell v4 zero chunk with padding")
|
||||
}
|
||||
return nil, ErrZeroChunk
|
||||
}
|
||||
if payloadLength > maxLength || paddingLength > maxLength {
|
||||
return nil, errors.New("snell v4 frame too large")
|
||||
}
|
||||
payloadCipherLength := payloadLength + r.aead.Overhead()
|
||||
frame := make([]byte, paddingLength+payloadCipherLength)
|
||||
if _, err := io.ReadFull(r.Reader, frame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if paddingLength > 0 {
|
||||
swapPadding(frame[:paddingLength], frame[paddingLength:])
|
||||
}
|
||||
payloadCipher := frame[paddingLength:]
|
||||
payload, err := r.aead.Open(payloadCipher[:0], r.nonce[:], payloadCipher, nil)
|
||||
incrementV4Nonce(r.nonce[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
type v4Writer struct {
|
||||
io.Writer
|
||||
aead cipher.AEAD
|
||||
nonce [v4NonceSize]byte
|
||||
salt [v4SaltSize]byte
|
||||
saltSent bool
|
||||
initialPaddingLength uint16
|
||||
payloadLimit uint16
|
||||
lastWrite time.Time
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newV4Writer(w io.Writer, psk []byte) (*v4Writer, error) {
|
||||
var salt [v4SaltSize]byte
|
||||
if _, err := io.ReadFull(cryptorand.Reader, salt[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := v4AEAD(psk, salt[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddingDelta, err := cryptoRandomInt(v4InitialPaddingSpan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v4Writer{
|
||||
Writer: w,
|
||||
aead: aead,
|
||||
salt: salt,
|
||||
initialPaddingLength: uint16(v4InitialPaddingMin + paddingDelta),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *v4Writer) Write(b []byte) (int, error) {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if len(b) == 0 {
|
||||
return 0, w.writeFrame(nil, 0)
|
||||
}
|
||||
written := 0
|
||||
for written < len(b) {
|
||||
payloadLimit := int(w.nextPayloadLimit())
|
||||
if payloadLimit <= 0 || payloadLimit > maxLength {
|
||||
payloadLimit = maxLength
|
||||
}
|
||||
end := written + payloadLimit
|
||||
if end > len(b) {
|
||||
end = len(b)
|
||||
}
|
||||
paddingLength := w.nextFramePaddingLength(end - written)
|
||||
if err := w.writeFrame(b[written:end], paddingLength); err != nil {
|
||||
return written, err
|
||||
}
|
||||
written = end
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (w *v4Writer) nextPayloadLimit() uint16 {
|
||||
now := time.Now()
|
||||
var payloadLimit uint16
|
||||
switch {
|
||||
case w.lastWrite.IsZero():
|
||||
payloadLimit = v4FrameSize - 55 - w.initialPaddingLength
|
||||
case now.Sub(w.lastWrite) > 30*time.Second:
|
||||
payloadLimit = v4FrameSize - 39
|
||||
default:
|
||||
payloadLimit = w.payloadLimit
|
||||
}
|
||||
w.lastWrite = now
|
||||
if payloadLimit <= maxLength-1 {
|
||||
next := int(payloadLimit) + v4FrameSize - 39
|
||||
if next > maxLength {
|
||||
next = maxLength
|
||||
}
|
||||
w.payloadLimit = uint16(next)
|
||||
} else {
|
||||
w.payloadLimit = maxLength
|
||||
}
|
||||
return payloadLimit
|
||||
}
|
||||
|
||||
func (w *v4Writer) nextFramePaddingLength(payloadLength int) int {
|
||||
if w.saltSent || payloadLength == 0 {
|
||||
return 0
|
||||
}
|
||||
return int(w.initialPaddingLength)
|
||||
}
|
||||
|
||||
func (w *v4Writer) writeFrame(payload []byte, paddingLength int) error {
|
||||
if len(payload) > maxLength || paddingLength > maxLength {
|
||||
return errors.New("snell v4 frame too large")
|
||||
}
|
||||
if len(payload) == 0 && paddingLength != 0 {
|
||||
return errors.New("snell v4 zero chunk with padding")
|
||||
}
|
||||
header := make([]byte, v4HeaderPlainSize)
|
||||
header[0] = 4
|
||||
binary.BigEndian.PutUint16(header[3:5], uint16(paddingLength))
|
||||
binary.BigEndian.PutUint16(header[5:7], uint16(len(payload)))
|
||||
headerCipher := w.aead.Seal(nil, w.nonce[:], header, nil)
|
||||
incrementV4Nonce(w.nonce[:])
|
||||
var payloadCipher []byte
|
||||
if len(payload) > 0 {
|
||||
payloadCipher = w.aead.Seal(nil, w.nonce[:], payload, nil)
|
||||
incrementV4Nonce(w.nonce[:])
|
||||
}
|
||||
frameLength := len(headerCipher) + paddingLength + len(payloadCipher)
|
||||
if !w.saltSent {
|
||||
frameLength += v4SaltSize
|
||||
}
|
||||
frame := make([]byte, 0, frameLength)
|
||||
if !w.saltSent {
|
||||
frame = append(frame, w.salt[:]...)
|
||||
w.saltSent = true
|
||||
}
|
||||
frame = append(frame, headerCipher...)
|
||||
if paddingLength > 0 {
|
||||
padding, err := makeV4Padding(payloadCipher, paddingLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
swapPadding(padding, payloadCipher)
|
||||
frame = append(frame, padding...)
|
||||
}
|
||||
frame = append(frame, payloadCipher...)
|
||||
return writeFull(w.Writer, frame)
|
||||
}
|
||||
|
||||
func swapPadding(padding, payloadCipher []byte) {
|
||||
limit := len(padding)
|
||||
if len(payloadCipher) < limit {
|
||||
limit = len(payloadCipher)
|
||||
}
|
||||
for i := 0; i < limit; i += 2 {
|
||||
padding[i], payloadCipher[i] = payloadCipher[i], padding[i]
|
||||
}
|
||||
}
|
||||
|
||||
func makeV4Padding(payloadCipher []byte, paddingLength int) ([]byte, error) {
|
||||
if paddingLength <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
payloadOnes := countV4PayloadOnes(payloadCipher)
|
||||
payloadZeros := 8*len(payloadCipher) - payloadOnes
|
||||
if payloadZeros <= 0 {
|
||||
return makeV4RandomPadding(paddingLength)
|
||||
}
|
||||
ratio := float64(payloadOnes) / float64(payloadZeros)
|
||||
if ratio <= 0.5 || ratio >= 1.6 {
|
||||
return makeV4RandomPadding(paddingLength)
|
||||
}
|
||||
targetRatioBase := 1.6
|
||||
if payloadZeros < payloadOnes {
|
||||
targetRatioBase = 0.4
|
||||
}
|
||||
jitter, err := randomUnitFloat64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetRatio := targetRatioBase + jitter/10
|
||||
totalBits := 8 * (paddingLength + len(payloadCipher))
|
||||
targetOnes := int(float64(totalBits)*(targetRatio/(targetRatio+1)) - float64(payloadOnes))
|
||||
if targetOnes < 0 || targetOnes > 8*paddingLength {
|
||||
return makeV4RandomPadding(paddingLength)
|
||||
}
|
||||
return makeV4BitCountPadding(paddingLength, targetOnes)
|
||||
}
|
||||
|
||||
func countV4PayloadOnes(payloadCipher []byte) int {
|
||||
limit := len(payloadCipher) &^ 3
|
||||
ones := 0
|
||||
for _, b := range payloadCipher[:limit] {
|
||||
ones += bits.OnesCount8(b)
|
||||
}
|
||||
return ones
|
||||
}
|
||||
|
||||
func makeV4RandomPadding(length int) ([]byte, error) {
|
||||
padding := make([]byte, length)
|
||||
_, err := io.ReadFull(cryptorand.Reader, padding)
|
||||
return padding, err
|
||||
}
|
||||
|
||||
func makeV4BitCountPadding(length, oneBits int) ([]byte, error) {
|
||||
totalBits := 8 * length
|
||||
if oneBits < 0 || oneBits > totalBits {
|
||||
return nil, errors.New("snell v4 invalid padding bit count")
|
||||
}
|
||||
bitset := make([]byte, totalBits)
|
||||
for i := 0; i < oneBits; i++ {
|
||||
bitset[i] = 1
|
||||
}
|
||||
for i := totalBits - 1; i > 0; i-- {
|
||||
j, err := cryptoRandomInt(i + 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bitset[i], bitset[j] = bitset[j], bitset[i]
|
||||
}
|
||||
padding := make([]byte, length)
|
||||
for i, bit := range bitset {
|
||||
if bit == 1 {
|
||||
padding[i/8] |= 1 << uint(i%8)
|
||||
}
|
||||
}
|
||||
return padding, nil
|
||||
}
|
||||
|
||||
func cryptoRandomInt(max int) (int, error) {
|
||||
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(int64(max)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(n.Int64()), nil
|
||||
}
|
||||
|
||||
func randomUnitFloat64() (float64, error) {
|
||||
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(1<<53))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return float64(n.Int64()) / math.Exp2(53), nil
|
||||
}
|
||||
|
||||
func writeFull(w io.Writer, p []byte) error {
|
||||
for len(p) > 0 {
|
||||
n, err := w.Write(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
p = p[n:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementV4Nonce(nonce []byte) {
|
||||
for i := range nonce {
|
||||
nonce[i]++
|
||||
if nonce[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user