Migrate components to library

This commit is contained in:
世界
2022-07-11 18:44:59 +08:00
parent f548511d7e
commit 3d16df288c
46 changed files with 109 additions and 2438 deletions

View File

@@ -37,8 +37,8 @@ func NewDefault(router adapter.Router, options option.DialerOptions) *DefaultDia
listener.Control = control.Append(listener.Control, control.ReuseAddr())
}
if options.ProtectPath != "" {
dialer.Control = control.Append(dialer.Control, ProtectPath(options.ProtectPath))
listener.Control = control.Append(listener.Control, ProtectPath(options.ProtectPath))
dialer.Control = control.Append(dialer.Control, control.ProtectPath(options.ProtectPath))
listener.Control = control.Append(listener.Control, control.ProtectPath(options.ProtectPath))
}
if options.ConnectTimeout != 0 {
dialer.Timeout = time.Duration(options.ConnectTimeout)

View File

@@ -4,8 +4,8 @@ import (
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
N "github.com/sagernet/sing/common/network"
)
@@ -20,13 +20,9 @@ func New(router adapter.Router, options option.DialerOptions) N.Dialer {
func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N.Dialer {
dialer := New(router, options.DialerOptions)
domainStrategy := C.DomainStrategy(options.DomainStrategy)
if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" {
fallbackDelay := time.Duration(options.FallbackDelay)
if fallbackDelay == 0 {
fallbackDelay = time.Millisecond * 300
}
dialer = NewResolveDialer(router, dialer, domainStrategy, fallbackDelay)
domainStrategy := dns.DomainStrategy(options.DomainStrategy)
if domainStrategy != dns.DomainStrategyAsIS || options.Detour == "" {
dialer = NewResolveDialer(router, dialer, domainStrategy, time.Duration(options.FallbackDelay))
}
if options.OverrideOptions.IsValid() {
dialer = NewOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions))

View File

@@ -1,90 +0,0 @@
package dialer
import (
"context"
"net"
"net/netip"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func DialParallel(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.DomainStrategy, fallbackDelay time.Duration) (net.Conn, error) {
// kanged form net.Dial
returned := make(chan struct{})
defer close(returned)
addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
return address.Is4() || address.Is4In6()
})
addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
return address.Is6() && !address.Is4In6()
})
if len(addresses4) == 0 || len(addresses6) == 0 {
return DialSerial(ctx, dialer, network, destination, destinationAddresses)
}
var primaries, fallbacks []netip.Addr
switch strategy {
case C.DomainStrategyPreferIPv6:
primaries = addresses6
fallbacks = addresses4
default:
primaries = addresses4
fallbacks = addresses6
}
type dialResult struct {
net.Conn
error
primary bool
done bool
}
results := make(chan dialResult) // unbuffered
startRacer := func(ctx context.Context, primary bool) {
ras := primaries
if !primary {
ras = fallbacks
}
c, err := DialSerial(ctx, dialer, network, destination, ras)
select {
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
case <-returned:
if c != nil {
c.Close()
}
}
}
var primary, fallback dialResult
primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
go startRacer(primaryCtx, true)
fallbackTimer := time.NewTimer(fallbackDelay)
defer fallbackTimer.Stop()
for {
select {
case <-fallbackTimer.C:
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
go startRacer(fallbackCtx, false)
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if res.primary {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
return nil, primary.error
}
if res.primary && fallbackTimer.Stop() {
fallbackTimer.Reset(0)
}
}
}
}

View File

