Add Kmutex

This commit is contained in:
Sergei Maklagin
2026-03-03 00:51:20 +03:00
parent 517f5152e7
commit 0503006f48
3 changed files with 166 additions and 8 deletions

57
common/kmutex/mutex.go Normal file
View File

@@ -0,0 +1,57 @@
package kmutex
import "sync"
type Kmutex[T comparable] struct {
l sync.Locker
s map[T]*klock
}
type klock struct {
cond *sync.Cond
ref uint64
}
// Create new Kmutex
func New[T comparable]() *Kmutex[T] {
l := sync.Mutex{}
return &Kmutex[T]{
l: &l,
s: make(map[T]*klock),
}
}
// Unlock Kmutex by unique ID
func (km *Kmutex[T]) Unlock(key T) {
km.l.Lock()
defer km.l.Unlock()
kl, ok := km.s[key]
if !ok || kl.ref == 0 {
panic("unlock of unlocked kmutex")
}
kl.ref--
if kl.ref == 0 {
delete(km.s, key)
return
}
kl.cond.Signal()
}
// Lock Kmutex by unique ID
func (km *Kmutex[T]) Lock(key T) {
km.l.Lock()
defer km.l.Unlock()
for {
kl, ok := km.s[key]
if !ok {
km.s[key] = &klock{
cond: sync.NewCond(km.l),
ref: 1,
}
return
}
kl.ref++
kl.cond.Wait()
return
}
}

View File

@@ -0,0 +1,96 @@
package kmutex
import (
"sync"
"testing"
"time"
)
// Number of unique resources to access
const number = 100
func makeIds(count int) []int {
ids := make([]int, count)
for i := 0; i < count; i++ {
ids[i] = i
}
return ids
}
func TestKmutex(t *testing.T) {
km := New[int]()
ids := makeIds(number)
resources := make([]int, number)
wg := sync.WaitGroup{}
lc := make(chan int)
uc := make(chan int)
// Start 10n goroutines accessing n resources 10 times each
for i := 0; i < 10*number; i++ {
wg.Add(1)
go func(k int) {
for j := 0; j < 10; j++ {
lc <- k
km.Lock(ids[k])
// read and write resource to check for race
resources[k] = resources[k] + 1
km.Unlock(ids[k])
uc <- k
}
wg.Done()
}(i % len(ids))
}
to := time.After(time.Second)
counts := make(map[int]int)
var lCount, ulCount int
loop:
for {
select {
case k := <-lc:
counts[k] = counts[k] + 1
lCount++
case k := <-uc:
counts[k] = counts[k] - 1
ulCount++
case <-to:
t.Fatal("timed out waiting for results")
break loop
}
expectCount := 100 * number
if lCount == expectCount && ulCount == expectCount {
// Have all results
break
}
}
for k, c := range counts {
if c != 0 {
t.Errorf("Key %d count != 0: %d\n", k, c)
}
}
wg.Wait()
}
func BenchmarkKmutex1000(b *testing.B) {
km := New[int]()
ids := makeIds(number)
resources := make([]int, number)
wg := sync.WaitGroup{}
// Start 1000 goroutines accessing 100 resources N times each
b.ResetTimer()
for i := 0; i < 1000; i++ {
wg.Add(1)
go func(k int) {
for j := 0; j < b.N; j++ {
km.Lock(ids[k])
// read and write resource to check for race
resources[k] = resources[k] + 1
km.Unlock(ids[k])
}
wg.Done()
}(i % len(ids))
}
wg.Wait()
}

View File

@@ -4,12 +4,12 @@ import (
"context"
"errors"
"net"
"sync"
"time"
"github.com/patrickmn/go-cache"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/kmutex"
"github.com/sagernet/sing-box/common/uot"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
@@ -31,7 +31,7 @@ type Inbound struct {
inbounds []adapter.Inbound
conns *cache.Cache
mtx sync.Mutex
mtx *kmutex.Kmutex[string]
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.BondInboundOptions) (adapter.Inbound, error) {
@@ -43,6 +43,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
logger: logger,
router: uot.NewRouter(router, logger),
conns: cache.New(C.TCPConnectTimeout, time.Second),
mtx: kmutex.New[string](),
}
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
inbounds := make([]adapter.Inbound, len(options.Inbounds))
@@ -55,8 +56,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
}
inbound.inbounds = inbounds
inbound.conns.OnEvicted(func(s string, i interface{}) {
inbound.mtx.Lock()
defer inbound.mtx.Unlock()
inbound.mtx.Lock(s)
defer inbound.mtx.Unlock(s)
ratioConns := i.(map[uint8]*ratioConn)
for _, ratioConn := range ratioConns {
if ratioConn != nil {
@@ -100,15 +101,15 @@ func (h *Inbound) connHandler(ctx context.Context, conn net.Conn, metadata adapt
if err != nil {
return err
}
h.mtx.Lock()
defer h.mtx.Unlock()
requestUUID := request.UUID.String()
h.mtx.Lock(requestUUID)
var ratioConns map[uint8]*ratioConn
rawRatioConns, ok := h.conns.Get(request.UUID.String())
rawRatioConns, ok := h.conns.Get(requestUUID)
if ok {
ratioConns = rawRatioConns.(map[uint8]*ratioConn)
} else {
ratioConns = make(map[uint8]*ratioConn, request.Count)
h.conns.SetDefault(request.UUID.String(), ratioConns)
h.conns.SetDefault(requestUUID, ratioConns)
}
ratioConns[request.Index] = &ratioConn{
conn: conn,
@@ -132,14 +133,18 @@ func (h *Inbound) connHandler(ctx context.Context, conn net.Conn, metadata adapt
for _, conn := range conns {
conn.Close()
}
h.mtx.Unlock(requestUUID)
return E.New("invalid ratios")
}
conn = NewBondedConn(conns, downloadRatios, uploadRatios)
metadata.Inbound = h.Tag()
metadata.InboundType = C.TypeBond
metadata.Destination = request.Destination
h.mtx.Unlock(requestUUID)
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return nil
}
h.mtx.Unlock(requestUUID)
return nil
}