Add hijack_dns for tun

This commit is contained in:
世界
2022-07-10 09:15:01 +08:00
parent a34057753c
commit ddfae8cd07
15 changed files with 220 additions and 99 deletions

View File

@@ -79,6 +79,6 @@ func (d *Direct) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.B
case 3:
metadata.Destination.Port = d.overrideDestination.Port
}
d.udpNat.NewPacketDirect(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), metadata.Source.AddrPort(), conn, buffer, adapter.UpstreamMetadata(metadata))
d.udpNat.NewPacketDirect(adapter.WithContext(log.ContextWithID(ctx), &metadata), metadata.Source.AddrPort(), conn, buffer, adapter.UpstreamMetadata(metadata))
return nil
}

107
inbound/dns.go Normal file
View File

@@ -0,0 +1,107 @@
package inbound
import (
"context"
"encoding/binary"
"io"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/dns/dnsmessage"
)
func NewDNSConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn net.Conn, metadata adapter.InboundContext) error {
_buffer := buf.StackNewSize(1024)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
for {
var queryLength uint16
err := binary.Read(conn, binary.BigEndian, &queryLength)
if err != nil {
return err
}
if queryLength > 1024 {
return io.ErrShortBuffer
}
buffer.FullReset()
_, err = buffer.ReadFullFrom(conn, int(queryLength))
if err != nil {
return err
}
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
}
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
if err != nil {
return err
}
buffer.FullReset()
responseBuffer, err := response.AppendPack(buffer.Index(0))
if err != nil {
return err
}
err = binary.Write(conn, binary.BigEndian, uint16(len(responseBuffer)))
if err != nil {
return err
}
_, err = conn.Write(responseBuffer)
if err != nil {
return err
}
}
}
func NewDNSPacketConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn N.PacketConn, metadata adapter.InboundContext) error {
for {
buffer := buf.StackNewSize(1024)
destination, err := conn.ReadPacket(buffer)
if err != nil {
buffer.Release()
return err
}
var message dnsmessage.Message
err = message.Unpack(buffer.Bytes())
if err != nil {
return err
}
if len(message.Questions) > 0 {
question := message.Questions[0]
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
}
go func() error {
defer buffer.Release()
response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
if err != nil {
return err
}
buffer.FullReset()
responseBuffer, err := response.AppendPack(buffer.Index(0))
if err != nil {
return err
}
buffer.Truncate(len(responseBuffer))
err = conn.WritePacket(buffer, destination)
return err
}()
}
}
func formatDNSQuestion(question dnsmessage.Question) string {
domain := question.Name.String()
domain = domain[:len(domain)-1]
return string(question.Name.Data[:question.Name.Length-1]) + " " + question.Type.String()[4:] + " " + question.Class.String()[5:]
}

View File

@@ -73,9 +73,9 @@ func newShadowsocks(ctx context.Context, router adapter.Router, logger log.Logge
}
func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (h *Shadowsocks) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
}

View File

