refactor: WireGuard endpoint

This commit is contained in:
世界
2024-11-21 18:10:41 +08:00
parent 987556fd3d
commit cc8ba050dd
91 changed files with 2193 additions and 682 deletions

View File

@@ -5,7 +5,6 @@ package wireguard
import (
"context"
"net"
"net/netip"
"os"
"github.com/sagernet/gvisor/pkg/buffer"
@@ -15,52 +14,41 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun"
)
var _ Device = (*StackDevice)(nil)
var _ Device = (*stackDevice)(nil)
const defaultNIC tcpip.NICID = 1
type StackDevice struct {
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
packetOutbound chan *buf.Buffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
type stackDevice struct {
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
}
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
HandleLocal: true,
})
tunDevice := &StackDevice{
stack: ipStack,
mtu: mtu,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
packetOutbound: make(chan *buf.Buffer, 256),
done: make(chan struct{}),
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
tunDevice := &stackDevice{
mtu: options.MTU,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
done: make(chan struct{}),
}
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
if err != nil {
return nil, E.New(err.String())
return nil, err
}
for _, prefix := range localAddresses {
for _, prefix := range options.Address {
addr := tun.AddressFromAddr(prefix.Addr())
protoAddr := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -75,32 +63,27 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
tunDevice.addr6 = addr
protoAddr.Protocol = ipv6.ProtocolNumber
}
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
if err != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
if gErr != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
}
}
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
cOpt := tcpip.CongestionControlOption("cubic")
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
tunDevice.stack = ipStack
if options.Handler != nil {
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
}
return tunDevice, nil
}
func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
return (*wireEndpoint)(w), nil
}
func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
addr := tcpip.FullAddress{
NIC: defaultNIC,
NIC: tun.DefaultNIC,
Port: destination.Port,
Addr: tun.AddressFromAddr(destination.Addr),
}
bind := tcpip.FullAddress{
NIC: defaultNIC,
NIC: tun.DefaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
@@ -128,9 +111,9 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
}
}
func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
bind := tcpip.FullAddress{
NIC: defaultNIC,
NIC: tun.DefaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
@@ -147,24 +130,19 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
return udpConn, nil
}
func (w *StackDevice) Inet4Address() netip.Addr {
return tun.AddrFromAddress(w.addr4)
func (w *stackDevice) SetDevice(device *device.Device) {
}
func (w *StackDevice) Inet6Address() netip.Addr {
return tun.AddrFromAddress(w.addr6)
}
func (w *StackDevice) Start() error {
func (w *stackDevice) Start() error {
w.events <- wgTun.EventUp
return nil
}
func (w *StackDevice) File() *os.File {
func (w *stackDevice) File() *os.File {
return nil
}
func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
select {
case packetBuffer, ok := <-w.outbound:
if !ok {
@@ -180,17 +158,12 @@ func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, e
sizes[0] = n
count = 1
return
case packet := <-w.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
count = 1
return
case <-w.done:
return 0, os.ErrClosed
}
}
func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
for _, b := range bufs {
b = b[offset:]
if len(b) == 0 {
@@ -213,23 +186,23 @@ func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return
}
func (w *StackDevice) Flush() error {
func (w *stackDevice) Flush() error {
return nil
}
func (w *StackDevice) MTU() (int, error) {
func (w *stackDevice) MTU() (int, error) {
return int(w.mtu), nil
}
func (w *StackDevice) Name() (string, error) {
func (w *stackDevice) Name() (string, error) {
return "sing-box", nil
}
func (w *StackDevice) Events() <-chan wgTun.Event {
func (w *stackDevice) Events() <-chan wgTun.Event {
return w.events
}
func (w *StackDevice) Close() error {
func (w *stackDevice) Close() error {
close(w.done)
close(w.events)
w.stack.Close()
@@ -240,13 +213,13 @@ func (w *StackDevice) Close() error {
return nil
}
func (w *StackDevice) BatchSize() int {
func (w *stackDevice) BatchSize() int {
return 1
}
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint StackDevice
type wireEndpoint stackDevice
func (ep *wireEndpoint) MTU() uint32 {
return ep.mtu