mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-08 12:14:54 +03:00
Improve multiplex
This commit is contained in:
@@ -28,9 +28,10 @@ type Client struct {
|
||||
maxConnections int
|
||||
minStreams int
|
||||
maxStreams int
|
||||
paddingEnabled bool
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int) *Client {
|
||||
func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int, paddingEnabled bool) (*Client, error) {
|
||||
return &Client{
|
||||
ctx: ctx,
|
||||
dialer: dialer,
|
||||
@@ -38,7 +39,8 @@ func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConne
|
||||
maxConnections: maxConnections,
|
||||
minStreams: minStreams,
|
||||
maxStreams: maxStreams,
|
||||
}
|
||||
paddingEnabled: paddingEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (*Client, error) {
|
||||
@@ -52,7 +54,7 @@ func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.M
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams), nil
|
||||
return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams, options.Padding)
|
||||
}
|
||||
|
||||
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
@@ -145,10 +147,19 @@ func (c *Client) offerNew() (abstractSession, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if vectorisedWriter, isVectorised := bufio.CreateVectorisedWriter(conn); isVectorised {
|
||||
conn = &vectorisedProtocolConn{protocolConn{Conn: conn, protocol: c.protocol}, vectorisedWriter}
|
||||
var version byte
|
||||
if c.paddingEnabled {
|
||||
version = Version1
|
||||
} else {
|
||||
conn = &protocolConn{Conn: conn, protocol: c.protocol}
|
||||
version = Version0
|
||||
}
|
||||
conn = newProtocolConn(conn, Request{
|
||||
Version: version,
|
||||
Protocol: c.protocol,
|
||||
PaddingEnabled: c.paddingEnabled,
|
||||
})
|
||||
if c.paddingEnabled {
|
||||
conn = newPaddingConn(conn)
|
||||
}
|
||||
session, err := c.protocol.newClient(conn)
|
||||
if err != nil {
|
||||
@@ -213,7 +224,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) {
|
||||
Network: N.NetworkTCP,
|
||||
Destination: c.destination,
|
||||
}
|
||||
_buffer := buf.StackNewSize(requestLen(request) + len(b))
|
||||
_buffer := buf.StackNewSize(streamRequestLen(request) + len(b))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
@@ -307,7 +318,7 @@ func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
|
||||
Network: N.NetworkUDP,
|
||||
Destination: c.destination,
|
||||
}
|
||||
rLen := requestLen(request)
|
||||
rLen := streamRequestLen(request)
|
||||
if len(payload) > 0 {
|
||||
rLen += 2 + len(payload)
|
||||
}
|
||||
@@ -452,7 +463,7 @@ func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
|
||||
Destination: c.destination,
|
||||
PacketAddr: true,
|
||||
}
|
||||
rLen := requestLen(request)
|
||||
rLen := streamRequestLen(request)
|
||||
if len(payload) > 0 {
|
||||
rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload)
|
||||
}
|
||||
|
||||
240
common/mux/padding.go
Normal file
240
common/mux/padding.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
const kFirstPaddings = 16
|
||||
|
||||
type paddingConn struct {
|
||||
N.ExtendedConn
|
||||
writer N.VectorisedWriter
|
||||
readPadding int
|
||||
writePadding int
|
||||
readRemaining int
|
||||
paddingRemaining int
|
||||
}
|
||||
|
||||
func newPaddingConn(conn net.Conn) net.Conn {
|
||||
writer, isVectorised := bufio.CreateVectorisedWriter(conn)
|
||||
if isVectorised {
|
||||
return &vectorisedPaddingConn{
|
||||
paddingConn{
|
||||
ExtendedConn: bufio.NewExtendedConn(conn),
|
||||
writer: bufio.NewVectorisedWriter(conn),
|
||||
},
|
||||
writer,
|
||||
}
|
||||
} else {
|
||||
return &paddingConn{
|
||||
ExtendedConn: bufio.NewExtendedConn(conn),
|
||||
writer: bufio.NewVectorisedWriter(conn),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *paddingConn) Read(p []byte) (n int, err error) {
|
||||
if c.readRemaining > 0 {
|
||||
if len(p) > c.readRemaining {
|
||||
p = p[:c.readRemaining]
|
||||
}
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.readRemaining -= n
|
||||
return
|
||||
}
|
||||
if c.paddingRemaining > 0 {
|
||||
err = rw.SkipN(c.ExtendedConn, c.paddingRemaining)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.paddingRemaining = 0
|
||||
}
|
||||
if c.readPadding < kFirstPaddings {
|
||||
var paddingHdr []byte
|
||||
if len(p) >= 4 {
|
||||
paddingHdr = p[:4]
|
||||
} else {
|
||||
_paddingHdr := make([]byte, 4)
|
||||
defer common.KeepAlive(_paddingHdr)
|
||||
paddingHdr = common.Dup(_paddingHdr)
|
||||
}
|
||||
_, err = io.ReadFull(c.ExtendedConn, paddingHdr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
|
||||
paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:]))
|
||||
if len(p) > originalDataSize {
|
||||
p = p[:originalDataSize]
|
||||
}
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.readPadding++
|
||||
c.readRemaining = originalDataSize - n
|
||||
c.paddingRemaining = paddingLen
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.Read(p)
|
||||
}
|
||||
|
||||
func (c *paddingConn) Write(p []byte) (n int, err error) {
|
||||
for pLen := len(p); pLen > 0; {
|
||||
var data []byte
|
||||
if pLen > 65535 {
|
||||
data = p[:65535]
|
||||
p = p[65535:]
|
||||
pLen -= 65535
|
||||
} else {
|
||||
data = p
|
||||
pLen = 0
|
||||
}
|
||||
var writeN int
|
||||
writeN, err = c.write(data)
|
||||
n += writeN
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *paddingConn) write(p []byte) (n int, err error) {
|
||||
if c.writePadding < kFirstPaddings {
|
||||
paddingLen := 256 + rand.Intn(512)
|
||||
_buffer := buf.StackNewSize(4 + len(p) + paddingLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
header := buffer.Extend(4)
|
||||
binary.BigEndian.PutUint16(header[:2], uint16(len(p)))
|
||||
binary.BigEndian.PutUint16(header[2:], uint16(paddingLen))
|
||||
common.Must1(buffer.Write(p))
|
||||
buffer.Extend(paddingLen)
|
||||
_, err = c.ExtendedConn.Write(buffer.Bytes())
|
||||
if err == nil {
|
||||
n = len(p)
|
||||
}
|
||||
c.writePadding++
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.Write(p)
|
||||
}
|
||||
|
||||
func (c *paddingConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
p := buffer.FreeBytes()
|
||||
if c.readRemaining > 0 {
|
||||
if len(p) > c.readRemaining {
|
||||
p = p[:c.readRemaining]
|
||||
}
|
||||
n, err := c.ExtendedConn.Read(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.readRemaining -= n
|
||||
buffer.Truncate(n)
|
||||
return nil
|
||||
}
|
||||
if c.paddingRemaining > 0 {
|
||||
err := rw.SkipN(c.ExtendedConn, c.paddingRemaining)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.paddingRemaining = 0
|
||||
}
|
||||
if c.readPadding < kFirstPaddings {
|
||||
var paddingHdr []byte
|
||||
if len(p) >= 4 {
|
||||
paddingHdr = p[:4]
|
||||
} else {
|
||||
_paddingHdr := make([]byte, 4)
|
||||
defer common.KeepAlive(_paddingHdr)
|
||||
paddingHdr = common.Dup(_paddingHdr)
|
||||
}
|
||||
_, err := io.ReadFull(c.ExtendedConn, paddingHdr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
|
||||
paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:]))
|
||||
|
||||
if len(p) > originalDataSize {
|
||||
p = p[:originalDataSize]
|
||||
}
|
||||
n, err := c.ExtendedConn.Read(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.readPadding++
|
||||
c.readRemaining = originalDataSize - n
|
||||
c.paddingRemaining = paddingLen
|
||||
buffer.Truncate(n)
|
||||
return nil
|
||||
}
|
||||
return c.ExtendedConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *paddingConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
if c.writePadding < kFirstPaddings {
|
||||
bufferLen := buffer.Len()
|
||||
if bufferLen > 65535 {
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
paddingLen := 256 + rand.Intn(512)
|
||||
header := buffer.ExtendHeader(4)
|
||||
binary.BigEndian.PutUint16(header[:2], uint16(bufferLen))
|
||||
binary.BigEndian.PutUint16(header[2:], uint16(paddingLen))
|
||||
buffer.Extend(paddingLen)
|
||||
c.writePadding++
|
||||
}
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *paddingConn) FrontHeadroom() int {
|
||||
return 4 + 256 + 1024
|
||||
}
|
||||
|
||||
type vectorisedPaddingConn struct {
|
||||
paddingConn
|
||||
writer N.VectorisedWriter
|
||||
}
|
||||
|
||||
func (c *vectorisedPaddingConn) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
if c.writePadding < kFirstPaddings {
|
||||
bufferLen := buf.LenMulti(buffers)
|
||||
if bufferLen > 65535 {
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
for _, buffer := range buffers {
|
||||
_, err := c.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
paddingLen := 256 + rand.Intn(512)
|
||||
header := buf.NewSize(4)
|
||||
common.Must(
|
||||
binary.Write(header, binary.BigEndian, uint16(bufferLen)),
|
||||
binary.Write(header, binary.BigEndian, uint16(paddingLen)),
|
||||
)
|
||||
c.writePadding++
|
||||
padding := buf.NewSize(paddingLen)
|
||||
padding.Extend(paddingLen)
|
||||
buffers = append(append([]*buf.Buffer{header}, buffers...), padding)
|
||||
}
|
||||
return c.writer.WriteVectorised(buffers)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package mux
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
@@ -113,11 +114,14 @@ func (p Protocol) String() string {
|
||||
}
|
||||
|
||||
const (
|
||||
version0 = 0
|
||||
Version0 = iota
|
||||
Version1
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Protocol Protocol
|
||||
Version byte
|
||||
Protocol Protocol
|
||||
PaddingEnabled bool
|
||||
}
|
||||
|
||||
func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
@@ -125,19 +129,60 @@ func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != version0 {
|
||||
if version < Version0 || version > Version1 {
|
||||
return nil, E.New("unsupported version: ", version)
|
||||
}
|
||||
protocol, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Request{Protocol: Protocol(protocol)}, nil
|
||||
var paddingEnabled bool
|
||||
if version == Version1 {
|
||||
err = binary.Read(reader, binary.BigEndian, &paddingEnabled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if paddingEnabled {
|
||||
var paddingLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = rw.SkipN(reader, int(paddingLen))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return &Request{Version: version, Protocol: Protocol(protocol), PaddingEnabled: paddingEnabled}, nil
|
||||
}
|
||||
|
||||
func EncodeRequest(buffer *buf.Buffer, request Request) {
|
||||
buffer.WriteByte(version0)
|
||||
buffer.WriteByte(byte(request.Protocol))
|
||||
func EncodeRequest(request Request, payload []byte) *buf.Buffer {
|
||||
var requestLen int
|
||||
requestLen += 2
|
||||
var paddingLen uint16
|
||||
if request.Version == Version1 {
|
||||
requestLen += 1
|
||||
if request.PaddingEnabled {
|
||||
requestLen += 2
|
||||
paddingLen = uint16(256 + rand.Intn(512))
|
||||
requestLen += int(paddingLen)
|
||||
}
|
||||
}
|
||||
buffer := buf.NewSize(requestLen + len(payload))
|
||||
common.Must(
|
||||
buffer.WriteByte(request.Version),
|
||||
buffer.WriteByte(byte(request.Protocol)),
|
||||
)
|
||||
if request.Version == Version1 {
|
||||
common.Must(binary.Write(buffer, binary.BigEndian, request.PaddingEnabled))
|
||||
if request.PaddingEnabled {
|
||||
common.Must(binary.Write(buffer, binary.BigEndian, paddingLen))
|
||||
buffer.Extend(int(paddingLen))
|
||||
}
|
||||
}
|
||||
common.Must1(buffer.Write(payload))
|
||||
return buffer
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -174,7 +219,7 @@ func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
|
||||
return &StreamRequest{network, destination, udpAddr}, nil
|
||||
}
|
||||
|
||||
func requestLen(request StreamRequest) int {
|
||||
func streamRequestLen(request StreamRequest) int {
|
||||
var rLen int
|
||||
rLen += 1 // version
|
||||
rLen += 2 // flags
|
||||
|
||||
@@ -22,6 +22,9 @@ func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Ha
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if request.PaddingEnabled {
|
||||
conn = newPaddingConn(conn)
|
||||
}
|
||||
session, err := request.Protocol.newServer(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
@@ -50,23 +49,35 @@ func (y *yamuxSession) CanTakeNewRequest() bool {
|
||||
|
||||
type protocolConn struct {
|
||||
net.Conn
|
||||
protocol Protocol
|
||||
request Request
|
||||
protocolWritten bool
|
||||
}
|
||||
|
||||
func newProtocolConn(conn net.Conn, request Request) net.Conn {
|
||||
writer, isVectorised := bufio.CreateVectorisedWriter(conn)
|
||||
if isVectorised {
|
||||
return &vectorisedProtocolConn{
|
||||
protocolConn{
|
||||
Conn: conn,
|
||||
request: request,
|
||||
},
|
||||
writer,
|
||||
}
|
||||
} else {
|
||||
return &protocolConn{
|
||||
Conn: conn,
|
||||
request: request,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *protocolConn) Write(p []byte) (n int, err error) {
|
||||
if c.protocolWritten {
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
_buffer := buf.StackNewSize(2 + len(p))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
EncodeRequest(buffer, Request{
|
||||
Protocol: c.protocol,
|
||||
})
|
||||
common.Must(common.Error(buffer.Write(p)))
|
||||
buffer := EncodeRequest(c.request, p)
|
||||
n, err = c.Conn.Write(buffer.Bytes())
|
||||
buffer.Release()
|
||||
if err == nil {
|
||||
n--
|
||||
}
|
||||
@@ -87,20 +98,14 @@ func (c *protocolConn) Upstream() any {
|
||||
|
||||
type vectorisedProtocolConn struct {
|
||||
protocolConn
|
||||
N.VectorisedWriter
|
||||
writer N.VectorisedWriter
|
||||
}
|
||||
|
||||
func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
if c.protocolWritten {
|
||||
return c.VectorisedWriter.WriteVectorised(buffers)
|
||||
return c.writer.WriteVectorised(buffers)
|
||||
}
|
||||
c.protocolWritten = true
|
||||
_buffer := buf.StackNewSize(2)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
EncodeRequest(buffer, Request{
|
||||
Protocol: c.protocol,
|
||||
})
|
||||
return c.VectorisedWriter.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
|
||||
buffer := EncodeRequest(c.request, nil)
|
||||
return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user