mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-26 20:29:03 +03:00
158 lines
3.4 KiB
Go
158 lines
3.4 KiB
Go
package bandwidth
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
"github.com/sagernet/sing-box/common/list"
|
|
)
|
|
|
|
type BandwidthLimiter interface {
|
|
WaitN(ctx context.Context, n int) (err error)
|
|
SetSpeed(speed uint64)
|
|
}
|
|
|
|
type FairQueueLimiter struct {
|
|
limiter BandwidthLimiter
|
|
connIDGetter ConnIDGetter
|
|
|
|
flows *list.List[*flow]
|
|
index map[string]*list.Element[*flow]
|
|
bytes map[string]uint64
|
|
pool sync.Pool
|
|
queue chan struct{}
|
|
reset time.Time
|
|
|
|
mtx sync.Mutex
|
|
}
|
|
|
|
func NewFairQueueLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FairQueueLimiter {
|
|
return &FairQueueLimiter{
|
|
limiter: limiter,
|
|
connIDGetter: connIDGetter,
|
|
flows: list.New[*flow](),
|
|
index: make(map[string]*list.Element[*flow]),
|
|
bytes: make(map[string]uint64),
|
|
pool: sync.Pool{New: func() any { return list.New[*request]() }},
|
|
queue: make(chan struct{}, 1),
|
|
reset: time.Now().Add(time.Second),
|
|
}
|
|
}
|
|
|
|
func (l *FairQueueLimiter) SetSpeed(speed uint64) {
|
|
l.limiter.SetSpeed(speed)
|
|
}
|
|
|
|
func (l *FairQueueLimiter) WaitN(ctx context.Context, n int) error {
|
|
id, _ := l.connIDGetter(ctx, adapter.ContextFrom(ctx))
|
|
mainRequest := &request{ctx: ctx, done: make(chan struct{}), n: n}
|
|
l.mtx.Lock()
|
|
elem, ok := l.index[id]
|
|
if !ok {
|
|
f := &flow{id: id, pending: l.pool.Get().(*list.List[*request])}
|
|
elem = l.flows.PushFront(f)
|
|
l.index[id] = elem
|
|
}
|
|
mainRequestElem := elem.Value.pending.PushBack(mainRequest)
|
|
l.reorder(elem)
|
|
l.mtx.Unlock()
|
|
select {
|
|
case l.queue <- struct{}{}:
|
|
case <-mainRequest.done:
|
|
return nil
|
|
case <-ctx.Done():
|
|
l.mtx.Lock()
|
|
l.removeRequest(id, mainRequestElem)
|
|
l.mtx.Unlock()
|
|
return ctx.Err()
|
|
}
|
|
select {
|
|
case <-mainRequest.done:
|
|
<-l.queue
|
|
return nil
|
|
default:
|
|
}
|
|
for {
|
|
if ctx.Err() != nil {
|
|
l.mtx.Lock()
|
|
l.removeRequest(id, mainRequestElem)
|
|
l.mtx.Unlock()
|
|
<-l.queue
|
|
return ctx.Err()
|
|
}
|
|
l.mtx.Lock()
|
|
now := time.Now()
|
|
if l.reset.Compare(now) == -1 {
|
|
clear(l.bytes)
|
|
l.reset = now.Add(time.Second)
|
|
}
|
|
flowElem := l.flows.Front()
|
|
flow := flowElem.Value
|
|
firstRequestElem := flow.pending.Front()
|
|
firstRequest := firstRequestElem.Value
|
|
l.bytes[flow.id] += uint64(firstRequest.n)
|
|
firstRequestElem.Remove()
|
|
if flow.pending.Len() == 0 {
|
|
l.flows.Remove(flowElem)
|
|
delete(l.index, flow.id)
|
|
l.pool.Put(flow.pending)
|
|
} else {
|
|
l.reorder(flowElem)
|
|
}
|
|
l.mtx.Unlock()
|
|
l.limiter.WaitN(firstRequest.ctx, firstRequest.n)
|
|
close(firstRequest.done)
|
|
if firstRequest == mainRequest {
|
|
<-l.queue
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *FairQueueLimiter) reorder(elem *list.Element[*flow]) {
|
|
f := elem.Value
|
|
front := f.pending.Front()
|
|
if front == nil {
|
|
return
|
|
}
|
|
cost := l.bytes[f.id] + uint64(front.Value.n)
|
|
for e := l.flows.Front(); e != nil; e = e.Next() {
|
|
if e == elem {
|
|
continue
|
|
}
|
|
eFront := e.Value.pending.Front()
|
|
if eFront == nil {
|
|
continue
|
|
}
|
|
if cost < l.bytes[e.Value.id]+uint64(eFront.Value.n) {
|
|
l.flows.MoveBefore(elem, e)
|
|
return
|
|
}
|
|
}
|
|
l.flows.MoveToBack(elem)
|
|
}
|
|
|
|
func (l *FairQueueLimiter) removeRequest(id string, elem *list.Element[*request]) {
|
|
if !elem.Remove() {
|
|
return
|
|
}
|
|
if flowElem, ok := l.index[id]; ok && flowElem.Value.pending.Len() == 0 {
|
|
l.flows.Remove(flowElem)
|
|
delete(l.index, id)
|
|
l.pool.Put(flowElem.Value.pending)
|
|
}
|
|
}
|
|
|
|
type flow struct {
|
|
id string
|
|
pending *list.List[*request]
|
|
}
|
|
|
|
type request struct {
|
|
ctx context.Context
|
|
done chan struct{}
|
|
n int
|
|
}
|