mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-05-14 00:51:12 +03:00
Add new admin panel, failover, dns fallback, providers, limiters. Update XHTTP
This commit is contained in:
232
protocol/failover/conn.go
Normal file
232
protocol/failover/conn.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package failover
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
)
|
||||
|
||||
type dial func() (net.Conn, error)
|
||||
|
||||
type failoverConn struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
dial dial
|
||||
onClose func()
|
||||
|
||||
readIndex uint32
|
||||
readBuffer *bytes.Buffer
|
||||
writeIndex uint32
|
||||
writeBuffers [BufferSize][]byte
|
||||
|
||||
await chan struct{}
|
||||
awaitMtx sync.Mutex
|
||||
|
||||
err error
|
||||
|
||||
once sync.Once
|
||||
mtx sync.RWMutex
|
||||
}
|
||||
|
||||
func NewFailoverConn(ctx context.Context, conn net.Conn, dial dial, onClose func()) *failoverConn {
|
||||
var writeBuffers [BufferSize][]byte
|
||||
for i := range BufferSize {
|
||||
writeBuffers[i] = make([]byte, 0, 1000)
|
||||
}
|
||||
return &failoverConn{
|
||||
Conn: conn,
|
||||
ctx: ctx,
|
||||
dial: dial,
|
||||
readBuffer: bytes.NewBuffer(make([]byte, 0, 1000)),
|
||||
writeBuffers: writeBuffers,
|
||||
onClose: onClose,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *failoverConn) Read(b []byte) (int, error) {
|
||||
for {
|
||||
c.mtx.RLock()
|
||||
conn := c.Conn
|
||||
n, err := c.read(conn, b)
|
||||
if err != nil {
|
||||
if err == SessionClosed {
|
||||
c.err = io.EOF
|
||||
conn.Close()
|
||||
c.mtx.RUnlock()
|
||||
return 0, c.err
|
||||
}
|
||||
c.mtx.RUnlock()
|
||||
err = c.awaitConn(conn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
c.readIndex++
|
||||
c.mtx.RUnlock()
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *failoverConn) Write(b []byte) (int, error) {
|
||||
for {
|
||||
c.mtx.RLock()
|
||||
conn := c.Conn
|
||||
n, err := c.write(conn, b)
|
||||
if err != nil {
|
||||
c.mtx.RUnlock()
|
||||
err = c.awaitConn(conn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
writeIndex := c.writeIndex % BufferSize
|
||||
c.writeBuffers[writeIndex] = append(c.writeBuffers[writeIndex][:0], b...)
|
||||
c.writeIndex++
|
||||
c.mtx.RUnlock()
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *failoverConn) RestoreConn(conn net.Conn) error {
|
||||
c.Conn.Close()
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
_, err := conn.Write([]byte{
|
||||
byte(c.readIndex >> 24),
|
||||
byte(c.readIndex >> 16),
|
||||
byte(c.readIndex >> 8),
|
||||
byte(c.readIndex),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var data [4]byte
|
||||
_, err = io.ReadFull(conn, data[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
writeIndex := binary.BigEndian.Uint32(data[:])
|
||||
buffers := make([][]byte, 0, BufferSize)
|
||||
for writeIndex != c.writeIndex {
|
||||
if len(buffers) == BufferSize {
|
||||
return SessionBroken
|
||||
}
|
||||
buffers = append(buffers, c.writeBuffers[writeIndex%BufferSize])
|
||||
writeIndex++
|
||||
}
|
||||
for _, buffer := range buffers {
|
||||
_, err = c.write(conn, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.Conn = conn
|
||||
if c.await != nil {
|
||||
close(c.await)
|
||||
c.await = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *failoverConn) Close() error {
|
||||
c.once.Do(func() {
|
||||
c.mtx.RLock()
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
c.err = io.EOF
|
||||
c.mtx.RUnlock()
|
||||
c.Write([]byte{})
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *failoverConn) read(conn net.Conn, b []byte) (int, error) {
|
||||
if c.readBuffer.Len() == 0 {
|
||||
c.readBuffer.Reset()
|
||||
var data [2]byte
|
||||
_, err := io.ReadFull(conn, data[:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.BigEndian.Uint16(data[:])
|
||||
if n == 0 {
|
||||
return 0, SessionClosed
|
||||
}
|
||||
_, err = io.CopyN(c.readBuffer, conn, int64(n))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.readBuffer.Read(b)
|
||||
}
|
||||
|
||||
func (c *failoverConn) write(conn net.Conn, b []byte) (int, error) {
|
||||
buffer := make([]byte, 2+len(b))
|
||||
binary.BigEndian.PutUint16(buffer, uint16(len(b)))
|
||||
copy(buffer[2:], b)
|
||||
n, err := conn.Write(buffer)
|
||||
return n - 2, err
|
||||
}
|
||||
|
||||
func (c *failoverConn) awaitConn(oldConn net.Conn) error {
|
||||
c.awaitMtx.Lock()
|
||||
defer c.awaitMtx.Unlock()
|
||||
if c.err != nil {
|
||||
return c.err
|
||||
}
|
||||
if c.Conn != oldConn {
|
||||
return c.ctx.Err()
|
||||
}
|
||||
oldConn.Close()
|
||||
timer := time.NewTimer(C.TCPConnectTimeout)
|
||||
defer timer.Stop()
|
||||
if c.dial != nil {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return c.ctx.Err()
|
||||
case <-timer.C:
|
||||
c.err = SessionExpired
|
||||
return c.err
|
||||
default:
|
||||
}
|
||||
conn, err := c.dial()
|
||||
if err != nil {
|
||||
if err == SessionNotFound {
|
||||
c.err = err
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
err = c.RestoreConn(conn)
|
||||
if err != nil {
|
||||
if err == SessionBroken {
|
||||
c.err = err
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
c.await = make(chan struct{})
|
||||
select {
|
||||
case <-c.await:
|
||||
case <-timer.C:
|
||||
c.err = SessionExpired
|
||||
return c.err
|
||||
case <-c.ctx.Done():
|
||||
return c.ctx.Err()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
136
protocol/failover/inbound.go
Normal file
136
protocol/failover/inbound.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package failover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/common/kmutex"
|
||||
"github.com/sagernet/sing-box/common/uot"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.FailoverInboundOptions](registry, C.TypeFailover, NewInbound)
|
||||
}
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
logger logger.ContextLogger
|
||||
router adapter.ConnectionRouterEx
|
||||
inbounds []adapter.Inbound
|
||||
conns map[uuid.UUID]*failoverConn
|
||||
|
||||
sessionMtx *kmutex.Kmutex[uuid.UUID]
|
||||
mtx sync.RWMutex
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FailoverInboundOptions) (adapter.Inbound, error) {
|
||||
if len(options.Inbounds) == 0 {
|
||||
return nil, E.New("missing inbounds")
|
||||
}
|
||||
inbound := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeFailover, tag),
|
||||
logger: logger,
|
||||
router: uot.NewRouter(router, logger),
|
||||
conns: make(map[uuid.UUID]*failoverConn),
|
||||
sessionMtx: kmutex.New[uuid.UUID](),
|
||||
}
|
||||
router = NewRouter(router, logger, inbound.connHandler)
|
||||
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
||||
inbounds := make([]adapter.Inbound, len(options.Inbounds))
|
||||
for i, inboundOptions := range options.Inbounds {
|
||||
inbound, err := inboundRegistry.UnsafeCreate(ctx, router, logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbounds[i] = inbound
|
||||
}
|
||||
inbound.inbounds = inbounds
|
||||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
for _, inbound := range h.inbounds {
|
||||
err := inbound.Start(stage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Close() error {
|
||||
errs := make([]error, 0)
|
||||
for _, inbound := range h.inbounds {
|
||||
err := inbound.Close()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if len(errs) != 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Inbound) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
|
||||
if metadata.Destination != Destination {
|
||||
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
return nil
|
||||
}
|
||||
request, err := ReadRequest(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionUUID := request.UUID
|
||||
h.sessionMtx.Lock(sessionUUID)
|
||||
if request.Command == CommandTCP {
|
||||
failoverConn := NewFailoverConn(ctx, conn, nil, func() {
|
||||
h.sessionMtx.Lock(sessionUUID)
|
||||
h.mtx.Lock()
|
||||
defer h.sessionMtx.Unlock(sessionUUID)
|
||||
defer h.mtx.Unlock()
|
||||
delete(h.conns, sessionUUID)
|
||||
})
|
||||
h.mtx.Lock()
|
||||
h.conns[sessionUUID] = failoverConn
|
||||
h.mtx.Unlock()
|
||||
metadata.Inbound = h.Tag()
|
||||
metadata.InboundType = C.TypeFailover
|
||||
metadata.Destination = request.Destination
|
||||
h.sessionMtx.Unlock(sessionUUID)
|
||||
h.router.RouteConnectionEx(ctx, failoverConn, metadata, onClose)
|
||||
return nil
|
||||
}
|
||||
if request.Command == CommandReconnect {
|
||||
h.mtx.RLock()
|
||||
serverConn, ok := h.conns[sessionUUID]
|
||||
h.mtx.RUnlock()
|
||||
if !ok {
|
||||
_, err := conn.Write([]byte{StatusSessionNotFound})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return SessionNotFound
|
||||
}
|
||||
_, err = conn.Write([]byte{StatusOK})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := serverConn.RestoreConn(conn)
|
||||
h.sessionMtx.Unlock(sessionUUID)
|
||||
return err
|
||||
}
|
||||
return E.New("command ", request.Command, " not found")
|
||||
}
|
||||
109
protocol/failover/outbound.go
Normal file
109
protocol/failover/outbound.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package failover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/uot"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.FailoverOutboundOptions](registry, C.TypeFailover, NewFailover)
|
||||
}
|
||||
|
||||
type Failover struct {
|
||||
outbound.Adapter
|
||||
ctx context.Context
|
||||
outbound adapter.OutboundManager
|
||||
logger logger.ContextLogger
|
||||
dial DialStrategy
|
||||
uotClient *uot.Client
|
||||
}
|
||||
|
||||
func NewFailover(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FailoverOutboundOptions) (adapter.Outbound, error) {
|
||||
if len(options.Outbounds) == 0 {
|
||||
return nil, E.New("missing outbounds")
|
||||
}
|
||||
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
||||
outbounds := make([]adapter.Outbound, len(options.Outbounds))
|
||||
for i, outboundOptions := range options.Outbounds {
|
||||
outbound, err := outboundRegistry.UnsafeCreateOutbound(ctx, router, logger, outboundOptions.Tag, outboundOptions.Type, outboundOptions.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outbounds[i] = outbound
|
||||
}
|
||||
dial, err := CreateStrategy(options.Strategy, outbounds, logger, options.Delay.Build())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outbound := &Failover{
|
||||
Adapter: outbound.NewAdapter(C.TypeFailover, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
|
||||
ctx: ctx,
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
logger: logger,
|
||||
dial: dial,
|
||||
}
|
||||
outbound.uotClient = &uot.Client{
|
||||
Dialer: outbound,
|
||||
Version: uot.Version,
|
||||
}
|
||||
return outbound, nil
|
||||
}
|
||||
|
||||
func (f *Failover) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if N.NetworkName(network) == N.NetworkUDP {
|
||||
return f.uotClient.DialContext(ctx, network, destination)
|
||||
}
|
||||
conn, err := f.dial(ctx, network, Destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionUUID, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = WriteRequest(conn, &Request{Command: CommandTCP, UUID: sessionUUID, Destination: destination})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewFailoverConn(ctx, conn, func() (net.Conn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, C.TCPConnectTimeout)
|
||||
defer cancel()
|
||||
conn, err := f.dial(ctx, network, Destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = WriteRequest(conn, &Request{Command: CommandReconnect, UUID: sessionUUID, Destination: destination})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data [1]byte
|
||||
_, err = io.ReadFull(conn, data[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var status uint8 = data[0]
|
||||
if status == StatusSessionNotFound {
|
||||
conn.Close()
|
||||
return nil, SessionNotFound
|
||||
}
|
||||
return conn, nil
|
||||
}, nil), nil
|
||||
}
|
||||
|
||||
func (f *Failover) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return f.uotClient.ListenPacket(ctx, destination)
|
||||
}
|
||||
97
protocol/failover/protocol.go
Normal file
97
protocol/failover/protocol.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package failover
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = 0
|
||||
BufferSize = 10
|
||||
)
|
||||
|
||||
const (
|
||||
CommandTCP = 1
|
||||
CommandReconnect = 2
|
||||
)
|
||||
|
||||
const (
|
||||
StatusOK uint8 = iota + 1
|
||||
StatusSessionNotFound
|
||||
)
|
||||
|
||||
var (
|
||||
SessionClosed = E.New("session closed")
|
||||
SessionNotFound = E.New("session not found")
|
||||
SessionExpired = E.New("session expired")
|
||||
SessionBroken = E.New("session broken")
|
||||
)
|
||||
|
||||
var Destination = M.Socksaddr{
|
||||
Fqdn: "sp.failover.sing-box.arpa",
|
||||
Port: 444,
|
||||
}
|
||||
|
||||
var AddressSerializer = M.NewSerializer(
|
||||
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
|
||||
M.AddressFamilyByte(0x03, M.AddressFamilyIPv6),
|
||||
M.AddressFamilyByte(0x02, M.AddressFamilyFqdn),
|
||||
M.PortThenAddress(),
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
UUID uuid.UUID
|
||||
Command byte
|
||||
Destination M.Socksaddr
|
||||
}
|
||||
|
||||
func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
var request Request
|
||||
var version uint8
|
||||
err := binary.Read(reader, binary.BigEndian, &version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != Version {
|
||||
return nil, E.New("unknown version: ", version)
|
||||
}
|
||||
_, err = io.ReadFull(reader, request.UUID[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &request.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Destination, err = AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &request, nil
|
||||
}
|
||||
|
||||
func WriteRequest(writer io.Writer, request *Request) error {
|
||||
var requestLen int
|
||||
requestLen += 1 // version
|
||||
requestLen += 1 // command
|
||||
requestLen += 16 // UUID
|
||||
requestLen += AddressSerializer.AddrPortLen(request.Destination)
|
||||
buffer := buf.NewSize(requestLen)
|
||||
defer buffer.Release()
|
||||
common.Must(
|
||||
buffer.WriteByte(Version),
|
||||
common.Error(buffer.Write(request.UUID[:])),
|
||||
buffer.WriteByte(request.Command),
|
||||
)
|
||||
err := AddressSerializer.WriteAddrPort(buffer, request.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(writer.Write(buffer.Bytes()))
|
||||
}
|
||||
55
protocol/failover/router.go
Normal file
55
protocol/failover/router.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package failover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type Router struct {
|
||||
adapter.Router
|
||||
logger logger.ContextLogger
|
||||
handler func(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) error
|
||||
}
|
||||
|
||||
func NewRouter(router adapter.Router, logger logger.ContextLogger, handler func(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) error) *Router {
|
||||
return &Router{Router: router, logger: logger, handler: handler}
|
||||
}
|
||||
|
||||
func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
if metadata.Destination != Destination {
|
||||
return r.Router.RouteConnection(ctx, conn, metadata)
|
||||
}
|
||||
return r.handler(ctx, conn, metadata, func(error) {})
|
||||
}
|
||||
|
||||
func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||
if metadata.Destination != Destination {
|
||||
return r.Router.RoutePacketConnection(ctx, conn, metadata)
|
||||
}
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
if metadata.Destination != Destination {
|
||||
r.Router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
return
|
||||
}
|
||||
if err := r.handler(ctx, conn, metadata, onClose); err != nil {
|
||||
r.logger.ErrorContext(ctx, err)
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
if metadata.Destination != Destination {
|
||||
r.Router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||
return
|
||||
}
|
||||
r.logger.ErrorContext(ctx, os.ErrInvalid)
|
||||
N.CloseOnHandshakeFailure(conn, onClose, os.ErrInvalid)
|
||||
}
|
||||
70
protocol/failover/strategy.go
Normal file
70
protocol/failover/strategy.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package failover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type DialStrategy = func(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error)
|
||||
|
||||
func cycleStrategy(outbounds []adapter.Outbound, logger logger.ContextLogger, delay time.Duration) DialStrategy {
|
||||
return func(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
for {
|
||||
for _, outbound := range outbounds {
|
||||
conn, err := outbound.DialContext(ctx, network, destination)
|
||||
if err != nil {
|
||||
logger.InfoContext(ctx, err)
|
||||
if delay > 0 {
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sequentialStrategy(outbounds []adapter.Outbound, logger logger.ContextLogger, delay time.Duration) DialStrategy {
|
||||
return func(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
var err error
|
||||
for _, outbound := range outbounds {
|
||||
var conn net.Conn
|
||||
conn, err = outbound.DialContext(ctx, network, destination)
|
||||
if err != nil {
|
||||
logger.InfoContext(ctx, err)
|
||||
if delay > 0 {
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func CreateStrategy(strategy string, outbounds []adapter.Outbound, logger logger.ContextLogger, delay time.Duration) (DialStrategy, error) {
|
||||
switch strategy {
|
||||
case "cycle":
|
||||
return cycleStrategy(outbounds, logger, delay), nil
|
||||
case "sequential", "":
|
||||
return sequentialStrategy(outbounds, logger, delay), nil
|
||||
default:
|
||||
return nil, E.New("strategy not found: ", strategy)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user