Add admin panel, manager, node_manager, bandwidth limiter, connection limiter, bonding, failover, vless encryption, mkcp transport

This commit is contained in:
Sergei Maklagin
2026-02-26 22:44:31 +03:00
parent 287fe834db
commit c0aa3480c5
115 changed files with 12582 additions and 301 deletions

164
protocol/bond/conn.go Normal file
View File

@@ -0,0 +1,164 @@
package bond
import (
"encoding/binary"
"errors"
"io"
"net"
"time"
)
type bondedConn struct {
conns []net.Conn
downloadRatios []uint8
uploadRatios []uint8
readBuffer []byte
readOffset int
readSize int
}
func NewBondedConn(conns []net.Conn, downloadRatios, uploadRatios []uint8) *bondedConn {
return &bondedConn{
conns: conns,
downloadRatios: downloadRatios,
uploadRatios: uploadRatios,
readBuffer: make([]byte, 65535),
}
}
func (c *bondedConn) Read(b []byte) (n int, err error) {
if c.readOffset == c.readSize {
var header [2]byte
_, err := io.ReadFull(c.conns[0], header[:])
if err != nil {
return 0, err
}
size := int(binary.BigEndian.Uint16(header[:]))
chunkLens := splitByRatios(size, c.downloadRatios)
total := 0
for i, chunkLen := range chunkLens {
if chunkLen == 0 {
continue
}
chunk := c.readBuffer[total : total+chunkLen]
n, err := io.ReadFull(c.conns[i], chunk)
total += n
if err != nil {
return total, err
}
}
c.readOffset = 0
c.readSize = size
}
n = copy(b, c.readBuffer[c.readOffset:c.readSize])
c.readOffset += n
return n, nil
}
func (c *bondedConn) Write(b []byte) (n int, err error) {
chunkLens := splitByRatios(len(b), c.uploadRatios)
var header [2]byte
binary.BigEndian.PutUint16(header[:], uint16(len(b)))
_, err = c.conns[0].Write(header[:])
if err != nil {
return 0, err
}
total := 0
for i, chunkLen := range chunkLens {
if chunkLen == 0 {
continue
}
chunk := b[total : total+chunkLen]
conn := c.conns[i]
subTotal := 0
for subTotal < len(chunk) {
n, err := conn.Write(chunk[subTotal:])
subTotal += n
total += n
if err != nil {
return total, err
}
if n == 0 {
return total, io.ErrUnexpectedEOF
}
}
}
return total, err
}
func (c *bondedConn) Close() error {
errs := make([]error, 0)
for _, conn := range c.conns {
err := conn.Close()
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func (c *bondedConn) LocalAddr() net.Addr {
return nil
}
func (c *bondedConn) RemoteAddr() net.Addr {
return nil
}
func (c *bondedConn) SetDeadline(t time.Time) error {
errs := make([]error, 0)
for _, conn := range c.conns {
err := conn.SetDeadline(t)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func (c *bondedConn) SetReadDeadline(t time.Time) error {
errs := make([]error, 0)
for _, conn := range c.conns {
err := conn.SetReadDeadline(t)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func (c *bondedConn) SetWriteDeadline(t time.Time) error {
errs := make([]error, 0)
for _, conn := range c.conns {
err := conn.SetWriteDeadline(t)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func splitByRatios(number int, ratios []uint8) []int {
result := make([]int, len(ratios))
remaining := number
for i := 0; i < len(ratios)-1; i++ {
part := number * int(ratios[i]) / 100
result[i] = part
remaining -= part
}
result[len(ratios)-1] = remaining
return result
}

146
protocol/bond/inbound.go Normal file
View File

@@ -0,0 +1,146 @@
package bond
import (
"context"
"errors"
"net"
"sync"
"time"
"github.com/patrickmn/go-cache"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/uot"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.BondInboundOptions](registry, C.TypeBond, NewInbound)
}
type Inbound struct {
inbound.Adapter
logger logger.ContextLogger
router adapter.ConnectionRouterEx
inbounds []adapter.Inbound
conns *cache.Cache
mtx sync.Mutex
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.BondInboundOptions) (adapter.Inbound, error) {
if len(options.Inbounds) == 0 {
return nil, E.New("missing tags")
}
inbound := &Inbound{
Adapter: inbound.NewAdapter(C.TypeTunnelServer, tag),
logger: logger,
router: uot.NewRouter(router, logger),
conns: cache.New(C.TCPConnectTimeout, time.Second),
}
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
inbounds := make([]adapter.Inbound, len(options.Inbounds))
for i, inboundOptions := range options.Inbounds {
inbound, err := inboundRegistry.UnsafeCreate(ctx, NewRouter(router, logger, inbound.connHandler), logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options)
if err != nil {
return nil, err
}
inbounds[i] = inbound
}
inbound.inbounds = inbounds
inbound.conns.OnEvicted(func(s string, i interface{}) {
inbound.mtx.Lock()
defer inbound.mtx.Unlock()
ratioConns := i.(map[uint8]*ratioConn)
for _, ratioConn := range ratioConns {
if ratioConn != nil {
ratioConn.conn.Close()
}
}
})
return inbound, nil
}
func (h *Inbound) Start(stage adapter.StartStage) error {
for _, inbound := range h.inbounds {
err := inbound.Start(stage)
if err != nil {
return err
}
}
return nil
}
func (h *Inbound) Close() error {
errs := make([]error, 0)
for _, inbound := range h.inbounds {
err := inbound.Close()
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func (h *Inbound) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
request, err := ReadRequest(conn)
if err != nil {
return err
}
h.mtx.Lock()
defer h.mtx.Unlock()
var ratioConns map[uint8]*ratioConn
rawRatioConns, ok := h.conns.Get(request.UUID.String())
if ok {
ratioConns = rawRatioConns.(map[uint8]*ratioConn)
} else {
ratioConns = make(map[uint8]*ratioConn, request.Count)
h.conns.SetDefault(request.UUID.String(), ratioConns)
}
ratioConns[request.Index] = &ratioConn{
conn: conn,
downloadRatio: request.DownloadRatio,
uploadRatio: request.UploadRatio,
}
if len(ratioConns) == int(request.Count) {
conns := make([]net.Conn, len(ratioConns))
downloadRatios := make([]uint8, len(ratioConns))
uploadRatios := make([]uint8, len(ratioConns))
var totalDownloadRatio, totalUploadRatio uint8
for index, ratioConn := range ratioConns {
conns[index] = ratioConn.conn
downloadRatios[index] = ratioConn.downloadRatio
uploadRatios[index] = ratioConn.uploadRatio
totalDownloadRatio += ratioConn.downloadRatio
totalUploadRatio += ratioConn.uploadRatio
delete(ratioConns, index)
}
if totalDownloadRatio != 100 || totalUploadRatio != 100 {
for _, conn := range conns {
conn.Close()
}
return E.New("invalid ratios")
}
conn = NewBondedConn(conns, downloadRatios, uploadRatios)
metadata.Inbound = h.Tag()
metadata.InboundType = C.TypeBond
metadata.Destination = request.Destination
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
}
return nil
}
type ratioConn struct {
conn net.Conn
downloadRatio uint8
uploadRatio uint8
}

152
protocol/bond/outbound.go Normal file
View File

@@ -0,0 +1,152 @@
package bond
import (
"context"
"errors"
"net"
"sync"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.BondOutboundOptions](registry, C.TypeBond, NewOutbound)
}
type Outbound struct {
outbound.Adapter
ctx context.Context
logger logger.ContextLogger
outbounds []adapter.Outbound
downloadRatios []uint8
uploadRatios []uint8
uotClient *uot.Client
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.BondOutboundOptions) (adapter.Outbound, error) {
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
outbounds := make([]adapter.Outbound, 0, len(options.Outbounds))
downloadRatios := make([]uint8, 0, len(options.Outbounds))
uploadRatios := make([]uint8, 0, len(options.Outbounds))
var totalDownloadRatio, totalUploadRatio uint8
for _, outboundOptions := range options.Outbounds {
count := outboundOptions.Count
if count == 0 {
count = 1
}
for range count {
outbound, err := outboundRegistry.UnsafeCreateOutbound(ctx, router, logger, outboundOptions.Outbound.Tag, outboundOptions.Outbound.Type, outboundOptions.Outbound.Options)
if err != nil {
return nil, err
}
outbounds = append(outbounds, outbound)
downloadRatios = append(downloadRatios, outboundOptions.DownloadRatio)
uploadRatios = append(uploadRatios, outboundOptions.UploadRatio)
totalDownloadRatio += outboundOptions.DownloadRatio
totalUploadRatio += outboundOptions.UploadRatio
}
}
if totalDownloadRatio != 100 || totalUploadRatio != 100 {
return nil, E.New("invalid ratios")
}
outbound := &Outbound{
Adapter: outbound.NewAdapter(C.TypeTunnelClient, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
ctx: ctx,
outbounds: outbounds,
downloadRatios: downloadRatios,
uploadRatios: uploadRatios,
logger: logger,
}
outbound.uotClient = &uot.Client{
Dialer: outbound,
Version: uot.Version,
}
return outbound, nil
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if N.NetworkName(network) == N.NetworkUDP {
return h.uotClient.DialContext(ctx, network, destination)
}
conns := make([]net.Conn, len(h.outbounds))
connUUID, err := uuid.NewV4()
if err != nil {
return nil, err
}
errs := make([]error, 0, len(conns))
var mtx sync.Mutex
var wg sync.WaitGroup
for i, outbound := range h.outbounds {
wg.Go(
func() {
conn, err := outbound.DialContext(ctx, network, Destination)
if err != nil {
mtx.Lock()
errs = append(errs, err)
mtx.Unlock()
return
}
err = WriteRequest(
conn,
&Request{
UUID: connUUID,
Index: byte(i),
Count: byte(len(h.outbounds)),
DownloadRatio: h.uploadRatios[i],
UploadRatio: h.downloadRatios[i],
Destination: destination,
},
)
if err != nil {
conn.Close()
mtx.Lock()
errs = append(errs, err)
mtx.Unlock()
return
}
conns[i] = conn
},
)
}
wg.Wait()
if len(errs) != 0 {
for _, conn := range conns {
if conn != nil {
conn.Close()
}
}
return nil, errors.Join(errs...)
}
conn := NewBondedConn(conns, h.downloadRatios, h.uploadRatios)
return conn, nil
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return h.uotClient.ListenPacket(ctx, destination)
}
func (h *Outbound) Close() error {
errs := make([]error, 0)
for _, outbound := range h.outbounds {
err := common.Close(outbound)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}

100
protocol/bond/protocol.go Normal file
View File

@@ -0,0 +1,100 @@
package bond
import (
"encoding/binary"
"io"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
const (
Version = 0
)
var Destination = M.Socksaddr{
Fqdn: "sp.bond.sing-box.arpa",
Port: 444,
}
var AddressSerializer = M.NewSerializer(
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
M.AddressFamilyByte(0x03, M.AddressFamilyIPv6),
M.AddressFamilyByte(0x02, M.AddressFamilyFqdn),
M.PortThenAddress(),
)
type Request struct {
UUID uuid.UUID
Index byte
Count byte
DownloadRatio byte
UploadRatio byte
Destination M.Socksaddr
}
func ReadRequest(reader io.Reader) (*Request, error) {
var request Request
var version uint8
err := binary.Read(reader, binary.BigEndian, &version)
if err != nil {
return nil, err
}
if version != Version {
return nil, E.New("unknown version: ", version)
}
_, err = io.ReadFull(reader, request.UUID[:])
if err != nil {
return nil, err
}
err = binary.Read(reader, binary.BigEndian, &request.Index)
if err != nil {
return nil, err
}
err = binary.Read(reader, binary.BigEndian, &request.Count)
if err != nil {
return nil, err
}
err = binary.Read(reader, binary.BigEndian, &request.DownloadRatio)
if err != nil {
return nil, err
}
err = binary.Read(reader, binary.BigEndian, &request.UploadRatio)
if err != nil {
return nil, err
}
request.Destination, err = AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
return &request, nil
}
func WriteRequest(writer io.Writer, request *Request) error {
var requestLen int
requestLen += 1 // version
requestLen += 16 // UUID
requestLen += 1 // index
requestLen += 1 // count
requestLen += 1 // download ratio
requestLen += 1 // upload ratio
requestLen += AddressSerializer.AddrPortLen(request.Destination)
buffer := buf.NewSize(requestLen)
defer buffer.Release()
common.Must(
buffer.WriteByte(Version),
common.Error(buffer.Write(request.UUID[:])),
buffer.WriteByte(request.Index),
buffer.WriteByte(request.Count),
buffer.WriteByte(request.DownloadRatio),
buffer.WriteByte(request.UploadRatio),
)
err := AddressSerializer.WriteAddrPort(buffer, request.Destination)
if err != nil {
return err
}
return common.Error(writer.Write(buffer.Bytes()))
}

41
protocol/bond/router.go Normal file
View File

@@ -0,0 +1,41 @@
package bond
import (
"context"
"net"
"os"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
)
type Router struct {
adapter.Router
logger logger.ContextLogger
handler func(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) error
}
func NewRouter(router adapter.Router, logger logger.ContextLogger, handler func(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) error) *Router {
return &Router{Router: router, logger: logger, handler: handler}
}
func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return r.handler(ctx, conn, metadata, func(error) {})
}
func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return os.ErrInvalid
}
func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if err := r.handler(ctx, conn, metadata, onClose); err != nil {
r.logger.ErrorContext(ctx, err)
N.CloseOnHandshakeFailure(conn, onClose, err)
}
}
func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
r.logger.ErrorContext(ctx, os.ErrInvalid)
N.CloseOnHandshakeFailure(conn, onClose, os.ErrInvalid)
}

View File

@@ -0,0 +1,96 @@
package group
import (
"context"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterFailover(registry *outbound.Registry) {
outbound.Register[option.FailoverOutboundOptions](registry, C.TypeFailover, NewFailover)
}
var (
_ adapter.OutboundGroup = (*Failover)(nil)
)
type Failover struct {
outbound.Adapter
ctx context.Context
outbound adapter.OutboundManager
logger logger.ContextLogger
tags []string
outbounds map[string]adapter.Outbound
}
func NewFailover(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FailoverOutboundOptions) (adapter.Outbound, error) {
if len(options.Outbounds) == 0 {
return nil, E.New("missing tags")
}
outbound := &Failover{
Adapter: outbound.NewAdapter(C.TypeFailover, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
ctx: ctx,
outbound: service.FromContext[adapter.OutboundManager](ctx),
logger: logger,
tags: options.Outbounds,
outbounds: make(map[string]adapter.Outbound, len(options.Outbounds)),
}
return outbound, nil
}
func (s *Failover) Start() error {
for i, tag := range s.tags {
outbound, loaded := s.outbound.Outbound(tag)
if !loaded {
return E.New("outbound ", i, " not found: ", tag)
}
s.outbounds[tag] = outbound
}
return nil
}
func (s *Failover) Now() string {
return s.tags[0]
}
func (s *Failover) All() []string {
return s.tags
}
func (s *Failover) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
var conn net.Conn
var err error
for _, outbound := range s.outbounds {
conn, err = outbound.DialContext(ctx, network, destination)
if err != nil {
s.logger.ErrorContext(ctx, err)
continue
}
return conn, nil
}
return nil, err
}
func (s *Failover) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
var conn net.PacketConn
var err error
for _, outbound := range s.outbounds {
conn, err = outbound.ListenPacket(ctx, destination)
if err != nil {
s.logger.ErrorContext(ctx, err)
continue
}
return conn, nil
}
return nil, err
}

View File

@@ -180,3 +180,11 @@ func (h *Inbound) Close() error {
common.PtrOrNil(h.service),
)
}
func (h *Inbound) UpdateUsers(users []option.HysteriaUser) {
h.service.UpdateUsers(common.MapIndexed(users, func(index int, _ option.HysteriaUser) int {
return index
}), common.Map(users, func(it option.HysteriaUser) string {
return it.AuthString
}))
}

View File

@@ -213,3 +213,11 @@ func (h *Inbound) Close() error {
common.PtrOrNil(h.service),
)
}
func (h *Inbound) UpdateUsers(users []option.Hysteria2User) {
h.service.UpdateUsers(common.MapIndexed(users, func(index int, _ option.Hysteria2User) int {
return index
}), common.Map(users, func(it option.Hysteria2User) string {
return it.Password
}))
}

View File

@@ -0,0 +1,158 @@
package bandwidth
import (
"context"
"net"
"golang.org/x/time/rate"
)
type connWithDownloadBandwidthLimiter struct {
net.Conn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithDownloadBandwidthLimiter {
return &connWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error) {
var nn int
for {
end := len(p)
if end == 0 {
break
}
if conn.burst < len(p) {
end = conn.burst
}
err = conn.limiter.WaitN(conn.ctx, end)
if err != nil {
return
}
nn, err = conn.Conn.Write(p[:end])
n += nn
if err != nil {
return
}
p = p[end:]
}
return
}
type connWithUploadBandwidthLimiter struct {
net.Conn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithUploadBandwidthLimiter {
return &connWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) {
if conn.burst < len(p) {
p = p[:conn.burst]
}
n, err = conn.Conn.Read(p)
if err != nil {
return
}
err = conn.limiter.WaitN(conn.ctx, n)
if err != nil {
return
}
return
}
type connWithCloseHandler struct {
net.Conn
onClose CloseHandlerFunc
}
func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler {
return &connWithCloseHandler{conn, onClose}
}
func (conn *connWithCloseHandler) Close() error {
conn.onClose()
return conn.Conn.Close()
}
type packetConnWithDownloadBandwidthLimiter struct {
net.PacketConn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithDownloadBandwidthLimiter {
return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.Addr) (n int, err error) {
var nn int
for {
end := len(p)
if end == 0 {
break
}
if conn.burst < len(p) {
end = conn.burst
}
err = conn.limiter.WaitN(conn.ctx, end)
if err != nil {
return
}
nn, err = conn.PacketConn.WriteTo(p[:end], addr)
n += nn
if err != nil {
return
}
p = p[end:]
}
return
}
type packetConnWithUploadBandwidthLimiter struct {
net.PacketConn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithUploadBandwidthLimiter {
return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if conn.burst < len(p) {
p = p[:conn.burst]
}
n, addr, err = conn.PacketConn.ReadFrom(p)
if err != nil {
return
}
err = conn.limiter.WaitN(conn.ctx, n)
if err != nil {
return
}
return
}
type packetConnWithCloseHandler struct {
net.PacketConn
onClose CloseHandlerFunc
}
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler {
return &packetConnWithCloseHandler{conn, onClose}
}
func (conn *packetConnWithCloseHandler) Close() error {
conn.onClose()
return conn.PacketConn.Close()
}

View File

@@ -0,0 +1,146 @@
package bandwidth
import (
"context"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/route"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.BandwidthLimiterOutboundOptions](registry, C.TypeBandwidthLimiter, NewOutbound)
}
type Outbound struct {
outbound.Adapter
ctx context.Context
outbound adapter.OutboundManager
connection adapter.ConnectionManager
logger logger.ContextLogger
strategy BandwidthStrategy
outboundTag string
detour adapter.Outbound
router *route.Router
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.BandwidthLimiterOutboundOptions) (adapter.Outbound, error) {
if options.Strategy == "" {
return nil, E.New("missing strategy")
}
if options.Route.Final == "" {
return nil, E.New("missing final outbound")
}
var strategy BandwidthStrategy
var err error
switch options.Strategy {
case "users":
usersStrategies := make(map[string]BandwidthStrategy, len(options.Users))
for _, user := range options.Users {
userStrategy, err := CreateStrategy(user.Strategy, user.Mode, user.ConnectionType, options.Speed.Value())
if err != nil {
return nil, err
}
usersStrategies[user.Name] = userStrategy
}
strategy = NewUsersBandwidthStrategy(usersStrategies)
case "manager":
strategy = NewManagerBandwidthStrategy()
default:
strategy, err = CreateStrategy(options.Strategy, options.Mode, options.ConnectionType, options.Speed.Value())
if err != nil {
return nil, err
}
}
logFactory := service.FromContext[log.Factory](ctx)
r := route.NewRouter(ctx, logFactory, options.Route, option.DNSOptions{})
err = r.Initialize(options.Route.Rules, options.Route.RuleSet)
if err != nil {
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapter(C.TypeBandwidthLimiter, tag, nil, []string{}),
ctx: ctx,
outbound: service.FromContext[adapter.OutboundManager](ctx),
connection: service.FromContext[adapter.ConnectionManager](ctx),
logger: logger,
strategy: strategy,
outboundTag: options.Route.Final,
router: r,
}
return outbound, nil
}
func (h *Outbound) Network() []string {
return []string{N.NetworkTCP, N.NetworkUDP}
}
func (h *Outbound) Start() error {
detour, loaded := h.outbound.Outbound(h.outboundTag)
if !loaded {
return E.New("outbound not found: ", h.outboundTag)
}
h.detour = detour
for _, stage := range []adapter.StartStage{adapter.StartStateStart, adapter.StartStatePostStart, adapter.StartStateStarted} {
err := h.router.Start(stage)
if err != nil {
return err
}
}
return nil
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
conn, err := h.detour.DialContext(ctx, network, destination)
if err != nil {
return nil, err
}
return h.strategy.wrapConn(ctx, conn, adapter.ContextFrom(ctx), true)
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
conn, err := h.detour.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return h.strategy.wrapPacketConn(ctx, conn, adapter.ContextFrom(ctx), true)
}
func (h *Outbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
conn, err := h.strategy.wrapConn(ctx, conn, &metadata, false)
if err != nil {
h.logger.ErrorContext(ctx, err)
return
}
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return
}
func (h *Outbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
packetConn, err := h.strategy.wrapPacketConn(ctx, bufio.NewNetPacketConn(conn), &metadata, false)
if err != nil {
h.logger.ErrorContext(ctx, err)
return
}
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
h.router.RoutePacketConnectionEx(ctx, bufio.NewPacketConn(packetConn), metadata, onClose)
return
}
func (h *Outbound) GetStrategy() BandwidthStrategy {
return h.strategy
}

View File

@@ -0,0 +1,266 @@
package bandwidth
import (
"context"
"net"
"strconv"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/time/rate"
)
type (
CloseHandlerFunc = func()
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter *rate.Limiter, reverse bool) net.Conn
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter, reverse bool) net.PacketConn
)
type BandwidthStrategy interface {
wrapConn(ctx context.Context, conn net.Conn, metadata *adapter.InboundContext, reverse bool) (net.Conn, error)
wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error)
}
type BandwidthLimiterStrategy interface {
getLimiter(ctx context.Context, metadata *adapter.InboundContext) (*rate.Limiter, CloseHandlerFunc, error)
}
type DefaultWrapStrategy struct {
limiterStrategy BandwidthLimiterStrategy
connWrapper ConnWrapper
packetConnWrapper PacketConnWrapper
}
func NewDefaultWrapStrategy(limiterStrategy BandwidthLimiterStrategy, connWrapper ConnWrapper, packetConnWrapper PacketConnWrapper) *DefaultWrapStrategy {
return &DefaultWrapStrategy{limiterStrategy, connWrapper, packetConnWrapper}
}
func (s *DefaultWrapStrategy) wrapConn(ctx context.Context, conn net.Conn, metadata *adapter.InboundContext, reverse bool) (net.Conn, error) {
limiter, onClose, err := s.limiterStrategy.getLimiter(ctx, metadata)
if err != nil {
return nil, err
}
return NewConnWithCloseHandler(s.connWrapper(ctx, conn, limiter, reverse), onClose), nil
}
func (s *DefaultWrapStrategy) wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error) {
limiter, onClose, err := s.limiterStrategy.getLimiter(ctx, metadata)
if err != nil {
return nil, err
}
return NewPacketConnWithCloseHandler(s.packetConnWrapper(ctx, conn, limiter, reverse), onClose), nil
}
type GlobalBandwidthStrategy struct {
limiter *rate.Limiter
}
func NewGlobalBandwidthStrategy(speed uint64) *GlobalBandwidthStrategy {
return &GlobalBandwidthStrategy{
limiter: createSpeedLimiter(speed),
}
}
func (s *GlobalBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (*rate.Limiter, CloseHandlerFunc, error) {
return s.limiter, func() {}, nil
}
type idBandwidthLimiter struct {
limiter *rate.Limiter
handles uint32
}
type ConnectionBandwidthStrategy struct {
limiters map[string]*idBandwidthLimiter
connIDGetter ConnIDGetter
speed uint64
mtx sync.Mutex
}
func NewConnectionBandwidthStrategy(connIDGetter ConnIDGetter, speed uint64) *ConnectionBandwidthStrategy {
return &ConnectionBandwidthStrategy{
limiters: make(map[string]*idBandwidthLimiter),
connIDGetter: connIDGetter,
speed: speed,
}
}
func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (*rate.Limiter, CloseHandlerFunc, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
id, ok := s.connIDGetter(ctx, metadata)
if !ok {
return nil, nil, E.New("id not found")
}
limiter, ok := s.limiters[id]
if !ok {
limiter = &idBandwidthLimiter{
limiter: createSpeedLimiter(s.speed),
}
s.limiters[id] = limiter
}
limiter.handles++
var once sync.Once
return limiter.limiter, func() {
once.Do(func() {
s.mtx.Lock()
defer s.mtx.Unlock()
limiter.handles--
if limiter.handles == 0 {
delete(s.limiters, id)
}
})
}, nil
}
type UsersBandwidthStrategy struct {
strategies map[string]BandwidthStrategy
mtx sync.Mutex
}
func NewUsersBandwidthStrategy(strategies map[string]BandwidthStrategy) *UsersBandwidthStrategy {
return &UsersBandwidthStrategy{
strategies: strategies,
}
}
func (s *UsersBandwidthStrategy) wrapConn(ctx context.Context, conn net.Conn, metadata *adapter.InboundContext, reverse bool) (net.Conn, error) {
strategy, err := s.getStrategy(ctx, metadata)
if err != nil {
return nil, err
}
return strategy.wrapConn(ctx, conn, metadata, reverse)
}
func (s *UsersBandwidthStrategy) wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error) {
strategy, err := s.getStrategy(ctx, metadata)
if err != nil {
return nil, err
}
return strategy.wrapPacketConn(ctx, conn, metadata, reverse)
}
func (s *UsersBandwidthStrategy) getStrategy(ctx context.Context, metadata *adapter.InboundContext) (BandwidthStrategy, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
var user string
if metadata != nil {
user = metadata.User
}
strategy, ok := s.strategies[user]
if ok {
return strategy, nil
}
return nil, E.New("user strategy not found: ", user)
}
type ManagerBandwidthStrategy struct {
*UsersBandwidthStrategy
}
func NewManagerBandwidthStrategy() *ManagerBandwidthStrategy {
return &ManagerBandwidthStrategy{
UsersBandwidthStrategy: NewUsersBandwidthStrategy(map[string]BandwidthStrategy{}),
}
}
func (s *ManagerBandwidthStrategy) UpdateStrategies(strategies map[string]BandwidthStrategy) {
s.mtx.Lock()
defer s.mtx.Unlock()
s.strategies = strategies
}
func CreateStrategy(strategy string, mode string, connectionType string, speed uint64) (BandwidthStrategy, error) {
var limiterStrategy BandwidthLimiterStrategy
switch strategy {
case "global":
limiterStrategy = NewGlobalBandwidthStrategy(speed)
case "connection":
var connIDGetter ConnIDGetter
switch connectionType {
case "mux":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
id, ok := log.MuxIDFromContext(ctx)
if !ok {
return "", ok
}
return strconv.FormatUint(uint64(id.ID), 10), ok
}
case "hwid":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
id, ok := ctx.Value("hwid").(string)
return id, ok
}
case "ip":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
return metadata.Source.IPAddr().String(), true
}
default:
return nil, E.New("connection type not found: ", connectionType)
}
limiterStrategy = NewConnectionBandwidthStrategy(connIDGetter, speed)
default:
return nil, E.New("strategy not found: ", strategy)
}
var (
connWrapper ConnWrapper
packetConnWrapper PacketConnWrapper
)
switch mode {
case "download":
connWrapper = connWithDownloadBandwidthWrapper
packetConnWrapper = packetConnWithDownloadBandwidthWrapper
case "upload":
connWrapper = connWithUploadBandwidthWrapper
packetConnWrapper = packetConnWithUploadBandwidthWrapper
case "duplex":
connWrapper = connWithDuplexBandwidthWrapper
packetConnWrapper = packetConnWithDuplexBandwidthWrapper
default:
return nil, E.New("mode not found: ", mode)
}
return NewDefaultWrapStrategy(limiterStrategy, connWrapper, packetConnWrapper), nil
}
func createSpeedLimiter(speed uint64) *rate.Limiter {
return rate.NewLimiter(rate.Limit(float64(speed)), 10000)
}
func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter *rate.Limiter, reverse bool) net.Conn {
if reverse {
return NewConnWithUploadBandwidthLimiter(ctx, conn, limiter)
}
return NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
}
func connWithUploadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter *rate.Limiter, reverse bool) net.Conn {
if reverse {
return NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
}
return NewConnWithUploadBandwidthLimiter(ctx, conn, limiter)
}
func connWithDuplexBandwidthWrapper(ctx context.Context, conn net.Conn, limiter *rate.Limiter, reverse bool) net.Conn {
return NewConnWithUploadBandwidthLimiter(ctx, NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter), limiter)
}
func packetConnWithDownloadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter, reverse bool) net.PacketConn {
if reverse {
return NewPacketConnWithUploadBandwidthLimiter(ctx, conn, limiter)
}
return NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
}
func packetConnWithUploadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter, reverse bool) net.PacketConn {
if reverse {
return NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
}
return NewPacketConnWithUploadBandwidthLimiter(ctx, conn, limiter)
}
func packetConnWithDuplexBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter, reverse bool) net.PacketConn {
return NewPacketConnWithUploadBandwidthLimiter(ctx, NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter), limiter)
}

