Clean up DNS transports

This commit is contained in:
世界
2026-04-23 01:33:38 +08:00
parent a3fc14f35f
commit 3312b8da50
9 changed files with 736 additions and 1031 deletions

View File

@@ -68,6 +68,8 @@ type DNSTransport interface {
Type() string Type() string
Tag() string Tag() string
Dependencies() []string Dependencies() []string
// Reset closes the transport's existing connections so later requests use fresh connections.
// Exchanges that are currently using those connections may fail.
Reset() Reset()
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
} }

View File

@@ -1,145 +0,0 @@
package transport
import (
"context"
"os"
"sync"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
type TransportState int
const (
StateNew TransportState = iota
StateStarted
StateClosing
StateClosed
)
var (
ErrTransportClosed = os.ErrClosed
ErrConnectionReset = E.New("connection reset")
)
type BaseTransport struct {
dns.TransportAdapter
Logger logger.ContextLogger
mutex sync.Mutex
state TransportState
inFlight int32
queriesComplete chan struct{}
closeCtx context.Context
closeCancel context.CancelFunc
}
func NewBaseTransport(adapter dns.TransportAdapter, logger logger.ContextLogger) *BaseTransport {
ctx, cancel := context.WithCancel(context.Background())
return &BaseTransport{
TransportAdapter: adapter,
Logger: logger,
state: StateNew,
closeCtx: ctx,
closeCancel: cancel,
}
}
func (t *BaseTransport) State() TransportState {
t.mutex.Lock()
defer t.mutex.Unlock()
return t.state
}
func (t *BaseTransport) SetStarted() error {
t.mutex.Lock()
defer t.mutex.Unlock()
switch t.state {
case StateNew:
t.state = StateStarted
return nil
case StateStarted:
return nil
default:
return ErrTransportClosed
}
}
func (t *BaseTransport) BeginQuery() bool {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.state != StateStarted {
return false
}
t.inFlight++
return true
}
func (t *BaseTransport) EndQuery() {
t.mutex.Lock()
if t.inFlight > 0 {
t.inFlight--
}
if t.inFlight == 0 && t.queriesComplete != nil {
close(t.queriesComplete)
t.queriesComplete = nil
}
t.mutex.Unlock()
}
func (t *BaseTransport) CloseContext() context.Context {
return t.closeCtx
}
func (t *BaseTransport) Shutdown(ctx context.Context) error {
t.mutex.Lock()
if t.state >= StateClosing {
t.mutex.Unlock()
return nil
}
if t.state == StateNew {
t.state = StateClosed
t.mutex.Unlock()
t.closeCancel()
return nil
}
t.state = StateClosing
if t.inFlight == 0 {
t.state = StateClosed
t.mutex.Unlock()
t.closeCancel()
return nil
}
t.queriesComplete = make(chan struct{})
queriesComplete := t.queriesComplete
t.mutex.Unlock()
t.closeCancel()
select {
case <-queriesComplete:
t.mutex.Lock()
t.state = StateClosed
t.mutex.Unlock()
return nil
case <-ctx.Done():
t.mutex.Lock()
t.state = StateClosed
t.mutex.Unlock()
return ctx.Err()
}
}
func (t *BaseTransport) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), C.TCPTimeout)
defer cancel()
return t.Shutdown(ctx)
}

547
dns/transport/conn_pool.go Normal file
View File