@@ -68,11 +68,11 @@ func newShadowsocksMulti(ctx context.Context, router adapter.Router, logger log.
}
func (h *ShadowsocksMulti) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksMulti) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksMulti) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {

View File

@@ -68,11 +68,11 @@ func newShadowsocksRelay(ctx context.Context, router adapter.Router, logger log.
}
func (h *ShadowsocksRelay) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksRelay) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
}
func (h *ShadowsocksRelay) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {

View File

@@ -20,6 +20,7 @@ import (
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
)
var _ adapter.Inbound = (*Tun)(nil)
@@ -27,23 +28,42 @@ var _ adapter.Inbound = (*Tun)(nil)
type Tun struct {
tag string
ctx context.Context
router adapter.Router
logger log.Logger
options option.TunInboundOptions
ctx context.Context
router adapter.Router
logger log.Logger
inboundOptions option.InboundOptions
tunName string
tunMTU uint32
inet4Address netip.Prefix
inet6Address netip.Prefix
autoRoute bool
hijackDNS bool
tunName string
tunFd uintptr
tun *tun.GVisorTun
tunFd uintptr
tun *tun.GVisorTun
}
func NewTun(ctx context.Context, router adapter.Router, logger log.Logger, tag string, options option.TunInboundOptions) (*Tun, error) {
tunName := options.InterfaceName
if tunName == "" {
tunName = mkInterfaceName()
}
tunMTU := options.MTU
if tunMTU == 0 {
tunMTU = 1500
}
return &Tun{
tag: tag,
ctx: ctx,
router: router,
logger: logger,
options: options,
tag: tag,
ctx: ctx,
router: router,
logger: logger,
inboundOptions: options.InboundOptions,
tunName: tunName,
tunMTU: tunMTU,
inet4Address: netip.Prefix(options.Inet4Address),
inet6Address: netip.Prefix(options.Inet6Address),
autoRoute: options.AutoRoute,
hijackDNS: options.HijackDNS,
}, nil
}
@@ -56,38 +76,26 @@ func (t *Tun) Tag() string {
}
func (t *Tun) Start() error {
tunName := t.options.InterfaceName
if tunName == "" {
tunName = mkInterfaceName()
}
var mtu uint32
if t.options.MTU != 0 {
mtu = t.options.MTU
} else {
mtu = 1500
}
tunFd, err := tun.Open(tunName)
tunFd, err := tun.Open(t.tunName)
if err != nil {
return E.Cause(err, "create tun interface")
}
err = tun.Configure(tunName, netip.Prefix(t.options.Inet4Address), netip.Prefix(t.options.Inet6Address), mtu, t.options.AutoRoute)
err = tun.Configure(t.tunName, t.inet4Address, t.inet6Address, t.tunMTU, t.autoRoute)
if err != nil {
return E.Cause(err, "configure tun interface")
}
t.tunName = tunName
t.tunFd = tunFd
t.tun = tun.NewGVisor(t.ctx, tunFd, mtu, t)
t.tun = tun.NewGVisor(t.ctx, tunFd, t.tunMTU, t)
err = t.tun.Start()
if err != nil {
return err
}
t.logger.Info("started at ", tunName)
t.logger.Info("started at ", t.tunName)
return nil
}
func (t *Tun) Close() error {
err := tun.UnConfigure(t.tunName, netip.Prefix(t.options.Inet4Address), netip.Prefix(t.options.Inet6Address), t.options.AutoRoute)
err := tun.UnConfigure(t.tunName, t.inet4Address, t.inet6Address, t.autoRoute)
if err != nil {
return err
}
@@ -98,30 +106,40 @@ func (t *Tun) Close() error {
}
func (t *Tun) NewConnection(ctx context.Context, conn net.Conn, upstreamMetadata M.Metadata) error {
t.logger.WithContext(ctx).Info("inbound connection from ", upstreamMetadata.Source)
t.logger.WithContext(ctx).Info("inbound connection to ", upstreamMetadata.Destination)
var metadata adapter.InboundContext
metadata.Inbound = t.tag
metadata.Network = C.NetworkTCP
metadata.Source = upstreamMetadata.Source
metadata.Destination = upstreamMetadata.Destination
metadata.SniffEnabled = t.options.SniffEnabled
metadata.SniffOverrideDestination = t.options.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.options.DomainStrategy)
metadata.SniffEnabled = t.inboundOptions.SniffEnabled
metadata.SniffOverrideDestination = t.inboundOptions.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.inboundOptions.DomainStrategy)
if t.hijackDNS && upstreamMetadata.Destination.Port == 53 {
return task.Run(ctx, func() error {
return NewDNSConnection(ctx, t.router, t.logger, conn, metadata)
})
}
t.logger.WithContext(ctx).Info("inbound connection from ", metadata.Source)
t.logger.WithContext(ctx).Info("inbound connection to ", metadata.Destination)
return t.router.RouteConnection(ctx, conn, metadata)
}
func (t *Tun) NewPacketConnection(ctx context.Context, conn N.PacketConn, upstreamMetadata M.Metadata) error {
t.logger.WithContext(ctx).Info("inbound packet connection from ", upstreamMetadata.Source)
t.logger.WithContext(ctx).Info("inbound packet connection to ", upstreamMetadata.Destination)
var metadata adapter.InboundContext
metadata.Inbound = t.tag
metadata.Network = C.NetworkUDP
metadata.Source = upstreamMetadata.Source
metadata.Destination = upstreamMetadata.Destination
metadata.SniffEnabled = t.options.SniffEnabled
metadata.SniffOverrideDestination = t.options.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.options.DomainStrategy)
metadata.SniffEnabled = t.inboundOptions.SniffEnabled
metadata.SniffOverrideDestination = t.inboundOptions.SniffOverrideDestination
metadata.DomainStrategy = C.DomainStrategy(t.inboundOptions.DomainStrategy)
if t.hijackDNS && upstreamMetadata.Destination.Port == 53 {
return task.Run(ctx, func() error {
return NewDNSPacketConnection(ctx, t.router, t.logger, conn, metadata)
})
}
t.logger.WithContext(ctx).Info("inbound packet connection from ", metadata.Source)
t.logger.WithContext(ctx).Info("inbound packet connection to ", metadata.Destination)
return t.router.RoutePacketConnection(ctx, conn, metadata)
}