View File

@@ -0,0 +1,37 @@
package connection
import (
"context"
"sync"
E "github.com/sagernet/sing/common/exceptions"
)
func NewDefaultLock(max uint32) LockIDGetter {
locks := make(map[string]*uint32)
mtx := sync.Mutex{}
return func(id string) (CloseHandlerFunc, context.Context, error) {
mtx.Lock()
defer mtx.Unlock()
handles, ok := locks[id]
if !ok {
if len(locks) == int(max) {
return nil, nil, E.New("not enough free locks")
}
handles = new(uint32)
locks[id] = handles
}
*handles++
var once sync.Once
return func() {
once.Do(func() {
mtx.Lock()
defer mtx.Unlock()
*handles--
if *handles == 0 {
delete(locks, id)
}
})
}, nil, nil
}
}

View File

@@ -0,0 +1,204 @@
package connection
import (
"context"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/route"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.ConnectionLimiterOutboundOptions](registry, C.TypeConnectionLimiter, NewOutbound)
}
type Outbound struct {
outbound.Adapter
ctx context.Context
outbound adapter.OutboundManager
connection adapter.ConnectionManager
logger logger.ContextLogger
strategy ConnectionStrategy
outboundTag string
detour adapter.Outbound
router *route.Router
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ConnectionLimiterOutboundOptions) (adapter.Outbound, error) {
if options.Strategy == "" {
return nil, E.New("missing strategy")
}
if options.Route.Final == "" {
return nil, E.New("missing final outbound")
}
var strategy ConnectionStrategy
var err error
switch options.Strategy {
case "users":
usersStrategies := make(map[string]ConnectionStrategy, len(options.Users))
for _, user := range options.Users {
userStrategy, err := CreateStrategy(user.Strategy, user.ConnectionType, NewDefaultLock(user.Count))
if err != nil {
return nil, err
}
usersStrategies[user.Name] = userStrategy
}
strategy = NewUsersConnectionStrategy(usersStrategies)
case "manager":
strategy = NewManagerConnectionStrategy()
default:
strategy, err = CreateStrategy(options.Strategy, options.ConnectionType, NewDefaultLock(options.Count))
if err != nil {
return nil, err
}
}
logFactory := service.FromContext[log.Factory](ctx)
r := route.NewRouter(ctx, logFactory, options.Route, option.DNSOptions{})
err = r.Initialize(options.Route.Rules, options.Route.RuleSet)
if err != nil {
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapter(C.TypeConnectionLimiter, tag, nil, []string{}),
ctx: ctx,
outbound: service.FromContext[adapter.OutboundManager](ctx),
connection: service.FromContext[adapter.ConnectionManager](ctx),
logger: logger,
outboundTag: options.Route.Final,
strategy: strategy,
router: r,
}
return outbound, nil
}
func (h *Outbound) Network() []string {
return []string{N.NetworkTCP, N.NetworkUDP}
}
func (h *Outbound) Start() error {
detour, loaded := h.outbound.Outbound(h.outboundTag)
if !loaded {
return E.New("outbound not found: ", h.outboundTag)
}
h.detour = detour
for _, stage := range []adapter.StartStage{adapter.StartStateStart, adapter.StartStatePostStart, adapter.StartStateStarted} {
err := h.router.Start(stage)
if err != nil {
return err
}
}
return nil
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
onClose, lockCtx, err := h.strategy.request(ctx, adapter.ContextFrom(ctx))
if err != nil {
return nil, err
}
conn, err := h.detour.DialContext(ctx, network, destination)
if err != nil {
onClose()
return nil, err
}
conn = newConnWithCloseHandlerFunc(conn, onClose)
if lockCtx != nil {
go connChecker(lockCtx, conn.Close)
}
return conn, nil
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
onClose, lockCtx, err := h.strategy.request(ctx, adapter.ContextFrom(ctx))
if err != nil {
return nil, err
}
conn, err := h.detour.ListenPacket(ctx, destination)
if err != nil {
onClose()
return nil, err
}
conn = newPacketConnWithCloseHandlerFunc(conn, onClose)
if lockCtx != nil {
go connChecker(lockCtx, conn.Close)
}
return conn, nil
}
func (h *Outbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
limiterOnClose, lockCtx, err := h.strategy.request(ctx, &metadata)
if err != nil {
h.logger.ErrorContext(ctx, err)
return
}
conn = newConnWithCloseHandlerFunc(conn, limiterOnClose)
if lockCtx != nil {
go connChecker(lockCtx, conn.Close)
}
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return
}
func (h *Outbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
limiterOnClose, lockCtx, err := h.strategy.request(ctx, &metadata)
if err != nil {
h.logger.ErrorContext(ctx, err)
return
}
conn = bufio.NewPacketConn(newPacketConnWithCloseHandlerFunc(bufio.NewNetPacketConn(conn), limiterOnClose))
if lockCtx != nil {
go connChecker(lockCtx, conn.Close)
}
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
return
}
func (h *Outbound) GetStrategy() ConnectionStrategy {
return h.strategy
}
type connWithCloseHandlerFunc struct {
net.Conn
onClose CloseHandlerFunc
}
func newConnWithCloseHandlerFunc(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandlerFunc {
return &connWithCloseHandlerFunc{conn, onClose}
}
func (conn *connWithCloseHandlerFunc) Close() error {
conn.onClose()
return conn.Conn.Close()
}
type packetConnWithCloseHandlerFunc struct {
net.PacketConn
onClose CloseHandlerFunc
}
func newPacketConnWithCloseHandlerFunc(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandlerFunc {
return &packetConnWithCloseHandlerFunc{conn, onClose}
}
func (conn *packetConnWithCloseHandlerFunc) Close() error {
conn.onClose()
return conn.PacketConn.Close()
}
func connChecker(ctx context.Context, closeFunc func() error) {
<-ctx.Done()
closeFunc()
}

View File

@@ -0,0 +1,119 @@
package connection
import (
"context"
"strconv"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type (
CloseHandlerFunc = func()
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
LockIDGetter = func(string) (CloseHandlerFunc, context.Context, error)
ConnectionStrategy interface {
request(ctx context.Context, metadata *adapter.InboundContext) (onClose CloseHandlerFunc, lockCtx context.Context, err error)
}
)
type DefaultConnectionStrategy struct {
connIDGetter ConnIDGetter
lockIDGetter LockIDGetter
mtx sync.Mutex
}
func NewDefaultConnectionStrategy(connIDGetter ConnIDGetter, lockIDGetter LockIDGetter) *DefaultConnectionStrategy {
outbound := &DefaultConnectionStrategy{
connIDGetter: connIDGetter,
lockIDGetter: lockIDGetter,
}
return outbound
}
func (s *DefaultConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
id, ok := s.connIDGetter(ctx, metadata)
if !ok {
return nil, nil, E.New("id not found")
}
return s.lockIDGetter(id)
}
type UsersConnectionStrategy struct {
strategies map[string]ConnectionStrategy
mtx sync.Mutex
}
func NewUsersConnectionStrategy(strategies map[string]ConnectionStrategy) *UsersConnectionStrategy {
return &UsersConnectionStrategy{
strategies: strategies,
}
}
func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
var user string
if metadata != nil {
user = metadata.User
}
strategy, ok := s.strategies[user]
if ok {
return strategy.request(ctx, metadata)
}
return nil, nil, E.New("user strategy not found: ", user)
}
type ManagerConnectionStrategy struct {
*UsersConnectionStrategy
}
func NewManagerConnectionStrategy() *ManagerConnectionStrategy {
return &ManagerConnectionStrategy{
UsersConnectionStrategy: NewUsersConnectionStrategy(map[string]ConnectionStrategy{}),
}
}
func (s *ManagerConnectionStrategy) UpdateStrategies(strategies map[string]ConnectionStrategy) {
s.mtx.Lock()
defer s.mtx.Unlock()
s.strategies = strategies
}
func CreateStrategy(strategy string, connectionType string, lockIDGetter LockIDGetter) (ConnectionStrategy, error) {
switch strategy {
case "connection":
var connIDGetter ConnIDGetter
switch connectionType {
case "mux":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
id, ok := log.MuxIDFromContext(ctx)
if !ok {
return "", ok
}
return strconv.FormatUint(uint64(id.ID), 10), ok
}
case "hwid":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
id, ok := ctx.Value("hwid").(string)
return id, ok
}
case "ip":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
return metadata.Source.IPAddr().String(), true
}
default:
return nil, E.New("connection type not found: ", connectionType)
}
return NewDefaultConnectionStrategy(connIDGetter, lockIDGetter), nil
default:
return nil, E.New("strategy not found: ", strategy)
}
}