@@ -0,0 +1,547 @@
package transport
import (
"context"
"net"
"sync"
"time"
"github.com/sagernet/sing/common/x/list"
)
type ConnPoolMode int
const (
ConnPoolSingle ConnPoolMode = iota
ConnPoolOrdered
)
type ConnPoolOptions[T comparable] struct {
Mode ConnPoolMode
IsAlive func(T) bool
Close func(T, error)
}
type ConnPool[T comparable] struct {
options ConnPoolOptions[T]
access sync.Mutex
closed bool
state *connPoolState[T]
}
type connPoolState[T comparable] struct {
ctx context.Context
cancel context.CancelCauseFunc
all map[T]struct{}
idle list.List[T]
idleElements map[T]*list.Element[T]
shared T
hasShared bool
sharedClaimed bool
sharedCtx context.Context
sharedCancel context.CancelCauseFunc
connecting *connPoolConnect[T]
}
type connPoolConnect[T comparable] struct {
done chan struct{}
err error
}
type connPoolDialContext struct {
context.Context
parent context.Context
}
func (c connPoolDialContext) Deadline() (time.Time, bool) {
return c.parent.Deadline()
}
func (c connPoolDialContext) Value(key any) any {
return c.parent.Value(key)
}
func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] {
return &ConnPool[T]{
options: options,
state: newConnPoolState[T](options.Mode),
}
}
func newConnPoolState[T comparable](mode ConnPoolMode) *connPoolState[T] {
ctx, cancel := context.WithCancelCause(context.Background())
state := &connPoolState[T]{
ctx: ctx,
cancel: cancel,
all: make(map[T]struct{}),
}
if mode == ConnPoolOrdered {
state.idleElements = make(map[T]*list.Element[T])
}
return state
}
func (p *ConnPool[T]) Acquire(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) {
switch p.options.Mode {
case ConnPoolSingle:
conn, _, created, err := p.acquireShared(ctx, dial)
return conn, created, err
case ConnPoolOrdered:
return p.acquireOrdered(ctx, dial)
default:
var zero T
return zero, false, net.ErrClosed
}
}
func (p *ConnPool[T]) AcquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) {
if p.options.Mode != ConnPoolSingle {
var zero T
return zero, nil, false, net.ErrClosed
}
return p.acquireShared(ctx, dial)
}
func (p *ConnPool[T]) Release(conn T, reuse bool) {
var (
closeConn bool
closeErr error
)
p.access.Lock()
if p.closed || p.state == nil {
closeConn = true
closeErr = net.ErrClosed
p.access.Unlock()
if closeConn {
p.options.Close(conn, closeErr)
}
return
}
currentState := p.state
_, tracked := currentState.all[conn]
if !tracked {
closeConn = true
closeErr = p.closeCause(currentState)
p.access.Unlock()
if closeConn {
p.options.Close(conn, closeErr)
}
return
}
if !reuse || !p.options.IsAlive(conn) {
delete(currentState.all, conn)
switch p.options.Mode {
case ConnPoolSingle:
if currentState.hasShared && currentState.shared == conn {
var zero T
currentState.shared = zero
currentState.hasShared = false
currentState.sharedClaimed = false
currentState.sharedCtx = nil
if currentState.sharedCancel != nil {
currentState.sharedCancel(net.ErrClosed)
currentState.sharedCancel = nil
}
}
case ConnPoolOrdered:
if element, loaded := currentState.idleElements[conn]; loaded {
currentState.idle.Remove(element)
delete(currentState.idleElements, conn)
}
}
closeConn = true
closeErr = net.ErrClosed
p.access.Unlock()
if closeConn {
p.options.Close(conn, closeErr)
}
return
}
if p.options.Mode == ConnPoolOrdered {
if _, loaded := currentState.idleElements[conn]; !loaded {
currentState.idleElements[conn] = currentState.idle.PushBack(conn)
}
}
p.access.Unlock()
}
func (p *ConnPool[T]) Invalidate(conn T, cause error) {
p.access.Lock()
if p.closed || p.state == nil {
p.access.Unlock()
p.options.Close(conn, cause)
return
}
currentState := p.state
_, tracked := currentState.all[conn]
if !tracked {
p.access.Unlock()
return
}
delete(currentState.all, conn)
switch p.options.Mode {
case ConnPoolSingle:
if currentState.hasShared && currentState.shared == conn {
var zero T
currentState.shared = zero
currentState.hasShared = false
currentState.sharedClaimed = false
currentState.sharedCtx = nil
if currentState.sharedCancel != nil {
currentState.sharedCancel(cause)
currentState.sharedCancel = nil
}
}
case ConnPoolOrdered:
if element, loaded := currentState.idleElements[conn]; loaded {
currentState.idle.Remove(element)
delete(currentState.idleElements, conn)
}
}
p.access.Unlock()
p.options.Close(conn, cause)
}
func (p *ConnPool[T]) Reset() {
p.access.Lock()
if p.closed {
p.access.Unlock()
return
}
oldState := p.state
p.state = newConnPoolState[T](p.options.Mode)
p.access.Unlock()
p.closeState(oldState, net.ErrClosed)
}
func (p *ConnPool[T]) Close() error {
p.access.Lock()
if p.closed {
p.access.Unlock()
return nil
}
p.closed = true
oldState := p.state
p.state = nil
p.access.Unlock()
p.closeState(oldState, net.ErrClosed)
return nil
}
func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) {
var zero T
for {
var (
staleConn T
hasStale bool
)
p.access.Lock()
if p.closed {
p.access.Unlock()
return zero, false, net.ErrClosed
}
currentState := p.state
if element := currentState.idle.Front(); element != nil {
conn := currentState.idle.Remove(element)
delete(currentState.idleElements, conn)
if p.options.IsAlive(conn) {
p.access.Unlock()
return conn, false, nil
}
delete(currentState.all, conn)
staleConn = conn
hasStale = true
}
p.access.Unlock()
if hasStale {
p.options.Close(staleConn, net.ErrClosed)
continue
}
conn, err := p.dial(ctx, currentState, dial)
if err != nil {
return zero, false, err
}
p.access.Lock()
if p.closed {
p.access.Unlock()
p.options.Close(conn, net.ErrClosed)
return zero, false, net.ErrClosed
}
if p.state != currentState {
cause := p.closeCause(currentState)
p.access.Unlock()
p.options.Close(conn, cause)
return zero, false, cause
}
currentState.all[conn] = struct{}{}
p.access.Unlock()
return conn, true, nil
}
}
func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) {
var zero T
for {
var (
staleConn T
hasStale bool
state *connPoolConnect[T]
current *connPoolState[T]
startDial bool
)
p.access.Lock()
if p.closed {
p.access.Unlock()
return zero, nil, false, net.ErrClosed
}
current = p.state
if current.hasShared {
conn := current.shared
if p.options.IsAlive(conn) {
created := !current.sharedClaimed
current.sharedClaimed = true
connCtx := current.sharedCtx
p.access.Unlock()
return conn, connCtx, created, nil
}
delete(current.all, conn)
var zeroConn T
current.shared = zeroConn
current.hasShared = false
current.sharedClaimed = false
current.sharedCtx = nil
if current.sharedCancel != nil {
current.sharedCancel(net.ErrClosed)
current.sharedCancel = nil
}
staleConn = conn
hasStale = true
p.access.Unlock()
p.options.Close(staleConn, net.ErrClosed)
continue
}
if current.connecting == nil {
current.connecting = &connPoolConnect[T]{
done: make(chan struct{}),
}
startDial = true
}
state = current.connecting
p.access.Unlock()
if hasStale {
continue
}
if startDial {
go p.connectSingle(current, state, ctx, dial)
}
select {
case <-state.done:
conn, connCtx, created, retry, err := p.collectShared(current, state, startDial)
if retry {
continue
}
return conn, connCtx, created, err
case <-ctx.Done():
return zero, nil, false, ctx.Err()
case <-current.ctx.Done():
p.access.Lock()
closed := p.closed
p.access.Unlock()
if closed {
return zero, nil, false, net.ErrClosed
}
}
}
}
func (p *ConnPool[T]) connectSingle(current *connPoolState[T], state *connPoolConnect[T], ctx context.Context, dial func(context.Context) (T, error)) {
conn, err := p.dial(ctx, current, dial)
if err != nil {
p.access.Lock()
if current.connecting == state {
current.connecting = nil
}
state.err = err
p.access.Unlock()
close(state.done)
return
}
var closeErr error
p.access.Lock()
if current.connecting == state {
current.connecting = nil
}
if p.closed {
closeErr = net.ErrClosed
state.err = closeErr
} else if p.state != current {
closeErr = p.closeCause(current)
state.err = closeErr
} else {
sharedCtx, sharedCancel := context.WithCancelCause(current.ctx)
current.shared = conn
current.hasShared = true
current.sharedClaimed = false
current.sharedCtx = sharedCtx
current.sharedCancel = sharedCancel
current.all[conn] = struct{}{}
}
p.access.Unlock()
if closeErr != nil {
p.options.Close(conn, closeErr)
}
close(state.done)
}
func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolConnect[T], startDial bool) (T, context.Context, bool, bool, error) {
var zero T
p.access.Lock()
if state.err != nil {
err := state.err
p.access.Unlock()
if startDial {
return zero, nil, false, false, err
}
return zero, nil, false, true, nil
}
if p.closed {
p.access.Unlock()
return zero, nil, false, false, net.ErrClosed
}
if p.state != current {
cause := p.closeCause(current)
p.access.Unlock()
return zero, nil, false, false, cause
}
if !current.hasShared {
p.access.Unlock()
return zero, nil, false, true, nil
}
conn := current.shared
if !p.options.IsAlive(conn) {
delete(current.all, conn)
var zeroConn T
current.shared = zeroConn
current.hasShared = false
current.sharedClaimed = false
current.sharedCtx = nil
if current.sharedCancel != nil {
current.sharedCancel(net.ErrClosed)
current.sharedCancel = nil
}
p.access.Unlock()
p.options.Close(conn, net.ErrClosed)
return zero, nil, false, true, nil
}
created := !current.sharedClaimed
current.sharedClaimed = true
connCtx := current.sharedCtx
p.access.Unlock()
return conn, connCtx, created, false, nil
}
func (p *ConnPool[T]) dial(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, error) {
var zero T
if err := ctx.Err(); err != nil {
return zero, err
}
if cause := context.Cause(current.ctx); cause != nil {
return zero, cause
}
dialCtx, cancel := context.WithCancelCause(current.ctx)
var (
stateAccess sync.Mutex
dialComplete bool
)
stopCancel := context.AfterFunc(ctx, func() {
stateAccess.Lock()
if !dialComplete {
cancel(context.Cause(ctx))
}
stateAccess.Unlock()
})
select {
case <-ctx.Done():
stateAccess.Lock()
dialComplete = true
stateAccess.Unlock()
stopCancel()
cancel(context.Cause(ctx))
return zero, ctx.Err()
default:
}
conn, err := dial(connPoolDialContext{
Context: dialCtx,
parent: ctx,
})
stateAccess.Lock()
dialComplete = true
stateAccess.Unlock()
stopCancel()
if err != nil {
if cause := context.Cause(dialCtx); cause != nil {
return zero, cause
}
return zero, err
}
if cause := context.Cause(dialCtx); cause != nil {
p.options.Close(conn, cause)
return zero, cause
}
return conn, nil
}
func (p *ConnPool[T]) closeState(state *connPoolState[T], cause error) {
if state == nil {
return
}
state.cancel(cause)
if state.sharedCancel != nil {
state.sharedCancel(cause)
}
for conn := range state.all {
p.options.Close(conn, cause)
}
}
func (p *ConnPool[T]) closeCause(state *connPoolState[T]) error {
_ = state
return net.ErrClosed
}

