mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-05-14 00:51:12 +03:00
Improve multiplexer
This commit is contained in:
@@ -15,40 +15,44 @@ import (
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
var _ N.Dialer = (*Client)(nil)
|
||||
|
||||
type Client struct {
|
||||
access sync.Mutex
|
||||
connections list.List[*yamux.Session]
|
||||
connections list.List[abstractSession]
|
||||
ctx context.Context
|
||||
dialer N.Dialer
|
||||
protocol Protocol
|
||||
maxConnections int
|
||||
minStreams int
|
||||
maxStreams int
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, dialer N.Dialer, maxConnections int, minStreams int, maxStreams int) *Client {
|
||||
func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int) *Client {
|
||||
return &Client{
|
||||
ctx: ctx,
|
||||
dialer: dialer,
|
||||
protocol: protocol,
|
||||
maxConnections: maxConnections,
|
||||
minStreams: minStreams,
|
||||
maxStreams: maxStreams,
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) N.Dialer {
|
||||
func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (N.Dialer, error) {
|
||||
if !options.Enabled {
|
||||
return dialer
|
||||
return dialer, nil
|
||||
}
|
||||
if options.MaxConnections == 0 && options.MaxStreams == 0 {
|
||||
options.MinStreams = 8
|
||||
}
|
||||
return NewClient(ctx, dialer, options.MaxConnections, options.MinStreams, options.MaxStreams)
|
||||
protocol, err := ParseProtocol(options.Protocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams), nil
|
||||
}
|
||||
|
||||
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
@@ -80,8 +84,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
|
||||
|
||||
func (c *Client) openStream() (net.Conn, error) {
|
||||
var (
|
||||
session *yamux.Session
|
||||
stream *yamux.Stream
|
||||
session abstractSession
|
||||
stream net.Conn
|
||||
err error
|
||||
)
|
||||
for attempts := 0; attempts < 2; attempts++ {
|
||||
@@ -89,7 +93,7 @@ func (c *Client) openStream() (net.Conn, error) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
stream, err = session.OpenStream()
|
||||
stream, err = session.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -101,11 +105,11 @@ func (c *Client) openStream() (net.Conn, error) {
|
||||
return &wrapStream{stream}, nil
|
||||
}
|
||||
|
||||
func (c *Client) offer() (*yamux.Session, error) {
|
||||
func (c *Client) offer() (abstractSession, error) {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
sessions := make([]*yamux.Session, 0, c.maxConnections)
|
||||
sessions := make([]abstractSession, 0, c.maxConnections)
|
||||
for element := c.connections.Front(); element != nil; {
|
||||
if element.Value.IsClosed() {
|
||||
nextElement := element.Next()
|
||||
@@ -120,10 +124,7 @@ func (c *Client) offer() (*yamux.Session, error) {
|
||||
if sLen == 0 {
|
||||
return c.offerNew()
|
||||
}
|
||||
// session := common.MinBy(sessions, yamux.Session.NumStreams)
|
||||
session := common.MinBy(sessions, func(it *yamux.Session) int {
|
||||
return it.NumStreams()
|
||||
})
|
||||
session := common.MinBy(sessions, abstractSession.NumStreams)
|
||||
numStreams := session.NumStreams()
|
||||
if numStreams == 0 {
|
||||
return session, nil
|
||||
@@ -140,12 +141,12 @@ func (c *Client) offer() (*yamux.Session, error) {
|
||||
return c.offerNew()
|
||||
}
|
||||
|
||||
func (c *Client) offerNew() (*yamux.Session, error) {
|
||||
func (c *Client) offerNew() (abstractSession, error) {
|
||||
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := yamux.Client(conn, newMuxConfig())
|
||||
session, err := c.protocol.newClient(&protocolConn{Conn: conn, protocol: c.protocol})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -170,7 +171,7 @@ type ClientConn struct {
|
||||
}
|
||||
|
||||
func (c *ClientConn) readResponse() error {
|
||||
response, err := ReadResponse(c.Conn)
|
||||
response, err := ReadStreamResponse(c.Conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -195,7 +196,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) {
|
||||
if c.requestWrite {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
request := Request{
|
||||
request := StreamRequest{
|
||||
Network: N.NetworkTCP,
|
||||
Destination: c.destination,
|
||||
}
|
||||
@@ -203,7 +204,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) {
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
EncodeRequest(request, buffer)
|
||||
EncodeStreamRequest(request, buffer)
|
||||
buffer.Write(b)
|
||||
_, err = c.Conn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
@@ -255,7 +256,7 @@ type ClientPacketConn struct {
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) readResponse() error {
|
||||
response, err := ReadResponse(c.ExtendedConn)
|
||||
response, err := ReadStreamResponse(c.ExtendedConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -285,7 +286,7 @@ func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
|
||||
request := Request{
|
||||
request := StreamRequest{
|
||||
Network: N.NetworkUDP,
|
||||
Destination: c.destination,
|
||||
}
|
||||
@@ -297,7 +298,7 @@ func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
EncodeRequest(request, buffer)
|
||||
EncodeStreamRequest(request, buffer)
|
||||
if len(payload) > 0 {
|
||||
common.Must(
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
|
||||
@@ -363,7 +364,7 @@ type ClientPacketAddrConn struct {
|
||||
}
|
||||
|
||||
func (c *ClientPacketAddrConn) readResponse() error {
|
||||
response, err := ReadResponse(c.ExtendedConn)
|
||||
response, err := ReadStreamResponse(c.ExtendedConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -399,7 +400,7 @@ func (c *ClientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
|
||||
}
|
||||
|
||||
func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
|
||||
request := Request{
|
||||
request := StreamRequest{
|
||||
Network: N.NetworkUDP,
|
||||
Destination: c.destination,
|
||||
PacketAddr: true,
|
||||
@@ -412,7 +413,7 @@ func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
EncodeRequest(request, buffer)
|
||||
EncodeStreamRequest(request, buffer)
|
||||
if len(payload) > 0 {
|
||||
common.Must(
|
||||
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
var Destination = M.Socksaddr{
|
||||
@@ -21,7 +22,55 @@ var Destination = M.Socksaddr{
|
||||
Port: 444,
|
||||
}
|
||||
|
||||
func newMuxConfig() *yamux.Config {
|
||||
const (
|
||||
ProtocolYAMux Protocol = 0
|
||||
ProtocolSMux Protocol = 1
|
||||
)
|
||||
|
||||
type Protocol byte
|
||||
|
||||
func ParseProtocol(name string) (Protocol, error) {
|
||||
switch name {
|
||||
case "", "yamux":
|
||||
return ProtocolYAMux, nil
|
||||
case "smux":
|
||||
return ProtocolSMux, nil
|
||||
default:
|
||||
return ProtocolYAMux, E.New("unknown multiplex protocol: ", name)
|
||||
}
|
||||
}
|
||||
|
||||
func (p Protocol) newServer(conn net.Conn) (abstractSession, error) {
|
||||
switch p {
|
||||
case ProtocolYAMux:
|
||||
return yamux.Server(conn, yaMuxConfig())
|
||||
case ProtocolSMux:
|
||||
session, err := smux.Server(conn, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &smuxSession{session}, nil
|
||||
default:
|
||||
panic("unknown protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func (p Protocol) newClient(conn net.Conn) (abstractSession, error) {
|
||||
switch p {
|
||||
case ProtocolYAMux:
|
||||
return yamux.Client(conn, yaMuxConfig())
|
||||
case ProtocolSMux:
|
||||
session, err := smux.Client(conn, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &smuxSession{session}, nil
|
||||
default:
|
||||
panic("unknown protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func yaMuxConfig() *yamux.Config {
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
config.StreamCloseTimeout = C.TCPTimeout
|
||||
@@ -29,18 +78,23 @@ func newMuxConfig() *yamux.Config {
|
||||
return config
|
||||
}
|
||||
|
||||
func (p Protocol) String() string {
|
||||
switch p {
|
||||
case ProtocolYAMux:
|
||||
return "yamux"
|
||||
case ProtocolSMux:
|
||||
return "smux"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
version0 = 0
|
||||
flagUDP = 1
|
||||
flagAddr = 2
|
||||
statusSuccess = 0
|
||||
statusError = 1
|
||||
version0 = 0
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Network string
|
||||
Destination M.Socksaddr
|
||||
PacketAddr bool
|
||||
Protocol Protocol
|
||||
}
|
||||
|
||||
func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
@@ -51,8 +105,37 @@ func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
if version != version0 {
|
||||
return nil, E.New("unsupported version: ", version)
|
||||
}
|
||||
protocol, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if protocol > byte(ProtocolSMux) {
|
||||
return nil, E.New("unsupported protocol: ", protocol)
|
||||
}
|
||||
return &Request{Protocol: Protocol(protocol)}, nil
|
||||
}
|
||||
|
||||
func EncodeRequest(buffer *buf.Buffer, request Request) {
|
||||
buffer.WriteByte(version0)
|
||||
buffer.WriteByte(byte(request.Protocol))
|
||||
}
|
||||
|
||||
const (
|
||||
flagUDP = 1
|
||||
flagAddr = 2
|
||||
statusSuccess = 0
|
||||
statusError = 1
|
||||
)
|
||||
|
||||
type StreamRequest struct {
|
||||
Network string
|
||||
Destination M.Socksaddr
|
||||
PacketAddr bool
|
||||
}
|
||||
|
||||
func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
|
||||
var flags uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &flags)
|
||||
err := binary.Read(reader, binary.BigEndian, &flags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -68,10 +151,10 @@ func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
network = N.NetworkUDP
|
||||
udpAddr = flags&flagAddr != 0
|
||||
}
|
||||
return &Request{network, destination, udpAddr}, nil
|
||||
return &StreamRequest{network, destination, udpAddr}, nil
|
||||
}
|
||||
|
||||
func requestLen(request Request) int {
|
||||
func requestLen(request StreamRequest) int {
|
||||
var rLen int
|
||||
rLen += 1 // version
|
||||
rLen += 2 // flags
|
||||
@@ -79,7 +162,7 @@ func requestLen(request Request) int {
|
||||
return rLen
|
||||
}
|
||||
|
||||
func EncodeRequest(request Request, buffer *buf.Buffer) {
|
||||
func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
|
||||
destination := request.Destination
|
||||
var flags uint16
|
||||
if request.Network == N.NetworkUDP {
|
||||
@@ -92,19 +175,18 @@ func EncodeRequest(request Request, buffer *buf.Buffer) {
|
||||
}
|
||||
}
|
||||
common.Must(
|
||||
buffer.WriteByte(version0),
|
||||
binary.Write(buffer, binary.BigEndian, flags),
|
||||
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
|
||||
)
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
type StreamResponse struct {
|
||||
Status uint8
|
||||
Message string
|
||||
}
|
||||
|
||||
func ReadResponse(reader io.Reader) (*Response, error) {
|
||||
var response Response
|
||||
func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
|
||||
var response StreamResponse
|
||||
status, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -14,12 +14,14 @@ import (
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
session, err := yamux.Server(conn, newMuxConfig())
|
||||
request, err := ReadRequest(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session, err := request.Protocol.newServer(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -34,7 +36,7 @@ func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Ha
|
||||
|
||||
func newConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, stream net.Conn, metadata adapter.InboundContext) {
|
||||
stream = &wrapStream{stream}
|
||||
request, err := ReadRequest(stream)
|
||||
request, err := ReadStreamRequest(stream)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
|
||||
71
common/mux/session.go
Normal file
71
common/mux/session.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
type abstractSession interface {
|
||||
Open() (net.Conn, error)
|
||||
Accept() (net.Conn, error)
|
||||
NumStreams() int
|
||||
Close() error
|
||||
IsClosed() bool
|
||||
}
|
||||
|
||||
var _ abstractSession = (*smuxSession)(nil)
|
||||
|
||||
type smuxSession struct {
|
||||
*smux.Session
|
||||
}
|
||||
|
||||
func (s *smuxSession) Open() (net.Conn, error) {
|
||||
return s.OpenStream()
|
||||
}
|
||||
|
||||
func (s *smuxSession) Accept() (net.Conn, error) {
|
||||
return s.AcceptStream()
|
||||
}
|
||||
|
||||
type protocolConn struct {
|
||||
net.Conn
|
||||
protocol Protocol
|
||||
protocolWritten bool
|
||||
}
|
||||
|
||||
func (c *protocolConn) Write(p []byte) (n int, err error) {
|
||||
if c.protocolWritten {
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
_buffer := buf.StackNewSize(2 + len(p))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
EncodeRequest(buffer, Request{
|
||||
Protocol: c.protocol,
|
||||
})
|
||||
common.Must(common.Error(buffer.Write(p)))
|
||||
n, err = c.Conn.Write(buffer.Bytes())
|
||||
if err == nil {
|
||||
n--
|
||||
}
|
||||
c.protocolWritten = true
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if !c.protocolWritten {
|
||||
return bufio.ReadFrom0(c, r)
|
||||
}
|
||||
return bufio.Copy(c.Conn, r)
|
||||
}
|
||||
|
||||
func (c *protocolConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
Reference in New Issue
Block a user