Files
sing-box-extended/transport/wireguard/endpoint.go
2026-05-29 14:34:11 +03:00

343 lines
10 KiB
Go

package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip"
"os"
"reflect"
"strconv"
"strings"
"time"
"unsafe"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
"go4.org/netipx"
)
type Endpoint struct {
options EndpointOptions
peers []peerConfig
ipcConf string
allowedAddress []netip.Prefix
tunDevice Device
natDevice NatDevice
device *device.Device
allowedIPs *device.AllowedIPs
pause pause.Manager
pauseCallback *list.Element[pause.Callback]
}
func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
if options.PrivateKey == "" {
return nil, E.New("missing private key")
}
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey := hex.EncodeToString(privateKeyBytes)
ipcConf := "private_key=" + privateKey
if options.ListenPort != 0 {
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
}
var peers []peerConfig
for peerIndex, rawPeer := range options.Peers {
peer := peerConfig{
allowedIPs: rawPeer.AllowedIPs,
keepalive: rawPeer.PersistentKeepaliveInterval,
}
if rawPeer.Endpoint.Addr.IsValid() {
peer.endpoint = rawPeer.Endpoint.AddrPort()
} else if rawPeer.Endpoint.IsDomain() {
peer.destination = rawPeer.Endpoint
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
if rawPeer.PreSharedKey != "" {
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed ips for peer ", peerIndex)
}
peers = append(peers, peer)
}
var allowedPrefixBuilder netipx.IPSetBuilder
for _, peer := range options.Peers {
for _, prefix := range peer.AllowedIPs {
allowedPrefixBuilder.AddPrefix(prefix)
}
}
allowedIPSet, err := allowedPrefixBuilder.IPSet()
if err != nil {
return nil, err
}
allowedAddresses := allowedIPSet.Prefixes()
if options.MTU == 0 {
options.MTU = 1408
}
deviceOptions := DeviceOptions{
Context: options.Context,
Logger: options.Logger,
System: options.System,
Handler: options.Handler,
UDPTimeout: options.UDPTimeout,
CreateDialer: options.CreateDialer,
Name: options.Name,
MTU: options.MTU,
Address: options.Address,
AllowedAddress: allowedAddresses,
}
tunDevice, err := NewDevice(deviceOptions)
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
natDevice, isNatDevice := tunDevice.(NatDevice)
if !isNatDevice {
natDevice = NewNATDevice(options.Context, options.Logger, tunDevice)
}
return &Endpoint{
options: options,
peers: peers,
ipcConf: ipcConf,
allowedAddress: allowedAddresses,
tunDevice: tunDevice,
natDevice: natDevice,
}, nil
}
func (e *Endpoint) Start(resolve bool) error {
if common.Any(e.peers, func(peer peerConfig) bool {
return !peer.endpoint.IsValid() && peer.destination.IsDomain()
}) {
if !resolve {
return nil
}
for peerIndex, peer := range e.peers {
if peer.endpoint.IsValid() || !peer.destination.IsDomain() {
continue
}
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
if err != nil {
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
}
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
}
} else if resolve {
return nil
}
var bind conn.Bind
wgListener, isWgListener := common.Cast[dialer.WireGuardListener](e.options.Dialer)
if isWgListener {
bind = conn.NewStdNetBind(wgListener.WireGuardControl())
} else {
var (
isConnect bool
connectAddr netip.AddrPort
)
if len(e.peers) == 1 && e.peers[0].endpoint.IsValid() {
isConnect = true
connectAddr = e.peers[0].endpoint
}
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr)
}
err := e.tunDevice.Start()
if err != nil {
return err
}
logger := &device.Logger{
Verbosef: func(format string, args ...any) {
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...any) {
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}
var deviceInput Device
if e.natDevice != nil {
deviceInput = e.natDevice
} else {
deviceInput = e.tunDevice
}
wgDevice := device.NewDevice(e.options.Context, deviceInput, bind, logger, e.options.Workers, e.options.PreallocatedBuffersPerPool, e.options.DisablePauses)
e.tunDevice.SetDevice(wgDevice)
var ipcConf strings.Builder
ipcConf.WriteString(e.ipcConf)
if e.options.Amnezia != nil {
if e.options.Amnezia.JC > 0 {
ipcConf.WriteString("\njc=" + strconv.Itoa(e.options.Amnezia.JC))
}
if e.options.Amnezia.JMin > 0 {
ipcConf.WriteString("\njmin=" + strconv.Itoa(e.options.Amnezia.JMin))
}
if e.options.Amnezia.JMax > 0 {
ipcConf.WriteString("\njmax=" + strconv.Itoa(e.options.Amnezia.JMax))
}
if e.options.Amnezia.S1 > 0 {
ipcConf.WriteString("\ns1=" + strconv.Itoa(e.options.Amnezia.S1))
}
if e.options.Amnezia.S2 > 0 {
ipcConf.WriteString("\ns2=" + strconv.Itoa(e.options.Amnezia.S2))
}
if e.options.Amnezia.S3 > 0 {
ipcConf.WriteString("\ns3=" + strconv.Itoa(e.options.Amnezia.S3))
}
if e.options.Amnezia.S4 > 0 {
ipcConf.WriteString("\ns4=" + strconv.Itoa(e.options.Amnezia.S4))
}
if e.options.Amnezia.H1 != nil {
ipcConf.WriteString("\nh1=" + e.options.Amnezia.H1.String())
}
if e.options.Amnezia.H2 != nil {
ipcConf.WriteString("\nh2=" + e.options.Amnezia.H2.String())
}
if e.options.Amnezia.H3 != nil {
ipcConf.WriteString("\nh3=" + e.options.Amnezia.H3.String())
}
if e.options.Amnezia.H4 != nil {
ipcConf.WriteString("\nh4=" + e.options.Amnezia.H4.String())
}
if e.options.Amnezia.I1 != "" {
ipcConf.WriteString("\ni1=" + e.options.Amnezia.I1)
}
if e.options.Amnezia.I2 != "" {
ipcConf.WriteString("\ni2=" + e.options.Amnezia.I2)
}
if e.options.Amnezia.I3 != "" {
ipcConf.WriteString("\ni3=" + e.options.Amnezia.I3)
}
if e.options.Amnezia.I4 != "" {
ipcConf.WriteString("\ni4=" + e.options.Amnezia.I4)
}
if e.options.Amnezia.I5 != "" {
ipcConf.WriteString("\ni5=" + e.options.Amnezia.I5)
}
if e.options.Amnezia.J1 != "" {
ipcConf.WriteString("\nj1=" + e.options.Amnezia.J1)
}
if e.options.Amnezia.J2 != "" {
ipcConf.WriteString("\nj2=" + e.options.Amnezia.J2)
}
if e.options.Amnezia.J3 != "" {
ipcConf.WriteString("\nj3=" + e.options.Amnezia.J3)
}
if e.options.Amnezia.ITime > 0 {
ipcConf.WriteString("\nitime=" + strconv.FormatInt(e.options.Amnezia.ITime, 10))
}
}
for _, peer := range e.peers {
ipcConf.WriteString(peer.GenerateIpcLines())
}
err = wgDevice.IpcSet(ipcConf.String())
if err != nil {
wgDevice.Close()
return E.Cause(err, "setup wireguard: \n", ipcConf.String())
}
e.device = wgDevice
e.pause = service.FromContext[pause.Manager](e.options.Context)
if e.pause != nil {
e.pauseCallback = e.pause.RegisterCallback(e.onPauseUpdated)
}
e.allowedIPs = (*device.AllowedIPs)(unsafe.Pointer(reflect.Indirect(reflect.ValueOf(wgDevice)).FieldByName("allowedips").UnsafeAddr()))
return nil
}
func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.DialContext(ctx, network, destination)
}
func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.ListenPacket(ctx, destination)
}
func (e *Endpoint) Close() error {
if e.pauseCallback != nil {
e.pause.UnregisterCallback(e.pauseCallback)
e.pauseCallback = nil
}
if e.device != nil {
e.device.Down()
e.device.Close()
e.device = nil
}
return nil
}
func (e *Endpoint) Lookup(address netip.Addr) *device.Peer {
if e.allowedIPs == nil {
return nil
}
return e.allowedIPs.Lookup(address.AsSlice())
}
func (e *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
if e.natDevice == nil {
return nil, os.ErrInvalid
}
return e.natDevice.CreateDestination(metadata, routeContext, timeout)
}
func (e *Endpoint) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused, pause.EventNetworkPause:
e.device.Down()
case pause.EventDeviceWake, pause.EventNetworkWake:
e.device.Up()
}
}
type peerConfig struct {
destination M.Socksaddr
endpoint netip.AddrPort
publicKeyHex string
preSharedKeyHex string
allowedIPs []netip.Prefix
keepalive uint16
}
func (c peerConfig) GenerateIpcLines() string {
var ipcLines strings.Builder
ipcLines.WriteString("\npublic_key=" + c.publicKeyHex)
if c.endpoint.IsValid() {
ipcLines.WriteString("\nendpoint=" + c.endpoint.String())
}
if c.preSharedKeyHex != "" {
ipcLines.WriteString("\npreshared_key=" + c.preSharedKeyHex)
}
for _, allowedIP := range c.allowedIPs {
ipcLines.WriteString("\nallowed_ip=" + allowedIP.String())
}
if c.keepalive > 0 {
ipcLines.WriteString("\npersistent_keepalive_interval=" + F.ToString(c.keepalive))
}
return ipcLines.String()
}