dns: Fix conn pool leak

This commit is contained in:
世界
2026-05-11 20:59:49 +08:00
parent e52a4b8d7d
commit e31b5fac0e

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"net" "net"
"sync" "sync"
"time"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
) )
@@ -53,19 +52,6 @@ type connPoolConnect[T comparable] struct {
err error err error
} }
type connPoolDialContext struct {
context.Context
parent context.Context
}
func (c connPoolDialContext) Deadline() (time.Time, bool) {
return c.parent.Deadline()
}
func (c connPoolDialContext) Value(key any) any {
return c.parent.Value(key)
}
func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] { func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] {
return &ConnPool[T]{ return &ConnPool[T]{
options: options, options: options,
@@ -108,67 +94,27 @@ func (p *ConnPool[T]) AcquireShared(ctx context.Context, dial func(context.Conte
} }
func (p *ConnPool[T]) Release(conn T, reuse bool) { func (p *ConnPool[T]) Release(conn T, reuse bool) {
var (
closeConn bool
closeErr error
)
p.access.Lock() p.access.Lock()
if p.closed || p.state == nil { if p.closed {
closeConn = true
closeErr = net.ErrClosed
p.access.Unlock() p.access.Unlock()
if closeConn { p.options.Close(conn, net.ErrClosed)
p.options.Close(conn, closeErr)
}
return return
} }
state := p.state
currentState := p.state if _, tracked := state.all[conn]; !tracked {
_, tracked := currentState.all[conn]
if !tracked {
closeConn = true
closeErr = p.closeCause(currentState)
p.access.Unlock() p.access.Unlock()
if closeConn { p.options.Close(conn, net.ErrClosed)
p.options.Close(conn, closeErr)
}
return return
} }
if !reuse || !p.options.IsAlive(conn) { if !reuse || !p.options.IsAlive(conn) {
delete(currentState.all, conn) p.removeConn(state, conn, net.ErrClosed)
switch p.options.Mode {
case ConnPoolSingle:
if currentState.hasShared && currentState.shared == conn {
var zero T
currentState.shared = zero
currentState.hasShared = false
currentState.sharedClaimed = false
currentState.sharedCtx = nil
if currentState.sharedCancel != nil {
currentState.sharedCancel(net.ErrClosed)
currentState.sharedCancel = nil
}
}
case ConnPoolOrdered:
if element, loaded := currentState.idleElements[conn]; loaded {
currentState.idle.Remove(element)
delete(currentState.idleElements, conn)
}
}
closeConn = true
closeErr = net.ErrClosed
p.access.Unlock() p.access.Unlock()
if closeConn { p.options.Close(conn, net.ErrClosed)
p.options.Close(conn, closeErr)
}
return return
} }
if p.options.Mode == ConnPoolOrdered { if p.options.Mode == ConnPoolOrdered {
if _, loaded := currentState.idleElements[conn]; !loaded { if _, loaded := state.idleElements[conn]; !loaded {
currentState.idleElements[conn] = currentState.idle.PushBack(conn) state.idleElements[conn] = state.idle.PushBack(conn)
} }
} }
p.access.Unlock() p.access.Unlock()
@@ -176,42 +122,43 @@ func (p *ConnPool[T]) Release(conn T, reuse bool) {
func (p *ConnPool[T]) Invalidate(conn T, cause error) { func (p *ConnPool[T]) Invalidate(conn T, cause error) {
p.access.Lock() p.access.Lock()
if p.closed || p.state == nil { if p.closed {
p.access.Unlock() p.access.Unlock()
p.options.Close(conn, cause) p.options.Close(conn, cause)
return return
} }
state := p.state
currentState := p.state if _, tracked := state.all[conn]; !tracked {
_, tracked := currentState.all[conn]
if !tracked {
p.access.Unlock() p.access.Unlock()
return return
} }
p.removeConn(state, conn, cause)
p.access.Unlock()
p.options.Close(conn, cause)
}
delete(currentState.all, conn) // removeConn must be called with p.access held.
func (p *ConnPool[T]) removeConn(state *connPoolState[T], conn T, cause error) {
delete(state.all, conn)
switch p.options.Mode { switch p.options.Mode {
case ConnPoolSingle: case ConnPoolSingle:
if currentState.hasShared && currentState.shared == conn { if state.hasShared && state.shared == conn {
var zero T var zero T
currentState.shared = zero state.shared = zero
currentState.hasShared = false state.hasShared = false
currentState.sharedClaimed = false state.sharedClaimed = false
currentState.sharedCtx = nil state.sharedCtx = nil
if currentState.sharedCancel != nil { if state.sharedCancel != nil {
currentState.sharedCancel(cause) state.sharedCancel(cause)
currentState.sharedCancel = nil state.sharedCancel = nil
} }
} }
case ConnPoolOrdered: case ConnPoolOrdered:
if element, loaded := currentState.idleElements[conn]; loaded { if element, loaded := state.idleElements[conn]; loaded {
currentState.idle.Remove(element) state.idle.Remove(element)
delete(currentState.idleElements, conn) delete(state.idleElements, conn)
} }
} }
p.access.Unlock()
p.options.Close(conn, cause)
} }
func (p *ConnPool[T]) Reset() { func (p *ConnPool[T]) Reset() {
@@ -220,7 +167,6 @@ func (p *ConnPool[T]) Reset() {
p.access.Unlock() p.access.Unlock()
return return
} }
oldState := p.state oldState := p.state
p.state = newConnPoolState[T](p.options.Mode) p.state = newConnPoolState[T](p.options.Mode)
p.access.Unlock() p.access.Unlock()
@@ -234,7 +180,6 @@ func (p *ConnPool[T]) Close() error {
p.access.Unlock() p.access.Unlock()
return nil return nil
} }
p.closed = true p.closed = true
oldState := p.state oldState := p.state
p.state = nil p.state = nil
@@ -247,40 +192,47 @@ func (p *ConnPool[T]) Close() error {
func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) { func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) {
var zero T var zero T
for { for {
var (
staleConn T
hasStale bool
)
p.access.Lock() p.access.Lock()
if p.closed { if p.closed {
p.access.Unlock() p.access.Unlock()
return zero, false, net.ErrClosed return zero, false, net.ErrClosed
} }
current := p.state
currentState := p.state if element := current.idle.Front(); element != nil {
if element := currentState.idle.Front(); element != nil { conn := current.idle.Remove(element)
conn := currentState.idle.Remove(element) delete(current.idleElements, conn)
delete(currentState.idleElements, conn)
if p.options.IsAlive(conn) { if p.options.IsAlive(conn) {
p.access.Unlock() p.access.Unlock()
return conn, false, nil return conn, false, nil
} }
delete(currentState.all, conn) delete(current.all, conn)
staleConn = conn p.access.Unlock()
hasStale = true p.options.Close(conn, net.ErrClosed)
continue
} }
p.access.Unlock() p.access.Unlock()
if hasStale { dialCtx, dialCancel := context.WithCancelCause(ctx)
p.options.Close(staleConn, net.ErrClosed) stopStateCancel := context.AfterFunc(current.ctx, func() {
continue dialCancel(context.Cause(current.ctx))
})
conn, err := dial(dialCtx)
stateCancelStopped := stopStateCancel()
dialErr := context.Cause(dialCtx)
if dialErr == nil && !stateCancelStopped {
dialErr = context.Cause(current.ctx)
} }
dialCancel(nil)
conn, err := p.dial(ctx, currentState, dial)
if err != nil { if err != nil {
if dialErr != nil {
return zero, false, dialErr
}
return zero, false, err return zero, false, err
} }
if dialErr != nil {
p.options.Close(conn, dialErr)
return zero, false, dialErr
}
p.access.Lock() p.access.Lock()
if p.closed { if p.closed {
@@ -288,13 +240,12 @@ func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Cont
p.options.Close(conn, net.ErrClosed) p.options.Close(conn, net.ErrClosed)
return zero, false, net.ErrClosed return zero, false, net.ErrClosed
} }
if p.state != currentState { if p.state != current {
cause := p.closeCause(currentState)
p.access.Unlock() p.access.Unlock()
p.options.Close(conn, cause) p.options.Close(conn, net.ErrClosed)
return zero, false, cause return zero, false, net.ErrClosed
} }
currentState.all[conn] = struct{}{} current.all[conn] = struct{}{}
p.access.Unlock() p.access.Unlock()
return conn, true, nil return conn, true, nil
} }
@@ -303,21 +254,12 @@ func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Cont
func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) { func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) {
var zero T var zero T
for { for {
var (
staleConn T
hasStale bool
state *connPoolConnect[T]
current *connPoolState[T]
startDial bool
)
p.access.Lock() p.access.Lock()
if p.closed { if p.closed {
p.access.Unlock() p.access.Unlock()
return zero, nil, false, net.ErrClosed return zero, nil, false, net.ErrClosed
} }
current := p.state
current = p.state
if current.hasShared { if current.hasShared {
conn := current.shared conn := current.shared
if p.options.IsAlive(conn) { if p.options.IsAlive(conn) {
@@ -327,35 +269,19 @@ func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Conte
p.access.Unlock() p.access.Unlock()
return conn, connCtx, created, nil return conn, connCtx, created, nil
} }
delete(current.all, conn) p.removeConn(current, conn, net.ErrClosed)
var zeroConn T
current.shared = zeroConn
current.hasShared = false
current.sharedClaimed = false
current.sharedCtx = nil
if current.sharedCancel != nil {
current.sharedCancel(net.ErrClosed)
current.sharedCancel = nil
}
staleConn = conn
hasStale = true
p.access.Unlock() p.access.Unlock()
p.options.Close(staleConn, net.ErrClosed) p.options.Close(conn, net.ErrClosed)
continue continue
} }
if current.connecting == nil { startDial := current.connecting == nil
current.connecting = &connPoolConnect[T]{ if startDial {
done: make(chan struct{}), current.connecting = &connPoolConnect[T]{done: make(chan struct{})}
}
startDial = true
} }
state = current.connecting state := current.connecting
p.access.Unlock() p.access.Unlock()
if hasStale {
continue
}
if startDial { if startDial {
go p.connectSingle(current, state, ctx, dial) go p.connectSingle(current, state, ctx, dial)
} }
@@ -381,35 +307,39 @@ func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Conte
} }
func (p *ConnPool[T]) connectSingle(current *connPoolState[T], state *connPoolConnect[T], ctx context.Context, dial func(context.Context) (T, error)) { func (p *ConnPool[T]) connectSingle(current *connPoolState[T], state *connPoolConnect[T], ctx context.Context, dial func(context.Context) (T, error)) {
conn, err := p.dial(ctx, current, dial) dialCtx, dialCancel := context.WithCancelCause(ctx)
if err != nil { stopStateCancel := context.AfterFunc(current.ctx, func() {
p.access.Lock() dialCancel(context.Cause(current.ctx))
if current.connecting == state { })
current.connecting = nil conn, err := dial(dialCtx)
stateCancelStopped := stopStateCancel()
dialErr := context.Cause(dialCtx)
if dialErr == nil && !stateCancelStopped {
dialErr = context.Cause(current.ctx)
}
dialCancel(nil)
if dialErr != nil {
if err == nil {
p.options.Close(conn, dialErr)
} }
state.err = err err = dialErr
p.access.Unlock()
close(state.done)
return
} }
var closeErr error var closeErr error
p.access.Lock() p.access.Lock()
if current.connecting == state { current.connecting = nil
current.connecting = nil if err != nil {
} state.err = err
if p.closed { } else if p.closed {
closeErr = net.ErrClosed closeErr = net.ErrClosed
state.err = closeErr state.err = closeErr
} else if p.state != current { } else if p.state != current {
closeErr = p.closeCause(current) closeErr = net.ErrClosed
state.err = closeErr state.err = closeErr
} else { } else {
sharedCtx, sharedCancel := context.WithCancelCause(current.ctx) sharedCtx, sharedCancel := context.WithCancelCause(current.ctx)
current.shared = conn current.shared = conn
current.hasShared = true current.hasShared = true
current.sharedClaimed = false
current.sharedCtx = sharedCtx current.sharedCtx = sharedCtx
current.sharedCancel = sharedCancel current.sharedCancel = sharedCancel
current.all[conn] = struct{}{} current.all[conn] = struct{}{}
@@ -439,9 +369,8 @@ func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolCo
return zero, nil, false, false, net.ErrClosed return zero, nil, false, false, net.ErrClosed
} }
if p.state != current { if p.state != current {
cause := p.closeCause(current)
p.access.Unlock() p.access.Unlock()
return zero, nil, false, false, cause return zero, nil, false, false, net.ErrClosed
} }
if !current.hasShared { if !current.hasShared {
p.access.Unlock() p.access.Unlock()
@@ -450,16 +379,7 @@ func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolCo
conn := current.shared conn := current.shared
if !p.options.IsAlive(conn) { if !p.options.IsAlive(conn) {
delete(current.all, conn) p.removeConn(current, conn, net.ErrClosed)
var zeroConn T
current.shared = zeroConn
current.hasShared = false
current.sharedClaimed = false
current.sharedCtx = nil
if current.sharedCancel != nil {
current.sharedCancel(net.ErrClosed)
current.sharedCancel = nil
}
p.access.Unlock() p.access.Unlock()
p.options.Close(conn, net.ErrClosed) p.options.Close(conn, net.ErrClosed)
return zero, nil, false, true, nil return zero, nil, false, true, nil
@@ -472,76 +392,9 @@ func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolCo
return conn, connCtx, created, false, nil return conn, connCtx, created, false, nil
} }
func (p *ConnPool[T]) dial(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, error) {
var zero T
if err := ctx.Err(); err != nil {
return zero, err
}
if cause := context.Cause(current.ctx); cause != nil {
return zero, cause
}
dialCtx, cancel := context.WithCancelCause(current.ctx)
var (
stateAccess sync.Mutex
dialComplete bool
)
stopCancel := context.AfterFunc(ctx, func() {
stateAccess.Lock()
if !dialComplete {
cancel(context.Cause(ctx))
}
stateAccess.Unlock()
})
select {
case <-ctx.Done():
stateAccess.Lock()
dialComplete = true
stateAccess.Unlock()
stopCancel()
cancel(context.Cause(ctx))
return zero, ctx.Err()
default:
}
conn, err := dial(connPoolDialContext{
Context: dialCtx,
parent: ctx,
})
stateAccess.Lock()
dialComplete = true
stateAccess.Unlock()
stopCancel()
if err != nil {
if cause := context.Cause(dialCtx); cause != nil {
return zero, cause
}
return zero, err
}
if cause := context.Cause(dialCtx); cause != nil {
p.options.Close(conn, cause)
return zero, cause
}
return conn, nil
}
func (p *ConnPool[T]) closeState(state *connPoolState[T], cause error) { func (p *ConnPool[T]) closeState(state *connPoolState[T], cause error) {
if state == nil {
return
}
state.cancel(cause) state.cancel(cause)
if state.sharedCancel != nil {
state.sharedCancel(cause)
}
for conn := range state.all { for conn := range state.all {
p.options.Close(conn, cause) p.options.Close(conn, cause)
} }
} }
func (p *ConnPool[T]) closeCause(state *connPoolState[T]) error {
_ = state
return net.ErrClosed
}