@@ -1,46 +0,0 @@
//go:build android || with_protect
package dialer
import (
"syscall"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
)
func sendAncillaryFileDescriptors(protectPath string, fileDescriptors []int) error {
socket, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
if err != nil {
return E.Cause(err, "open protect socket")
}
defer syscall.Close(socket)
err = syscall.Connect(socket, &syscall.SockaddrUnix{Name: protectPath})
if err != nil {
return E.Cause(err, "connect protect path")
}
oob := syscall.UnixRights(fileDescriptors...)
dummy := []byte{1}
err = syscall.Sendmsg(socket, dummy, oob, nil, 0)
if err != nil {
return err
}
n, err := syscall.Read(socket, dummy)
if err != nil {
return err
}
if n != 1 {
return E.New("failed to protect fd")
}
return nil
}
func ProtectPath(protectPath string) control.Func {
return func(network, address string, conn syscall.RawConn) error {
var innerErr error
err := conn.Control(func(fd uintptr) {
innerErr = sendAncillaryFileDescriptors(protectPath, []int{int(fd)})
})
return E.Errors(innerErr, err)
}
}

View File

@@ -1,9 +0,0 @@
//go:build !android && !with_protect
package dialer
import "github.com/sagernet/sing/common/control"
func ProtectPath(protectPath string) control.Func {
return nil
}

View File

