mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-17 16:44:29 +03:00
Add OpenVPN, TrustTunnel, Sudoku, inbound managers. Fixes
This commit is contained in:
@@ -3,15 +3,17 @@ package bandwidth
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
)
|
||||
|
||||
type connWithDownloadBandwidthLimiter struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter Limiter) *connWithDownloadBandwidthLimiter {
|
||||
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter BandwidthLimiter) *connWithDownloadBandwidthLimiter {
|
||||
return &connWithDownloadBandwidthLimiter{conn, ctx, limiter}
|
||||
}
|
||||
|
||||
@@ -26,10 +28,10 @@ func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error)
|
||||
type connWithUploadBandwidthLimiter struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter Limiter) *connWithUploadBandwidthLimiter {
|
||||
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter BandwidthLimiter) *connWithUploadBandwidthLimiter {
|
||||
return &connWithUploadBandwidthLimiter{conn, ctx, limiter}
|
||||
}
|
||||
|
||||
@@ -47,10 +49,10 @@ func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) {
|
||||
|
||||
type connWithCloseHandler struct {
|
||||
net.Conn
|
||||
onClose CloseHandlerFunc
|
||||
onClose onclose.CloseHandlerFunc
|
||||
}
|
||||
|
||||
func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler {
|
||||
func NewConnWithCloseHandler(conn net.Conn, onClose onclose.CloseHandlerFunc) *connWithCloseHandler {
|
||||
return &connWithCloseHandler{conn, onClose}
|
||||
}
|
||||
|
||||
@@ -62,10 +64,10 @@ func (conn *connWithCloseHandler) Close() error {
|
||||
type packetConnWithDownloadBandwidthLimiter struct {
|
||||
net.PacketConn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter Limiter) *packetConnWithDownloadBandwidthLimiter {
|
||||
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter) *packetConnWithDownloadBandwidthLimiter {
|
||||
return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter}
|
||||
}
|
||||
|
||||
@@ -80,10 +82,10 @@ func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.A
|
||||
type packetConnWithUploadBandwidthLimiter struct {
|
||||
net.PacketConn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter Limiter) *packetConnWithUploadBandwidthLimiter {
|
||||
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter) *packetConnWithUploadBandwidthLimiter {
|
||||
return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter}
|
||||
}
|
||||
|
||||
@@ -101,10 +103,10 @@ func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, add
|
||||
|
||||
type packetConnWithCloseHandler struct {
|
||||
net.PacketConn
|
||||
onClose CloseHandlerFunc
|
||||
onClose onclose.CloseHandlerFunc
|
||||
}
|
||||
|
||||
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler {
|
||||
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose onclose.CloseHandlerFunc) *packetConnWithCloseHandler {
|
||||
return &packetConnWithCloseHandler{conn, onClose}
|
||||
}
|
||||
|
||||
@@ -113,38 +115,38 @@ func (conn *packetConnWithCloseHandler) Close() error {
|
||||
return conn.PacketConn.Close()
|
||||
}
|
||||
|
||||
func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn {
|
||||
func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn {
|
||||
if reverse {
|
||||
return NewConnWithUploadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
return NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
|
||||
func connWithUploadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn {
|
||||
func connWithUploadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn {
|
||||
if reverse {
|
||||
return NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
return NewConnWithUploadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
|
||||
func connWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn {
|
||||
func connWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn {
|
||||
return NewConnWithUploadBandwidthLimiter(ctx, NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter), limiter)
|
||||
}
|
||||
|
||||
func packetConnWithDownloadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn {
|
||||
func packetConnWithDownloadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn {
|
||||
if reverse {
|
||||
return NewPacketConnWithUploadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
return NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
|
||||
func packetConnWithUploadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn {
|
||||
func packetConnWithUploadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn {
|
||||
if reverse {
|
||||
return NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
return NewPacketConnWithUploadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
|
||||
func packetConnWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn {
|
||||
func packetConnWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn {
|
||||
return NewPacketConnWithUploadBandwidthLimiter(ctx, NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter), limiter)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,13 @@ import (
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
type Limiter interface {
|
||||
type BandwidthLimiter interface {
|
||||
WaitN(ctx context.Context, n int) (err error)
|
||||
SetSpeed(speed uint64)
|
||||
}
|
||||
|
||||
type FlowKeysLimiter struct {
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
connIDGetter ConnIDGetter
|
||||
|
||||
waits map[string][]*wait
|
||||
@@ -25,7 +26,7 @@ type FlowKeysLimiter struct {
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter Limiter) *FlowKeysLimiter {
|
||||
func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FlowKeysLimiter {
|
||||
return &FlowKeysLimiter{
|
||||
limiter: limiter,
|
||||
connIDGetter: connIDGetter,
|
||||
@@ -36,6 +37,10 @@ func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter Limiter) *FlowKeysLim
|
||||
}
|
||||
}
|
||||
|
||||
func (l *FlowKeysLimiter) SetSpeed(speed uint64) {
|
||||
l.limiter.SetSpeed(speed)
|
||||
}
|
||||
|
||||
func (l *FlowKeysLimiter) WaitN(ctx context.Context, n int) error {
|
||||
id, _ := l.connIDGetter(ctx, adapter.ContextFrom(ctx))
|
||||
mainWait := &wait{ctx, make(chan struct{}), n}
|
||||
|
||||
@@ -7,16 +7,16 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type (
|
||||
CloseHandlerFunc = func()
|
||||
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
|
||||
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn
|
||||
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn
|
||||
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn
|
||||
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn
|
||||
)
|
||||
|
||||
type BandwidthStrategy interface {
|
||||
@@ -24,8 +24,12 @@ type BandwidthStrategy interface {
|
||||
wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error)
|
||||
}
|
||||
|
||||
type SpeedUpdater interface {
|
||||
SetSpeed(speed uint64)
|
||||
}
|
||||
|
||||
type BandwidthLimiterStrategy interface {
|
||||
getLimiter(ctx context.Context, metadata *adapter.InboundContext) (Limiter, CloseHandlerFunc, error)
|
||||
getLimiter(ctx context.Context, metadata *adapter.InboundContext) (BandwidthLimiter, onclose.CloseHandlerFunc, error)
|
||||
}
|
||||
|
||||
type DefaultWrapStrategy struct {
|
||||
@@ -54,8 +58,14 @@ func (s *DefaultWrapStrategy) wrapPacketConn(ctx context.Context, conn net.Packe
|
||||
return NewPacketConnWithCloseHandler(s.packetConnWrapper(ctx, conn, limiter, reverse), onClose), nil
|
||||
}
|
||||
|
||||
func (s *DefaultWrapStrategy) SetSpeed(speed uint64) {
|
||||
if updater, ok := s.limiterStrategy.(SpeedUpdater); ok {
|
||||
updater.SetSpeed(speed)
|
||||
}
|
||||
}
|
||||
|
||||
type GlobalBandwidthStrategy struct {
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewGlobalBandwidthStrategy(speed uint64, flowKeys []string) (*GlobalBandwidthStrategy, error) {
|
||||
@@ -68,12 +78,16 @@ func NewGlobalBandwidthStrategy(speed uint64, flowKeys []string) (*GlobalBandwid
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GlobalBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (Limiter, CloseHandlerFunc, error) {
|
||||
func (s *GlobalBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (BandwidthLimiter, onclose.CloseHandlerFunc, error) {
|
||||
return s.limiter, func() {}, nil
|
||||
}
|
||||
|
||||
func (s *GlobalBandwidthStrategy) SetSpeed(speed uint64) {
|
||||
s.limiter.SetSpeed(speed)
|
||||
}
|
||||
|
||||
type idBandwidthLimiter struct {
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
handles uint32
|
||||
}
|
||||
|
||||
@@ -94,7 +108,7 @@ func NewConnectionBandwidthStrategy(connIDGetter ConnIDGetter, speed uint64, flo
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (Limiter, CloseHandlerFunc, error) {
|
||||
func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (BandwidthLimiter, onclose.CloseHandlerFunc, error) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
id, ok := s.connIDGetter(ctx, metadata)
|
||||
@@ -126,6 +140,15 @@ func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ConnectionBandwidthStrategy) SetSpeed(speed uint64) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
s.speed = speed
|
||||
for _, limiter := range s.limiters {
|
||||
limiter.limiter.SetSpeed(speed)
|
||||
}
|
||||
}
|
||||
|
||||
type UsersBandwidthStrategy struct {
|
||||
strategies map[string]BandwidthStrategy
|
||||
mtx sync.Mutex
|
||||
@@ -167,20 +190,86 @@ func (s *UsersBandwidthStrategy) getStrategy(ctx context.Context, metadata *adap
|
||||
return nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
|
||||
type bwConnEntry struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
|
||||
|
||||
type ManagerBandwidthStrategy struct {
|
||||
*UsersBandwidthStrategy
|
||||
strategies map[string]BandwidthStrategy
|
||||
conns map[string][]*bwConnEntry
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewManagerBandwidthStrategy() *ManagerBandwidthStrategy {
|
||||
return &ManagerBandwidthStrategy{
|
||||
UsersBandwidthStrategy: NewUsersBandwidthStrategy(map[string]BandwidthStrategy{}),
|
||||
strategies: make(map[string]BandwidthStrategy),
|
||||
conns: make(map[string][]*bwConnEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerBandwidthStrategy) wrapConn(ctx context.Context, conn net.Conn, metadata *adapter.InboundContext, reverse bool) (net.Conn, error) {
|
||||
s.mtx.Lock()
|
||||
var user string
|
||||
if metadata != nil {
|
||||
user = metadata.User
|
||||
}
|
||||
strategy, ok := s.strategies[user]
|
||||
s.mtx.Unlock()
|
||||
if !ok {
|
||||
return nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
wrapped, err := strategy.wrapConn(ctx, conn, metadata, reverse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry := &bwConnEntry{conn: conn}
|
||||
s.mtx.Lock()
|
||||
s.conns[user] = append(s.conns[user], entry)
|
||||
s.mtx.Unlock()
|
||||
return onclose.NewConn(wrapped, func() {
|
||||
s.mtx.Lock()
|
||||
entries := s.conns[user]
|
||||
for i, e := range entries {
|
||||
if e == entry {
|
||||
s.conns[user] = append(entries[:i], entries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *ManagerBandwidthStrategy) wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error) {
|
||||
s.mtx.Lock()
|
||||
var user string
|
||||
if metadata != nil {
|
||||
user = metadata.User
|
||||
}
|
||||
strategy, ok := s.strategies[user]
|
||||
s.mtx.Unlock()
|
||||
if !ok {
|
||||
return nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
return strategy.wrapPacketConn(ctx, conn, metadata, reverse)
|
||||
}
|
||||
|
||||
func (s *ManagerBandwidthStrategy) UpdateStrategies(strategies map[string]BandwidthStrategy) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var closedEntries []*bwConnEntry
|
||||
for user, entries := range s.conns {
|
||||
if _, exists := strategies[user]; !exists {
|
||||
closedEntries = append(closedEntries, entries...)
|
||||
delete(s.conns, user)
|
||||
}
|
||||
}
|
||||
s.strategies = strategies
|
||||
s.mtx.Unlock()
|
||||
for _, entry := range closedEntries {
|
||||
entry.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type BypassBandwidthStrategy struct{}
|
||||
@@ -263,8 +352,8 @@ func CreateStrategy(strategy string, mode string, connectionType string, speed u
|
||||
return NewDefaultWrapStrategy(limiterStrategy, connWrapper, packetConnWrapper), nil
|
||||
}
|
||||
|
||||
func createSpeedLimiter(speed uint64, flowKeys []string) (Limiter, error) {
|
||||
var limiter Limiter = rate.NewLimiter(rate.Limit(float64(speed)), 65536)
|
||||
func createSpeedLimiter(speed uint64, flowKeys []string) (BandwidthLimiter, error) {
|
||||
var limiter BandwidthLimiter = &speedLimiter{limiter: rate.NewLimiter(rate.Limit(float64(speed)), 65536)}
|
||||
for i := len(flowKeys) - 1; i >= 0; i-- {
|
||||
getter, err := flowKeysConnIDGetter(flowKeys[i])
|
||||
if err != nil {
|
||||
@@ -275,16 +364,24 @@ func createSpeedLimiter(speed uint64, flowKeys []string) (Limiter, error) {
|
||||
return limiter, nil
|
||||
}
|
||||
|
||||
type speedLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func (r *speedLimiter) WaitN(ctx context.Context, n int) error {
|
||||
return r.limiter.WaitN(ctx, n)
|
||||
}
|
||||
|
||||
func (r *speedLimiter) SetSpeed(speed uint64) {
|
||||
r.limiter.SetLimit(rate.Limit(float64(speed)))
|
||||
}
|
||||
|
||||
func flowKeysConnIDGetter(name string) (ConnIDGetter, error) {
|
||||
switch name {
|
||||
case "user":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.User, true
|
||||
}, nil
|
||||
case "destination":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.Destination.String(), true
|
||||
}, nil
|
||||
case "source_ip":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.Source.IPAddr().String(), true
|
||||
@@ -302,6 +399,14 @@ func flowKeysConnIDGetter(name string) (ConnIDGetter, error) {
|
||||
}
|
||||
return strconv.FormatUint(uint64(id.ID), 10), ok
|
||||
}, nil
|
||||
case "protocol":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.Protocol, metadata.Protocol != ""
|
||||
}, nil
|
||||
case "destination":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.Destination.String(), true
|
||||
}, nil
|
||||
default:
|
||||
return nil, E.New("flow key not found: ", name)
|
||||
}
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func NewDefaultLock(max uint32) LockIDGetter {
|
||||
locks := make(map[string]*uint32)
|
||||
mtx := sync.Mutex{}
|
||||
return func(id string) (CloseHandlerFunc, context.Context, error) {
|
||||
return func(id string) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
mtx.Lock()
|
||||
defer mtx.Unlock()
|
||||
handles, ok := locks[id]
|
||||
@@ -22,16 +23,13 @@ func NewDefaultLock(max uint32) LockIDGetter {
|
||||
locks[id] = handles
|
||||
}
|
||||
*handles++
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
mtx.Lock()
|
||||
defer mtx.Unlock()
|
||||
*handles--
|
||||
if *handles == 0 {
|
||||
delete(locks, id)
|
||||
}
|
||||
})
|
||||
mtx.Lock()
|
||||
defer mtx.Unlock()
|
||||
*handles--
|
||||
if *handles == 0 {
|
||||
delete(locks, id)
|
||||
}
|
||||
}, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/route"
|
||||
@@ -110,7 +111,7 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination
|
||||
onClose()
|
||||
return nil, err
|
||||
}
|
||||
conn = newConnWithCloseHandlerFunc(conn, onClose)
|
||||
conn = onclose.NewConn(conn, onClose)
|
||||
if lockCtx != nil {
|
||||
go connChecker(lockCtx, conn.Close)
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
|
||||
onClose()
|
||||
return nil, err
|
||||
}
|
||||
conn = newPacketConnWithCloseHandlerFunc(conn, onClose)
|
||||
conn = onclose.NewPacketConn(conn, onClose)
|
||||
if lockCtx != nil {
|
||||
go connChecker(lockCtx, conn.Close)
|
||||
}
|
||||
@@ -141,7 +142,7 @@ func (h *Outbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
return
|
||||
}
|
||||
conn = newConnWithCloseHandlerFunc(conn, limiterOnClose)
|
||||
conn = onclose.NewConn(conn, limiterOnClose)
|
||||
if lockCtx != nil {
|
||||
go connChecker(lockCtx, conn.Close)
|
||||
}
|
||||
@@ -158,7 +159,7 @@ func (h *Outbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
return
|
||||
}
|
||||
conn = bufio.NewPacketConn(newPacketConnWithCloseHandlerFunc(bufio.NewNetPacketConn(conn), limiterOnClose))
|
||||
conn = bufio.NewPacketConn(onclose.NewPacketConn(bufio.NewNetPacketConn(conn), limiterOnClose))
|
||||
if lockCtx != nil {
|
||||
go connChecker(lockCtx, conn.Close)
|
||||
}
|
||||
@@ -172,33 +173,7 @@ func (h *Outbound) GetStrategy() ConnectionStrategy {
|
||||
return h.strategy
|
||||
}
|
||||
|
||||
type connWithCloseHandlerFunc struct {
|
||||
net.Conn
|
||||
onClose CloseHandlerFunc
|
||||
}
|
||||
|
||||
func newConnWithCloseHandlerFunc(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandlerFunc {
|
||||
return &connWithCloseHandlerFunc{conn, onClose}
|
||||
}
|
||||
|
||||
func (conn *connWithCloseHandlerFunc) Close() error {
|
||||
conn.onClose()
|
||||
return conn.Conn.Close()
|
||||
}
|
||||
|
||||
type packetConnWithCloseHandlerFunc struct {
|
||||
net.PacketConn
|
||||
onClose CloseHandlerFunc
|
||||
}
|
||||
|
||||
func newPacketConnWithCloseHandlerFunc(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandlerFunc {
|
||||
return &packetConnWithCloseHandlerFunc{conn, onClose}
|
||||
}
|
||||
|
||||
func (conn *packetConnWithCloseHandlerFunc) Close() error {
|
||||
conn.onClose()
|
||||
return conn.PacketConn.Close()
|
||||
}
|
||||
|
||||
func connChecker(ctx context.Context, closeFunc func() error) {
|
||||
<-ctx.Done()
|
||||
|
||||
@@ -6,18 +6,17 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type (
|
||||
CloseHandlerFunc = func()
|
||||
|
||||
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
|
||||
LockIDGetter = func(string) (CloseHandlerFunc, context.Context, error)
|
||||
LockIDGetter = func(string) (onclose.CloseHandlerFunc, context.Context, error)
|
||||
|
||||
ConnectionStrategy interface {
|
||||
request(ctx context.Context, metadata *adapter.InboundContext) (onClose CloseHandlerFunc, lockCtx context.Context, err error)
|
||||
request(ctx context.Context, metadata *adapter.InboundContext) (onClose onclose.CloseHandlerFunc, lockCtx context.Context, err error)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -36,7 +35,7 @@ func NewDefaultConnectionStrategy(connIDGetter ConnIDGetter, lockIDGetter LockID
|
||||
return outbound
|
||||
}
|
||||
|
||||
func (s *DefaultConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
|
||||
func (s *DefaultConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
id, ok := s.connIDGetter(ctx, metadata)
|
||||
@@ -57,7 +56,7 @@ func NewUsersConnectionStrategy(strategies map[string]ConnectionStrategy) *Users
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
|
||||
func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var user string
|
||||
@@ -71,20 +70,78 @@ func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter
|
||||
return nil, nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
|
||||
type cancelEntry struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type ManagerConnectionStrategy struct {
|
||||
*UsersConnectionStrategy
|
||||
strategies map[string]ConnectionStrategy
|
||||
cancels map[string][]*cancelEntry
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewManagerConnectionStrategy() *ManagerConnectionStrategy {
|
||||
return &ManagerConnectionStrategy{
|
||||
UsersConnectionStrategy: NewUsersConnectionStrategy(map[string]ConnectionStrategy{}),
|
||||
strategies: make(map[string]ConnectionStrategy),
|
||||
cancels: make(map[string][]*cancelEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
s.mtx.Lock()
|
||||
var user string
|
||||
if metadata != nil {
|
||||
user = metadata.User
|
||||
}
|
||||
strategy, ok := s.strategies[user]
|
||||
if !ok {
|
||||
s.mtx.Unlock()
|
||||
return nil, nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
onClose, _, err := strategy.request(ctx, metadata)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
entry := &cancelEntry{cancel: cancel}
|
||||
s.mtx.Lock()
|
||||
s.cancels[user] = append(s.cancels[user], entry)
|
||||
s.mtx.Unlock()
|
||||
originalOnClose := onClose
|
||||
wrappedOnClose := func() {
|
||||
s.mtx.Lock()
|
||||
entries := s.cancels[user]
|
||||
for i, e := range entries {
|
||||
if e == entry {
|
||||
s.cancels[user] = append(entries[:i], entries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
cancel()
|
||||
if originalOnClose != nil {
|
||||
originalOnClose()
|
||||
}
|
||||
}
|
||||
return wrappedOnClose, cancelCtx, nil
|
||||
}
|
||||
|
||||
func (s *ManagerConnectionStrategy) UpdateStrategies(strategies map[string]ConnectionStrategy) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var entries []*cancelEntry
|
||||
for user, cancels := range s.cancels {
|
||||
if _, exists := strategies[user]; !exists {
|
||||
entries = append(entries, cancels...)
|
||||
delete(s.cancels, user)
|
||||
}
|
||||
}
|
||||
s.strategies = strategies
|
||||
s.mtx.Unlock()
|
||||
for _, entry := range entries {
|
||||
entry.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
type BypassConnectionStrategy struct{}
|
||||
@@ -93,7 +150,7 @@ func NewBypassConnectionStrategy() *BypassConnectionStrategy {
|
||||
return &BypassConnectionStrategy{}
|
||||
}
|
||||
|
||||
func (s *BypassConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
|
||||
func (s *BypassConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
return func() {}, nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type (
|
||||
CloseHandlerFunc = func()
|
||||
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter TrafficLimiter, reverse bool) net.Conn
|
||||
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter TrafficLimiter, reverse bool) net.PacketConn
|
||||
)
|
||||
@@ -72,32 +72,60 @@ func (s *GlobalTrafficStrategy) getLimiter(ctx context.Context, metadata *adapte
|
||||
return s.limiter, nil
|
||||
}
|
||||
|
||||
type connEntry struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
|
||||
|
||||
type ManagerTrafficStrategy struct {
|
||||
strategies map[string]TrafficStrategy
|
||||
mtx sync.Mutex
|
||||
conns map[string][]*connEntry
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewManagerTrafficStrategy() *ManagerTrafficStrategy {
|
||||
return &ManagerTrafficStrategy{}
|
||||
return &ManagerTrafficStrategy{
|
||||
conns: make(map[string][]*connEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerTrafficStrategy) wrapConn(ctx context.Context, conn net.Conn, metadata *adapter.InboundContext, reverse bool) (net.Conn, error) {
|
||||
strategy, err := s.getStrategy(ctx, metadata)
|
||||
strategy, user, err := s.getStrategy(ctx, metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return strategy.wrapConn(ctx, conn, metadata, reverse)
|
||||
wrapped, err := strategy.wrapConn(ctx, conn, metadata, reverse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry := &connEntry{conn: conn}
|
||||
s.mtx.Lock()
|
||||
s.conns[user] = append(s.conns[user], entry)
|
||||
s.mtx.Unlock()
|
||||
return onclose.NewConn(wrapped, func() {
|
||||
s.mtx.Lock()
|
||||
entries := s.conns[user]
|
||||
for i, e := range entries {
|
||||
if e == entry {
|
||||
s.conns[user] = append(entries[:i], entries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *ManagerTrafficStrategy) wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error) {
|
||||
strategy, err := s.getStrategy(ctx, metadata)
|
||||
strategy, _, err := s.getStrategy(ctx, metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return strategy.wrapPacketConn(ctx, conn, metadata, reverse)
|
||||
}
|
||||
|
||||
func (s *ManagerTrafficStrategy) getStrategy(ctx context.Context, metadata *adapter.InboundContext) (TrafficStrategy, error) {
|
||||
func (s *ManagerTrafficStrategy) getStrategy(ctx context.Context, metadata *adapter.InboundContext) (TrafficStrategy, string, error) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var user string
|
||||
@@ -106,15 +134,25 @@ func (s *ManagerTrafficStrategy) getStrategy(ctx context.Context, metadata *adap
|
||||
}
|
||||
strategy, ok := s.strategies[user]
|
||||
if ok {
|
||||
return strategy, nil
|
||||
return strategy, user, nil
|
||||
}
|
||||
return nil, E.New("user strategy not found: ", user)
|
||||
return nil, user, E.New("user strategy not found: ", user)
|
||||
}
|
||||
|
||||
func (s *ManagerTrafficStrategy) UpdateStrategies(strategies map[string]TrafficStrategy) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var closedEntries []*connEntry
|
||||
for user, entries := range s.conns {
|
||||
if _, exists := strategies[user]; !exists {
|
||||
closedEntries = append(closedEntries, entries...)
|
||||
delete(s.conns, user)
|
||||
}
|
||||
}
|
||||
s.strategies = strategies
|
||||
s.mtx.Unlock()
|
||||
for _, entry := range closedEntries {
|
||||
entry.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type BypassTrafficStrategy struct{}
|
||||
|
||||
Reference in New Issue
Block a user