Files
sing-box-extended/protocol/limiter/connection/strategy.go

197 lines
5.3 KiB
Go

package connection
import (
"context"
"strconv"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/onclose"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type (
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
LockIDGetter = func(string) (onclose.CloseHandlerFunc, context.Context, error)
ConnectionStrategy interface {
request(ctx context.Context, metadata *adapter.InboundContext) (onClose onclose.CloseHandlerFunc, lockCtx context.Context, err error)
}
)
type DefaultConnectionStrategy struct {
connIDGetter ConnIDGetter
lockIDGetter LockIDGetter
mtx sync.Mutex
}
func NewDefaultConnectionStrategy(connIDGetter ConnIDGetter, lockIDGetter LockIDGetter) *DefaultConnectionStrategy {
outbound := &DefaultConnectionStrategy{
connIDGetter: connIDGetter,
lockIDGetter: lockIDGetter,
}
return outbound
}
func (s *DefaultConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
id, ok := s.connIDGetter(ctx, metadata)
if !ok {
return nil, nil, E.New("id not found")
}
return s.lockIDGetter(id)
}
type UsersConnectionStrategy struct {
strategies map[string]ConnectionStrategy
mtx sync.Mutex
}
func NewUsersConnectionStrategy(strategies map[string]ConnectionStrategy) *UsersConnectionStrategy {
return &UsersConnectionStrategy{
strategies: strategies,
}
}
func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
var user string
if metadata != nil {
user = metadata.User
}
strategy, ok := s.strategies[user]
if ok {
return strategy.request(ctx, metadata)
}
return nil, nil, E.New("user strategy not found: ", user)
}
type cancelEntry struct {
cancel context.CancelFunc
}
type ManagerConnectionStrategy struct {
strategies map[string]ConnectionStrategy
cancels map[string][]*cancelEntry
mtx sync.Mutex
}
func NewManagerConnectionStrategy() *ManagerConnectionStrategy {
return &ManagerConnectionStrategy{
strategies: make(map[string]ConnectionStrategy),
cancels: make(map[string][]*cancelEntry),
}
}
func (s *ManagerConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
var user string
if metadata != nil {
user = metadata.User
}
strategy, ok := s.strategies[user]
if !ok {
s.mtx.Unlock()
return nil, nil, E.New("user strategy not found: ", user)
}
s.mtx.Unlock()
onClose, _, err := strategy.request(ctx, metadata)
if err != nil {
return nil, nil, err
}
cancelCtx, cancel := context.WithCancel(context.Background())
entry := &cancelEntry{cancel: cancel}
s.mtx.Lock()
s.cancels[user] = append(s.cancels[user], entry)
s.mtx.Unlock()
originalOnClose := onClose
wrappedOnClose := func() {
s.mtx.Lock()
entries := s.cancels[user]
for i, e := range entries {
if e == entry {
s.cancels[user] = append(entries[:i], entries[i+1:]...)
break
}
}
s.mtx.Unlock()
cancel()
if originalOnClose != nil {
originalOnClose()
}
}
return wrappedOnClose, cancelCtx, nil
}
func (s *ManagerConnectionStrategy) UpdateStrategies(strategies map[string]ConnectionStrategy) {
s.mtx.Lock()
var entries []*cancelEntry
for user, cancels := range s.cancels {
if _, exists := strategies[user]; !exists {
entries = append(entries, cancels...)
delete(s.cancels, user)
}
}
s.strategies = strategies
s.mtx.Unlock()
for _, entry := range entries {
entry.cancel()
}
}
type BypassConnectionStrategy struct{}
func NewBypassConnectionStrategy() *BypassConnectionStrategy {
return &BypassConnectionStrategy{}
}
func (s *BypassConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
return func() {}, nil, nil
}
func CreateStrategy(strategy string, connectionType string, lockIDGetter LockIDGetter) (ConnectionStrategy, error) {
switch strategy {
case "connection":
var connIDGetter ConnIDGetter
switch connectionType {
case "hwid":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
id, ok := ctx.Value("hwid").(string)
return id, ok
}
case "mux":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
id, ok := log.MuxIDFromContext(ctx)
if !ok {
return "", ok
}
return strconv.FormatUint(uint64(id.ID), 10), ok
}
case "source_ip":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
return metadata.Source.IPAddr().String(), true
}
case "default", "":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
id, ok := log.IDFromContext(ctx)
if !ok {
return "", ok
}
return strconv.FormatUint(uint64(id.ID), 10), ok
}
default:
return nil, E.New("connection type not found: ", connectionType)
}
return NewDefaultConnectionStrategy(connIDGetter, lockIDGetter), nil
case "bypass":
return NewBypassConnectionStrategy(), nil
default:
return nil, E.New("strategy not found: ", strategy)
}
}