mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-26 20:29:03 +03:00
Add Snell protocol. Refactor MASQUE HTTP/2, Fair Queue. Update XHTTP, OpenVPN, Sudoku, Fallback. Fixes
This commit is contained in:
331
transport/masque/client_h2.go
Normal file
331
transport/masque/client_h2.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package masque
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/sagernet/quic-go/quicvarint"
|
||||
"github.com/yosida95/uritemplate/v3"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const h2DatagramCapsuleType uint64 = 0
|
||||
|
||||
const (
|
||||
ipv4HeaderLen = 20
|
||||
ipv6HeaderLen = 40
|
||||
)
|
||||
|
||||
func ConnectTunnelH2(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, endpoint *net.TCPAddr, connectUri string) (io.Closer, IpConn, *http.Response, error) {
|
||||
if endpoint == nil {
|
||||
return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
|
||||
}
|
||||
|
||||
tlsConfig.SetNextProtos([]string{"h2"})
|
||||
|
||||
conn, err := dialer.DialContext(ctx, N.NetworkTCP, M.SocksaddrFromNetIP(endpoint.AddrPort()))
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
tlsConn, err := tlsConfig.Client(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if err = tlsConn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
tr := &http2.Transport{
|
||||
ReadIdleTimeout: 30 * time.Second,
|
||||
}
|
||||
cc, err := tr.NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
_ = tlsConn.Close()
|
||||
return nil, nil, nil, fmt.Errorf("connect-ip: failed to create client connection: %w", err)
|
||||
}
|
||||
|
||||
additionalHeaders := http.Header{
|
||||
"User-Agent": []string{""},
|
||||
}
|
||||
template := uritemplate.MustNew(connectUri)
|
||||
|
||||
h2Headers := additionalHeaders.Clone()
|
||||
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
|
||||
h2Headers.Set("pq-enabled", "false")
|
||||
|
||||
ipConn, rsp, err := dialH2(ctx, cc, template, h2Headers)
|
||||
if err != nil {
|
||||
_ = cc.Close()
|
||||
if strings.Contains(err.Error(), "tls: access denied") {
|
||||
return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
}
|
||||
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
|
||||
}
|
||||
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
_ = ipConn.Close()
|
||||
_ = cc.Close()
|
||||
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %v", rsp.Status)
|
||||
}
|
||||
|
||||
return cc, ipConn, rsp, nil
|
||||
}
|
||||
|
||||
func dialH2(ctx context.Context, rt http.RoundTripper, template *uritemplate.Template, additionalHeaders http.Header) (*h2IpConn, *http.Response, error) {
|
||||
if len(template.Varnames()) > 0 {
|
||||
return nil, nil, errors.New("connect-ip: IP flow forwarding not supported")
|
||||
}
|
||||
|
||||
u, err := url.Parse(template.Raw())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("connect-ip: failed to parse URI: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodConnect, u.String(), pr)
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
return nil, nil, fmt.Errorf("connect-ip: failed to create request: %w", err)
|
||||
}
|
||||
req.Host = authorityFromURL(u)
|
||||
req.ContentLength = -1
|
||||
req.Header = make(http.Header)
|
||||
for k, v := range additionalHeaders {
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
stop := context.AfterFunc(ctx, cancel)
|
||||
rsp, err := rt.RoundTrip(req)
|
||||
stop()
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
return nil, nil, fmt.Errorf("connect-ip: failed to send request: %w", err)
|
||||
}
|
||||
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
|
||||
cancel()
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
_ = rsp.Body.Close()
|
||||
return nil, rsp, fmt.Errorf("connect-ip: server responded with %d", rsp.StatusCode)
|
||||
}
|
||||
|
||||
stream := &h2DatagramStream{
|
||||
requestBody: pw,
|
||||
responseBody: rsp.Body,
|
||||
cancel: cancel,
|
||||
}
|
||||
return &h2IpConn{
|
||||
str: stream,
|
||||
closeChan: make(chan struct{}),
|
||||
}, rsp, nil
|
||||
}
|
||||
|
||||
func authorityFromURL(u *url.URL) string {
|
||||
if u.Port() != "" {
|
||||
return u.Host
|
||||
}
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return u.Host
|
||||
}
|
||||
return host + ":443"
|
||||
}
|
||||
|
||||
type h2IpConn struct {
|
||||
str *h2DatagramStream
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
closeChan chan struct{}
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (c *h2IpConn) ReadPacket() (b []byte, err error) {
|
||||
start:
|
||||
data, err := c.str.ReceiveDatagram(context.Background())
|
||||
if err != nil {
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
select {
|
||||
case <-c.closeChan:
|
||||
return nil, c.closeErr
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := c.handleIncomingProxiedPacket(data); err != nil {
|
||||
goto start
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) handleIncomingProxiedPacket(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return errors.New("connect-ip: empty packet")
|
||||
}
|
||||
switch v := ipVersion(data); v {
|
||||
default:
|
||||
return fmt.Errorf("connect-ip: unknown IP versions: %d", v)
|
||||
case 4:
|
||||
if len(data) < ipv4HeaderLen {
|
||||
return fmt.Errorf("connect-ip: malformed datagram: too short")
|
||||
}
|
||||
case 6:
|
||||
if len(data) < ipv6HeaderLen {
|
||||
return fmt.Errorf("connect-ip: malformed datagram: too short")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) WritePacket(b []byte) (icmp []byte, err error) {
|
||||
data, err := c.composeDatagram(b)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err := c.str.SendDatagram(data); err != nil {
|
||||
select {
|
||||
case <-c.closeChan:
|
||||
return nil, c.closeErr
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) composeDatagram(b []byte) ([]byte, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
switch v := ipVersion(b); v {
|
||||
default:
|
||||
return nil, fmt.Errorf("connect-ip: unknown IP versions: %d", v)
|
||||
case 4:
|
||||
if len(b) < ipv4HeaderLen {
|
||||
return nil, fmt.Errorf("connect-ip: IPv4 packet too short")
|
||||
}
|
||||
ttl := b[8]
|
||||
if ttl <= 1 {
|
||||
return nil, fmt.Errorf("connect-ip: datagram TTL too small: %d", ttl)
|
||||
}
|
||||
b[8]--
|
||||
binary.BigEndian.PutUint16(b[10:12], calculateIPv4Checksum(([ipv4HeaderLen]byte)(b[:ipv4HeaderLen])))
|
||||
case 6:
|
||||
if len(b) < ipv6HeaderLen {
|
||||
return nil, fmt.Errorf("connect-ip: IPv6 packet too short")
|
||||
}
|
||||
hopLimit := b[7]
|
||||
if hopLimit <= 1 {
|
||||
return nil, fmt.Errorf("connect-ip: datagram Hop Limit too small: %d", hopLimit)
|
||||
}
|
||||
b[7]--
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) Close() error {
|
||||
c.mu.Lock()
|
||||
if c.closeErr == nil {
|
||||
c.closeErr = net.ErrClosed
|
||||
close(c.closeChan)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
err := c.str.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func ipVersion(b []byte) uint8 { return b[0] >> 4 }
|
||||
|
||||
func calculateIPv4Checksum(header [ipv4HeaderLen]byte) uint16 {
|
||||
var sum uint32
|
||||
for i := 0; i < len(header); i += 2 {
|
||||
if i == 10 {
|
||||
continue
|
||||
}
|
||||
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||
}
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
type h2DatagramStream struct {
|
||||
requestBody *io.PipeWriter
|
||||
responseBody io.ReadCloser
|
||||
cancel context.CancelFunc
|
||||
|
||||
readMu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *h2DatagramStream) ReceiveDatagram(_ context.Context) ([]byte, error) {
|
||||
s.readMu.Lock()
|
||||
defer s.readMu.Unlock()
|
||||
|
||||
reader := quicvarint.NewReader(s.responseBody)
|
||||
for {
|
||||
capsuleType, err := quicvarint.Read(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen, err := quicvarint.Read(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := make([]byte, payloadLen)
|
||||
_, err = io.ReadFull(reader, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if capsuleType != h2DatagramCapsuleType {
|
||||
continue
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *h2DatagramStream) SendDatagram(data []byte) error {
|
||||
frame := make([]byte, 0, quicvarint.Len(h2DatagramCapsuleType)+quicvarint.Len(uint64(len(data)))+len(data))
|
||||
frame = quicvarint.Append(frame, h2DatagramCapsuleType)
|
||||
frame = quicvarint.Append(frame, uint64(len(data)))
|
||||
frame = append(frame, data...)
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
_, err := s.requestBody.Write(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect-ip: failed to send datagram capsule: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *h2DatagramStream) Close() error {
|
||||
_ = s.requestBody.Close()
|
||||
err := s.responseBody.Close()
|
||||
s.cancel()
|
||||
return err
|
||||
}
|
||||
@@ -2,9 +2,9 @@ package masque
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -12,13 +12,13 @@ import (
|
||||
|
||||
connectip "github.com/Diniboy1123/connect-ip-go"
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
qtls "github.com/sagernet/sing-quic"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
"github.com/yosida95/uritemplate/v3"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -26,39 +26,60 @@ type (
|
||||
ListenPacket func(network string, address string) (net.PacketConn, error)
|
||||
)
|
||||
|
||||
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool) (net.PacketConn, *http3.Transport, *connectip.Conn, *http.Response, error) {
|
||||
template := uritemplate.MustNew(connectUri)
|
||||
additionalHeaders := http.Header{
|
||||
"User-Agent": []string{""},
|
||||
type IpConn interface {
|
||||
ReadPacket() (b []byte, err error)
|
||||
WritePacket(b []byte) (icmp []byte, err error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type closerFunc func() error
|
||||
|
||||
func (f closerFunc) Close() error { return f() }
|
||||
|
||||
type quicIpConn struct {
|
||||
conn *connectip.Conn
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newQuicIpConn(conn *connectip.Conn) *quicIpConn {
|
||||
return &quicIpConn{
|
||||
conn: conn,
|
||||
buf: make([]byte, 0xFFFF),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *quicIpConn) ReadPacket() ([]byte, error) {
|
||||
n, err := c.conn.ReadPacket(c.buf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.buf[:n], nil
|
||||
}
|
||||
|
||||
func (c *quicIpConn) WritePacket(b []byte) (icmp []byte, err error) {
|
||||
return c.conn.WritePacket(b)
|
||||
}
|
||||
|
||||
func (c *quicIpConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool, congestionControl func(conn *quic.Conn) congestion.CongestionControl) (io.Closer, IpConn, *http.Response, error) {
|
||||
if useHTTP2 {
|
||||
h2Endpoint, ok := endpoint.(*net.TCPAddr)
|
||||
if !ok || h2Endpoint == nil {
|
||||
return nil, nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
|
||||
return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
|
||||
}
|
||||
h2Headers := additionalHeaders.Clone()
|
||||
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
|
||||
h2Headers.Set("pq-enabled", "false")
|
||||
h2Client, err := newHTTP2Client(dialer, tlsConfig, h2Endpoint, connectUri)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to create HTTP/2 client: %w", err)
|
||||
}
|
||||
ipConn, rsp, err := connectip.DialH2(ctx, h2Client, template, h2Headers)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "tls: access denied") {
|
||||
return nil, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
}
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
|
||||
}
|
||||
return nil, nil, ipConn, rsp, nil
|
||||
return ConnectTunnelH2(ctx, dialer, tlsConfig, h2Endpoint, connectUri)
|
||||
}
|
||||
|
||||
quicEndpoint, ok := endpoint.(*net.UDPAddr)
|
||||
if !ok || quicEndpoint == nil {
|
||||
return nil, nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint")
|
||||
return nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint")
|
||||
}
|
||||
udpConn, err := dialer.ListenPacket(ctx, M.SocksaddrFromNetIP(quicEndpoint.AddrPort()))
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
conn, err := qtls.Dial(
|
||||
ctx,
|
||||
@@ -68,28 +89,34 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
|
||||
quicConfig,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
_ = udpConn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if congestionControl != nil {
|
||||
conn.SetCongestionControl(congestionControl(conn))
|
||||
}
|
||||
tr := &http3.Transport{
|
||||
EnableDatagrams: true,
|
||||
AdditionalSettings: map[uint64]uint64{
|
||||
// official client still sends this out as well, even though
|
||||
// it's deprecated, see https://datatracker.ietf.org/doc/draft-ietf-masque-h3-datagram/00/
|
||||
// SETTINGS_H3_DATAGRAM_00 = 0x0000000000000276
|
||||
// https://github.com/cloudflare/quiche/blob/7c66757dbc55b8d0c3653d4b345c6785a181f0b7/quiche/src/h3/frame.rs#L46
|
||||
0x276: 1,
|
||||
},
|
||||
DisableCompression: true,
|
||||
}
|
||||
hconn := tr.NewClientConn(conn)
|
||||
|
||||
template := uritemplate.MustNew(connectUri)
|
||||
additionalHeaders := http.Header{
|
||||
"User-Agent": []string{""},
|
||||
}
|
||||
ipConn, rsp, err := connectip.Dial(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true)
|
||||
if err != nil {
|
||||
_ = tr.Close()
|
||||
_ = conn.CloseWithError(0, "connect-ip dial failed")
|
||||
_ = udpConn.Close()
|
||||
if strings.Contains(err.Error(), "tls: access denied") {
|
||||
return udpConn, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
}
|
||||
return udpConn, nil, nil, nil, fmt.Errorf("failed to dial connect-ip: %w", err)
|
||||
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %w", err)
|
||||
}
|
||||
err = ipConn.AdvertiseRoute(ctx, []connectip.IPRoute{
|
||||
{
|
||||
@@ -109,34 +136,16 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return udpConn, nil, nil, nil, err
|
||||
_ = ipConn.Close()
|
||||
_ = tr.Close()
|
||||
_ = udpConn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return udpConn, tr, ipConn, rsp, nil
|
||||
}
|
||||
|
||||
func newHTTP2Client(dialer N.Dialer, baseTLSConfig aTLS.Config, endpoint *net.TCPAddr, connectURI string) (*http.Client, error) {
|
||||
if endpoint == nil {
|
||||
return nil, errors.New("missing HTTP/2 endpoint")
|
||||
}
|
||||
tlsConfig := baseTLSConfig.Clone()
|
||||
tlsConfig.SetNextProtos([]string{"h2"})
|
||||
return &http.Client{
|
||||
Transport: &http2.Transport{
|
||||
DialTLSContext: func(ctx context.Context, network, _ string, _ *tls.Config) (net.Conn, error) {
|
||||
conn, err := dialer.DialContext(ctx, network, M.SocksaddrFromNetIP(endpoint.AddrPort()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConn, err := tlsConfig.Client(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return tlsConn, nil
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
closer := closerFunc(func() error {
|
||||
_ = tr.Close()
|
||||
_ = udpConn.Close()
|
||||
return nil
|
||||
})
|
||||
return closer, newQuicIpConn(ipConn), rsp, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
@@ -23,4 +25,5 @@ type TunnelOptions struct {
|
||||
UDPKeepalivePeriod time.Duration
|
||||
UDPInitialPacketSize uint16
|
||||
ReconnectDelay time.Duration
|
||||
CongestionControl func(conn *quic.Conn) congestion.CongestionControl
|
||||
}
|
||||
|
||||
@@ -4,13 +4,12 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
connectip "github.com/Diniboy1123/connect-ip-go"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
@@ -22,9 +21,8 @@ type Tunnel struct {
|
||||
options TunnelOptions
|
||||
device Device
|
||||
|
||||
udpConn net.PacketConn
|
||||
tr *http3.Transport
|
||||
ipConn *connectip.Conn
|
||||
closer io.Closer
|
||||
ipConn IpConn
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
@@ -83,13 +81,11 @@ func (e *Tunnel) Close() error {
|
||||
defer e.mtx.Unlock()
|
||||
if e.ipConn != nil {
|
||||
e.ipConn.Close()
|
||||
if e.udpConn != nil {
|
||||
e.udpConn.Close()
|
||||
}
|
||||
if e.tr != nil {
|
||||
e.tr.Close()
|
||||
if e.closer != nil {
|
||||
e.closer.Close()
|
||||
}
|
||||
e.ipConn = nil
|
||||
e.closer = nil
|
||||
}
|
||||
return e.device.Close()
|
||||
}
|
||||
@@ -124,7 +120,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
}
|
||||
icmp, err := ipConn.WritePacket(packet)
|
||||
if err != nil {
|
||||
if errors.As(err, new(*connectip.CloseError)) {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
if ok := e.closeIpConn(ipConn); ok {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing to IP connection: %w", err))
|
||||
}
|
||||
@@ -135,7 +131,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
}
|
||||
if len(icmp) > 0 {
|
||||
if _, err := e.device.Write([][]byte{icmp}, 0); err != nil {
|
||||
if errors.As(err, new(*connectip.CloseError)) {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing ICMP to TUN device: %v", err))
|
||||
continue
|
||||
}
|
||||
@@ -145,15 +141,14 @@ func (e *Tunnel) maintainTunnel() {
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
buf := make([]byte, 1280)
|
||||
for e.ctx.Err() == nil {
|
||||
ipConn, err := e.getIpConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n, err := ipConn.ReadPacket(buf, true)
|
||||
packet, err := ipConn.ReadPacket()
|
||||
if err != nil {
|
||||
if e.options.UseHTTP2 || errors.As(err, new(*connectip.CloseError)) {
|
||||
if e.options.UseHTTP2 || errors.Is(err, net.ErrClosed) {
|
||||
if ok := e.closeIpConn(ipConn); ok {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while reading from IP connection: %v", err))
|
||||
}
|
||||
@@ -162,7 +157,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("Error reading from IP connection: %v, continuine...", err))
|
||||
continue
|
||||
}
|
||||
if _, err := e.device.Write([][]byte{buf[:n]}, 0); err != nil {
|
||||
if _, err := e.device.Write([][]byte{packet}, 0); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -170,7 +165,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
<-e.ctx.Done()
|
||||
}
|
||||
|
||||
func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
func (e *Tunnel) getIpConn() (IpConn, error) {
|
||||
e.mtx.Lock()
|
||||
defer e.mtx.Unlock()
|
||||
if e.ctx.Err() != nil {
|
||||
@@ -184,7 +179,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
defer timer.Stop()
|
||||
for {
|
||||
e.logger.NoticeContext(e.ctx, fmt.Errorf("Establishing MASQUE connection to %s", e.options.Endpoint))
|
||||
udpConn, tr, ipConn, rsp, err := ConnectTunnel(
|
||||
closer, ipConn, rsp, err := ConnectTunnel(
|
||||
e.ctx,
|
||||
e.options.Dialer,
|
||||
e.options.TLSConfig,
|
||||
@@ -192,6 +187,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
"https://cloudflareaccess.com",
|
||||
e.options.Endpoint,
|
||||
e.options.UseHTTP2,
|
||||
e.options.CongestionControl,
|
||||
)
|
||||
if err != nil {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("Failed to connect tunnel: %v", err))
|
||||
@@ -206,11 +202,8 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
if rsp.StatusCode != 200 {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("Tunnel connection failed: %s", rsp.Status))
|
||||
ipConn.Close()
|
||||
if udpConn != nil {
|
||||
udpConn.Close()
|
||||
}
|
||||
if tr != nil {
|
||||
tr.Close()
|
||||
if closer != nil {
|
||||
closer.Close()
|
||||
}
|
||||
timer.Reset(e.options.ReconnectDelay)
|
||||
select {
|
||||
@@ -220,26 +213,23 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
e.udpConn = udpConn
|
||||
e.tr = tr
|
||||
e.closer = closer
|
||||
e.ipConn = ipConn
|
||||
e.logger.NoticeContext(e.ctx, "Connected to MASQUE server ", e.options.Endpoint)
|
||||
return ipConn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Tunnel) closeIpConn(ipConn *connectip.Conn) bool {
|
||||
func (e *Tunnel) closeIpConn(ipConn IpConn) bool {
|
||||
e.mtx.Lock()
|
||||
defer e.mtx.Unlock()
|
||||
if ipConn == e.ipConn {
|
||||
e.ipConn.Close()
|
||||
if e.udpConn != nil {
|
||||
e.udpConn.Close()
|
||||
}
|
||||
if e.tr != nil {
|
||||
e.tr.Close()
|
||||
if e.closer != nil {
|
||||
e.closer.Close()
|
||||
}
|
||||
e.ipConn = nil
|
||||
e.closer = nil
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
@@ -23,7 +24,7 @@ const (
|
||||
|
||||
type DataCipher interface {
|
||||
Encrypt(header []byte, packetID uint32, payload []byte) ([]byte, error)
|
||||
Decrypt(packet []byte, headerSize int) ([]byte, error)
|
||||
Decrypt(packet []byte, headerSize int) (plaintext []byte, packetID uint32, err error)
|
||||
}
|
||||
|
||||
type AEADDataCipher struct {
|
||||
@@ -86,9 +87,9 @@ func (g *AEADDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) {
|
||||
func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
|
||||
if len(packet) < headerSize+4+AESGCMTagSize+1 {
|
||||
return nil, errors.New("openvpn gcm data packet too short")
|
||||
return nil, 0, errors.New("openvpn gcm data packet too short")
|
||||
}
|
||||
header := packet[:headerSize]
|
||||
pidBytes := packet[headerSize : headerSize+4]
|
||||
@@ -96,8 +97,13 @@ func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error)
|
||||
ciphertext := packet[headerSize+4+AESGCMTagSize:]
|
||||
combined := append(ciphertext, tag...)
|
||||
ad := append(header, pidBytes...)
|
||||
nonce := g.nonce(binary.BigEndian.Uint32(pidBytes), g.recvImplicitIV)
|
||||
return g.recv.Open(nil, nonce[:], combined, ad)
|
||||
packetID := binary.BigEndian.Uint32(pidBytes)
|
||||
nonce := g.nonce(packetID, g.recvImplicitIV)
|
||||
plain, err := g.recv.Open(nil, nonce[:], combined, ad)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return plain, packetID, nil
|
||||
}
|
||||
|
||||
func (g *AEADDataCipher) nonce(packetID uint32, implicit [AESGCMIVSize]byte) [AESGCMIVSize]byte {
|
||||
@@ -127,6 +133,9 @@ func NewCBCCipher(keys *KeyMaterial, auth string) (*CBCDataCipher, error) {
|
||||
var newHash func() hash.Hash
|
||||
var hmacSize int
|
||||
switch auth {
|
||||
case AuthMD5:
|
||||
newHash = md5.New
|
||||
hmacSize = md5.Size
|
||||
case AuthSHA256:
|
||||
newHash = sha256.New
|
||||
hmacSize = sha256.Size
|
||||
@@ -176,34 +185,35 @@ func (c *CBCDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) {
|
||||
func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
|
||||
minSize := headerSize + c.hmacSize + CBCIVSize + aes.BlockSize
|
||||
if len(packet) < minSize {
|
||||
return nil, errors.New("openvpn cbc data packet too short")
|
||||
return nil, 0, errors.New("openvpn cbc data packet too short")
|
||||
}
|
||||
tag := packet[headerSize : headerSize+c.hmacSize]
|
||||
iv := packet[headerSize+c.hmacSize : headerSize+c.hmacSize+CBCIVSize]
|
||||
ct := packet[headerSize+c.hmacSize+CBCIVSize:]
|
||||
if len(ct)%aes.BlockSize != 0 {
|
||||
return nil, errors.New("openvpn cbc ciphertext not block-aligned")
|
||||
return nil, 0, errors.New("openvpn cbc ciphertext not block-aligned")
|
||||
}
|
||||
mac := hmac.New(c.newHash, c.recvHMAC)
|
||||
mac.Write(iv)
|
||||
mac.Write(ct)
|
||||
if !hmac.Equal(tag, mac.Sum(nil)) {
|
||||
return nil, errors.New("openvpn cbc hmac verification failed")
|
||||
return nil, 0, errors.New("openvpn cbc hmac verification failed")
|
||||
}
|
||||
plain := make([]byte, len(ct))
|
||||
cipher.NewCBCDecrypter(c.recvBlock, iv).CryptBlocks(plain, ct)
|
||||
padLen := int(plain[len(plain)-1])
|
||||
if padLen < 1 || padLen > aes.BlockSize {
|
||||
return nil, errors.New("openvpn cbc invalid padding")
|
||||
return nil, 0, errors.New("openvpn cbc invalid padding")
|
||||
}
|
||||
plain = plain[:len(plain)-padLen]
|
||||
if len(plain) < 4 {
|
||||
return nil, errors.New("openvpn cbc payload too short")
|
||||
return nil, 0, errors.New("openvpn cbc payload too short")
|
||||
}
|
||||
return plain[4:], nil
|
||||
packetID := binary.BigEndian.Uint32(plain[:4])
|
||||
return plain[4:], packetID, nil
|
||||
}
|
||||
|
||||
func CipherKeyLength(cipher string) int {
|
||||
|
||||
@@ -8,12 +8,16 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
const defaultHandshakeTimeout = 30 * time.Second
|
||||
const (
|
||||
defaultHandshakeTimeout = 30 * time.Second
|
||||
controlRetransmitDelay = time.Second
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
config *ClientConfig
|
||||
@@ -26,6 +30,8 @@ type Client struct {
|
||||
push *PushReply
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
lastReceiveNano atomic.Int64
|
||||
}
|
||||
|
||||
func NewClient(config *ClientConfig, io PacketIO, tlsConfig tls.Config) (*Client, error) {
|
||||
@@ -154,6 +160,7 @@ func (c *Client) Handshake(ctx context.Context) (*PushReply, error) {
|
||||
return nil, err
|
||||
}
|
||||
c.data = NewDataChannel(cipher, push.PeerID, push.CompLZO)
|
||||
c.markReceive()
|
||||
return push, nil
|
||||
}
|
||||
|
||||
@@ -181,10 +188,21 @@ func (c *Client) ReadIPPacket(ctx context.Context) ([]byte, error) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c.markReceive()
|
||||
return plain, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SinceReceive() time.Duration {
|
||||
return time.Duration(int64(time.Since(clientStart)) - c.lastReceiveNano.Load())
|
||||
}
|
||||
|
||||
func (c *Client) markReceive() {
|
||||
c.lastReceiveNano.Store(int64(time.Since(clientStart)))
|
||||
}
|
||||
|
||||
var clientStart = time.Now().Add(-time.Hour)
|
||||
|
||||
func (c *Client) Close() error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
@@ -199,10 +217,24 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) waitServerReset(ctx context.Context) error {
|
||||
retransmits := 0
|
||||
for {
|
||||
packet, err := c.control.Read(ctx)
|
||||
readCtx := ctx
|
||||
cancel := func() {}
|
||||
if c.config.Proto == ProtoUDP {
|
||||
readCtx, cancel = context.WithTimeout(ctx, controlRetransmitDelay)
|
||||
}
|
||||
packet, err := c.control.Read(readCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read hard reset response: %w", err)
|
||||
if c.config.Proto == ProtoUDP && errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil {
|
||||
if err := c.control.RetransmitPending(ctx); err != nil {
|
||||
return fmt.Errorf("retransmit hard reset: %w", err)
|
||||
}
|
||||
retransmits++
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read hard reset response after %d retransmits: %w", retransmits, err)
|
||||
}
|
||||
switch packet.Opcode {
|
||||
case PControlHardResetServerV2:
|
||||
|
||||
@@ -20,6 +20,7 @@ const (
|
||||
CipherAES256CBC = "AES-256-CBC"
|
||||
CipherCHACHA20POLY = "CHACHA20-POLY1305"
|
||||
|
||||
AuthMD5 = "MD5"
|
||||
AuthSHA1 = "SHA1"
|
||||
AuthSHA256 = "SHA256"
|
||||
AuthSHA384 = "SHA384"
|
||||
@@ -107,7 +108,7 @@ func isValidCipher(cipher string) bool {
|
||||
|
||||
func isValidAuth(auth string) bool {
|
||||
switch auth {
|
||||
case AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512:
|
||||
case AuthMD5, AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -30,8 +30,10 @@ type ControlChannel struct {
|
||||
mu sync.Mutex
|
||||
sendPacketID uint32
|
||||
sendMessage uint32
|
||||
recvMessage uint32
|
||||
ackPending []uint32
|
||||
pending map[uint32]*ControlPacket
|
||||
recvPending map[uint32]*ControlPacket
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
@@ -40,9 +42,10 @@ func NewControlChannel(io PacketIO, crypt ControlCrypt, local SessionID) *Contro
|
||||
ch := &ControlChannel{
|
||||
io: io,
|
||||
|
||||
clock: time.Now,
|
||||
local: local,
|
||||
pending: make(map[uint32]*ControlPacket),
|
||||
clock: time.Now,
|
||||
local: local,
|
||||
pending: make(map[uint32]*ControlPacket),
|
||||
recvPending: make(map[uint32]*ControlPacket),
|
||||
}
|
||||
if crypt != nil {
|
||||
ch.encode = func(p *ControlPacket, pid uint32, t uint32) ([]byte, error) {
|
||||
@@ -130,10 +133,23 @@ func (c *ControlChannel) SendAck(ctx context.Context) error {
|
||||
|
||||
func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
|
||||
for {
|
||||
c.mu.Lock()
|
||||
if packet, ok := c.recvPending[c.recvMessage]; ok {
|
||||
delete(c.recvPending, c.recvMessage)
|
||||
c.recvMessage++
|
||||
c.mu.Unlock()
|
||||
return packet, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
packet, err := c.readControlPacket(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var deliver *ControlPacket
|
||||
sendAck := false
|
||||
|
||||
c.mu.Lock()
|
||||
if c.remote == (SessionID{}) && packet.LocalSession != c.local {
|
||||
c.remote = packet.LocalSession
|
||||
@@ -144,11 +160,33 @@ func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
|
||||
if packet.Opcode.HasMessageID() {
|
||||
c.ackPending = appendAck(c.ackPending, packet.MessageID)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
if packet.Opcode == PAckV1 {
|
||||
continue
|
||||
|
||||
switch {
|
||||
case packet.Opcode == PAckV1:
|
||||
case !packet.Opcode.HasMessageID():
|
||||
deliver = packet
|
||||
case packet.MessageID < c.recvMessage:
|
||||
sendAck = true
|
||||
case packet.MessageID == c.recvMessage:
|
||||
deliver = packet
|
||||
c.recvMessage++
|
||||
default:
|
||||
if _, exists := c.recvPending[packet.MessageID]; !exists {
|
||||
c.recvPending[packet.MessageID] = packet
|
||||
}
|
||||
sendAck = true
|
||||
}
|
||||
|
||||
c.mu.Unlock()
|
||||
|
||||
if deliver != nil {
|
||||
return deliver, nil
|
||||
}
|
||||
if sendAck {
|
||||
if err := c.SendAck(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return packet, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -349,11 +387,17 @@ func (c *ControlConn) SetWriteDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
type streamPacketIO struct {
|
||||
conn net.Conn
|
||||
conn net.Conn
|
||||
deadlineMu sync.Mutex
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
|
||||
type datagramPacketIO struct {
|
||||
conn net.Conn
|
||||
conn net.Conn
|
||||
deadlineMu sync.Mutex
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
|
||||
func NewDatagramPacketIO(conn net.Conn) PacketIO {
|
||||
@@ -361,40 +405,23 @@ func NewDatagramPacketIO(conn net.Conn) PacketIO {
|
||||
}
|
||||
|
||||
func (d *datagramPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
|
||||
done := make(chan struct{})
|
||||
var (
|
||||
packet []byte
|
||||
err error
|
||||
)
|
||||
go func() {
|
||||
defer close(done)
|
||||
buf := make([]byte, 64*1024)
|
||||
var n int
|
||||
n, err = d.conn.Read(buf)
|
||||
if err == nil {
|
||||
packet = cloneBytes(buf[:n])
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return packet, err
|
||||
if err := setReadDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.readDeadline); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := make([]byte, 64*1024)
|
||||
n, err := d.conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, contextIOError(ctx, err)
|
||||
}
|
||||
return buf[:n], nil
|
||||
}
|
||||
|
||||
func (d *datagramPacketIO) WritePacket(ctx context.Context, packet []byte) error {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := d.conn.Write(packet)
|
||||
done <- err
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
if err := setWriteDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.writeDeadline); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := d.conn.Write(packet)
|
||||
return contextIOError(ctx, err)
|
||||
}
|
||||
|
||||
func (d *datagramPacketIO) Close() error {
|
||||
@@ -414,52 +441,37 @@ func NewTCPPacketIO(conn net.Conn) PacketIO {
|
||||
}
|
||||
|
||||
func (s *streamPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
|
||||
done := make(chan struct{})
|
||||
var (
|
||||
packet []byte
|
||||
err error
|
||||
)
|
||||
go func() {
|
||||
defer close(done)
|
||||
var lenBuf [2]byte
|
||||
if _, err = io.ReadFull(s.conn, lenBuf[:]); err != nil {
|
||||
return
|
||||
}
|
||||
size := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if size == 0 {
|
||||
err = errors.New("empty openvpn tcp packet")
|
||||
return
|
||||
}
|
||||
packet = make([]byte, size)
|
||||
_, err = io.ReadFull(s.conn, packet)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return packet, err
|
||||
if err := setReadDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.readDeadline); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var lenBuf [2]byte
|
||||
if _, err := io.ReadFull(s.conn, lenBuf[:]); err != nil {
|
||||
return nil, contextIOError(ctx, err)
|
||||
}
|
||||
size := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if size == 0 {
|
||||
return nil, errors.New("empty openvpn tcp packet")
|
||||
}
|
||||
packet := make([]byte, size)
|
||||
if _, err := io.ReadFull(s.conn, packet); err != nil {
|
||||
return nil, contextIOError(ctx, err)
|
||||
}
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func (s *streamPacketIO) WritePacket(ctx context.Context, packet []byte) error {
|
||||
if len(packet) > 0xffff {
|
||||
return fmt.Errorf("openvpn tcp packet too large: %d", len(packet))
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
frame := make([]byte, 2+len(packet))
|
||||
frame[0] = byte(len(packet) >> 8)
|
||||
frame[1] = byte(len(packet))
|
||||
copy(frame[2:], packet)
|
||||
_, err := s.conn.Write(frame)
|
||||
done <- err
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
if err := setWriteDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.writeDeadline); err != nil {
|
||||
return err
|
||||
}
|
||||
frame := make([]byte, 2+len(packet))
|
||||
frame[0] = byte(len(packet) >> 8)
|
||||
frame[1] = byte(len(packet))
|
||||
copy(frame[2:], packet)
|
||||
_, err := s.conn.Write(frame)
|
||||
return contextIOError(ctx, err)
|
||||
}
|
||||
|
||||
func (s *streamPacketIO) Close() error {
|
||||
@@ -473,3 +485,50 @@ func (s *streamPacketIO) LocalAddr() net.Addr {
|
||||
func (s *streamPacketIO) RemoteAddr() net.Addr {
|
||||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func setReadDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if current.Equal(deadline) {
|
||||
return nil
|
||||
}
|
||||
if hasDeadline {
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return err
|
||||
}
|
||||
*current = deadline
|
||||
return nil
|
||||
}
|
||||
|
||||
func setWriteDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if current.Equal(deadline) {
|
||||
return nil
|
||||
}
|
||||
if hasDeadline {
|
||||
if err := conn.SetWriteDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := conn.SetWriteDeadline(time.Time{}); err != nil {
|
||||
return err
|
||||
}
|
||||
*current = deadline
|
||||
return nil
|
||||
}
|
||||
|
||||
func contextIOError(ctx context.Context, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() && ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,15 +8,21 @@ import (
|
||||
|
||||
const (
|
||||
PeerIDUnset uint32 = 0xffffff
|
||||
|
||||
dataChannelReplayWindow = 64
|
||||
)
|
||||
|
||||
type DataChannel struct {
|
||||
cipher DataCipher
|
||||
keyID uint8
|
||||
peerID uint32
|
||||
compLZO bool
|
||||
cipher DataCipher
|
||||
keyID uint8
|
||||
peerID uint32
|
||||
compLZO bool
|
||||
|
||||
mu sync.Mutex
|
||||
sendPacketID uint32
|
||||
recvHighest uint32
|
||||
recvWindow uint64
|
||||
recvSeen bool
|
||||
}
|
||||
|
||||
func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel {
|
||||
@@ -29,10 +35,11 @@ func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel
|
||||
|
||||
func (d *DataChannel) Encrypt(packet []byte) ([]byte, error) {
|
||||
if d.compLZO {
|
||||
p := make([]byte, 1+len(packet))
|
||||
p[0] = 0xFA
|
||||
copy(p[1:], packet)
|
||||
packet = p
|
||||
compressed, err := lzo1xCompressSafe(packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packet = compressed
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.sendPacketID++
|
||||
@@ -50,18 +57,15 @@ func (d *DataChannel) Decrypt(packet []byte) ([]byte, error) {
|
||||
if opcode == PDataV2 {
|
||||
headerSize = 4
|
||||
}
|
||||
plain, err := d.cipher.Decrypt(packet, headerSize)
|
||||
plain, packetID, err := d.cipher.Decrypt(packet, headerSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.acceptPacketID(packetID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.compLZO {
|
||||
if len(plain) < 1 {
|
||||
return nil, errors.New("openvpn comp-lzo packet too short")
|
||||
}
|
||||
if plain[0] != 0xFA {
|
||||
return nil, fmt.Errorf("openvpn compressed packet not supported (byte: 0x%02x)", plain[0])
|
||||
}
|
||||
plain = plain[1:]
|
||||
return lzo1xDecompressSafe(plain)
|
||||
}
|
||||
return plain, nil
|
||||
}
|
||||
@@ -78,6 +82,40 @@ func (d *DataChannel) dataHeader() []byte {
|
||||
return []byte{opcodeKeyID(PDataV1, d.keyID)}
|
||||
}
|
||||
|
||||
func (d *DataChannel) acceptPacketID(packetID uint32) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if !d.recvSeen {
|
||||
d.recvHighest = packetID
|
||||
d.recvWindow = 1
|
||||
d.recvSeen = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if packetID > d.recvHighest {
|
||||
shift := packetID - d.recvHighest
|
||||
if shift >= dataChannelReplayWindow {
|
||||
d.recvWindow = 1
|
||||
} else {
|
||||
d.recvWindow = d.recvWindow<<shift | 1
|
||||
}
|
||||
d.recvHighest = packetID
|
||||
return nil
|
||||
}
|
||||
|
||||
diff := d.recvHighest - packetID
|
||||
if diff >= dataChannelReplayWindow {
|
||||
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
|
||||
}
|
||||
mask := uint64(1) << diff
|
||||
if d.recvWindow&mask != 0 {
|
||||
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
|
||||
}
|
||||
d.recvWindow |= mask
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParsePeerID(options string) uint32 {
|
||||
for _, field := range splitPushOptions(options) {
|
||||
if len(field) > len("peer-id ") && field[:len("peer-id ")] == "peer-id " {
|
||||
|
||||
444
transport/openvpn/e2e_test.go
Normal file
444
transport/openvpn/e2e_test.go
Normal file
@@ -0,0 +1,444 @@
|
||||
//go:build with_openvpn && with_gvisor
|
||||
|
||||
// OpenVPN E2E tests. Require a local OpenVPN server setup.
|
||||
//
|
||||
// Setup (run once before testing):
|
||||
//
|
||||
// # Generate PKI
|
||||
// mkdir -p /tmp/ovpn-e2e/pki/{issued,private}
|
||||
// cd /tmp/ovpn-e2e/pki
|
||||
// openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -days 1 -nodes -keyout ca.key -out ca.crt -subj "/CN=E2ETestCA"
|
||||
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/server.key -out server.csr -subj "/CN=server"
|
||||
// openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/server.crt -days 1
|
||||
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/client.key -out client.csr -subj "/CN=client"
|
||||
// openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/client.crt -days 1
|
||||
// openvpn --genkey secret ta.key
|
||||
// openvpn --genkey secret ta-auth.key
|
||||
//
|
||||
// # Start servers (4 instances: TCP/UDP × tls-crypt/tls-auth)
|
||||
// # TCP + tls-crypt on :11940, subnet 10.99.0.0/24
|
||||
// # UDP + tls-crypt on :11941, subnet 10.99.1.0/24
|
||||
// # TCP + tls-auth on :11942, subnet 10.99.2.0/24
|
||||
// # UDP + tls-auth on :11943, subnet 10.99.3.0/24
|
||||
// #
|
||||
// # Each server config needs: topology subnet, duplicate-cn, persist-tun,
|
||||
// # data-ciphers AES-256-GCM:AES-128-GCM:AES-192-GCM:CHACHA20-POLY1305:AES-256-CBC:AES-128-CBC:AES-192-CBC
|
||||
// # auth SHA256, keepalive 10 60, ca/cert/key from above PKI.
|
||||
// # tls-auth servers use: tls-auth ta-auth.key 0
|
||||
// # tls-crypt servers use: tls-crypt ta.key
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-crypt.conf --daemon
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-crypt.conf --daemon
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-auth.conf --daemon
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-auth.conf --daemon
|
||||
//
|
||||
// # Start HTTP servers on each VPN subnet
|
||||
// for ip in 10.99.0.1 10.99.1.1 10.99.2.1 10.99.3.1; do
|
||||
// mkdir -p /tmp/ovpn-e2e/$ip && echo "hello" > /tmp/ovpn-e2e/$ip/index.html
|
||||
// cd /tmp/ovpn-e2e/$ip && python3 -m http.server 8080 --bind $ip &
|
||||
// done
|
||||
//
|
||||
// Run tests:
|
||||
//
|
||||
// go test -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_manager,with_admin_panel,with_v2ray_api,with_ccm,with_ocm,with_profiler,with_openvpn,with_sudoku,with_trusttunnel" \
|
||||
// -run TestE2E -v -count=1 ./transport/openvpn/ -timeout 300s
|
||||
//
|
||||
// Tests all 28 combinations: 2 protos (tcp/udp) × 2 TLS modes (tls-crypt/tls-auth) × 7 ciphers.
|
||||
|
||||
package openvpn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box"
|
||||
"github.com/sagernet/sing-box/include"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
// Servers (started externally):
|
||||
// TCP+tls-crypt :11940 subnet 10.99.0.0/24
|
||||
// UDP+tls-crypt :11941 subnet 10.99.1.0/24
|
||||
// TCP+tls-auth :11942 subnet 10.99.2.0/24
|
||||
// UDP+tls-auth :11943 subnet 10.99.3.0/24
|
||||
// TCP+plain :11944 subnet 10.99.4.0/24
|
||||
// UDP+plain :11945 subnet 10.99.5.0/24
|
||||
// TCP+tls-crypt+SHA1 :11946 subnet 10.99.6.0/24 (CBC only)
|
||||
// TCP+tls-crypt+SHA512 :11947 subnet 10.99.7.0/24 (CBC only)
|
||||
// Each has HTTP on .1:8080 serving "hello"
|
||||
|
||||
const pkiDir = "/tmp/ovpn-e2e/pki"
|
||||
|
||||
type serverConfig struct {
|
||||
proto string
|
||||
port uint16
|
||||
tlsMode string // "tls-crypt" or "tls-auth"
|
||||
httpAddr string
|
||||
}
|
||||
|
||||
var servers = []serverConfig{
|
||||
{"tcp", 11940, "tls-crypt", "10.99.0.1:8080"},
|
||||
{"udp", 11941, "tls-crypt", "10.99.1.1:8080"},
|
||||
{"tcp", 11942, "tls-auth", "10.99.2.1:8080"},
|
||||
{"udp", 11943, "tls-auth", "10.99.3.1:8080"},
|
||||
}
|
||||
|
||||
var ciphers = []string{
|
||||
"AES-128-GCM",
|
||||
"AES-192-GCM",
|
||||
"AES-256-GCM",
|
||||
"CHACHA20-POLY1305",
|
||||
"AES-128-CBC",
|
||||
"AES-192-CBC",
|
||||
"AES-256-CBC",
|
||||
}
|
||||
|
||||
var portCounter atomic.Uint32
|
||||
|
||||
func init() { portCounter.Store(18100) }
|
||||
|
||||
func nextPort() uint16 { return uint16(portCounter.Add(1)) }
|
||||
|
||||
func readFile(t *testing.T, path string) string {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Skipf("PKI not found: %v", err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func testCombo(t *testing.T, srv serverConfig, cipher string) {
|
||||
t.Helper()
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
|
||||
ovpnOpts := &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: srv.port}},
|
||||
Proto: srv.proto,
|
||||
Cipher: cipher,
|
||||
Auth: "SHA256",
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
}
|
||||
|
||||
switch srv.tlsMode {
|
||||
case "tls-crypt":
|
||||
ovpnOpts.TLSCrypt = readFile(t, pkiDir+"/ta.key")
|
||||
case "tls-auth":
|
||||
ovpnOpts.TLSAuth = readFile(t, pkiDir+"/ta-auth.key")
|
||||
ovpnOpts.KeyDirection = 1
|
||||
}
|
||||
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer instance.Close()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(srv.httpAddr))
|
||||
if err != nil {
|
||||
t.Fatal("dial:", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
|
||||
if err != nil {
|
||||
t.Fatal("write:", err)
|
||||
}
|
||||
body, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Fatal("read:", err)
|
||||
}
|
||||
if !strings.Contains(string(body), "hello") {
|
||||
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
|
||||
}
|
||||
}
|
||||
|
||||
// 4 servers × 7 ciphers = 28 combinations
|
||||
func TestE2E(t *testing.T) {
|
||||
for _, srv := range servers {
|
||||
for _, cipher := range ciphers {
|
||||
name := fmt.Sprintf("%s/%s/%s", srv.proto, srv.tlsMode, cipher)
|
||||
srv, cipher := srv, cipher
|
||||
t.Run(name, func(t *testing.T) {
|
||||
testCombo(t, srv, cipher)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test CBC ciphers with different auth algorithms (SHA1, SHA512)
|
||||
func TestE2E_Auth(t *testing.T) {
|
||||
type authServer struct {
|
||||
port uint16
|
||||
auth string
|
||||
httpAddr string
|
||||
}
|
||||
authServers := []authServer{
|
||||
{11946, "SHA1", "10.99.6.1:8080"},
|
||||
{11947, "SHA512", "10.99.7.1:8080"},
|
||||
}
|
||||
cbcCiphers := []string{"AES-128-CBC", "AES-256-CBC"}
|
||||
|
||||
for _, as := range authServers {
|
||||
for _, cipher := range cbcCiphers {
|
||||
name := fmt.Sprintf("auth-%s/%s", as.auth, cipher)
|
||||
as, cipher := as, cipher
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{
|
||||
Type: "openvpn", Tag: "vpn",
|
||||
Options: &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: as.port}},
|
||||
Proto: "tcp",
|
||||
Cipher: cipher,
|
||||
Auth: as.auth,
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer instance.Close()
|
||||
time.Sleep(2 * time.Second)
|
||||
doHTTPCheck(t, port, as.httpAddr)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test tunnel stability with multiple sequential requests
|
||||
func TestE2E_BulkData(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{
|
||||
Type: "openvpn", Tag: "vpn",
|
||||
Options: &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11940}},
|
||||
Proto: "tcp",
|
||||
Cipher: "AES-256-GCM",
|
||||
Auth: "SHA256",
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer instance.Close()
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
|
||||
for i := 0; i < 10; i++ {
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr("10.99.0.1:8080"))
|
||||
if err != nil {
|
||||
t.Fatalf("request %d dial: %v", i, err)
|
||||
}
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
fmt.Fprintf(conn, "GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n")
|
||||
body, err := io.ReadAll(conn)
|
||||
conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("request %d read: %v", i, err)
|
||||
}
|
||||
if !strings.Contains(string(body), "hello") {
|
||||
t.Fatalf("request %d: no 'hello'", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doHTTPCheck(t *testing.T, socksPort uint16, httpAddr string) {
|
||||
t.Helper()
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", socksPort), socks.Version5, "", "")
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(httpAddr))
|
||||
if err != nil {
|
||||
t.Fatal("dial:", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
|
||||
if err != nil {
|
||||
t.Fatal("write:", err)
|
||||
}
|
||||
body, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Fatal("read:", err)
|
||||
}
|
||||
if !strings.Contains(string(body), "hello") {
|
||||
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
|
||||
}
|
||||
}
|
||||
|
||||
func startInstance(t *testing.T, ovpnOpts *option.OpenVPNOutboundOptions) uint16 {
|
||||
t.Helper()
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { instance.Close() })
|
||||
time.Sleep(2 * time.Second)
|
||||
return port
|
||||
}
|
||||
|
||||
func TestE2E_CompLZO(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
|
||||
for _, cipher := range ciphers {
|
||||
cipher := cipher
|
||||
t.Run(cipher, func(t *testing.T) {
|
||||
port := startInstance(t, &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11948}},
|
||||
Proto: "udp",
|
||||
Cipher: cipher,
|
||||
Auth: "SHA256",
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
})
|
||||
doHTTPCheck(t, port, "10.99.8.1:8080")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestE2E_AES192(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
|
||||
type combo struct {
|
||||
proto string
|
||||
port uint16
|
||||
httpAddr string
|
||||
}
|
||||
for _, c := range []combo{
|
||||
{"tcp", 11940, "10.99.0.1:8080"},
|
||||
{"udp", 11941, "10.99.1.1:8080"},
|
||||
} {
|
||||
for _, cipher := range []string{"AES-192-GCM", "AES-192-CBC"} {
|
||||
c, cipher := c, cipher
|
||||
t.Run(fmt.Sprintf("%s/%s", c.proto, cipher), func(t *testing.T) {
|
||||
port := startInstance(t, &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: c.port}},
|
||||
Proto: c.proto,
|
||||
Cipher: cipher,
|
||||
Auth: "SHA256",
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
})
|
||||
doHTTPCheck(t, port, c.httpAddr)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ net.Conn
|
||||
@@ -114,7 +114,7 @@ func ParseServerKeyMethod2Record(packet []byte) (*KeyMethod2Record, error) {
|
||||
}
|
||||
|
||||
func DeriveClientKeyMaterial(sources KeySource2, clientSession, serverSession SessionID, cipherKeyLen int) (*KeyMaterial, error) {
|
||||
if cipherKeyLen != 16 && cipherKeyLen != 32 {
|
||||
if cipherKeyLen != 16 && cipherKeyLen != 24 && cipherKeyLen != 32 {
|
||||
return nil, fmt.Errorf("unsupported data cipher key length %d", cipherKeyLen)
|
||||
}
|
||||
var master [48]byte
|
||||
|
||||
48
transport/openvpn/lzo.go
Normal file
48
transport/openvpn/lzo.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/rasky/go-lzo"
|
||||
)
|
||||
|
||||
const (
|
||||
lzoCompressNone = 0xFA
|
||||
lzoCompressLZO = 0x66
|
||||
)
|
||||
|
||||
var ErrLZODecompress = errors.New("lzo decompression failed")
|
||||
|
||||
func lzo1xDecompressSafe(src []byte) ([]byte, error) {
|
||||
if len(src) == 0 {
|
||||
return nil, ErrLZODecompress
|
||||
}
|
||||
|
||||
switch src[0] {
|
||||
case lzoCompressNone:
|
||||
if len(src) > 1 {
|
||||
return src[1:], nil
|
||||
}
|
||||
return nil, nil
|
||||
case lzoCompressLZO:
|
||||
if len(src) > 1 {
|
||||
r := bytes.NewReader(src[1:])
|
||||
out, err := lzo.Decompress1X(r, len(src)-1, 0)
|
||||
if err != nil {
|
||||
return nil, ErrLZODecompress
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, ErrLZODecompress
|
||||
}
|
||||
}
|
||||
|
||||
func lzo1xCompressSafe(src []byte) ([]byte, error) {
|
||||
lzoPacket := make([]byte, 1+len(src))
|
||||
lzoPacket[0] = lzoCompressNone
|
||||
copy(lzoPacket[1:], src)
|
||||
return lzoPacket, nil
|
||||
}
|
||||
@@ -10,16 +10,17 @@ import (
|
||||
const PushRequest = "PUSH_REQUEST"
|
||||
|
||||
type PushReply struct {
|
||||
Raw string
|
||||
Prefixes []netip.Prefix
|
||||
DNS []netip.Addr
|
||||
PeerID uint32
|
||||
Cipher string
|
||||
Ping uint32
|
||||
MTU uint32
|
||||
CompLZO bool
|
||||
Redirect bool
|
||||
BlockIPv6 bool
|
||||
Raw string
|
||||
Prefixes []netip.Prefix
|
||||
DNS []netip.Addr
|
||||
PeerID uint32
|
||||
Cipher string
|
||||
Ping uint32
|
||||
PingRestart uint32
|
||||
MTU uint32
|
||||
CompLZO bool
|
||||
Redirect bool
|
||||
BlockIPv6 bool
|
||||
}
|
||||
|
||||
func ParsePushReply(message string) (*PushReply, error) {
|
||||
@@ -81,6 +82,12 @@ func ParsePushReply(message string) (*PushReply, error) {
|
||||
reply.Ping = uint32(v)
|
||||
}
|
||||
}
|
||||
case "ping-restart":
|
||||
if len(fields) >= 2 {
|
||||
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
|
||||
reply.PingRestart = uint32(v)
|
||||
}
|
||||
}
|
||||
case "tun-mtu":
|
||||
if len(fields) >= 2 {
|
||||
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
|
||||
@@ -113,27 +120,44 @@ func splitPushOptions(message string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseIPv4Ifconfig(address, mask string) (netip.Prefix, error) {
|
||||
func parseIPv4Ifconfig(address, maskOrPeer string) (netip.Prefix, error) {
|
||||
addr, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 address %q: %w", address, err)
|
||||
}
|
||||
maskAddr, err := netip.ParseAddr(mask)
|
||||
maskAddr, err := netip.ParseAddr(maskOrPeer)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", mask, err)
|
||||
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", maskOrPeer, err)
|
||||
}
|
||||
if !addr.Is4() || !maskAddr.Is4() {
|
||||
return netip.Prefix{}, fmt.Errorf("openvpn ifconfig requires ipv4 address and mask")
|
||||
}
|
||||
maskBytes := maskAddr.As4()
|
||||
|
||||
if ones, ok := ipv4MaskSize(maskAddr); ok {
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
}
|
||||
|
||||
// Some servers, including SoftEther/VPNGate in net30/p2p mode, push
|
||||
// "ifconfig <local> <remote>" rather than "ifconfig <local> <netmask>".
|
||||
// Use a host prefix for that local tunnel address.
|
||||
return netip.PrefixFrom(addr, 32), nil
|
||||
}
|
||||
|
||||
func ipv4MaskSize(mask netip.Addr) (int, bool) {
|
||||
maskBytes := mask.As4()
|
||||
ones := 0
|
||||
seenZero := false
|
||||
for _, b := range maskBytes {
|
||||
for i := 7; i >= 0; i-- {
|
||||
if b&(1<<i) == 0 {
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
seenZero = true
|
||||
continue
|
||||
}
|
||||
if seenZero {
|
||||
return 0, false
|
||||
}
|
||||
ones++
|
||||
}
|
||||
}
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
return ones, true
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package openvpn
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
@@ -35,6 +36,9 @@ func NewTLSAuth(staticKey []byte, keyDirection int, auth string) (*TLSAuth, erro
|
||||
var newHash func() hash.Hash
|
||||
var hmacSize int
|
||||
switch auth {
|
||||
case AuthMD5:
|
||||
newHash = md5.New
|
||||
hmacSize = md5.Size
|
||||
case AuthSHA256:
|
||||
newHash = sha256.New
|
||||
hmacSize = sha256.Size
|
||||
|
||||
@@ -30,16 +30,19 @@ type TunnelOptions struct {
|
||||
UDPTimeout time.Duration
|
||||
ReconnectDelay time.Duration
|
||||
PingInterval time.Duration
|
||||
PingRestart time.Duration
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger logger.ContextLogger
|
||||
options TunnelOptions
|
||||
device Device
|
||||
client *Client
|
||||
mtu uint32
|
||||
serverIndex int
|
||||
wg sync.WaitGroup
|
||||
|
||||
await chan struct{}
|
||||
mu sync.Mutex
|
||||
@@ -49,8 +52,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
|
||||
if options.ReconnectDelay == 0 {
|
||||
options.ReconnectDelay = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Tunnel{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
options: options,
|
||||
await: make(chan struct{}),
|
||||
@@ -59,10 +64,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
|
||||
|
||||
func (t *Tunnel) Start() error {
|
||||
go func() {
|
||||
defer close(t.await)
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
t.logger.Error("OpenVPN connect: ", err)
|
||||
close(t.await)
|
||||
return
|
||||
}
|
||||
t.mtu = 1500
|
||||
@@ -84,20 +89,26 @@ func (t *Tunnel) Start() error {
|
||||
if err != nil {
|
||||
client.Close()
|
||||
t.logger.Error("create OpenVPN device: ", err)
|
||||
close(t.await)
|
||||
return
|
||||
}
|
||||
t.device = device
|
||||
if err := device.Start(); err != nil {
|
||||
client.Close()
|
||||
t.logger.Error("start OpenVPN device: ", err)
|
||||
close(t.await)
|
||||
return
|
||||
}
|
||||
close(t.await)
|
||||
t.maintainTunnel()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if err := t.isTunnelInitialized(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
@@ -105,6 +116,9 @@ func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.
|
||||
}
|
||||
|
||||
func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if err := t.isTunnelInitialized(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
@@ -112,15 +126,18 @@ func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
|
||||
}
|
||||
|
||||
func (t *Tunnel) Close() error {
|
||||
t.cancel()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.client != nil {
|
||||
t.client.Close()
|
||||
t.client = nil
|
||||
}
|
||||
if t.device != nil {
|
||||
return t.device.Close()
|
||||
t.device.Close()
|
||||
t.device = nil
|
||||
}
|
||||
t.mu.Unlock()
|
||||
t.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -137,7 +154,9 @@ func (t *Tunnel) isTunnelInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (t *Tunnel) maintainTunnel() {
|
||||
t.wg.Add(2)
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
bufs := make([][]byte, 1)
|
||||
bufs[0] = make([]byte, t.mtu)
|
||||
sizes := make([]int, 1)
|
||||
@@ -161,6 +180,7 @@ func (t *Tunnel) maintainTunnel() {
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
for t.ctx.Err() == nil {
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
@@ -179,10 +199,14 @@ func (t *Tunnel) maintainTunnel() {
|
||||
if bytes.Equal(packet, pingPayload) {
|
||||
continue
|
||||
}
|
||||
if t.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if t.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if _, err := t.device.Write([][]byte{packet}, 0); err != nil {
|
||||
if t.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -208,6 +232,34 @@ func (t *Tunnel) maintainTunnel() {
|
||||
}
|
||||
}()
|
||||
}
|
||||
pingRestart := t.options.PingRestart
|
||||
if pingRestart == 0 && t.client != nil && t.client.push.PingRestart > 0 {
|
||||
pingRestart = time.Duration(t.client.push.PingRestart) * time.Second
|
||||
}
|
||||
if pingRestart > 0 {
|
||||
t.wg.Add(1)
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
ticker := time.NewTicker(pingRestart)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if client.SinceReceive() >= pingRestart {
|
||||
if ok := t.closeClient(client); ok {
|
||||
t.logger.ErrorContext(t.ctx, fmt.Errorf("ping-restart timeout: no packet received for %s", pingRestart))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
<-t.ctx.Done()
|
||||
}
|
||||
|
||||
|
||||
100
transport/simple-obfs/http_server.go
Normal file
100
transport/simple-obfs/http_server.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package obfs
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTPObfsServer is the server side of the shadowsocks http simple-obfs implementation.
|
||||
type HTTPObfsServer struct {
|
||||
net.Conn
|
||||
buf []byte
|
||||
bio *bufio.Reader
|
||||
offset int
|
||||
firstRequest bool
|
||||
firstResponse bool
|
||||
}
|
||||
|
||||
func (hos *HTTPObfsServer) Read(b []byte) (int, error) {
|
||||
if hos.buf != nil {
|
||||
n := copy(b, hos.buf[hos.offset:])
|
||||
hos.offset += n
|
||||
if hos.offset == len(hos.buf) {
|
||||
hos.offset = 0
|
||||
hos.buf = nil
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
if hos.firstRequest {
|
||||
bio := bufio.NewReader(hos.Conn)
|
||||
req, err := http.ReadRequest(bio)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if req.Method != "GET" || req.Header.Get("Connection") != "Upgrade" {
|
||||
return 0, io.EOF
|
||||
}
|
||||
buf, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := copy(b, buf)
|
||||
if n < len(buf) {
|
||||
hos.buf = buf
|
||||
hos.offset = n
|
||||
}
|
||||
req.Body.Close()
|
||||
hos.bio = bio
|
||||
hos.firstRequest = false
|
||||
return n, nil
|
||||
}
|
||||
return hos.bio.Read(b)
|
||||
}
|
||||
|
||||
func (hos *HTTPObfsServer) Write(b []byte) (int, error) {
|
||||
if hos.firstResponse {
|
||||
randBytes := make([]byte, 16)
|
||||
cryptorand.Read(randBytes)
|
||||
date := time.Now().Format(time.RFC1123)
|
||||
resp := fmt.Sprintf(httpResponseTemplate, vMajor, vMinor, date, base64.URLEncoding.EncodeToString(randBytes))
|
||||
_, err := hos.Conn.Write([]byte(resp))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
hos.firstResponse = false
|
||||
}
|
||||
return hos.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (hos *HTTPObfsServer) Upstream() any {
|
||||
return hos.Conn
|
||||
}
|
||||
|
||||
// NewHTTPObfsServer returns a server-side HTTPObfs.
|
||||
func NewHTTPObfsServer(conn net.Conn) net.Conn {
|
||||
return &HTTPObfsServer{
|
||||
Conn: conn,
|
||||
firstRequest: true,
|
||||
firstResponse: true,
|
||||
}
|
||||
}
|
||||
|
||||
const httpResponseTemplate = "HTTP/1.1 101 Switching Protocols\r\n" +
|
||||
"Server: nginx/1.%d.%d\r\n" +
|
||||
"Date: %s\r\n" +
|
||||
"Upgrade: websocket\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Sec-WebSocket-Accept: %s\r\n" +
|
||||
"\r\n"
|
||||
|
||||
var (
|
||||
vMajor = rand.IntN(11)
|
||||
vMinor = rand.IntN(12)
|
||||
)
|
||||
154
transport/simple-obfs/tls_server.go
Normal file
154
transport/simple-obfs/tls_server.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package obfs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
B "github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
// TLSObfsServer is the server side of the shadowsocks tls simple-obfs implementation.
|
||||
type TLSObfsServer struct {
|
||||
net.Conn
|
||||
remain int
|
||||
firstRequest bool
|
||||
sessionTicketDone bool
|
||||
firstResponse bool
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) read(b []byte, discardN int) (int, error) {
|
||||
buf := B.Get(discardN)
|
||||
_, err := io.ReadFull(tos.Conn, buf)
|
||||
B.Put(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizeBuf := make([]byte, 2)
|
||||
_, err = io.ReadFull(tos.Conn, sizeBuf)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
length := int(binary.BigEndian.Uint16(sizeBuf))
|
||||
if length > len(b) {
|
||||
n, err := tos.Conn.Read(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
tos.remain = length - n
|
||||
return n, nil
|
||||
}
|
||||
return io.ReadFull(tos.Conn, b[:length])
|
||||
}
|
||||
|
||||
// skipOtherExts skips SNI & other TLS extensions.
|
||||
func (tos *TLSObfsServer) skipOtherExts() error {
|
||||
buf := make([]byte, 256)
|
||||
_, err := tos.read(buf, 7)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.ReadFull(tos.Conn, buf[:4*16+2])
|
||||
return err
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) Read(b []byte) (int, error) {
|
||||
if tos.remain > 0 {
|
||||
length := tos.remain
|
||||
if length > len(b) {
|
||||
length = len(b)
|
||||
}
|
||||
n, err := io.ReadFull(tos.Conn, b[:length])
|
||||
tos.remain -= n
|
||||
return n, err
|
||||
}
|
||||
if tos.firstRequest {
|
||||
tos.firstRequest = false
|
||||
return tos.read(b, 9*16-4)
|
||||
}
|
||||
if !tos.sessionTicketDone {
|
||||
tos.sessionTicketDone = true
|
||||
err := tos.skipOtherExts()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return tos.read(b, 3)
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) Write(b []byte) (int, error) {
|
||||
length := len(b)
|
||||
for i := 0; i < length; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > length {
|
||||
end = length
|
||||
}
|
||||
n, err := tos.write(b[i:end])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
return length, nil
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) write(b []byte) (int, error) {
|
||||
if tos.firstResponse {
|
||||
serverHello := makeServerHello(b)
|
||||
_, err := tos.Conn.Write(serverHello)
|
||||
tos.firstResponse = false
|
||||
return len(b), err
|
||||
}
|
||||
buf := B.NewSize(5 + len(b))
|
||||
defer buf.Release()
|
||||
buf.Write([]byte{0x17, 0x03, 0x03})
|
||||
binary.Write(buf, binary.BigEndian, uint16(len(b)))
|
||||
buf.Write(b)
|
||||
_, err := tos.Conn.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (tos *TLSObfsServer) Upstream() any {
|
||||
return tos.Conn
|
||||
}
|
||||
|
||||
// NewTLSObfsServer returns a server-side TLS SimpleObfs.
|
||||
func NewTLSObfsServer(conn net.Conn) net.Conn {
|
||||
return &TLSObfsServer{
|
||||
Conn: conn,
|
||||
firstRequest: true,
|
||||
firstResponse: true,
|
||||
}
|
||||
}
|
||||
|
||||
func makeServerHello(data []byte) []byte {
|
||||
randBytes := make([]byte, 28)
|
||||
sessionId := make([]byte, 32)
|
||||
rand.Read(randBytes)
|
||||
rand.Read(sessionId)
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteByte(0x16)
|
||||
binary.Write(buf, binary.BigEndian, uint16(0x0301))
|
||||
binary.Write(buf, binary.BigEndian, uint16(91))
|
||||
buf.Write([]byte{2, 0, 0, 87, 0x03, 0x03})
|
||||
binary.Write(buf, binary.BigEndian, uint32(time.Now().Unix()))
|
||||
buf.Write(randBytes)
|
||||
buf.WriteByte(32)
|
||||
buf.Write(sessionId)
|
||||
buf.Write([]byte{0xcc, 0xa8})
|
||||
buf.WriteByte(0)
|
||||
buf.Write([]byte{0x00, 0x00})
|
||||
buf.Write([]byte{0xff, 0x01, 0x00, 0x01, 0x00})
|
||||
buf.Write([]byte{0x00, 0x17, 0x00, 0x00})
|
||||
buf.Write([]byte{0x00, 0x0b, 0x00, 0x02, 0x01, 0x00})
|
||||
buf.Write([]byte{0x14, 0x03, 0x03, 0x00, 0x01, 0x01})
|
||||
buf.Write([]byte{0x16, 0x03, 0x03})
|
||||
binary.Write(buf, binary.BigEndian, uint16(len(data)))
|
||||
buf.Write(data)
|
||||
return buf.Bytes()
|
||||
}
|
||||
144
transport/snell/address.go
Normal file
144
transport/snell/address.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// SOCKS address types as defined in RFC 1928 section 5.
|
||||
const (
|
||||
atypIPv4 = 1
|
||||
atypDomainName = 3
|
||||
atypIPv6 = 4
|
||||
)
|
||||
|
||||
// socksAddr represents a SOCKS address as defined in RFC 1928 section 5.
|
||||
type socksAddr []byte
|
||||
|
||||
func (a socksAddr) String() string {
|
||||
var host, port string
|
||||
switch a[0] {
|
||||
case atypDomainName:
|
||||
hostLen := uint16(a[1])
|
||||
host = string(a[2 : 2+hostLen])
|
||||
port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
|
||||
case atypIPv4:
|
||||
host = net.IP(a[1 : 1+net.IPv4len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
|
||||
case atypIPv6:
|
||||
host = net.IP(a[1 : 1+net.IPv6len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// UDPAddr converts a socksAddr to *net.UDPAddr.
|
||||
func (a socksAddr) UDPAddr() *net.UDPAddr {
|
||||
if len(a) == 0 {
|
||||
return nil
|
||||
}
|
||||
switch a[0] {
|
||||
case atypIPv4:
|
||||
var ip [net.IPv4len]byte
|
||||
copy(ip[0:], a[1:1+net.IPv4len])
|
||||
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
|
||||
case atypIPv6:
|
||||
var ip [net.IPv6len]byte
|
||||
copy(ip[0:], a[1:1+net.IPv6len])
|
||||
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitSocksAddr slices a SOCKS address from beginning of b. Returns nil if failed.
|
||||
func splitSocksAddr(b []byte) socksAddr {
|
||||
addrLen := 1
|
||||
if len(b) < addrLen {
|
||||
return nil
|
||||
}
|
||||
switch b[0] {
|
||||
case atypDomainName:
|
||||
if len(b) < 2 {
|
||||
return nil
|
||||
}
|
||||
addrLen = 1 + 1 + int(b[1]) + 2
|
||||
case atypIPv4:
|
||||
addrLen = 1 + net.IPv4len + 2
|
||||
case atypIPv6:
|
||||
addrLen = 1 + net.IPv6len + 2
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
if len(b) < addrLen {
|
||||
return nil
|
||||
}
|
||||
return b[:addrLen]
|
||||
}
|
||||
|
||||
// parseAddr parses the address in string s. Returns nil if failed.
|
||||
func parseAddr(s string) socksAddr {
|
||||
var addr socksAddr
|
||||
host, port, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
addr = make([]byte, 1+net.IPv4len+2)
|
||||
addr[0] = atypIPv4
|
||||
copy(addr[1:], ip4)
|
||||
} else {
|
||||
addr = make([]byte, 1+net.IPv6len+2)
|
||||
addr[0] = atypIPv6
|
||||
copy(addr[1:], ip)
|
||||
}
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
return nil
|
||||
}
|
||||
addr = make([]byte, 1+1+len(host)+2)
|
||||
addr[0] = atypDomainName
|
||||
addr[1] = byte(len(host))
|
||||
copy(addr[2:], host)
|
||||
}
|
||||
portnum, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
|
||||
return addr
|
||||
}
|
||||
|
||||
// parseAddrToSocksAddr parses a socks addr from net.Addr.
|
||||
// This is a fast path of parseAddr(addr.String()).
|
||||
func parseAddrToSocksAddr(addr net.Addr) socksAddr {
|
||||
var hostip net.IP
|
||||
var port int
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
hostip = addr.IP
|
||||
port = addr.Port
|
||||
case *net.TCPAddr:
|
||||
hostip = addr.IP
|
||||
port = addr.Port
|
||||
case nil:
|
||||
return nil
|
||||
}
|
||||
if hostip == nil {
|
||||
return parseAddr(addr.String())
|
||||
}
|
||||
var parsed socksAddr
|
||||
if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
|
||||
parsed = make([]byte, 1+net.IPv4len+2)
|
||||
parsed[0] = atypIPv4
|
||||
copy(parsed[1:], ip4)
|
||||
binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
|
||||
} else {
|
||||
parsed = make([]byte, 1+net.IPv6len+2)
|
||||
parsed[0] = atypIPv6
|
||||
copy(parsed[1:], hostip)
|
||||
binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
56
transport/snell/cipher.go
Normal file
56
transport/snell/cipher.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// NewAES128GCM returns the AES-128-GCM cipher used by snell v2/v3.
|
||||
func NewAES128GCM(psk []byte) Cipher {
|
||||
return &snellCipher{
|
||||
psk: psk,
|
||||
keySize: 16,
|
||||
makeAEAD: aesGCM,
|
||||
}
|
||||
}
|
||||
|
||||
// NewChacha20Poly1305 returns the ChaCha20-Poly1305 cipher used by snell v1.
|
||||
func NewChacha20Poly1305(psk []byte) Cipher {
|
||||
return &snellCipher{
|
||||
psk: psk,
|
||||
keySize: 32,
|
||||
makeAEAD: chacha20poly1305.New,
|
||||
}
|
||||
}
|
||||
|
||||
type snellCipher struct {
|
||||
psk []byte
|
||||
keySize int
|
||||
makeAEAD func(key []byte) (cipher.AEAD, error)
|
||||
}
|
||||
|
||||
func (sc *snellCipher) KeySize() int { return sc.keySize }
|
||||
func (sc *snellCipher) SaltSize() int { return 16 }
|
||||
|
||||
func (sc *snellCipher) Encrypter(salt []byte) (cipher.AEAD, error) {
|
||||
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
|
||||
}
|
||||
|
||||
func (sc *snellCipher) Decrypter(salt []byte) (cipher.AEAD, error) {
|
||||
return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
|
||||
}
|
||||
|
||||
func snellKDF(psk, salt []byte, keySize int) []byte {
|
||||
return argon2.IDKey(psk, salt, 3, 8, 1, 32)[:keySize]
|
||||
}
|
||||
|
||||
func aesGCM(key []byte) (cipher.AEAD, error) {
|
||||
blk, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cipher.NewGCM(blk)
|
||||
}
|
||||
120
transport/snell/client.go
Normal file
120
transport/snell/client.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type ClientOptions struct {
|
||||
Dialer N.Dialer
|
||||
Server M.Socksaddr
|
||||
PSK []byte
|
||||
Version int
|
||||
Reuse bool
|
||||
ObfsMode string
|
||||
ObfsHost string
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
dialer N.Dialer
|
||||
server M.Socksaddr
|
||||
psk []byte
|
||||
version int
|
||||
reuse bool
|
||||
obfsMode string
|
||||
obfsHost string
|
||||
pool *Pool
|
||||
}
|
||||
|
||||
func NewClient(options ClientOptions) *Client {
|
||||
c := &Client{
|
||||
dialer: options.Dialer,
|
||||
server: options.Server,
|
||||
psk: options.PSK,
|
||||
version: options.Version,
|
||||
reuse: options.Reuse,
|
||||
obfsMode: options.ObfsMode,
|
||||
obfsHost: options.ObfsHost,
|
||||
}
|
||||
if c.reuse {
|
||||
c.pool = NewPool(func(ctx context.Context) (*Snell, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.streamConn(conn), nil
|
||||
})
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) streamConn(conn net.Conn) *Snell {
|
||||
switch c.obfsMode {
|
||||
case "tls":
|
||||
conn = obfs.NewTLSObfs(conn, c.obfsHost)
|
||||
case "http":
|
||||
conn = obfs.NewHTTPObfs(conn, c.obfsHost, strconv.Itoa(int(c.server.Port)))
|
||||
}
|
||||
return StreamConn(conn, c.psk, c.version)
|
||||
}
|
||||
|
||||
func (c *Client) writeHeader(ctx context.Context, conn net.Conn, destination M.Socksaddr, udp bool) (err error) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = conn.SetWriteDeadline(deadline)
|
||||
defer conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
if udp {
|
||||
err = WriteUDPHeader(conn, c.version)
|
||||
if err == nil && c.version >= Version4 {
|
||||
if sc, ok := conn.(*Snell); ok {
|
||||
err = sc.ReadReply()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
err = WriteHeaderWithReuse(conn, destination.AddrString(), uint(destination.Port), c.version, c.reuse)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) DialContext(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
||||
if c.reuse {
|
||||
conn, err := c.pool.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = c.writeHeader(ctx, conn, destination, false); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream := c.streamConn(conn)
|
||||
if err = c.writeHeader(ctx, stream, destination, false); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream := c.streamConn(conn)
|
||||
if err = c.writeHeader(ctx, stream, destination, true); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return PacketConn(stream), nil
|
||||
}
|
||||
153
transport/snell/pool.go
Normal file
153
transport/snell/pool.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// poolEntry holds a pooled item with its insertion time.
|
||||
|
||||
// connPool is a small connection pool with age-based eviction.
|
||||
|
||||
// milliseconds
|
||||
|
||||
// Pool is a pool of reusable snell connections.
|
||||
type Pool struct {
|
||||
pool *connPool
|
||||
}
|
||||
|
||||
func (p *Pool) Get() (net.Conn, error) {
|
||||
return p.GetContext(context.Background())
|
||||
}
|
||||
|
||||
func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) {
|
||||
elm, err := p.pool.GetContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PoolConn{Snell: elm, pool: p}, nil
|
||||
}
|
||||
|
||||
func (p *Pool) Put(conn *Snell) {
|
||||
if err := HalfClose(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
p.pool.put(conn)
|
||||
}
|
||||
|
||||
// PoolConn wraps a pooled snell connection and returns it to the pool on Close.
|
||||
type PoolConn struct {
|
||||
*Snell
|
||||
pool *Pool
|
||||
closeWriteOnce sync.Once
|
||||
closeWriteErr error
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (pc *PoolConn) Read(b []byte) (int, error) {
|
||||
n, err := pc.Snell.Read(b)
|
||||
if err == ErrZeroChunk {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (pc *PoolConn) Write(b []byte) (int, error) {
|
||||
return pc.Snell.Write(b)
|
||||
}
|
||||
|
||||
func (pc *PoolConn) CloseWrite() error {
|
||||
pc.closeWriteOnce.Do(func() {
|
||||
pc.closeWriteErr = writeZeroChunk(pc.Snell)
|
||||
})
|
||||
return pc.closeWriteErr
|
||||
}
|
||||
|
||||
func (pc *PoolConn) Close() error {
|
||||
pc.closeOnce.Do(func() {
|
||||
if err := pc.CloseWrite(); err != nil {
|
||||
pc.closeErr = err
|
||||
_ = pc.Snell.Close()
|
||||
return
|
||||
}
|
||||
_ = pc.Snell.Conn.SetReadDeadline(time.Time{})
|
||||
pc.Snell.reply = false
|
||||
pc.pool.pool.put(pc.Snell)
|
||||
})
|
||||
return pc.closeErr
|
||||
}
|
||||
|
||||
// NewPool creates a new snell connection pool using the given factory.
|
||||
func NewPool(factory func(context.Context) (*Snell, error)) *Pool {
|
||||
cp := &connPool{
|
||||
ch: make(chan *poolEntry, 10),
|
||||
factory: factory,
|
||||
maxAge: 15000,
|
||||
evict: func(item *Snell) {
|
||||
_ = item.Close()
|
||||
},
|
||||
}
|
||||
p := &Pool{pool: cp}
|
||||
runtime.SetFinalizer(p, recycle)
|
||||
return p
|
||||
}
|
||||
|
||||
type poolEntry struct {
|
||||
elm *Snell
|
||||
time time.Time
|
||||
}
|
||||
|
||||
type connPool struct {
|
||||
ch chan *poolEntry
|
||||
factory func(context.Context) (*Snell, error)
|
||||
evict func(*Snell)
|
||||
maxAge int64
|
||||
}
|
||||
|
||||
func (p *connPool) GetContext(ctx context.Context) (*Snell, error) {
|
||||
now := time.Now()
|
||||
for {
|
||||
select {
|
||||
case item := <-p.ch:
|
||||
if p.maxAge != 0 && now.Sub(item.time).Milliseconds() > p.maxAge {
|
||||
if p.evict != nil {
|
||||
p.evict(item.elm)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return item.elm, nil
|
||||
default:
|
||||
return p.factory(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *connPool) put(item *Snell) {
|
||||
e := &poolEntry{
|
||||
elm: item,
|
||||
time: time.Now(),
|
||||
}
|
||||
select {
|
||||
case p.ch <- e:
|
||||
return
|
||||
default:
|
||||
if p.evict != nil {
|
||||
p.evict(item)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func recycle(p *Pool) {
|
||||
for item := range p.pool.ch {
|
||||
if p.pool.evict != nil {
|
||||
p.pool.evict(item.elm)
|
||||
}
|
||||
}
|
||||
}
|
||||
294
transport/snell/service.go
Normal file
294
transport/snell/service.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
obfs "github.com/sagernet/sing-box/transport/simple-obfs"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, clientID string)
|
||||
|
||||
NewPacketConnection(ctx context.Context, conn net.PacketConn, source M.Socksaddr, clientID string)
|
||||
}
|
||||
|
||||
type ServiceOptions struct {
|
||||
PSK []byte
|
||||
Version int
|
||||
ObfsMode string
|
||||
UDP bool
|
||||
Logger logger.ContextLogger
|
||||
Handler Handler
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
psk []byte
|
||||
version int
|
||||
obfsMode string
|
||||
udp bool
|
||||
logger logger.ContextLogger
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func NewService(options ServiceOptions) (*Service, error) {
|
||||
version := options.Version
|
||||
if version == 0 {
|
||||
version = Version4
|
||||
}
|
||||
if version != Version4 && version != Version5 {
|
||||
return nil, fmt.Errorf("snell inbound version %d is not supported", version)
|
||||
}
|
||||
if len(options.PSK) == 0 {
|
||||
return nil, errors.New("snell inbound requires psk")
|
||||
}
|
||||
switch options.ObfsMode {
|
||||
case "", "http", "tls":
|
||||
default:
|
||||
return nil, fmt.Errorf("snell inbound obfs mode error: %s", options.ObfsMode)
|
||||
}
|
||||
return &Service{
|
||||
psk: options.PSK,
|
||||
version: version,
|
||||
obfsMode: options.ObfsMode,
|
||||
udp: options.UDP,
|
||||
logger: options.Logger,
|
||||
handler: options.Handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(ctx context.Context, rawConn net.Conn, source M.Socksaddr) error {
|
||||
conn := rawConn
|
||||
switch s.obfsMode {
|
||||
case "http":
|
||||
conn = obfs.NewHTTPObfsServer(conn)
|
||||
case "tls":
|
||||
conn = obfs.NewTLSObfsServer(conn)
|
||||
}
|
||||
stream := ServerStreamConn(conn, s.psk, s.version)
|
||||
for {
|
||||
reuse, err := s.handleRequest(ctx, stream, source)
|
||||
if err != nil || !reuse {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleRequest(ctx context.Context, stream *Snell, source M.Socksaddr) (bool, error) {
|
||||
br := bufio.NewReader(stream)
|
||||
version, err := br.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if version != Version {
|
||||
return false, fmt.Errorf("snell invalid protocol version: %d", version)
|
||||
}
|
||||
command, err := br.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if command == CommandPing {
|
||||
_, _ = stream.Write([]byte{CommandPong})
|
||||
return false, nil
|
||||
}
|
||||
clientID, err := readClientID(br)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
switch command {
|
||||
case CommandConnect, CommandConnectV2:
|
||||
return s.handleTCP(ctx, stream, br, command == CommandConnectV2, clientID, source)
|
||||
case CommandUDP:
|
||||
if !s.udp {
|
||||
return false, errors.New("snell UDP is disabled")
|
||||
}
|
||||
return false, s.handleUDP(ctx, stream, clientID, source)
|
||||
default:
|
||||
return false, fmt.Errorf("snell unknown command: %d", command)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleTCP(ctx context.Context, stream *Snell, br *bufio.Reader, reuse bool, clientID string, source M.Socksaddr) (bool, error) {
|
||||
hostLen, err := br.ReadByte()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if hostLen == 0 {
|
||||
return false, errors.New("snell connect host is empty")
|
||||
}
|
||||
hostBytes := make([]byte, int(hostLen))
|
||||
if _, err := io.ReadFull(br, hostBytes); err != nil {
|
||||
return false, err
|
||||
}
|
||||
var portBytes [2]byte
|
||||
if _, err := io.ReadFull(br, portBytes[:]); err != nil {
|
||||
return false, err
|
||||
}
|
||||
destination := M.ParseSocksaddrHostPort(string(hostBytes), binary.BigEndian.Uint16(portBytes[:]))
|
||||
conn := &tcpRequestConn{
|
||||
Conn: stream,
|
||||
reader: br,
|
||||
reuse: reuse,
|
||||
}
|
||||
s.handler.NewConnection(ctx, conn, source, destination, clientID)
|
||||
if !reuse {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *Service) handleUDP(ctx context.Context, stream *Snell, clientID string, source M.Socksaddr) error {
|
||||
if _, err := stream.Write([]byte{CommandTunnel}); err != nil {
|
||||
return err
|
||||
}
|
||||
pc := &serverPacketConn{
|
||||
conn: stream,
|
||||
writeMu: &sync.Mutex{},
|
||||
}
|
||||
s.handler.NewPacketConnection(ctx, pc, source, clientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxPacketLength = 0x3fff
|
||||
|
||||
func readClientID(r *bufio.Reader) (string, error) {
|
||||
length, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
id := make([]byte, int(length))
|
||||
if _, err := io.ReadFull(r, id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(id), nil
|
||||
}
|
||||
|
||||
func writeCommandError(w io.Writer, code byte, message string) error {
|
||||
msg := []byte(message)
|
||||
if len(msg) > 255 {
|
||||
msg = msg[:255]
|
||||
}
|
||||
buf := make([]byte, 0, 3+len(msg))
|
||||
buf = append(buf, CommandError, code, byte(len(msg)))
|
||||
buf = append(buf, msg...)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
type tcpRequestConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
reuse bool
|
||||
writeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
replyWritten bool
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) Read(p []byte) (int, error) {
|
||||
n, err := c.reader.Read(p)
|
||||
if errors.Is(err, ErrZeroChunk) {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) Write(p []byte) (int, error) {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if !c.replyWritten {
|
||||
payload := make([]byte, 1+len(p))
|
||||
payload[0] = CommandTunnel
|
||||
copy(payload[1:], p)
|
||||
if _, err := c.Conn.Write(payload); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.replyWritten = true
|
||||
return len(p), nil
|
||||
}
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) CloseWrite() error {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
func (c *tcpRequestConn) Close() error {
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if !c.replyWritten {
|
||||
err = writeCommandError(c.Conn, 0x65, "Remote EOF")
|
||||
if !c.reuse {
|
||||
err = errors.Join(err, c.Conn.Close())
|
||||
}
|
||||
return
|
||||
}
|
||||
if c.reuse {
|
||||
_, err = c.Conn.Write(nil)
|
||||
return
|
||||
}
|
||||
err = c.Conn.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
type serverPacketConn struct {
|
||||
conn *Snell
|
||||
writeMu *sync.Mutex
|
||||
readBuf []byte
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
if c.readBuf == nil {
|
||||
c.readBuf = make([]byte, maxPacketLength)
|
||||
}
|
||||
for {
|
||||
n, err := c.conn.Read(c.readBuf)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, ErrZeroChunk) {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
request, err := ParseUDPRequest(c.readBuf[:n])
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
var destination M.Socksaddr
|
||||
if request.Ip.IsValid() {
|
||||
destination = M.SocksaddrFrom(request.Ip, request.Port)
|
||||
} else {
|
||||
destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
|
||||
}
|
||||
length := copy(p, request.Payload)
|
||||
if destination.IsFqdn() {
|
||||
return length, destination, nil
|
||||
}
|
||||
return length, destination.UDPAddr(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
return WritePacketResponse(c.conn, addr, p)
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Close() error { return c.conn.Close() }
|
||||
func (c *serverPacketConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||
func (c *serverPacketConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
||||
func (c *serverPacketConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
||||
func (c *serverPacketConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
||||
211
transport/snell/shadowaead.go
Normal file
211
transport/snell/shadowaead.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
// payloadSizeMask is the maximum size of payload in bytes.
|
||||
// 16*1024 - 1
|
||||
// >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead()
|
||||
|
||||
// ErrZeroChunk is returned when a zero-length chunk is read, which snell uses
|
||||
// as an end-of-stream signal.
|
||||
var ErrZeroChunk = errors.New("zero chunk")
|
||||
|
||||
// Cipher is the AEAD cipher abstraction used by the shadowaead stream.
|
||||
type Cipher interface {
|
||||
KeySize() int
|
||||
SaltSize() int
|
||||
Encrypter(salt []byte) (cipher.AEAD, error)
|
||||
Decrypter(salt []byte) (cipher.AEAD, error)
|
||||
}
|
||||
|
||||
const (
|
||||
payloadSizeMask = 0x3FFF
|
||||
bufSize = 17 * 1024
|
||||
)
|
||||
|
||||
type aeadWriter struct {
|
||||
io.Writer
|
||||
cipher.AEAD
|
||||
nonce [32]byte // should be sufficient for most nonce sizes
|
||||
}
|
||||
|
||||
// newAEADWriter wraps an io.Writer with authenticated encryption.
|
||||
func newAEADWriter(w io.Writer, aead cipher.AEAD) *aeadWriter {
|
||||
return &aeadWriter{Writer: w, AEAD: aead}
|
||||
}
|
||||
|
||||
// Write encrypts p and writes to the embedded io.Writer.
|
||||
func (w *aeadWriter) Write(p []byte) (n int, err error) {
|
||||
b := buf.Get(bufSize)
|
||||
defer buf.Put(b)
|
||||
nonce := w.nonce[:w.NonceSize()]
|
||||
tag := w.Overhead()
|
||||
off := 2 + tag
|
||||
if len(p) == 0 {
|
||||
b = b[:off]
|
||||
b[0], b[1] = byte(0), byte(0)
|
||||
w.Seal(b[:0], nonce, b[:2], nil)
|
||||
increment(nonce)
|
||||
_, err = w.Writer.Write(b)
|
||||
return
|
||||
}
|
||||
for nr := 0; n < len(p) && err == nil; n += nr {
|
||||
nr = payloadSizeMask
|
||||
if n+nr > len(p) {
|
||||
nr = len(p) - n
|
||||
}
|
||||
b = b[:off+nr+tag]
|
||||
b[0], b[1] = byte(nr>>8), byte(nr)
|
||||
w.Seal(b[:0], nonce, b[:2], nil)
|
||||
increment(nonce)
|
||||
w.Seal(b[:off], nonce, p[n:n+nr], nil)
|
||||
increment(nonce)
|
||||
_, err = w.Writer.Write(b)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type aeadReader struct {
|
||||
io.Reader
|
||||
cipher.AEAD
|
||||
nonce [32]byte // should be sufficient for most nonce sizes
|
||||
buf []byte // to be put back into bufPool
|
||||
off int // offset to unconsumed part of buf
|
||||
}
|
||||
|
||||
// newAEADReader wraps an io.Reader with authenticated decryption.
|
||||
func newAEADReader(r io.Reader, aead cipher.AEAD) *aeadReader {
|
||||
return &aeadReader{Reader: r, AEAD: aead}
|
||||
}
|
||||
|
||||
// read and decrypt a record into p. len(p) >= max payload size + AEAD overhead.
|
||||
func (r *aeadReader) read(p []byte) (int, error) {
|
||||
nonce := r.nonce[:r.NonceSize()]
|
||||
tag := r.Overhead()
|
||||
p = p[:2+tag]
|
||||
if _, err := io.ReadFull(r.Reader, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err := r.Open(p[:0], nonce, p, nil)
|
||||
increment(nonce)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask
|
||||
if size == 0 {
|
||||
return 0, ErrZeroChunk
|
||||
}
|
||||
p = p[:size+tag]
|
||||
if _, err := io.ReadFull(r.Reader, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = r.Open(p[:0], nonce, p, nil)
|
||||
increment(nonce)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// Read reads from the embedded io.Reader, decrypts and writes to p.
|
||||
func (r *aeadReader) Read(p []byte) (int, error) {
|
||||
if r.buf == nil {
|
||||
if len(p) >= payloadSizeMask+r.Overhead() {
|
||||
return r.read(p)
|
||||
}
|
||||
b := buf.Get(bufSize)
|
||||
n, err := r.read(b)
|
||||
if err != nil {
|
||||
buf.Put(b)
|
||||
return 0, err
|
||||
}
|
||||
r.buf = b[:n]
|
||||
r.off = 0
|
||||
}
|
||||
n := copy(p, r.buf[r.off:])
|
||||
r.off += n
|
||||
if r.off == len(r.buf) {
|
||||
buf.Put(r.buf[:cap(r.buf)])
|
||||
r.buf = nil
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// increment little-endian encoded unsigned integer b. Wrap around on overflow.
|
||||
func increment(b []byte) {
|
||||
for i := range b {
|
||||
b[i]++
|
||||
if b[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// aeadConn wraps a stream-oriented net.Conn with the shadowaead cipher.
|
||||
type aeadConn struct {
|
||||
net.Conn
|
||||
Cipher
|
||||
r *aeadReader
|
||||
w *aeadWriter
|
||||
}
|
||||
|
||||
// newAEADConn wraps a stream-oriented net.Conn with cipher.
|
||||
func newAEADConn(c net.Conn, ciph Cipher) *aeadConn {
|
||||
return &aeadConn{Conn: c, Cipher: ciph}
|
||||
}
|
||||
|
||||
func (c *aeadConn) initReader() error {
|
||||
salt := make([]byte, c.SaltSize())
|
||||
if _, err := io.ReadFull(c.Conn, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
aead, err := c.Decrypter(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.r = newAEADReader(c.Conn, aead)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) Read(b []byte) (int, error) {
|
||||
if c.r == nil {
|
||||
if err := c.initReader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *aeadConn) initWriter() error {
|
||||
salt := make([]byte, c.SaltSize())
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return err
|
||||
}
|
||||
aead, err := c.Encrypter(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.Conn.Write(salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.w = newAEADWriter(c.Conn, aead)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *aeadConn) Write(b []byte) (int, error) {
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.w.Write(b)
|
||||
}
|
||||
408
transport/snell/snell.go
Normal file
408
transport/snell/snell.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
const (
|
||||
Version1 = 1
|
||||
Version2 = 2
|
||||
Version3 = 3
|
||||
Version4 = 4
|
||||
Version5 = 5
|
||||
DefaultSnellVersion = Version1
|
||||
|
||||
// max packet length
|
||||
maxLength = 0x3FFF
|
||||
)
|
||||
|
||||
const (
|
||||
CommandPing byte = 0
|
||||
CommandConnect byte = 1
|
||||
CommandConnectV2 byte = 5
|
||||
CommandUDP byte = 6
|
||||
CommandUDPForward byte = 1
|
||||
|
||||
CommandTunnel byte = 0
|
||||
CommandPong byte = 1
|
||||
CommandError byte = 2
|
||||
|
||||
Version byte = 1
|
||||
)
|
||||
|
||||
// Snell wraps an encrypted stream and handles the snell reply header.
|
||||
type Snell struct {
|
||||
net.Conn
|
||||
buffer [1]byte
|
||||
reply bool
|
||||
}
|
||||
|
||||
func (s *Snell) Read(b []byte) (int, error) {
|
||||
if err := s.ReadReply(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (s *Snell) ReadReply() error {
|
||||
if s.reply {
|
||||
return nil
|
||||
}
|
||||
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
s.reply = true
|
||||
if s.buffer[0] == CommandTunnel {
|
||||
return nil
|
||||
} else if s.buffer[0] != CommandError {
|
||||
return errors.New("command not support")
|
||||
}
|
||||
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
errcode := int(s.buffer[0])
|
||||
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
length := int(s.buffer[0])
|
||||
msg := make([]byte, length)
|
||||
if _, err := io.ReadFull(s.Conn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
|
||||
}
|
||||
|
||||
func WriteHeader(conn net.Conn, host string, port uint, version int) error {
|
||||
return WriteHeaderWithReuse(conn, host, port, version, false)
|
||||
}
|
||||
|
||||
func WriteHeaderWithReuse(conn net.Conn, host string, port uint, version int, reuse bool) error {
|
||||
buffer := &bytes.Buffer{}
|
||||
buffer.WriteByte(Version)
|
||||
if version == Version2 || reuse {
|
||||
buffer.WriteByte(CommandConnectV2)
|
||||
} else {
|
||||
buffer.WriteByte(CommandConnect)
|
||||
}
|
||||
buffer.WriteByte(0)
|
||||
buffer.WriteByte(uint8(len(host)))
|
||||
buffer.WriteString(host)
|
||||
binary.Write(buffer, binary.BigEndian, uint16(port))
|
||||
if _, err := conn.Write(buffer.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WriteUDPHeader(conn net.Conn, version int) error {
|
||||
if version < Version3 {
|
||||
return errors.New("unsupport UDP version")
|
||||
}
|
||||
_, err := conn.Write([]byte{Version, CommandUDP, 0x00})
|
||||
return err
|
||||
}
|
||||
|
||||
// HalfClose only works after the request negotiated the reuse command.
|
||||
func HalfClose(conn net.Conn) error {
|
||||
if err := writeZeroChunk(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if s, ok := conn.(*Snell); ok {
|
||||
s.reply = false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamConn wraps a raw connection with the snell stream cipher for the given version.
|
||||
func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
|
||||
if version >= Version4 {
|
||||
return &Snell{Conn: newV4Conn(conn, psk)}
|
||||
}
|
||||
var cipher Cipher
|
||||
if version != Version1 {
|
||||
cipher = NewAES128GCM(psk)
|
||||
} else {
|
||||
cipher = NewChacha20Poly1305(psk)
|
||||
}
|
||||
return &Snell{Conn: newAEADConn(conn, cipher)}
|
||||
}
|
||||
|
||||
// ServerStreamConn wraps a raw connection on the server side.
|
||||
func ServerStreamConn(conn net.Conn, psk []byte, version int) *Snell {
|
||||
stream := StreamConn(conn, psk, version)
|
||||
stream.reply = true
|
||||
return stream
|
||||
}
|
||||
|
||||
func PacketConn(conn net.Conn) net.PacketConn {
|
||||
return &packetConn{
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Snell) WritePacketFrame(b []byte) (int, error) {
|
||||
if fw, ok := s.Conn.(packetFrameWriter); ok {
|
||||
return fw.WritePacketFrame(b)
|
||||
}
|
||||
return s.Conn.Write(b)
|
||||
}
|
||||
|
||||
func WritePacket(w io.Writer, target, payload []byte) (int, error) {
|
||||
maxPayloadLength := maxLength - udpRequestHeaderLength(target)
|
||||
if maxPayloadLength <= 0 {
|
||||
return 0, errors.New("snell UDP address too large")
|
||||
}
|
||||
if len(payload) <= maxPayloadLength {
|
||||
return writePacket(w, target, payload)
|
||||
}
|
||||
return 0, errors.New("snell UDP payload too large")
|
||||
}
|
||||
|
||||
func WritePacketResponse(w io.Writer, addr net.Addr, payload []byte) (int, error) {
|
||||
buffer := &bytes.Buffer{}
|
||||
target := parseAddrToSocksAddr(addr)
|
||||
if len(target) == 0 {
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
switch target[0] {
|
||||
case atypIPv4:
|
||||
if len(target) < 1+net.IPv4len+2 {
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
buffer.WriteByte(0x04)
|
||||
buffer.Write(target[1 : 1+net.IPv4len+2])
|
||||
case atypIPv6:
|
||||
if len(target) < 1+net.IPv6len+2 {
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
buffer.WriteByte(0x06)
|
||||
buffer.Write(target[1 : 1+net.IPv6len+2])
|
||||
default:
|
||||
return 0, errors.New("snell UDP response address invalid")
|
||||
}
|
||||
buffer.Write(payload)
|
||||
var err error
|
||||
if fw, ok := w.(packetFrameWriter); ok {
|
||||
_, err = fw.WritePacketFrame(buffer.Bytes())
|
||||
} else {
|
||||
_, err = w.Write(buffer.Bytes())
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
// UDPRequest is a parsed snell UDP forward request.
|
||||
type UDPRequest struct {
|
||||
Host string
|
||||
Ip netip.Addr
|
||||
Port uint16
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func ParseUDPRequest(packet []byte) (UDPRequest, error) {
|
||||
if len(packet) < 2 || packet[0] != CommandUDPForward {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP request")
|
||||
}
|
||||
if hostLen := int(packet[1]); hostLen != 0 {
|
||||
if len(packet) <= 2+hostLen+2 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP domain request")
|
||||
}
|
||||
offset := 2 + hostLen
|
||||
return UDPRequest{
|
||||
Host: string(packet[2:offset]),
|
||||
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
|
||||
Payload: packet[offset+2:],
|
||||
}, nil
|
||||
}
|
||||
if len(packet) < 3 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP IP request")
|
||||
}
|
||||
switch packet[2] {
|
||||
case 0x04:
|
||||
if len(packet) < 3+net.IPv4len+2 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP IPv4 request")
|
||||
}
|
||||
offset := 3 + net.IPv4len
|
||||
ip, _ := netip.AddrFromSlice(packet[3:offset])
|
||||
return UDPRequest{
|
||||
Ip: ip.Unmap(),
|
||||
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
|
||||
Payload: packet[offset+2:],
|
||||
}, nil
|
||||
case 0x06:
|
||||
if len(packet) < 3+net.IPv6len+2 {
|
||||
return UDPRequest{}, errors.New("snell invalid UDP IPv6 request")
|
||||
}
|
||||
offset := 3 + net.IPv6len
|
||||
ip, _ := netip.AddrFromSlice(packet[3:offset])
|
||||
return UDPRequest{
|
||||
Ip: ip.Unmap(),
|
||||
Port: binary.BigEndian.Uint16(packet[offset : offset+2]),
|
||||
Payload: packet[offset+2:],
|
||||
}, nil
|
||||
default:
|
||||
return UDPRequest{}, errors.New("snell invalid UDP address type")
|
||||
}
|
||||
}
|
||||
|
||||
func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) {
|
||||
b := buf.Get(buf.UDPBufferSize)
|
||||
defer buf.Put(b)
|
||||
n, err := r.Read(b)
|
||||
headLen := 1
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if n < headLen {
|
||||
return nil, 0, errors.New("insufficient UDP length")
|
||||
}
|
||||
switch b[0] {
|
||||
case 0x04:
|
||||
headLen += net.IPv4len + 2
|
||||
if n < headLen {
|
||||
err = errors.New("insufficient UDP length")
|
||||
break
|
||||
}
|
||||
b[0] = atypIPv4
|
||||
case 0x06:
|
||||
headLen += net.IPv6len + 2
|
||||
if n < headLen {
|
||||
err = errors.New("insufficient UDP length")
|
||||
break
|
||||
}
|
||||
b[0] = atypIPv6
|
||||
default:
|
||||
err = errors.New("ip version invalid")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
addr := splitSocksAddr(b[0:])
|
||||
if addr == nil {
|
||||
return nil, 0, errors.New("remote address invalid")
|
||||
}
|
||||
uAddr := addr.UDPAddr()
|
||||
if uAddr == nil {
|
||||
return nil, 0, errors.New("parse addr error")
|
||||
}
|
||||
length := len(payload)
|
||||
if n-headLen < length {
|
||||
length = n - headLen
|
||||
}
|
||||
copy(payload[:], b[headLen:headLen+length])
|
||||
return uAddr, length, nil
|
||||
}
|
||||
|
||||
var endSignal = []byte{}
|
||||
|
||||
type packetFrameWriter interface {
|
||||
WritePacketFrame([]byte) (int, error)
|
||||
}
|
||||
|
||||
func writeZeroChunk(conn net.Conn) error {
|
||||
if _, err := conn.Write(endSignal); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePacket(w io.Writer, target, payload []byte) (int, error) {
|
||||
buffer := &bytes.Buffer{}
|
||||
buffer.WriteByte(CommandUDPForward)
|
||||
switch target[0] {
|
||||
case atypDomainName:
|
||||
hostLen := target[1]
|
||||
if len(target) < 1+1+int(hostLen)+2 {
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write(target[1 : 1+1+hostLen+2])
|
||||
case atypIPv4:
|
||||
if len(target) < 1+net.IPv4len+2 {
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write([]byte{0x00, 0x04})
|
||||
buffer.Write(target[1 : 1+net.IPv4len+2])
|
||||
case atypIPv6:
|
||||
if len(target) < 1+net.IPv6len+2 {
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write([]byte{0x00, 0x06})
|
||||
buffer.Write(target[1 : 1+net.IPv6len+2])
|
||||
default:
|
||||
return 0, errors.New("snell UDP address invalid")
|
||||
}
|
||||
buffer.Write(payload)
|
||||
if fw, ok := w.(packetFrameWriter); ok {
|
||||
_, err := fw.WritePacketFrame(buffer.Bytes())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
_, err := w.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
func udpRequestHeaderLength(target []byte) int {
|
||||
if len(target) == 0 {
|
||||
return maxLength + 1
|
||||
}
|
||||
switch target[0] {
|
||||
case atypDomainName:
|
||||
if len(target) < 2 {
|
||||
return maxLength + 1
|
||||
}
|
||||
return 1 + 1 + int(target[1]) + 2
|
||||
case atypIPv4:
|
||||
return 1 + 2 + net.IPv4len + 2
|
||||
case atypIPv6:
|
||||
return 1 + 2 + net.IPv6len + 2
|
||||
default:
|
||||
return maxLength + 1
|
||||
}
|
||||
}
|
||||
|
||||
type packetConn struct {
|
||||
net.Conn
|
||||
rMux sync.Mutex
|
||||
wMux sync.Mutex
|
||||
}
|
||||
|
||||
func (pc *packetConn) WritePacketFrame(b []byte) (int, error) {
|
||||
if s, ok := pc.Conn.(*Snell); ok {
|
||||
if fw, ok := s.Conn.(packetFrameWriter); ok {
|
||||
return fw.WritePacketFrame(b)
|
||||
}
|
||||
}
|
||||
return pc.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
pc.wMux.Lock()
|
||||
defer pc.wMux.Unlock()
|
||||
return WritePacket(pc, parseAddrToSocksAddr(addr), b)
|
||||
}
|
||||
|
||||
func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
pc.rMux.Lock()
|
||||
defer pc.rMux.Unlock()
|
||||
addr, n, err := ReadPacket(pc.Conn, b)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
return n, addr, nil
|
||||
}
|
||||
463
transport/snell/v4.go
Normal file
463
transport/snell/v4.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package snell
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"math/big"
|
||||
"math/bits"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
v4SaltSize = 16
|
||||
v4NonceSize = 12
|
||||
v4HeaderPlainSize = 7
|
||||
v4HeaderCipherSize = v4HeaderPlainSize + 16
|
||||
v4FrameSize = 1460
|
||||
v4InitialPaddingMin = 0x100
|
||||
v4InitialPaddingSpan = 0x100
|
||||
)
|
||||
|
||||
type v4Conn struct {
|
||||
net.Conn
|
||||
psk []byte
|
||||
r *v4Reader
|
||||
w *v4Writer
|
||||
}
|
||||
|
||||
func newV4Conn(conn net.Conn, psk []byte) *v4Conn {
|
||||
return &v4Conn{Conn: conn, psk: psk}
|
||||
}
|
||||
|
||||
func (c *v4Conn) initReader() error {
|
||||
salt := make([]byte, v4SaltSize)
|
||||
if _, err := io.ReadFull(c.Conn, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
aead, err := v4AEAD(c.psk, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.r = &v4Reader{Reader: c.Conn, aead: aead}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *v4Conn) initWriter() error {
|
||||
w, err := newV4Writer(c.Conn, c.psk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.w = w
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *v4Conn) Read(b []byte) (int, error) {
|
||||
if c.r == nil {
|
||||
if err := c.initReader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *v4Conn) Write(b []byte) (int, error) {
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.w.Write(b)
|
||||
}
|
||||
|
||||
func (c *v4Conn) WritePacketFrame(b []byte) (int, error) {
|
||||
if len(b) > maxLength {
|
||||
return 0, errors.New("snell v4 frame too large")
|
||||
}
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
c.w.mux.Lock()
|
||||
defer c.w.mux.Unlock()
|
||||
if err := c.w.writeFrame(b, c.w.nextFramePaddingLength(len(b))); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *v4Conn) WriteTo(w io.Writer) (int64, error) {
|
||||
if c.r == nil {
|
||||
if err := c.initReader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
var written int64
|
||||
buf := make([]byte, maxLength)
|
||||
for {
|
||||
n, err := c.r.Read(buf)
|
||||
if n > 0 {
|
||||
nw, ew := w.Write(buf[:n])
|
||||
written += int64(nw)
|
||||
if ew != nil {
|
||||
return written, ew
|
||||
}
|
||||
if nw != n {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *v4Conn) ReadFrom(r io.Reader) (int64, error) {
|
||||
if c.w == nil {
|
||||
if err := c.initWriter(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
var read int64
|
||||
buf := make([]byte, maxLength)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
read += int64(n)
|
||||
if _, ew := c.w.Write(buf[:n]); ew != nil {
|
||||
return read, ew
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return read, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func v4AEAD(psk, salt []byte) (cipher.AEAD, error) {
|
||||
return aesGCM(snellKDF(psk, salt, 16))
|
||||
}
|
||||
|
||||
type v4Reader struct {
|
||||
io.Reader
|
||||
aead cipher.AEAD
|
||||
nonce [v4NonceSize]byte
|
||||
buf []byte
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func (r *v4Reader) Read(b []byte) (int, error) {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
if len(r.buf) == 0 {
|
||||
payload, err := r.readFrame()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r.buf = payload
|
||||
}
|
||||
n := copy(b, r.buf)
|
||||
r.buf = r.buf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *v4Reader) readFrame() ([]byte, error) {
|
||||
headerCipher := make([]byte, v4HeaderCipherSize)
|
||||
if _, err := io.ReadFull(r.Reader, headerCipher); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
header, err := r.aead.Open(headerCipher[:0], r.nonce[:], headerCipher, nil)
|
||||
incrementV4Nonce(r.nonce[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(header) != v4HeaderPlainSize || header[0] != 4 {
|
||||
return nil, errors.New("snell v4 invalid frame header")
|
||||
}
|
||||
paddingLength := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
payloadLength := int(binary.BigEndian.Uint16(header[5:7]))
|
||||
if payloadLength == 0 {
|
||||
if paddingLength != 0 {
|
||||
return nil, errors.New("snell v4 zero chunk with padding")
|
||||
}
|
||||
return nil, ErrZeroChunk
|
||||
}
|
||||
if payloadLength > maxLength || paddingLength > maxLength {
|
||||
return nil, errors.New("snell v4 frame too large")
|
||||
}
|
||||
payloadCipherLength := payloadLength + r.aead.Overhead()
|
||||
frame := make([]byte, paddingLength+payloadCipherLength)
|
||||
if _, err := io.ReadFull(r.Reader, frame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if paddingLength > 0 {
|
||||
swapPadding(frame[:paddingLength], frame[paddingLength:])
|
||||
}
|
||||
payloadCipher := frame[paddingLength:]
|
||||
payload, err := r.aead.Open(payloadCipher[:0], r.nonce[:], payloadCipher, nil)
|
||||
incrementV4Nonce(r.nonce[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
type v4Writer struct {
|
||||
io.Writer
|
||||
aead cipher.AEAD
|
||||
nonce [v4NonceSize]byte
|
||||
salt [v4SaltSize]byte
|
||||
saltSent bool
|
||||
initialPaddingLength uint16
|
||||
payloadLimit uint16
|
||||
lastWrite time.Time
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newV4Writer(w io.Writer, psk []byte) (*v4Writer, error) {
|
||||
var salt [v4SaltSize]byte
|
||||
if _, err := io.ReadFull(cryptorand.Reader, salt[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := v4AEAD(psk, salt[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddingDelta, err := cryptoRandomInt(v4InitialPaddingSpan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v4Writer{
|
||||
Writer: w,
|
||||
aead: aead,
|
||||
salt: salt,
|
||||
initialPaddingLength: uint16(v4InitialPaddingMin + paddingDelta),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *v4Writer) Write(b []byte) (int, error) {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if len(b) == 0 {
|
||||
return 0, w.writeFrame(nil, 0)
|
||||
}
|
||||
written := 0
|
||||
for written < len(b) {
|
||||
payloadLimit := int(w.nextPayloadLimit())
|
||||
if payloadLimit <= 0 || payloadLimit > maxLength {
|
||||
payloadLimit = maxLength
|
||||
}
|
||||
end := written + payloadLimit
|
||||
if end > len(b) {
|
||||
end = len(b)
|
||||
}
|
||||
paddingLength := w.nextFramePaddingLength(end - written)
|
||||
if err := w.writeFrame(b[written:end], paddingLength); err != nil {
|
||||
return written, err
|
||||
}
|
||||
written = end
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (w *v4Writer) nextPayloadLimit() uint16 {
|
||||
now := time.Now()
|
||||
var payloadLimit uint16
|
||||
switch {
|
||||
case w.lastWrite.IsZero():
|
||||
payloadLimit = v4FrameSize - 55 - w.initialPaddingLength
|
||||
case now.Sub(w.lastWrite) > 30*time.Second:
|
||||
payloadLimit = v4FrameSize - 39
|
||||
default:
|
||||
payloadLimit = w.payloadLimit
|
||||
}
|
||||
w.lastWrite = now
|
||||
if payloadLimit <= maxLength-1 {
|
||||
next := int(payloadLimit) + v4FrameSize - 39
|
||||
if next > maxLength {
|
||||
next = maxLength
|
||||
}
|
||||
w.payloadLimit = uint16(next)
|
||||
} else {
|
||||
w.payloadLimit = maxLength
|
||||
}
|
||||
return payloadLimit
|
||||
}
|
||||
|
||||
func (w *v4Writer) nextFramePaddingLength(payloadLength int) int {
|
||||
if w.saltSent || payloadLength == 0 {
|
||||
return 0
|
||||
}
|
||||
return int(w.initialPaddingLength)
|
||||
}
|
||||
|
||||
func (w *v4Writer) writeFrame(payload []byte, paddingLength int) error {
|
||||
if len(payload) > maxLength || paddingLength > maxLength {
|
||||
return errors.New("snell v4 frame too large")
|
||||
}
|
||||
if len(payload) == 0 && paddingLength != 0 {
|
||||
return errors.New("snell v4 zero chunk with padding")
|
||||
}
|
||||
header := make([]byte, v4HeaderPlainSize)
|
||||
header[0] = 4
|
||||
binary.BigEndian.PutUint16(header[3:5], uint16(paddingLength))
|
||||
binary.BigEndian.PutUint16(header[5:7], uint16(len(payload)))
|
||||
headerCipher := w.aead.Seal(nil, w.nonce[:], header, nil)
|
||||
incrementV4Nonce(w.nonce[:])
|
||||
var payloadCipher []byte
|
||||
if len(payload) > 0 {
|
||||
payloadCipher = w.aead.Seal(nil, w.nonce[:], payload, nil)
|
||||
incrementV4Nonce(w.nonce[:])
|
||||
}
|
||||
frameLength := len(headerCipher) + paddingLength + len(payloadCipher)
|
||||
if !w.saltSent {
|
||||
frameLength += v4SaltSize
|
||||
}
|
||||
frame := make([]byte, 0, frameLength)
|
||||
if !w.saltSent {
|
||||
frame = append(frame, w.salt[:]...)
|
||||
w.saltSent = true
|
||||
}
|
||||
frame = append(frame, headerCipher...)
|
||||
if paddingLength > 0 {
|
||||
padding, err := makeV4Padding(payloadCipher, paddingLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
swapPadding(padding, payloadCipher)
|
||||
frame = append(frame, padding...)
|
||||
}
|
||||
frame = append(frame, payloadCipher...)
|
||||
return writeFull(w.Writer, frame)
|
||||
}
|
||||
|
||||
func swapPadding(padding, payloadCipher []byte) {
|
||||
limit := len(padding)
|
||||
if len(payloadCipher) < limit {
|
||||
limit = len(payloadCipher)
|
||||
}
|
||||
for i := 0; i < limit; i += 2 {
|
||||
padding[i], payloadCipher[i] = payloadCipher[i], padding[i]
|
||||
}
|
||||
}
|
||||
|
||||
func makeV4Padding(payloadCipher []byte, paddingLength int) ([]byte, error) {
|
||||
if paddingLength <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
payloadOnes := countV4PayloadOnes(payloadCipher)
|
||||
payloadZeros := 8*len(payloadCipher) - payloadOnes
|
||||
if payloadZeros <= 0 {
|
||||
return makeV4RandomPadding(paddingLength)
|
||||
}
|
||||
ratio := float64(payloadOnes) / float64(payloadZeros)
|
||||
if ratio <= 0.5 || ratio >= 1.6 {
|
||||
return makeV4RandomPadding(paddingLength)
|
||||
}
|
||||
targetRatioBase := 1.6
|
||||
if payloadZeros < payloadOnes {
|
||||
targetRatioBase = 0.4
|
||||
}
|
||||
jitter, err := randomUnitFloat64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetRatio := targetRatioBase + jitter/10
|
||||
totalBits := 8 * (paddingLength + len(payloadCipher))
|
||||
targetOnes := int(float64(totalBits)*(targetRatio/(targetRatio+1)) - float64(payloadOnes))
|
||||
if targetOnes < 0 || targetOnes > 8*paddingLength {
|
||||
return makeV4RandomPadding(paddingLength)
|
||||
}
|
||||
return makeV4BitCountPadding(paddingLength, targetOnes)
|
||||
}
|
||||
|
||||
func countV4PayloadOnes(payloadCipher []byte) int {
|
||||
limit := len(payloadCipher) &^ 3
|
||||
ones := 0
|
||||
for _, b := range payloadCipher[:limit] {
|
||||
ones += bits.OnesCount8(b)
|
||||
}
|
||||
return ones
|
||||
}
|
||||
|
||||
func makeV4RandomPadding(length int) ([]byte, error) {
|
||||
padding := make([]byte, length)
|
||||
_, err := io.ReadFull(cryptorand.Reader, padding)
|
||||
return padding, err
|
||||
}
|
||||
|
||||
func makeV4BitCountPadding(length, oneBits int) ([]byte, error) {
|
||||
totalBits := 8 * length
|
||||
if oneBits < 0 || oneBits > totalBits {
|
||||
return nil, errors.New("snell v4 invalid padding bit count")
|
||||
}
|
||||
bitset := make([]byte, totalBits)
|
||||
for i := 0; i < oneBits; i++ {
|
||||
bitset[i] = 1
|
||||
}
|
||||
for i := totalBits - 1; i > 0; i-- {
|
||||
j, err := cryptoRandomInt(i + 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bitset[i], bitset[j] = bitset[j], bitset[i]
|
||||
}
|
||||
padding := make([]byte, length)
|
||||
for i, bit := range bitset {
|
||||
if bit == 1 {
|
||||
padding[i/8] |= 1 << uint(i%8)
|
||||
}
|
||||
}
|
||||
return padding, nil
|
||||
}
|
||||
|
||||
func cryptoRandomInt(max int) (int, error) {
|
||||
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(int64(max)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(n.Int64()), nil
|
||||
}
|
||||
|
||||
func randomUnitFloat64() (float64, error) {
|
||||
n, err := cryptorand.Int(cryptorand.Reader, big.NewInt(1<<53))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return float64(n.Int64()) / math.Exp2(53), nil
|
||||
}
|
||||
|
||||
func writeFull(w io.Writer, p []byte) error {
|
||||
for len(p) > 0 {
|
||||
n, err := w.Write(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
p = p[n:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementV4Nonce(nonce []byte) {
|
||||
for i := range nonce {
|
||||
nonce[i]++
|
||||
if nonce[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,9 +18,12 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
headerSize = 1 + 4 + 4
|
||||
maxFrameSize = 256 * 1024
|
||||
maxDataPayload = 32 * 1024
|
||||
headerSize = 1 + 4 + 4
|
||||
// maxQueuedBytesPerStream bounds unread payload retained by a single logical stream.
|
||||
// Backpressure is applied to the demux loop instead of dropping data.
|
||||
maxQueuedBytesPerStream = 4 * 1024 * 1024
|
||||
maxFrameSize = 256 * 1024
|
||||
maxDataPayload = 128 * 1024
|
||||
)
|
||||
|
||||
type acceptEvent struct {
|
||||
@@ -344,6 +347,8 @@ type stream struct {
|
||||
closeErr error
|
||||
readBuf []byte
|
||||
queue [][]byte
|
||||
// queuedBytes includes unread bytes in readBuf and queue.
|
||||
queuedBytes int
|
||||
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
@@ -362,16 +367,20 @@ func newStream(session *Session, id uint32) *stream {
|
||||
|
||||
func (c *stream) enqueue(payload []byte) {
|
||||
c.mu.Lock()
|
||||
for !c.closed && c.queuedBytes+len(payload) > maxQueuedBytesPerStream {
|
||||
c.cond.Wait()
|
||||
}
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.queuedBytes += len(payload)
|
||||
if len(c.readBuf) == 0 && len(c.queue) == 0 {
|
||||
c.readBuf = payload
|
||||
} else {
|
||||
c.queue = append(c.queue, payload)
|
||||
}
|
||||
c.cond.Signal()
|
||||
c.cond.Broadcast()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -413,7 +422,11 @@ func (c *stream) Read(p []byte) (int, error) {
|
||||
}
|
||||
if len(c.readBuf) == 0 && len(c.queue) > 0 {
|
||||
c.readBuf = c.queue[0]
|
||||
c.queue[0] = nil
|
||||
c.queue = c.queue[1:]
|
||||
if len(c.queue) == 0 {
|
||||
c.queue = nil
|
||||
}
|
||||
}
|
||||
if len(c.readBuf) == 0 && c.closed {
|
||||
if c.closeErr == nil {
|
||||
@@ -424,6 +437,14 @@ func (c *stream) Read(p []byte) (int, error) {
|
||||
|
||||
n := copy(p, c.readBuf)
|
||||
c.readBuf = c.readBuf[n:]
|
||||
if len(c.readBuf) == 0 {
|
||||
c.readBuf = nil
|
||||
}
|
||||
c.queuedBytes -= n
|
||||
if c.queuedBytes < 0 {
|
||||
c.queuedBytes = 0
|
||||
}
|
||||
c.cond.Broadcast()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
||||
91
transport/sudoku/multiplex/session_backpressure_test.go
Normal file
91
transport/sudoku/multiplex/session_backpressure_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package multiplex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestSession_LargeTransferBackpressure verifies that a transfer larger than
|
||||
// maxQueuedBytesPerStream completes correctly: the demux loop applies
|
||||
// backpressure (cond.Wait) instead of dropping data, and the reader draining
|
||||
// the stream wakes the blocked loop without deadlock.
|
||||
func TestSession_LargeTransferBackpressure(t *testing.T) {
|
||||
c1, c2 := net.Pipe()
|
||||
|
||||
client, err := NewClientSession(c1)
|
||||
if err != nil {
|
||||
t.Fatalf("client session: %v", err)
|
||||
}
|
||||
server, err := NewServerSession(c2)
|
||||
if err != nil {
|
||||
t.Fatalf("server session: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
// Payload bigger than the per-stream backpressure window (4MB).
|
||||
const total = 12 * 1024 * 1024
|
||||
payload := make([]byte, total)
|
||||
if _, err := rand.Read(payload); err != nil {
|
||||
t.Fatalf("rand: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var writeErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stream, err := client.OpenStream([]byte("hello"))
|
||||
if err != nil {
|
||||
writeErr = err
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
if _, err := stream.Write(payload); err != nil {
|
||||
writeErr = err
|
||||
return
|
||||
}
|
||||
_ = stream.(interface{ CloseWrite() error }).CloseWrite()
|
||||
}()
|
||||
|
||||
var got []byte
|
||||
var readErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stream, openPayload, err := server.AcceptStream()
|
||||
if err != nil {
|
||||
readErr = err
|
||||
return
|
||||
}
|
||||
if string(openPayload) != "hello" {
|
||||
readErr = io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
got, readErr = io.ReadAll(stream)
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { wg.Wait(); close(done) }()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(30 * time.Second):
|
||||
t.Fatal("transfer deadlocked (backpressure did not release)")
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
t.Fatalf("write: %v", writeErr)
|
||||
}
|
||||
if readErr != nil {
|
||||
t.Fatalf("read: %v", readErr)
|
||||
}
|
||||
if !bytes.Equal(got, payload) {
|
||||
t.Fatalf("payload mismatch: got %d bytes, want %d", len(got), len(payload))
|
||||
}
|
||||
}
|
||||
56
transport/sudoku/obfs/sudoku/ascii_mode_test.go
Normal file
56
transport/sudoku/obfs/sudoku/ascii_mode_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package sudoku
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeASCIIMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", "prefer_entropy"},
|
||||
{"entropy", "prefer_entropy"},
|
||||
{"prefer_ascii", "prefer_ascii"},
|
||||
{"up_ascii_down_entropy", "up_ascii_down_entropy"},
|
||||
{"up_entropy_down_ascii", "up_entropy_down_ascii"},
|
||||
{"up_prefer_ascii_down_prefer_entropy", "up_ascii_down_entropy"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := NormalizeASCIIMode(tt.in)
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizeASCIIMode(%q): %v", tt.in, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("NormalizeASCIIMode(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := NormalizeASCIIMode("up_ascii_down_binary"); err == nil {
|
||||
t.Fatalf("expected invalid directional mode to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTableWithCustomDirectionalOpposite(t *testing.T) {
|
||||
table, err := NewTableWithCustom("seed", "up_ascii_down_entropy", "xpxvvpvv")
|
||||
if err != nil {
|
||||
t.Fatalf("NewTableWithCustom: %v", err)
|
||||
}
|
||||
if !table.IsASCII {
|
||||
t.Fatalf("uplink table should be ascii")
|
||||
}
|
||||
opposite := table.OppositeDirection()
|
||||
if opposite == nil || opposite == table {
|
||||
t.Fatalf("expected distinct opposite table")
|
||||
}
|
||||
if opposite.IsASCII {
|
||||
t.Fatalf("downlink table should be entropy/custom")
|
||||
}
|
||||
|
||||
symmetric, err := NewTableWithCustom("seed", "prefer_ascii", "xpxvvpvv")
|
||||
if err != nil {
|
||||
t.Fatalf("NewTableWithCustom symmetric: %v", err)
|
||||
}
|
||||
if symmetric.OppositeDirection() != symmetric {
|
||||
t.Fatalf("symmetric table should point to itself")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package sudoku
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -10,6 +11,8 @@ import (
|
||||
|
||||
const IOBufferSize = 32 * 1024
|
||||
|
||||
const minDecodeReadSize = 64
|
||||
|
||||
var perm4 = [24][4]byte{
|
||||
{0, 1, 2, 3},
|
||||
{0, 1, 3, 2},
|
||||
@@ -52,7 +55,7 @@ type Conn struct {
|
||||
writeMu sync.Mutex
|
||||
writeBuf []byte
|
||||
|
||||
rng randomSource
|
||||
rng *sudokuRand
|
||||
paddingThreshold uint64
|
||||
}
|
||||
|
||||
@@ -97,6 +100,9 @@ func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
|
||||
}
|
||||
|
||||
func (sc *Conn) StopRecording() {
|
||||
if sc == nil {
|
||||
return
|
||||
}
|
||||
sc.recordLock.Lock()
|
||||
sc.recording.Store(false)
|
||||
sc.recorder = nil
|
||||
@@ -115,6 +121,9 @@ func (sc *Conn) GetBufferedAndRecorded() []byte {
|
||||
if sc.recorder != nil {
|
||||
recorded = sc.recorder.Bytes()
|
||||
}
|
||||
if sc.reader == nil {
|
||||
return recorded
|
||||
}
|
||||
|
||||
buffered := sc.reader.Buffered()
|
||||
if buffered > 0 {
|
||||
@@ -131,6 +140,9 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if sc == nil || sc.Conn == nil || sc.table == nil || sc.table.layout == nil || sc.rng == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
sc.writeMu.Lock()
|
||||
defer sc.writeMu.Unlock()
|
||||
@@ -140,16 +152,19 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (sc *Conn) Read(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if sc == nil || sc.Conn == nil || sc.reader == nil || len(sc.rawBuf) == 0 || sc.table == nil || sc.table.layout == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if n, ok := drainPending(p, &sc.pendingData); ok {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
outN := 0
|
||||
for {
|
||||
if sc.pendingData.available() > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
nr, rErr := sc.reader.Read(sc.rawBuf)
|
||||
nr, rErr := readRawLimited(sc.Conn, sc.reader, sc.rawBuf[:sudokuReadSize(len(p)-outN, len(sc.rawBuf))])
|
||||
if nr > 0 {
|
||||
chunk := sc.rawBuf[:nr]
|
||||
if sc.recording.Load() {
|
||||
@@ -160,34 +175,80 @@ func (sc *Conn) Read(p []byte) (n int, err error) {
|
||||
sc.recordLock.Unlock()
|
||||
}
|
||||
|
||||
layout := sc.table.layout
|
||||
for _, b := range chunk {
|
||||
table := sc.table
|
||||
layout := table.layout
|
||||
for i := 0; i < len(chunk); {
|
||||
if sc.hintCount == 0 && outN < len(p) && i+3 < len(chunk) &&
|
||||
layout.hintTable[chunk[i]] &&
|
||||
layout.hintTable[chunk[i+1]] &&
|
||||
layout.hintTable[chunk[i+2]] &&
|
||||
layout.hintTable[chunk[i+3]] {
|
||||
val, ok := table.DecodeMap[packHintBytes(chunk[i], chunk[i+1], chunk[i+2], chunk[i+3])]
|
||||
if !ok {
|
||||
return 0, ErrInvalidSudokuMapMiss
|
||||
}
|
||||
p[outN] = val
|
||||
outN++
|
||||
i += 4
|
||||
continue
|
||||
}
|
||||
|
||||
b := chunk[i]
|
||||
i++
|
||||
if !layout.hintTable[b] {
|
||||
continue
|
||||
}
|
||||
|
||||
sc.hintBuf[sc.hintCount] = b
|
||||
sc.hintCount++
|
||||
if sc.hintCount == len(sc.hintBuf) {
|
||||
key := packHintsToKey(sc.hintBuf)
|
||||
val, ok := sc.table.DecodeMap[key]
|
||||
if !ok {
|
||||
return 0, ErrInvalidSudokuMapMiss
|
||||
}
|
||||
sc.pendingData.appendByte(val)
|
||||
sc.hintCount = 0
|
||||
if sc.hintCount != len(sc.hintBuf) {
|
||||
continue
|
||||
}
|
||||
|
||||
val, ok := table.DecodeMap[packHintBytes(sc.hintBuf[0], sc.hintBuf[1], sc.hintBuf[2], sc.hintBuf[3])]
|
||||
if !ok {
|
||||
return 0, ErrInvalidSudokuMapMiss
|
||||
}
|
||||
outN = appendDecodedByte(p, outN, &sc.pendingData, val)
|
||||
sc.hintCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
if rErr != nil {
|
||||
if outN > 0 {
|
||||
return outN, nil
|
||||
}
|
||||
if n, ok := drainPending(p, &sc.pendingData); ok {
|
||||
return n, nil
|
||||
}
|
||||
return 0, rErr
|
||||
}
|
||||
if sc.pendingData.available() > 0 {
|
||||
break
|
||||
if outN > 0 {
|
||||
return outN, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, _ = drainPending(p, &sc.pendingData)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func sudokuReadSize(decodedRemaining, maxRaw int) int {
|
||||
if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 {
|
||||
return maxRaw
|
||||
}
|
||||
if decodedRemaining > (maxRaw-minDecodeReadSize)/5 {
|
||||
return maxRaw
|
||||
}
|
||||
|
||||
return decodedRemaining*5 + minDecodeReadSize
|
||||
}
|
||||
|
||||
func readRawLimited(conn net.Conn, reader *bufio.Reader, dst []byte) (int, error) {
|
||||
if len(dst) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if reader != nil && reader.Buffered() > 0 {
|
||||
return reader.Read(dst)
|
||||
}
|
||||
if conn == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return conn.Read(dst)
|
||||
}
|
||||
|
||||
51
transport/sudoku/obfs/sudoku/conn_roundtrip_test.go
Normal file
51
transport/sudoku/obfs/sudoku/conn_roundtrip_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConn_Roundtrip exercises the optimized Conn encode/decode hot paths:
|
||||
// the no-padding fast path (pMin==pMax==0), the always-padding path
|
||||
// (pMin==pMax==100), a probabilistic range, and the adaptive read-size /
|
||||
// 4-byte fast hint decode path across a variety of payload sizes and modes.
|
||||
func TestConn_Roundtrip(t *testing.T) {
|
||||
modes := []string{"prefer_entropy", "prefer_ascii"}
|
||||
paddings := []struct{ min, max int }{
|
||||
{0, 0}, // no-padding specialized path
|
||||
{100, 100}, // always-padding specialized path
|
||||
{20, 60}, // probabilistic path
|
||||
}
|
||||
sizes := []int{1, 3, 4, 7, 16, 100, 1000, 64 * 1024}
|
||||
|
||||
for _, mode := range modes {
|
||||
for _, pad := range paddings {
|
||||
for _, size := range sizes {
|
||||
payload := make([]byte, size)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i*31 + 7)
|
||||
}
|
||||
|
||||
table := NewTable("conn-roundtrip-seed", mode)
|
||||
|
||||
// Encode via Conn.Write.
|
||||
w := &mockConn{}
|
||||
enc := NewConn(w, table, pad.min, pad.max, false)
|
||||
if _, err := enc.Write(payload); err != nil {
|
||||
t.Fatalf("mode=%s pad=%v size=%d write: %v", mode, pad, size, err)
|
||||
}
|
||||
|
||||
// Decode via Conn.Read using the same table.
|
||||
dec := NewConn(&mockConn{readBuf: w.writeBuf}, table, pad.min, pad.max, false)
|
||||
got := make([]byte, size)
|
||||
if _, err := io.ReadFull(dec, got); err != nil {
|
||||
t.Fatalf("mode=%s pad=%v size=%d read: %v", mode, pad, size, err)
|
||||
}
|
||||
if !bytes.Equal(got, payload) {
|
||||
t.Fatalf("mode=%s pad=%v size=%d roundtrip mismatch", mode, pad, size)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
package sudoku
|
||||
|
||||
func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThreshold uint64, p []byte) []byte {
|
||||
func encodeSudokuPayload(dst []byte, table *Table, rng *sudokuRand, paddingThreshold uint64, p []byte) []byte {
|
||||
if len(p) == 0 {
|
||||
return dst[:0]
|
||||
}
|
||||
if paddingThreshold == 0 {
|
||||
return encodeSudokuPayloadNoPadding(dst, table, rng, p)
|
||||
}
|
||||
|
||||
outCapacity := len(p)*6 + 1
|
||||
if cap(dst) < outCapacity {
|
||||
@@ -13,8 +16,25 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre
|
||||
pads := table.PaddingPool
|
||||
padLen := len(pads)
|
||||
|
||||
if paddingThreshold >= probOne {
|
||||
for _, b := range p {
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
|
||||
puzzles := table.EncodeTable[b]
|
||||
puzzle := puzzles[rng.Intn(len(puzzles))]
|
||||
|
||||
perm := perm4[rng.Intn(len(perm4))]
|
||||
for _, idx := range perm {
|
||||
out = append(out, pads[rng.Intn(padLen)], puzzle[idx])
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
return out
|
||||
}
|
||||
|
||||
for _, b := range p {
|
||||
if shouldPad(rng, paddingThreshold) {
|
||||
if uint64(rng.Uint32()) < paddingThreshold {
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
}
|
||||
|
||||
@@ -22,15 +42,31 @@ func encodeSudokuPayload(dst []byte, table *Table, rng randomSource, paddingThre
|
||||
puzzle := puzzles[rng.Intn(len(puzzles))]
|
||||
perm := perm4[rng.Intn(len(perm4))]
|
||||
for _, idx := range perm {
|
||||
if shouldPad(rng, paddingThreshold) {
|
||||
if uint64(rng.Uint32()) < paddingThreshold {
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
}
|
||||
out = append(out, puzzle[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if shouldPad(rng, paddingThreshold) {
|
||||
if uint64(rng.Uint32()) < paddingThreshold {
|
||||
out = append(out, pads[rng.Intn(padLen)])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeSudokuPayloadNoPadding(dst []byte, table *Table, rng *sudokuRand, p []byte) []byte {
|
||||
outCapacity := len(p) * 4
|
||||
if cap(dst) < outCapacity {
|
||||
dst = make([]byte, 0, outCapacity)
|
||||
}
|
||||
out := dst[:0]
|
||||
|
||||
for _, b := range p {
|
||||
puzzles := table.EncodeTable[b]
|
||||
puzzle := puzzles[rng.Intn(len(puzzles))]
|
||||
perm := perm4[rng.Intn(len(perm4))]
|
||||
out = append(out, puzzle[perm[0]], puzzle[perm[1]], puzzle[perm[2]], puzzle[perm[3]])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
RngBatchSize = 128
|
||||
|
||||
packedProtectedPrefixBytes = 14
|
||||
packedIOBufferSize = 64 * 1024
|
||||
packedDecodeBufferSize = 96 * 1024
|
||||
)
|
||||
|
||||
// PackedConn encodes traffic with the packed Sudoku layout while preserving
|
||||
@@ -35,7 +35,7 @@ type PackedConn struct {
|
||||
readBits int
|
||||
|
||||
// Padding selection matches Conn's threshold-based model.
|
||||
rng randomSource
|
||||
rng *sudokuRand
|
||||
paddingThreshold uint64
|
||||
padMarker byte
|
||||
padPool []byte
|
||||
@@ -67,18 +67,20 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
||||
pc := &PackedConn{
|
||||
Conn: c,
|
||||
table: table,
|
||||
reader: bufio.NewReaderSize(c, IOBufferSize),
|
||||
rawBuf: make([]byte, IOBufferSize),
|
||||
reader: bufio.NewReaderSize(c, packedIOBufferSize),
|
||||
rawBuf: make([]byte, packedDecodeBufferSize),
|
||||
pendingData: newPendingBuffer(4096),
|
||||
writeBuf: make([]byte, 0, 4096),
|
||||
rng: localRng,
|
||||
paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
|
||||
}
|
||||
|
||||
pc.padMarker = table.layout.padMarker
|
||||
for _, b := range table.PaddingPool {
|
||||
if b != pc.padMarker {
|
||||
pc.padPool = append(pc.padPool, b)
|
||||
if table != nil && table.layout != nil {
|
||||
pc.padMarker = table.layout.padMarker
|
||||
for _, b := range table.PaddingPool {
|
||||
if b != pc.padMarker {
|
||||
pc.padPool = append(pc.padPool, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(pc.padPool) == 0 {
|
||||
@@ -87,18 +89,6 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
||||
return pc
|
||||
}
|
||||
|
||||
func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
|
||||
if shouldPad(pc.rng, pc.paddingThreshold) {
|
||||
out = append(out, pc.getPaddingByte())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (pc *PackedConn) appendGroup(out []byte, group byte) []byte {
|
||||
out = pc.maybeAddPadding(out)
|
||||
return append(out, pc.table.layout.groupByte(group))
|
||||
}
|
||||
|
||||
func (pc *PackedConn) appendForcedPadding(out []byte) []byte {
|
||||
return append(out, pc.getPaddingByte())
|
||||
}
|
||||
@@ -134,7 +124,7 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
|
||||
} else {
|
||||
pc.bitBuf &= (1 << pc.bitCount) - 1
|
||||
}
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
out = appendPackedGroup(out, pc.table.layout, pc.rng, pc.paddingThreshold, pc.padPool, group)
|
||||
}
|
||||
|
||||
effective++
|
||||
@@ -148,19 +138,49 @@ func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
|
||||
return out, limit
|
||||
}
|
||||
|
||||
func appendPackedGroup(out []byte, layout *byteLayout, rng *sudokuRand, paddingThreshold uint64, padPool []byte, group byte) []byte {
|
||||
if paddingThreshold != 0 {
|
||||
u := rng.Uint32()
|
||||
if uint64(u) < paddingThreshold {
|
||||
out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))])
|
||||
}
|
||||
}
|
||||
return append(out, layout.encodeGroup[group&0x3F])
|
||||
}
|
||||
|
||||
func maybeAppendPackedPadding(out []byte, rng *sudokuRand, paddingThreshold uint64, padPool []byte) []byte {
|
||||
if paddingThreshold != 0 {
|
||||
u := rng.Uint32()
|
||||
if uint64(u) < paddingThreshold {
|
||||
out = append(out, padPool[fastIntnFromUint32(rng.Uint32(), len(padPool))])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
pc.writeMu.Lock()
|
||||
defer pc.writeMu.Unlock()
|
||||
|
||||
needed := len(p)*3/2 + 32
|
||||
if pc.paddingThreshold == 0 {
|
||||
needed = ((len(p)+2)/3)*4 + 32
|
||||
}
|
||||
if cap(pc.writeBuf) < needed {
|
||||
pc.writeBuf = make([]byte, 0, needed)
|
||||
}
|
||||
out := pc.writeBuf[:0]
|
||||
layout := pc.table.layout
|
||||
rng := pc.rng
|
||||
paddingThreshold := pc.paddingThreshold
|
||||
padPool := pc.padPool
|
||||
|
||||
var prefixN int
|
||||
out, prefixN = pc.writeProtectedPrefix(out, p)
|
||||
@@ -181,7 +201,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
} else {
|
||||
pc.bitBuf &= (1 << pc.bitCount) - 1
|
||||
}
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,10 +215,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
|
||||
g4 := b3 & 0x3F
|
||||
|
||||
out = pc.appendGroup(out, g1)
|
||||
out = pc.appendGroup(out, g2)
|
||||
out = pc.appendGroup(out, g3)
|
||||
out = pc.appendGroup(out, g4)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,10 +231,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
|
||||
g4 := b3 & 0x3F
|
||||
|
||||
out = pc.appendGroup(out, g1)
|
||||
out = pc.appendGroup(out, g2)
|
||||
out = pc.appendGroup(out, g3)
|
||||
out = pc.appendGroup(out, g4)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g1)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g2)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g3)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, g4)
|
||||
}
|
||||
|
||||
for ; i < n; i++ {
|
||||
@@ -229,7 +249,7 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
} else {
|
||||
pc.bitBuf &= (1 << pc.bitCount) - 1
|
||||
}
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,11 +257,11 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
group := byte(pc.bitBuf << (6 - pc.bitCount))
|
||||
pc.bitBuf = 0
|
||||
pc.bitCount = 0
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
out = appendPackedGroup(out, layout, rng, paddingThreshold, padPool, group)
|
||||
out = append(out, pc.padMarker)
|
||||
}
|
||||
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = maybeAppendPackedPadding(out, rng, paddingThreshold, padPool)
|
||||
|
||||
if len(out) > 0 {
|
||||
pc.writeBuf = out[:0]
|
||||
@@ -252,6 +272,10 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (pc *PackedConn) Flush() error {
|
||||
if pc == nil || pc.Conn == nil || pc.table == nil || pc.table.layout == nil || pc.rng == nil || len(pc.padPool) == 0 {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
|
||||
pc.writeMu.Lock()
|
||||
defer pc.writeMu.Unlock()
|
||||
|
||||
@@ -265,7 +289,7 @@ func (pc *PackedConn) Flush() error {
|
||||
out = append(out, pc.padMarker)
|
||||
}
|
||||
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = maybeAppendPackedPadding(out, pc.rng, pc.paddingThreshold, pc.padPool)
|
||||
|
||||
if len(out) > 0 {
|
||||
pc.writeBuf = out[:0]
|
||||
@@ -289,19 +313,44 @@ func writeFull(w io.Writer, b []byte) error {
|
||||
}
|
||||
|
||||
func (pc *PackedConn) Read(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if pc == nil || pc.Conn == nil || pc.reader == nil || len(pc.rawBuf) == 0 || pc.table == nil || pc.table.layout == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if n, ok := drainPending(p, &pc.pendingData); ok {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
outN := 0
|
||||
for {
|
||||
nr, rErr := pc.reader.Read(pc.rawBuf)
|
||||
nr, rErr := readRawLimited(pc.Conn, pc.reader, pc.rawBuf[:packedReadSize(len(p)-outN, len(pc.rawBuf))])
|
||||
if nr > 0 {
|
||||
rBuf := pc.readBitBuf
|
||||
rBits := pc.readBits
|
||||
padMarker := pc.padMarker
|
||||
layout := pc.table.layout
|
||||
|
||||
for _, b := range pc.rawBuf[:nr] {
|
||||
chunk := pc.rawBuf[:nr]
|
||||
for i := 0; i < len(chunk); {
|
||||
if rBits == 0 && outN+3 <= len(p) && i+3 < len(chunk) &&
|
||||
layout.hintTable[chunk[i]] && layout.hintTable[chunk[i+1]] &&
|
||||
layout.hintTable[chunk[i+2]] && layout.hintTable[chunk[i+3]] {
|
||||
g1 := layout.decodeGroup[chunk[i]]
|
||||
g2 := layout.decodeGroup[chunk[i+1]]
|
||||
g3 := layout.decodeGroup[chunk[i+2]]
|
||||
g4 := layout.decodeGroup[chunk[i+3]]
|
||||
p[outN] = (g1 << 2) | (g2 >> 4)
|
||||
p[outN+1] = (g2 << 4) | (g3 >> 2)
|
||||
p[outN+2] = (g3 << 6) | g4
|
||||
outN += 3
|
||||
i += 4
|
||||
continue
|
||||
}
|
||||
|
||||
b := chunk[i]
|
||||
i++
|
||||
if !layout.hintTable[b] {
|
||||
if b == padMarker {
|
||||
rBuf = 0
|
||||
@@ -321,7 +370,7 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
|
||||
if rBits >= 8 {
|
||||
rBits -= 8
|
||||
val := byte(rBuf >> rBits)
|
||||
pc.pendingData.appendByte(val)
|
||||
outN = appendDecodedByte(p, outN, &pc.pendingData, val)
|
||||
if rBits == 0 {
|
||||
rBuf = 0
|
||||
} else {
|
||||
@@ -339,21 +388,32 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
|
||||
pc.readBitBuf = 0
|
||||
pc.readBits = 0
|
||||
}
|
||||
if pc.pendingData.available() > 0 {
|
||||
break
|
||||
if outN > 0 {
|
||||
return outN, nil
|
||||
}
|
||||
if n, ok := drainPending(p, &pc.pendingData); ok {
|
||||
return n, nil
|
||||
}
|
||||
return 0, rErr
|
||||
}
|
||||
|
||||
if pc.pendingData.available() > 0 {
|
||||
break
|
||||
if outN > 0 {
|
||||
return outN, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, _ := drainPending(p, &pc.pendingData)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (pc *PackedConn) getPaddingByte() byte {
|
||||
return pc.padPool[pc.rng.Intn(len(pc.padPool))]
|
||||
}
|
||||
|
||||
func packedReadSize(decodedRemaining, maxRaw int) int {
|
||||
if maxRaw <= minDecodeReadSize || decodedRemaining <= 0 {
|
||||
return maxRaw
|
||||
}
|
||||
if decodedRemaining > (maxRaw-minDecodeReadSize)/2 {
|
||||
return maxRaw
|
||||
}
|
||||
|
||||
return decodedRemaining*2 + minDecodeReadSize
|
||||
}
|
||||
|
||||
90
transport/sudoku/obfs/sudoku/packed_prefix_test.go
Normal file
90
transport/sudoku/obfs/sudoku/packed_prefix_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
readBuf []byte
|
||||
writeBuf []byte
|
||||
}
|
||||
|
||||
func (c *mockConn) Read(p []byte) (int, error) {
|
||||
if len(c.readBuf) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, c.readBuf)
|
||||
c.readBuf = c.readBuf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Write(p []byte) (int, error) {
|
||||
c.writeBuf = append(c.writeBuf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Close() error { return nil }
|
||||
func (c *mockConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *mockConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *mockConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *mockConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *mockConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
func TestPackedConn_ProtectedPrefixPadding(t *testing.T) {
|
||||
table := NewTable("packed-prefix-seed", "prefer_ascii")
|
||||
mock := &mockConn{}
|
||||
writer := NewPackedConn(mock, table, 0, 0)
|
||||
writer.rng = newSudokuRand(1)
|
||||
|
||||
payload := bytes.Repeat([]byte{0}, 32)
|
||||
if _, err := writer.Write(payload); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
wire := append([]byte(nil), mock.writeBuf...)
|
||||
if len(wire) < 20 {
|
||||
t.Fatalf("wire too short: %d", len(wire))
|
||||
}
|
||||
|
||||
firstHint := -1
|
||||
nonHintCount := 0
|
||||
maxHintRun := 0
|
||||
currentHintRun := 0
|
||||
for i, b := range wire[:20] {
|
||||
if table.layout.isHint(b) {
|
||||
if firstHint == -1 {
|
||||
firstHint = i
|
||||
}
|
||||
currentHintRun++
|
||||
if currentHintRun > maxHintRun {
|
||||
maxHintRun = currentHintRun
|
||||
}
|
||||
continue
|
||||
}
|
||||
nonHintCount++
|
||||
currentHintRun = 0
|
||||
}
|
||||
|
||||
if firstHint < 1 || firstHint > 2 {
|
||||
t.Fatalf("expected 1-2 leading padding bytes, first hint index=%d", firstHint)
|
||||
}
|
||||
if nonHintCount < 6 {
|
||||
t.Fatalf("expected dense prefix padding, got only %d non-hint bytes in first 20", nonHintCount)
|
||||
}
|
||||
if maxHintRun > 3 {
|
||||
t.Fatalf("prefix still exposes long hint run: %d", maxHintRun)
|
||||
}
|
||||
|
||||
reader := NewPackedConn(&mockConn{readBuf: wire}, table, 0, 0)
|
||||
decoded := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(reader, decoded); err != nil {
|
||||
t.Fatalf("read back: %v", err)
|
||||
}
|
||||
if !bytes.Equal(decoded, payload) {
|
||||
t.Fatalf("roundtrip mismatch")
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package sudoku
|
||||
|
||||
const probOne = uint64(1) << 32
|
||||
|
||||
func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 {
|
||||
func pickPaddingThreshold(r *sudokuRand, pMin, pMax int) uint64 {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
@@ -28,7 +28,7 @@ func pickPaddingThreshold(r randomSource, pMin, pMax int) uint64 {
|
||||
return min + (u * (max - min) >> 32)
|
||||
}
|
||||
|
||||
func shouldPad(r randomSource, threshold uint64) bool {
|
||||
func shouldPad(r *sudokuRand, threshold uint64) bool {
|
||||
if threshold == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -25,7 +25,10 @@ func (p *pendingBuffer) reset() {
|
||||
}
|
||||
|
||||
func (p *pendingBuffer) ensureAppendCapacity(extra int) {
|
||||
if p == nil || extra <= 0 || p.off == 0 {
|
||||
if p == nil || extra <= 0 {
|
||||
return
|
||||
}
|
||||
if p.off == 0 {
|
||||
return
|
||||
}
|
||||
if cap(p.data)-len(p.data) >= extra {
|
||||
@@ -43,6 +46,15 @@ func (p *pendingBuffer) appendByte(b byte) {
|
||||
p.data = append(p.data, b)
|
||||
}
|
||||
|
||||
func appendDecodedByte(dst []byte, n int, pending *pendingBuffer, b byte) int {
|
||||
if n < len(dst) {
|
||||
dst[n] = b
|
||||
return n + 1
|
||||
}
|
||||
pending.appendByte(b)
|
||||
return n
|
||||
}
|
||||
|
||||
func drainPending(dst []byte, pending *pendingBuffer) (int, bool) {
|
||||
if pending == nil || pending.available() == 0 {
|
||||
return 0, false
|
||||
|
||||
@@ -6,14 +6,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type randomSource interface {
|
||||
Uint32() uint32
|
||||
Uint64() uint64
|
||||
Intn(n int) int
|
||||
}
|
||||
|
||||
type sudokuRand struct {
|
||||
state uint64
|
||||
state uint64
|
||||
cached uint32
|
||||
haveCached bool
|
||||
}
|
||||
|
||||
func newSeededRand() *sudokuRand {
|
||||
@@ -37,20 +33,36 @@ func (r *sudokuRand) Uint64() uint64 {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
r.state += 0x9e3779b97f4a7c15
|
||||
z := r.state
|
||||
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9
|
||||
z = (z ^ (z >> 27)) * 0x94d049bb133111eb
|
||||
return z ^ (z >> 31)
|
||||
r.haveCached = false
|
||||
x := r.state
|
||||
x ^= x >> 12
|
||||
x ^= x << 25
|
||||
x ^= x >> 27
|
||||
r.state = x
|
||||
return x * 0x2545f4914f6cdd1d
|
||||
}
|
||||
|
||||
func (r *sudokuRand) Uint32() uint32 {
|
||||
return uint32(r.Uint64() >> 32)
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
if r.haveCached {
|
||||
r.haveCached = false
|
||||
return r.cached
|
||||
}
|
||||
v := r.Uint64()
|
||||
r.cached = uint32(v)
|
||||
r.haveCached = true
|
||||
return uint32(v >> 32)
|
||||
}
|
||||
|
||||
func (r *sudokuRand) Intn(n int) int {
|
||||
if n <= 1 {
|
||||
return 0
|
||||
}
|
||||
return int((uint64(r.Uint32()) * uint64(n)) >> 32)
|
||||
return fastIntnFromUint32(r.Uint32(), n)
|
||||
}
|
||||
|
||||
func fastIntnFromUint32(u uint32, n int) int {
|
||||
return int((uint64(u) * uint64(n)) >> 32)
|
||||
}
|
||||
|
||||
@@ -192,23 +192,27 @@ func tableHintFingerprint(key string, mode string, uplinkPattern string, downlin
|
||||
}
|
||||
|
||||
func packHintsToKey(hints [4]byte) uint32 {
|
||||
return packHintBytes(hints[0], hints[1], hints[2], hints[3])
|
||||
}
|
||||
|
||||
func packHintBytes(h0, h1, h2, h3 byte) uint32 {
|
||||
// Sorting network for 4 elements (Bubble sort unrolled)
|
||||
// Swap if a > b
|
||||
if hints[0] > hints[1] {
|
||||
hints[0], hints[1] = hints[1], hints[0]
|
||||
if h0 > h1 {
|
||||
h0, h1 = h1, h0
|
||||
}
|
||||
if hints[2] > hints[3] {
|
||||
hints[2], hints[3] = hints[3], hints[2]
|
||||
if h2 > h3 {
|
||||
h2, h3 = h3, h2
|
||||
}
|
||||
if hints[0] > hints[2] {
|
||||
hints[0], hints[2] = hints[2], hints[0]
|
||||
if h0 > h2 {
|
||||
h0, h2 = h2, h0
|
||||
}
|
||||
if hints[1] > hints[3] {
|
||||
hints[1], hints[3] = hints[3], hints[1]
|
||||
if h1 > h3 {
|
||||
h1, h3 = h3, h1
|
||||
}
|
||||
if hints[1] > hints[2] {
|
||||
hints[1], hints[2] = hints[2], hints[1]
|
||||
if h1 > h2 {
|
||||
h1, h2 = h2, h1
|
||||
}
|
||||
|
||||
return uint32(hints[0])<<24 | uint32(hints[1])<<16 | uint32(hints[2])<<8 | uint32(hints[3])
|
||||
return uint32(h0)<<24 | uint32(h1)<<16 | uint32(h2)<<8 | uint32(h3)
|
||||
}
|
||||
|
||||
@@ -14,12 +14,14 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/common/congestion"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
@@ -50,7 +52,7 @@ type ClientOptions struct {
|
||||
QUIC bool
|
||||
CongestionControl string
|
||||
CWND int
|
||||
BBRProfile string
|
||||
Logger logger.Logger
|
||||
HealthCheck bool
|
||||
MaxConnections int
|
||||
MinStreams int
|
||||
@@ -81,7 +83,7 @@ func NewClient(ctx context.Context, options ClientOptions) (*Client, error) {
|
||||
healthCheck: options.HealthCheck,
|
||||
}
|
||||
if options.QUIC {
|
||||
congestionControlFactory, err := NewCongestionControl(options.CongestionControl, options.CWND, options.BBRProfile, ntp.TimeFuncFromContext(ctx))
|
||||
congestionControlFactory, err := congestion.NewCongestionControl(options.CongestionControl, options.CWND, ntp.TimeFuncFromContext(ctx))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package trusttunnel
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
"github.com/sagernet/sing-quic/congestion_bbr1"
|
||||
"github.com/sagernet/sing-quic/congestion_bbr2"
|
||||
congestion_meta1 "github.com/sagernet/sing-quic/congestion_meta1"
|
||||
congestion_meta2 "github.com/sagernet/sing-quic/congestion_meta2"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func NewCongestionControl(name string, cwnd int, bbrProfile string, timeFunc func() time.Time) (func(conn *quic.Conn) congestion.CongestionControl, error) {
|
||||
if timeFunc == nil {
|
||||
timeFunc = time.Now
|
||||
}
|
||||
if cwnd == 0 {
|
||||
cwnd = 32
|
||||
}
|
||||
switch name {
|
||||
case "", "bbr":
|
||||
return func(conn *quic.Conn) congestion.CongestionControl {
|
||||
return congestion_meta2.NewBbrSender(
|
||||
congestion_meta2.DefaultClock{TimeFunc: timeFunc},
|
||||
congestion.ByteCount(conn.Config().InitialPacketSize),
|
||||
congestion.ByteCount(cwnd)*congestion.ByteCount(conn.Config().InitialPacketSize),
|
||||
)
|
||||
}, nil
|
||||
case "bbr_standard":
|
||||
return func(conn *quic.Conn) congestion.CongestionControl {
|
||||
return congestion_bbr1.NewBbrSender(
|
||||
congestion_bbr1.DefaultClock{TimeFunc: timeFunc},
|
||||
congestion.ByteCount(conn.Config().InitialPacketSize),
|
||||
congestion_bbr1.InitialCongestionWindowPackets,
|
||||
congestion_bbr1.MaxCongestionWindowPackets,
|
||||
)
|
||||
}, nil
|
||||
case "bbr2":
|
||||
return func(conn *quic.Conn) congestion.CongestionControl {
|
||||
return congestion_bbr2.NewBBR2Sender(
|
||||
congestion_bbr2.DefaultClock{TimeFunc: timeFunc},
|
||||
congestion.ByteCount(conn.Config().InitialPacketSize),
|
||||
0,
|
||||
false,
|
||||
)
|
||||
}, nil
|
||||
case "bbr2_variant":
|
||||
return func(conn *quic.Conn) congestion.CongestionControl {
|
||||
return congestion_bbr2.NewBBR2Sender(
|
||||
congestion_bbr2.DefaultClock{TimeFunc: timeFunc},
|
||||
congestion.ByteCount(conn.Config().InitialPacketSize),
|
||||
32*congestion.ByteCount(conn.Config().InitialPacketSize),
|
||||
true,
|
||||
)
|
||||
}, nil
|
||||
case "cubic":
|
||||
return func(conn *quic.Conn) congestion.CongestionControl {
|
||||
return congestion_meta1.NewCubicSender(
|
||||
congestion_meta1.DefaultClock{TimeFunc: timeFunc},
|
||||
congestion.ByteCount(conn.Config().InitialPacketSize),
|
||||
false,
|
||||
)
|
||||
}, nil
|
||||
case "reno":
|
||||
return func(conn *quic.Conn) congestion.CongestionControl {
|
||||
return congestion_meta1.NewCubicSender(
|
||||
congestion_meta1.DefaultClock{TimeFunc: timeFunc},
|
||||
congestion.ByteCount(conn.Config().InitialPacketSize),
|
||||
true,
|
||||
)
|
||||
}, nil
|
||||
default:
|
||||
return nil, E.New("unknown congestion control: ", name)
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package v2raygrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
@@ -13,13 +14,21 @@ type GunService interface {
|
||||
}
|
||||
|
||||
func ServerDesc(name string) grpc.ServiceDesc {
|
||||
serviceName := name
|
||||
streamName := "Tun"
|
||||
if strings.Contains(name, "/") {
|
||||
name = strings.TrimPrefix(name, "/")
|
||||
lastSlash := strings.LastIndex(name, "/")
|
||||
serviceName = name[:lastSlash]
|
||||
streamName = name[lastSlash+1:]
|
||||
}
|
||||
return grpc.ServiceDesc{
|
||||
ServiceName: name,
|
||||
ServiceName: serviceName,
|
||||
HandlerType: (*GunServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "Tun",
|
||||
StreamName: streamName,
|
||||
Handler: _GunService_Tun_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
@@ -30,7 +39,11 @@ func ServerDesc(name string) grpc.ServiceDesc {
|
||||
}
|
||||
|
||||
func (c *gunServiceClient) TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GunService_TunClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], "/"+name+"/Tun", opts...)
|
||||
path := "/" + name + "/Tun"
|
||||
if strings.Contains(name, "/") {
|
||||
path = name
|
||||
}
|
||||
stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], path, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -53,10 +53,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
DisableCompression: true,
|
||||
},
|
||||
url: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: serverAddr.String(),
|
||||
Path: "/" + options.ServiceName + "/Tun",
|
||||
RawPath: "/" + url.PathEscape(options.ServiceName) + "/Tun",
|
||||
Scheme: "https",
|
||||
Host: serverAddr.String(),
|
||||
Path: grpcPath(options.ServiceName),
|
||||
},
|
||||
host: host,
|
||||
}
|
||||
|
||||
10
transport/v2raygrpclite/path.go
Normal file
10
transport/v2raygrpclite/path.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package v2raygrpclite
|
||||
|
||||
import "strings"
|
||||
|
||||
func grpcPath(serviceName string) string {
|
||||
if strings.Contains(serviceName, "/") {
|
||||
return serviceName
|
||||
}
|
||||
return "/" + serviceName + "/Tun"
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
|
||||
tlsConfig: tlsConfig,
|
||||
logger: logger,
|
||||
handler: handler,
|
||||
path: "/" + options.ServiceName + "/Tun",
|
||||
path: grpcPath(options.ServiceName),
|
||||
h2Server: &http2.Server{
|
||||
IdleTimeout: time.Duration(options.IdleTimeout),
|
||||
},
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package v2raykcp
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/common/list"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type SendingWindow struct {
|
||||
cache *list.List
|
||||
cache *list.List[*DataSegment]
|
||||
totalInFlightSize uint32
|
||||
writer SegmentWriter
|
||||
onPacketLoss func(uint32)
|
||||
@@ -16,7 +16,7 @@ type SendingWindow struct {
|
||||
|
||||
func NewSendingWindow(writer SegmentWriter, onPacketLoss func(uint32)) *SendingWindow {
|
||||
return &SendingWindow{
|
||||
cache: list.New(),
|
||||
cache: list.New[*DataSegment](),
|
||||
writer: writer,
|
||||
onPacketLoss: onPacketLoss,
|
||||
}
|
||||
@@ -27,9 +27,9 @@ func (sw *SendingWindow) Release() {
|
||||
return
|
||||
}
|
||||
for sw.cache.Len() > 0 {
|
||||
seg := sw.cache.Front().Value.(*DataSegment)
|
||||
seg := sw.cache.Front().Value
|
||||
seg.Release()
|
||||
sw.cache.Remove(sw.cache.Front())
|
||||
sw.cache.Front().Remove()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,17 +50,17 @@ func (sw *SendingWindow) Push(number uint32, b *buf.Buffer) {
|
||||
}
|
||||
|
||||
func (sw *SendingWindow) FirstNumber() uint32 {
|
||||
return sw.cache.Front().Value.(*DataSegment).Number
|
||||
return sw.cache.Front().Value.Number
|
||||
}
|
||||
|
||||
func (sw *SendingWindow) Clear(una uint32) {
|
||||
for !sw.IsEmpty() {
|
||||
seg := sw.cache.Front().Value.(*DataSegment)
|
||||
seg := sw.cache.Front().Value
|
||||
if seg.Number >= una {
|
||||
break
|
||||
}
|
||||
seg.Release()
|
||||
sw.cache.Remove(sw.cache.Front())
|
||||
sw.cache.Front().Remove()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,8 +87,7 @@ func (sw *SendingWindow) Visit(visitor func(seg *DataSegment) bool) {
|
||||
}
|
||||
|
||||
for e := sw.cache.Front(); e != nil; e = e.Next() {
|
||||
seg := e.Value.(*DataSegment)
|
||||
if !visitor(seg) {
|
||||
if !visitor(e.Value) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -132,7 +131,7 @@ func (sw *SendingWindow) Remove(number uint32) bool {
|
||||
}
|
||||
|
||||
for e := sw.cache.Front(); e != nil; e = e.Next() {
|
||||
seg := e.Value.(*DataSegment)
|
||||
seg := e.Value
|
||||
if seg.Number > number {
|
||||
return false
|
||||
} else if seg.Number == number {
|
||||
@@ -140,7 +139,7 @@ func (sw *SendingWindow) Remove(number uint32) bool {
|
||||
sw.totalInFlightSize--
|
||||
}
|
||||
seg.Release()
|
||||
sw.cache.Remove(e)
|
||||
e.Remove()
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/congestion"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
"github.com/sagernet/sing-box/common/xray/buf"
|
||||
"github.com/sagernet/sing-box/common/xray/net"
|
||||
"github.com/sagernet/sing-box/common/xray/pipe"
|
||||
"github.com/sagernet/sing-box/common/xray/signal/done"
|
||||
"github.com/sagernet/sing-box/common/xray/uuid"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
qtls "github.com/sagernet/sing-quic"
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
sHTTP "github.com/sagernet/sing/protocol/http"
|
||||
"github.com/sagernet/sing/service"
|
||||
"golang.org/x/net/http2"
|
||||
@@ -42,15 +43,22 @@ type Client struct {
|
||||
baseRequestURL2 url.URL
|
||||
getHTTPClient func() (DialerClient, *XmuxClient)
|
||||
getHTTPClient2 func() (DialerClient, *XmuxClient)
|
||||
xmuxManager *XmuxManager
|
||||
xmuxManager2 *XmuxManager
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayXHTTPOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
|
||||
if options.Mode == "" {
|
||||
return nil, E.New("mode is not set")
|
||||
}
|
||||
if tlsConfig != nil && len(tlsConfig.NextProtos()) == 0 {
|
||||
tlsConfig.SetNextProtos([]string{"h2"})
|
||||
}
|
||||
if _, err := congestion.NewCongestionControl(options.CongestionController, options.CWND, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if options.Download != nil {
|
||||
if _, err := congestion.NewCongestionControl(options.Download.CongestionController, options.Download.CWND, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
dest := serverAddr
|
||||
baseRequestURL, err := getBaseRequestURL(&options.V2RayXHTTPBaseOptions, dest, tlsConfig)
|
||||
if err != nil {
|
||||
@@ -61,7 +69,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
|
||||
xmuxOptions = *options.Xmux
|
||||
}
|
||||
xmuxManager := NewXmuxManager(xmuxOptions, func() XmuxConn {
|
||||
return createHTTPClient(dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig)
|
||||
return createHTTPClient(ctx, dest, dialer, &options.V2RayXHTTPBaseOptions, tlsConfig)
|
||||
})
|
||||
getHTTPClient := func() (DialerClient, *XmuxClient) {
|
||||
xmuxClient := xmuxManager.GetXmuxClient(ctx)
|
||||
@@ -69,6 +77,7 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
|
||||
}
|
||||
baseRequestURL2 := baseRequestURL
|
||||
getHTTPClient2 := getHTTPClient
|
||||
var xmuxManager2 *XmuxManager
|
||||
if options.Download != nil {
|
||||
options2 := options.Download
|
||||
dialer2 := dialer
|
||||
@@ -98,8 +107,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
|
||||
if options2.Xmux != nil {
|
||||
xmuxOptions2 = *options2.Xmux
|
||||
}
|
||||
xmuxManager2 := NewXmuxManager(xmuxOptions2, func() XmuxConn {
|
||||
return createHTTPClient(dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2)
|
||||
xmuxManager2 = NewXmuxManager(xmuxOptions2, func() XmuxConn {
|
||||
return createHTTPClient(ctx, dest2, dialer2, &options2.V2RayXHTTPBaseOptions, tlsConfig2)
|
||||
})
|
||||
getHTTPClient2 = func() (DialerClient, *XmuxClient) {
|
||||
xmuxClient2 := xmuxManager2.GetXmuxClient(ctx)
|
||||
@@ -113,6 +122,8 @@ func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, s
|
||||
getHTTPClient2: getHTTPClient2,
|
||||
baseRequestURL: baseRequestURL,
|
||||
baseRequestURL2: baseRequestURL2,
|
||||
xmuxManager: xmuxManager,
|
||||
xmuxManager2: xmuxManager2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -121,8 +132,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
mode := c.options.Mode
|
||||
sessionId := ""
|
||||
if c.options.Mode != "stream-one" {
|
||||
sessionIdUuid := uuid.New()
|
||||
sessionId = sessionIdUuid.String()
|
||||
sessionId = GenerateSessionID(&c.options.V2RayXHTTPBaseOptions)
|
||||
}
|
||||
requestURL := c.baseRequestURL
|
||||
requestURL2 := c.baseRequestURL2
|
||||
@@ -182,10 +192,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
}
|
||||
scMaxEachPostBytes := options.GetNormalizedScMaxEachPostBytes()
|
||||
scMinPostsIntervalMs := options.GetNormalizedScMinPostsIntervalMs()
|
||||
if scMaxEachPostBytes.From <= 0 {
|
||||
panic("`scMaxEachPostBytes` should be bigger than 0")
|
||||
}
|
||||
maxUploadSize := scMaxEachPostBytes.Rand()
|
||||
maxUploadSize := int32(scMaxEachPostBytes.Rand())
|
||||
// WithSizeLimit(0) will still allow single bytes to pass, and a lot of
|
||||
// code relies on this behavior. Subtract 1 so that together with
|
||||
// uploadWriter wrapper, exact size limits can be enforced
|
||||
@@ -255,6 +262,10 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.xmuxManager.Close()
|
||||
if c.xmuxManager2 != nil {
|
||||
c.xmuxManager2.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -294,7 +305,7 @@ func getBaseRequestURL(options *option.V2RayXHTTPBaseOptions, dest M.Socksaddr,
|
||||
return requestURL, nil
|
||||
}
|
||||
|
||||
func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient {
|
||||
func createHTTPClient(ctx context.Context, dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXHTTPBaseOptions, tlsConfig tls.Config) DialerClient {
|
||||
httpVersion := decideHTTPVersion(tlsConfig)
|
||||
dialContext := func(ctxInner context.Context) (net.Conn, error) {
|
||||
conn, err := dialer.DialContext(ctxInner, "tcp", dest)
|
||||
@@ -319,6 +330,7 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH
|
||||
if keepAlivePeriod < 0 {
|
||||
keepAlivePeriod = 0
|
||||
}
|
||||
congestionControlFactory, _ := congestion.NewCongestionControl(options.CongestionController, options.CWND, ntp.TimeFuncFromContext(ctx))
|
||||
quicConfig := &quic.Config{
|
||||
MaxIdleTimeout: net.ConnIdleTimeout,
|
||||
// these two are defaults of quic-go/http3. the default of quic-go (no
|
||||
@@ -334,7 +346,14 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH
|
||||
if dErr != nil {
|
||||
return nil, dErr
|
||||
}
|
||||
return qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg)
|
||||
conn, dErr := qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), tlsConfig, cfg)
|
||||
if dErr != nil {
|
||||
return nil, dErr
|
||||
}
|
||||
if congestionControlFactory != nil {
|
||||
conn.SetCongestionControl(congestionControlFactory(conn))
|
||||
}
|
||||
return conn, nil
|
||||
},
|
||||
}
|
||||
case "2":
|
||||
|
||||
@@ -39,7 +39,7 @@ func (c *splitConn) Close() error {
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
return err
|
||||
return err2
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -147,7 +147,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio
|
||||
if c.httpVersion != "1.1" {
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
c.closed = true
|
||||
c.Close()
|
||||
return err
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
@@ -225,10 +225,9 @@ func (w *WaitReadCloser) Set(rc io.ReadCloser) {
|
||||
}
|
||||
|
||||
func (w *WaitReadCloser) Read(b []byte) (int, error) {
|
||||
<-w.Wait
|
||||
if w.ReadCloser == nil {
|
||||
if <-w.Wait; w.ReadCloser == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return w.ReadCloser.Read(b)
|
||||
}
|
||||
|
||||
@@ -19,8 +19,8 @@ type XmuxConn interface {
|
||||
|
||||
type XmuxClient struct {
|
||||
XmuxConn XmuxConn
|
||||
openUsage int32
|
||||
leftUsage int32
|
||||
openUsage int
|
||||
leftUsage int
|
||||
LeftRequests atomic.Int32
|
||||
UnreusableAt time.Time
|
||||
|
||||
@@ -37,7 +37,7 @@ func (c *XmuxClient) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *XmuxClient) AddOpenUsage(delta int32) {
|
||||
func (c *XmuxClient) AddOpenUsage(delta int) {
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
c.openUsage += delta
|
||||
@@ -46,7 +46,7 @@ func (c *XmuxClient) AddOpenUsage(delta int32) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *XmuxClient) GetOpenUsage() int32 {
|
||||
func (c *XmuxClient) GetOpenUsage() int {
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
return c.openUsage
|
||||
@@ -54,8 +54,8 @@ func (c *XmuxClient) GetOpenUsage() int32 {
|
||||
|
||||
type XmuxManager struct {
|
||||
options option.V2RayXHTTPXmuxOptions
|
||||
concurrency int32
|
||||
connections int32
|
||||
concurrency int
|
||||
connections int
|
||||
newConnFunc func() XmuxConn
|
||||
xmuxClients []*XmuxClient
|
||||
mtx sync.Mutex
|
||||
@@ -71,6 +71,15 @@ func NewXmuxManager(options option.V2RayXHTTPXmuxOptions, newConnFunc func() Xmu
|
||||
}
|
||||
}
|
||||
|
||||
func (m *XmuxManager) Close() {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
for _, xmuxClient := range m.xmuxClients {
|
||||
xmuxClient.Close()
|
||||
}
|
||||
m.xmuxClients = m.xmuxClients[:0]
|
||||
}
|
||||
|
||||
func (m *XmuxManager) newXmuxClient() *XmuxClient {
|
||||
xmuxClient := &XmuxClient{
|
||||
XmuxConn: m.newConnFunc(),
|
||||
@@ -81,7 +90,7 @@ func (m *XmuxManager) newXmuxClient() *XmuxClient {
|
||||
}
|
||||
xmuxClient.LeftRequests.Store(math.MaxInt32)
|
||||
if x := m.options.GetNormalizedHMaxRequestTimes().Rand(); x > 0 {
|
||||
xmuxClient.LeftRequests.Store(x)
|
||||
xmuxClient.LeftRequests.Store(int32(x))
|
||||
}
|
||||
if x := m.options.GetNormalizedHMaxReusableSecs().Rand(); x > 0 {
|
||||
xmuxClient.UnreusableAt = time.Now().Add(time.Duration(x) * time.Second)
|
||||
@@ -112,7 +121,7 @@ func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient {
|
||||
if len(m.xmuxClients) == 0 {
|
||||
return m.newXmuxClient()
|
||||
}
|
||||
if m.connections > 0 && len(m.xmuxClients) < int(m.connections) {
|
||||
if m.connections > 0 && len(m.xmuxClients) < m.connections {
|
||||
return m.newXmuxClient()
|
||||
}
|
||||
xmuxClients := make([]*XmuxClient, 0)
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/congestion"
|
||||
"github.com/sagernet/sing-box/common/kmutex"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
"github.com/sagernet/sing-box/common/xray/buf"
|
||||
xnet "github.com/sagernet/sing-box/common/xray/net"
|
||||
@@ -31,6 +33,7 @@ import (
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
sHttp "github.com/sagernet/sing/protocol/http"
|
||||
)
|
||||
@@ -49,7 +52,7 @@ type Server struct {
|
||||
options *option.V2RayXHTTPOptions
|
||||
host string
|
||||
path string
|
||||
sessionMu sync.Mutex
|
||||
sessionMu *kmutex.Kmutex[string]
|
||||
sessions sync.Map
|
||||
}
|
||||
|
||||
@@ -62,6 +65,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
|
||||
options: &options,
|
||||
host: options.Host,
|
||||
path: options.GetNormalizedPath(),
|
||||
sessionMu: kmutex.New[string](),
|
||||
}
|
||||
if server.network() == N.NetworkTCP {
|
||||
protocols := new(http.Protocols)
|
||||
@@ -80,11 +84,21 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
|
||||
},
|
||||
}
|
||||
} else {
|
||||
congestionControlFactory, err := congestion.NewCongestionControl(options.CongestionController, options.CWND, ntp.TimeFuncFromContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.quicConfig = &quic.Config{
|
||||
DisablePathMTUDiscovery: !C.IsLinux && !C.IsWindows,
|
||||
}
|
||||
server.http3Server = &http3.Server{
|
||||
Handler: server,
|
||||
ConnContext: func(ctx context.Context, conn *quic.Conn) context.Context {
|
||||
if congestionControlFactory != nil {
|
||||
conn.SetCongestionControl(congestionControlFactory(conn))
|
||||
}
|
||||
return log.ContextWithNewID(ctx)
|
||||
},
|
||||
}
|
||||
}
|
||||
return server, nil
|
||||
@@ -102,7 +116,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
return
|
||||
}
|
||||
WriteResponseHeader(writer, request.Method, request.Header, s.options)
|
||||
length := int(s.options.GetNormalizedXPaddingBytes().Rand())
|
||||
length := s.options.GetNormalizedXPaddingBytes().Rand()
|
||||
config := XPaddingConfig{Length: length}
|
||||
if s.options.XPaddingObfsMode {
|
||||
config.Placement = XPaddingPlacement{
|
||||
@@ -125,15 +139,25 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
validRange := s.options.GetNormalizedXPaddingBytes()
|
||||
paddingValue, paddingPlacement := ExtractXPaddingFromRequest(&s.options.V2RayXHTTPBaseOptions, request, s.options.XPaddingObfsMode)
|
||||
if !IsPaddingValid(&s.options.V2RayXHTTPBaseOptions, paddingValue, validRange.From, validRange.To, PaddingMethod(s.options.XPaddingMethod)) {
|
||||
s.logger.ErrorContext(request.Context(), "invalid padding ("+paddingPlacement+") length:", int32(len(paddingValue)))
|
||||
s.logger.ErrorContext(request.Context(), "invalid padding ("+paddingPlacement+") length:", len(paddingValue))
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionId, seqStr := ExtractMetaFromRequest(s.options, request, s.path)
|
||||
if sessionId == "" && s.options.Mode != "" && s.options.Mode != "auto" && s.options.Mode != "stream-one" && s.options.Mode != "stream-up" {
|
||||
s.logger.ErrorContext(request.Context(), "stream-one mode is not allowed")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
if s.options.Mode != "" && s.options.Mode != "auto" {
|
||||
if sessionId == "" {
|
||||
if s.options.Mode != "stream-one" && s.options.Mode != "stream-up" {
|
||||
s.logger.ErrorContext(request.Context(), "stream-one mode is not allowed")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if s.options.Mode == "stream-one" {
|
||||
s.logger.ErrorContext(request.Context(), "session is not allowed in stream-one mode")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
var forwardedAddrs []xnet.Address
|
||||
if len(s.options.TrustedXForwardedFor) > 0 {
|
||||
@@ -171,7 +195,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
if sessionId != "" {
|
||||
currentSession = s.upsertSession(sessionId)
|
||||
}
|
||||
scMaxEachPostBytes := int(s.options.GetNormalizedScMaxEachPostBytes().To)
|
||||
scMaxEachPostBytes := s.options.GetNormalizedScMaxEachPostBytes().To
|
||||
uplinkDataPlacement := s.options.GetNormalizedUplinkDataPlacement()
|
||||
uplinkDataKey := s.options.UplinkDataKey
|
||||
isUplinkRequest := false
|
||||
@@ -207,12 +231,22 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
referrer := request.Header.Get("Referer")
|
||||
if referrer != "" && scStreamUpServerSecs.To > 0 {
|
||||
go func() {
|
||||
timer := time.NewTimer(0)
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
defer timer.Stop()
|
||||
for {
|
||||
_, err := httpSC.Write(bytes.Repeat([]byte{'X'}, int(s.options.GetNormalizedXPaddingBytes().Rand())))
|
||||
_, err := httpSC.Write(bytes.Repeat([]byte{'X'}, s.options.GetNormalizedXPaddingBytes().Rand()))
|
||||
if err != nil {
|
||||
break
|
||||
return
|
||||
}
|
||||
timer.Reset(time.Duration(scStreamUpServerSecs.Rand()) * time.Second)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-httpSC.Wait():
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Duration(scStreamUpServerSecs.Rand()) * time.Second)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -327,7 +361,11 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
// after GET is done, the connection is finished. disable automatic
|
||||
// session reaping, and handle it in defer
|
||||
currentSession.isFullyConnected.Close()
|
||||
defer s.sessions.Delete(sessionId)
|
||||
defer func() {
|
||||
s.sessionMu.Lock(sessionId)
|
||||
defer s.sessionMu.Unlock(sessionId)
|
||||
s.sessions.Delete(sessionId)
|
||||
}()
|
||||
}
|
||||
// magic header instructs nginx + apache to not buffer response body
|
||||
writer.Header().Set("X-Accel-Buffering", "no")
|
||||
@@ -410,32 +448,27 @@ func (s *Server) network() string {
|
||||
}
|
||||
|
||||
func (s *Server) upsertSession(sessionId string) *httpSession {
|
||||
// fast path
|
||||
s.sessionMu.Lock(sessionId)
|
||||
defer s.sessionMu.Unlock(sessionId)
|
||||
currentSessionAny, ok := s.sessions.Load(sessionId)
|
||||
if ok {
|
||||
return currentSessionAny.(*httpSession)
|
||||
}
|
||||
// slow path
|
||||
s.sessionMu.Lock()
|
||||
defer s.sessionMu.Unlock()
|
||||
currentSessionAny, ok = s.sessions.Load(sessionId)
|
||||
if ok {
|
||||
return currentSessionAny.(*httpSession)
|
||||
}
|
||||
session := &httpSession{
|
||||
uploadQueue: NewUploadQueue(s.options.GetNormalizedScMaxBufferedPosts()),
|
||||
isFullyConnected: done.New(),
|
||||
}
|
||||
s.sessions.Store(sessionId, session)
|
||||
shouldReap := done.New()
|
||||
go func() {
|
||||
time.Sleep(30 * time.Second)
|
||||
shouldReap.Close()
|
||||
}()
|
||||
go func() {
|
||||
reapTimer := time.NewTimer(30 * time.Second)
|
||||
defer reapTimer.Stop()
|
||||
select {
|
||||
case <-shouldReap.Wait():
|
||||
s.sessions.Delete(sessionId)
|
||||
case <-reapTimer.C:
|
||||
s.sessionMu.Lock(sessionId)
|
||||
if current, ok := s.sessions.Load(sessionId); ok && current.(*httpSession) == session {
|
||||
s.sessions.Delete(sessionId)
|
||||
}
|
||||
s.sessionMu.Unlock(sessionId)
|
||||
session.uploadQueue.Close()
|
||||
case <-session.isFullyConnected.Wait():
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ package xhttp
|
||||
import (
|
||||
"container/heap"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
@@ -19,19 +18,22 @@ type Packet struct {
|
||||
}
|
||||
|
||||
type uploadQueue struct {
|
||||
reader io.ReadCloser
|
||||
nomore bool
|
||||
pushedPackets chan Packet
|
||||
writeCloseMutex sync.Mutex
|
||||
heap uploadHeap
|
||||
nextSeq uint64
|
||||
closed bool
|
||||
maxPackets int
|
||||
reader io.ReadCloser
|
||||
nomore bool
|
||||
pushedPackets chan Packet
|
||||
done chan struct{}
|
||||
heap uploadHeap
|
||||
nextSeq uint64
|
||||
closed bool
|
||||
maxPackets int
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewUploadQueue(maxPackets int) *uploadQueue {
|
||||
return &uploadQueue{
|
||||
pushedPackets: make(chan Packet, maxPackets),
|
||||
done: make(chan struct{}),
|
||||
heap: uploadHeap{},
|
||||
nextSeq: 0,
|
||||
closed: false,
|
||||
@@ -40,63 +42,83 @@ func NewUploadQueue(maxPackets int) *uploadQueue {
|
||||
}
|
||||
|
||||
func (h *uploadQueue) Push(p Packet) error {
|
||||
h.writeCloseMutex.Lock()
|
||||
defer h.writeCloseMutex.Unlock()
|
||||
h.mtx.Lock()
|
||||
if h.closed {
|
||||
h.mtx.Unlock()
|
||||
return E.New("packet queue closed")
|
||||
}
|
||||
if h.nomore {
|
||||
h.mtx.Unlock()
|
||||
return E.New("h.reader already exists")
|
||||
}
|
||||
if p.Reader != nil {
|
||||
h.nomore = true
|
||||
}
|
||||
h.pushedPackets <- p
|
||||
return nil
|
||||
h.mtx.Unlock()
|
||||
select {
|
||||
case h.pushedPackets <- p:
|
||||
return nil
|
||||
case <-h.done:
|
||||
return E.New("packet queue closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *uploadQueue) Close() error {
|
||||
h.writeCloseMutex.Lock()
|
||||
defer h.writeCloseMutex.Unlock()
|
||||
if !h.closed {
|
||||
h.closed = true
|
||||
runtime.Gosched() // hope Read() gets the packet
|
||||
f:
|
||||
for {
|
||||
select {
|
||||
case p := <-h.pushedPackets:
|
||||
if p.Reader != nil {
|
||||
h.reader = p.Reader
|
||||
}
|
||||
default:
|
||||
break f
|
||||
h.mtx.Lock()
|
||||
if h.closed {
|
||||
h.mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
h.closed = true
|
||||
close(h.done)
|
||||
h.mtx.Unlock()
|
||||
|
||||
for {
|
||||
select {
|
||||
case p := <-h.pushedPackets:
|
||||
if p.Reader != nil {
|
||||
p.Reader.Close()
|
||||
}
|
||||
default:
|
||||
if h.reader != nil {
|
||||
return h.reader.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
close(h.pushedPackets)
|
||||
}
|
||||
if h.reader != nil {
|
||||
return h.reader.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *uploadQueue) Read(b []byte) (int, error) {
|
||||
h.mtx.Lock()
|
||||
if h.closed {
|
||||
h.mtx.Unlock()
|
||||
return 0, io.EOF
|
||||
}
|
||||
h.mtx.Unlock()
|
||||
if h.reader != nil {
|
||||
return h.reader.Read(b)
|
||||
}
|
||||
if h.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if len(h.heap) == 0 {
|
||||
packet, more := <-h.pushedPackets
|
||||
if !more {
|
||||
select {
|
||||
case packet, more := <-h.pushedPackets:
|
||||
if !more {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if packet.Reader != nil {
|
||||
h.mtx.Lock()
|
||||
if h.closed {
|
||||
packet.Reader.Close()
|
||||
h.mtx.Unlock()
|
||||
return 0, io.EOF
|
||||
}
|
||||
h.reader = packet.Reader
|
||||
h.mtx.Unlock()
|
||||
return h.reader.Read(b)
|
||||
}
|
||||
heap.Push(&h.heap, packet)
|
||||
case <-h.done:
|
||||
return 0, io.EOF
|
||||
}
|
||||
if packet.Reader != nil {
|
||||
h.reader = packet.Reader
|
||||
return h.reader.Read(b)
|
||||
}
|
||||
heap.Push(&h.heap, packet)
|
||||
}
|
||||
for len(h.heap) > 0 {
|
||||
packet := heap.Pop(&h.heap).(Packet)
|
||||
@@ -125,11 +147,15 @@ func (h *uploadQueue) Read(b []byte) (int, error) {
|
||||
return 0, E.New("packet queue is too large")
|
||||
}
|
||||
heap.Push(&h.heap, packet)
|
||||
packet2, more := <-h.pushedPackets
|
||||
if !more {
|
||||
select {
|
||||
case packet2, more := <-h.pushedPackets:
|
||||
if !more {
|
||||
return 0, io.EOF
|
||||
}
|
||||
heap.Push(&h.heap, packet2)
|
||||
case <-h.done:
|
||||
return 0, io.EOF
|
||||
}
|
||||
heap.Push(&h.heap, packet2)
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
|
||||
@@ -4,16 +4,49 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
|
||||
"github.com/sagernet/sing-box/common/xray/buf"
|
||||
"github.com/sagernet/sing-box/common/xray/utils"
|
||||
"github.com/sagernet/sing-box/common/xray/uuid"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
// PredefinedTable maps named charsets to their alphabets for session ID generation.
|
||||
|
||||
var PredefinedTable = map[string]string{
|
||||
"ALPHABET": "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
||||
"Alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
|
||||
"BASE36": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
||||
"Base62": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
|
||||
"HEX": "0123456789ABCDEF",
|
||||
"alphabet": "abcdefghijklmnopqrstuvwxyz",
|
||||
"base36": "0123456789abcdefghijklmnopqrstuvwxyz",
|
||||
"hex": "0123456789abcdef",
|
||||
"number": "0123456789",
|
||||
}
|
||||
|
||||
func GenerateSessionID(options *option.V2RayXHTTPBaseOptions) string {
|
||||
length := options.SessionIDLength.Rand()
|
||||
table := options.SessionIDTable
|
||||
if predefined, ok := PredefinedTable[table]; ok {
|
||||
table = predefined
|
||||
}
|
||||
if table != "" && length > 0 {
|
||||
id := make([]byte, length)
|
||||
for i := range id {
|
||||
id[i] = table[rand.N(len(table))]
|
||||
}
|
||||
return string(id)
|
||||
}
|
||||
newUUID := uuid.New()
|
||||
return newUUID.String()
|
||||
}
|
||||
|
||||
func FillStreamRequest(request *http.Request, sessionId string, seqStr string, options *option.V2RayXHTTPBaseOptions) {
|
||||
request.Header = options.GetRequestHeader()
|
||||
length := int(options.GetNormalizedXPaddingBytes().Rand())
|
||||
length := options.GetNormalizedXPaddingBytes().Rand()
|
||||
config := XPaddingConfig{Length: length}
|
||||
if options.XPaddingObfsMode {
|
||||
config.Placement = XPaddingPlacement{
|
||||
@@ -58,7 +91,7 @@ func FillPacketRequest(request *http.Request, sessionId string, seqStr string, p
|
||||
}
|
||||
}
|
||||
}
|
||||
length := int(options.GetNormalizedXPaddingBytes().Rand())
|
||||
length := options.GetNormalizedXPaddingBytes().Rand()
|
||||
config := XPaddingConfig{Length: length}
|
||||
if options.XPaddingObfsMode {
|
||||
config.Placement = XPaddingPlacement{
|
||||
@@ -125,7 +158,7 @@ func GetRequestHeaderWithPayload(payload []byte, options *option.V2RayXHTTPBaseO
|
||||
key := options.UplinkDataKey
|
||||
encodedData := base64.RawURLEncoding.EncodeToString(payload)
|
||||
for i := 0; len(encodedData) > 0; i++ {
|
||||
chunkSize := min(int(options.GetNormalizedUplinkChunkSize().Rand()), len(encodedData))
|
||||
chunkSize := min(options.GetNormalizedUplinkChunkSize().Rand(), len(encodedData))
|
||||
chunk := encodedData[:chunkSize]
|
||||
encodedData = encodedData[chunkSize:]
|
||||
headerKey := fmt.Sprintf("%s-%d", key, i)
|
||||
@@ -140,7 +173,7 @@ func GetRequestCookiesWithPayload(payload []byte, options *option.V2RayXHTTPBase
|
||||
key := options.UplinkDataKey
|
||||
encodedData := base64.RawURLEncoding.EncodeToString(payload)
|
||||
for i := 0; len(encodedData) > 0; i++ {
|
||||
chunkSize := min(int(options.GetNormalizedUplinkChunkSize().Rand()), len(encodedData))
|
||||
chunkSize := min(options.GetNormalizedUplinkChunkSize().Rand(), len(encodedData))
|
||||
chunk := encodedData[:chunkSize]
|
||||
encodedData = encodedData[chunkSize:]
|
||||
cookieName := fmt.Sprintf("%s_%d", key, i)
|
||||
|
||||
@@ -31,11 +31,12 @@ func (w uploadWriter) Write(b []byte) (int, error) {
|
||||
|
||||
var writed int
|
||||
for _, buff := range buffer.MultiBuffer {
|
||||
n := int(buff.Len())
|
||||
err := w.WriteMultiBuffer(buf.MultiBuffer{buff})
|
||||
if err != nil {
|
||||
return writed, err
|
||||
}
|
||||
writed += int(buff.Len())
|
||||
writed += n
|
||||
}
|
||||
return writed, nil
|
||||
}
|
||||
|
||||
@@ -264,7 +264,7 @@ func ExtractXPaddingFromRequest(options *option.V2RayXHTTPBaseOptions, req *http
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string, from, to int32, method PaddingMethod) bool {
|
||||
func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string, from, to int, method PaddingMethod) bool {
|
||||
if paddingValue == "" {
|
||||
return false
|
||||
}
|
||||
@@ -274,11 +274,11 @@ func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string,
|
||||
}
|
||||
switch method {
|
||||
case PaddingMethodRepeatX:
|
||||
n := int32(len(paddingValue))
|
||||
n := len(paddingValue)
|
||||
return n >= from && n <= to
|
||||
case PaddingMethodTokenish:
|
||||
const tolerance = int32(validationTolerance)
|
||||
n := int32(hpack.HuffmanEncodeLength(paddingValue))
|
||||
const tolerance = validationTolerance
|
||||
n := int(hpack.HuffmanEncodeLength(paddingValue))
|
||||
f := from - tolerance
|
||||
t := to + tolerance
|
||||
if f < 0 {
|
||||
@@ -286,7 +286,7 @@ func IsPaddingValid(options *option.V2RayXHTTPBaseOptions, paddingValue string,
|
||||
}
|
||||
return n >= f && n <= t
|
||||
default:
|
||||
n := int32(len(paddingValue))
|
||||
n := len(paddingValue)
|
||||
return n >= from && n <= to
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
Xbadoption "github.com/sagernet/sing-box/common/xray/json/badoption"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
tun "github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
@@ -49,10 +49,10 @@ type AmneziaOptions struct {
|
||||
S2 int
|
||||
S3 int
|
||||
S4 int
|
||||
H1 *Xbadoption.Range
|
||||
H2 *Xbadoption.Range
|
||||
H3 *Xbadoption.Range
|
||||
H4 *Xbadoption.Range
|
||||
H1 *badoption.Range[uint32]
|
||||
H2 *badoption.Range[uint32]
|
||||
H3 *badoption.Range[uint32]
|
||||
H4 *badoption.Range[uint32]
|
||||
I1 string
|
||||
I2 string
|
||||
I3 string
|
||||
|
||||
Reference in New Issue
Block a user