Add L3 routing support

This commit is contained in:
世界
2023-03-21 21:36:17 +08:00
parent 9b762e3c2e
commit afcc8b204f
61 changed files with 2043 additions and 965 deletions

View File

@@ -1,13 +1,23 @@
package wireguard
import (
"net/netip"
"github.com/sagernet/sing-tun"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/tun"
wgTun "github.com/sagernet/wireguard-go/tun"
)
type Device interface {
tun.Device
wgTun.Device
N.Dialer
Start() error
Inet4Address() netip.Addr
Inet6Address() netip.Addr
// NewEndpoint() (stack.LinkEndpoint, error)
}
type NatDevice interface {
Device
CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination
}

View File

@@ -0,0 +1,75 @@
package wireguard
import (
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
)
var _ Device = (*natDeviceWrapper)(nil)
type natDeviceWrapper struct {
Device
outbound chan *buf.Buffer
mapping *tun.NatMapping
writer *tun.NatWriter
}
func NewNATDevice(upstream Device, ipRewrite bool) NatDevice {
wrapper := &natDeviceWrapper{
Device: upstream,
outbound: make(chan *buf.Buffer, 256),
mapping: tun.NewNatMapping(ipRewrite),
}
if ipRewrite {
wrapper.writer = tun.NewNatWriter(upstream.Inet4Address(), upstream.Inet6Address())
}
return wrapper
}
func (d *natDeviceWrapper) Read(p []byte, offset int) (int, error) {
select {
case packet := <-d.outbound:
defer packet.Release()
return copy(p[offset:], packet.Bytes()), nil
default:
}
return d.Device.Read(p, offset)
}
func (d *natDeviceWrapper) Write(p []byte, offset int) (int, error) {
packet := p[offset:]
handled, err := d.mapping.WritePacket(packet)
if handled {
return len(packet), err
}
return d.Device.Write(p, offset)
}
func (d *natDeviceWrapper) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination {
d.mapping.CreateSession(session, conn)
return &natDestinationWrapper{d, session}
}
var _ tun.DirectDestination = (*natDestinationWrapper)(nil)
type natDestinationWrapper struct {
device *natDeviceWrapper
session tun.RouteSession
}
func (d *natDestinationWrapper) WritePacket(buffer *buf.Buffer) error {
if d.device.writer != nil {
d.device.writer.RewritePacket(buffer.Bytes())
}
d.device.outbound <- buffer
return nil
}
func (d *natDestinationWrapper) Close() error {
d.device.mapping.DeleteSession(d.session)
return nil
}
func (d *natDestinationWrapper) Timeout() bool {
return false
}

View File

@@ -0,0 +1,27 @@
//go:build with_gvisor
package wireguard
import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
func (d *natDestinationWrapper) WritePacketBuffer(buffer *stack.PacketBuffer) error {
defer buffer.DecRef()
if d.device.writer != nil {
d.device.writer.RewritePacketBuffer(buffer)
}
var packetLen int
for _, slice := range buffer.AsSlices() {
packetLen += len(slice)
}
packet := buf.NewSize(packetLen)
for _, slice := range buffer.AsSlices() {
common.Must1(packet.Write(slice))
}
d.device.outbound <- packet
return nil
}

View File