View File

@@ -1,321 +0,0 @@
package transport
import (
"context"
"net"
"sync"
"time"
E "github.com/sagernet/sing/common/exceptions"
)
type ConnectorCallbacks[T any] struct {
IsClosed func(connection T) bool
Close func(connection T)
Reset func(connection T)
}
type Connector[T any] struct {
dial func(ctx context.Context) (T, error)
callbacks ConnectorCallbacks[T]
access sync.Mutex
connection T
hasConnection bool
connectionCancel context.CancelFunc
connecting chan struct{}
closeCtx context.Context
closed bool
}
func NewConnector[T any](closeCtx context.Context, dial func(context.Context) (T, error), callbacks ConnectorCallbacks[T]) *Connector[T] {
return &Connector[T]{
dial: dial,
callbacks: callbacks,
closeCtx: closeCtx,
}
}
func NewSingleflightConnector(closeCtx context.Context, dial func(context.Context) (*Connection, error)) *Connector[*Connection] {
return NewConnector(closeCtx, dial, ConnectorCallbacks[*Connection]{
IsClosed: func(connection *Connection) bool {
return connection.IsClosed()
},
Close: func(connection *Connection) {
connection.CloseWithError(ErrTransportClosed)
},
Reset: func(connection *Connection) {
connection.CloseWithError(ErrConnectionReset)
},
})
}
type contextKeyConnecting struct{}
var errRecursiveConnectorDial = E.New("recursive connector dial")
type connectorDialResult[T any] struct {
connection T
cancel context.CancelFunc
err error
}
func (c *Connector[T]) Get(ctx context.Context) (T, error) {
var zero T
for {
c.access.Lock()
if c.closed {
c.access.Unlock()
return zero, ErrTransportClosed
}
if c.hasConnection && !c.callbacks.IsClosed(c.connection) {
connection := c.connection
c.access.Unlock()
return connection, nil
}
c.hasConnection = false
if c.connectionCancel != nil {
c.connectionCancel()
c.connectionCancel = nil
}
if isRecursiveConnectorDial(ctx, c) {
c.access.Unlock()
return zero, errRecursiveConnectorDial
}
if c.connecting != nil {
connecting := c.connecting
c.access.Unlock()
select {
case <-connecting:
continue
case <-ctx.Done():
return zero, ctx.Err()
case <-c.closeCtx.Done():
return zero, ErrTransportClosed
}
}
if err := ctx.Err(); err != nil {
c.access.Unlock()
return zero, err
}
connecting := make(chan struct{})
c.connecting = connecting
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
dialResult := make(chan connectorDialResult[T], 1)
c.access.Unlock()
go func() {
connection, cancel, err := c.dialWithCancellation(dialContext)
dialResult <- connectorDialResult[T]{
connection: connection,
cancel: cancel,
err: err,
}
}()
select {
case result := <-dialResult:
return c.completeDial(ctx, connecting, result)
case <-ctx.Done():
go func() {
result := <-dialResult
_, _ = c.completeDial(ctx, connecting, result)
}()
return zero, ctx.Err()
case <-c.closeCtx.Done():
go func() {
result := <-dialResult
_, _ = c.completeDial(ctx, connecting, result)
}()
return zero, ErrTransportClosed
}
}
}
func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T]) bool {
dialConnector, loaded := ctx.Value(contextKeyConnecting{}).(*Connector[T])
return loaded && dialConnector == connector
}
func (c *Connector[T]) completeDial(ctx context.Context, connecting chan struct{}, result connectorDialResult[T]) (T, error) {
var zero T
c.access.Lock()
defer c.access.Unlock()
defer func() {
if c.connecting == connecting {
c.connecting = nil
}
close(connecting)
}()
if result.err != nil {
return zero, result.err
}
if c.closed || c.closeCtx.Err() != nil {
result.cancel()
c.callbacks.Close(result.connection)
return zero, ErrTransportClosed
}
if err := ctx.Err(); err != nil {
result.cancel()
c.callbacks.Close(result.connection)
return zero, err
}
c.connection = result.connection
c.hasConnection = true
c.connectionCancel = result.cancel
return c.connection, nil
}
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) {
var zero T
if err := ctx.Err(); err != nil {
return zero, nil, err
}
connCtx, cancel := context.WithCancel(c.closeCtx)
var (
stateAccess sync.Mutex
dialComplete bool
)
stopCancel := context.AfterFunc(ctx, func() {
stateAccess.Lock()
if !dialComplete {
cancel()
}
stateAccess.Unlock()
})
select {
case <-ctx.Done():
stateAccess.Lock()
dialComplete = true
stateAccess.Unlock()
stopCancel()
cancel()
return zero, nil, ctx.Err()
default:
}
connection, err := c.dial(valueContext{connCtx, ctx})
stateAccess.Lock()
dialComplete = true
stateAccess.Unlock()
stopCancel()
if err != nil {
cancel()
return zero, nil, err
}
return connection, cancel, nil
}
type valueContext struct {
context.Context
parent context.Context
}
func (v valueContext) Value(key any) any {
return v.parent.Value(key)
}
func (v valueContext) Deadline() (time.Time, bool) {
return v.parent.Deadline()
}
func (c *Connector[T]) Close() error {
c.access.Lock()
defer c.access.Unlock()
if c.closed {
return nil
}
c.closed = true
if c.connectionCancel != nil {
c.connectionCancel()
c.connectionCancel = nil
}
if c.hasConnection {
c.callbacks.Close(c.connection)
c.hasConnection = false
}
return nil
}
func (c *Connector[T]) Reset() {
c.access.Lock()
defer c.access.Unlock()
if c.connectionCancel != nil {
c.connectionCancel()
c.connectionCancel = nil
}
if c.hasConnection {
c.callbacks.Reset(c.connection)
c.hasConnection = false
}
}
type Connection struct {
net.Conn
closeOnce sync.Once
done chan struct{}
closeError error
}
func WrapConnection(conn net.Conn) *Connection {
return &Connection{
Conn: conn,
done: make(chan struct{}),
}
}
func (c *Connection) Done() <-chan struct{} {
return c.done
}
func (c *Connection) IsClosed() bool {
select {
case <-c.done:
return true
default:
return false
}
}
func (c *Connection) CloseError() error {
select {
case <-c.done:
if c.closeError != nil {
return c.closeError
}
return ErrTransportClosed
default:
return nil
}
}
func (c *Connection) Close() error {
return c.CloseWithError(ErrTransportClosed)
}
func (c *Connection) CloseWithError(err error) error {
var returnError error
c.closeOnce.Do(func() {
c.closeError = err
returnError = c.Conn.Close()
close(c.done)
})
return returnError
}

