Files
sing-box-extended/protocol/masque/outbound.go

301 lines
9.6 KiB
Go

package masque
import (
"context"
"encoding/base64"
"encoding/json"
"net"
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/cloudflare"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/masque"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.MASQUEOutboundOptions](registry, C.TypeMASQUE, NewOutbound)
}
type Outbound struct {
outbound.Adapter
ctx context.Context
dnsRouter adapter.DNSRouter
logger logger.ContextLogger
options option.MASQUEOutboundOptions
tunnel *masque.Tunnel
startHandler func()
await chan struct{}
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MASQUEOutboundOptions) (adapter.Outbound, error) {
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeMASQUE, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, options.DialerOptions),
ctx: ctx,
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
options: options,
await: make(chan struct{}),
}
outbound.startHandler = func() {
defer close(outbound.await)
cacheFile := service.FromContext[adapter.CacheFile](ctx)
var appConfig *Config
var err error
if !options.Profile.Recreate && cacheFile != nil && cacheFile.StoreMASQUEConfig() {
savedProfile := cacheFile.LoadMASQUEConfig(tag)
if savedProfile != nil {
if err = json.Unmarshal(savedProfile.Content, &appConfig); err != nil {
logger.ErrorContext(ctx, err)
return
}
}
}
if appConfig == nil {
appConfig, err = outbound.createConfig()
if err != nil {
logger.ErrorContext(ctx, err)
return
}
if cacheFile != nil && cacheFile.StoreMASQUEConfig() {
content, err := json.Marshal(appConfig)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
cacheFile.SaveMASQUEConfig(tag, &adapter.SavedBinary{
LastUpdated: time.Now(),
Content: content,
LastEtag: "",
})
}
}
privKey, err := appConfig.GetEcPrivateKey()
if err != nil {
logger.ErrorContext(ctx, E.New("failed to get private key: ", err))
return
}
peerPubKey, err := appConfig.GetEcEndpointPublicKey()
if err != nil {
logger.ErrorContext(ctx, E.New("failed to get public key: ", err))
return
}
cert, err := masque.GenerateCert(privKey, &privKey.PublicKey)
if err != nil {
logger.ErrorContext(ctx, E.New("failed to generate cert: ", err))
return
}
tlsConfig, err := tls.NewMASQUEClient(ctx, logger, "consumer-masque.cloudflareclient.com", cert, privKey, peerPubKey, options.MASQUEOutboundTLSOptions)
if err != nil {
logger.ErrorContext(ctx, E.New("failed to prepare TLS config: ", err))
return
}
endpoint, err := appConfig.SelectEndpointFromConfig(options.UseHTTP2, options.UseIPv6, 443)
if err != nil {
logger.ErrorContext(ctx, E.New("failed to select endpoint: ", err))
return
}
var udpTimeout time.Duration
if options.UDPTimeout != 0 {
udpTimeout = time.Duration(options.UDPTimeout)
} else {
udpTimeout = C.UDPTimeout
}
var udpKeepalivePeriod time.Duration
if options.UDPKeepalivePeriod != 0 {
udpKeepalivePeriod = time.Duration(options.UDPKeepalivePeriod)
} else {
udpKeepalivePeriod = time.Second * 30
}
outboundDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
RemoteIsDomain: false,
ResolverOnDetour: true,
})
if err != nil {
logger.ErrorContext(ctx, err)
return
}
tunnel, err := masque.NewTunnel(
ctx,
logger,
masque.TunnelOptions{
Dialer: outboundDialer,
Address: []netip.Prefix{
netip.MustParsePrefix(appConfig.IPv4 + "/32"),
netip.MustParsePrefix(appConfig.IPv6 + "/128"),
},
Endpoint: endpoint,
TLSConfig: tlsConfig,
UseHTTP2: options.UseHTTP2,
UDPTimeout: udpTimeout,
UDPKeepalivePeriod: udpKeepalivePeriod,
UDPInitialPacketSize: options.UDPInitialPacketSize,
ReconnectDelay: options.ReconnectDelay.Build(),
})
if err != nil {
logger.ErrorContext(ctx, err)
return
}
outbound.tunnel = tunnel
if err = outbound.tunnel.Start(false); err != nil {
logger.ErrorContext(ctx, err)
return
}
if err = outbound.tunnel.Start(true); err != nil {
logger.ErrorContext(ctx, err)
return
}
}
return outbound, nil
}
func (w *Outbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStatePostStart {
return nil
}
go w.startHandler()
return nil
}
func (w *Outbound) Close() error {
if err := w.isTunnelInitialized(w.ctx); err != nil {
return err
}
return w.tunnel.Close()
}
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if err := w.isTunnelInitialized(ctx); err != nil {
return nil, err
}
switch network {
case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsDomain() {
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
return N.DialSerial(ctx, w.tunnel, network, destination, destinationAddresses)
} else if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
}
return w.tunnel.DialContext(ctx, network, destination)
}
func (w *Outbound) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) {
if err := w.isTunnelInitialized(ctx); err != nil {
return nil, netip.Addr{}, err
}
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsDomain() {
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, netip.Addr{}, err
}
return N.ListenSerial(ctx, w.tunnel, destination, destinationAddresses)
}
packetConn, err := w.tunnel.ListenPacket(ctx, destination)
if err != nil {
return nil, netip.Addr{}, err
}
if destination.IsIP() {
return packetConn, destination.Addr, nil
}
return packetConn, netip.Addr{}, nil
}
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
packetConn, destinationAddress, err := w.ListenPacketWithDestination(ctx, destination)
if err != nil {
return nil, err
}
if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) {
return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
}
return packetConn, nil
}
func (w *Outbound) isTunnelInitialized(ctx context.Context) error {
select {
case <-w.await:
case <-ctx.Done():
return ctx.Err()
}
if w.tunnel == nil {
return E.New("tunnel not initialized")
}
return nil
}
func (w *Outbound) createConfig() (*Config, error) {
opts := make([]cloudflare.CloudflareApiOption, 0, 1)
if w.options.Profile.Detour != "" {
detour, ok := service.FromContext[adapter.OutboundManager](w.ctx).Outbound(w.options.Profile.Detour)
if !ok {
return nil, E.New("outbound detour not found: ", w.options.Profile.Detour)
}
opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) {
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
}))
}
api := cloudflare.NewCloudflareApi(opts...)
var profile *cloudflare.CloudflareProfile
var err error
if w.options.Profile.AuthToken != "" && w.options.Profile.ID != "" {
profile, err = api.GetProfile(w.ctx, w.options.Profile.AuthToken, w.options.Profile.ID)
if err != nil {
return nil, err
}
} else {
wgPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
profile, err = api.CreateProfile(w.ctx, wgPrivateKey.PublicKey().String())
if err != nil {
return nil, err
}
}
privateKey, publicKey, err := masque.GenerateEcKeyPair()
if err != nil {
return nil, E.New("failed to generate key pair: ", err)
}
updatedProfile, err := api.EnrollKey(w.ctx, profile.Token, profile.ID, cloudflare.KeyTypeMasque, cloudflare.TunTypeMasque, base64.StdEncoding.EncodeToString(publicKey))
if err != nil {
return nil, err
}
return &Config{
PrivateKey: base64.StdEncoding.EncodeToString(privateKey),
EndpointV4: updatedProfile.Config.Peers[0].Endpoint.V4[:len(updatedProfile.Config.Peers[0].Endpoint.V4)-2],
EndpointV6: updatedProfile.Config.Peers[0].Endpoint.V6[1 : len(updatedProfile.Config.Peers[0].Endpoint.V6)-3],
EndpointH2V4: cloudflare.DefaultEndpointH2V4,
EndpointH2V6: cloudflare.DefaultEndpointH2V6,
EndpointPubKey: updatedProfile.Config.Peers[0].PublicKey,
License: updatedProfile.Account.License,
ID: updatedProfile.ID,
AccessToken: profile.Token,
IPv4: updatedProfile.Config.Interface.Addresses.V4,
IPv6: updatedProfile.Config.Interface.Addresses.V6,
}, nil
}