Files
sing-box-extended/dns/transport/tls.go

143 lines
3.9 KiB
Go

package transport
import (
"context"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns"
)
var _ adapter.DNSTransport = (*TLSTransport)(nil)
func RegisterTLS(registry *dns.TransportRegistry) {
dns.RegisterTransport[option.RemoteTLSDNSServerOptions](registry, C.DNSTypeTLS, NewTLS)
}
type TLSTransport struct {
dns.TransportAdapter
logger logger.ContextLogger
dialer tls.Dialer
serverAddr M.Socksaddr
tlsConfig tls.Config
connections *ConnPool[*tlsDNSConn]
}
type tlsDNSConn struct {
tls.Conn
queryId uint16
}
func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
transportDialer, err := dns.NewRemoteDialer(ctx, options.RemoteDNSServerOptions)
if err != nil {
return nil, err
}
tlsOptions := common.PtrValueOrDefault(options.TLS)
tlsOptions.Enabled = true
tlsConfig, err := tls.NewClient(ctx, logger, options.Server, tlsOptions)
if err != nil {
return nil, err
}
serverAddr := options.DNSServerAddressOptions.Build()
if serverAddr.Port == 0 {
serverAddr.Port = 853
}
if !serverAddr.IsValid() {
return nil, E.New("invalid server address: ", serverAddr)
}
return NewTLSRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeTLS, tag, options.RemoteDNSServerOptions), transportDialer, serverAddr, tlsConfig), nil
}
func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport {
return &TLSTransport{
TransportAdapter: adapter,
logger: logger,
dialer: tls.NewDialer(dialer, tlsConfig),
serverAddr: serverAddr,
tlsConfig: tlsConfig,
connections: NewConnPool(ConnPoolOptions[*tlsDNSConn]{
Mode: ConnPoolOrdered,
IsAlive: func(conn *tlsDNSConn) bool {
return conn != nil
},
Close: func(conn *tlsDNSConn, _ error) {
conn.Close()
},
}),
}
}
func (t *TLSTransport) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return dialer.InitializeDetour(t.dialer)
}
func (t *TLSTransport) Close() error {
return t.connections.Close()
}
func (t *TLSTransport) Reset() {
t.connections.Reset()
}
func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
var lastErr error
for attempt := 0; attempt < 2; attempt++ {
conn, created, err := t.connections.Acquire(ctx, func(ctx context.Context) (*tlsDNSConn, error) {
tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr)
if err != nil {
return nil, E.Cause(err, "dial TLS connection")
}
return &tlsDNSConn{Conn: tlsConn}, nil
})
if err != nil {
return nil, err
}
response, err := t.exchange(ctx, message, conn)
if err == nil {
t.connections.Release(conn, true)
return response, nil
}
lastErr = err
t.logger.DebugContext(ctx, "discarded pooled connection: ", err)
t.connections.Release(conn, false)
if created {
return nil, err
}
}
return nil, lastErr
}
func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
if deadline, ok := ctx.Deadline(); ok {
conn.SetDeadline(deadline)
}
conn.queryId++
err := WriteMessage(conn, conn.queryId, message)
if err != nil {
return nil, E.Cause(err, "write request")
}
response, err := ReadMessage(conn)
if err != nil {
return nil, E.Cause(err, "read response")
}
conn.SetDeadline(time.Time{})
return response, nil
}