View File

@@ -1,407 +0,0 @@
package transport
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type testConnectorConnection struct{}
func TestConnectorRecursiveGetFailsFast(t *testing.T) {
t.Parallel()
var (
dialCount atomic.Int32
closeCount atomic.Int32
connector *Connector[*testConnectorConnection]
)
dial := func(ctx context.Context) (*testConnectorConnection, error) {
dialCount.Add(1)
_, err := connector.Get(ctx)
if err != nil {
return nil, err
}
return &testConnectorConnection{}, nil
}
connector = NewConnector(context.Background(), dial, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
Reset: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
})
_, err := connector.Get(context.Background())
require.ErrorIs(t, err, errRecursiveConnectorDial)
require.EqualValues(t, 1, dialCount.Load())
require.EqualValues(t, 0, closeCount.Load())
}
func TestConnectorRecursiveGetAcrossConnectorsAllowed(t *testing.T) {
t.Parallel()
var (
outerDialCount atomic.Int32
innerDialCount atomic.Int32
outerConnector *Connector[*testConnectorConnection]
innerConnector *Connector[*testConnectorConnection]
)
innerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
innerDialCount.Add(1)
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
outerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
outerDialCount.Add(1)
_, err := innerConnector.Get(ctx)
if err != nil {
return nil, err
}
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
_, err := outerConnector.Get(context.Background())
require.NoError(t, err)
require.EqualValues(t, 1, outerDialCount.Load())
require.EqualValues(t, 1, innerDialCount.Load())
}
func TestConnectorDialContextPreservesValueAndDeadline(t *testing.T) {
t.Parallel()
type contextKey struct{}
var (
dialValue any
dialDeadline time.Time
dialHasDeadline bool
)
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialValue = ctx.Value(contextKey{})
dialDeadline, dialHasDeadline = ctx.Deadline()
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
deadline := time.Now().Add(time.Minute)
requestContext, cancel := context.WithDeadline(context.WithValue(context.Background(), contextKey{}, "test-value"), deadline)
defer cancel()
_, err := connector.Get(requestContext)
require.NoError(t, err)
require.Equal(t, "test-value", dialValue)
require.True(t, dialHasDeadline)
require.WithinDuration(t, deadline, dialDeadline, time.Second)
}
func TestConnectorDialSkipsCanceledRequest(t *testing.T) {
t.Parallel()
var dialCount atomic.Int32
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialCount.Add(1)
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
cancel()
_, err := connector.Get(requestContext)
require.ErrorIs(t, err, context.Canceled)
require.EqualValues(t, 0, dialCount.Load())
}
func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) {
t.Parallel()
var (
dialCount atomic.Int32
closeCount atomic.Int32
)
dialStarted := make(chan struct{}, 1)
releaseDial := make(chan struct{})
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialCount.Add(1)
select {
case dialStarted <- struct{}{}:
default:
}
<-releaseDial
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
result := make(chan error, 1)
go func() {
_, err := connector.Get(requestContext)
result <- err
}()
<-dialStarted
cancel()
close(releaseDial)
err := <-result
require.ErrorIs(t, err, context.Canceled)
require.EqualValues(t, 1, dialCount.Load())
require.Eventually(t, func() bool {
return closeCount.Load() == 1
}, time.Second, 10*time.Millisecond)
_, err = connector.Get(context.Background())
require.NoError(t, err)
require.EqualValues(t, 2, dialCount.Load())
}
func TestConnectorCanceledRequestReturnsBeforeIgnoredDialCompletes(t *testing.T) {
t.Parallel()
var (
dialCount atomic.Int32
closeCount atomic.Int32
)
dialStarted := make(chan struct{}, 1)
releaseDial := make(chan struct{})
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialCount.Add(1)
select {
case dialStarted <- struct{}{}:
default:
}
<-releaseDial
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
result := make(chan error, 1)
go func() {
_, err := connector.Get(requestContext)
result <- err
}()
<-dialStarted
cancel()
select {
case err := <-result:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("Get did not return after request cancel")
}
require.EqualValues(t, 1, dialCount.Load())
require.EqualValues(t, 0, closeCount.Load())
close(releaseDial)
require.Eventually(t, func() bool {
return closeCount.Load() == 1
}, time.Second, 10*time.Millisecond)
_, err := connector.Get(context.Background())
require.NoError(t, err)
require.EqualValues(t, 2, dialCount.Load())
}
func TestConnectorWaiterDoesNotStartNewDialBeforeCanceledDialCompletes(t *testing.T) {
t.Parallel()
var (
dialCount atomic.Int32
closeCount atomic.Int32
)
firstDialStarted := make(chan struct{}, 1)
secondDialStarted := make(chan struct{}, 1)
releaseFirstDial := make(chan struct{})
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
attempt := dialCount.Add(1)
switch attempt {
case 1:
select {
case firstDialStarted <- struct{}{}:
default:
}
<-releaseFirstDial
case 2:
select {
case secondDialStarted <- struct{}{}:
default:
}
}
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
firstResult := make(chan error, 1)
go func() {
_, err := connector.Get(requestContext)
firstResult <- err
}()
<-firstDialStarted
cancel()
secondResult := make(chan error, 1)
go func() {
_, err := connector.Get(context.Background())
secondResult <- err
}()
select {
case <-secondDialStarted:
t.Fatal("second dial started before first dial completed")
case <-time.After(100 * time.Millisecond):
}
select {
case err := <-firstResult:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("first Get did not return after request cancel")
}
close(releaseFirstDial)
require.Eventually(t, func() bool {
return closeCount.Load() == 1
}, time.Second, 10*time.Millisecond)
select {
case <-secondDialStarted:
case <-time.After(time.Second):
t.Fatal("second dial did not start after first dial completed")
}
err := <-secondResult
require.NoError(t, err)
require.EqualValues(t, 2, dialCount.Load())
}
func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) {
t.Parallel()
var dialContext context.Context
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialContext = ctx
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
_, err := connector.Get(requestContext)
require.NoError(t, err)
require.NotNil(t, dialContext)
cancel()
select {
case <-dialContext.Done():
t.Fatal("dial context canceled by request context after successful dial")
case <-time.After(100 * time.Millisecond):
}
err = connector.Close()
require.NoError(t, err)
}
func TestConnectorDialContextCanceledOnClose(t *testing.T) {
t.Parallel()
var dialContext context.Context
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialContext = ctx
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
_, err := connector.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, dialContext)
select {
case <-dialContext.Done():
t.Fatal("dial context canceled before connector close")
default:
}
err = connector.Close()
require.NoError(t, err)
select {
case <-dialContext.Done():
case <-time.After(time.Second):
t.Fatal("dial context not canceled after connector close")
}
}

