Files
sing-box-extended/protocol/wireguard/endpoint_warp.go

211 lines
6.3 KiB
Go

package wireguard
import (
"context"
"encoding/json"
"math/rand"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/common/cloudflare"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json/badoption"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func RegisterWARPEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.WireGuardWARPEndpointOptions](registry, C.TypeWARP, NewWARPEndpoint)
}
type WARPEndpoint struct {
endpoint.Adapter
endpoint adapter.Endpoint
startHandler func()
mtx sync.Mutex
}
func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardWARPEndpointOptions) (adapter.Endpoint, error) {
var dependencies []string
if options.Detour != "" {
dependencies = append(dependencies, options.Detour)
}
if options.Profile.Detour != "" {
dependencies = append(dependencies, options.Profile.Detour)
}
warpEndpoint := &WARPEndpoint{
Adapter: endpoint.NewAdapter(C.TypeWARP, tag, []string{N.NetworkTCP, N.NetworkUDP}, dependencies),
}
warpEndpoint.mtx.Lock()
warpEndpoint.startHandler = func() {
defer warpEndpoint.mtx.Unlock()
cacheFile := service.FromContext[adapter.CacheFile](ctx)
var config *C.WARPConfig
var err error
if !options.Profile.Recreate && cacheFile != nil && cacheFile.StoreWARPConfig() {
savedProfile := cacheFile.LoadWARPConfig(tag)
if savedProfile != nil {
if err = json.Unmarshal(savedProfile.Content, &config); err != nil {
logger.ErrorContext(ctx, err)
return
}
}
}
if config == nil {
var privateKey wgtypes.Key
if options.Profile.PrivateKey != "" {
privateKey, err = wgtypes.ParseKey(options.Profile.PrivateKey)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
} else {
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.ErrorContext(ctx, err)
return
}
}
opts := make([]cloudflare.CloudflareApiOption, 0, 1)
if options.Profile.Detour != "" {
detour, ok := service.FromContext[adapter.OutboundManager](ctx).Outbound(options.Profile.Detour)
if !ok {
logger.ErrorContext(ctx, E.New("outbound detour not found: ", options.Profile.Detour))
return
}
opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) {
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
}))
}
api := cloudflare.NewCloudflareApi(opts...)
var profile *cloudflare.CloudflareProfile
if options.Profile.AuthToken != "" && options.Profile.ID != "" {
profile, err = api.GetProfile(ctx, options.Profile.AuthToken, options.Profile.ID)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
} else {
profile, err = api.CreateProfile(ctx, privateKey.PublicKey().String())
if err != nil {
logger.ErrorContext(ctx, err)
return
}
}
config = &C.WARPConfig{
PrivateKey: privateKey.String(),
Interface: profile.Config.Interface,
Peers: profile.Config.Peers,
}
if cacheFile != nil && cacheFile.StoreWARPConfig() {
content, err := json.Marshal(config)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
cacheFile.SaveWARPConfig(tag, &adapter.SavedBinary{
LastUpdated: time.Now(),
Content: content,
LastEtag: "",
})
}
}
peer := config.Peers[0]
hostParts := strings.Split(peer.Endpoint.Host, ":")
warpEndpoint.endpoint, err = NewEndpoint(
ctx,
router,
logger,
tag,
option.WireGuardEndpointOptions{
System: options.System,
Name: options.Name,
ListenPort: options.ListenPort,
UDPTimeout: options.UDPTimeout,
Workers: options.Workers,
PreallocatedBuffersPerPool: options.PreallocatedBuffersPerPool,
DisablePauses: options.DisablePauses,
Amnezia: options.Amnezia,
DialerOptions: options.DialerOptions,
Address: badoption.Listable[netip.Prefix]{
netip.MustParsePrefix(config.Interface.Addresses.V4 + "/32"),
netip.MustParsePrefix(config.Interface.Addresses.V6 + "/128"),
},
PrivateKey: config.PrivateKey,
Peers: []option.WireGuardPeer{
{
Address: hostParts[0],
Port: uint16(peer.Endpoint.Ports[rand.Intn(len(peer.Endpoint.Ports))]),
PublicKey: peer.PublicKey,
AllowedIPs: badoption.Listable[netip.Prefix]{
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("::/0"),
},
PersistentKeepaliveInterval: options.PersistentKeepaliveInterval,
Reserved: options.Reserved,
},
},
MTU: 1280,
},
)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
if err = warpEndpoint.endpoint.Start(adapter.StartStateStart); err != nil {
logger.ErrorContext(ctx, err)
return
}
if err = warpEndpoint.endpoint.Start(adapter.StartStatePostStart); err != nil {
logger.ErrorContext(ctx, err)
return
}
}
return warpEndpoint, nil
}
func (w *WARPEndpoint) Start(stage adapter.StartStage) error {
if stage != adapter.StartStatePostStart {
return nil
}
go w.startHandler()
return nil
}
func (w *WARPEndpoint) Close() error {
return common.Close(w.endpoint)
}
func (w *WARPEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if ok := w.isEndpointInitialized(); !ok {
return nil, E.New("endpoint not initialized")
}
return w.endpoint.DialContext(ctx, network, destination)
}
func (w *WARPEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if ok := w.isEndpointInitialized(); !ok {
return nil, E.New("endpoint not initialized")
}
return w.endpoint.ListenPacket(ctx, destination)
}
func (w *WARPEndpoint) isEndpointInitialized() bool {
w.mtx.Lock()
defer w.mtx.Unlock()
return w.endpoint != nil
}