View File

@@ -158,6 +158,14 @@ func (h *Inbound) Close() error {
)
}
func (h *Inbound) UpdateUsers(users []option.TrojanUser) {
h.service.UpdateUsers(common.MapIndexed(users, func(index int, _ option.TrojanUser) int {
return index
}), common.Map(users, func(it option.TrojanUser) string {
return it.Password
}))
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if h.tlsConfig != nil && h.transport == nil {
tlsConn, err := tls.ServerHandshake(ctx, conn, h.tlsConfig)

View File

@@ -170,3 +170,13 @@ func (h *Inbound) Close() error {
common.PtrOrNil(h.server),
)
}
func (h *Inbound) UpdateUsers(users []option.TUICUser) {
h.server.UpdateUsers(common.MapIndexed(users, func(index int, _ option.TUICUser) int {
return index
}), common.Map(users, func(it option.TUICUser) [16]byte {
return [16]byte(uuid.Must(uuid.FromString(it.UUID)).Bytes())
}), common.Map(users, func(it option.TUICUser) string {
return it.Password
}))
}

View File

@@ -3,13 +3,13 @@ package tunnel
import (
"context"
"net"
"os"
"time"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound"
sbUot "github.com/sagernet/sing-box/common/uot"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
@@ -18,6 +18,7 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/service"
)
@@ -27,12 +28,13 @@ func RegisterClientEndpoint(registry *endpoint.Registry) {
type ClientEndpoint struct {
outbound.Adapter
ctx context.Context
outbound adapter.Outbound
router adapter.ConnectionRouterEx
logger logger.ContextLogger
uuid uuid.UUID
key uuid.UUID
ctx context.Context
outbound adapter.Outbound
router adapter.ConnectionRouterEx
logger logger.ContextLogger
uuid uuid.UUID
key uuid.UUID
uotClient *uot.Client
}
func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelClientEndpointOptions) (adapter.Endpoint, error) {
@@ -45,9 +47,9 @@ func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.Co
return nil, err
}
client := &ClientEndpoint{
Adapter: outbound.NewAdapter(C.TypeTunnelClient, tag, []string{N.NetworkTCP}, []string{}),
Adapter: outbound.NewAdapter(C.TypeTunnelClient, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
ctx: ctx,
router: router,
router: sbUot.NewRouter(router, logger),
logger: logger,
uuid: clientUUID,
key: clientKey,
@@ -58,6 +60,10 @@ func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.Co
return nil, err
}
client.outbound = outbound
client.uotClient = &uot.Client{
Dialer: outbound,
Version: uot.Version,
}
return client, nil
}
@@ -85,8 +91,8 @@ func (c *ClientEndpoint) Start(stage adapter.StartStage) error {
}
func (c *ClientEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if network != N.NetworkTCP {
return nil, os.ErrInvalid
if N.NetworkName(network) == N.NetworkUDP {
return c.uotClient.DialContext(ctx, network, destination)
}
var destinationUUID *uuid.UUID
if metadata := adapter.ContextFrom(ctx); metadata != nil {
@@ -109,11 +115,14 @@ func (c *ClientEndpoint) DialContext(ctx context.Context, network string, destin
return nil, err
}
err = WriteRequest(conn, &Request{UUID: c.key, Command: CommandTCP, DestinationUUID: *destinationUUID, Destination: destination})
return conn, err
if err != nil {
return nil, err
}
return conn, nil
}
func (c *ClientEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return nil, os.ErrInvalid
return c.uotClient.ListenPacket(ctx, destination)
}
func (c *ClientEndpoint) Close() error {
@@ -139,6 +148,7 @@ func (c *ClientEndpoint) startInboundConn() error {
func (c *ClientEndpoint) connHandler(conn net.Conn, request *Request) {
metadata := adapter.InboundContext{
Inbound: c.Tag(),
Source: M.ParseSocksaddr(conn.RemoteAddr().String()),
Destination: request.Destination,
}

View File

@@ -3,14 +3,13 @@ package tunnel
import (
"context"
"net"
"os"
"sync"
"time"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound"
sbUot "github.com/sagernet/sing-box/common/uot"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
@@ -19,6 +18,7 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/service"
)
@@ -28,16 +28,15 @@ func RegisterServerEndpoint(registry *endpoint.Registry) {
type ServerEndpoint struct {
outbound.Adapter
logger logger.ContextLogger
inbound adapter.Inbound
router adapter.Router
uuid uuid.UUID
users map[uuid.UUID]uuid.UUID
keys map[uuid.UUID]uuid.UUID
conns map[uuid.UUID]chan net.Conn
timeout time.Duration
mtx sync.Mutex
logger logger.ContextLogger
inbound adapter.Inbound
router adapter.ConnectionRouterEx
uuid uuid.UUID
users map[uuid.UUID]uuid.UUID
keys map[uuid.UUID]uuid.UUID
conns map[uuid.UUID]chan net.Conn
timeout time.Duration
uotClient *uot.Client
}
func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelServerEndpointOptions) (adapter.Endpoint, error) {
@@ -46,9 +45,9 @@ func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.Co
return nil, err
}
server := &ServerEndpoint{
Adapter: outbound.NewAdapter(C.TypeTunnelServer, tag, []string{N.NetworkTCP}, []string{}),
Adapter: outbound.NewAdapter(C.TypeTunnelServer, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
logger: logger,
router: router,
router: sbUot.NewRouter(router, logger),
uuid: serverUUID,
}
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
@@ -78,6 +77,10 @@ func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.Co
} else {
server.timeout = C.TCPConnectTimeout
}
server.uotClient = &uot.Client{
Dialer: server,
Version: uot.Version,
}
return server, nil
}
@@ -86,8 +89,8 @@ func (s *ServerEndpoint) Start(stage adapter.StartStage) error {
}
func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if network != N.NetworkTCP {
return nil, os.ErrInvalid
if N.NetworkName(network) == N.NetworkUDP {
return s.uotClient.DialContext(ctx, network, destination)
}
var sourceUUID *uuid.UUID
var ch chan net.Conn
@@ -97,13 +100,11 @@ func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destin
if err != nil {
return nil, err
}
s.mtx.Lock()
var ok bool
ch, ok = s.conns[tunnelDestination]
if !ok {
return nil, E.New("user ", metadata.TunnelDestination, " not found")
}
s.mtx.Unlock()
}
if metadata.TunnelSource != "" {
tunnelSource, err := uuid.FromString(metadata.TunnelSource)
@@ -131,6 +132,7 @@ func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destin
case conn := <-ch:
err := WriteRequest(conn, &Request{UUID: *sourceUUID, Command: CommandTCP, Destination: destination})
if err != nil {
conn.Close()
s.logger.ErrorContext(ctx, err)
continue
}
@@ -142,7 +144,7 @@ func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destin
}
func (s *ServerEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return nil, os.ErrInvalid
return s.uotClient.ListenPacket(ctx, destination)
}
func (s *ServerEndpoint) Close() error {
@@ -159,8 +161,6 @@ func (s *ServerEndpoint) connHandler(ctx context.Context, conn net.Conn, metadat
return err
}
if request.Command == CommandInbound {
s.mtx.Lock()
defer s.mtx.Unlock()
uuid, ok := s.users[request.UUID]
if !ok {
return E.New("key ", request.UUID.String(), " not found")
@@ -183,14 +183,12 @@ func (s *ServerEndpoint) connHandler(ctx context.Context, conn net.Conn, metadat
if sourceUUID == request.DestinationUUID {
return E.New("routing loop on ", sourceUUID)
}
s.mtx.Lock()
if request.DestinationUUID != s.uuid {
_, ok = s.keys[request.DestinationUUID]
if !ok {
return E.New("user ", sourceUUID, " not found")
return E.New("user ", request.DestinationUUID, " not found")
}
}
s.mtx.Unlock()
metadata.Inbound = s.Tag()
metadata.InboundType = C.TypeTunnelServer
metadata.Destination = request.Destination

View File

@@ -138,6 +138,17 @@ func (h *Inbound) Close() error {
)
}
func (h *Inbound) UpdateUsers(users []option.VLESSUser) {
h.users = users
h.service.UpdateUsers(common.MapIndexed(users, func(index int, _ option.VLESSUser) int {
return index
}), common.Map(users, func(it option.VLESSUser) string {
return it.UUID
}), common.Map(users, func(it option.VLESSUser) string {
return it.Flow
}))
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if h.tlsConfig != nil && h.transport == nil {
tlsConn, err := tls.ServerHandshake(ctx, conn, h.tlsConfig)

View File

@@ -153,6 +153,16 @@ func (h *Inbound) Close() error {
)
}
func (h *Inbound) UpdateUsers(users []option.VMessUser) {
h.service.UpdateUsers(common.MapIndexed(users, func(index int, _ option.VMessUser) int {
return index
}), common.Map(users, func(it option.VMessUser) string {
return it.UUID
}), common.Map(users, func(it option.VMessUser) int {
return it.AlterId
}))
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if h.tlsConfig != nil && h.transport == nil {
tlsConn, err := tls.ServerHandshake(ctx, conn, h.tlsConfig)

View File

@@ -154,6 +154,8 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("::/0"),
},
PersistentKeepaliveInterval: options.PersistentKeepaliveInterval,
Reserved: options.Reserved,
},
},
MTU: 1280,