View File

@@ -31,14 +31,13 @@ func RegisterTransport(registry *dns.TransportRegistry) {
} }
type Transport struct { type Transport struct {
*transport.BaseTransport dns.TransportAdapter
ctx context.Context
dialer N.Dialer dialer N.Dialer
serverAddr M.Socksaddr serverAddr M.Socksaddr
tlsConfig tls.Config tlsConfig tls.Config
connector *transport.Connector[*quic.Conn] connection *transport.ConnPool[*quic.Conn]
} }
func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) { func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
@@ -63,93 +62,76 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options
return nil, E.New("invalid server address: ", serverAddr) return nil, E.New("invalid server address: ", serverAddr)
} }
t := &Transport{ return &Transport{
BaseTransport: transport.NewBaseTransport( TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions),
dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions), dialer: transportDialer,
logger, serverAddr: serverAddr,
), tlsConfig: tlsConfig,
ctx: ctx, connection: transport.NewConnPool(transport.ConnPoolOptions[*quic.Conn]{
dialer: transportDialer, Mode: transport.ConnPoolSingle,
serverAddr: serverAddr, IsAlive: func(conn *quic.Conn) bool {
tlsConfig: tlsConfig, return conn != nil && !common.Done(conn.Context())
} },
Close: func(conn *quic.Conn, _ error) {
t.connector = transport.NewConnector(t.CloseContext(), t.dial, transport.ConnectorCallbacks[*quic.Conn]{ conn.CloseWithError(0, "")
IsClosed: func(connection *quic.Conn) bool { },
return common.Done(connection.Context()) }),
}, }, nil
Close: func(connection *quic.Conn) {
connection.CloseWithError(0, "")
},
Reset: func(connection *quic.Conn) {
connection.CloseWithError(0, "")
},
})
return t, nil
}
func (t *Transport) dial(ctx context.Context) (*quic.Conn, error) {
conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
if err != nil {
return nil, E.Cause(err, "dial UDP connection")
}
earlyConnection, err := sQUIC.DialEarly(
ctx,
bufio.NewUnbindPacketConn(conn),
t.serverAddr.UDPAddr(),
t.tlsConfig,
nil,
)
if err != nil {
conn.Close()
return nil, E.Cause(err, "establish QUIC connection")
}
return earlyConnection, nil
} }
func (t *Transport) Start(stage adapter.StartStage) error { func (t *Transport) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart { if stage != adapter.StartStateStart {
return nil return nil
} }
err := t.SetStarted()
if err != nil {
return err
}
return dialer.InitializeDetour(t.dialer) return dialer.InitializeDetour(t.dialer)
} }
func (t *Transport) Close() error { func (t *Transport) Close() error {
return E.Errors(t.BaseTransport.Close(), t.connector.Close()) return t.connection.Close()
} }
func (t *Transport) Reset() { func (t *Transport) Reset() {
t.connector.Reset() t.connection.Reset()
} }
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
if !t.BeginQuery() {
return nil, transport.ErrTransportClosed
}
defer t.EndQuery()
var ( var (
conn *quic.Conn conn *quic.Conn
err error err error
response *mDNS.Msg response *mDNS.Msg
) )
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
conn, err = t.connector.Get(ctx) conn, _, err = t.connection.Acquire(ctx, func(ctx context.Context) (*quic.Conn, error) {
rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
if err != nil {
return nil, E.Cause(err, "dial UDP connection")
}
earlyConnection, err := sQUIC.DialEarly(
ctx,
bufio.NewUnbindPacketConn(rawConn),
t.serverAddr.UDPAddr(),
t.tlsConfig,
nil,
)
if err != nil {
rawConn.Close()
return nil, E.Cause(err, "establish QUIC connection")
}
return earlyConnection, nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
response, err = t.exchange(ctx, message, conn) response, err = t.exchange(ctx, message, conn)
if err == nil { if err == nil {
t.connection.Release(conn, true)
return response, nil return response, nil
} else if !isQUICRetryError(err) { } else if !isQUICRetryError(err) {
t.connection.Release(conn, true)
return nil, err return nil, err
} else { } else {
t.connector.Reset() t.connection.Release(conn, true)
t.Reset()
continue continue
} }
} }

