Files
sing-box-extended/protocol/limiter/connection/strategy.go

120 lines
3.2 KiB
Go

package connection
import (
"context"
"strconv"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type (
CloseHandlerFunc = func()
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
LockIDGetter = func(string) (CloseHandlerFunc, context.Context, error)
ConnectionStrategy interface {
request(ctx context.Context, metadata *adapter.InboundContext) (onClose CloseHandlerFunc, lockCtx context.Context, err error)
}
)
type DefaultConnectionStrategy struct {
connIDGetter ConnIDGetter
lockIDGetter LockIDGetter
mtx sync.Mutex
}
func NewDefaultConnectionStrategy(connIDGetter ConnIDGetter, lockIDGetter LockIDGetter) *DefaultConnectionStrategy {
outbound := &DefaultConnectionStrategy{
connIDGetter: connIDGetter,
lockIDGetter: lockIDGetter,
}
return outbound
}
func (s *DefaultConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
id, ok := s.connIDGetter(ctx, metadata)
if !ok {
return nil, nil, E.New("id not found")
}
return s.lockIDGetter(id)
}
type UsersConnectionStrategy struct {
strategies map[string]ConnectionStrategy
mtx sync.Mutex
}
func NewUsersConnectionStrategy(strategies map[string]ConnectionStrategy) *UsersConnectionStrategy {
return &UsersConnectionStrategy{
strategies: strategies,
}
}
func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
var user string
if metadata != nil {
user = metadata.User
}
strategy, ok := s.strategies[user]
if ok {
return strategy.request(ctx, metadata)
}
return nil, nil, E.New("user strategy not found: ", user)
}
type ManagerConnectionStrategy struct {
*UsersConnectionStrategy
}
func NewManagerConnectionStrategy() *ManagerConnectionStrategy {
return &ManagerConnectionStrategy{
UsersConnectionStrategy: NewUsersConnectionStrategy(map[string]ConnectionStrategy{}),
}
}
func (s *ManagerConnectionStrategy) UpdateStrategies(strategies map[string]ConnectionStrategy) {
s.mtx.Lock()
defer s.mtx.Unlock()
s.strategies = strategies
}
func CreateStrategy(strategy string, connectionType string, lockIDGetter LockIDGetter) (ConnectionStrategy, error) {
switch strategy {
case "connection":
var connIDGetter ConnIDGetter
switch connectionType {
case "mux":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
id, ok := log.MuxIDFromContext(ctx)
if !ok {
return "", ok
}
return strconv.FormatUint(uint64(id.ID), 10), ok
}
case "hwid":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
id, ok := ctx.Value("hwid").(string)
return id, ok
}
case "ip":
connIDGetter = func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
return metadata.Source.IPAddr().String(), true
}
default:
return nil, E.New("connection type not found: ", connectionType)
}
return NewDefaultConnectionStrategy(connIDGetter, lockIDGetter), nil
default:
return nil, E.New("strategy not found: ", strategy)
}
}