mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-05-14 00:51:12 +03:00
Clean up DNS transports
This commit is contained in:
@@ -68,6 +68,8 @@ type DNSTransport interface {
|
|||||||
Type() string
|
Type() string
|
||||||
Tag() string
|
Tag() string
|
||||||
Dependencies() []string
|
Dependencies() []string
|
||||||
|
// Reset closes the transport's existing connections so later requests use fresh connections.
|
||||||
|
// Exchanges that are currently using those connections may fail.
|
||||||
Reset()
|
Reset()
|
||||||
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
|
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,145 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
C "github.com/sagernet/sing-box/constant"
|
|
||||||
"github.com/sagernet/sing-box/dns"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
"github.com/sagernet/sing/common/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TransportState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
StateNew TransportState = iota
|
|
||||||
StateStarted
|
|
||||||
StateClosing
|
|
||||||
StateClosed
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrTransportClosed = os.ErrClosed
|
|
||||||
ErrConnectionReset = E.New("connection reset")
|
|
||||||
)
|
|
||||||
|
|
||||||
type BaseTransport struct {
|
|
||||||
dns.TransportAdapter
|
|
||||||
Logger logger.ContextLogger
|
|
||||||
|
|
||||||
mutex sync.Mutex
|
|
||||||
state TransportState
|
|
||||||
inFlight int32
|
|
||||||
queriesComplete chan struct{}
|
|
||||||
closeCtx context.Context
|
|
||||||
closeCancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBaseTransport(adapter dns.TransportAdapter, logger logger.ContextLogger) *BaseTransport {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
return &BaseTransport{
|
|
||||||
TransportAdapter: adapter,
|
|
||||||
Logger: logger,
|
|
||||||
state: StateNew,
|
|
||||||
closeCtx: ctx,
|
|
||||||
closeCancel: cancel,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *BaseTransport) State() TransportState {
|
|
||||||
t.mutex.Lock()
|
|
||||||
defer t.mutex.Unlock()
|
|
||||||
return t.state
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *BaseTransport) SetStarted() error {
|
|
||||||
t.mutex.Lock()
|
|
||||||
defer t.mutex.Unlock()
|
|
||||||
switch t.state {
|
|
||||||
case StateNew:
|
|
||||||
t.state = StateStarted
|
|
||||||
return nil
|
|
||||||
case StateStarted:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return ErrTransportClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *BaseTransport) BeginQuery() bool {
|
|
||||||
t.mutex.Lock()
|
|
||||||
defer t.mutex.Unlock()
|
|
||||||
if t.state != StateStarted {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
t.inFlight++
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *BaseTransport) EndQuery() {
|
|
||||||
t.mutex.Lock()
|
|
||||||
if t.inFlight > 0 {
|
|
||||||
t.inFlight--
|
|
||||||
}
|
|
||||||
if t.inFlight == 0 && t.queriesComplete != nil {
|
|
||||||
close(t.queriesComplete)
|
|
||||||
t.queriesComplete = nil
|
|
||||||
}
|
|
||||||
t.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *BaseTransport) CloseContext() context.Context {
|
|
||||||
return t.closeCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *BaseTransport) Shutdown(ctx context.Context) error {
|
|
||||||
t.mutex.Lock()
|
|
||||||
|
|
||||||
if t.state >= StateClosing {
|
|
||||||
t.mutex.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.state == StateNew {
|
|
||||||
t.state = StateClosed
|
|
||||||
t.mutex.Unlock()
|
|
||||||
t.closeCancel()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
t.state = StateClosing
|
|
||||||
|
|
||||||
if t.inFlight == 0 {
|
|
||||||
t.state = StateClosed
|
|
||||||
t.mutex.Unlock()
|
|
||||||
t.closeCancel()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
t.queriesComplete = make(chan struct{})
|
|
||||||
queriesComplete := t.queriesComplete
|
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
t.closeCancel()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-queriesComplete:
|
|
||||||
t.mutex.Lock()
|
|
||||||
t.state = StateClosed
|
|
||||||
t.mutex.Unlock()
|
|
||||||
return nil
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.mutex.Lock()
|
|
||||||
t.state = StateClosed
|
|
||||||
t.mutex.Unlock()
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *BaseTransport) Close() error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), C.TCPTimeout)
|
|
||||||
defer cancel()
|
|
||||||
return t.Shutdown(ctx)
|
|
||||||
}
|
|
||||||
547
dns/transport/conn_pool.go
Normal file
547
dns/transport/conn_pool.go
Normal file
@@ -0,0 +1,547 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common/x/list"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnPoolMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ConnPoolSingle ConnPoolMode = iota
|
||||||
|
ConnPoolOrdered
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnPoolOptions[T comparable] struct {
|
||||||
|
Mode ConnPoolMode
|
||||||
|
IsAlive func(T) bool
|
||||||
|
Close func(T, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnPool[T comparable] struct {
|
||||||
|
options ConnPoolOptions[T]
|
||||||
|
|
||||||
|
access sync.Mutex
|
||||||
|
closed bool
|
||||||
|
state *connPoolState[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
type connPoolState[T comparable] struct {
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelCauseFunc
|
||||||
|
|
||||||
|
all map[T]struct{}
|
||||||
|
|
||||||
|
idle list.List[T]
|
||||||
|
idleElements map[T]*list.Element[T]
|
||||||
|
|
||||||
|
shared T
|
||||||
|
hasShared bool
|
||||||
|
sharedClaimed bool
|
||||||
|
sharedCtx context.Context
|
||||||
|
sharedCancel context.CancelCauseFunc
|
||||||
|
|
||||||
|
connecting *connPoolConnect[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
type connPoolConnect[T comparable] struct {
|
||||||
|
done chan struct{}
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type connPoolDialContext struct {
|
||||||
|
context.Context
|
||||||
|
parent context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c connPoolDialContext) Deadline() (time.Time, bool) {
|
||||||
|
return c.parent.Deadline()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c connPoolDialContext) Value(key any) any {
|
||||||
|
return c.parent.Value(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] {
|
||||||
|
return &ConnPool[T]{
|
||||||
|
options: options,
|
||||||
|
state: newConnPoolState[T](options.Mode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnPoolState[T comparable](mode ConnPoolMode) *connPoolState[T] {
|
||||||
|
ctx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
state := &connPoolState[T]{
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
all: make(map[T]struct{}),
|
||||||
|
}
|
||||||
|
if mode == ConnPoolOrdered {
|
||||||
|
state.idleElements = make(map[T]*list.Element[T])
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) Acquire(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) {
|
||||||
|
switch p.options.Mode {
|
||||||
|
case ConnPoolSingle:
|
||||||
|
conn, _, created, err := p.acquireShared(ctx, dial)
|
||||||
|
return conn, created, err
|
||||||
|
case ConnPoolOrdered:
|
||||||
|
return p.acquireOrdered(ctx, dial)
|
||||||
|
default:
|
||||||
|
var zero T
|
||||||
|
return zero, false, net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) AcquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) {
|
||||||
|
if p.options.Mode != ConnPoolSingle {
|
||||||
|
var zero T
|
||||||
|
return zero, nil, false, net.ErrClosed
|
||||||
|
}
|
||||||
|
return p.acquireShared(ctx, dial)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) Release(conn T, reuse bool) {
|
||||||
|
var (
|
||||||
|
closeConn bool
|
||||||
|
closeErr error
|
||||||
|
)
|
||||||
|
|
||||||
|
p.access.Lock()
|
||||||
|
if p.closed || p.state == nil {
|
||||||
|
closeConn = true
|
||||||
|
closeErr = net.ErrClosed
|
||||||
|
p.access.Unlock()
|
||||||
|
if closeConn {
|
||||||
|
p.options.Close(conn, closeErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState := p.state
|
||||||
|
_, tracked := currentState.all[conn]
|
||||||
|
if !tracked {
|
||||||
|
closeConn = true
|
||||||
|
closeErr = p.closeCause(currentState)
|
||||||
|
p.access.Unlock()
|
||||||
|
if closeConn {
|
||||||
|
p.options.Close(conn, closeErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reuse || !p.options.IsAlive(conn) {
|
||||||
|
delete(currentState.all, conn)
|
||||||
|
switch p.options.Mode {
|
||||||
|
case ConnPoolSingle:
|
||||||
|
if currentState.hasShared && currentState.shared == conn {
|
||||||
|
var zero T
|
||||||
|
currentState.shared = zero
|
||||||
|
currentState.hasShared = false
|
||||||
|
currentState.sharedClaimed = false
|
||||||
|
currentState.sharedCtx = nil
|
||||||
|
if currentState.sharedCancel != nil {
|
||||||
|
currentState.sharedCancel(net.ErrClosed)
|
||||||
|
currentState.sharedCancel = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case ConnPoolOrdered:
|
||||||
|
if element, loaded := currentState.idleElements[conn]; loaded {
|
||||||
|
currentState.idle.Remove(element)
|
||||||
|
delete(currentState.idleElements, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closeConn = true
|
||||||
|
closeErr = net.ErrClosed
|
||||||
|
p.access.Unlock()
|
||||||
|
if closeConn {
|
||||||
|
p.options.Close(conn, closeErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.options.Mode == ConnPoolOrdered {
|
||||||
|
if _, loaded := currentState.idleElements[conn]; !loaded {
|
||||||
|
currentState.idleElements[conn] = currentState.idle.PushBack(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.access.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) Invalidate(conn T, cause error) {
|
||||||
|
p.access.Lock()
|
||||||
|
if p.closed || p.state == nil {
|
||||||
|
p.access.Unlock()
|
||||||
|
p.options.Close(conn, cause)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState := p.state
|
||||||
|
_, tracked := currentState.all[conn]
|
||||||
|
if !tracked {
|
||||||
|
p.access.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(currentState.all, conn)
|
||||||
|
switch p.options.Mode {
|
||||||
|
case ConnPoolSingle:
|
||||||
|
if currentState.hasShared && currentState.shared == conn {
|
||||||
|
var zero T
|
||||||
|
currentState.shared = zero
|
||||||
|
currentState.hasShared = false
|
||||||
|
currentState.sharedClaimed = false
|
||||||
|
currentState.sharedCtx = nil
|
||||||
|
if currentState.sharedCancel != nil {
|
||||||
|
currentState.sharedCancel(cause)
|
||||||
|
currentState.sharedCancel = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case ConnPoolOrdered:
|
||||||
|
if element, loaded := currentState.idleElements[conn]; loaded {
|
||||||
|
currentState.idle.Remove(element)
|
||||||
|
delete(currentState.idleElements, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.access.Unlock()
|
||||||
|
|
||||||
|
p.options.Close(conn, cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) Reset() {
|
||||||
|
p.access.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.access.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oldState := p.state
|
||||||
|
p.state = newConnPoolState[T](p.options.Mode)
|
||||||
|
p.access.Unlock()
|
||||||
|
|
||||||
|
p.closeState(oldState, net.ErrClosed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) Close() error {
|
||||||
|
p.access.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.access.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.closed = true
|
||||||
|
oldState := p.state
|
||||||
|
p.state = nil
|
||||||
|
p.access.Unlock()
|
||||||
|
|
||||||
|
p.closeState(oldState, net.ErrClosed)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) {
|
||||||
|
var zero T
|
||||||
|
for {
|
||||||
|
var (
|
||||||
|
staleConn T
|
||||||
|
hasStale bool
|
||||||
|
)
|
||||||
|
|
||||||
|
p.access.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.access.Unlock()
|
||||||
|
return zero, false, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState := p.state
|
||||||
|
if element := currentState.idle.Front(); element != nil {
|
||||||
|
conn := currentState.idle.Remove(element)
|
||||||
|
delete(currentState.idleElements, conn)
|
||||||
|
if p.options.IsAlive(conn) {
|
||||||
|
p.access.Unlock()
|
||||||
|
return conn, false, nil
|
||||||
|
}
|
||||||
|
delete(currentState.all, conn)
|
||||||
|
staleConn = conn
|
||||||
|
hasStale = true
|
||||||
|
}
|
||||||
|
p.access.Unlock()
|
||||||
|
|
||||||
|
if hasStale {
|
||||||
|
p.options.Close(staleConn, net.ErrClosed)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := p.dial(ctx, currentState, dial)
|
||||||
|
if err != nil {
|
||||||
|
return zero, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.access.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.access.Unlock()
|
||||||
|
p.options.Close(conn, net.ErrClosed)
|
||||||
|
return zero, false, net.ErrClosed
|
||||||
|
}
|
||||||
|
if p.state != currentState {
|
||||||
|
cause := p.closeCause(currentState)
|
||||||
|
p.access.Unlock()
|
||||||
|
p.options.Close(conn, cause)
|
||||||
|
return zero, false, cause
|
||||||
|
}
|
||||||
|
currentState.all[conn] = struct{}{}
|
||||||
|
p.access.Unlock()
|
||||||
|
return conn, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) {
|
||||||
|
var zero T
|
||||||
|
for {
|
||||||
|
var (
|
||||||
|
staleConn T
|
||||||
|
hasStale bool
|
||||||
|
state *connPoolConnect[T]
|
||||||
|
current *connPoolState[T]
|
||||||
|
startDial bool
|
||||||
|
)
|
||||||
|
|
||||||
|
p.access.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.access.Unlock()
|
||||||
|
return zero, nil, false, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
current = p.state
|
||||||
|
if current.hasShared {
|
||||||
|
conn := current.shared
|
||||||
|
if p.options.IsAlive(conn) {
|
||||||
|
created := !current.sharedClaimed
|
||||||
|
current.sharedClaimed = true
|
||||||
|
connCtx := current.sharedCtx
|
||||||
|
p.access.Unlock()
|
||||||
|
return conn, connCtx, created, nil
|
||||||
|
}
|
||||||
|
delete(current.all, conn)
|
||||||
|
var zeroConn T
|
||||||
|
current.shared = zeroConn
|
||||||
|
current.hasShared = false
|
||||||
|
current.sharedClaimed = false
|
||||||
|
current.sharedCtx = nil
|
||||||
|
if current.sharedCancel != nil {
|
||||||
|
current.sharedCancel(net.ErrClosed)
|
||||||
|
current.sharedCancel = nil
|
||||||
|
}
|
||||||
|
staleConn = conn
|
||||||
|
hasStale = true
|
||||||
|
p.access.Unlock()
|
||||||
|
p.options.Close(staleConn, net.ErrClosed)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if current.connecting == nil {
|
||||||
|
current.connecting = &connPoolConnect[T]{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
startDial = true
|
||||||
|
}
|
||||||
|
state = current.connecting
|
||||||
|
p.access.Unlock()
|
||||||
|
|
||||||
|
if hasStale {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if startDial {
|
||||||
|
go p.connectSingle(current, state, ctx, dial)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-state.done:
|
||||||
|
conn, connCtx, created, retry, err := p.collectShared(current, state, startDial)
|
||||||
|
if retry {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return conn, connCtx, created, err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return zero, nil, false, ctx.Err()
|
||||||
|
case <-current.ctx.Done():
|
||||||
|
p.access.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
p.access.Unlock()
|
||||||
|
if closed {
|
||||||
|
return zero, nil, false, net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) connectSingle(current *connPoolState[T], state *connPoolConnect[T], ctx context.Context, dial func(context.Context) (T, error)) {
|
||||||
|
conn, err := p.dial(ctx, current, dial)
|
||||||
|
if err != nil {
|
||||||
|
p.access.Lock()
|
||||||
|
if current.connecting == state {
|
||||||
|
current.connecting = nil
|
||||||
|
}
|
||||||
|
state.err = err
|
||||||
|
p.access.Unlock()
|
||||||
|
close(state.done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var closeErr error
|
||||||
|
|
||||||
|
p.access.Lock()
|
||||||
|
if current.connecting == state {
|
||||||
|
current.connecting = nil
|
||||||
|
}
|
||||||
|
if p.closed {
|
||||||
|
closeErr = net.ErrClosed
|
||||||
|
state.err = closeErr
|
||||||
|
} else if p.state != current {
|
||||||
|
closeErr = p.closeCause(current)
|
||||||
|
state.err = closeErr
|
||||||
|
} else {
|
||||||
|
sharedCtx, sharedCancel := context.WithCancelCause(current.ctx)
|
||||||
|
current.shared = conn
|
||||||
|
current.hasShared = true
|
||||||
|
current.sharedClaimed = false
|
||||||
|
current.sharedCtx = sharedCtx
|
||||||
|
current.sharedCancel = sharedCancel
|
||||||
|
current.all[conn] = struct{}{}
|
||||||
|
}
|
||||||
|
p.access.Unlock()
|
||||||
|
|
||||||
|
if closeErr != nil {
|
||||||
|
p.options.Close(conn, closeErr)
|
||||||
|
}
|
||||||
|
close(state.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolConnect[T], startDial bool) (T, context.Context, bool, bool, error) {
|
||||||
|
var zero T
|
||||||
|
|
||||||
|
p.access.Lock()
|
||||||
|
if state.err != nil {
|
||||||
|
err := state.err
|
||||||
|
p.access.Unlock()
|
||||||
|
if startDial {
|
||||||
|
return zero, nil, false, false, err
|
||||||
|
}
|
||||||
|
return zero, nil, false, true, nil
|
||||||
|
}
|
||||||
|
if p.closed {
|
||||||
|
p.access.Unlock()
|
||||||
|
return zero, nil, false, false, net.ErrClosed
|
||||||
|
}
|
||||||
|
if p.state != current {
|
||||||
|
cause := p.closeCause(current)
|
||||||
|
p.access.Unlock()
|
||||||
|
return zero, nil, false, false, cause
|
||||||
|
}
|
||||||
|
if !current.hasShared {
|
||||||
|
p.access.Unlock()
|
||||||
|
return zero, nil, false, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := current.shared
|
||||||
|
if !p.options.IsAlive(conn) {
|
||||||
|
delete(current.all, conn)
|
||||||
|
var zeroConn T
|
||||||
|
current.shared = zeroConn
|
||||||
|
current.hasShared = false
|
||||||
|
current.sharedClaimed = false
|
||||||
|
current.sharedCtx = nil
|
||||||
|
if current.sharedCancel != nil {
|
||||||
|
current.sharedCancel(net.ErrClosed)
|
||||||
|
current.sharedCancel = nil
|
||||||
|
}
|
||||||
|
p.access.Unlock()
|
||||||
|
p.options.Close(conn, net.ErrClosed)
|
||||||
|
return zero, nil, false, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
created := !current.sharedClaimed
|
||||||
|
current.sharedClaimed = true
|
||||||
|
connCtx := current.sharedCtx
|
||||||
|
p.access.Unlock()
|
||||||
|
return conn, connCtx, created, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) dial(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, error) {
|
||||||
|
var zero T
|
||||||
|
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
if cause := context.Cause(current.ctx); cause != nil {
|
||||||
|
return zero, cause
|
||||||
|
}
|
||||||
|
|
||||||
|
dialCtx, cancel := context.WithCancelCause(current.ctx)
|
||||||
|
var (
|
||||||
|
stateAccess sync.Mutex
|
||||||
|
dialComplete bool
|
||||||
|
)
|
||||||
|
stopCancel := context.AfterFunc(ctx, func() {
|
||||||
|
stateAccess.Lock()
|
||||||
|
if !dialComplete {
|
||||||
|
cancel(context.Cause(ctx))
|
||||||
|
}
|
||||||
|
stateAccess.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
stateAccess.Lock()
|
||||||
|
dialComplete = true
|
||||||
|
stateAccess.Unlock()
|
||||||
|
stopCancel()
|
||||||
|
cancel(context.Cause(ctx))
|
||||||
|
return zero, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dial(connPoolDialContext{
|
||||||
|
Context: dialCtx,
|
||||||
|
parent: ctx,
|
||||||
|
})
|
||||||
|
stateAccess.Lock()
|
||||||
|
dialComplete = true
|
||||||
|
stateAccess.Unlock()
|
||||||
|
stopCancel()
|
||||||
|
if err != nil {
|
||||||
|
if cause := context.Cause(dialCtx); cause != nil {
|
||||||
|
return zero, cause
|
||||||
|
}
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
if cause := context.Cause(dialCtx); cause != nil {
|
||||||
|
p.options.Close(conn, cause)
|
||||||
|
return zero, cause
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) closeState(state *connPoolState[T], cause error) {
|
||||||
|
if state == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state.cancel(cause)
|
||||||
|
if state.sharedCancel != nil {
|
||||||
|
state.sharedCancel(cause)
|
||||||
|
}
|
||||||
|
for conn := range state.all {
|
||||||
|
p.options.Close(conn, cause)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool[T]) closeCause(state *connPoolState[T]) error {
|
||||||
|
_ = state
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
@@ -1,321 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ConnectorCallbacks[T any] struct {
|
|
||||||
IsClosed func(connection T) bool
|
|
||||||
Close func(connection T)
|
|
||||||
Reset func(connection T)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Connector[T any] struct {
|
|
||||||
dial func(ctx context.Context) (T, error)
|
|
||||||
callbacks ConnectorCallbacks[T]
|
|
||||||
|
|
||||||
access sync.Mutex
|
|
||||||
connection T
|
|
||||||
hasConnection bool
|
|
||||||
connectionCancel context.CancelFunc
|
|
||||||
connecting chan struct{}
|
|
||||||
|
|
||||||
closeCtx context.Context
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConnector[T any](closeCtx context.Context, dial func(context.Context) (T, error), callbacks ConnectorCallbacks[T]) *Connector[T] {
|
|
||||||
return &Connector[T]{
|
|
||||||
dial: dial,
|
|
||||||
callbacks: callbacks,
|
|
||||||
closeCtx: closeCtx,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSingleflightConnector(closeCtx context.Context, dial func(context.Context) (*Connection, error)) *Connector[*Connection] {
|
|
||||||
return NewConnector(closeCtx, dial, ConnectorCallbacks[*Connection]{
|
|
||||||
IsClosed: func(connection *Connection) bool {
|
|
||||||
return connection.IsClosed()
|
|
||||||
},
|
|
||||||
Close: func(connection *Connection) {
|
|
||||||
connection.CloseWithError(ErrTransportClosed)
|
|
||||||
},
|
|
||||||
Reset: func(connection *Connection) {
|
|
||||||
connection.CloseWithError(ErrConnectionReset)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type contextKeyConnecting struct{}
|
|
||||||
|
|
||||||
var errRecursiveConnectorDial = E.New("recursive connector dial")
|
|
||||||
|
|
||||||
type connectorDialResult[T any] struct {
|
|
||||||
connection T
|
|
||||||
cancel context.CancelFunc
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
|
||||||
var zero T
|
|
||||||
for {
|
|
||||||
c.access.Lock()
|
|
||||||
|
|
||||||
if c.closed {
|
|
||||||
c.access.Unlock()
|
|
||||||
return zero, ErrTransportClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.hasConnection && !c.callbacks.IsClosed(c.connection) {
|
|
||||||
connection := c.connection
|
|
||||||
c.access.Unlock()
|
|
||||||
return connection, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c.hasConnection = false
|
|
||||||
if c.connectionCancel != nil {
|
|
||||||
c.connectionCancel()
|
|
||||||
c.connectionCancel = nil
|
|
||||||
}
|
|
||||||
if isRecursiveConnectorDial(ctx, c) {
|
|
||||||
c.access.Unlock()
|
|
||||||
return zero, errRecursiveConnectorDial
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.connecting != nil {
|
|
||||||
connecting := c.connecting
|
|
||||||
c.access.Unlock()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-connecting:
|
|
||||||
continue
|
|
||||||
case <-ctx.Done():
|
|
||||||
return zero, ctx.Err()
|
|
||||||
case <-c.closeCtx.Done():
|
|
||||||
return zero, ErrTransportClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ctx.Err(); err != nil {
|
|
||||||
c.access.Unlock()
|
|
||||||
return zero, err
|
|
||||||
}
|
|
||||||
|
|
||||||
connecting := make(chan struct{})
|
|
||||||
c.connecting = connecting
|
|
||||||
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
|
|
||||||
dialResult := make(chan connectorDialResult[T], 1)
|
|
||||||
c.access.Unlock()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
connection, cancel, err := c.dialWithCancellation(dialContext)
|
|
||||||
dialResult <- connectorDialResult[T]{
|
|
||||||
connection: connection,
|
|
||||||
cancel: cancel,
|
|
||||||
err: err,
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case result := <-dialResult:
|
|
||||||
return c.completeDial(ctx, connecting, result)
|
|
||||||
case <-ctx.Done():
|
|
||||||
go func() {
|
|
||||||
result := <-dialResult
|
|
||||||
_, _ = c.completeDial(ctx, connecting, result)
|
|
||||||
}()
|
|
||||||
return zero, ctx.Err()
|
|
||||||
case <-c.closeCtx.Done():
|
|
||||||
go func() {
|
|
||||||
result := <-dialResult
|
|
||||||
_, _ = c.completeDial(ctx, connecting, result)
|
|
||||||
}()
|
|
||||||
return zero, ErrTransportClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T]) bool {
|
|
||||||
dialConnector, loaded := ctx.Value(contextKeyConnecting{}).(*Connector[T])
|
|
||||||
return loaded && dialConnector == connector
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connector[T]) completeDial(ctx context.Context, connecting chan struct{}, result connectorDialResult[T]) (T, error) {
|
|
||||||
var zero T
|
|
||||||
|
|
||||||
c.access.Lock()
|
|
||||||
defer c.access.Unlock()
|
|
||||||
defer func() {
|
|
||||||
if c.connecting == connecting {
|
|
||||||
c.connecting = nil
|
|
||||||
}
|
|
||||||
close(connecting)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if result.err != nil {
|
|
||||||
return zero, result.err
|
|
||||||
}
|
|
||||||
if c.closed || c.closeCtx.Err() != nil {
|
|
||||||
result.cancel()
|
|
||||||
c.callbacks.Close(result.connection)
|
|
||||||
return zero, ErrTransportClosed
|
|
||||||
}
|
|
||||||
if err := ctx.Err(); err != nil {
|
|
||||||
result.cancel()
|
|
||||||
c.callbacks.Close(result.connection)
|
|
||||||
return zero, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.connection = result.connection
|
|
||||||
c.hasConnection = true
|
|
||||||
c.connectionCancel = result.cancel
|
|
||||||
return c.connection, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) {
|
|
||||||
var zero T
|
|
||||||
if err := ctx.Err(); err != nil {
|
|
||||||
return zero, nil, err
|
|
||||||
}
|
|
||||||
connCtx, cancel := context.WithCancel(c.closeCtx)
|
|
||||||
|
|
||||||
var (
|
|
||||||
stateAccess sync.Mutex
|
|
||||||
dialComplete bool
|
|
||||||
)
|
|
||||||
stopCancel := context.AfterFunc(ctx, func() {
|
|
||||||
stateAccess.Lock()
|
|
||||||
if !dialComplete {
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
stateAccess.Unlock()
|
|
||||||
})
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
stateAccess.Lock()
|
|
||||||
dialComplete = true
|
|
||||||
stateAccess.Unlock()
|
|
||||||
stopCancel()
|
|
||||||
cancel()
|
|
||||||
return zero, nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
connection, err := c.dial(valueContext{connCtx, ctx})
|
|
||||||
stateAccess.Lock()
|
|
||||||
dialComplete = true
|
|
||||||
stateAccess.Unlock()
|
|
||||||
stopCancel()
|
|
||||||
if err != nil {
|
|
||||||
cancel()
|
|
||||||
return zero, nil, err
|
|
||||||
}
|
|
||||||
return connection, cancel, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type valueContext struct {
|
|
||||||
context.Context
|
|
||||||
parent context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v valueContext) Value(key any) any {
|
|
||||||
return v.parent.Value(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v valueContext) Deadline() (time.Time, bool) {
|
|
||||||
return v.parent.Deadline()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connector[T]) Close() error {
|
|
||||||
c.access.Lock()
|
|
||||||
defer c.access.Unlock()
|
|
||||||
|
|
||||||
if c.closed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
c.closed = true
|
|
||||||
|
|
||||||
if c.connectionCancel != nil {
|
|
||||||
c.connectionCancel()
|
|
||||||
c.connectionCancel = nil
|
|
||||||
}
|
|
||||||
if c.hasConnection {
|
|
||||||
c.callbacks.Close(c.connection)
|
|
||||||
c.hasConnection = false
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connector[T]) Reset() {
|
|
||||||
c.access.Lock()
|
|
||||||
defer c.access.Unlock()
|
|
||||||
|
|
||||||
if c.connectionCancel != nil {
|
|
||||||
c.connectionCancel()
|
|
||||||
c.connectionCancel = nil
|
|
||||||
}
|
|
||||||
if c.hasConnection {
|
|
||||||
c.callbacks.Reset(c.connection)
|
|
||||||
c.hasConnection = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Connection struct {
|
|
||||||
net.Conn
|
|
||||||
|
|
||||||
closeOnce sync.Once
|
|
||||||
done chan struct{}
|
|
||||||
closeError error
|
|
||||||
}
|
|
||||||
|
|
||||||
func WrapConnection(conn net.Conn) *Connection {
|
|
||||||
return &Connection{
|
|
||||||
Conn: conn,
|
|
||||||
done: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) Done() <-chan struct{} {
|
|
||||||
return c.done
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) IsClosed() bool {
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) CloseError() error {
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
if c.closeError != nil {
|
|
||||||
return c.closeError
|
|
||||||
}
|
|
||||||
return ErrTransportClosed
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) Close() error {
|
|
||||||
return c.CloseWithError(ErrTransportClosed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) CloseWithError(err error) error {
|
|
||||||
var returnError error
|
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
c.closeError = err
|
|
||||||
returnError = c.Conn.Close()
|
|
||||||
close(c.done)
|
|
||||||
})
|
|
||||||
return returnError
|
|
||||||
}
|
|
||||||
@@ -1,407 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type testConnectorConnection struct{}
|
|
||||||
|
|
||||||
func TestConnectorRecursiveGetFailsFast(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var (
|
|
||||||
dialCount atomic.Int32
|
|
||||||
closeCount atomic.Int32
|
|
||||||
connector *Connector[*testConnectorConnection]
|
|
||||||
)
|
|
||||||
|
|
||||||
dial := func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
dialCount.Add(1)
|
|
||||||
_, err := connector.Get(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
connector = NewConnector(context.Background(), dial, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {
|
|
||||||
closeCount.Add(1)
|
|
||||||
},
|
|
||||||
Reset: func(connection *testConnectorConnection) {
|
|
||||||
closeCount.Add(1)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
_, err := connector.Get(context.Background())
|
|
||||||
require.ErrorIs(t, err, errRecursiveConnectorDial)
|
|
||||||
require.EqualValues(t, 1, dialCount.Load())
|
|
||||||
require.EqualValues(t, 0, closeCount.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectorRecursiveGetAcrossConnectorsAllowed(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var (
|
|
||||||
outerDialCount atomic.Int32
|
|
||||||
innerDialCount atomic.Int32
|
|
||||||
outerConnector *Connector[*testConnectorConnection]
|
|
||||||
innerConnector *Connector[*testConnectorConnection]
|
|
||||||
)
|
|
||||||
|
|
||||||
innerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
innerDialCount.Add(1)
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {},
|
|
||||||
Reset: func(connection *testConnectorConnection) {},
|
|
||||||
})
|
|
||||||
|
|
||||||
outerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
outerDialCount.Add(1)
|
|
||||||
_, err := innerConnector.Get(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {},
|
|
||||||
Reset: func(connection *testConnectorConnection) {},
|
|
||||||
})
|
|
||||||
|
|
||||||
_, err := outerConnector.Get(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.EqualValues(t, 1, outerDialCount.Load())
|
|
||||||
require.EqualValues(t, 1, innerDialCount.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectorDialContextPreservesValueAndDeadline(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
type contextKey struct{}
|
|
||||||
|
|
||||||
var (
|
|
||||||
dialValue any
|
|
||||||
dialDeadline time.Time
|
|
||||||
dialHasDeadline bool
|
|
||||||
)
|
|
||||||
|
|
||||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
dialValue = ctx.Value(contextKey{})
|
|
||||||
dialDeadline, dialHasDeadline = ctx.Deadline()
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {},
|
|
||||||
Reset: func(connection *testConnectorConnection) {},
|
|
||||||
})
|
|
||||||
|
|
||||||
deadline := time.Now().Add(time.Minute)
|
|
||||||
requestContext, cancel := context.WithDeadline(context.WithValue(context.Background(), contextKey{}, "test-value"), deadline)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, err := connector.Get(requestContext)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "test-value", dialValue)
|
|
||||||
require.True(t, dialHasDeadline)
|
|
||||||
require.WithinDuration(t, deadline, dialDeadline, time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectorDialSkipsCanceledRequest(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var dialCount atomic.Int32
|
|
||||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
dialCount.Add(1)
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {},
|
|
||||||
Reset: func(connection *testConnectorConnection) {},
|
|
||||||
})
|
|
||||||
|
|
||||||
requestContext, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
_, err := connector.Get(requestContext)
|
|
||||||
require.ErrorIs(t, err, context.Canceled)
|
|
||||||
require.EqualValues(t, 0, dialCount.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var (
|
|
||||||
dialCount atomic.Int32
|
|
||||||
closeCount atomic.Int32
|
|
||||||
)
|
|
||||||
dialStarted := make(chan struct{}, 1)
|
|
||||||
releaseDial := make(chan struct{})
|
|
||||||
|
|
||||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
dialCount.Add(1)
|
|
||||||
select {
|
|
||||||
case dialStarted <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
<-releaseDial
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {
|
|
||||||
closeCount.Add(1)
|
|
||||||
},
|
|
||||||
Reset: func(connection *testConnectorConnection) {},
|
|
||||||
})
|
|
||||||
|
|
||||||
requestContext, cancel := context.WithCancel(context.Background())
|
|
||||||
result := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
_, err := connector.Get(requestContext)
|
|
||||||
result <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-dialStarted
|
|
||||||
cancel()
|
|
||||||
close(releaseDial)
|
|
||||||
|
|
||||||
err := <-result
|
|
||||||
require.ErrorIs(t, err, context.Canceled)
|
|
||||||
require.EqualValues(t, 1, dialCount.Load())
|
|
||||||
require.Eventually(t, func() bool {
|
|
||||||
return closeCount.Load() == 1
|
|
||||||
}, time.Second, 10*time.Millisecond)
|
|
||||||
|
|
||||||
_, err = connector.Get(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.EqualValues(t, 2, dialCount.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectorCanceledRequestReturnsBeforeIgnoredDialCompletes(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var (
|
|
||||||
dialCount atomic.Int32
|
|
||||||
closeCount atomic.Int32
|
|
||||||
)
|
|
||||||
dialStarted := make(chan struct{}, 1)
|
|
||||||
releaseDial := make(chan struct{})
|
|
||||||
|
|
||||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
dialCount.Add(1)
|
|
||||||
select {
|
|
||||||
case dialStarted <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
<-releaseDial
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {
|
|
||||||
closeCount.Add(1)
|
|
||||||
},
|
|
||||||
Reset: func(connection *testConnectorConnection) {},
|
|
||||||
})
|
|
||||||
|
|
||||||
requestContext, cancel := context.WithCancel(context.Background())
|
|
||||||
result := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
_, err := connector.Get(requestContext)
|
|
||||||
result <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-dialStarted
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-result:
|
|
||||||
require.ErrorIs(t, err, context.Canceled)
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("Get did not return after request cancel")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.EqualValues(t, 1, dialCount.Load())
|
|
||||||
require.EqualValues(t, 0, closeCount.Load())
|
|
||||||
|
|
||||||
close(releaseDial)
|
|
||||||
|
|
||||||
require.Eventually(t, func() bool {
|
|
||||||
return closeCount.Load() == 1
|
|
||||||
}, time.Second, 10*time.Millisecond)
|
|
||||||
|
|
||||||
_, err := connector.Get(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.EqualValues(t, 2, dialCount.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectorWaiterDoesNotStartNewDialBeforeCanceledDialCompletes(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var (
|
|
||||||
dialCount atomic.Int32
|
|
||||||
closeCount atomic.Int32
|
|
||||||
)
|
|
||||||
firstDialStarted := make(chan struct{}, 1)
|
|
||||||
secondDialStarted := make(chan struct{}, 1)
|
|
||||||
releaseFirstDial := make(chan struct{})
|
|
||||||
|
|
||||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
attempt := dialCount.Add(1)
|
|
||||||
switch attempt {
|
|
||||||
case 1:
|
|
||||||
select {
|
|
||||||
case firstDialStarted <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
<-releaseFirstDial
|
|
||||||
case 2:
|
|
||||||
select {
|
|
||||||
case secondDialStarted <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {
|
|
||||||
closeCount.Add(1)
|
|
||||||
},
|
|
||||||
Reset: func(connection *testConnectorConnection) {},
|
|
||||||
})
|
|
||||||
|
|
||||||
requestContext, cancel := context.WithCancel(context.Background())
|
|
||||||
firstResult := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
_, err := connector.Get(requestContext)
|
|
||||||
firstResult <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-firstDialStarted
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
secondResult := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
_, err := connector.Get(context.Background())
|
|
||||||
secondResult <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-secondDialStarted:
|
|
||||||
t.Fatal("second dial started before first dial completed")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-firstResult:
|
|
||||||
require.ErrorIs(t, err, context.Canceled)
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("first Get did not return after request cancel")
|
|
||||||
}
|
|
||||||
|
|
||||||
close(releaseFirstDial)
|
|
||||||
|
|
||||||
require.Eventually(t, func() bool {
|
|
||||||
return closeCount.Load() == 1
|
|
||||||
}, time.Second, 10*time.Millisecond)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-secondDialStarted:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("second dial did not start after first dial completed")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := <-secondResult
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.EqualValues(t, 2, dialCount.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var dialContext context.Context
|
|
||||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
dialContext = ctx
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {},
|
|
||||||
Reset: func(connection *testConnectorConnection) {},
|
|
||||||
})
|
|
||||||
|
|
||||||
requestContext, cancel := context.WithCancel(context.Background())
|
|
||||||
_, err := connector.Get(requestContext)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, dialContext)
|
|
||||||
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-dialContext.Done():
|
|
||||||
t.Fatal("dial context canceled by request context after successful dial")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
}
|
|
||||||
|
|
||||||
err = connector.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectorDialContextCanceledOnClose(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var dialContext context.Context
|
|
||||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
||||||
dialContext = ctx
|
|
||||||
return &testConnectorConnection{}, nil
|
|
||||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
||||||
IsClosed: func(connection *testConnectorConnection) bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
Close: func(connection *testConnectorConnection) {},
|
|
||||||
Reset: func(connection *testConnectorConnection) {},
|
|
||||||
})
|
|
||||||
|
|
||||||
_, err := connector.Get(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, dialContext)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-dialContext.Done():
|
|
||||||
t.Fatal("dial context canceled before connector close")
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
err = connector.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-dialContext.Done():
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("dial context not canceled after connector close")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -31,14 +31,13 @@ func RegisterTransport(registry *dns.TransportRegistry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Transport struct {
|
type Transport struct {
|
||||||
*transport.BaseTransport
|
dns.TransportAdapter
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
serverAddr M.Socksaddr
|
serverAddr M.Socksaddr
|
||||||
tlsConfig tls.Config
|
tlsConfig tls.Config
|
||||||
|
|
||||||
connector *transport.Connector[*quic.Conn]
|
connection *transport.ConnPool[*quic.Conn]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
|
func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
|
||||||
@@ -63,93 +62,76 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options
|
|||||||
return nil, E.New("invalid server address: ", serverAddr)
|
return nil, E.New("invalid server address: ", serverAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &Transport{
|
return &Transport{
|
||||||
BaseTransport: transport.NewBaseTransport(
|
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions),
|
||||||
dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions),
|
dialer: transportDialer,
|
||||||
logger,
|
serverAddr: serverAddr,
|
||||||
),
|
tlsConfig: tlsConfig,
|
||||||
ctx: ctx,
|
connection: transport.NewConnPool(transport.ConnPoolOptions[*quic.Conn]{
|
||||||
dialer: transportDialer,
|
Mode: transport.ConnPoolSingle,
|
||||||
serverAddr: serverAddr,
|
IsAlive: func(conn *quic.Conn) bool {
|
||||||
tlsConfig: tlsConfig,
|
return conn != nil && !common.Done(conn.Context())
|
||||||
}
|
},
|
||||||
|
Close: func(conn *quic.Conn, _ error) {
|
||||||
t.connector = transport.NewConnector(t.CloseContext(), t.dial, transport.ConnectorCallbacks[*quic.Conn]{
|
conn.CloseWithError(0, "")
|
||||||
IsClosed: func(connection *quic.Conn) bool {
|
},
|
||||||
return common.Done(connection.Context())
|
}),
|
||||||
},
|
}, nil
|
||||||
Close: func(connection *quic.Conn) {
|
|
||||||
connection.CloseWithError(0, "")
|
|
||||||
},
|
|
||||||
Reset: func(connection *quic.Conn) {
|
|
||||||
connection.CloseWithError(0, "")
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Transport) dial(ctx context.Context) (*quic.Conn, error) {
|
|
||||||
conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "dial UDP connection")
|
|
||||||
}
|
|
||||||
earlyConnection, err := sQUIC.DialEarly(
|
|
||||||
ctx,
|
|
||||||
bufio.NewUnbindPacketConn(conn),
|
|
||||||
t.serverAddr.UDPAddr(),
|
|
||||||
t.tlsConfig,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
conn.Close()
|
|
||||||
return nil, E.Cause(err, "establish QUIC connection")
|
|
||||||
}
|
|
||||||
return earlyConnection, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) Start(stage adapter.StartStage) error {
|
func (t *Transport) Start(stage adapter.StartStage) error {
|
||||||
if stage != adapter.StartStateStart {
|
if stage != adapter.StartStateStart {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err := t.SetStarted()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return dialer.InitializeDetour(t.dialer)
|
return dialer.InitializeDetour(t.dialer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) Close() error {
|
func (t *Transport) Close() error {
|
||||||
return E.Errors(t.BaseTransport.Close(), t.connector.Close())
|
return t.connection.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) Reset() {
|
func (t *Transport) Reset() {
|
||||||
t.connector.Reset()
|
t.connection.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
if !t.BeginQuery() {
|
|
||||||
return nil, transport.ErrTransportClosed
|
|
||||||
}
|
|
||||||
defer t.EndQuery()
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
conn *quic.Conn
|
conn *quic.Conn
|
||||||
err error
|
err error
|
||||||
response *mDNS.Msg
|
response *mDNS.Msg
|
||||||
)
|
)
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
conn, err = t.connector.Get(ctx)
|
conn, _, err = t.connection.Acquire(ctx, func(ctx context.Context) (*quic.Conn, error) {
|
||||||
|
rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "dial UDP connection")
|
||||||
|
}
|
||||||
|
earlyConnection, err := sQUIC.DialEarly(
|
||||||
|
ctx,
|
||||||
|
bufio.NewUnbindPacketConn(rawConn),
|
||||||
|
t.serverAddr.UDPAddr(),
|
||||||
|
t.tlsConfig,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
rawConn.Close()
|
||||||
|
return nil, E.Cause(err, "establish QUIC connection")
|
||||||
|
}
|
||||||
|
return earlyConnection, nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
response, err = t.exchange(ctx, message, conn)
|
response, err = t.exchange(ctx, message, conn)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
t.connection.Release(conn, true)
|
||||||
return response, nil
|
return response, nil
|
||||||
} else if !isQUICRetryError(err) {
|
} else if !isQUICRetryError(err) {
|
||||||
|
t.connection.Release(conn, true)
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
t.connector.Reset()
|
t.connection.Release(conn, true)
|
||||||
|
t.Reset()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package transport
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
@@ -17,7 +16,6 @@ import (
|
|||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/sing/common/x/list"
|
|
||||||
|
|
||||||
mDNS "github.com/miekg/dns"
|
mDNS "github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
@@ -29,13 +27,13 @@ func RegisterTLS(registry *dns.TransportRegistry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TLSTransport struct {
|
type TLSTransport struct {
|
||||||
*BaseTransport
|
dns.TransportAdapter
|
||||||
|
logger logger.ContextLogger
|
||||||
|
|
||||||
dialer tls.Dialer
|
dialer tls.Dialer
|
||||||
serverAddr M.Socksaddr
|
serverAddr M.Socksaddr
|
||||||
tlsConfig tls.Config
|
tlsConfig tls.Config
|
||||||
access sync.Mutex
|
connections *ConnPool[*tlsDNSConn]
|
||||||
connections list.List[*tlsDNSConn]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type tlsDNSConn struct {
|
type tlsDNSConn struct {
|
||||||
@@ -66,10 +64,20 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o
|
|||||||
|
|
||||||
func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport {
|
func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport {
|
||||||
return &TLSTransport{
|
return &TLSTransport{
|
||||||
BaseTransport: NewBaseTransport(adapter, logger),
|
TransportAdapter: adapter,
|
||||||
dialer: tls.NewDialer(dialer, tlsConfig),
|
logger: logger,
|
||||||
serverAddr: serverAddr,
|
dialer: tls.NewDialer(dialer, tlsConfig),
|
||||||
tlsConfig: tlsConfig,
|
serverAddr: serverAddr,
|
||||||
|
tlsConfig: tlsConfig,
|
||||||
|
connections: NewConnPool(ConnPoolOptions[*tlsDNSConn]{
|
||||||
|
Mode: ConnPoolOrdered,
|
||||||
|
IsAlive: func(conn *tlsDNSConn) bool {
|
||||||
|
return conn != nil
|
||||||
|
},
|
||||||
|
Close: func(conn *tlsDNSConn, _ error) {
|
||||||
|
conn.Close()
|
||||||
|
},
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,53 +85,43 @@ func (t *TLSTransport) Start(stage adapter.StartStage) error {
|
|||||||
if stage != adapter.StartStateStart {
|
if stage != adapter.StartStateStart {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err := t.SetStarted()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return dialer.InitializeDetour(t.dialer)
|
return dialer.InitializeDetour(t.dialer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TLSTransport) Close() error {
|
func (t *TLSTransport) Close() error {
|
||||||
t.access.Lock()
|
return t.connections.Close()
|
||||||
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
|
|
||||||
connection.Value.Close()
|
|
||||||
}
|
|
||||||
t.connections.Init()
|
|
||||||
t.access.Unlock()
|
|
||||||
return t.BaseTransport.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TLSTransport) Reset() {
|
func (t *TLSTransport) Reset() {
|
||||||
t.access.Lock()
|
t.connections.Reset()
|
||||||
defer t.access.Unlock()
|
|
||||||
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
|
|
||||||
connection.Value.Close()
|
|
||||||
}
|
|
||||||
t.connections.Init()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
if !t.BeginQuery() {
|
var lastErr error
|
||||||
return nil, ErrTransportClosed
|
for attempt := 0; attempt < 2; attempt++ {
|
||||||
}
|
conn, created, err := t.connections.Acquire(ctx, func(ctx context.Context) (*tlsDNSConn, error) {
|
||||||
defer t.EndQuery()
|
tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr)
|
||||||
|
if err != nil {
|
||||||
t.access.Lock()
|
return nil, E.Cause(err, "dial TLS connection")
|
||||||
conn := t.connections.PopFront()
|
}
|
||||||
t.access.Unlock()
|
return &tlsDNSConn{Conn: tlsConn}, nil
|
||||||
if conn != nil {
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
response, err := t.exchange(ctx, message, conn)
|
response, err := t.exchange(ctx, message, conn)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
t.connections.Release(conn, true)
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
t.Logger.DebugContext(ctx, "discarded pooled connection: ", err)
|
lastErr = err
|
||||||
|
t.logger.DebugContext(ctx, "discarded pooled connection: ", err)
|
||||||
|
t.connections.Release(conn, false)
|
||||||
|
if created {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr)
|
return nil, lastErr
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "dial TLS connection")
|
|
||||||
}
|
|
||||||
return t.exchange(ctx, message, &tlsDNSConn{Conn: tlsConn})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
|
func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
|
||||||
@@ -133,22 +131,12 @@ func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tl
|
|||||||
conn.queryId++
|
conn.queryId++
|
||||||
err := WriteMessage(conn, conn.queryId, message)
|
err := WriteMessage(conn, conn.queryId, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
|
||||||
return nil, E.Cause(err, "write request")
|
return nil, E.Cause(err, "write request")
|
||||||
}
|
}
|
||||||
response, err := ReadMessage(conn)
|
response, err := ReadMessage(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
|
||||||
return nil, E.Cause(err, "read response")
|
return nil, E.Cause(err, "read response")
|
||||||
}
|
}
|
||||||
t.access.Lock()
|
|
||||||
if t.State() >= StateClosing {
|
|
||||||
t.access.Unlock()
|
|
||||||
conn.Close()
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
conn.SetDeadline(time.Time{})
|
conn.SetDeadline(time.Time{})
|
||||||
t.connections.PushBack(conn)
|
|
||||||
t.access.Unlock()
|
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package transport
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@@ -27,13 +28,14 @@ func RegisterUDP(registry *dns.TransportRegistry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type UDPTransport struct {
|
type UDPTransport struct {
|
||||||
*BaseTransport
|
dns.TransportAdapter
|
||||||
|
logger logger.ContextLogger
|
||||||
|
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
serverAddr M.Socksaddr
|
serverAddr M.Socksaddr
|
||||||
udpSize atomic.Int32
|
udpSize atomic.Int32
|
||||||
|
|
||||||
connector *Connector[*Connection]
|
connection *ConnPool[net.Conn]
|
||||||
|
|
||||||
callbackAccess sync.RWMutex
|
callbackAccess sync.RWMutex
|
||||||
queryId uint16
|
queryId uint16
|
||||||
@@ -63,43 +65,38 @@ func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options o
|
|||||||
|
|
||||||
func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
|
func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
|
||||||
t := &UDPTransport{
|
t := &UDPTransport{
|
||||||
BaseTransport: NewBaseTransport(adapter, logger),
|
TransportAdapter: adapter,
|
||||||
dialer: dialerInstance,
|
logger: logger,
|
||||||
serverAddr: serverAddr,
|
dialer: dialerInstance,
|
||||||
callbacks: make(map[uint16]*udpCallback),
|
serverAddr: serverAddr,
|
||||||
|
callbacks: make(map[uint16]*udpCallback),
|
||||||
|
connection: NewConnPool(ConnPoolOptions[net.Conn]{
|
||||||
|
Mode: ConnPoolSingle,
|
||||||
|
IsAlive: func(conn net.Conn) bool {
|
||||||
|
return conn != nil
|
||||||
|
},
|
||||||
|
Close: func(conn net.Conn, cause error) {
|
||||||
|
conn.Close()
|
||||||
|
},
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
t.udpSize.Store(2048)
|
t.udpSize.Store(2048)
|
||||||
t.connector = NewSingleflightConnector(t.CloseContext(), t.dial)
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) dial(ctx context.Context) (*Connection, error) {
|
|
||||||
rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "dial UDP connection")
|
|
||||||
}
|
|
||||||
conn := WrapConnection(rawConn)
|
|
||||||
go t.recvLoop(conn)
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *UDPTransport) Start(stage adapter.StartStage) error {
|
func (t *UDPTransport) Start(stage adapter.StartStage) error {
|
||||||
if stage != adapter.StartStateStart {
|
if stage != adapter.StartStateStart {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err := t.SetStarted()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return dialer.InitializeDetour(t.dialer)
|
return dialer.InitializeDetour(t.dialer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) Close() error {
|
func (t *UDPTransport) Close() error {
|
||||||
return E.Errors(t.BaseTransport.Close(), t.connector.Close())
|
return t.connection.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) Reset() {
|
func (t *UDPTransport) Reset() {
|
||||||
t.connector.Reset()
|
t.connection.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) nextAvailableQueryId() (uint16, error) {
|
func (t *UDPTransport) nextAvailableQueryId() (uint16, error) {
|
||||||
@@ -116,17 +113,12 @@ func (t *UDPTransport) nextAvailableQueryId() (uint16, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
if !t.BeginQuery() {
|
|
||||||
return nil, ErrTransportClosed
|
|
||||||
}
|
|
||||||
defer t.EndQuery()
|
|
||||||
|
|
||||||
response, err := t.exchange(ctx, message)
|
response, err := t.exchange(ctx, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if response.Truncated {
|
if response.Truncated {
|
||||||
t.Logger.InfoContext(ctx, "response truncated, retrying with TCP")
|
t.logger.InfoContext(ctx, "response truncated, retrying with TCP")
|
||||||
return t.exchangeTCP(ctx, message)
|
return t.exchangeTCP(ctx, message)
|
||||||
}
|
}
|
||||||
return response, nil
|
return response, nil
|
||||||
@@ -158,16 +150,25 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
if t.udpSize.CompareAndSwap(current, udpSize) {
|
if t.udpSize.CompareAndSwap(current, udpSize) {
|
||||||
t.connector.Reset()
|
t.Reset()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := t.connector.Get(ctx)
|
conn, connCtx, created, err := t.connection.AcquireShared(ctx, func(ctx context.Context) (net.Conn, error) {
|
||||||
|
rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "dial UDP connection")
|
||||||
|
}
|
||||||
|
return rawConn, nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if created {
|
||||||
|
go t.recvLoop(conn)
|
||||||
|
}
|
||||||
|
|
||||||
callback := &udpCallback{
|
callback := &udpCallback{
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
@@ -177,6 +178,7 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||||||
queryId, err := t.nextAvailableQueryId()
|
queryId, err := t.nextAvailableQueryId()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.callbackAccess.Unlock()
|
t.callbackAccess.Unlock()
|
||||||
|
t.connection.Release(conn, true)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t.callbacks[queryId] = callback
|
t.callbacks[queryId] = callback
|
||||||
@@ -203,30 +205,30 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||||||
|
|
||||||
_, err = conn.Write(rawMessage)
|
_, err = conn.Write(rawMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.CloseWithError(err)
|
t.connection.Invalidate(conn, err)
|
||||||
return nil, E.Cause(err, "write request")
|
return nil, E.Cause(err, "write request")
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-callback.done:
|
case <-callback.done:
|
||||||
|
t.connection.Release(conn, true)
|
||||||
callback.response.Id = originalId
|
callback.response.Id = originalId
|
||||||
return callback.response, nil
|
return callback.response, nil
|
||||||
case <-conn.Done():
|
case <-connCtx.Done():
|
||||||
return nil, conn.CloseError()
|
return nil, context.Cause(connCtx)
|
||||||
case <-t.CloseContext().Done():
|
|
||||||
return nil, ErrTransportClosed
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
t.connection.Release(conn, true)
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) recvLoop(conn *Connection) {
|
func (t *UDPTransport) recvLoop(conn net.Conn) {
|
||||||
for {
|
for {
|
||||||
buffer := buf.NewSize(int(t.udpSize.Load()))
|
buffer := buf.NewSize(int(t.udpSize.Load()))
|
||||||
_, err := buffer.ReadOnceFrom(conn)
|
_, err := buffer.ReadOnceFrom(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
conn.CloseWithError(err)
|
t.connection.Invalidate(conn, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,7 +236,7 @@ func (t *UDPTransport) recvLoop(conn *Connection) {
|
|||||||
err = message.Unpack(buffer.Bytes())
|
err = message.Unpack(buffer.Bytes())
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logger.Debug("discarded malformed UDP response: ", err)
|
t.logger.Debug("discarded malformed UDP response: ", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ type DNSTransport struct {
|
|||||||
dnsRouter adapter.DNSRouter
|
dnsRouter adapter.DNSRouter
|
||||||
endpointManager adapter.EndpointManager
|
endpointManager adapter.EndpointManager
|
||||||
endpoint *Endpoint
|
endpoint *Endpoint
|
||||||
|
access sync.RWMutex
|
||||||
routePrefixes []netip.Prefix
|
routePrefixes []netip.Prefix
|
||||||
routes map[string][]adapter.DNSTransport
|
routes map[string][]adapter.DNSTransport
|
||||||
hosts map[string][]netip.Addr
|
hosts map[string][]netip.Addr
|
||||||
@@ -91,6 +92,12 @@ func (t *DNSTransport) Start(stage adapter.StartStage) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *DNSTransport) Reset() {
|
func (t *DNSTransport) Reset() {
|
||||||
|
t.access.RLock()
|
||||||
|
transports := t.collectResolversLocked()
|
||||||
|
t.access.RUnlock()
|
||||||
|
for _, transport := range transports {
|
||||||
|
transport.Reset()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) {
|
func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) {
|
||||||
@@ -101,7 +108,7 @@ func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, d
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *nDNS.Config) error {
|
func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *nDNS.Config) error {
|
||||||
t.routePrefixes = buildRoutePrefixes(routeConfig)
|
routePrefixes := buildRoutePrefixes(routeConfig)
|
||||||
directDialerOnce := sync.OnceValue(func() N.Dialer {
|
directDialerOnce := sync.OnceValue(func() N.Dialer {
|
||||||
directDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{}))
|
directDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{}))
|
||||||
return &DNSDialer{transport: t, fallbackDialer: directDialer}
|
return &DNSDialer{transport: t, fallbackDialer: directDialer}
|
||||||
@@ -130,9 +137,19 @@ func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *n
|
|||||||
}
|
}
|
||||||
defaultResolvers = append(defaultResolvers, myResolver)
|
defaultResolvers = append(defaultResolvers, myResolver)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.access.Lock()
|
||||||
|
oldResolvers := t.collectResolversLocked()
|
||||||
|
t.routePrefixes = routePrefixes
|
||||||
t.routes = routes
|
t.routes = routes
|
||||||
t.hosts = hosts
|
t.hosts = hosts
|
||||||
t.defaultResolvers = defaultResolvers
|
t.defaultResolvers = defaultResolvers
|
||||||
|
t.access.Unlock()
|
||||||
|
|
||||||
|
for _, transport := range oldResolvers {
|
||||||
|
transport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
if len(defaultResolvers) > 0 {
|
if len(defaultResolvers) > 0 {
|
||||||
t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts, default resolvers: ",
|
t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts, default resolvers: ",
|
||||||
strings.Join(common.Map(dnsConfig.DefaultResolvers, func(it *dnstype.Resolver) string { return it.Addr }), " "))
|
strings.Join(common.Map(dnsConfig.DefaultResolvers, func(it *dnstype.Resolver) string { return it.Addr }), " "))
|
||||||
@@ -207,7 +224,22 @@ func buildRoutePrefixes(routeConfig *router.Config) []netip.Prefix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *DNSTransport) Close() error {
|
func (t *DNSTransport) Close() error {
|
||||||
return nil
|
t.access.Lock()
|
||||||
|
transports := t.collectResolversLocked()
|
||||||
|
t.routePrefixes = nil
|
||||||
|
t.routes = nil
|
||||||
|
t.hosts = nil
|
||||||
|
t.defaultResolvers = nil
|
||||||
|
t.access.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, transport := range transports {
|
||||||
|
name := "resolver/" + transport.Type() + "[" + transport.Tag() + "]"
|
||||||
|
err = E.Append(err, transport.Close(), func(err error) error {
|
||||||
|
return E.Cause(err, "close ", name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *DNSTransport) Raw() bool {
|
func (t *DNSTransport) Raw() bool {
|
||||||
@@ -219,7 +251,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||||||
return nil, os.ErrInvalid
|
return nil, os.ErrInvalid
|
||||||
}
|
}
|
||||||
question := message.Question[0]
|
question := message.Question[0]
|
||||||
addresses, hostsLoaded := t.hosts[question.Name]
|
|
||||||
|
t.access.RLock()
|
||||||
|
hosts := t.hosts
|
||||||
|
routes := t.routes
|
||||||
|
defaultResolvers := t.defaultResolvers
|
||||||
|
acceptDefaultResolvers := t.acceptDefaultResolvers
|
||||||
|
t.access.RUnlock()
|
||||||
|
|
||||||
|
addresses, hostsLoaded := hosts[question.Name]
|
||||||
if hostsLoaded {
|
if hostsLoaded {
|
||||||
switch question.Qtype {
|
switch question.Qtype {
|
||||||
case mDNS.TypeA:
|
case mDNS.TypeA:
|
||||||
@@ -238,7 +278,7 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for domainSuffix, transports := range t.routes {
|
for domainSuffix, transports := range routes {
|
||||||
if strings.HasSuffix(question.Name, domainSuffix) {
|
if strings.HasSuffix(question.Name, domainSuffix) {
|
||||||
if len(transports) == 0 {
|
if len(transports) == 0 {
|
||||||
return &mDNS.Msg{
|
return &mDNS.Msg{
|
||||||
@@ -262,10 +302,10 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||||||
return nil, lastErr
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if t.acceptDefaultResolvers {
|
if acceptDefaultResolvers {
|
||||||
if len(t.defaultResolvers) > 0 {
|
if len(defaultResolvers) > 0 {
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for _, resolver := range t.defaultResolvers {
|
for _, resolver := range defaultResolvers {
|
||||||
response, err := resolver.Exchange(ctx, message)
|
response, err := resolver.Exchange(ctx, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
@@ -281,6 +321,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||||||
return nil, dns.RcodeNameError
|
return nil, dns.RcodeNameError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *DNSTransport) collectResolversLocked() []adapter.DNSTransport {
|
||||||
|
var transports []adapter.DNSTransport
|
||||||
|
for _, resolvers := range t.routes {
|
||||||
|
transports = append(transports, resolvers...)
|
||||||
|
}
|
||||||
|
transports = append(transports, t.defaultResolvers...)
|
||||||
|
return transports
|
||||||
|
}
|
||||||
|
|
||||||
type DNSDialer struct {
|
type DNSDialer struct {
|
||||||
transport *DNSTransport
|
transport *DNSTransport
|
||||||
fallbackDialer N.Dialer
|
fallbackDialer N.Dialer
|
||||||
@@ -290,7 +339,8 @@ func (d *DNSDialer) DialContext(ctx context.Context, network string, destination
|
|||||||
if destination.IsDomain() {
|
if destination.IsDomain() {
|
||||||
panic("invalid request here")
|
panic("invalid request here")
|
||||||
}
|
}
|
||||||
for _, prefix := range d.transport.routePrefixes {
|
routePrefixes := d.transport.routePrefixesSnapshot()
|
||||||
|
for _, prefix := range routePrefixes {
|
||||||
if prefix.Contains(destination.Addr) {
|
if prefix.Contains(destination.Addr) {
|
||||||
return d.transport.endpoint.DialContext(ctx, network, destination)
|
return d.transport.endpoint.DialContext(ctx, network, destination)
|
||||||
}
|
}
|
||||||
@@ -302,10 +352,17 @@ func (d *DNSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (
|
|||||||
if destination.IsDomain() {
|
if destination.IsDomain() {
|
||||||
panic("invalid request here")
|
panic("invalid request here")
|
||||||
}
|
}
|
||||||
for _, prefix := range d.transport.routePrefixes {
|
routePrefixes := d.transport.routePrefixesSnapshot()
|
||||||
|
for _, prefix := range routePrefixes {
|
||||||
if prefix.Contains(destination.Addr) {
|
if prefix.Contains(destination.Addr) {
|
||||||
return d.transport.endpoint.ListenPacket(ctx, destination)
|
return d.transport.endpoint.ListenPacket(ctx, destination)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return d.fallbackDialer.ListenPacket(ctx, destination)
|
return d.fallbackDialer.ListenPacket(ctx, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *DNSTransport) routePrefixesSnapshot() []netip.Prefix {
|
||||||
|
t.access.RLock()
|
||||||
|
defer t.access.RUnlock()
|
||||||
|
return append([]netip.Prefix(nil), t.routePrefixes...)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user