View File

@@ -2,7 +2,6 @@ package transport
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
@@ -17,7 +16,6 @@ import (
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
) )
@@ -29,13 +27,13 @@ func RegisterTLS(registry *dns.TransportRegistry) {
} }
type TLSTransport struct { type TLSTransport struct {
*BaseTransport dns.TransportAdapter
logger logger.ContextLogger
dialer tls.Dialer dialer tls.Dialer
serverAddr M.Socksaddr serverAddr M.Socksaddr
tlsConfig tls.Config tlsConfig tls.Config
access sync.Mutex connections *ConnPool[*tlsDNSConn]
connections list.List[*tlsDNSConn]
} }
type tlsDNSConn struct { type tlsDNSConn struct {
@@ -66,10 +64,20 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o
func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport { func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport {
return &TLSTransport{ return &TLSTransport{
BaseTransport: NewBaseTransport(adapter, logger), TransportAdapter: adapter,
dialer: tls.NewDialer(dialer, tlsConfig), logger: logger,
serverAddr: serverAddr, dialer: tls.NewDialer(dialer, tlsConfig),
tlsConfig: tlsConfig, serverAddr: serverAddr,
tlsConfig: tlsConfig,
connections: NewConnPool(ConnPoolOptions[*tlsDNSConn]{
Mode: ConnPoolOrdered,
IsAlive: func(conn *tlsDNSConn) bool {
return conn != nil
},
Close: func(conn *tlsDNSConn, _ error) {
conn.Close()
},
}),
} }
} }
@@ -77,53 +85,43 @@ func (t *TLSTransport) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart { if stage != adapter.StartStateStart {
return nil return nil
} }
err := t.SetStarted()
if err != nil {
return err
}
return dialer.InitializeDetour(t.dialer) return dialer.InitializeDetour(t.dialer)
} }
func (t *TLSTransport) Close() error { func (t *TLSTransport) Close() error {
t.access.Lock() return t.connections.Close()
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
connection.Value.Close()
}
t.connections.Init()
t.access.Unlock()
return t.BaseTransport.Close()
} }
func (t *TLSTransport) Reset() { func (t *TLSTransport) Reset() {
t.access.Lock() t.connections.Reset()
defer t.access.Unlock()
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
connection.Value.Close()
}
t.connections.Init()
} }
func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
if !t.BeginQuery() { var lastErr error
return nil, ErrTransportClosed for attempt := 0; attempt < 2; attempt++ {
} conn, created, err := t.connections.Acquire(ctx, func(ctx context.Context) (*tlsDNSConn, error) {
defer t.EndQuery() tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr)
if err != nil {
t.access.Lock() return nil, E.Cause(err, "dial TLS connection")
conn := t.connections.PopFront() }
t.access.Unlock() return &tlsDNSConn{Conn: tlsConn}, nil
if conn != nil { })
if err != nil {
return nil, err
}
response, err := t.exchange(ctx, message, conn) response, err := t.exchange(ctx, message, conn)
if err == nil { if err == nil {
t.connections.Release(conn, true)
return response, nil return response, nil
} }
t.Logger.DebugContext(ctx, "discarded pooled connection: ", err) lastErr = err
t.logger.DebugContext(ctx, "discarded pooled connection: ", err)
t.connections.Release(conn, false)
if created {
return nil, err
}
} }
tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr) return nil, lastErr
if err != nil {
return nil, E.Cause(err, "dial TLS connection")
}
return t.exchange(ctx, message, &tlsDNSConn{Conn: tlsConn})
} }
func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) { func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
@@ -133,22 +131,12 @@ func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tl
conn.queryId++ conn.queryId++
err := WriteMessage(conn, conn.queryId, message) err := WriteMessage(conn, conn.queryId, message)
if err != nil { if err != nil {
conn.Close()
return nil, E.Cause(err, "write request") return nil, E.Cause(err, "write request")
} }
response, err := ReadMessage(conn) response, err := ReadMessage(conn)
if err != nil { if err != nil {
conn.Close()
return nil, E.Cause(err, "read response") return nil, E.Cause(err, "read response")
} }
t.access.Lock()
if t.State() >= StateClosing {
t.access.Unlock()
conn.Close()
return response, nil
}
conn.SetDeadline(time.Time{}) conn.SetDeadline(time.Time{})
t.connections.PushBack(conn)
t.access.Unlock()
return response, nil return response, nil
} }

View File

