diff --git a/common/kmutex/mutex.go b/common/kmutex/mutex.go new file mode 100644 index 00000000..9767959f --- /dev/null +++ b/common/kmutex/mutex.go @@ -0,0 +1,57 @@ +package kmutex + +import "sync" + +type Kmutex[T comparable] struct { + l sync.Locker + s map[T]*klock +} + +type klock struct { + cond *sync.Cond + ref uint64 +} + +// Create new Kmutex +func New[T comparable]() *Kmutex[T] { + l := sync.Mutex{} + return &Kmutex[T]{ + l: &l, + s: make(map[T]*klock), + } +} + +// Unlock Kmutex by unique ID +func (km *Kmutex[T]) Unlock(key T) { + km.l.Lock() + defer km.l.Unlock() + kl, ok := km.s[key] + if !ok || kl.ref == 0 { + panic("unlock of unlocked kmutex") + } + kl.ref-- + if kl.ref == 0 { + delete(km.s, key) + return + } + kl.cond.Signal() +} + +// Lock Kmutex by unique ID +func (km *Kmutex[T]) Lock(key T) { + km.l.Lock() + defer km.l.Unlock() + for { + kl, ok := km.s[key] + if !ok { + km.s[key] = &klock{ + cond: sync.NewCond(km.l), + ref: 1, + } + return + } + kl.ref++ + kl.cond.Wait() + return + } +} diff --git a/common/kmutex/mutex_test.go b/common/kmutex/mutex_test.go new file mode 100644 index 00000000..6648442b --- /dev/null +++ b/common/kmutex/mutex_test.go @@ -0,0 +1,96 @@ +package kmutex + +import ( + "sync" + "testing" + "time" +) + +// Number of unique resources to access +const number = 100 + +func makeIds(count int) []int { + ids := make([]int, count) + for i := 0; i < count; i++ { + ids[i] = i + } + return ids +} + +func TestKmutex(t *testing.T) { + km := New[int]() + ids := makeIds(number) + resources := make([]int, number) + wg := sync.WaitGroup{} + + lc := make(chan int) + uc := make(chan int) + // Start 10n goroutines accessing n resources 10 times each + for i := 0; i < 10*number; i++ { + wg.Add(1) + go func(k int) { + for j := 0; j < 10; j++ { + lc <- k + km.Lock(ids[k]) + // read and write resource to check for race + resources[k] = resources[k] + 1 + km.Unlock(ids[k]) + uc <- k + } + wg.Done() + }(i % len(ids)) + } + + to := time.After(time.Second) + counts := make(map[int]int) + var lCount, ulCount int +loop: + for { + select { + case k := <-lc: + counts[k] = counts[k] + 1 + lCount++ + case k := <-uc: + counts[k] = counts[k] - 1 + ulCount++ + case <-to: + t.Fatal("timed out waiting for results") + break loop + } + expectCount := 100 * number + if lCount == expectCount && ulCount == expectCount { + // Have all results + break + } + } + for k, c := range counts { + if c != 0 { + t.Errorf("Key %d count != 0: %d\n", k, c) + } + } + + wg.Wait() +} + +func BenchmarkKmutex1000(b *testing.B) { + km := New[int]() + ids := makeIds(number) + resources := make([]int, number) + wg := sync.WaitGroup{} + + // Start 1000 goroutines accessing 100 resources N times each + b.ResetTimer() + for i := 0; i < 1000; i++ { + wg.Add(1) + go func(k int) { + for j := 0; j < b.N; j++ { + km.Lock(ids[k]) + // read and write resource to check for race + resources[k] = resources[k] + 1 + km.Unlock(ids[k]) + } + wg.Done() + }(i % len(ids)) + } + wg.Wait() +} diff --git a/protocol/bond/inbound.go b/protocol/bond/inbound.go index 35e1cc48..6eac51c6 100644 --- a/protocol/bond/inbound.go +++ b/protocol/bond/inbound.go @@ -4,12 +4,12 @@ import ( "context" "errors" "net" - "sync" "time" "github.com/patrickmn/go-cache" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/common/kmutex" "github.com/sagernet/sing-box/common/uot" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" @@ -31,7 +31,7 @@ type Inbound struct { inbounds []adapter.Inbound conns *cache.Cache - mtx sync.Mutex + mtx *kmutex.Kmutex[string] } func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.BondInboundOptions) (adapter.Inbound, error) { @@ -43,6 +43,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo logger: logger, router: uot.NewRouter(router, logger), conns: cache.New(C.TCPConnectTimeout, time.Second), + mtx: kmutex.New[string](), } inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) inbounds := make([]adapter.Inbound, len(options.Inbounds)) @@ -55,8 +56,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo } inbound.inbounds = inbounds inbound.conns.OnEvicted(func(s string, i interface{}) { - inbound.mtx.Lock() - defer inbound.mtx.Unlock() + inbound.mtx.Lock(s) + defer inbound.mtx.Unlock(s) ratioConns := i.(map[uint8]*ratioConn) for _, ratioConn := range ratioConns { if ratioConn != nil { @@ -100,15 +101,15 @@ func (h *Inbound) connHandler(ctx context.Context, conn net.Conn, metadata adapt if err != nil { return err } - h.mtx.Lock() - defer h.mtx.Unlock() + requestUUID := request.UUID.String() + h.mtx.Lock(requestUUID) var ratioConns map[uint8]*ratioConn - rawRatioConns, ok := h.conns.Get(request.UUID.String()) + rawRatioConns, ok := h.conns.Get(requestUUID) if ok { ratioConns = rawRatioConns.(map[uint8]*ratioConn) } else { ratioConns = make(map[uint8]*ratioConn, request.Count) - h.conns.SetDefault(request.UUID.String(), ratioConns) + h.conns.SetDefault(requestUUID, ratioConns) } ratioConns[request.Index] = &ratioConn{ conn: conn, @@ -132,14 +133,18 @@ func (h *Inbound) connHandler(ctx context.Context, conn net.Conn, metadata adapt for _, conn := range conns { conn.Close() } + h.mtx.Unlock(requestUUID) return E.New("invalid ratios") } conn = NewBondedConn(conns, downloadRatios, uploadRatios) metadata.Inbound = h.Tag() metadata.InboundType = C.TypeBond metadata.Destination = request.Destination + h.mtx.Unlock(requestUUID) h.router.RouteConnectionEx(ctx, conn, metadata, onClose) + return nil } + h.mtx.Unlock(requestUUID) return nil }