@@ -8,10 +8,12 @@ import (
"net/netip"
"os"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/tun"
wgTun "github.com/sagernet/wireguard-go/tun"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -25,33 +27,38 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
var _ Device = (*StackDevice)(nil)
var _ NatDevice = (*StackDevice)(nil)
const defaultNIC tcpip.NICID = 1
type StackDevice struct {
stack *stack.Stack
mtu uint32
events chan tun.Event
outbound chan *stack.PacketBuffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
packetOutbound chan *buf.Buffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
mapping *tun.NatMapping
writer *tun.NatWriter
}
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32, ipRewrite bool) (*StackDevice, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
HandleLocal: true,
})
tunDevice := &StackDevice{
stack: ipStack,
mtu: mtu,
events: make(chan tun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
done: make(chan struct{}),
stack: ipStack,
mtu: mtu,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
packetOutbound: make(chan *buf.Buffer, 256),
done: make(chan struct{}),
mapping: tun.NewNatMapping(ipRewrite),
}
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
if err != nil {
@@ -77,6 +84,9 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
}
}
if ipRewrite {
tunDevice.writer = tun.NewNatWriter(tunDevice.Inet4Address(), tunDevice.Inet6Address())
}
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
cOpt := tcpip.CongestionControlOption("cubic")
@@ -144,8 +154,16 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
return udpConn, nil
}
func (w *StackDevice) Inet4Address() netip.Addr {
return M.AddrFromIP(net.IP(w.addr4))
}
func (w *StackDevice) Inet6Address() netip.Addr {
return M.AddrFromIP(net.IP(w.addr6))
}
func (w *StackDevice) Start() error {
w.events <- tun.EventUp
w.events <- wgTun.EventUp
return nil
}
@@ -165,6 +183,10 @@ func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
n += copy(p[n:], slice)
}
return
case packet := <-w.packetOutbound:
defer packet.Release()
n = copy(p[offset:], packet.Bytes())
return
case <-w.done:
return 0, os.ErrClosed
}
@@ -175,6 +197,10 @@ func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
if len(p) == 0 {
return
}
handled, err := w.mapping.WritePacket(p)
if handled {
return len(p), err
}
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(p) {
case header.IPv4Version:
@@ -203,7 +229,7 @@ func (w *StackDevice) Name() (string, error) {
return "sing-box", nil
}
func (w *StackDevice) Events() chan tun.Event {
func (w *StackDevice) Events() chan wgTun.Event {
return w.events
}
@@ -222,6 +248,44 @@ func (w *StackDevice) Close() error {
return nil
}
func (w *StackDevice) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination {
w.mapping.CreateSession(session, conn)
return &stackNatDestination{
device: w,
session: session,
}
}
type stackNatDestination struct {
device *StackDevice
session tun.RouteSession
}
func (d *stackNatDestination) WritePacket(buffer *buf.Buffer) error {
if d.device.writer != nil {
d.device.writer.RewritePacket(buffer.Bytes())
}
d.device.packetOutbound <- buffer
return nil
}
func (d *stackNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error {
if d.device.writer != nil {
d.device.writer.RewritePacketBuffer(buffer)
}
d.device.outbound <- buffer
return nil
}
func (d *stackNatDestination) Close() error {
d.device.mapping.DeleteSession(d.session)
return nil
}
func (d *stackNatDestination) Timeout() bool {
return false
}
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint StackDevice

View File

@@ -8,6 +8,6 @@ import (
"github.com/sagernet/sing-tun"
)
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32, ipRewrite bool) (Device, error) {
return nil, tun.ErrGVisorNotIncluded
}

View File

@@ -23,6 +23,8 @@ type SystemDevice struct {
name string
mtu int
events chan wgTun.Event
addr4 netip.Addr
addr6 netip.Addr
}
/*func (w *SystemDevice) NewEndpoint() (stack.LinkEndpoint, error) {
@@ -55,11 +57,24 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes
if err != nil {
return nil, err
}
var inet4Address netip.Addr
var inet6Address netip.Addr
if len(inet4Addresses) > 0 {
inet4Address = inet4Addresses[0].Addr()
}
if len(inet6Addresses) > 0 {
inet6Address = inet6Addresses[0].Addr()
}
return &SystemDevice{
dialer.NewDefault(router, option.DialerOptions{
dialer: dialer.NewDefault(router, option.DialerOptions{
BindInterface: interfaceName,
}),
tunInterface, interfaceName, int(mtu), make(chan wgTun.Event),
device: tunInterface,
name: interfaceName,
mtu: int(mtu),
events: make(chan wgTun.Event),
addr4: inet4Address,
addr6: inet6Address,
}, nil
}
@@ -71,6 +86,14 @@ func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr
return w.dialer.ListenPacket(ctx, destination)
}
func (w *SystemDevice) Inet4Address() netip.Addr {
return w.addr4
}
func (w *SystemDevice) Inet6Address() netip.Addr {
return w.addr6
}
func (w *SystemDevice) Start() error {
w.events <- wgTun.EventUp
return nil
@@ -80,12 +103,12 @@ func (w *SystemDevice) File() *os.File {
return nil
}
func (w *SystemDevice) Read(bytes []byte, index int) (int, error) {
return w.device.Read(bytes[index-tun.PacketOffset:])
func (w *SystemDevice) Read(p []byte, offset int) (int, error) {
return w.device.Read(p[offset-tun.PacketOffset:])
}
func (w *SystemDevice) Write(bytes []byte, index int) (int, error) {
return w.device.Write(bytes[index:])
func (w *SystemDevice) Write(p []byte, offset int) (int, error) {
return w.device.Write(p[offset:])
}
func (w *SystemDevice) Flush() error {