Add Snell protocol. Refactor MASQUE HTTP/2, Fair Queue. Update XHTTP, OpenVPN, Sudoku, Fallback. Fixes

This commit is contained in:
Shtorm
2026-06-26 01:25:57 +03:00
parent d174962a04
commit edf38d33d6
107 changed files with 5346 additions and 708 deletions

View File

@@ -0,0 +1,331 @@
package masque
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/sagernet/quic-go/quicvarint"
"github.com/yosida95/uritemplate/v3"
"golang.org/x/net/http2"
)
const h2DatagramCapsuleType uint64 = 0
const (
ipv4HeaderLen = 20
ipv6HeaderLen = 40
)
func ConnectTunnelH2(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, endpoint *net.TCPAddr, connectUri string) (io.Closer, IpConn, *http.Response, error) {
if endpoint == nil {
return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
}
tlsConfig.SetNextProtos([]string{"h2"})
conn, err := dialer.DialContext(ctx, N.NetworkTCP, M.SocksaddrFromNetIP(endpoint.AddrPort()))
if err != nil {
return nil, nil, nil, err
}
tlsConn, err := tlsConfig.Client(conn)
if err != nil {
_ = conn.Close()
return nil, nil, nil, err
}
if err = tlsConn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, nil, nil, err
}
tr := &http2.Transport{
ReadIdleTimeout: 30 * time.Second,
}
cc, err := tr.NewClientConn(tlsConn)
if err != nil {
_ = tlsConn.Close()
return nil, nil, nil, fmt.Errorf("connect-ip: failed to create client connection: %w", err)
}
additionalHeaders := http.Header{
"User-Agent": []string{""},
}
template := uritemplate.MustNew(connectUri)
h2Headers := additionalHeaders.Clone()
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
h2Headers.Set("pq-enabled", "false")
ipConn, rsp, err := dialH2(ctx, cc, template, h2Headers)
if err != nil {
_ = cc.Close()
if strings.Contains(err.Error(), "tls: access denied") {
return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
}
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
}
if rsp.StatusCode != http.StatusOK {
_ = ipConn.Close()
_ = cc.Close()
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %v", rsp.Status)
}
return cc, ipConn, rsp, nil
}
func dialH2(ctx context.Context, rt http.RoundTripper, template *uritemplate.Template, additionalHeaders http.Header) (*h2IpConn, *http.Response, error) {
if len(template.Varnames()) > 0 {
return nil, nil, errors.New("connect-ip: IP flow forwarding not supported")
}
u, err := url.Parse(template.Raw())
if err != nil {
return nil, nil, fmt.Errorf("connect-ip: failed to parse URI: %w", err)
}
reqCtx, cancel := context.WithCancel(context.Background())
pr, pw := io.Pipe()
req, err := http.NewRequestWithContext(reqCtx, http.MethodConnect, u.String(), pr)
if err != nil {
cancel()
_ = pr.Close()
_ = pw.Close()
return nil, nil, fmt.Errorf("connect-ip: failed to create request: %w", err)
}
req.Host = authorityFromURL(u)
req.ContentLength = -1
req.Header = make(http.Header)
for k, v := range additionalHeaders {
req.Header[k] = v
}
stop := context.AfterFunc(ctx, cancel)
rsp, err := rt.RoundTrip(req)
stop()
if err != nil {
cancel()
_ = pr.Close()
_ = pw.Close()
return nil, nil, fmt.Errorf("connect-ip: failed to send request: %w", err)
}
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
cancel()
_ = pr.Close()
_ = pw.Close()
_ = rsp.Body.Close()
return nil, rsp, fmt.Errorf("connect-ip: server responded with %d", rsp.StatusCode)
}
stream := &h2DatagramStream{
requestBody: pw,
responseBody: rsp.Body,
cancel: cancel,
}
return &h2IpConn{
str: stream,
closeChan: make(chan struct{}),
}, rsp, nil
}
func authorityFromURL(u *url.URL) string {
if u.Port() != "" {
return u.Host
}
host := u.Hostname()
if host == "" {
return u.Host
}
return host + ":443"
}
type h2IpConn struct {
str *h2DatagramStream
mu sync.Mutex
closeChan chan struct{}
closeErr error
}
func (c *h2IpConn) ReadPacket() (b []byte, err error) {
start:
data, err := c.str.ReceiveDatagram(context.Background())
if err != nil {
defer func() {
_ = c.Close()
}()
select {
case <-c.closeChan:
return nil, c.closeErr
default:
return nil, err
}
}
if err := c.handleIncomingProxiedPacket(data); err != nil {
goto start
}
return data, nil
}
func (c *h2IpConn) handleIncomingProxiedPacket(data []byte) error {
if len(data) == 0 {
return errors.New("connect-ip: empty packet")
}
switch v := ipVersion(data); v {
default:
return fmt.Errorf("connect-ip: unknown IP versions: %d", v)
case 4:
if len(data) < ipv4HeaderLen {
return fmt.Errorf("connect-ip: malformed datagram: too short")
}
case 6:
if len(data) < ipv6HeaderLen {
return fmt.Errorf("connect-ip: malformed datagram: too short")
}
}
return nil
}
func (c *h2IpConn) WritePacket(b []byte) (icmp []byte, err error) {
data, err := c.composeDatagram(b)
if err != nil {
return nil, nil
}
if err := c.str.SendDatagram(data); err != nil {
select {
case <-c.closeChan:
return nil, c.closeErr
default:
return nil, err
}
}
return nil, nil
}
func (c *h2IpConn) composeDatagram(b []byte) ([]byte, error) {
if len(b) == 0 {
return nil, nil
}
switch v := ipVersion(b); v {
default:
return nil, fmt.Errorf("connect-ip: unknown IP versions: %d", v)
case 4:
if len(b) < ipv4HeaderLen {
return nil, fmt.Errorf("connect-ip: IPv4 packet too short")
}
ttl := b[8]
if ttl <= 1 {
return nil, fmt.Errorf("connect-ip: datagram TTL too small: %d", ttl)
}
b[8]--
binary.BigEndian.PutUint16(b[10:12], calculateIPv4Checksum(([ipv4HeaderLen]byte)(b[:ipv4HeaderLen])))
case 6:
if len(b) < ipv6HeaderLen {
return nil, fmt.Errorf("connect-ip: IPv6 packet too short")
}
hopLimit := b[7]
if hopLimit <= 1 {
return nil, fmt.Errorf("connect-ip: datagram Hop Limit too small: %d", hopLimit)
}
b[7]--
}
return b, nil
}
func (c *h2IpConn) Close() error {
c.mu.Lock()
if c.closeErr == nil {
c.closeErr = net.ErrClosed
close(c.closeChan)
}
c.mu.Unlock()
err := c.str.Close()
return err
}
func ipVersion(b []byte) uint8 { return b[0] >> 4 }
func calculateIPv4Checksum(header [ipv4HeaderLen]byte) uint16 {
var sum uint32
for i := 0; i < len(header); i += 2 {
if i == 10 {
continue
}
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
for (sum >> 16) > 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
type h2DatagramStream struct {
requestBody *io.PipeWriter
responseBody io.ReadCloser
cancel context.CancelFunc
readMu sync.Mutex
writeMu sync.Mutex
}
func (s *h2DatagramStream) ReceiveDatagram(_ context.Context) ([]byte, error) {
s.readMu.Lock()
defer s.readMu.Unlock()
reader := quicvarint.NewReader(s.responseBody)
for {
capsuleType, err := quicvarint.Read(reader)
if err != nil {
return nil, err
}
payloadLen, err := quicvarint.Read(reader)
if err != nil {
return nil, err
}
payload := make([]byte, payloadLen)
_, err = io.ReadFull(reader, payload)
if err != nil {
return nil, err
}
if capsuleType != h2DatagramCapsuleType {
continue
}
return payload, nil
}
}
func (s *h2DatagramStream) SendDatagram(data []byte) error {
frame := make([]byte, 0, quicvarint.Len(h2DatagramCapsuleType)+quicvarint.Len(uint64(len(data)))+len(data))
frame = quicvarint.Append(frame, h2DatagramCapsuleType)
frame = quicvarint.Append(frame, uint64(len(data)))
frame = append(frame, data...)
s.writeMu.Lock()
defer s.writeMu.Unlock()
_, err := s.requestBody.Write(frame)
if err != nil {
return fmt.Errorf("connect-ip: failed to send datagram capsule: %w", err)
}
return nil
}
func (s *h2DatagramStream) Close() error {
_ = s.requestBody.Close()
err := s.responseBody.Close()
s.cancel()
return err
}

View File

@@ -2,9 +2,9 @@ package masque
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
@@ -12,13 +12,13 @@ import (
connectip "github.com/Diniboy1123/connect-ip-go"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/congestion"
"github.com/sagernet/quic-go/http3"
qtls "github.com/sagernet/sing-quic"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/yosida95/uritemplate/v3"
"golang.org/x/net/http2"
)
type (
@@ -26,39 +26,60 @@ type (
ListenPacket func(network string, address string) (net.PacketConn, error)
)
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool) (net.PacketConn, *http3.Transport, *connectip.Conn, *http.Response, error) {
template := uritemplate.MustNew(connectUri)
additionalHeaders := http.Header{
"User-Agent": []string{""},
type IpConn interface {
ReadPacket() (b []byte, err error)
WritePacket(b []byte) (icmp []byte, err error)
Close() error
}
type closerFunc func() error
func (f closerFunc) Close() error { return f() }
type quicIpConn struct {
conn *connectip.Conn
buf []byte
}
func newQuicIpConn(conn *connectip.Conn) *quicIpConn {
return &quicIpConn{
conn: conn,
buf: make([]byte, 0xFFFF),
}
}
func (c *quicIpConn) ReadPacket() ([]byte, error) {
n, err := c.conn.ReadPacket(c.buf, true)
if err != nil {
return nil, err
}
return c.buf[:n], nil
}
func (c *quicIpConn) WritePacket(b []byte) (icmp []byte, err error) {
return c.conn.WritePacket(b)
}
func (c *quicIpConn) Close() error {
return c.conn.Close()
}
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool, congestionControl func(conn *quic.Conn) congestion.CongestionControl) (io.Closer, IpConn, *http.Response, error) {
if useHTTP2 {
h2Endpoint, ok := endpoint.(*net.TCPAddr)
if !ok || h2Endpoint == nil {
return nil, nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
}
h2Headers := additionalHeaders.Clone()
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
h2Headers.Set("pq-enabled", "false")
h2Client, err := newHTTP2Client(dialer, tlsConfig, h2Endpoint, connectUri)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to create HTTP/2 client: %w", err)
}
ipConn, rsp, err := connectip.DialH2(ctx, h2Client, template, h2Headers)
if err != nil {
if strings.Contains(err.Error(), "tls: access denied") {
return nil, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
}
return nil, nil, nil, nil, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
}
return nil, nil, ipConn, rsp, nil
return ConnectTunnelH2(ctx, dialer, tlsConfig, h2Endpoint, connectUri)
}
quicEndpoint, ok := endpoint.(*net.UDPAddr)
if !ok || quicEndpoint == nil {
return nil, nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint")
return nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint")
}
udpConn, err := dialer.ListenPacket(ctx, M.SocksaddrFromNetIP(quicEndpoint.AddrPort()))
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
conn, err := qtls.Dial(
ctx,
@@ -68,28 +89,34 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
quicConfig,
)
if err != nil {
return nil, nil, nil, nil, err
_ = udpConn.Close()
return nil, nil, nil, err
}
if congestionControl != nil {
conn.SetCongestionControl(congestionControl(conn))
}
tr := &http3.Transport{
EnableDatagrams: true,
AdditionalSettings: map[uint64]uint64{
// official client still sends this out as well, even though
// it's deprecated, see https://datatracker.ietf.org/doc/draft-ietf-masque-h3-datagram/00/
// SETTINGS_H3_DATAGRAM_00 = 0x0000000000000276
// https://github.com/cloudflare/quiche/blob/7c66757dbc55b8d0c3653d4b345c6785a181f0b7/quiche/src/h3/frame.rs#L46
0x276: 1,
},
DisableCompression: true,
}
hconn := tr.NewClientConn(conn)
template := uritemplate.MustNew(connectUri)
additionalHeaders := http.Header{
"User-Agent": []string{""},
}
ipConn, rsp, err := connectip.Dial(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true)
if err != nil {
_ = tr.Close()
_ = conn.CloseWithError(0, "connect-ip dial failed")
_ = udpConn.Close()
if strings.Contains(err.Error(), "tls: access denied") {
return udpConn, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
}
return udpConn, nil, nil, nil, fmt.Errorf("failed to dial connect-ip: %w", err)
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %w", err)
}
err = ipConn.AdvertiseRoute(ctx, []connectip.IPRoute{
{
@@ -109,34 +136,16 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
},
})
if err != nil {
return udpConn, nil, nil, nil, err
_ = ipConn.Close()
_ = tr.Close()
_ = udpConn.Close()
return nil, nil, nil, err
}
return udpConn, tr, ipConn, rsp, nil
}
func newHTTP2Client(dialer N.Dialer, baseTLSConfig aTLS.Config, endpoint *net.TCPAddr, connectURI string) (*http.Client, error) {
if endpoint == nil {
return nil, errors.New("missing HTTP/2 endpoint")
}
tlsConfig := baseTLSConfig.Clone()
tlsConfig.SetNextProtos([]string{"h2"})
return &http.Client{
Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, _ string, _ *tls.Config) (net.Conn, error) {
conn, err := dialer.DialContext(ctx, network, M.SocksaddrFromNetIP(endpoint.AddrPort()))
if err != nil {
return nil, err
}
tlsConn, err := tlsConfig.Client(conn)
if err != nil {
return nil, err
}
if err := tlsConn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, err
}
return tlsConn, nil
},
},
}, nil
closer := closerFunc(func() error {
_ = tr.Close()
_ = udpConn.Close()
return nil
})
return closer, newQuicIpConn(ipConn), rsp, nil
}

View File

@@ -5,6 +5,8 @@ import (
"net/netip"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/congestion"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/tls"
)
@@ -23,4 +25,5 @@ type TunnelOptions struct {
UDPKeepalivePeriod time.Duration
UDPInitialPacketSize uint16
ReconnectDelay time.Duration
CongestionControl func(conn *quic.Conn) congestion.CongestionControl
}

View File