@@ -2,6 +2,7 @@ package transport
import ( import (
"context" "context"
"net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -27,13 +28,14 @@ func RegisterUDP(registry *dns.TransportRegistry) {
} }
type UDPTransport struct { type UDPTransport struct {
*BaseTransport dns.TransportAdapter
logger logger.ContextLogger
dialer N.Dialer dialer N.Dialer
serverAddr M.Socksaddr serverAddr M.Socksaddr
udpSize atomic.Int32 udpSize atomic.Int32
connector *Connector[*Connection] connection *ConnPool[net.Conn]
callbackAccess sync.RWMutex callbackAccess sync.RWMutex
queryId uint16 queryId uint16
@@ -63,43 +65,38 @@ func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options o
func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport { func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
t := &UDPTransport{ t := &UDPTransport{
BaseTransport: NewBaseTransport(adapter, logger), TransportAdapter: adapter,
dialer: dialerInstance, logger: logger,
serverAddr: serverAddr, dialer: dialerInstance,
callbacks: make(map[uint16]*udpCallback), serverAddr: serverAddr,
callbacks: make(map[uint16]*udpCallback),
connection: NewConnPool(ConnPoolOptions[net.Conn]{
Mode: ConnPoolSingle,
IsAlive: func(conn net.Conn) bool {
return conn != nil
},
Close: func(conn net.Conn, cause error) {
conn.Close()
},
}),
} }
t.udpSize.Store(2048) t.udpSize.Store(2048)
t.connector = NewSingleflightConnector(t.CloseContext(), t.dial)
return t return t
} }
func (t *UDPTransport) dial(ctx context.Context) (*Connection, error) {
rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
if err != nil {
return nil, E.Cause(err, "dial UDP connection")
}
conn := WrapConnection(rawConn)
go t.recvLoop(conn)
return conn, nil
}
func (t *UDPTransport) Start(stage adapter.StartStage) error { func (t *UDPTransport) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart { if stage != adapter.StartStateStart {
return nil return nil
} }
err := t.SetStarted()
if err != nil {
return err
}
return dialer.InitializeDetour(t.dialer) return dialer.InitializeDetour(t.dialer)
} }
func (t *UDPTransport) Close() error { func (t *UDPTransport) Close() error {
return E.Errors(t.BaseTransport.Close(), t.connector.Close()) return t.connection.Close()
} }
func (t *UDPTransport) Reset() { func (t *UDPTransport) Reset() {
t.connector.Reset() t.connection.Reset()
} }
func (t *UDPTransport) nextAvailableQueryId() (uint16, error) { func (t *UDPTransport) nextAvailableQueryId() (uint16, error) {
@@ -116,17 +113,12 @@ func (t *UDPTransport) nextAvailableQueryId() (uint16, error) {
} }
func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
if !t.BeginQuery() {
return nil, ErrTransportClosed
}
defer t.EndQuery()
response, err := t.exchange(ctx, message) response, err := t.exchange(ctx, message)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if response.Truncated { if response.Truncated {
t.Logger.InfoContext(ctx, "response truncated, retrying with TCP") t.logger.InfoContext(ctx, "response truncated, retrying with TCP")
return t.exchangeTCP(ctx, message) return t.exchangeTCP(ctx, message)
} }
return response, nil return response, nil
@@ -158,16 +150,25 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
break break
} }
if t.udpSize.CompareAndSwap(current, udpSize) { if t.udpSize.CompareAndSwap(current, udpSize) {
t.connector.Reset() t.Reset()
break break
} }
} }
} }
conn, err := t.connector.Get(ctx) conn, connCtx, created, err := t.connection.AcquireShared(ctx, func(ctx context.Context) (net.Conn, error) {
rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
if err != nil {
return nil, E.Cause(err, "dial UDP connection")
}
return rawConn, nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if created {
go t.recvLoop(conn)
}
callback := &udpCallback{ callback := &udpCallback{
done: make(chan struct{}), done: make(chan struct{}),
@@ -177,6 +178,7 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
queryId, err := t.nextAvailableQueryId() queryId, err := t.nextAvailableQueryId()
if err != nil { if err != nil {
t.callbackAccess.Unlock() t.callbackAccess.Unlock()
t.connection.Release(conn, true)
return nil, err return nil, err
} }
t.callbacks[queryId] = callback t.callbacks[queryId] = callback
@@ -203,30 +205,30 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
_, err = conn.Write(rawMessage) _, err = conn.Write(rawMessage)
if err != nil { if err != nil {
conn.CloseWithError(err) t.connection.Invalidate(conn, err)
return nil, E.Cause(err, "write request") return nil, E.Cause(err, "write request")
} }
select { select {
case <-callback.done: case <-callback.done:
t.connection.Release(conn, true)
callback.response.Id = originalId callback.response.Id = originalId
return callback.response, nil return callback.response, nil
case <-conn.Done(): case <-connCtx.Done():
return nil, conn.CloseError() return nil, context.Cause(connCtx)
case <-t.CloseContext().Done():
return nil, ErrTransportClosed
case <-ctx.Done(): case <-ctx.Done():
t.connection.Release(conn, true)
return nil, ctx.Err() return nil, ctx.Err()
} }
} }
func (t *UDPTransport) recvLoop(conn *Connection) { func (t *UDPTransport) recvLoop(conn net.Conn) {
for { for {
buffer := buf.NewSize(int(t.udpSize.Load())) buffer := buf.NewSize(int(t.udpSize.Load()))
_, err := buffer.ReadOnceFrom(conn) _, err := buffer.ReadOnceFrom(conn)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
conn.CloseWithError(err) t.connection.Invalidate(conn, err)
return return
} }
@@ -234,7 +236,7 @@ func (t *UDPTransport) recvLoop(conn *Connection) {
err = message.Unpack(buffer.Bytes()) err = message.Unpack(buffer.Bytes())
buffer.Release() buffer.Release()
if err != nil { if err != nil {
t.Logger.Debug("discarded malformed UDP response: ", err) t.logger.Debug("discarded malformed UDP response: ", err)
continue continue
} }

View File

@@ -49,6 +49,7 @@ type DNSTransport struct {
dnsRouter adapter.DNSRouter dnsRouter adapter.DNSRouter
endpointManager adapter.EndpointManager endpointManager adapter.EndpointManager
endpoint *Endpoint endpoint *Endpoint
access sync.RWMutex
routePrefixes []netip.Prefix routePrefixes []netip.Prefix
routes map[string][]adapter.DNSTransport routes map[string][]adapter.DNSTransport
hosts map[string][]netip.Addr hosts map[string][]netip.Addr
@@ -91,6 +92,12 @@ func (t *DNSTransport) Start(stage adapter.StartStage) error {
} }
func (t *DNSTransport) Reset() { func (t *DNSTransport) Reset() {
t.access.RLock()
transports := t.collectResolversLocked()
t.access.RUnlock()
for _, transport := range transports {
transport.Reset()
}
} }
func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) { func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) {
@@ -101,7 +108,7 @@ func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, d
} }
func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *nDNS.Config) error { func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *nDNS.Config) error {
t.routePrefixes = buildRoutePrefixes(routeConfig) routePrefixes := buildRoutePrefixes(routeConfig)
directDialerOnce := sync.OnceValue(func() N.Dialer { directDialerOnce := sync.OnceValue(func() N.Dialer {
directDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{})) directDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{}))
return &DNSDialer{transport: t, fallbackDialer: directDialer} return &DNSDialer{transport: t, fallbackDialer: directDialer}
@@ -130,9 +137,19 @@ func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *n
} }
defaultResolvers = append(defaultResolvers, myResolver) defaultResolvers = append(defaultResolvers, myResolver)
} }
t.access.Lock()
oldResolvers := t.collectResolversLocked()
t.routePrefixes = routePrefixes
t.routes = routes t.routes = routes
t.hosts = hosts t.hosts = hosts
t.defaultResolvers = defaultResolvers t.defaultResolvers = defaultResolvers
t.access.Unlock()
for _, transport := range oldResolvers {
transport.Close()
}
if len(defaultResolvers) > 0 { if len(defaultResolvers) > 0 {
t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts, default resolvers: ", t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts, default resolvers: ",
strings.Join(common.Map(dnsConfig.DefaultResolvers, func(it *dnstype.Resolver) string { return it.Addr }), " ")) strings.Join(common.Map(dnsConfig.DefaultResolvers, func(it *dnstype.Resolver) string { return it.Addr }), " "))
@@ -207,7 +224,22 @@ func buildRoutePrefixes(routeConfig *router.Config) []netip.Prefix {
} }
func (t *DNSTransport) Close() error { func (t *DNSTransport) Close() error {
return nil t.access.Lock()
transports := t.collectResolversLocked()
t.routePrefixes = nil
t.routes = nil
t.hosts = nil
t.defaultResolvers = nil
t.access.Unlock()
var err error
for _, transport := range transports {
name := "resolver/" + transport.Type() + "[" + transport.Tag() + "]"
err = E.Append(err, transport.Close(), func(err error) error {
return E.Cause(err, "close ", name)
})
}
return err
} }
func (t *DNSTransport) Raw() bool { func (t *DNSTransport) Raw() bool {
@@ -219,7 +251,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }
question := message.Question[0] question := message.Question[0]
addresses, hostsLoaded := t.hosts[question.Name]
t.access.RLock()
hosts := t.hosts
routes := t.routes
defaultResolvers := t.defaultResolvers
acceptDefaultResolvers := t.acceptDefaultResolvers
t.access.RUnlock()
addresses, hostsLoaded := hosts[question.Name]
if hostsLoaded { if hostsLoaded {
switch question.Qtype { switch question.Qtype {
case mDNS.TypeA: case mDNS.TypeA:
@@ -238,7 +278,7 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
} }
} }
} }
for domainSuffix, transports := range t.routes { for domainSuffix, transports := range routes {
if strings.HasSuffix(question.Name, domainSuffix) { if strings.HasSuffix(question.Name, domainSuffix) {
if len(transports) == 0 { if len(transports) == 0 {
return &mDNS.Msg{ return &mDNS.Msg{
@@ -262,10 +302,10 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
return nil, lastErr return nil, lastErr
} }
} }
if t.acceptDefaultResolvers { if acceptDefaultResolvers {
if len(t.defaultResolvers) > 0 { if len(defaultResolvers) > 0 {
var lastErr error var lastErr error
for _, resolver := range t.defaultResolvers { for _, resolver := range defaultResolvers {
response, err := resolver.Exchange(ctx, message) response, err := resolver.Exchange(ctx, message)
if err != nil { if err != nil {
lastErr = err lastErr = err
@@ -281,6 +321,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
return nil, dns.RcodeNameError return nil, dns.RcodeNameError
} }
func (t *DNSTransport) collectResolversLocked() []adapter.DNSTransport {
var transports []adapter.DNSTransport
for _, resolvers := range t.routes {
transports = append(transports, resolvers...)
}
transports = append(transports, t.defaultResolvers...)
return transports
}
type DNSDialer struct { type DNSDialer struct {
transport *DNSTransport transport *DNSTransport
fallbackDialer N.Dialer fallbackDialer N.Dialer
@@ -290,7 +339,8 @@ func (d *DNSDialer) DialContext(ctx context.Context, network string, destination
if destination.IsDomain() { if destination.IsDomain() {
panic("invalid request here") panic("invalid request here")
} }
for _, prefix := range d.transport.routePrefixes { routePrefixes := d.transport.routePrefixesSnapshot()
for _, prefix := range routePrefixes {
if prefix.Contains(destination.Addr) { if prefix.Contains(destination.Addr) {
return d.transport.endpoint.DialContext(ctx, network, destination) return d.transport.endpoint.DialContext(ctx, network, destination)
} }
@@ -302,10 +352,17 @@ func (d *DNSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (
if destination.IsDomain() { if destination.IsDomain() {
panic("invalid request here") panic("invalid request here")
} }
for _, prefix := range d.transport.routePrefixes { routePrefixes := d.transport.routePrefixesSnapshot()
for _, prefix := range routePrefixes {
if prefix.Contains(destination.Addr) { if prefix.Contains(destination.Addr) {
return d.transport.endpoint.ListenPacket(ctx, destination) return d.transport.endpoint.ListenPacket(ctx, destination)
} }
} }
return d.fallbackDialer.ListenPacket(ctx, destination) return d.fallbackDialer.ListenPacket(ctx, destination)
} }
func (t *DNSTransport) routePrefixesSnapshot() []netip.Prefix {
t.access.RLock()
defer t.access.RUnlock()
return append([]netip.Prefix(nil), t.routePrefixes...)
}