@@ -7,7 +7,7 @@ import (
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
@@ -15,11 +15,11 @@ import (
type ResolveDialer struct {
dialer N.Dialer
router adapter.Router
strategy C.DomainStrategy
strategy dns.DomainStrategy
fallbackDelay time.Duration
}
func NewResolveDialer(router adapter.Router, dialer N.Dialer, strategy C.DomainStrategy, fallbackDelay time.Duration) *ResolveDialer {
func NewResolveDialer(router adapter.Router, dialer N.Dialer, strategy dns.DomainStrategy, fallbackDelay time.Duration) *ResolveDialer {
return &ResolveDialer{
dialer,
router,
@@ -37,7 +37,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == C.DomainStrategyAsIS {
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
@@ -45,7 +45,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina
if err != nil {
return nil, err
}
return DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy, d.fallbackDelay)
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, d.fallbackDelay)
}
func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
@@ -57,7 +57,7 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == C.DomainStrategyAsIS {
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
@@ -65,7 +65,7 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
if err != nil {
return nil, err
}
conn, err := ListenSerial(ctx, d.dialer, destination, addresses)
conn, err := N.ListenSerial(ctx, d.dialer, destination, addresses)
if err != nil {
return nil, err
}

View File

@@ -5,14 +5,14 @@ import (
"net"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func NewResolvePacketConn(router adapter.Router, strategy C.DomainStrategy, conn net.PacketConn) N.NetPacketConn {
func NewResolvePacketConn(router adapter.Router, strategy dns.DomainStrategy, conn net.PacketConn) N.NetPacketConn {
if udpConn, ok := conn.(*net.UDPConn); ok {
return &ResolveUDPConn{udpConn, router, strategy}
} else {
@@ -23,7 +23,7 @@ func NewResolvePacketConn(router adapter.Router, strategy C.DomainStrategy, conn
type ResolveUDPConn struct {
*net.UDPConn
router adapter.Router
strategy C.DomainStrategy
strategy dns.DomainStrategy
}
func (w *ResolveUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
@@ -54,7 +54,7 @@ func (w *ResolveUDPConn) Upstream() any {
type ResolvePacketConn struct {
net.PacketConn
router adapter.Router
strategy C.DomainStrategy
strategy dns.DomainStrategy
}
func (w *ResolvePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {

View File

@@ -1,41 +0,0 @@
package dialer
import (
"context"
"net"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func DialSerial(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) {
var conn net.Conn
var err error
var connErrors []error
for _, address := range destinationAddresses {
conn, err = dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port))
if err != nil {
connErrors = append(connErrors, err)
continue
}
return conn, nil
}
return nil, E.Errors(connErrors...)
}
func ListenSerial(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.PacketConn, error) {
var conn net.PacketConn
var err error
var connErrors []error
for _, address := range destinationAddresses {
conn, err = dialer.ListenPacket(ctx, M.SocksaddrFromAddrPort(address, destination.Port))
if err != nil {
connErrors = append(connErrors, err)
continue
}
return conn, nil
}
return nil, E.Errors(connErrors...)
}

View File

@@ -1,60 +0,0 @@
package domain
import (
"sort"
"unicode/utf8"
)
type Matcher struct {
set *succinctSet
}
func NewMatcher(domains []string, domainSuffix []string) *Matcher {
domainList := make([]string, 0, len(domains)+len(domainSuffix))
seen := make(map[string]bool, len(domainList))
for _, domain := range domainSuffix {
if seen[domain] {
continue
}
seen[domain] = true
domainList = append(domainList, reverseDomainSuffix(domain))
}
for _, domain := range domains {
if seen[domain] {
continue
}
seen[domain] = true
domainList = append(domainList, reverseDomain(domain))
}
sort.Strings(domainList)
return &Matcher{
newSuccinctSet(domainList),
}
}
func (m *Matcher) Match(domain string) bool {
return m.set.Has(reverseDomain(domain))
}
func reverseDomain(domain string) string {
l := len(domain)
b := make([]byte, l)
for i := 0; i < l; {
r, n := utf8.DecodeRuneInString(domain[i:])
i += n
utf8.EncodeRune(b[l-i:], r)
}
return string(b)
}
func reverseDomainSuffix(domain string) string {
l := len(domain)
b := make([]byte, l+1)
for i := 0; i < l; {
r, n := utf8.DecodeRuneInString(domain[i:])
i += n
utf8.EncodeRune(b[l-i:], r)
}
b[l] = prefixLabel
return string(b)
}

View File

@@ -1,21 +0,0 @@
package domain_test
import (
"testing"
"github.com/sagernet/sing-box/common/domain"
"github.com/stretchr/testify/require"
)
func TestMatch(t *testing.T) {
t.Parallel()
r := require.New(t)
matcher := domain.NewMatcher([]string{"domain.com"}, []string{"suffix.com", ".suffix.org"})
r.True(matcher.Match("domain.com"))
r.False(matcher.Match("my.domain.com"))
r.True(matcher.Match("suffix.com"))
r.True(matcher.Match("my.suffix.com"))
r.False(matcher.Match("suffix.org"))
r.True(matcher.Match("my.suffix.org"))
}

View File

@@ -1,231 +0,0 @@
package domain
import (
"math/bits"
)
const prefixLabel = '\r'
// mod from https://github.com/openacid/succinct
type succinctSet struct {
leaves, labelBitmap []uint64
labels []byte
ranks, selects []int32
}
func newSuccinctSet(keys []string) *succinctSet {
ss := &succinctSet{}
lIdx := 0
type qElt struct{ s, e, col int }
queue := []qElt{{0, len(keys), 0}}
for i := 0; i < len(queue); i++ {
elt := queue[i]
if elt.col == len(keys[elt.s]) {
// a leaf node
elt.s++
setBit(&ss.leaves, i, 1)
}
for j := elt.s; j < elt.e; {
frm := j
for ; j < elt.e && keys[j][elt.col] == keys[frm][elt.col]; j++ {
}
queue = append(queue, qElt{frm, j, elt.col + 1})
ss.labels = append(ss.labels, keys[frm][elt.col])
setBit(&ss.labelBitmap, lIdx, 0)
lIdx++
}
setBit(&ss.labelBitmap, lIdx, 1)
lIdx++
}
ss.init()
return ss
}
func (ss *succinctSet) Has(key string) bool {
var nodeId, bmIdx int
for i := 0; i < len(key); i++ {
currentChar := key[i]
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return false
}
nextLabel := ss.labels[bmIdx-nodeId]
if nextLabel == prefixLabel {
return true
}
if nextLabel == currentChar {
break
}
}
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
}
if getBit(ss.leaves, nodeId) != 0 {
return true
}
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return false
}
if ss.labels[bmIdx-nodeId] == prefixLabel {
return true
}
}
}
func setBit(bm *[]uint64, i int, v int) {
for i>>6 >= len(*bm) {
*bm = append(*bm, 0)
}
(*bm)[i>>6] |= uint64(v) << uint(i&63)
}
func getBit(bm []uint64, i int) uint64 {
return bm[i>>6] & (1 << uint(i&63))
}
func (ss *succinctSet) init() {
ss.selects, ss.ranks = indexSelect32R64(ss.labelBitmap)
}
func countZeros(bm []uint64, ranks []int32, i int) int {
a, _ := rank64(bm, ranks, int32(i))
return i - int(a)
}
func selectIthOne(bm []uint64, ranks, selects []int32, i int) int {
a, _ := select32R64(bm, selects, ranks, int32(i))
return int(a)
}
func rank64(words []uint64, rindex []int32, i int32) (int32, int32) {
wordI := i >> 6
j := uint32(i & 63)
n := rindex[wordI]
w := words[wordI]
c1 := n + int32(bits.OnesCount64(w&mask[j]))
return c1, int32(w>>uint(j)) & 1
}
func indexRank64(words []uint64, opts ...bool) []int32 {
trailing := false
if len(opts) > 0 {
trailing = opts[0]
}
l := len(words)
if trailing {
l++
}
idx := make([]int32, l)
n := int32(0)
for i := 0; i < len(words); i++ {
idx[i] = n
n += int32(bits.OnesCount64(words[i]))
}
if trailing {
idx[len(words)] = n
}
return idx
}
func select32R64(words []uint64, selectIndex, rankIndex []int32, i int32) (int32, int32) {
a := int32(0)
l := int32(len(words))
wordI := selectIndex[i>>5] >> 6
for ; rankIndex[wordI+1] <= i; wordI++ {
}
w := words[wordI]
ww := w
base := wordI << 6
findIth := int(i - rankIndex[wordI])
offset := int32(0)
ones := bits.OnesCount32(uint32(ww))
if ones <= findIth {
findIth -= ones
offset |= 32
ww >>= 32
}
ones = bits.OnesCount16(uint16(ww))
if ones <= findIth {
findIth -= ones
offset |= 16
ww >>= 16
}
ones = bits.OnesCount8(uint8(ww))
if ones <= findIth {
a = int32(select8Lookup[(ww>>5)&(0x7f8)|uint64(findIth-ones)]) + offset + 8
} else {
a = int32(select8Lookup[(ww&0xff)<<3|uint64(findIth)]) + offset
}
a += base
w &= rMaskUpto[a&63]
if w != 0 {
return a, base + int32(bits.TrailingZeros64(w))
}
wordI++
for ; wordI < l; wordI++ {
w = words[wordI]
if w != 0 {
return a, wordI<<6 + int32(bits.TrailingZeros64(w))
}
}
return a, l << 6
}
func indexSelect32R64(words []uint64) ([]int32, []int32) {
l := len(words) << 6
sidx := make([]int32, 0, len(words))
ith := -1
for i := 0; i < l; i++ {
if words[i>>6]&(1<<uint(i&63)) != 0 {
ith++
if ith&31 == 0 {
sidx = append(sidx, int32(i))
}
}
}
// clone to reduce cap to len
sidx = append(sidx[:0:0], sidx...)
return sidx, indexRank64(words, true)
}
func init() {
initMasks()
initSelectLookup()
}
var (
mask [65]uint64
rMaskUpto [64]uint64
)
func initMasks() {
for i := 0; i < 65; i++ {
mask[i] = (1 << uint(i)) - 1
}
var maskUpto [64]uint64
for i := 0; i < 64; i++ {
maskUpto[i] = (1 << uint(i+1)) - 1
rMaskUpto[i] = ^maskUpto[i]
}
}
var select8Lookup [256 * 8]uint8
func initSelectLookup() {
for i := 0; i < 256; i++ {
w := uint8(i)
for j := 0; j < 8; j++ {
// x-th 1 in w
// if x-th 1 is not found, it is 8
x := bits.TrailingZeros8(w)
w &= w - 1
select8Lookup[i*8+j] = uint8(x)
}
}
}

View File

@@ -1,151 +0,0 @@
package tun
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const defaultNIC tcpip.NICID = 1
var _ adapter.Service = (*GVisorTun)(nil)
type GVisorTun struct {
ctx context.Context
tunFd uintptr
tunMtu uint32
handler Handler
stack *stack.Stack
}
func NewGVisor(ctx context.Context, tunFd uintptr, tunMtu uint32, handler Handler) *GVisorTun {
return &GVisorTun{
ctx: ctx,
tunFd: tunFd,
tunMtu: tunMtu,
handler: handler,
}
}
func (t *GVisorTun) Start() error {
linkEndpoint, err := NewEndpoint(t.tunFd, t.tunMtu)
if err != nil {
return err
}
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
tErr := ipStack.CreateNIC(defaultNIC, linkEndpoint)
if tErr != nil {
return E.New("create nic: ", tErr)
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
})
ipStack.SetSpoofing(defaultNIC, true)
ipStack.SetPromiscuousMode(defaultNIC, true)
bufSize := 20 * 1024
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
endpoint, err := r.CreateEndpoint(&wq)
if err != nil {
r.Complete(true)
return
}
r.Complete(false)
endpoint.SocketOptions().SetKeepAlive(true)
tcpConn := gonet.NewTCPConn(&wq, endpoint)
lAddr := tcpConn.RemoteAddr()
rAddr := tcpConn.LocalAddr()
if lAddr == nil || rAddr == nil {
tcpConn.Close()
return
}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
ctx := log.ContextWithID(t.ctx)
hErr := t.handler.NewConnection(ctx, tcpConn, metadata)
if hErr != nil {
t.handler.NewError(ctx, hErr)
}
}()
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool {
return tcpForwarder.HandlePacket(id, buffer)
})
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
var wq waiter.Queue
endpoint, err := request.CreateEndpoint(&wq)
if err != nil {
return
}
udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint)
lAddr := udpConn.RemoteAddr()
rAddr := udpConn.LocalAddr()
if lAddr == nil || rAddr == nil {
udpConn.Close()
return
}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
ctx := log.ContextWithID(t.ctx)
hErr := t.handler.NewPacketConnection(ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(udpConn), Addr: M.SocksaddrFromNet(rAddr)}), metadata)
if hErr != nil {
t.handler.NewError(ctx, hErr)
}
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
t.stack = ipStack
return nil
}
func (t *GVisorTun) Close() error {
return common.Close(
common.PtrOrNil(t.stack),
)
}

View File

@@ -1,24 +0,0 @@
//go:build linux
package tun
import (
"runtime"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
func NewEndpoint(tunFd uintptr, tunMtu uint32) (stack.LinkEndpoint, error) {
var packetDispatchMode fdbased.PacketDispatchMode
if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
packetDispatchMode = fdbased.PacketMMap
} else {
packetDispatchMode = fdbased.RecvMMsg
}
return fdbased.New(&fdbased.Options{
FDs: []int{int(tunFd)},
MTU: tunMtu,
PacketDispatchMode: packetDispatchMode,
})
}

View File

@@ -1,9 +0,0 @@
//go:build !linux
package tun
import "gvisor.dev/gvisor/pkg/tcpip/stack"
func NewEndpoint(tunFd uintptr, tunMtu uint32) (stack.LinkEndpoint, error) {
return NewPosixEndpoint(tunFd, tunMtu)
}

View File

@@ -1,118 +0,0 @@
package tun
import (
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/rw"
gBuffer "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
var _ stack.LinkEndpoint = (*PosixEndpoint)(nil)
type PosixEndpoint struct {
fd uintptr
mtu uint32
file *os.File
dispatcher stack.NetworkDispatcher
}
func NewPosixEndpoint(tunFd uintptr, tunMtu uint32) (stack.LinkEndpoint, error) {
return &PosixEndpoint{
fd: tunFd,
mtu: tunMtu,
file: os.NewFile(tunFd, "tun"),
}, nil
}
func (e *PosixEndpoint) MTU() uint32 {
return e.mtu
}
func (e *PosixEndpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *PosixEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (e *PosixEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (e *PosixEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
if dispatcher == nil && e.dispatcher != nil {
e.dispatcher = nil
return
}
if dispatcher != nil && e.dispatcher == nil {
e.dispatcher = dispatcher
go e.dispatchLoop()
}
}
func (e *PosixEndpoint) dispatchLoop() {
_buffer := buf.StackNewPacket()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
for {
n, err := e.file.Read(buffer.FreeBytes())
if err != nil {
break
}
var view gBuffer.View
view.Append(buffer.To(n))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: view,
IsForwardedPacket: true,
})
defer pkt.DecRef()
var p tcpip.NetworkProtocolNumber
ipHeader, ok := pkt.Data().PullUp(1)
if !ok {
continue
}
switch header.IPVersion(ipHeader) {
case header.IPv4Version:
p = header.IPv4ProtocolNumber
case header.IPv6Version:
p = header.IPv6ProtocolNumber
default:
continue
}
e.dispatcher.DeliverNetworkPacket(p, pkt)
}
}
func (e *PosixEndpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *PosixEndpoint) Wait() {
}
func (e *PosixEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *PosixEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (e *PosixEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var n int
for _, packet := range pkts.AsSlice() {
_, err := rw.WriteV(e.fd, packet.Slices())
if err != nil {
return n, &tcpip.ErrAborted{}
}
n++
}
return n, nil
}

View File

@@ -1,12 +0,0 @@
package tun
import (
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
)
type Handler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
E.Handler
}

View File

@@ -1,113 +0,0 @@
package tun
import (
"net"
"net/netip"
"github.com/vishvananda/netlink"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
)
func Open(name string) (uintptr, error) {
tunFd, err := tun.Open(name)
if err != nil {
return 0, err
}
return uintptr(tunFd), nil
}
func Configure(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu uint32, autoRoute bool) error {
tunLink, err := netlink.LinkByName(name)
if err != nil {
return err
}
if inet4Address.IsValid() {
addr4, _ := netlink.ParseAddr(inet4Address.String())
err = netlink.AddrAdd(tunLink, addr4)
if err != nil {
return err
}
}
if inet6Address.IsValid() {
addr6, _ := netlink.ParseAddr(inet6Address.String())
err = netlink.AddrAdd(tunLink, addr6)
if err != nil {
return err
}
}
err = netlink.LinkSetMTU(tunLink, int(mtu))
if err != nil {
return err
}
err = netlink.LinkSetUp(tunLink)
if err != nil {
return err
}
if autoRoute {
if inet4Address.IsValid() {
err = netlink.RouteAdd(&netlink.Route{
Dst: &net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(0, 32),
},
LinkIndex: tunLink.Attrs().Index,
})
if err != nil {
return err
}
}
if inet6Address.IsValid() {
err = netlink.RouteAdd(&netlink.Route{
Dst: &net.IPNet{
IP: net.IPv6zero,
Mask: net.CIDRMask(0, 128),
},
LinkIndex: tunLink.Attrs().Index,
})
if err != nil {
return err
}
}
}
return nil
}
func UnConfigure(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, autoRoute bool) error {
if autoRoute {
tunLink, err := netlink.LinkByName(name)
if err != nil {
return err
}
if inet4Address.IsValid() {
err = netlink.RouteDel(&netlink.Route{
Dst: &net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(0, 32),
},
LinkIndex: tunLink.Attrs().Index,
})
if err != nil {
return err
}
}
if inet6Address.IsValid() {
err = netlink.RouteDel(&netlink.Route{
Dst: &net.IPNet{
IP: net.IPv6zero,
Mask: net.CIDRMask(0, 128),
},
LinkIndex: tunLink.Attrs().Index,
})
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -1,20 +0,0 @@
//go:build !linux
package tun
import (
"net/netip"
"os"
)
func Open(name string) (uintptr, error) {
return 0, os.ErrInvalid
}
func Configure(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu uint32, autoRoute bool) error {
return os.ErrInvalid
}
func UnConfigure(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, autoRoute bool) error {
return os.ErrInvalid
}