@@ -4,13 +4,12 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"time"
connectip "github.com/Diniboy1123/connect-ip-go"
"github.com/sagernet/quic-go/http3"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
@@ -22,9 +21,8 @@ type Tunnel struct {
options TunnelOptions
device Device
udpConn net.PacketConn
tr *http3.Transport
ipConn *connectip.Conn
closer io.Closer
ipConn IpConn
mtx sync.Mutex
}
@@ -83,13 +81,11 @@ func (e *Tunnel) Close() error {
defer e.mtx.Unlock()
if e.ipConn != nil {
e.ipConn.Close()
if e.udpConn != nil {
e.udpConn.Close()
}
if e.tr != nil {
e.tr.Close()
if e.closer != nil {
e.closer.Close()
}
e.ipConn = nil
e.closer = nil
}
return e.device.Close()
}
@@ -124,7 +120,7 @@ func (e *Tunnel) maintainTunnel() {
}
icmp, err := ipConn.WritePacket(packet)
if err != nil {
if errors.As(err, new(*connectip.CloseError)) {
if errors.Is(err, net.ErrClosed) {
if ok := e.closeIpConn(ipConn); ok {
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing to IP connection: %w", err))
}
@@ -135,7 +131,7 @@ func (e *Tunnel) maintainTunnel() {
}
if len(icmp) > 0 {
if _, err := e.device.Write([][]byte{icmp}, 0); err != nil {
if errors.As(err, new(*connectip.CloseError)) {
if errors.Is(err, net.ErrClosed) {
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing ICMP to TUN device: %v", err))
continue
}
@@ -145,15 +141,14 @@ func (e *Tunnel) maintainTunnel() {
}
}()
go func() {
buf := make([]byte, 1280)
for e.ctx.Err() == nil {
ipConn, err := e.getIpConn()
if err != nil {
return
}
n, err := ipConn.ReadPacket(buf, true)
packet, err := ipConn.ReadPacket()
if err != nil {
if e.options.UseHTTP2 || errors.As(err, new(*connectip.CloseError)) {
if e.options.UseHTTP2 || errors.Is(err, net.ErrClosed) {
if ok := e.closeIpConn(ipConn); ok {
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while reading from IP connection: %v", err))
}
@@ -162,7 +157,7 @@ func (e *Tunnel) maintainTunnel() {
e.logger.ErrorContext(e.ctx, fmt.Errorf("Error reading from IP connection: %v, continuine...", err))
continue
}
if _, err := e.device.Write([][]byte{buf[:n]}, 0); err != nil {
if _, err := e.device.Write([][]byte{packet}, 0); err != nil {
continue
}
}
@@ -170,7 +165,7 @@ func (e *Tunnel) maintainTunnel() {
<-e.ctx.Done()
}
func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
func (e *Tunnel) getIpConn() (IpConn, error) {
e.mtx.Lock()
defer e.mtx.Unlock()
if e.ctx.Err() != nil {
@@ -184,7 +179,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
defer timer.Stop()
for {
e.logger.NoticeContext(e.ctx, fmt.Errorf("Establishing MASQUE connection to %s", e.options.Endpoint))
udpConn, tr, ipConn, rsp, err := ConnectTunnel(
closer, ipConn, rsp, err := ConnectTunnel(
e.ctx,
e.options.Dialer,
e.options.TLSConfig,
@@ -192,6 +187,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
"https://cloudflareaccess.com",
e.options.Endpoint,
e.options.UseHTTP2,
e.options.CongestionControl,
)
if err != nil {
e.logger.ErrorContext(e.ctx, fmt.Errorf("Failed to connect tunnel: %v", err))
@@ -206,11 +202,8 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
if rsp.StatusCode != 200 {
e.logger.ErrorContext(e.ctx, fmt.Errorf("Tunnel connection failed: %s", rsp.Status))
ipConn.Close()
if udpConn != nil {
udpConn.Close()
}
if tr != nil {
tr.Close()
if closer != nil {
closer.Close()
}
timer.Reset(e.options.ReconnectDelay)
select {
@@ -220,26 +213,23 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
}
continue
}
e.udpConn = udpConn
e.tr = tr
e.closer = closer
e.ipConn = ipConn
e.logger.NoticeContext(e.ctx, "Connected to MASQUE server ", e.options.Endpoint)
return ipConn, nil
}
}
func (e *Tunnel) closeIpConn(ipConn *connectip.Conn) bool {
func (e *Tunnel) closeIpConn(ipConn IpConn) bool {
e.mtx.Lock()
defer e.mtx.Unlock()
if ipConn == e.ipConn {
e.ipConn.Close()
if e.udpConn != nil {
e.udpConn.Close()
}
if e.tr != nil {
e.tr.Close()
if e.closer != nil {
e.closer.Close()
}
e.ipConn = nil
e.closer = nil
return true
}
return false

View File

@@ -4,6 +4,7 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
@@ -23,7 +24,7 @@ const (
type DataCipher interface {
Encrypt(header []byte, packetID uint32, payload []byte) ([]byte, error)
Decrypt(packet []byte, headerSize int) ([]byte, error)
Decrypt(packet []byte, headerSize int) (plaintext []byte, packetID uint32, err error)
}
type AEADDataCipher struct {
@@ -86,9 +87,9 @@ func (g *AEADDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
return out, nil
}
func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) {
func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
if len(packet) < headerSize+4+AESGCMTagSize+1 {
return nil, errors.New("openvpn gcm data packet too short")
return nil, 0, errors.New("openvpn gcm data packet too short")
}
header := packet[:headerSize]
pidBytes := packet[headerSize : headerSize+4]
@@ -96,8 +97,13 @@ func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error)
ciphertext := packet[headerSize+4+AESGCMTagSize:]
combined := append(ciphertext, tag...)
ad := append(header, pidBytes...)
nonce := g.nonce(binary.BigEndian.Uint32(pidBytes), g.recvImplicitIV)
return g.recv.Open(nil, nonce[:], combined, ad)
packetID := binary.BigEndian.Uint32(pidBytes)
nonce := g.nonce(packetID, g.recvImplicitIV)
plain, err := g.recv.Open(nil, nonce[:], combined, ad)
if err != nil {
return nil, 0, err
}
return plain, packetID, nil
}
func (g *AEADDataCipher) nonce(packetID uint32, implicit [AESGCMIVSize]byte) [AESGCMIVSize]byte {
@@ -127,6 +133,9 @@ func NewCBCCipher(keys *KeyMaterial, auth string) (*CBCDataCipher, error) {
var newHash func() hash.Hash
var hmacSize int
switch auth {
case AuthMD5:
newHash = md5.New
hmacSize = md5.Size
case AuthSHA256:
newHash = sha256.New
hmacSize = sha256.Size
@@ -176,34 +185,35 @@ func (c *CBCDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
return out, nil
}
func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) {
func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
minSize := headerSize + c.hmacSize + CBCIVSize + aes.BlockSize
if len(packet) < minSize {
return nil, errors.New("openvpn cbc data packet too short")
return nil, 0, errors.New("openvpn cbc data packet too short")
}
tag := packet[headerSize : headerSize+c.hmacSize]
iv := packet[headerSize+c.hmacSize : headerSize+c.hmacSize+CBCIVSize]
ct := packet[headerSize+c.hmacSize+CBCIVSize:]
if len(ct)%aes.BlockSize != 0 {
return nil, errors.New("openvpn cbc ciphertext not block-aligned")
return nil, 0, errors.New("openvpn cbc ciphertext not block-aligned")
}
mac := hmac.New(c.newHash, c.recvHMAC)
mac.Write(iv)
mac.Write(ct)
if !hmac.Equal(tag, mac.Sum(nil)) {
return nil, errors.New("openvpn cbc hmac verification failed")
return nil, 0, errors.New("openvpn cbc hmac verification failed")
}
plain := make([]byte, len(ct))
cipher.NewCBCDecrypter(c.recvBlock, iv).CryptBlocks(plain, ct)
padLen := int(plain[len(plain)-1])
if padLen < 1 || padLen > aes.BlockSize {
return nil, errors.New("openvpn cbc invalid padding")
return nil, 0, errors.New("openvpn cbc invalid padding")
}
plain = plain[:len(plain)-padLen]
if len(plain) < 4 {
return nil, errors.New("openvpn cbc payload too short")
return nil, 0, errors.New("openvpn cbc payload too short")
}
return plain[4:], nil
packetID := binary.BigEndian.Uint32(plain[:4])
return plain[4:], packetID, nil
}
func CipherKeyLength(cipher string) int {

View File

@@ -8,12 +8,16 @@ import (
"io"
"net"
"strings"
"sync/atomic"
"time"
"github.com/sagernet/sing/common/tls"
)
const defaultHandshakeTimeout = 30 * time.Second
const (
defaultHandshakeTimeout = 30 * time.Second
controlRetransmitDelay = time.Second
)
type Client struct {
config *ClientConfig
@@ -26,6 +30,8 @@ type Client struct {
push *PushReply
cancel context.CancelFunc
lastReceiveNano atomic.Int64
}
func NewClient(config *ClientConfig, io PacketIO, tlsConfig tls.Config) (*Client, error) {
@@ -154,6 +160,7 @@ func (c *Client) Handshake(ctx context.Context) (*PushReply, error) {
return nil, err
}
c.data = NewDataChannel(cipher, push.PeerID, push.CompLZO)
c.markReceive()
return push, nil
}
@@ -181,10 +188,21 @@ func (c *Client) ReadIPPacket(ctx context.Context) ([]byte, error) {
if err != nil {
continue
}
c.markReceive()
return plain, nil
}
}
func (c *Client) SinceReceive() time.Duration {
return time.Duration(int64(time.Since(clientStart)) - c.lastReceiveNano.Load())
}
func (c *Client) markReceive() {
c.lastReceiveNano.Store(int64(time.Since(clientStart)))
}
var clientStart = time.Now().Add(-time.Hour)
func (c *Client) Close() error {
if c.cancel != nil {
c.cancel()
@@ -199,10 +217,24 @@ func (c *Client) Close() error {
}
func (c *Client) waitServerReset(ctx context.Context) error {
retransmits := 0
for {
packet, err := c.control.Read(ctx)
readCtx := ctx
cancel := func() {}
if c.config.Proto == ProtoUDP {
readCtx, cancel = context.WithTimeout(ctx, controlRetransmitDelay)
}
packet, err := c.control.Read(readCtx)
cancel()
if err != nil {
return fmt.Errorf("read hard reset response: %w", err)
if c.config.Proto == ProtoUDP && errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil {
if err := c.control.RetransmitPending(ctx); err != nil {
return fmt.Errorf("retransmit hard reset: %w", err)
}
retransmits++
continue
}
return fmt.Errorf("read hard reset response after %d retransmits: %w", retransmits, err)
}
switch packet.Opcode {
case PControlHardResetServerV2:

View File

@@ -20,6 +20,7 @@ const (
CipherAES256CBC = "AES-256-CBC"
CipherCHACHA20POLY = "CHACHA20-POLY1305"
AuthMD5 = "MD5"
AuthSHA1 = "SHA1"
AuthSHA256 = "SHA256"
AuthSHA384 = "SHA384"
@@ -107,7 +108,7 @@ func isValidCipher(cipher string) bool {
func isValidAuth(auth string) bool {
switch auth {
case AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512:
case AuthMD5, AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512:
return true
}
return false

View File

@@ -30,8 +30,10 @@ type ControlChannel struct {
mu sync.Mutex
sendPacketID uint32
sendMessage uint32
recvMessage uint32
ackPending []uint32
pending map[uint32]*ControlPacket
recvPending map[uint32]*ControlPacket
readDeadline time.Time
writeDeadline time.Time
}
@@ -40,9 +42,10 @@ func NewControlChannel(io PacketIO, crypt ControlCrypt, local SessionID) *Contro
ch := &ControlChannel{
io: io,
clock: time.Now,
local: local,
pending: make(map[uint32]*ControlPacket),
clock: time.Now,
local: local,
pending: make(map[uint32]*ControlPacket),
recvPending: make(map[uint32]*ControlPacket),
}
if crypt != nil {
ch.encode = func(p *ControlPacket, pid uint32, t uint32) ([]byte, error) {
@@ -130,10 +133,23 @@ func (c *ControlChannel) SendAck(ctx context.Context) error {
func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
for {
c.mu.Lock()
if packet, ok := c.recvPending[c.recvMessage]; ok {
delete(c.recvPending, c.recvMessage)
c.recvMessage++
c.mu.Unlock()
return packet, nil
}
c.mu.Unlock()
packet, err := c.readControlPacket(ctx)
if err != nil {
return nil, err
}
var deliver *ControlPacket
sendAck := false
c.mu.Lock()
if c.remote == (SessionID{}) && packet.LocalSession != c.local {
c.remote = packet.LocalSession
@@ -144,11 +160,33 @@ func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
if packet.Opcode.HasMessageID() {
c.ackPending = appendAck(c.ackPending, packet.MessageID)
}
c.mu.Unlock()
if packet.Opcode == PAckV1 {
continue
switch {
case packet.Opcode == PAckV1:
case !packet.Opcode.HasMessageID():
deliver = packet
case packet.MessageID < c.recvMessage:
sendAck = true
case packet.MessageID == c.recvMessage:
deliver = packet
c.recvMessage++
default:
if _, exists := c.recvPending[packet.MessageID]; !exists {
c.recvPending[packet.MessageID] = packet
}
sendAck = true
}
c.mu.Unlock()
if deliver != nil {
return deliver, nil
}
if sendAck {
if err := c.SendAck(ctx); err != nil {
return nil, err
}
}
return packet, nil
}
}
@@ -349,11 +387,17 @@ func (c *ControlConn) SetWriteDeadline(t time.Time) error {
}
type streamPacketIO struct {
conn net.Conn
conn net.Conn
deadlineMu sync.Mutex
readDeadline time.Time
writeDeadline time.Time
}
type datagramPacketIO struct {
conn net.Conn
conn net.Conn
deadlineMu sync.Mutex
readDeadline time.Time
writeDeadline time.Time
}
func NewDatagramPacketIO(conn net.Conn) PacketIO {
@@ -361,40 +405,23 @@ func NewDatagramPacketIO(conn net.Conn) PacketIO {
}
func (d *datagramPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
done := make(chan struct{})
var (
packet []byte
err error
)
go func() {
defer close(done)
buf := make([]byte, 64*1024)
var n int
n, err = d.conn.Read(buf)
if err == nil {
packet = cloneBytes(buf[:n])
}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return packet, err
if err := setReadDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.readDeadline); err != nil {
return nil, err
}
buf := make([]byte, 64*1024)
n, err := d.conn.Read(buf)
if err != nil {
return nil, contextIOError(ctx, err)
}
return buf[:n], nil
}
func (d *datagramPacketIO) WritePacket(ctx context.Context, packet []byte) error {
done := make(chan error, 1)
go func() {
_, err := d.conn.Write(packet)
done <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
if err := setWriteDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.writeDeadline); err != nil {
return err
}
_, err := d.conn.Write(packet)
return contextIOError(ctx, err)
}
func (d *datagramPacketIO) Close() error {
@@ -414,52 +441,37 @@ func NewTCPPacketIO(conn net.Conn) PacketIO {
}
func (s *streamPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
done := make(chan struct{})
var (
packet []byte
err error
)
go func() {
defer close(done)
var lenBuf [2]byte
if _, err = io.ReadFull(s.conn, lenBuf[:]); err != nil {
return
}
size := int(lenBuf[0])<<8 | int(lenBuf[1])
if size == 0 {
err = errors.New("empty openvpn tcp packet")
return
}
packet = make([]byte, size)
_, err = io.ReadFull(s.conn, packet)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return packet, err
if err := setReadDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.readDeadline); err != nil {
return nil, err
}
var lenBuf [2]byte
if _, err := io.ReadFull(s.conn, lenBuf[:]); err != nil {
return nil, contextIOError(ctx, err)
}
size := int(lenBuf[0])<<8 | int(lenBuf[1])
if size == 0 {
return nil, errors.New("empty openvpn tcp packet")
}
packet := make([]byte, size)
if _, err := io.ReadFull(s.conn, packet); err != nil {
return nil, contextIOError(ctx, err)
}
return packet, nil
}
func (s *streamPacketIO) WritePacket(ctx context.Context, packet []byte) error {
if len(packet) > 0xffff {
return fmt.Errorf("openvpn tcp packet too large: %d", len(packet))
}
done := make(chan error, 1)
go func() {
frame := make([]byte, 2+len(packet))
frame[0] = byte(len(packet) >> 8)
frame[1] = byte(len(packet))
copy(frame[2:], packet)
_, err := s.conn.Write(frame)
done <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
if err := setWriteDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.writeDeadline); err != nil {
return err
}
frame := make([]byte, 2+len(packet))
frame[0] = byte(len(packet) >> 8)
frame[1] = byte(len(packet))
copy(frame[2:], packet)
_, err := s.conn.Write(frame)
return contextIOError(ctx, err)
}
func (s *streamPacketIO) Close() error {
@@ -473,3 +485,50 @@ func (s *streamPacketIO) LocalAddr() net.Addr {
func (s *streamPacketIO) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func setReadDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
deadline, hasDeadline := ctx.Deadline()
mu.Lock()
defer mu.Unlock()
if current.Equal(deadline) {
return nil
}
if hasDeadline {
if err := conn.SetReadDeadline(deadline); err != nil {
return err
}
} else if err := conn.SetReadDeadline(time.Time{}); err != nil {
return err
}
*current = deadline
return nil
}
func setWriteDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
deadline, hasDeadline := ctx.Deadline()
mu.Lock()
defer mu.Unlock()
if current.Equal(deadline) {
return nil
}
if hasDeadline {
if err := conn.SetWriteDeadline(deadline); err != nil {
return err
}
} else if err := conn.SetWriteDeadline(time.Time{}); err != nil {
return err
}
*current = deadline
return nil
}
func contextIOError(ctx context.Context, err error) error {
if err == nil {
return nil
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() && ctx.Err() != nil {
return ctx.Err()
}
return err
}

View File

@@ -8,15 +8,21 @@ import (
const (
PeerIDUnset uint32 = 0xffffff
dataChannelReplayWindow = 64
)
type DataChannel struct {
cipher DataCipher
keyID uint8
peerID uint32
compLZO bool
cipher DataCipher
keyID uint8
peerID uint32
compLZO bool
mu sync.Mutex
sendPacketID uint32
recvHighest uint32
recvWindow uint64
recvSeen bool
}
func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel {
@@ -29,10 +35,11 @@ func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel
func (d *DataChannel) Encrypt(packet []byte) ([]byte, error) {
if d.compLZO {
p := make([]byte, 1+len(packet))
p[0] = 0xFA
copy(p[1:], packet)
packet = p
compressed, err := lzo1xCompressSafe(packet)
if err != nil {
return nil, err
}
packet = compressed
}
d.mu.Lock()
d.sendPacketID++
@@ -50,18 +57,15 @@ func (d *DataChannel) Decrypt(packet []byte) ([]byte, error) {
if opcode == PDataV2 {
headerSize = 4
}
plain, err := d.cipher.Decrypt(packet, headerSize)
plain, packetID, err := d.cipher.Decrypt(packet, headerSize)
if err != nil {
return nil, err
}
if err := d.acceptPacketID(packetID); err != nil {
return nil, err
}
if d.compLZO {
if len(plain) < 1 {
return nil, errors.New("openvpn comp-lzo packet too short")
}
if plain[0] != 0xFA {
return nil, fmt.Errorf("openvpn compressed packet not supported (byte: 0x%02x)", plain[0])
}
plain = plain[1:]
return lzo1xDecompressSafe(plain)
}
return plain, nil
}
@@ -78,6 +82,40 @@ func (d *DataChannel) dataHeader() []byte {
return []byte{opcodeKeyID(PDataV1, d.keyID)}
}
func (d *DataChannel) acceptPacketID(packetID uint32) error {
d.mu.Lock()
defer d.mu.Unlock()
if !d.recvSeen {
d.recvHighest = packetID
d.recvWindow = 1
d.recvSeen = true
return nil
}
if packetID > d.recvHighest {
shift := packetID - d.recvHighest
if shift >= dataChannelReplayWindow {
d.recvWindow = 1
} else {
d.recvWindow = d.recvWindow<<shift | 1
}
d.recvHighest = packetID
return nil
}
diff := d.recvHighest - packetID
if diff >= dataChannelReplayWindow {
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
}
mask := uint64(1) << diff
if d.recvWindow&mask != 0 {
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
}
d.recvWindow |= mask
return nil
}
func ParsePeerID(options string) uint32 {
for _, field := range splitPushOptions(options) {
if len(field) > len("peer-id ") && field[:len("peer-id ")] == "peer-id " {

View File

@@ -0,0 +1,444 @@
//go:build with_openvpn && with_gvisor
// OpenVPN E2E tests. Require a local OpenVPN server setup.
//
// Setup (run once before testing):
//
// # Generate PKI
// mkdir -p /tmp/ovpn-e2e/pki/{issued,private}
// cd /tmp/ovpn-e2e/pki
// openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -days 1 -nodes -keyout ca.key -out ca.crt -subj "/CN=E2ETestCA"
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/server.key -out server.csr -subj "/CN=server"
// openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/server.crt -days 1
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/client.key -out client.csr -subj "/CN=client"
// openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/client.crt -days 1
// openvpn --genkey secret ta.key
// openvpn --genkey secret ta-auth.key
//
// # Start servers (4 instances: TCP/UDP × tls-crypt/tls-auth)
// # TCP + tls-crypt on :11940, subnet 10.99.0.0/24
// # UDP + tls-crypt on :11941, subnet 10.99.1.0/24
// # TCP + tls-auth on :11942, subnet 10.99.2.0/24
// # UDP + tls-auth on :11943, subnet 10.99.3.0/24
// #
// # Each server config needs: topology subnet, duplicate-cn, persist-tun,
// # data-ciphers AES-256-GCM:AES-128-GCM:AES-192-GCM:CHACHA20-POLY1305:AES-256-CBC:AES-128-CBC:AES-192-CBC
// # auth SHA256, keepalive 10 60, ca/cert/key from above PKI.
// # tls-auth servers use: tls-auth ta-auth.key 0
// # tls-crypt servers use: tls-crypt ta.key
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-crypt.conf --daemon
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-crypt.conf --daemon
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-auth.conf --daemon
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-auth.conf --daemon
//
// # Start HTTP servers on each VPN subnet
// for ip in 10.99.0.1 10.99.1.1 10.99.2.1 10.99.3.1; do
// mkdir -p /tmp/ovpn-e2e/$ip && echo "hello" > /tmp/ovpn-e2e/$ip/index.html
// cd /tmp/ovpn-e2e/$ip && python3 -m http.server 8080 --bind $ip &
// done
//
// Run tests:
//
// go test -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_manager,with_admin_panel,with_v2ray_api,with_ccm,with_ocm,with_profiler,with_openvpn,with_sudoku,with_trusttunnel" \
// -run TestE2E -v -count=1 ./transport/openvpn/ -timeout 300s
//
// Tests all 28 combinations: 2 protos (tcp/udp) × 2 TLS modes (tls-crypt/tls-auth) × 7 ciphers.
package openvpn_test
import (
"context"
"fmt"
"io"
"net"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/sagernet/sing-box"
"github.com/sagernet/sing-box/include"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/json/badoption"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/protocol/socks"
)
// Servers (started externally):
// TCP+tls-crypt :11940 subnet 10.99.0.0/24
// UDP+tls-crypt :11941 subnet 10.99.1.0/24
// TCP+tls-auth :11942 subnet 10.99.2.0/24
// UDP+tls-auth :11943 subnet 10.99.3.0/24
// TCP+plain :11944 subnet 10.99.4.0/24
// UDP+plain :11945 subnet 10.99.5.0/24
// TCP+tls-crypt+SHA1 :11946 subnet 10.99.6.0/24 (CBC only)
// TCP+tls-crypt+SHA512 :11947 subnet 10.99.7.0/24 (CBC only)
// Each has HTTP on .1:8080 serving "hello"
const pkiDir = "/tmp/ovpn-e2e/pki"
type serverConfig struct {
proto string
port uint16
tlsMode string // "tls-crypt" or "tls-auth"
httpAddr string
}
var servers = []serverConfig{
{"tcp", 11940, "tls-crypt", "10.99.0.1:8080"},
{"udp", 11941, "tls-crypt", "10.99.1.1:8080"},
{"tcp", 11942, "tls-auth", "10.99.2.1:8080"},
{"udp", 11943, "tls-auth", "10.99.3.1:8080"},
}
var ciphers = []string{
"AES-128-GCM",
"AES-192-GCM",
"AES-256-GCM",
"CHACHA20-POLY1305",
"AES-128-CBC",
"AES-192-CBC",
"AES-256-CBC",
}
var portCounter atomic.Uint32
func init() { portCounter.Store(18100) }
func nextPort() uint16 { return uint16(portCounter.Add(1)) }
func readFile(t *testing.T, path string) string {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Skipf("PKI not found: %v", err)
}
return string(data)
}
func testCombo(t *testing.T, srv serverConfig, cipher string) {
t.Helper()
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
ovpnOpts := &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: srv.port}},
Proto: srv.proto,
Cipher: cipher,
Auth: "SHA256",
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
}
switch srv.tlsMode {
case "tls-crypt":
ovpnOpts.TLSCrypt = readFile(t, pkiDir+"/ta.key")
case "tls-auth":
ovpnOpts.TLSAuth = readFile(t, pkiDir+"/ta-auth.key")
ovpnOpts.KeyDirection = 1
}
port := nextPort()
opts := option.Options{
Log: &option.LogOptions{Level: "error"},
Inbounds: []option.Inbound{{
Type: "socks",
Options: &option.SocksInboundOptions{
ListenOptions: option.ListenOptions{
Listen: (*badoption.Addr)(&badoption.Addr{}),
ListenPort: port,
},
},
}},
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
Route: &option.RouteOptions{Final: "vpn"},
}
ctx := include.Context(context.Background())
instance, err := box.New(box.Options{Context: ctx, Options: opts})
if err != nil {
t.Fatal(err)
}
if err := instance.Start(); err != nil {
t.Fatal(err)
}
defer instance.Close()
time.Sleep(2 * time.Second)
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(srv.httpAddr))
if err != nil {
t.Fatal("dial:", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
if err != nil {
t.Fatal("write:", err)
}
body, err := io.ReadAll(conn)
if err != nil {
t.Fatal("read:", err)
}
if !strings.Contains(string(body), "hello") {
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
}
}
// 4 servers × 7 ciphers = 28 combinations
func TestE2E(t *testing.T) {
for _, srv := range servers {
for _, cipher := range ciphers {
name := fmt.Sprintf("%s/%s/%s", srv.proto, srv.tlsMode, cipher)
srv, cipher := srv, cipher
t.Run(name, func(t *testing.T) {
testCombo(t, srv, cipher)
})
}
}
}
// Test CBC ciphers with different auth algorithms (SHA1, SHA512)
func TestE2E_Auth(t *testing.T) {
type authServer struct {
port uint16
auth string
httpAddr string
}
authServers := []authServer{
{11946, "SHA1", "10.99.6.1:8080"},
{11947, "SHA512", "10.99.7.1:8080"},
}
cbcCiphers := []string{"AES-128-CBC", "AES-256-CBC"}
for _, as := range authServers {
for _, cipher := range cbcCiphers {
name := fmt.Sprintf("auth-%s/%s", as.auth, cipher)
as, cipher := as, cipher
t.Run(name, func(t *testing.T) {
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
tlsCrypt := readFile(t, pkiDir+"/ta.key")
port := nextPort()
opts := option.Options{
Log: &option.LogOptions{Level: "error"},
Inbounds: []option.Inbound{{
Type: "socks",
Options: &option.SocksInboundOptions{
ListenOptions: option.ListenOptions{
Listen: (*badoption.Addr)(&badoption.Addr{}),
ListenPort: port,
},
},
}},
Outbounds: []option.Outbound{{
Type: "openvpn", Tag: "vpn",
Options: &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: as.port}},
Proto: "tcp",
Cipher: cipher,
Auth: as.auth,
TLSCrypt: tlsCrypt,
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
},
}},
Route: &option.RouteOptions{Final: "vpn"},
}
ctx := include.Context(context.Background())
instance, err := box.New(box.Options{Context: ctx, Options: opts})
if err != nil {
t.Fatal(err)
}
if err := instance.Start(); err != nil {
t.Fatal(err)
}
defer instance.Close()
time.Sleep(2 * time.Second)
doHTTPCheck(t, port, as.httpAddr)
})
}
}
}
// Test tunnel stability with multiple sequential requests
func TestE2E_BulkData(t *testing.T) {
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
tlsCrypt := readFile(t, pkiDir+"/ta.key")
port := nextPort()
opts := option.Options{
Log: &option.LogOptions{Level: "error"},
Inbounds: []option.Inbound{{
Type: "socks",
Options: &option.SocksInboundOptions{
ListenOptions: option.ListenOptions{
Listen: (*badoption.Addr)(&badoption.Addr{}),
ListenPort: port,
},
},
}},
Outbounds: []option.Outbound{{
Type: "openvpn", Tag: "vpn",
Options: &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11940}},
Proto: "tcp",
Cipher: "AES-256-GCM",
Auth: "SHA256",
TLSCrypt: tlsCrypt,
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
},
}},
Route: &option.RouteOptions{Final: "vpn"},
}
ctx := include.Context(context.Background())
instance, err := box.New(box.Options{Context: ctx, Options: opts})
if err != nil {
t.Fatal(err)
}
if err := instance.Start(); err != nil {
t.Fatal(err)
}
defer instance.Close()
time.Sleep(2 * time.Second)
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
for i := 0; i < 10; i++ {
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr("10.99.0.1:8080"))
if err != nil {
t.Fatalf("request %d dial: %v", i, err)
}
conn.SetDeadline(time.Now().Add(5 * time.Second))
fmt.Fprintf(conn, "GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n")
body, err := io.ReadAll(conn)
conn.Close()
if err != nil {
t.Fatalf("request %d read: %v", i, err)
}
if !strings.Contains(string(body), "hello") {
t.Fatalf("request %d: no 'hello'", i)
}
}
}
func doHTTPCheck(t *testing.T, socksPort uint16, httpAddr string) {
t.Helper()
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", socksPort), socks.Version5, "", "")
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(httpAddr))
if err != nil {
t.Fatal("dial:", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
if err != nil {
t.Fatal("write:", err)
}
body, err := io.ReadAll(conn)
if err != nil {
t.Fatal("read:", err)
}
if !strings.Contains(string(body), "hello") {
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
}
}
func startInstance(t *testing.T, ovpnOpts *option.OpenVPNOutboundOptions) uint16 {
t.Helper()
port := nextPort()
opts := option.Options{
Log: &option.LogOptions{Level: "error"},
Inbounds: []option.Inbound{{
Type: "socks",
Options: &option.SocksInboundOptions{
ListenOptions: option.ListenOptions{
Listen: (*badoption.Addr)(&badoption.Addr{}),
ListenPort: port,
},
},
}},
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
Route: &option.RouteOptions{Final: "vpn"},
}
ctx := include.Context(context.Background())
instance, err := box.New(box.Options{Context: ctx, Options: opts})
if err != nil {
t.Fatal(err)
}
if err := instance.Start(); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { instance.Close() })
time.Sleep(2 * time.Second)
return port
}
func TestE2E_CompLZO(t *testing.T) {
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
tlsCrypt := readFile(t, pkiDir+"/ta.key")
for _, cipher := range ciphers {
cipher := cipher
t.Run(cipher, func(t *testing.T) {
port := startInstance(t, &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11948}},
Proto: "udp",
Cipher: cipher,
Auth: "SHA256",
TLSCrypt: tlsCrypt,
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
})
doHTTPCheck(t, port, "10.99.8.1:8080")
})
}
}
func TestE2E_AES192(t *testing.T) {
ca := readFile(t, pkiDir+"/ca.crt")
cert := readFile(t, pkiDir+"/issued/client.crt")
key := readFile(t, pkiDir+"/private/client.key")
tlsCrypt := readFile(t, pkiDir+"/ta.key")
type combo struct {
proto string
port uint16
httpAddr string
}
for _, c := range []combo{
{"tcp", 11940, "10.99.0.1:8080"},
{"udp", 11941, "10.99.1.1:8080"},
} {
for _, cipher := range []string{"AES-192-GCM", "AES-192-CBC"} {
c, cipher := c, cipher
t.Run(fmt.Sprintf("%s/%s", c.proto, cipher), func(t *testing.T) {
port := startInstance(t, &option.OpenVPNOutboundOptions{
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: c.port}},
Proto: c.proto,
Cipher: cipher,
Auth: "SHA256",
TLSCrypt: tlsCrypt,
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
},
})
doHTTPCheck(t, port, c.httpAddr)
})
}
}
}
var _ net.Conn

View File

@@ -114,7 +114,7 @@ func ParseServerKeyMethod2Record(packet []byte) (*KeyMethod2Record, error) {
}
func DeriveClientKeyMaterial(sources KeySource2, clientSession, serverSession SessionID, cipherKeyLen int) (*KeyMaterial, error) {
if cipherKeyLen != 16 && cipherKeyLen != 32 {
if cipherKeyLen != 16 && cipherKeyLen != 24 && cipherKeyLen != 32 {
return nil, fmt.Errorf("unsupported data cipher key length %d", cipherKeyLen)
}
var master [48]byte

48
transport/openvpn/lzo.go Normal file
View File

@@ -0,0 +1,48 @@
package openvpn
import (
"bytes"
"errors"
"github.com/rasky/go-lzo"
)
const (
lzoCompressNone = 0xFA
lzoCompressLZO = 0x66
)
var ErrLZODecompress = errors.New("lzo decompression failed")
func lzo1xDecompressSafe(src []byte) ([]byte, error) {
if len(src) == 0 {
return nil, ErrLZODecompress
}
switch src[0] {
case lzoCompressNone:
if len(src) > 1 {
return src[1:], nil
}
return nil, nil
case lzoCompressLZO:
if len(src) > 1 {
r := bytes.NewReader(src[1:])
out, err := lzo.Decompress1X(r, len(src)-1, 0)
if err != nil {
return nil, ErrLZODecompress
}
return out, nil
}
return nil, nil
default:
return nil, ErrLZODecompress
}
}
func lzo1xCompressSafe(src []byte) ([]byte, error) {
lzoPacket := make([]byte, 1+len(src))
lzoPacket[0] = lzoCompressNone
copy(lzoPacket[1:], src)
return lzoPacket, nil
}

View File

@@ -10,16 +10,17 @@ import (
const PushRequest = "PUSH_REQUEST"
type PushReply struct {
Raw string
Prefixes []netip.Prefix
DNS []netip.Addr
PeerID uint32
Cipher string
Ping uint32
MTU uint32
CompLZO bool
Redirect bool
BlockIPv6 bool
Raw string
Prefixes []netip.Prefix
DNS []netip.Addr
PeerID uint32
Cipher string
Ping uint32
PingRestart uint32
MTU uint32
CompLZO bool
Redirect bool
BlockIPv6 bool
}
func ParsePushReply(message string) (*PushReply, error) {
@@ -81,6 +82,12 @@ func ParsePushReply(message string) (*PushReply, error) {
reply.Ping = uint32(v)
}
}
case "ping-restart":
if len(fields) >= 2 {
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
reply.PingRestart = uint32(v)
}
}
case "tun-mtu":
if len(fields) >= 2 {
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
@@ -113,27 +120,44 @@ func splitPushOptions(message string) []string {
return out
}
func parseIPv4Ifconfig(address, mask string) (netip.Prefix, error) {
func parseIPv4Ifconfig(address, maskOrPeer string) (netip.Prefix, error) {
addr, err := netip.ParseAddr(address)
if err != nil {
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 address %q: %w", address, err)
}
maskAddr, err := netip.ParseAddr(mask)
maskAddr, err := netip.ParseAddr(maskOrPeer)
if err != nil {
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", mask, err)
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", maskOrPeer, err)
}
if !addr.Is4() || !maskAddr.Is4() {
return netip.Prefix{}, fmt.Errorf("openvpn ifconfig requires ipv4 address and mask")
}
maskBytes := maskAddr.As4()
if ones, ok := ipv4MaskSize(maskAddr); ok {
return netip.PrefixFrom(addr, ones), nil
}
// Some servers, including SoftEther/VPNGate in net30/p2p mode, push
// "ifconfig <local> <remote>" rather than "ifconfig <local> <netmask>".
// Use a host prefix for that local tunnel address.
return netip.PrefixFrom(addr, 32), nil
}
func ipv4MaskSize(mask netip.Addr) (int, bool) {
maskBytes := mask.As4()
ones := 0
seenZero := false
for _, b := range maskBytes {
for i := 7; i >= 0; i-- {
if b&(1<<i) == 0 {
return netip.PrefixFrom(addr, ones), nil
seenZero = true
continue
}
if seenZero {
return 0, false
}
ones++
}
}
return netip.PrefixFrom(addr, ones), nil
return ones, true
}

View File

@@ -2,6 +2,7 @@ package openvpn
import (
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
@@ -35,6 +36,9 @@ func NewTLSAuth(staticKey []byte, keyDirection int, auth string) (*TLSAuth, erro
var newHash func() hash.Hash
var hmacSize int
switch auth {
case AuthMD5:
newHash = md5.New
hmacSize = md5.Size
case AuthSHA256:
newHash = sha256.New
hmacSize = sha256.Size

View File

@@ -30,16 +30,19 @@ type TunnelOptions struct {
UDPTimeout time.Duration
ReconnectDelay time.Duration
PingInterval time.Duration
PingRestart time.Duration
}
type Tunnel struct {
ctx context.Context
cancel context.CancelFunc
logger logger.ContextLogger
options TunnelOptions
device Device
client *Client
mtu uint32
serverIndex int
wg sync.WaitGroup
await chan struct{}
mu sync.Mutex
@@ -49,8 +52,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
if options.ReconnectDelay == 0 {
options.ReconnectDelay = 5 * time.Second
}
ctx, cancel := context.WithCancel(ctx)
return &Tunnel{
ctx: ctx,
cancel: cancel,
logger: logger,
options: options,
await: make(chan struct{}),
@@ -59,10 +64,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
func (t *Tunnel) Start() error {
go func() {
defer close(t.await)
client, err := t.getClient()
if err != nil {
t.logger.Error("OpenVPN connect: ", err)
close(t.await)
return
}
t.mtu = 1500
@@ -84,20 +89,26 @@ func (t *Tunnel) Start() error {
if err != nil {
client.Close()
t.logger.Error("create OpenVPN device: ", err)
close(t.await)
return
}
t.device = device
if err := device.Start(); err != nil {
client.Close()
t.logger.Error("start OpenVPN device: ", err)
close(t.await)
return
}
close(t.await)
t.maintainTunnel()
}()
return nil
}
func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if err := t.isTunnelInitialized(ctx); err != nil {
return nil, err
}
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
@@ -105,6 +116,9 @@ func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.
}
func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if err := t.isTunnelInitialized(ctx); err != nil {
return nil, err
}
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
@@ -112,15 +126,18 @@ func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
}
func (t *Tunnel) Close() error {
t.cancel()
t.mu.Lock()
defer t.mu.Unlock()
if t.client != nil {
t.client.Close()
t.client = nil
}
if t.device != nil {
return t.device.Close()
t.device.Close()
t.device = nil
}
t.mu.Unlock()
t.wg.Wait()
return nil
}
@@ -137,7 +154,9 @@ func (t *Tunnel) isTunnelInitialized(ctx context.Context) error {
}
func (t *Tunnel) maintainTunnel() {
t.wg.Add(2)
go func() {
defer t.wg.Done()
bufs := make([][]byte, 1)
bufs[0] = make([]byte, t.mtu)
sizes := make([]int, 1)
@@ -161,6 +180,7 @@ func (t *Tunnel) maintainTunnel() {
}
}()
go func() {
defer t.wg.Done()
for t.ctx.Err() == nil {
client, err := t.getClient()
if err != nil {
@@ -179,10 +199,14 @@ func (t *Tunnel) maintainTunnel() {
if bytes.Equal(packet, pingPayload) {
continue
}
if t.ctx.Err() != nil {
return
}
if t.ctx.Err() != nil {
return
}
if _, err := t.device.Write([][]byte{packet}, 0); err != nil {
if t.ctx.Err() != nil {
return
}
return
}
}
}()
@@ -208,6 +232,34 @@ func (t *Tunnel) maintainTunnel() {
}
}()
}
pingRestart := t.options.PingRestart
if pingRestart == 0 && t.client != nil && t.client.push.PingRestart > 0 {
pingRestart = time.Duration(t.client.push.PingRestart) * time.Second
}
if pingRestart > 0 {
t.wg.Add(1)
go func() {
defer t.wg.Done()
ticker := time.NewTicker(pingRestart)
defer ticker.Stop()
for {
select {
case <-t.ctx.Done():
return
case <-ticker.C:
client, err := t.getClient()
if err != nil {
return
}
if client.SinceReceive() >= pingRestart {
if ok := t.closeClient(client); ok {
t.logger.ErrorContext(t.ctx, fmt.Errorf("ping-restart timeout: no packet received for %s", pingRestart))
}
}
}
}
}()
}
<-t.ctx.Done()
}

View File

@@ -0,0 +1,100 @@
package obfs
import (
"bufio"
cryptorand "crypto/rand"
"encoding/base64"
"fmt"
"io"
"math/rand/v2"
"net"
"net/http"
"time"
)
// HTTPObfsServer is the server side of the shadowsocks http simple-obfs implementation.
type HTTPObfsServer struct {
net.Conn
buf []byte
bio *bufio.Reader
offset int
firstRequest bool
firstResponse bool
}
func (hos *HTTPObfsServer) Read(b []byte) (int, error) {
if hos.buf != nil {
n := copy(b, hos.buf[hos.offset:])
hos.offset += n
if hos.offset == len(hos.buf) {
hos.offset = 0
hos.buf = nil
}
return n, nil
}
if hos.firstRequest {
bio := bufio.NewReader(hos.Conn)
req, err := http.ReadRequest(bio)
if err != nil {
return 0, err
}
if req.Method != "GET" || req.Header.Get("Connection") != "Upgrade" {
return 0, io.EOF
}
buf, err := io.ReadAll(req.Body)
if err != nil {
return 0, err
}
n := copy(b, buf)
if n < len(buf) {
hos.buf = buf
hos.offset = n
}
req.Body.Close()
hos.bio = bio
hos.firstRequest = false
return n, nil
}
return hos.bio.Read(b)
}
func (hos *HTTPObfsServer) Write(b []byte) (int, error) {
if hos.firstResponse {
randBytes := make([]byte, 16)
cryptorand.Read(randBytes)
date := time.Now().Format(time.RFC1123)
resp := fmt.Sprintf(httpResponseTemplate, vMajor, vMinor, date, base64.URLEncoding.EncodeToString(randBytes))
_, err := hos.Conn.Write([]byte(resp))
if err != nil {
return 0, err
}
hos.firstResponse = false
}
return hos.Conn.Write(b)
}
func (hos *HTTPObfsServer) Upstream() any {
return hos.Conn
}
// NewHTTPObfsServer returns a server-side HTTPObfs.
func NewHTTPObfsServer(conn net.Conn) net.Conn {
return &HTTPObfsServer{
Conn: conn,
firstRequest: true,
firstResponse: true,
}
}
const httpResponseTemplate = "HTTP/1.1 101 Switching Protocols\r\n" +
"Server: nginx/1.%d.%d\r\n" +
"Date: %s\r\n" +
"Upgrade: websocket\r\n" +
"Connection: Upgrade\r\n" +
"Sec-WebSocket-Accept: %s\r\n" +
"\r\n"
var (
vMajor = rand.IntN(11)
vMinor = rand.IntN(12)
)

View File

@@ -0,0 +1,154 @@
package obfs
import (
"bytes"
"crypto/rand"
"encoding/binary"
"io"
"net"
"time"
B "github.com/sagernet/sing/common/buf"
)
// TLSObfsServer is the server side of the shadowsocks tls simple-obfs implementation.
type TLSObfsServer struct {
net.Conn
remain int
firstRequest bool
sessionTicketDone bool
firstResponse bool
}
func (tos *TLSObfsServer) read(b []byte, discardN int) (int, error) {
buf := B.Get(discardN)
_, err := io.ReadFull(tos.Conn, buf)
B.Put(buf)
if err != nil {
return 0, err
}
sizeBuf := make([]byte, 2)
_, err = io.ReadFull(tos.Conn, sizeBuf)
if err != nil {
return 0, nil
}
length := int(binary.BigEndian.Uint16(sizeBuf))
if length > len(b) {
n, err := tos.Conn.Read(b)
if err != nil {
return n, err
}
tos.remain = length - n
return n, nil
}
return io.ReadFull(tos.Conn, b[:length])
}
// skipOtherExts skips SNI & other TLS extensions.
func (tos *TLSObfsServer) skipOtherExts() error {
buf := make([]byte, 256)
_, err := tos.read(buf, 7)
if err != nil {
return err
}
_, err = io.ReadFull(tos.Conn, buf[:4*16+2])
return err
}
func (tos *TLSObfsServer) Read(b []byte) (int, error) {
if tos.remain > 0 {
length := tos.remain
if length > len(b) {
length = len(b)
}
n, err := io.ReadFull(tos.Conn, b[:length])
tos.remain -= n
return n, err
}
if tos.firstRequest {
tos.firstRequest = false
return tos.read(b, 9*16-4)
}
if !tos.sessionTicketDone {
tos.sessionTicketDone = true
err := tos.skipOtherExts()
if err != nil {
return 0, err
}
}
return tos.read(b, 3)
}
func (tos *TLSObfsServer) Write(b []byte) (int, error) {
length := len(b)
for i := 0; i < length; i += chunkSize {
end := i + chunkSize
if end > length {
end = length
}
n, err := tos.write(b[i:end])
if err != nil {
return n, err
}
}
return length, nil
}
func (tos *TLSObfsServer) write(b []byte) (int, error) {
if tos.firstResponse {
serverHello := makeServerHello(b)
_, err := tos.Conn.Write(serverHello)
tos.firstResponse = false
return len(b), err
}
buf := B.NewSize(5 + len(b))
defer buf.Release()
buf.Write([]byte{0x17, 0x03, 0x03})
binary.Write(buf, binary.BigEndian, uint16(len(b)))
buf.Write(b)
_, err := tos.Conn.Write(buf.Bytes())
if err != nil {
return 0, err
}
return len(b), nil
}
func (tos *TLSObfsServer) Upstream() any {
return tos.Conn
}
// NewTLSObfsServer returns a server-side TLS SimpleObfs.
func NewTLSObfsServer(conn net.Conn) net.Conn {
return &TLSObfsServer{
Conn: conn,
firstRequest: true,
firstResponse: true,
}
}
func makeServerHello(data []byte) []byte {
randBytes := make([]byte, 28)
sessionId := make([]byte, 32)
rand.Read(randBytes)
rand.Read(sessionId)
buf := &bytes.Buffer{}
buf.WriteByte(0x16)
binary.Write(buf, binary.BigEndian, uint16(0x0301))
binary.Write(buf, binary.BigEndian, uint16(91))
buf.Write([]byte{2, 0, 0, 87, 0x03, 0x03})
binary.Write(buf, binary.BigEndian, uint32(time.Now().Unix()))
buf.Write(randBytes)
buf.WriteByte(32)
buf.Write(sessionId)
buf.Write([]byte{0xcc, 0xa8})
buf.WriteByte(0)
buf.Write([]byte{0x00, 0x00})
buf.Write([]byte{0xff, 0x01, 0x00, 0x01, 0x00})
buf.Write([]byte{0x00, 0x17, 0x00, 0x00})
buf.Write([]byte{0x00, 0x0b, 0x00, 0x02, 0x01, 0x00})
buf.Write([]byte{0x14, 0x03, 0x03, 0x00, 0x01, 0x01})
buf.Write([]byte{0x16, 0x03, 0x03})
binary.Write(buf, binary.BigEndian, uint16(len(data)))
buf.Write(data)
return buf.Bytes()
}

144
transport/snell/address.go Normal file
View File

@@ -0,0 +1,144 @@
package snell
import (
"encoding/binary"
"net"
"strconv"
)
// SOCKS address types as defined in RFC 1928 section 5.
const (
atypIPv4 = 1
atypDomainName = 3
atypIPv6 = 4
)
// socksAddr represents a SOCKS address as defined in RFC 1928 section 5.
type socksAddr []byte
func (a socksAddr) String() string {
var host, port string
switch a[0] {
case atypDomainName:
hostLen := uint16(a[1])
host = string(a[2 : 2+hostLen])
port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
case atypIPv4:
host = net.IP(a[1 : 1+net.IPv4len]).String()
port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
case atypIPv6:
host = net.IP(a[1 : 1+net.IPv6len]).String()
port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
}
return net.JoinHostPort(host, port)
}
// UDPAddr converts a socksAddr to *net.UDPAddr.
func (a socksAddr) UDPAddr() *net.UDPAddr {
if len(a) == 0 {
return nil
}
switch a[0] {
case atypIPv4:
var ip [net.IPv4len]byte
copy(ip[0:], a[1:1+net.IPv4len])
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
case atypIPv6:
var ip [net.IPv6len]byte
copy(ip[0:], a[1:1+net.IPv6len])
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
}
return nil
}
// splitSocksAddr slices a SOCKS address from beginning of b. Returns nil if failed.
func splitSocksAddr(b []byte) socksAddr {
addrLen := 1
if len(b) < addrLen {
return nil
}
switch b[0] {
case atypDomainName:
if len(b) < 2 {
return nil
}
addrLen = 1 + 1 + int(b[1]) + 2
case atypIPv4:
addrLen = 1 + net.IPv4len + 2
case atypIPv6:
addrLen = 1 + net.IPv6len + 2
default:
return nil
}
if len(b) < addrLen {
return nil
}
return b[:addrLen]
}
// parseAddr parses the address in string s. Returns nil if failed.
func parseAddr(s string) socksAddr {
var addr socksAddr
host, port, err := net.SplitHostPort(s)
if err != nil {
return nil
}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
addr = make([]byte, 1+net.IPv4len+2)
addr[0] = atypIPv4
copy(addr[1:], ip4)
} else {
addr = make([]byte, 1+net.IPv6len+2)
addr[0] = atypIPv6
copy(addr[1:], ip)
}
} else {
if len(host) > 255 {
return nil
}
addr = make([]byte, 1+1+len(host)+2)
addr[0] = atypDomainName
addr[1] = byte(len(host))
copy(addr[2:], host)
}
portnum, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil
}
addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
return addr
}
// parseAddrToSocksAddr parses a socks addr from net.Addr.
// This is a fast path of parseAddr(addr.String()).
func parseAddrToSocksAddr(addr net.Addr) socksAddr {
var hostip net.IP
var port int
switch addr := addr.(type) {
case *net.UDPAddr:
hostip = addr.IP
port = addr.Port
case *net.TCPAddr:
hostip = addr.IP
port = addr.Port
case nil:
return nil
}
if hostip == nil {
return parseAddr(addr.String())
}
var parsed socksAddr
if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
parsed = make([]byte, 1+net.IPv4len+2)
parsed[0] = atypIPv4
copy(parsed[1:], ip4)
binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
} else {
parsed = make([]byte, 1+net.IPv6len+2)
parsed[0] = atypIPv6
copy(parsed[1:], hostip)
binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
}
return parsed
}

56
transport/snell/cipher.go Normal file
View File

@@ -0,0 +1,56 @@
package snell
import (
"crypto/aes"
"crypto/cipher"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/chacha20poly1305"
)
// NewAES128GCM returns the AES-128-GCM cipher used by snell v2/v3.
func NewAES128GCM(psk []byte) Cipher {
return &snellCipher{
psk: psk,
keySize: 16,
makeAEAD: aesGCM,
}
}
// NewChacha20Poly1305 returns the ChaCha20-Poly1305 cipher used by snell v1.
func NewChacha20Poly1305(psk []byte) Cipher {
return &snellCipher{
psk: psk,
keySize: 32,
makeAEAD: chacha20poly1305.New,
}
}
type snellCipher struct {
psk []byte
keySize int
makeAEAD func(key []byte) (cipher.AEAD, error)
}
func (sc *snellCipher) KeySize() int { return sc.keySize }
func (sc *snellCipher) SaltSize() int { return 16 }
func (sc *snellCipher) Encrypter(salt []byte) (cipher.AEAD, error) {
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
}
func (sc *snellCipher) Decrypter(salt []byte) (cipher.AEAD, error) {
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
}
func snellKDF(psk, salt []byte, keySize int) []byte {
return argon2.IDKey(psk, salt, 3, 8, 1, 32)[:keySize]
}
func aesGCM(key []byte) (cipher.AEAD, error) {
blk, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewGCM(blk)
}

120
transport/snell/client.go Normal file
View File

@@ -0,0 +1,120 @@
package snell
import (
"context"
"net"
"strconv"
"time"
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type ClientOptions struct {
Dialer N.Dialer
Server M.Socksaddr
PSK []byte
Version int
Reuse bool
ObfsMode string
ObfsHost string
}
type Client struct {
dialer N.Dialer
server M.Socksaddr
psk []byte
version int
reuse bool
obfsMode string
obfsHost string
pool *Pool
}
func NewClient(options ClientOptions) *Client {
c := &Client{
dialer: options.Dialer,
server: options.Server,
psk: options.PSK,
version: options.Version,
reuse: options.Reuse,
obfsMode: options.ObfsMode,
obfsHost: options.ObfsHost,
}
if c.reuse {
c.pool = NewPool(func(ctx context.Context) (*Snell, error) {
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
if err != nil {
return nil, err
}
return c.streamConn(conn), nil
})
}
return c
}
func (c *Client) streamConn(conn net.Conn) *Snell {
switch c.obfsMode {
case "tls":
conn = obfs.NewTLSObfs(conn, c.obfsHost)
case "http":
conn = obfs.NewHTTPObfs(conn, c.obfsHost, strconv.Itoa(int(c.server.Port)))
}
return StreamConn(conn, c.psk, c.version)
}
func (c *Client) writeHeader(ctx context.Context, conn net.Conn, destination M.Socksaddr, udp bool) (err error) {
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetWriteDeadline(deadline)
defer conn.SetWriteDeadline(time.Time{})
}
if udp {
err = WriteUDPHeader(conn, c.version)
if err == nil && c.version >= Version4 {
if sc, ok := conn.(*Snell); ok {
err = sc.ReadReply()
}
}
return
}
err = WriteHeaderWithReuse(conn, destination.AddrString(), uint(destination.Port), c.version, c.reuse)
return
}
func (c *Client) DialContext(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
if c.reuse {
conn, err := c.pool.Get()
if err != nil {
return nil, err
}
if err = c.writeHeader(ctx, conn, destination, false); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
if err != nil {
return nil, err
}
stream := c.streamConn(conn)
if err = c.writeHeader(ctx, stream, destination, false); err != nil {
_ = conn.Close()
return nil, err
}
return stream, nil
}
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
if err != nil {
return nil, err
}
stream := c.streamConn(conn)
if err = c.writeHeader(ctx, stream, destination, true); err != nil {
_ = conn.Close()
return nil, err
}
return PacketConn(stream), nil
}

153
transport/snell/pool.go Normal file
View File

@@ -0,0 +1,153 @@
package snell
import (
"context"
"io"
"net"
"runtime"
"sync"
"time"
)
// poolEntry holds a pooled item with its insertion time.
// connPool is a small connection pool with age-based eviction.
// milliseconds
// Pool is a pool of reusable snell connections.
type Pool struct {
pool *connPool
}
func (p *Pool) Get() (net.Conn, error) {
return p.GetContext(context.Background())
}
func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) {
elm, err := p.pool.GetContext(ctx)
if err != nil {
return nil, err
}
return &PoolConn{Snell: elm, pool: p}, nil
}
func (p *Pool) Put(conn *Snell) {
if err := HalfClose(conn); err != nil {
_ = conn.Close()
return
}
p.pool.put(conn)
}
// PoolConn wraps a pooled snell connection and returns it to the pool on Close.
type PoolConn struct {
*Snell
pool *Pool
closeWriteOnce sync.Once
closeWriteErr error
closeOnce sync.Once
closeErr error
}
func (pc *PoolConn) Read(b []byte) (int, error) {
n, err := pc.Snell.Read(b)
if err == ErrZeroChunk {
return n, io.EOF
}
return n, err
}
func (pc *PoolConn) Write(b []byte) (int, error) {
return pc.Snell.Write(b)
}
func (pc *PoolConn) CloseWrite() error {
pc.closeWriteOnce.Do(func() {
pc.closeWriteErr = writeZeroChunk(pc.Snell)
})
return pc.closeWriteErr
}
func (pc *PoolConn) Close() error {
pc.closeOnce.Do(func() {
if err := pc.CloseWrite(); err != nil {
pc.closeErr = err
_ = pc.Snell.Close()
return
}
_ = pc.Snell.Conn.SetReadDeadline(time.Time{})
pc.Snell.reply = false
pc.pool.pool.put(pc.Snell)
})
return pc.closeErr
}
// NewPool creates a new snell connection pool using the given factory.
func NewPool(factory func(context.Context) (*Snell, error)) *Pool {
cp := &connPool{
ch: make(chan *poolEntry, 10),
factory: factory,
maxAge: 15000,
evict: func(item *Snell) {
_ = item.Close()
},
}
p := &Pool{pool: cp}
runtime.SetFinalizer(p, recycle)
return p
}
type poolEntry struct {
elm *Snell
time time.Time
}
type connPool struct {
ch chan *poolEntry
factory func(context.Context) (*Snell, error)
evict func(*Snell)
maxAge int64
}
func (p *connPool) GetContext(ctx context.Context) (*Snell, error) {
now := time.Now()
for {
select {
case item := <-p.ch:
if p.maxAge != 0 && now.Sub(item.time).Milliseconds() > p.maxAge {
if p.evict != nil {
p.evict(item.elm)
}
continue
}
return item.elm, nil
default:
return p.factory(ctx)
}
}
}
func (p *connPool) put(item *Snell) {
e := &poolEntry{
elm: item,
time: time.Now(),
}
select {
case p.ch <- e:
return
default:
if p.evict != nil {
p.evict(item)
}
return
}
}
func recycle(p *Pool) {
for item := range p.pool.ch {
if p.pool.evict != nil {
p.pool.evict(item.elm)
}
}
}

294
transport/snell/service.go Normal file
View File

@@ -0,0 +1,294 @@
package snell
import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
)
type Handler interface {
NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, clientID string)
NewPacketConnection(ctx context.Context, conn net.PacketConn, source M.Socksaddr, clientID string)
}
type ServiceOptions struct {
PSK []byte
Version int
ObfsMode string
UDP bool
Logger logger.ContextLogger
Handler Handler
}
type Service struct {
psk []byte
version int
obfsMode string
udp bool
logger logger.ContextLogger
handler Handler
}
func NewService(options ServiceOptions) (*Service, error) {
version := options.Version
if version == 0 {
version = Version4
}
if version != Version4 && version != Version5 {
return nil, fmt.Errorf("snell inbound version %d is not supported", version)
}
if len(options.PSK) == 0 {
return nil, errors.New("snell inbound requires psk")
}
switch options.ObfsMode {
case "", "http", "tls":
default:
return nil, fmt.Errorf("snell inbound obfs mode error: %s", options.ObfsMode)
}
return &Service{
psk: options.PSK,
version: version,
obfsMode: options.ObfsMode,
udp: options.UDP,
logger: options.Logger,
handler: options.Handler,
}, nil
}
func (s *Service) NewConnection(ctx context.Context, rawConn net.Conn, source M.Socksaddr) error {
conn := rawConn
switch s.obfsMode {
case "http":
conn = obfs.NewHTTPObfsServer(conn)
case "tls":
conn = obfs.NewTLSObfsServer(conn)
}
stream := ServerStreamConn(conn, s.psk, s.version)
for {
reuse, err := s.handleRequest(ctx, stream, source)
if err != nil || !reuse {
return err
}
}
}
func (s *Service) handleRequest(ctx context.Context, stream *Snell, source M.Socksaddr) (bool, error) {
br := bufio.NewReader(stream)
version, err := br.ReadByte()
if err != nil {
return false, err
}
if version != Version {
return false, fmt.Errorf("snell invalid protocol version: %d", version)
}
command, err := br.ReadByte()
if err != nil {
return false, err
}
if command == CommandPing {
_, _ = stream.Write([]byte{CommandPong})
return false, nil
}
clientID, err := readClientID(br)
if err != nil {
return false, err
}
switch command {
case CommandConnect, CommandConnectV2:
return s.handleTCP(ctx, stream, br, command == CommandConnectV2, clientID, source)
case CommandUDP:
if !s.udp {
return false, errors.New("snell UDP is disabled")
}
return false, s.handleUDP(ctx, stream, clientID, source)
default:
return false, fmt.Errorf("snell unknown command: %d", command)
}
}
func (s *Service) handleTCP(ctx context.Context, stream *Snell, br *bufio.Reader, reuse bool, clientID string, source M.Socksaddr) (bool, error) {
hostLen, err := br.ReadByte()
if err != nil {
return false, err
}
if hostLen == 0 {
return false, errors.New("snell connect host is empty")
}
hostBytes := make([]byte, int(hostLen))
if _, err := io.ReadFull(br, hostBytes); err != nil {
return false, err
}
var portBytes [2]byte
if _, err := io.ReadFull(br, portBytes[:]); err != nil {
return false, err
}
destination := M.ParseSocksaddrHostPort(string(hostBytes), binary.BigEndian.Uint16(portBytes[:]))
conn := &tcpRequestConn{
Conn: stream,
reader: br,
reuse: reuse,
}
s.handler.NewConnection(ctx, conn, source, destination, clientID)
if !reuse {
return false, nil
}
return true, nil
}
func (s *Service) handleUDP(ctx context.Context, stream *Snell, clientID string, source M.Socksaddr) error {
if _, err := stream.Write([]byte{CommandTunnel}); err != nil {
return err
}
pc := &serverPacketConn{
conn: stream,
writeMu: &sync.Mutex{},
}
s.handler.NewPacketConnection(ctx, pc, source, clientID)
return nil
}
const maxPacketLength = 0x3fff
func readClientID(r *bufio.Reader) (string, error) {
length, err := r.ReadByte()
if err != nil {
return "", err
}
if length == 0 {
return "", nil
}
id := make([]byte, int(length))
if _, err := io.ReadFull(r, id); err != nil {
return "", err
}
return string(id), nil
}
func writeCommandError(w io.Writer, code byte, message string) error {
msg := []byte(message)
if len(msg) > 255 {
msg = msg[:255]
}
buf := make([]byte, 0, 3+len(msg))
buf = append(buf, CommandError, code, byte(len(msg)))
buf = append(buf, msg...)
_, err := w.Write(buf)
return err
}
type tcpRequestConn struct {
net.Conn
reader *bufio.Reader
reuse bool
writeMu sync.Mutex
closeOnce sync.Once
replyWritten bool
}
func (c *tcpRequestConn) Read(p []byte) (int, error) {
n, err := c.reader.Read(p)
if errors.Is(err, ErrZeroChunk) {
err = io.EOF
}
return n, err
}
func (c *tcpRequestConn) Write(p []byte) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if !c.replyWritten {
payload := make([]byte, 1+len(p))
payload[0] = CommandTunnel
copy(payload[1:], p)
if _, err := c.Conn.Write(payload); err != nil {
return 0, err
}
c.replyWritten = true
return len(p), nil
}
return c.Conn.Write(p)
}
func (c *tcpRequestConn) CloseWrite() error {
return c.Close()
}
func (c *tcpRequestConn) Close() error {
var err error
c.closeOnce.Do(func() {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if !c.replyWritten {
err = writeCommandError(c.Conn, 0x65, "Remote EOF")
if !c.reuse {
err = errors.Join(err, c.Conn.Close())
}
return
}
if c.reuse {
_, err = c.Conn.Write(nil)
return
}
err = c.Conn.Close()
})
return err
}
type serverPacketConn struct {
conn *Snell
writeMu *sync.Mutex
readBuf []byte
}
func (c *serverPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
if c.readBuf == nil {
c.readBuf = make([]byte, maxPacketLength)
}
for {
n, err := c.conn.Read(c.readBuf)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, ErrZeroChunk) {
return 0, nil, io.EOF
}
return 0, nil, err
}
request, err := ParseUDPRequest(c.readBuf[:n])
if err != nil {
return 0, nil, err
}
var destination M.Socksaddr
if request.Ip.IsValid() {
destination = M.SocksaddrFrom(request.Ip, request.Port)
} else {
destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
}
length := copy(p, request.Payload)
if destination.IsFqdn() {
return length, destination, nil
}
return length, destination.UDPAddr(), nil
}
}
func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
return WritePacketResponse(c.conn, addr, p)
}
func (c *serverPacketConn) Close() error { return c.conn.Close() }
func (c *serverPacketConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *serverPacketConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *serverPacketConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *serverPacketConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }

View File

@@ -0,0 +1,211 @@
package snell
import (
"crypto/cipher"
"crypto/rand"
"errors"
"io"
"net"
"github.com/sagernet/sing/common/buf"
)
// payloadSizeMask is the maximum size of payload in bytes.
// 16*1024 - 1
// >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead()
// ErrZeroChunk is returned when a zero-length chunk is read, which snell uses
// as an end-of-stream signal.
var ErrZeroChunk = errors.New("zero chunk")
// Cipher is the AEAD cipher abstraction used by the shadowaead stream.
type Cipher interface {
KeySize() int
SaltSize() int
Encrypter(salt []byte) (cipher.AEAD, error)
Decrypter(salt []byte) (cipher.AEAD, error)
}
const (
payloadSizeMask = 0x3FFF
bufSize = 17 * 1024
)
type aeadWriter struct {
io.Writer
cipher.AEAD
nonce [32]byte // should be sufficient for most nonce sizes
}
// newAEADWriter wraps an io.Writer with authenticated encryption.
func newAEADWriter(w io.Writer, aead cipher.AEAD) *aeadWriter {
return &aeadWriter{Writer: w, AEAD: aead}
}
// Write encrypts p and writes to the embedded io.Writer.
func (w *aeadWriter) Write(p []byte) (n int, err error) {
b := buf.Get(bufSize)
defer buf.Put(b)
nonce := w.nonce[:w.NonceSize()]
tag := w.Overhead()
off := 2 + tag
if len(p) == 0 {
b = b[:off]
b[0], b[1] = byte(0), byte(0)
w.Seal(b[:0], nonce, b[:2], nil)
increment(nonce)
_, err = w.Writer.Write(b)
return
}
for nr := 0; n < len(p) && err == nil; n += nr {
nr = payloadSizeMask
if n+nr > len(p) {
nr = len(p) - n
}
b = b[:off+nr+tag]
b[0], b[1] = byte(nr>>8), byte(nr)
w.Seal(b[:0], nonce, b[:2], nil)
increment(nonce)
w.Seal(b[:off], nonce, p[n:n+nr], nil)
increment(nonce)
_, err = w.Writer.Write(b)
}
return
}
type aeadReader struct {
io.Reader
cipher.AEAD
nonce [32]byte // should be sufficient for most nonce sizes
buf []byte // to be put back into bufPool
off int // offset to unconsumed part of buf
}
// newAEADReader wraps an io.Reader with authenticated decryption.
func newAEADReader(r io.Reader, aead cipher.AEAD) *aeadReader {
return &aeadReader{Reader: r, AEAD: aead}
}
// read and decrypt a record into p. len(p) >= max payload size + AEAD overhead.
func (r *aeadReader) read(p []byte) (int, error) {
nonce := r.nonce[:r.NonceSize()]
tag := r.Overhead()
p = p[:2+tag]
if _, err := io.ReadFull(r.Reader, p); err != nil {
return 0, err
}
_, err := r.Open(p[:0], nonce, p, nil)
increment(nonce)
if err != nil {
return 0, err
}
size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask
if size == 0 {
return 0, ErrZeroChunk
}
p = p[:size+tag]
if _, err := io.ReadFull(r.Reader, p); err != nil {
return 0, err
}
_, err = r.Open(p[:0], nonce, p, nil)
increment(nonce)
if err != nil {
return 0, err
}
return size, nil
}
// Read reads from the embedded io.Reader, decrypts and writes to p.
func (r *aeadReader) Read(p []byte) (int, error) {
if r.buf == nil {
if len(p) >= payloadSizeMask+r.Overhead() {
return r.read(p)
}
b := buf.Get(bufSize)
n, err := r.read(b)
if err != nil {
buf.Put(b)
return 0, err
}
r.buf = b[:n]
r.off = 0
}
n := copy(p, r.buf[r.off:])
r.off += n
if r.off == len(r.buf) {
buf.Put(r.buf[:cap(r.buf)])
r.buf = nil
}
return n, nil
}
// increment little-endian encoded unsigned integer b. Wrap around on overflow.
func increment(b []byte) {
for i := range b {
b[i]++
if b[i] != 0 {
return
}
}
}
// aeadConn wraps a stream-oriented net.Conn with the shadowaead cipher.
type aeadConn struct {
net.Conn
Cipher
r *aeadReader
w *aeadWriter
}
// newAEADConn wraps a stream-oriented net.Conn with cipher.
func newAEADConn(c net.Conn, ciph Cipher) *aeadConn {
return &aeadConn{Conn: c, Cipher: ciph}
}
func (c *aeadConn) initReader() error {
salt := make([]byte, c.SaltSize())
if _, err := io.ReadFull(c.Conn, salt); err != nil {
return err
}
aead, err := c.Decrypter(salt)
if err != nil {
return err
}
c.r = newAEADReader(c.Conn, aead)
return nil
}
func (c *aeadConn) Read(b []byte) (int, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
return c.r.Read(b)
}
func (c *aeadConn) initWriter() error {
salt := make([]byte, c.SaltSize())
if _, err := rand.Read(salt); err != nil {
return err
}
aead, err := c.Encrypter(salt)
if err != nil {
return err
}
_, err = c.Conn.Write(salt)
if err != nil {
return err
}
c.w = newAEADWriter(c.Conn, aead)
return nil
}
func (c *aeadConn) Write(b []byte) (int, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
return c.w.Write(b)
}

408
transport/snell/snell.go Normal file
View File

@@ -0,0 +1,408 @@
package snell
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"github.com/sagernet/sing/common/buf"
)
const (
Version1 = 1
Version2 = 2
Version3 = 3
Version4 = 4
Version5 = 5
DefaultSnellVersion = Version1
// max packet length
maxLength = 0x3FFF
)
const (
CommandPing byte = 0
CommandConnect byte = 1
CommandConnectV2 byte = 5
CommandUDP byte = 6
CommandUDPForward byte = 1
CommandTunnel byte = 0
CommandPong byte = 1
CommandError byte = 2
Version byte = 1
)
// Snell wraps an encrypted stream and handles the snell reply header.
type Snell struct {
net.Conn
buffer [1]byte
reply bool
}
func (s *Snell) Read(b []byte) (int, error) {
if err := s.ReadReply(); err != nil {
return 0, err
}
return s.Conn.Read(b)
}
func (s *Snell) ReadReply() error {
if s.reply {
return nil
}
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
return err
}
s.reply = true
if s.buffer[0] == CommandTunnel {
return nil
} else if s.buffer[0] != CommandError {
return errors.New("command not support")
}
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
return err
}
errcode := int(s.buffer[0])
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
return err
}
length := int(s.buffer[0])
msg := make([]byte, length)
if _, err := io.ReadFull(s.Conn, msg); err != nil {
return err
}
return fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
}
func WriteHeader(conn net.Conn, host string, port uint, version int) error {
return WriteHeaderWithReuse(conn, host, port, version, false)
}
func WriteHeaderWithReuse(conn net.Conn, host string, port uint, version int, reuse bool) error {
buffer := &bytes.Buffer{}
buffer.WriteByte(Version)
if version == Version2 || reuse {
buffer.WriteByte(CommandConnectV2)
} else {
buffer.WriteByte(CommandConnect)
}
buffer.WriteByte(0)
buffer.WriteByte(uint8(len(host)))
buffer.WriteString(host)
binary.Write(buffer, binary.BigEndian, uint16(port))
if _, err := conn.Write(buffer.Bytes()); err != nil {
return err
}
return nil
}
func WriteUDPHeader(conn net.Conn, version int) error {
if version < Version3 {
return errors.New("unsupport UDP version")
}
_, err := conn.Write([]byte{Version, CommandUDP, 0x00})
return err
}
// HalfClose only works after the request negotiated the reuse command.
func HalfClose(conn net.Conn) error {
if err := writeZeroChunk(conn); err != nil {
return err
}
if s, ok := conn.(*Snell); ok {
s.reply = false
}
return nil
}
// StreamConn wraps a raw connection with the snell stream cipher for the given version.
func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
if version >= Version4 {
return &Snell{Conn: newV4Conn(conn, psk)}
}
var cipher Cipher
if version != Version1 {
cipher = NewAES128GCM(psk)
} else {
cipher = NewChacha20Poly1305(psk)
}
return &Snell{Conn: newAEADConn(conn, cipher)}
}
// ServerStreamConn wraps a raw connection on the server side.
func ServerStreamConn(conn net.Conn, psk []byte, version int) *Snell {
stream := StreamConn(conn, psk, version)
stream.reply = true
return stream
}
func PacketConn(conn net.Conn) net.PacketConn {
return &packetConn{
Conn: conn,
}
}
func (s *Snell) WritePacketFrame(b []byte) (int, error) {
if fw, ok := s.Conn.(packetFrameWriter); ok {
return fw.WritePacketFrame(b)
}
return s.Conn.Write(b)
}
func WritePacket(w io.Writer, target, payload []byte) (int, error) {
maxPayloadLength := maxLength - udpRequestHeaderLength(target)
if maxPayloadLength <= 0 {
return 0, errors.New("snell UDP address too large")
}
if len(payload) <= maxPayloadLength {
return writePacket(w, target, payload)
}
return 0, errors.New("snell UDP payload too large")
}
func WritePacketResponse(w io.Writer, addr net.Addr, payload []byte) (int, error) {
buffer := &bytes.Buffer{}
target := parseAddrToSocksAddr(addr)
if len(target) == 0 {
return 0, errors.New("snell UDP response address invalid")
}
switch target[0] {
case atypIPv4:
if len(target) < 1+net.IPv4len+2 {
return 0, errors.New("snell UDP response address invalid")
}
buffer.WriteByte(0x04)
buffer.Write(target[1 : 1+net.IPv4len+2])
case atypIPv6:
if len(target) < 1+net.IPv6len+2 {
return 0, errors.New("snell UDP response address invalid")
}
buffer.WriteByte(0x06)
buffer.Write(target[1 : 1+net.IPv6len+2])
default:
return 0, errors.New("snell UDP response address invalid")
}
buffer.Write(payload)
var err error
if fw, ok := w.(packetFrameWriter); ok {
_, err = fw.WritePacketFrame(buffer.Bytes())
} else {
_, err = w.Write(buffer.Bytes())
}
if err != nil {
return 0, err
}
return len(payload), nil
}
// UDPRequest is a parsed snell UDP forward request.
type UDPRequest struct {
Host string
Ip netip.Addr
Port uint16
Payload []byte
}
func ParseUDPRequest(packet []byte) (UDPRequest, error) {
if len(packet) < 2 || packet[0] != CommandUDPForward {
return UDPRequest{}, errors.New("snell invalid UDP request")
}
if hostLen := int(packet[1]); hostLen != 0 {
if len(packet) <= 2+hostLen+2 {
return UDPRequest{}, errors.New("snell invalid UDP domain request")
}
offset := 2 + hostLen
return UDPRequest{
Host: string(packet[2:offset]),
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
Payload: packet[offset+2:],
}, nil
}
if len(packet) < 3 {
return UDPRequest{}, errors.New("snell invalid UDP IP request")
}
switch packet[2] {
case 0x04:
if len(packet) < 3+net.IPv4len+2 {
return UDPRequest{}, errors.New("snell invalid UDP IPv4 request")
}
offset := 3 + net.IPv4len
ip, _ := netip.AddrFromSlice(packet[3:offset])
return UDPRequest{
Ip: ip.Unmap(),
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
Payload: packet[offset+2:],
}, nil
case 0x06:
if len(packet) < 3+net.IPv6len+2 {
return UDPRequest{}, errors.New("snell invalid UDP IPv6 request")
}
offset := 3 + net.IPv6len
ip, _ := netip.AddrFromSlice(packet[3:offset])
return UDPRequest{
Ip: ip.Unmap(),
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
Payload: packet[offset+2:],
}, nil
default:
return UDPRequest{}, errors.New("snell invalid UDP address type")
}
}
func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) {
b := buf.Get(buf.UDPBufferSize)
defer buf.Put(b)
n, err := r.Read(b)
headLen := 1
if err != nil {
return nil, 0, err
}
if n < headLen {
return nil, 0, errors.New("insufficient UDP length")
}
switch b[0] {
case 0x04:
headLen += net.IPv4len + 2
if n < headLen {
err = errors.New("insufficient UDP length")
break
}
b[0] = atypIPv4
case 0x06:
headLen += net.IPv6len + 2
if n < headLen {
err = errors.New("insufficient UDP length")
break
}
b[0] = atypIPv6
default:
err = errors.New("ip version invalid")
}
if err != nil {
return nil, 0, err
}
addr := splitSocksAddr(b[0:])
if addr == nil {
return nil, 0, errors.New("remote address invalid")
}
uAddr := addr.UDPAddr()
if uAddr == nil {
return nil, 0, errors.New("parse addr error")
}
length := len(payload)
if n-headLen < length {
length = n - headLen
}
copy(payload[:], b[headLen:headLen+length])
return uAddr, length, nil
}
var endSignal = []byte{}
type packetFrameWriter interface {
WritePacketFrame([]byte) (int, error)
}
func writeZeroChunk(conn net.Conn) error {
if _, err := conn.Write(endSignal); err != nil {
return err
}
return nil
}
func writePacket(w io.Writer, target, payload []byte) (int, error) {
buffer := &bytes.Buffer{}
buffer.WriteByte(CommandUDPForward)
switch target[0] {
case atypDomainName:
hostLen := target[1]
if len(target) < 1+1+int(hostLen)+2 {
return 0, errors.New("snell UDP address invalid")
}
buffer.Write(target[1 : 1+1+hostLen+2])
case atypIPv4:
if len(target) < 1+net.IPv4len+2 {
return 0, errors.New("snell UDP address invalid")
}
buffer.Write([]byte{0x00, 0x04})
buffer.Write(target[1 : 1+net.IPv4len+2])
case atypIPv6:
if len(target) < 1+net.IPv6len+2 {
return 0, errors.New("snell UDP address invalid")
}
buffer.Write([]byte{0x00, 0x06})
buffer.Write(target[1 : 1+net.IPv6len+2])
default:
return 0, errors.New("snell UDP address invalid")
}
buffer.Write(payload)
if fw, ok := w.(packetFrameWriter); ok {
_, err := fw.WritePacketFrame(buffer.Bytes())
if err != nil {
return 0, err
}
return len(payload), nil
}
_, err := w.Write(buffer.Bytes())
if err != nil {
return 0, err
}
return len(payload), nil
}
func udpRequestHeaderLength(target []byte) int {
if len(target) == 0 {
return maxLength + 1
}
switch target[0] {
case atypDomainName:
if len(target) < 2 {
return maxLength + 1
}
return 1 + 1 + int(target[1]) + 2
case atypIPv4:
return 1 + 2 + net.IPv4len + 2
case atypIPv6:
return 1 + 2 + net.IPv6len + 2
default:
return maxLength + 1
}
}
type packetConn struct {
net.Conn
rMux sync.Mutex
wMux sync.Mutex
}
func (pc *packetConn) WritePacketFrame(b []byte) (int, error) {
if s, ok := pc.Conn.(*Snell); ok {
if fw, ok := s.Conn.(packetFrameWriter); ok {
return fw.WritePacketFrame(b)
}
}
return pc.Conn.Write(b)
}
func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
pc.wMux.Lock()
defer pc.wMux.Unlock()
return WritePacket(pc, parseAddrToSocksAddr(addr), b)
}
func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
pc.rMux.Lock()
defer pc.rMux.Unlock()
addr, n, err := ReadPacket(pc.Conn, b)
if err != nil {
return 0, nil, err
}
return n, addr, nil
}

463
transport/snell/v4.go Normal file
View File

@@ -0,0 +1,463 @@
package snell
import (
"crypto/cipher"
cryptorand "crypto/rand"
"encoding/binary"
"errors"
"io"
"math"
"math/big"
"math/bits"
"net"
"sync"
"time"
)
const (
v4SaltSize = 16
v4NonceSize = 12
v4HeaderPlainSize = 7
v4HeaderCipherSize = v4HeaderPlainSize + 16
v4FrameSize = 1460
v4InitialPaddingMin = 0x100
v4InitialPaddingSpan = 0x100
)
type v4Conn struct {
net.Conn
psk []byte
r *v4Reader
w *v4Writer
}
func newV4Conn(conn net.Conn, psk []byte) *v4Conn {
return &v4Conn{Conn: conn, psk: psk}
}
func (c *v4Conn) initReader() error {
salt := make([]byte, v4SaltSize)
if _, err := io.ReadFull(c.Conn, salt); err != nil {
return err
}
aead, err := v4AEAD(c.psk, salt)
if err != nil {
return err
}
c.r = &v4Reader{Reader: c.Conn, aead: aead}
return nil
}
func (c *v4Conn) initWriter() error {
w, err := newV4Writer(c.Conn, c.psk)
if err != nil {
return err
}
c.w = w
return nil
}
func (c *v4Conn) Read(b []byte) (int, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
return c.r.Read(b)
}
func (c *v4Conn) Write(b []byte) (int, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
return c.w.Write(b)
}
func (c *v4Conn) WritePacketFrame(b []byte) (int, error) {
if len(b) > maxLength {
return 0, errors.New("snell v4 frame too large")
}
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
c.w.mux.Lock()
defer c.w.mux.Unlock()
if err := c.w.writeFrame(b, c.w.nextFramePaddingLength(len(b))); err != nil {
return 0, err
}
return len(b), nil
}
func (c *v4Conn) WriteTo(w io.Writer) (int64, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
var written int64
buf := make([]byte, maxLength)
for {
n, err := c.r.Read(buf)
if n > 0 {
nw, ew := w.Write(buf[:n])
written += int64(nw)
if ew != nil {
return written, ew
}
if nw != n {
return written, io.ErrShortWrite
}
}
if err != nil {
if err == io.EOF {
err = nil
}
return written, err
}
}
}
func (c *v4Conn) ReadFrom(r io.Reader) (int64, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
var read int64
buf := make([]byte, maxLength)
for {
n, err := r.Read(buf)
if n > 0 {
read += int64(n)
if _, ew := c.w.Write(buf[:n]); ew != nil {
return read, ew
}
}
if err != nil {
if err == io.EOF {
err = nil
}
return read, err
}
}
}
func v4AEAD(psk, salt []byte) (cipher.AEAD, error) {
return aesGCM(snellKDF(psk, salt, 16))
}
type v4Reader struct {
io.Reader
aead cipher.AEAD
nonce [v4NonceSize]byte
buf []byte
mux sync.Mutex
}
func (r *v4Reader) Read(b []byte) (int, error) {
r.mux.Lock()
defer r.mux.Unlock()
if len(r.buf) == 0 {
payload, err := r.readFrame()
if err != nil {
return 0, err
}
r.buf = payload
}
n := copy(b, r.buf)
r.buf = r.buf[n:]
return n, nil
}
func (r *v4Reader) readFrame() ([]byte, error) {
headerCipher := make([]byte, v4HeaderCipherSize)
if _, err := io.ReadFull(r.Reader, headerCipher); err != nil {
return nil, err
}
header, err := r.aead.Open(headerCipher[:0], r.nonce[:], headerCipher, nil)
incrementV4Nonce(r.nonce[:])
if err != nil {
return nil, err
}
if len(header) != v4HeaderPlainSize || header[0] != 4 {
return nil, errors.New("snell v4 invalid frame header")
}
paddingLength := int(binary.BigEndian.Uint16(header[3:5]))
payloadLength := int(binary.BigEndian.Uint16(header[5:7]))
if payloadLength == 0 {
if paddingLength != 0 {
return nil, errors.New("snell v4 zero chunk with padding")
}
return nil, ErrZeroChunk
}
if payloadLength > maxLength || paddingLength > maxLength {
return nil, errors.New("snell v4 frame too large")
}
payloadCipherLength := payloadLength + r.aead.Overhead()
frame := make([]byte, paddingLength+payloadCipherLength)
if _, err := io.ReadFull(r.Reader, frame); err != nil {
return nil, err
}
if paddingLength > 0 {
swapPadding(frame[:paddingLength], frame[paddingLength:])
}
payloadCipher := frame[paddingLength:]
payload, err := r.aead.Open(payloadCipher[:0], r.nonce[:], payloadCipher, nil)
incrementV4Nonce(r.nonce[:])
if err != nil {
return nil, err
}
return payload, nil
}
type v4Writer struct {
io.Writer
aead cipher.AEAD
nonce [v4NonceSize]byte
salt [v4SaltSize]byte
saltSent bool
initialPaddingLength uint16
payloadLimit uint16
lastWrite time.Time
mux sync.Mutex
}
func newV4Writer(w io.Writer, psk []byte) (*v4Writer, error) {
var salt [v4SaltSize]byte
if _, err := io.ReadFull(cryptorand.Reader, salt[:]); err != nil {
return nil, err
}
aead, err := v4AEAD(psk, salt[:])
if err != nil {
return nil, err
}
paddingDelta, err := cryptoRandomInt(v4InitialPaddingSpan)
if err != nil {
return nil, err
}
return &v4Writer{
Writer: w,
aead: aead,
salt: salt,
initialPaddingLength: uint16(v4InitialPaddingMin + paddingDelta),
}, nil
}
func (w *v4Writer) Write(b []byte) (int, error) {
w.mux.Lock()
defer w.mux.Unlock()
if len(b) == 0 {
return 0, w.writeFrame(nil, 0)
}
written := 0
for written < len(b) {
payloadLimit := int(w.nextPayloadLimit())
if payloadLimit <= 0 || payloadLimit > maxLength {
payloadLimit = maxLength
}
end := written + payloadLimit
if end > len(b) {
end = len(b)
}
paddingLength := w.nextFramePaddingLength(end - written)
if err := w.writeFrame(b[written:end], paddingLength); err != nil {
return written, err
}
written = end
}
return written, nil
}
func (w *v4Writer) nextPayloadLimit() uint16 {
now := time.Now()
var payloadLimit uint16
switch {
case w.lastWrite.IsZero():
payloadLimit = v4FrameSize - 55 - w.initialPaddingLength
case now.Sub(w.lastWrite) > 30*time.Second:
payloadLimit = v4FrameSize - 39
default:
payloadLimit = w.payloadLimit
}
w.lastWrite = now
if payloadLimit <= maxLength-1 {
next := int(payloadLimit) + v4FrameSize - 39
if next > maxLength {
next = maxLength
}
w.payloadLimit = uint16(next)
} else {
w.payloadLimit = maxLength
}
return payloadLimit
}
func (w *v4Writer) nextFramePaddingLength(payloadLength int) int {
if w.saltSent || payloadLength == 0 {
return 0
}
return int(w.initialPaddingLength)
}
func (w *v4Writer) writeFrame(payload []byte, paddingLength int) error {
if len(payload) > maxLength || paddingLength > maxLength {
return errors.New("snell v4 frame too large")
}
if len(payload) == 0 && paddingLength != 0 {
return errors.New("snell v4 zero chunk with padding")
}
header := make([]byte, v4HeaderPlainSize)
header[0] = 4
binary.BigEndian.PutUint16(header[3:5], uint16(paddingLength))
binary.BigEndian.PutUint16(header[5:7], uint16(len(payload)))
headerCipher := w.aead.Seal(nil, w.nonce[:], header, nil)
incrementV4Nonce(w.nonce[:])
var payloadCipher []byte
if len(payload) > 0 {
payloadCipher = w.aead.Seal(nil, w.nonce[:], payload, nil)
incrementV4Nonce(w.nonce[:])
}
frameLength := len(headerCipher) + paddingLength + len(payloadCipher)
if !w.saltSent {
frameLength += v4SaltSize
}
frame := make([]byte, 0, frameLength)
if !w.saltSent {
frame = append(frame, w.salt[:]...)
w.saltSent = true
}
frame = append(frame, headerCipher...)
if paddingLength > 0 {
padding, err := makeV4Padding(payloadCipher, paddingLength)
if err != nil {
return err
}
swapPadding(padding, payloadCipher)
frame = append(frame, padding...)
}
frame = append(frame, payloadCipher...)
return writeFull(w.Writer, frame)
}
func swapPadding(padding, payloadCipher []byte) {
limit := len(padding)
if len(payloadCipher) < limit {
limit = len(payloadCipher)
}
for i := 0; i < limit; i += 2 {
padding[i], payloadCipher[i] = payloadCipher[i], padding[i]
}
}
func makeV4Padding(payloadCipher []byte, paddingLength int) ([]byte, error) {
if paddingLength <= 0 {
return nil, nil
}
payloadOnes := countV4PayloadOnes(payloadCipher)
payloadZeros := 8*len(payloadCipher) - payloadOnes
if payloadZeros <= 0 {
return makeV4RandomPadding(paddingLength)
}
ratio := float64(payloadOnes) / float64(payloadZeros)
if ratio <= 0.5 || ratio >= 1.6 {
return makeV4RandomPadding(paddingLength)
}
targetRatioBase := 1.6
if payloadZeros < payloadOnes {
targetRatioBase = 0.4
}
jitter, err := randomUnitFloat64()
if err != nil {
return nil, err
}
targetRatio := targetRatioBase + jitter/10
totalBits := 8 * (paddingLength + len(payloadCipher))
targetOnes := int(float64(totalBits)*(targetRatio/(targetRatio+1)) - float64(payloadOnes))
if targetOnes < 0 || targetOnes > 8*paddingLength {
return makeV4RandomPadding(paddingLength)
}
return makeV4BitCountPadding(paddingLength, targetOnes)
}
func countV4PayloadOnes(payloadCipher []byte) int {
limit := len(payloadCipher) &^ 3
ones := 0
for _, b := range payloadCipher[:limit] {
ones += bits.OnesCount8(b)
}
return ones
}
func makeV4RandomPadding(length int) ([]byte, error) {
padding := make([]byte, length)
_, err := io.ReadFull(cryptorand.Reader, padding)
return padding, err
}
func makeV4BitCountPadding(length, oneBits int) ([]byte, error) {
totalBits := 8 * length
if oneBits < 0 || oneBits > totalBits {
return nil, errors.New("snell v4 invalid padding bit count")
}
bitset := make([]byte, totalBits)
for i := 0; i < oneBits; i++ {
bitset[i] = 1
}
for i := totalBits - 1; i > 0; i-- {
j, err := cryptoRandomInt(i + 1)
if err != nil {
return nil, err
}
bitset[i], bitset[j] = bitset[j], bitset[i]
}
padding := make([]byte, length)
for i, bit := range bitset {
if bit == 1 {
padding[i/8] |= 1 << uint(i%8)
}
}
return padding, nil
}
func cryptoRandomInt(max int) (int, error) {
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(int64(max)))
if err != nil {
return 0, err
}
return int(n.Int64()), nil
}
func randomUnitFloat64() (float64, error) {
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(1<<53))
if err != nil {
return 0, err
}
return float64(n.Int64()) / math.Exp2(53), nil
}
func writeFull(w io.Writer, p []byte) error {
for len(p) > 0 {
n, err := w.Write(p)
if err != nil {
return err
}
if n == 0 {
return io.ErrShortWrite
}
p = p[n:]
}
return nil
}
func incrementV4Nonce(nonce []byte) {
for i := range nonce {
nonce[i]++
if nonce[i] != 0 {
return
}
}
}

View File

@@ -18,9 +18,12 @@ const (
)
const (
headerSize = 1 + 4 + 4
maxFrameSize = 256 * 1024
maxDataPayload = 32 * 1024
headerSize = 1 + 4 + 4
// maxQueuedBytesPerStream bounds unread payload retained by a single logical stream.
// Backpressure is applied to the demux loop instead of dropping data.
maxQueuedBytesPerStream = 4 * 1024 * 1024
maxFrameSize = 256 * 1024
maxDataPayload = 128 * 1024
)
type acceptEvent struct {
@@ -344,6 +347,8 @@ type stream struct {
closeErr error
readBuf []byte
queue [][]byte
// queuedBytes includes unread bytes in readBuf and queue.
queuedBytes int
localAddr net.Addr
remoteAddr net.Addr
@@ -362,16 +367,20 @@ func newStream(session *Session, id uint32) *stream {
func (c *stream) enqueue(payload []byte) {
c.mu.Lock()
for !c.closed && c.queuedBytes+len(payload) > maxQueuedBytesPerStream {
c.cond.Wait()
}
if c.closed {
c.mu.Unlock()
return
}
c.queuedBytes += len(payload)
if len(c.readBuf) == 0 && len(c.queue) == 0 {
c.readBuf = payload
} else {
c.queue = append(c.queue, payload)
}
c.cond.Signal()
c.cond.Broadcast()
c.mu.Unlock()
}
@@ -413,7 +422,11 @@ func (c *stream) Read(p []byte) (int, error) {
}
if len(c.readBuf) == 0 && len(c.queue) > 0 {
c.readBuf = c.queue[0]
c.queue[0] = nil
c.queue = c.queue[1:]
if len(c.queue) == 0 {
c.queue = nil
}
}
if len(c.readBuf) == 0 && c.closed {
if c.closeErr == nil {
@@ -424,6 +437,14 @@ func (c *stream) Read(p []byte) (int, error) {
n := copy(p, c.readBuf)
c.readBuf = c.readBuf[n:]
if len(c.readBuf) == 0 {
c.readBuf = nil
}
c.queuedBytes -= n
if c.queuedBytes < 0 {
c.queuedBytes = 0
}
c.cond.Broadcast()
return n, nil
}

View File

@@ -0,0 +1,91 @@
package multiplex
import (
"bytes"
"crypto/rand"
"io"
"net"
"sync"
"testing"
"time"
)
// TestSession_LargeTransferBackpressure verifies that a transfer larger than
// maxQueuedBytesPerStream completes correctly: the demux loop applies
// backpressure (cond.Wait) instead of dropping data, and the reader draining
// the stream wakes the blocked loop without deadlock.
func TestSession_LargeTransferBackpressure(t *testing.T) {
c1, c2 := net.Pipe()
client, err := NewClientSession(c1)
if err != nil {
t.Fatalf("client session: %v", err)
}
server, err := NewServerSession(c2)
if err != nil {
t.Fatalf("server session: %v", err)
}
defer client.Close()
defer server.Close()
// Payload bigger than the per-stream backpressure window (4MB).
const total = 12 * 1024 * 1024
payload := make([]byte, total)
if _, err := rand.Read(payload); err != nil {
t.Fatalf("rand: %v", err)
}
var wg sync.WaitGroup
wg.Add(2)
var writeErr error
go func() {
defer wg.Done()
stream, err := client.OpenStream([]byte("hello"))
if err != nil {
writeErr = err
return
}
defer stream.Close()
if _, err := stream.Write(payload); err != nil {
writeErr = err
return
}
_ = stream.(interface{ CloseWrite() error }).CloseWrite()
}()
var got []byte
var readErr error
go func() {
defer wg.Done()
stream, openPayload, err := server.AcceptStream()
if err != nil {
readErr = err
return
}
if string(openPayload) != "hello" {
readErr = io.ErrUnexpectedEOF
return
}
got, readErr = io.ReadAll(stream)
}()
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
case <-time.After(30 * time.Second):
t.Fatal("transfer deadlocked (backpressure did not release)")
}
if writeErr != nil {
t.Fatalf("write: %v", writeErr)
}
if readErr != nil {
t.Fatalf("read: %v", readErr)
}
if !bytes.Equal(got, payload) {
t.Fatalf("payload mismatch: got %d bytes, want %d", len(got), len(payload))
}
}

View File

@@ -0,0 +1,56 @@
package sudoku
import "testing"
func TestNormalizeASCIIMode(t *testing.T) {
tests := []struct {
in string
want string
}{
{"", "prefer_entropy"},
{"entropy", "prefer_entropy"},
{"prefer_ascii", "prefer_ascii"},
{"up_ascii_down_entropy", "up_ascii_down_entropy"},
{"up_entropy_down_ascii", "up_entropy_down_ascii"},
{"up_prefer_ascii_down_prefer_entropy", "up_ascii_down_entropy"},
}
for _, tt := range tests {
got, err := NormalizeASCIIMode(tt.in)
if err != nil {
t.Fatalf("NormalizeASCIIMode(%q): %v", tt.in, err)
}
if got != tt.want {
t.Fatalf("NormalizeASCIIMode(%q) = %q, want %q", tt.in, got, tt.want)
}
}
if _, err := NormalizeASCIIMode("up_ascii_down_binary"); err == nil {
t.Fatalf("expected invalid directional mode to fail")
}
}
func TestNewTableWithCustomDirectionalOpposite(t *testing.T) {
table, err := NewTableWithCustom("seed", "up_ascii_down_entropy", "xpxvvpvv")
if err != nil {
t.Fatalf("NewTableWithCustom: %v", err)
}
if !table.IsASCII {
t.Fatalf("uplink table should be ascii")
}
opposite := table.OppositeDirection()
if opposite == nil || opposite == table {
t.Fatalf("expected distinct opposite table")
}
if opposite.IsASCII {
t.Fatalf("downlink table should be entropy/custom")
}
symmetric, err := NewTableWithCustom("seed", "prefer_ascii", "xpxvvpvv")
if err != nil {
t.Fatalf("NewTableWithCustom symmetric: %v", err)
}
if symmetric.OppositeDirection() != symmetric {
t.Fatalf("symmetric table should point to itself")
}
}

View File

@@ -3,6 +3,7 @@ package sudoku
import (
"bufio"
"bytes"
"io"
"net"
"sync"
"sync/atomic"
@@ -10,6 +11,8 @@ import (
const IOBufferSize = 32 * 1024
const minDecodeReadSize = 64
var perm4 = [24][4]byte{
{0, 1, 2, 3},
{0, 1, 3, 2},
@@ -52,7 +55,7 @@ type Conn struct {
writeMu sync.Mutex
writeBuf []byte
rng randomSource
rng *sudokuRand
paddingThreshold uint64
}
@@ -97,6 +100,9 @@ func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
}
func (sc *Conn) StopRecording() {
if sc == nil {
return
}
sc.recordLock.Lock()
sc.recording.Store(false)
sc.recorder = nil
@@ -115,6 +121,9 @@ func (sc *Conn) GetBufferedAndRecorded() []byte {
if sc.recorder != nil {
recorded = sc.recorder.Bytes()
}
if sc.reader == nil {
return recorded
}
buffered := sc.reader.Buffered()
if buffered > 0 {
@@ -131,6 +140,9 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
if sc == nil || sc.Conn == nil || sc.table == nil || sc.table.layout == nil || sc.rng == nil {
return 0, io.ErrClosedPipe
}
sc.writeMu.Lock()
defer sc.writeMu.Unlock()
@@ -140,16 +152,19 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
}
func (sc *Conn) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
if sc == nil || sc.Conn == nil || sc.reader == nil || len(sc.rawBuf) == 0 || sc.table == nil || sc.table.layout == nil {
return 0, io.ErrClosedPipe
}
if n, ok := drainPending(p, &sc.pendingData); ok {
return n, nil
}
outN := 0
for {
if sc.pendingData.available() > 0 {
break
}
nr, rErr := sc.reader.Read(sc.rawBuf)
nr, rErr := readRawLimited(sc.Conn, sc.reader, sc.rawBuf[:sudokuReadSize(len(p)-outN, len(sc.rawBuf))])
if nr > 0 {
chunk := sc.rawBuf[:nr]
if sc.recording.Load() {
@@ -160,34 +175,80 @@ func (sc *Conn) Read(p []byte) (n int, err error) {
sc.recordLock.Unlock()
}
layout := sc.table.layout
for _, b := range chunk {
table := sc.table
layout := table.layout
for i := 0; i < len(chunk); {
if sc.hintCount == 0 && outN < len(p) && i+3 < len(chunk) &&
layout.hintTable[chunk[i]] &&
layout.hintTable[chunk[i+1]] &&
layout.hintTable[chunk[i+2]] &&
layout.hintTable[chunk[i+3]] {
val, ok := table.DecodeMap[packHintBytes(chunk[i], chunk[i+1], chunk[i+2], chunk[i+3])]
if !ok {
return 0, ErrInvalidSudokuMapMiss
}
p[outN] = val
outN++
i += 4
continue
}
b := chunk[i]
i++
if !layout.hintTable[b] {
continue
}
sc.hintBuf[sc.hintCount] = b
sc.hintCount++
if sc.hintCount == len(sc.hintBuf) {
key := packHintsToKey(sc.hintBuf)
val, ok := sc.table.DecodeMap[key]
if !ok {
return 0, ErrInvalidSudokuMapMiss
}
sc.pendingData.appendByte(val)
sc.hintCount = 0
if sc.hintCount != len(sc.hintBuf) {
continue
}
val, ok := table.DecodeMap[packHintBytes(sc.hintBuf[0], sc.hintBuf[1], sc.hintBuf[2], sc.hintBuf[3])]
if !ok {
return 0, ErrInvalidSudokuMapMiss
}
outN = appendDecodedByte(p, outN, &sc.pendingData, val)
sc.hintCount = 0
}
}
if rErr != nil {
if outN > 0 {
return outN, nil
}
if n, ok := drainPending(p, &sc.pendingData); ok {
return n, nil
}
return 0, rErr
}
if sc.pendingData.available() > 0 {
break
if outN > 0 {
return outN, nil
}
}
n, _ = drainPending(p, &sc.pendingData)
return n, nil
}
func sudokuReadSize(decodedRemaining, maxRaw int) int {
if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 {
return maxRaw
}
if decodedRemaining > (maxRaw-minDecodeReadSize)/5 {
return maxRaw
}
return decodedRemaining*5 + minDecodeReadSize
}
func readRawLimited(conn net.Conn, reader *bufio.Reader, dst []byte) (int, error) {
if len(dst) == 0 {
return 0, nil
}
if reader != nil && reader.Buffered() > 0 {
return reader.Read(dst)
}
if conn == nil {
return 0, io.ErrClosedPipe
}
return conn.Read(dst)
}

View File

@@ -0,0 +1,51 @@
package sudoku
import (
"bytes"
"io"
"testing"
)
// TestConn_Roundtrip exercises the optimized Conn encode/decode hot paths:
// the no-padding fast path (pMin==pMax==0), the always-padding path
// (pMin==pMax==100), a probabilistic range, and the adaptive read-size /
// 4-byte fast hint decode path across a variety of payload sizes and modes.
func TestConn_Roundtrip(t *testing.T) {
modes := []string{"prefer_entropy", "prefer_ascii"}
paddings := []struct{ min, max int }{
{0, 0}, // no-padding specialized path
{100, 100}, // always-padding specialized path
{20, 60}, // probabilistic path
}
sizes := []int{1, 3, 4, 7, 16, 100, 1000, 64 * 1024}
for _, mode := range modes {
for _, pad := range paddings {
for _, size := range sizes {
payload := make([]byte, size)
for i := range payload {
payload[i] = byte(i*31 + 7)
}
table := NewTable("conn-roundtrip-seed", mode)
// Encode via Conn.Write.
w := &mockConn{}
enc := NewConn(w, table, pad.min, pad.max, false)
if _, err := enc.Write(payload); err != nil {
t.Fatalf("mode=%s pad=%v size=%d write: %v", mode, pad, size, err)
}
// Decode via Conn.Read using the same table.
dec := NewConn(&mockConn{readBuf: w.writeBuf}, table, pad.min, pad.max, false)
got := make([]byte, size)
if _, err := io.ReadFull(dec, got); err != nil {
t.Fatalf("mode=%s pad=%v size=%d read: %v", mode, pad, size, err)
}
if !bytes.Equal(got, payload) {
t.Fatalf("mode=%s pad=%v size=%d roundtrip mismatch", mode, pad, size)
}
}
}
}
}

View File

@@ -1,9 +1,12 @@
package sudoku
func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThreshold uint64, p []byte) []byte {
func encodeSudokuPayload(dst []byte, table *Table, rng *sudokuRand, paddingThreshold uint64, p []byte) []byte {
if len(p) == 0 {
return dst[:0]
}
if paddingThreshold == 0 {
return encodeSudokuPayloadNoPadding(dst, table, rng, p)
}
outCapacity := len(p)*6 + 1
if cap(dst) < outCapacity {
@@ -13,8 +16,25 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre
pads := table.PaddingPool
padLen := len(pads)
if paddingThreshold >= probOne {
for _, b := range p {
out = append(out, pads[rng.Intn(padLen)])
puzzles := table.EncodeTable[b]
puzzle := puzzles[rng.Intn(len(puzzles))]
perm := perm4[rng.Intn(len(perm4))]
for _, idx := range perm {
out = append(out, pads[rng.Intn(padLen)], puzzle[idx])
}
}
out = append(out, pads[rng.Intn(padLen)])
return out
}
for _, b := range p {
if shouldPad(rng, paddingThreshold) {
if uint64(rng.Uint32()) < paddingThreshold {
out = append(out, pads[rng.Intn(padLen)])
}
@@ -22,15 +42,31 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre
puzzle := puzzles[rng.Intn(len(puzzles))]
perm := perm4[rng.Intn(len(perm4))]
for _, idx := range perm {
if shouldPad(rng, paddingThreshold) {
if uint64(rng.Uint32()) < paddingThreshold {
out = append(out, pads[rng.Intn(padLen)])
}
out = append(out, puzzle[idx])
}
}
if shouldPad(rng, paddingThreshold) {
if uint64(rng.Uint32()) < paddingThreshold {
out = append(out, pads[rng.Intn(padLen)])
}
return out
}
func encodeSudokuPayloadNoPadding(dst []byte, table *Table, rng *sudokuRand, p []byte) []byte {
outCapacity := len(p) * 4
if cap(dst) < outCapacity {
dst = make([]byte, 0, outCapacity)
}
out := dst[:0]
for _, b := range p {
puzzles := table.EncodeTable[b]
puzzle := puzzles[rng.Intn(len(puzzles))]
perm := perm4[rng.Intn(len(perm4))]
out = append(out, puzzle[perm[0]], puzzle[perm[1]], puzzle[perm[2]], puzzle[perm[3]])
}
return out
}

View File

@@ -8,9 +8,9 @@ import (
)
const (
RngBatchSize = 128
packedProtectedPrefixBytes = 14
packedIOBufferSize = 64 * 1024
packedDecodeBufferSize = 96 * 1024
)
// PackedConn encodes traffic with the packed Sudoku layout while preserving
@@ -35,7 +35,7 @@ type PackedConn struct {
readBits int
// Padding selection matches Conn's threshold-based model.
rng randomSource
rng *sudokuRand
paddingThreshold uint64
padMarker byte
padPool []byte
@@ -67,18 +67,20 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
pc := &PackedConn{
Conn: c,
table: table,
reader: bufio.NewReaderSize(c, IOBufferSize),
rawBuf: make([]byte, IOBufferSize),
reader: bufio.NewReaderSize(c, packedIOBufferSize),
rawBuf: make([]byte, packedDecodeBufferSize),
pendingData: newPendingBuffer(4096),
writeBuf: make([]byte, 0, 4096),
rng: localRng,
paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
}
pc.padMarker = table.layout.padMarker
for _, b := range table.PaddingPool {
if b != pc.padMarker {
pc.padPool = append(pc.padPool, b)
if table != nil && table.layout != nil {
pc.padMarker = table.layout.padMarker
for _, b := range table.PaddingPool {
if b != pc.padMarker {
pc.padPool = append(pc.padPool, b)
}
}
}
if len(pc.padPool) == 0 {
@@ -87,18 +89,6 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
return pc
}
func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
if shouldPad(pc.rng, pc.paddingThreshold) {
out = append(out, pc.getPaddingByte())
}
return out
}
func (pc *PackedConn) appendGroup(out []byte, group byte) []byte {
out = pc.maybeAddPadding(out)
return append(out, pc.table.layout.groupByte(group))
}
func (pc *PackedConn) appendForcedPadding(out []byte) []byte {
return append(out, pc.getPaddingByte())
}
@@ -134,7 +124,7 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
} else {
pc.bitBuf &= (1 << pc.bitCount) - 1
}
out = pc.appendGroup(out, group&0x3F)
out = appendPackedGroup(out, pc.table.layout, pc.rng, pc.paddingThreshold, pc.padPool, group)
}
effective++
@@ -148,19 +138,49 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
return out, limit
}
func appendPackedGroup(out []byte, layout *byteLayout, rng *sudokuRand, paddingThreshold uint64, padPool []byte, group byte) []byte {
if paddingThreshold != 0 {
u := rng.Uint32()
if uint64(u) < paddingThreshold {
out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))])
}
}
return append(out, layout.encodeGroup[group&0x3F])
}
func maybeAppendPackedPadding(out []byte, rng *sudokuRand, paddingThreshold uint64, padPool []byte) []byte {
if paddingThreshold != 0 {
u := rng.Uint32()
if uint64(u) < paddingThreshold {
out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))])
}
}
return out
}
func (pc *PackedConn) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 {
return 0, io.ErrClosedPipe
}
pc.writeMu.Lock()
defer pc.writeMu.Unlock()
needed := len(p)*3/2 + 32
if pc.paddingThreshold == 0 {
needed = ((len(p)+2)/3)*4 + 32
}
if cap(pc.writeBuf) < needed {
pc.writeBuf = make([]byte, 0, needed)
}
out := pc.writeBuf[:0]
layout := pc.table.layout
rng := pc.rng
paddingThreshold := pc.paddingThreshold
padPool := pc.padPool
var prefixN int
out, prefixN = pc.writeProtectedPrefix(out, p)
@@ -181,7 +201,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
} else {
pc.bitBuf &= (1 << pc.bitCount) - 1
}
out = pc.appendGroup(out, group&0x3F)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
}
}
@@ -195,10 +215,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
g4 := b3 & 0x3F
out = pc.appendGroup(out, g1)
out = pc.appendGroup(out, g2)
out = pc.appendGroup(out, g3)
out = pc.appendGroup(out, g4)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4)
}
}
@@ -211,10 +231,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
g4 := b3 & 0x3F
out = pc.appendGroup(out, g1)
out = pc.appendGroup(out, g2)
out = pc.appendGroup(out, g3)
out = pc.appendGroup(out, g4)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4)
}
for ; i < n; i++ {
@@ -229,7 +249,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
} else {
pc.bitBuf &= (1 << pc.bitCount) - 1
}
out = pc.appendGroup(out, group&0x3F)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
}
}
@@ -237,11 +257,11 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
group := byte(pc.bitBuf << (6 - pc.bitCount))
pc.bitBuf = 0
pc.bitCount = 0
out = pc.appendGroup(out, group&0x3F)
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
out = append(out, pc.padMarker)
}
out = pc.maybeAddPadding(out)
out = maybeAppendPackedPadding(out, rng, paddingThreshold, padPool)
if len(out) > 0 {
pc.writeBuf = out[:0]
@@ -252,6 +272,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
}
func (pc *PackedConn) Flush() error {
if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 {
return io.ErrClosedPipe
}
pc.writeMu.Lock()
defer pc.writeMu.Unlock()
@@ -265,7 +289,7 @@ func (pc *PackedConn) Flush() error {
out = append(out, pc.padMarker)
}
out = pc.maybeAddPadding(out)
out = maybeAppendPackedPadding(out, pc.rng, pc.paddingThreshold, pc.padPool)
if len(out) > 0 {
pc.writeBuf = out[:0]
@@ -289,19 +313,44 @@ func writeFull(w io.Writer, b []byte) error {
}
func (pc *PackedConn) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
if pc == nil || pc.Conn == nil || pc.reader == nil || len(pc.rawBuf) == 0 || pc.table == nil || pc.table.layout == nil {
return 0, io.ErrClosedPipe
}
if n, ok := drainPending(p, &pc.pendingData); ok {
return n, nil
}
outN := 0
for {
nr, rErr := pc.reader.Read(pc.rawBuf)
nr, rErr := readRawLimited(pc.Conn, pc.reader, pc.rawBuf[:packedReadSize(len(p)-outN, len(pc.rawBuf))])
if nr > 0 {
rBuf := pc.readBitBuf
rBits := pc.readBits
padMarker := pc.padMarker
layout := pc.table.layout
for _, b := range pc.rawBuf[:nr] {
chunk := pc.rawBuf[:nr]
for i := 0; i < len(chunk); {
if rBits == 0 && outN+3 <= len(p) && i+3 < len(chunk) &&
layout.hintTable[chunk[i]] && layout.hintTable[chunk[i+1]] &&
layout.hintTable[chunk[i+2]] && layout.hintTable[chunk[i+3]] {
g1 := layout.decodeGroup[chunk[i]]
g2 := layout.decodeGroup[chunk[i+1]]
g3 := layout.decodeGroup[chunk[i+2]]
g4 := layout.decodeGroup[chunk[i+3]]
p[outN] = (g1 << 2) | (g2 >> 4)
p[outN+1] = (g2 << 4) | (g3 >> 2)
p[outN+2] = (g3 << 6) | g4
outN += 3
i += 4
continue
}
b := chunk[i]
i++
if !layout.hintTable[b] {
if b == padMarker {
rBuf = 0
@@ -321,7 +370,7 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
if rBits >= 8 {
rBits -= 8
val := byte(rBuf >> rBits)
pc.pendingData.appendByte(val)
outN = appendDecodedByte(p, outN, &pc.pendingData, val)
if rBits == 0 {
rBuf = 0
} else {
@@ -339,21 +388,32 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
pc.readBitBuf = 0
pc.readBits = 0
}
if pc.pendingData.available() > 0 {
break
if outN > 0 {
return outN, nil
}
if n, ok := drainPending(p, &pc.pendingData); ok {
return n, nil
}
return 0, rErr
}
if pc.pendingData.available() > 0 {
break
if outN > 0 {
return outN, nil
}
}
n, _ := drainPending(p, &pc.pendingData)
return n, nil
}
func (pc *PackedConn) getPaddingByte() byte {
return pc.padPool[pc.rng.Intn(len(pc.padPool))]
}
func packedReadSize(decodedRemaining, maxRaw int) int {
if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 {
return maxRaw
}
if decodedRemaining > (maxRaw-minDecodeReadSize)/2 {
return maxRaw
}
return decodedRemaining*2 + minDecodeReadSize
}

View File

@@ -0,0 +1,90 @@
package sudoku
import (
"bytes"
"io"
"net"
"testing"
"time"
)
type mockConn struct {
readBuf []byte
writeBuf []byte
}
func (c *mockConn) Read(p []byte) (int, error) {
if len(c.readBuf) == 0 {
return 0, io.EOF
}
n := copy(p, c.readBuf)
c.readBuf = c.readBuf[n:]
return n, nil
}
func (c *mockConn) Write(p []byte) (int, error) {
c.writeBuf = append(c.writeBuf, p...)
return len(p), nil
}
func (c *mockConn) Close() error { return nil }
func (c *mockConn) LocalAddr() net.Addr { return nil }
func (c *mockConn) RemoteAddr() net.Addr { return nil }
func (c *mockConn) SetDeadline(time.Time) error { return nil }
func (c *mockConn) SetReadDeadline(time.Time) error { return nil }
func (c *mockConn) SetWriteDeadline(time.Time) error { return nil }
func TestPackedConn_ProtectedPrefixPadding(t *testing.T) {
table := NewTable("packed-prefix-seed", "prefer_ascii")
mock := &mockConn{}
writer := NewPackedConn(mock, table, 0, 0)
writer.rng = newSudokuRand(1)
payload := bytes.Repeat([]byte{0}, 32)
if _, err := writer.Write(payload); err != nil {
t.Fatalf("write: %v", err)
}
wire := append([]byte(nil), mock.writeBuf...)
if len(wire) < 20 {
t.Fatalf("wire too short: %d", len(wire))
}
firstHint := -1
nonHintCount := 0
maxHintRun := 0
currentHintRun := 0
for i, b := range wire[:20] {
if table.layout.isHint(b) {
if firstHint == -1 {
firstHint = i
}
currentHintRun++
if currentHintRun > maxHintRun {
maxHintRun = currentHintRun
}
continue
}
nonHintCount++
currentHintRun = 0
}
if firstHint < 1 || firstHint > 2 {
t.Fatalf("expected 1-2 leading padding bytes, first hint index=%d", firstHint)
}
if nonHintCount < 6 {
t.Fatalf("expected dense prefix padding, got only %d non-hint bytes in first 20", nonHintCount)
}
if maxHintRun > 3 {
t.Fatalf("prefix still exposes long hint run: %d", maxHintRun)
}
reader := NewPackedConn(&mockConn{readBuf: wire}, table, 0, 0)
decoded := make([]byte, len(payload))
if _, err := io.ReadFull(reader, decoded); err != nil {
t.Fatalf("read back: %v", err)
}
if !bytes.Equal(decoded, payload) {
t.Fatalf("roundtrip mismatch")
}
}

View File

@@ -2,7 +2,7 @@ package sudoku
const probOne = uint64(1) << 32
func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 {
func pickPaddingThreshold(r *sudokuRand, pMin, pMax int) uint64 {
if r == nil {
return 0
}
@@ -28,7 +28,7 @@ func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 {
return min + (u * (max - min) >> 32)
}
func shouldPad(r randomSource, threshold uint64) bool {
func shouldPad(r *sudokuRand, threshold uint64) bool {
if threshold == 0 {
return false
}

View File

@@ -25,7 +25,10 @@ func (p *pendingBuffer) reset() {
}
func (p *pendingBuffer) ensureAppendCapacity(extra int) {
if p == nil || extra <= 0 || p.off == 0 {
if p == nil || extra <= 0 {
return
}
if p.off == 0 {
return
}
if cap(p.data)-len(p.data) >= extra {
@@ -43,6 +46,15 @@ func (p *pendingBuffer) appendByte(b byte) {
p.data = append(p.data, b)
}
func appendDecodedByte(dst []byte, n int, pending *pendingBuffer, b byte) int {
if n < len(dst) {
dst[n] = b
return n + 1
}
pending.appendByte(b)
return n
}
func drainPending(dst []byte, pending *pendingBuffer) (int, bool) {
if pending == nil || pending.available() == 0 {
return 0, false

View File

@@ -6,14 +6,10 @@ import (
"time"
)
type randomSource interface {
Uint32() uint32
Uint64() uint64
Intn(n int) int
}
type sudokuRand struct {
state uint64
state uint64
cached uint32
haveCached bool
}
func newSeededRand() *sudokuRand {
@@ -37,20 +33,36 @@ func (r *sudokuRand) Uint64() uint64 {
if r == nil {
return 0
}
r.state += 0x9e3779b97f4a7c15
z := r.state
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9
z = (z ^ (z >> 27)) * 0x94d049bb133111eb
return z ^ (z >> 31)
r.haveCached = false
x := r.state
x ^= x >> 12
x ^= x << 25
x ^= x >> 27
r.state = x
return x * 0x2545f4914f6cdd1d
}
func (r *sudokuRand) Uint32() uint32 {
return uint32(r.Uint64() >> 32)
if r == nil {
return 0
}
if r.haveCached {
r.haveCached = false
return r.cached
}
v := r.Uint64()
r.cached = uint32(v)
r.haveCached = true
return uint32(v >> 32)
}
func (r *sudokuRand) Intn(n int) int {
if n <= 1 {
return 0
}
return int((uint64(r.Uint32()) * uint64(n)) >> 32)
return fastIntnFromUint32(r.Uint32(), n)
}
func fastIntnFromUint32(u uint32, n int) int {
return int((uint64(u) * uint64(n)) >> 32)
}

View File

@@ -192,23 +192,27 @@ func tableHintFingerprint(key string, mode string, uplinkPattern string, downlin
}
func packHintsToKey(hints [4]byte) uint32 {
return packHintBytes(hints[0], hints[1], hints[2], hints[3])
}
func packHintBytes(h0, h1, h2, h3 byte) uint32 {
// Sorting network for 4 elements (Bubble sort unrolled)
// Swap if a > b
if hints[0] > hints[1] {
hints[0], hints[1] = hints[1], hints[0]
if h0 > h1 {
h0, h1 = h1, h0
}
if hints[2] > hints[3] {
hints[2], hints[3] = hints[3], hints[2]
if h2 > h3 {
h2, h3 = h3, h2
}
if hints[0] > hints[2] {
hints[0], hints[2] = hints[2], hints[0]
if h0 > h2 {
h0, h2 = h2, h0
}
if hints[1] > hints[3] {
hints[1], hints[3] = hints[3], hints[1]
if h1 > h3 {
h1, h3 = h3, h1
}
if hints[1] > hints[2] {
hints[1], hints[2] = hints[2], hints[1]
if h1 > h2 {
h1, h2 = h2, h1
}
return uint32(hints[0])<<24 | uint32(hints[1])<<16 | uint32(hints[2])<<8 | uint32(hints[3])
return uint32(h0)<<24 | uint32(h1)<<16 | uint32(h2)<<8 | uint32(h3)
}

View File

@@ -14,12 +14,14 @@ import (
"sync/atomic"
"time"
"github.com/sagernet/sing-box/common/congestion"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3"
@@ -50,7 +52,7 @@ type ClientOptions struct {
QUIC bool
CongestionControl string
CWND int
BBRProfile string
Logger logger.Logger
HealthCheck bool
MaxConnections int
MinStreams int
@@ -81,7 +83,7 @@ func NewClient(ctx context.Context, options ClientOptions) (*Client, error) {
healthCheck: options.HealthCheck,
}
if options.QUIC {
congestionControlFactory, err := NewCongestionControl(options.CongestionControl, options.CWND, options.BBRProfile, ntp.TimeFuncFromContext(ctx))
congestionControlFactory, err := congestion.NewCongestionControl(options.CongestionControl, options.CWND, ntp.TimeFuncFromContext(ctx))
if err != nil {
cancel()
return nil, err

View File

@@ -1,77 +0,0 @@
package trusttunnel
import (
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/congestion"
"github.com/sagernet/sing-quic/congestion_bbr1"
"github.com/sagernet/sing-quic/congestion_bbr2"
congestion_meta1 "github.com/sagernet/sing-quic/congestion_meta1"
congestion_meta2 "github.com/sagernet/sing-quic/congestion_meta2"
E "github.com/sagernet/sing/common/exceptions"
)
func NewCongestionControl(name string, cwnd int, bbrProfile string, timeFunc func() time.Time) (func(conn *quic.Conn) congestion.CongestionControl, error) {
if timeFunc == nil {
timeFunc = time.Now
}
if cwnd == 0 {
cwnd = 32
}
switch name {
case "", "bbr":
return func(conn *quic.Conn) congestion.CongestionControl {
return congestion_meta2.NewBbrSender(
congestion_meta2.DefaultClock{TimeFunc: timeFunc},
congestion.ByteCount(conn.Config().InitialPacketSize),
congestion.ByteCount(cwnd)*congestion.ByteCount(conn.Config().InitialPacketSize),
)
}, nil
case "bbr_standard":
return func(conn *quic.Conn) congestion.CongestionControl {
return congestion_bbr1.NewBbrSender(
congestion_bbr1.DefaultClock{TimeFunc: timeFunc},
congestion.ByteCount(conn.Config().InitialPacketSize),
congestion_bbr1.InitialCongestionWindowPackets,
congestion_bbr1.MaxCongestionWindowPackets,
)
}, nil
case "bbr2":
return func(conn *quic.Conn) congestion.CongestionControl {
return congestion_bbr2.NewBBR2Sender(
congestion_bbr2.DefaultClock{TimeFunc: timeFunc},
congestion.ByteCount(conn.Config().InitialPacketSize),
0,
false,
)
}, nil
case "bbr2_variant":
return func(conn *quic.Conn) congestion.CongestionControl {
return congestion_bbr2.NewBBR2Sender(
congestion_bbr2.DefaultClock{TimeFunc: timeFunc},
congestion.ByteCount(conn.Config().InitialPacketSize),
32*congestion.ByteCount(conn.Config().InitialPacketSize),
true,
)
}, nil
case "cubic":
return func(conn *quic.Conn) congestion.CongestionControl {
return congestion_meta1.NewCubicSender(
congestion_meta1.DefaultClock{TimeFunc: timeFunc},
congestion.ByteCount(conn.Config().InitialPacketSize),
false,
)
}, nil
case "reno":
return func(conn *quic.Conn) congestion.CongestionControl {
return congestion_meta1.NewCubicSender(
congestion_meta1.DefaultClock{TimeFunc: timeFunc},
congestion.ByteCount(conn.Config().InitialPacketSize),
true,
)
}, nil
default:
return nil, E.New("unknown congestion control: ", name)
}
}

View File

@@ -2,6 +2,7 @@ package v2raygrpc
import (
"context"
"strings"
"google.golang.org/grpc"
)
@@ -13,13 +14,21 @@ type GunService interface {
}
func ServerDesc(name string) grpc.ServiceDesc {
serviceName := name
streamName := "Tun"
if strings.Contains(name, "/") {
name = strings.TrimPrefix(name, "/")
lastSlash := strings.LastIndex(name, "/")
serviceName = name[:lastSlash]
streamName = name[lastSlash+1:]
}
return grpc.ServiceDesc{
ServiceName: name,
ServiceName: serviceName,
HandlerType: (*GunServiceServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "Tun",
StreamName: streamName,
Handler: _GunService_Tun_Handler,
ServerStreams: true,
ClientStreams: true,
@@ -30,7 +39,11 @@ func ServerDesc(name string) grpc.ServiceDesc {
}
func (c *gunServiceClient) TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GunService_TunClient, error) {
stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], "/"+name+"/Tun", opts...)
path := "/" + name + "/Tun"
if strings.Contains(name, "/") {
path = name
}
stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], path, opts...)
if err != nil {
return nil, err
}

View File

@@ -53,10 +53,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
DisableCompression: true,
},
url: &url.URL{
Scheme: "https",
Host: serverAddr.String(),
Path: "/" + options.ServiceName + "/Tun",
RawPath: "/" + url.PathEscape(options.ServiceName) + "/Tun",
Scheme: "https",
Host: serverAddr.String(),
Path: grpcPath(options.ServiceName),
},
host: host,
}

View File

@@ -0,0 +1,10 @@
package v2raygrpclite
import "strings"
func grpcPath(serviceName string) string {
if strings.Contains(serviceName, "/") {
return serviceName
}
return "/" + serviceName + "/Tun"
}

View File

@@ -42,7 +42,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
tlsConfig: tlsConfig,
logger: logger,
handler: handler,
path: "/" + options.ServiceName + "/Tun",
path: grpcPath(options.ServiceName),
h2Server: &http2.Server{
IdleTimeout: time.Duration(options.IdleTimeout),
},

View File

@@ -1,14 +1,14 @@
package v2raykcp
import (
"container/list"
"sync"
"github.com/sagernet/sing-box/common/list"
"github.com/sagernet/sing/common/buf"
)
type SendingWindow struct {
cache *list.List
cache *list.List[*DataSegment]
totalInFlightSize uint32
writer SegmentWriter
onPacketLoss func(uint32)
@@ -16,7 +16,7 @@ type SendingWindow struct {
func NewSendingWindow(writer SegmentWriter, onPacketLoss func(uint32)) *SendingWindow {
return &SendingWindow{
cache: list.New(),
cache: list.New[*DataSegment](),
writer: writer,
onPacketLoss: onPacketLoss,
}
@@ -27,9 +27,9 @@ func (sw *SendingWindow) Release() {
return
}
for sw.cache.Len() > 0 {
seg := sw.cache.Front().Value.(*DataSegment)
seg := sw.cache.Front().Value
seg.Release()
sw.cache.Remove(sw.cache.Front())
sw.cache.Front().Remove()
}
}
@@ -50,17 +50,17 @@ func (sw *SendingWindow) Push(number uint32, b *buf.Buffer) {
}
func (sw *SendingWindow) FirstNumber() uint32 {
return sw.cache.Front().Value.(*DataSegment).Number
return sw.cache.Front().Value.Number
}
func (sw *SendingWindow) Clear(una uint32) {
for !sw.IsEmpty() {
seg := sw.cache.Front().Value.(*DataSegment)
seg := sw.cache.Front().Value
if seg.Number >= una {
break
}
seg.Release()
sw.cache.Remove(sw.cache.Front())
sw.cache.Front().Remove()
}
}
@@ -87,8 +87,7 @@ func (sw *SendingWindow) Visit(visitor func(seg *DataSegment) bool) {
}
for e := sw.cache.Front(); e != nil; e = e.Next() {
seg := e.Value.(*DataSegment)
if !visitor(seg) {
if !visitor(e.Value) {
break
}
}
@@ -132,7 +131,7 @@ func (sw *SendingWindow) Remove(number uint32) bool {
}
for e := sw.cache.Front(); e != nil; e = e.Next() {
seg := e.Value.(*DataSegment)
seg := e.Value
if seg.Number > number {
return false
} else if seg.Number == number {
@@ -140,7 +139,7 @@ func (sw *SendingWindow) Remove(number uint32) bool {
sw.totalInFlightSize--
}
seg.Release()
sw.cache.Remove(e)
e.Remove()
return true
}
}

View File

@@ -16,12 +16,12 @@ import (
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/congestion"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/common/xray/buf"
"github.com/sagernet/sing-box/common/xray/net"
"github.com/sagernet/sing-box/common/xray/pipe"
"github.com/sagernet/sing-box/common/xray/signal/done"
"github.com/sagernet/sing-box/common/xray/uuid"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
qtls "github.com/sagernet/sing-quic"
@@ -30,6 +30,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
sHTTP "github.com/sagernet/sing/protocol/http"
"github.com/sagernet/sing/service"
"golang.org/x/net/http2"
@@ -42,15 +43,22 @@ type Client struct {
baseRequestURL2 url.URL
getHTTPClient func() (DialerClient, *XmuxClient)
getHTTPClient2 func() (DialerClient, *XmuxClient)
xmuxManager *XmuxManager
xmuxManager2 *XmuxManager
}
func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayXHTTPOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
if options.Mode == "" {
return nil, E.New("mode is not set")
}
if tlsConfig != nil && len(tlsConfig.NextProtos()) == 0 {
tlsConfig.SetNextProtos([]string{"h2"})
}
if _, err := congestion.NewCongestionControl(options.CongestionController, options.CWND, nil); err != nil {
return nil, err
}
if options.Download != nil {
if _, err := congestion.NewCongestionControl(options.Download.CongestionController, options.Download.CWND, nil); err != nil {
return nil, err
}
}
dest := serverAddr
baseRequestURL, err := getBaseRequestURL(&options.V2RayXHTTPBaseOptions, dest, tlsConfig)
if err != nil {
@@ -61,7 +69,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
xmuxOptions = *options.Xmux
}
xmuxManager := NewXmuxManager(xmuxOptions, func() XmuxConn {
return createHTTPClient(dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig)
return createHTTPClient(ctx, dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig)
})
getHTTPClient := func() (DialerClient, *XmuxClient) {
xmuxClient := xmuxManager.GetXmuxClient(ctx)
@@ -69,6 +77,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
}
baseRequestURL2 := baseRequestURL
getHTTPClient2 := getHTTPClient
var xmuxManager2 *XmuxManager
if options.Download != nil {
options2 := options.Download
dialer2 := dialer
@@ -98,8 +107,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
if options2.Xmux != nil {
xmuxOptions2 = *options2.Xmux
}
xmuxManager2 := NewXmuxManager(xmuxOptions2, func() XmuxConn {
return createHTTPClient(dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2)
xmuxManager2 = NewXmuxManager(xmuxOptions2, func() XmuxConn {
return createHTTPClient(ctx, dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2)
})
getHTTPClient2 = func() (DialerClient, *XmuxClient) {
xmuxClient2 := xmuxManager2.GetXmuxClient(ctx)
@@ -113,6 +122,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
getHTTPClient2: getHTTPClient2,
baseRequestURL: baseRequestURL,
baseRequestURL2: baseRequestURL2,
xmuxManager: xmuxManager,
xmuxManager2: xmuxManager2,
}, nil
}
@@ -121,8 +132,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
mode := c.options.Mode
sessionId := ""
if c.options.Mode != "stream-one" {
sessionIdUuid := uuid.New()
sessionId = sessionIdUuid.String()
sessionId = GenerateSessionID(&c.options.V2RayXHTTPBaseOptions)
}
requestURL := c.baseRequestURL
requestURL2 := c.baseRequestURL2
@@ -182,10 +192,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
}
scMaxEachPostBytes := options.GetNormalizedScMaxEachPostBytes()
scMinPostsIntervalMs := options.GetNormalizedScMinPostsIntervalMs()
if scMaxEachPostBytes.From <= 0 {
panic("`scMaxEachPostBytes` should be bigger than 0")
}
maxUploadSize := scMaxEachPostBytes.Rand()
maxUploadSize := int32(scMaxEachPostBytes.Rand())
// WithSizeLimit(0) will still allow single bytes to pass, and a lot of
// code relies on this behavior. Subtract 1 so that together with
// uploadWriter wrapper, exact size limits can be enforced
@@ -255,6 +262,10 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
}
func (c *Client) Close() error {
c.xmuxManager.Close()
if c.xmuxManager2 != nil {
c.xmuxManager2.Close()
}
return nil
}
@@ -294,7 +305,7 @@ func getBaseRequestURL(options *option.V2RayXHTTPBaseOptions, dest M.Socksaddr,
return requestURL, nil
}
func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient {
func createHTTPClient(ctx context.Context, dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient {
httpVersion := decideHTTPVersion(tlsConfig)
dialContext := func(ctxInner context.Context) (net.Conn, error) {
conn, err := dialer.DialContext(ctxInner, "tcp", dest)
@@ -319,6 +330,7 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH
if keepAlivePeriod < 0 {
keepAlivePeriod = 0
}
congestionControlFactory, _ := congestion.NewCongestionControl(options.CongestionController, options.CWND, ntp.TimeFuncFromContext(ctx))
quicConfig := &quic.Config{
MaxIdleTimeout: net.ConnIdleTimeout,
// these two are defaults of quic-go/http3. the default of quic-go (no
@@ -334,7 +346,14 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH
if dErr != nil {
return nil, dErr
}
return qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg)
conn, dErr := qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg)
if dErr != nil {
return nil, dErr
}
if congestionControlFactory != nil {
conn.SetCongestionControl(congestionControlFactory(conn))
}
return conn, nil
},
}
case "2":

View File

@@ -39,7 +39,7 @@ func (c *splitConn) Close() error {
}
if err2 != nil {
return err
return err2
}
return nil

View File

@@ -147,7 +147,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio
if c.httpVersion != "1.1" {
resp, err := c.client.Do(req)
if err != nil {
c.closed = true
c.Close()
return err
}
io.Copy(io.Discard, resp.Body)
@@ -225,10 +225,9 @@ func (w *WaitReadCloser) Set(rc io.ReadCloser) {
}
func (w *WaitReadCloser) Read(b []byte) (int, error) {
<-w.Wait
if w.ReadCloser == nil {
if <-w.Wait; w.ReadCloser == nil {
return 0, io.ErrClosedPipe
}
return 0, io.ErrClosedPipe
}
return w.ReadCloser.Read(b)
}

View File

@@ -19,8 +19,8 @@ type XmuxConn interface {
type XmuxClient struct {
XmuxConn XmuxConn
openUsage int32
leftUsage int32
openUsage int
leftUsage int
LeftRequests atomic.Int32
UnreusableAt time.Time
@@ -37,7 +37,7 @@ func (c *XmuxClient) Close() {
}
}
func (c *XmuxClient) AddOpenUsage(delta int32) {
func (c *XmuxClient) AddOpenUsage(delta int) {
c.mtx.Lock()
defer c.mtx.Unlock()
c.openUsage += delta
@@ -46,7 +46,7 @@ func (c *XmuxClient) AddOpenUsage(delta int32) {
}
}
func (c *XmuxClient) GetOpenUsage() int32 {
func (c *XmuxClient) GetOpenUsage() int {
c.mtx.Lock()
defer c.mtx.Unlock()
return c.openUsage
@@ -54,8 +54,8 @@ func (c *XmuxClient) GetOpenUsage() int32 {
type XmuxManager struct {
options option.V2RayXHTTPXmuxOptions
concurrency int32
connections int32
concurrency int
connections int
newConnFunc func() XmuxConn
xmuxClients []*XmuxClient
mtx sync.Mutex
@@ -71,6 +71,15 @@ func NewXmuxManager(options option.V2RayXHTTPXmuxOptions, newConnFunc func() Xmu
}
}
func (m *XmuxManager) Close() {
m.mtx.Lock()
defer m.mtx.Unlock()
for _, xmuxClient := range m.xmuxClients {
xmuxClient.Close()
}
m.xmuxClients = m.xmuxClients[:0]
}
func (m *XmuxManager) newXmuxClient() *XmuxClient {
xmuxClient := &XmuxClient{
XmuxConn: m.newConnFunc(),
@@ -81,7 +90,7 @@ func (m *XmuxManager) newXmuxClient() *XmuxClient {
}
xmuxClient.LeftRequests.Store(math.MaxInt32)
if x := m.options.GetNormalizedHMaxRequestTimes().Rand(); x > 0 {
xmuxClient.LeftRequests.Store(x)
xmuxClient.LeftRequests.Store(int32(x))
}
if x := m.options.GetNormalizedHMaxReusableSecs().Rand(); x > 0 {
xmuxClient.UnreusableAt = time.Now().Add(time.Duration(x) * time.Second)
@@ -112,7 +121,7 @@ func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient {
if len(m.xmuxClients) == 0 {
return m.newXmuxClient()
}
if m.connections > 0 && len(m.xmuxClients) < int(m.connections) {
if m.connections > 0 && len(m.xmuxClients) < m.connections {
return m.newXmuxClient()
}
xmuxClients := make([]*XmuxClient, 0)

View File

@@ -18,6 +18,8 @@ import (
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/http3"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/congestion"
"github.com/sagernet/sing-box/common/kmutex"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/common/xray/buf"
xnet "github.com/sagernet/sing-box/common/xray/net"
@@ -31,6 +33,7 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
aTLS "github.com/sagernet/sing/common/tls"
sHttp "github.com/sagernet/sing/protocol/http"
)
@@ -49,7 +52,7 @@ type Server struct {
options *option.V2RayXHTTPOptions
host string
path string
sessionMu sync.Mutex
sessionMu *kmutex.Kmutex[string]
sessions sync.Map
}
@@ -62,6 +65,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
options: &options,
host: options.Host,
path: options.GetNormalizedPath(),
sessionMu: kmutex.New[string](),
}
if server.network() == N.NetworkTCP {
protocols := new(http.Protocols)
@@ -80,11 +84,21 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
},
}
} else {
congestionControlFactory, err := congestion.NewCongestionControl(options.CongestionController, options.CWND, ntp.TimeFuncFromContext(ctx))
if err != nil {
return nil, err
}
server.quicConfig = &quic.Config{
DisablePathMTUDiscovery: !C.IsLinux && !C.IsWindows,
}
server.http3Server = &http3.Server{
Handler: server,
ConnContext: func(ctx context.Context, conn *quic.Conn) context.Context {
if congestionControlFactory != nil {
conn.SetCongestionControl(congestionControlFactory(conn))
}
return log.ContextWithNewID(ctx)
},
}
}
return server, nil
@@ -102,7 +116,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
return
}
WriteResponseHeader(writer, request.Method, request.Header, s.options)
length := int(s.options.GetNormalizedXPaddingBytes().Rand())
length := s.options.GetNormalizedXPaddingBytes().Rand()
config := XPaddingConfig{Length: length}
if s.options.XPaddingObfsMode {
config.Placement = XPaddingPlacement{
@@ -125,15 +139,25 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
validRange := s.options.GetNormalizedXPaddingBytes()
paddingValue, paddingPlacement := ExtractXPaddingFromRequest(&s.options.V2RayXHTTPBaseOptions, request, s.options.XPaddingObfsMode)
if !IsPaddingValid(&s.options.V2RayXHTTPBaseOptions, paddingValue, validRange.From, validRange.To, PaddingMethod(s.options.XPaddingMethod)) {
s.logger.ErrorContext(request.Context(), "invalid padding ("+paddingPlacement+") length:", int32(len(paddingValue)))
s.logger.ErrorContext(request.Context(), "invalid padding ("+paddingPlacement+") length:", len(paddingValue))
writer.WriteHeader(http.StatusBadRequest)
return
}
sessionId, seqStr := ExtractMetaFromRequest(s.options, request, s.path)
if sessionId == "" && s.options.Mode != "" && s.options.Mode != "auto" && s.options.Mode != "stream-one" && s.options.Mode != "stream-up" {
s.logger.ErrorContext(request.Context(), "stream-one mode is not allowed")
writer.WriteHeader(http.StatusBadRequest)
return
if s.options.Mode != "" && s.options.Mode != "auto" {
if sessionId == "" {
if s.options.Mode != "stream-one" && s.options.Mode != "stream-up" {
s.logger.ErrorContext(request.Context(), "stream-one mode is not allowed")
writer.WriteHeader(http.StatusBadRequest)
return
}
} else {
if s.options.Mode == "stream-one" {
s.logger.ErrorContext(request.Context(), "session is not allowed in stream-one mode")
writer.WriteHeader(http.StatusBadRequest)
return
}
}
}
var forwardedAddrs []xnet.Address
if len(s.options.TrustedXForwardedFor) > 0 {
@@ -171,7 +195,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if sessionId != "" {
currentSession = s.upsertSession(sessionId)
}
scMaxEachPostBytes := int(s.options.GetNormalizedScMaxEachPostBytes().To)
scMaxEachPostBytes := s.options.GetNormalizedScMaxEachPostBytes().To
uplinkDataPlacement := s.options.GetNormalizedUplinkDataPlacement()
uplinkDataKey := s.options.UplinkDataKey
isUplinkRequest := false
@@ -207,12 +231,22 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
referrer := request.Header.Get("Referer")
if referrer != "" && scStreamUpServerSecs.To > 0 {
go func() {
timer := time.NewTimer(0)
if !timer.Stop() {
<-timer.C
}
defer timer.Stop()
for {
_, err := httpSC.Write(bytes.Repeat([]byte{'X'}, int(s.options.GetNormalizedXPaddingBytes().Rand())))
_, err := httpSC.Write(bytes.Repeat([]byte{'X'}, s.options.GetNormalizedXPaddingBytes().Rand()))
if err != nil {
break
return
}
timer.Reset(time.Duration(scStreamUpServerSecs.Rand()) * time.Second)
select {
case <-timer.C:
case <-httpSC.Wait():
return
}
time.Sleep(time.Duration(scStreamUpServerSecs.Rand()) * time.Second)
}
}()
}
@@ -327,7 +361,11 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
// after GET is done, the connection is finished. disable automatic
// session reaping, and handle it in defer
currentSession.isFullyConnected.Close()
defer s.sessions.Delete(sessionId)
defer func() {
s.sessionMu.Lock(sessionId)
defer s.sessionMu.Unlock(sessionId)
s.sessions.Delete(sessionId)
}()
}
// magic header instructs nginx + apache to not buffer response body
writer.Header().Set("X-Accel-Buffering", "no")
@@ -410,32 +448,27 @@ func (s *Server) network() string {
}
func (s *Server) upsertSession(sessionId string) *httpSession {
// fast path
s.sessionMu.Lock(sessionId)
defer s.sessionMu.Unlock(sessionId)
currentSessionAny, ok := s.sessions.Load(sessionId)
if ok {
return currentSessionAny.(*httpSession)
}
// slow path
s.sessionMu.Lock()
defer s.sessionMu.Unlock()
currentSessionAny, ok = s.sessions.Load(sessionId)
if ok {
return currentSessionAny.(*httpSession)
}
session := &httpSession{
uploadQueue: NewUploadQueue(s.options.GetNormalizedScMaxBufferedPosts()),
isFullyConnected: done.New(),
}
s.sessions.Store(sessionId, session)
shouldReap := done.New()
go func() {
time.Sleep(30 * time.Second)
shouldReap.Close()
}()
go func() {
reapTimer := time.NewTimer(30 * time.Second)
defer reapTimer.Stop()
select {
case <-shouldReap.Wait():
s.sessions.Delete(sessionId)
case <-reapTimer.C:
s.sessionMu.Lock(sessionId)
if current, ok := s.sessions.Load(sessionId); ok && current.(*httpSession) == session {
s.sessions.Delete(sessionId)
}
s.sessionMu.Unlock(sessionId)
session.uploadQueue.Close()
case <-session.isFullyConnected.Wait():
}

View File

@@ -6,7 +6,6 @@ package xhttp
import (
"container/heap"
"io"
"runtime"
"sync"
E "github.com/sagernet/sing/common/exceptions"
@@ -19,19 +18,22 @@ type Packet struct {
}
type uploadQueue struct {
reader io.ReadCloser
nomore bool
pushedPackets chan Packet
writeCloseMutex sync.Mutex
heap uploadHeap
nextSeq uint64
closed bool
maxPackets int
reader io.ReadCloser
nomore bool
pushedPackets chan Packet
done chan struct{}
heap uploadHeap
nextSeq uint64
closed bool
maxPackets int
mtx sync.Mutex
}
func NewUploadQueue(maxPackets int) *uploadQueue {
return &uploadQueue{
pushedPackets: make(chan Packet, maxPackets),
done: make(chan struct{}),
heap: uploadHeap{},
nextSeq: 0,
closed: false,
@@ -40,63 +42,83 @@ func NewUploadQueue(maxPackets int) *uploadQueue {
}
func (h *uploadQueue) Push(p Packet) error {
h.writeCloseMutex.Lock()
defer h.writeCloseMutex.Unlock()
h.mtx.Lock()
if h.closed {
h.mtx.Unlock()
return E.New("packet queue closed")
}
if h.nomore {
h.mtx.Unlock()
return E.New("h.reader already exists")
}
if p.Reader != nil {
h.nomore = true
}
h.pushedPackets <- p
return nil
h.mtx.Unlock()
select {
case h.pushedPackets <- p:
return nil
case <-h.done:
return E.New("packet queue closed")
}
}
func (h *uploadQueue) Close() error {
h.writeCloseMutex.Lock()
defer h.writeCloseMutex.Unlock()
if !h.closed {
h.closed = true
runtime.Gosched() // hope Read() gets the packet
f:
for {
select {
case p := <-h.pushedPackets:
if p.Reader != nil {
h.reader = p.Reader
}
default:
break f
h.mtx.Lock()
if h.closed {
h.mtx.Unlock()
return nil
}
h.closed = true
close(h.done)
h.mtx.Unlock()
for {
select {
case p := <-h.pushedPackets:
if p.Reader != nil {
p.Reader.Close()
}
default:
if h.reader != nil {
return h.reader.Close()
}
return nil
}
close(h.pushedPackets)
}
if h.reader != nil {
return h.reader.Close()
}
return nil
}
func (h *uploadQueue) Read(b []byte) (int, error) {
h.mtx.Lock()
if h.closed {
h.mtx.Unlock()
return 0, io.EOF
}
h.mtx.Unlock()
if h.reader != nil {
return h.reader.Read(b)
}
if h.closed {
return 0, io.EOF
}
if len(h.heap) == 0 {
packet, more := <-h.pushedPackets
if !more {
select {
case packet, more := <-h.pushedPackets:
if !more {
return 0, io.EOF
}
if packet.Reader != nil {
h.mtx.Lock()
if h.closed {
packet.Reader.Close()
h.mtx.Unlock()
return 0, io.EOF
}
h.reader = packet.Reader
h.mtx.Unlock()
return h.reader.Read(b)
}
heap.Push(&h.heap, packet)
case <-h.done:
return 0, io.EOF
}
if packet.Reader != nil {
h.reader = packet.Reader
return h.reader.Read(b)
}
heap.Push(&h.heap, packet)
}
for len(h.heap) > 0 {
packet := heap.Pop(&h.heap).(Packet)
@@ -125,11 +147,15 @@ func (h *uploadQueue) Read(b []byte) (int, error) {
return 0, E.New("packet queue is too large")
}
heap.Push(&h.heap, packet)
packet2, more := <-h.pushedPackets
if !more {
select {
case packet2, more := <-h.pushedPackets:
if !more {
return 0, io.EOF
}
heap.Push(&h.heap, packet2)
case <-h.done:
return 0, io.EOF
}
heap.Push(&h.heap, packet2)
}
}
return 0, nil

View File

@@ -4,16 +4,49 @@ import (
"encoding/base64"
"fmt"
"io"
"math/rand/v2"
"net/http"
"github.com/sagernet/sing-box/common/xray/buf"
"github.com/sagernet/sing-box/common/xray/utils"
"github.com/sagernet/sing-box/common/xray/uuid"
"github.com/sagernet/sing-box/option"
)
// PredefinedTable maps named charsets to their alphabets for session ID generation.
var PredefinedTable = map[string]string{
"ALPHABET": "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
"Alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
"BASE36": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ",
"Base62": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
"HEX": "0123456789ABCDEF",
"alphabet": "abcdefghijklmnopqrstuvwxyz",
"base36": "0123456789abcdefghijklmnopqrstuvwxyz",
"hex": "0123456789abcdef",
"number": "0123456789",
}
func GenerateSessionID(options *option.V2RayXHTTPBaseOptions) string {
length := options.SessionIDLength.Rand()
table := options.SessionIDTable
if predefined, ok := PredefinedTable[table]; ok {
table = predefined
}
if table != "" && length > 0 {
id := make([]byte, length)
for i := range id {
id[i] = table[rand.N(len(table))]
}
return string(id)
}
newUUID := uuid.New()
return newUUID.String()
}
func FillStreamRequest(request *http.Request, sessionId string, seqStr string, options *option.V2RayXHTTPBaseOptions) {
request.Header = options.GetRequestHeader()
length := int(options.GetNormalizedXPaddingBytes().Rand())
length := options.GetNormalizedXPaddingBytes().Rand()
config := XPaddingConfig{Length: length}
if options.XPaddingObfsMode {
config.Placement = XPaddingPlacement{
@@ -58,7 +91,7 @@ func FillPacketRequest(request *http.Request, sessionId string, seqStr string, p
}
}
}
length := int(options.GetNormalizedXPaddingBytes().Rand())
length := options.GetNormalizedXPaddingBytes().Rand()
config := XPaddingConfig{Length: length}
if options.XPaddingObfsMode {
config.Placement = XPaddingPlacement{
@@ -125,7 +158,7 @@ func GetRequestHeaderWithPayload(payload []byte, options *option.V2RayXHTTPBaseO
key := options.UplinkDataKey
encodedData := base64.RawURLEncoding.EncodeToString(payload)
for i := 0; len(encodedData) > 0; i++ {
chunkSize := min(int(options.GetNormalizedUplinkChunkSize().Rand()), len(encodedData))
chunkSize := min(options.GetNormalizedUplinkChunkSize().Rand(), len(encodedData))
chunk := encodedData[:chunkSize]
encodedData = encodedData[chunkSize:]
headerKey := fmt.Sprintf("%s-%d", key, i)
@@ -140,7 +173,7 @@ func GetRequestCookiesWithPayload(payload []byte, options *option.V2RayXHTTPBase
key := options.UplinkDataKey
encodedData := base64.RawURLEncoding.EncodeToString(payload)
for i := 0; len(encodedData) > 0; i++ {
chunkSize := min(int(options.GetNormalizedUplinkChunkSize().Rand()), len(encodedData))
chunkSize := min(options.GetNormalizedUplinkChunkSize().Rand(), len(encodedData))
chunk := encodedData[:chunkSize]
encodedData = encodedData[chunkSize:]
cookieName := fmt.Sprintf("%s_%d", key, i)

View File

@@ -31,11 +31,12 @@ func (w uploadWriter) Write(b []byte) (int, error) {
var writed int
for _, buff := range buffer.MultiBuffer {
n := int(buff.Len())
err := w.WriteMultiBuffer(buf.MultiBuffer{buff})
if err != nil {
return writed, err
}
writed += int(buff.Len())
writed += n
}
return writed, nil
}

View File

@@ -264,7 +264,7 @@ func ExtractXPaddingFromRequest(options *option.V2RayXHTTPBaseOptions, req *http
return "", ""
}
func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string, from, to int32, method PaddingMethod) bool {
func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string, from, to int, method PaddingMethod) bool {
if paddingValue == "" {
return false
}
@@ -274,11 +274,11 @@ func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string,
}
switch method {
case PaddingMethodRepeatX:
n := int32(len(paddingValue))
n := len(paddingValue)
return n >= from && n <= to
case PaddingMethodTokenish:
const tolerance = int32(validationTolerance)
n := int32(hpack.HuffmanEncodeLength(paddingValue))
const tolerance = validationTolerance
n := int(hpack.HuffmanEncodeLength(paddingValue))
f := from - tolerance
t := to + tolerance
if f < 0 {
@@ -286,7 +286,7 @@ func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string,
}
return n >= f && n <= t
default:
n := int32(len(paddingValue))
n := len(paddingValue)
return n >= from && n <= to
}
}

View File

@@ -5,7 +5,7 @@ import (
"net/netip"
"time"
Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption"
"github.com/sagernet/sing/common/json/badoption"
tun "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
@@ -49,10 +49,10 @@ type AmneziaOptions struct {
S2 int
S3 int
S4 int
H1 *Xbadoption.Range
H2 *Xbadoption.Range
H3 *Xbadoption.Range
H4 *Xbadoption.Range
H1 *badoption.Range[uint32]
H2 *badoption.Range[uint32]
H3 *badoption.Range[uint32]
H4 *badoption.Range[uint32]
I1 string
I2 string
I3 string