diff --git a/adapter/experimental.go b/adapter/experimental.go index 648eb418..fc00b78a 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -41,6 +41,8 @@ type CacheFile interface { StoreGroupExpand(group string, expand bool) error LoadRuleSet(tag string) *SavedBinary SaveRuleSet(tag string, set *SavedBinary) error + LoadCloudflareProfile(tag string) *SavedBinary + SaveCloudflareProfile(tag string, set *SavedBinary) error } type SavedBinary struct { diff --git a/common/cloudflare/api.go b/common/cloudflare/api.go new file mode 100644 index 00000000..136ed192 --- /dev/null +++ b/common/cloudflare/api.go @@ -0,0 +1,58 @@ +package cloudflare + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/tidwall/gjson" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type CloudeflareApi struct { + client http.Client +} + +func NewCloudeflareApi(opts ...CloudflareApiOption) *CloudeflareApi { + api := &CloudeflareApi{http.Client{Timeout: 30 * time.Second}} + for _, opt := range opts { + opt(api) + } + return api +} + +func (api *CloudeflareApi) CreateProfile(ctx context.Context) (*CloudflareProfile, error) { + privateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + request, err := http.NewRequest("POST", "https://api.cloudflareclient.com/v0i1909051800/reg", strings.NewReader( + fmt.Sprintf( + "{\"install_id\":\"\",\"tos\":\"%s\",\"key\":\"%s\",\"fcm_token\":\"\",\"type\":\"ios\",\"locale\":\"en_US\"}", + time.Now().Format("2006-01-02T15:04:05.000Z"), + privateKey.PublicKey().String(), + ), + )) + if err != nil { + return nil, err + } + response, err := api.client.Do(request.WithContext(ctx)) + if err != nil { + return nil, err + } + defer response.Body.Close() + if response.StatusCode != 200 { + return nil, fmt.Errorf("status code is not 200") + } + content, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + profile := new(CloudflareProfile) + profile.Config.PrivateKey = privateKey.String() + return profile, json.NewDecoder(strings.NewReader(gjson.Get(string(content), "result").Raw)).Decode(profile) +} diff --git a/common/cloudflare/option.go b/common/cloudflare/option.go new file mode 100644 index 00000000..00c0d6d4 --- /dev/null +++ b/common/cloudflare/option.go @@ -0,0 +1,17 @@ +package cloudflare + +import ( + "context" + "net" + "net/http" +) + +type CloudflareApiOption func(api *CloudeflareApi) + +func WithDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) CloudflareApiOption { + return func(api *CloudeflareApi) { + api.client.Transport = &http.Transport{ + DialContext: dialContext, + } + } +} diff --git a/common/cloudflare/profile.go b/common/cloudflare/profile.go new file mode 100644 index 00000000..f9cf76d2 --- /dev/null +++ b/common/cloudflare/profile.go @@ -0,0 +1,65 @@ +package cloudflare + +import "time" + +type CloudflareProfile struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + Key string `json:"key"` + Account struct { + ID string `json:"id"` + AccountType string `json:"account_type"` + Created time.Time `json:"created"` + Updated time.Time `json:"updated"` + PremiumData int `json:"premium_data"` + Quota int `json:"quota"` + Usage int `json:"usage"` + WarpPlus bool `json:"warp_plus"` + ReferralCount int `json:"referral_count"` + ReferralRenewalCountdown int `json:"referral_renewal_countdown"` + Role string `json:"role"` + License string `json:"license"` + TTL time.Time `json:"ttl"` + } `json:"account"` + Config struct { + ClientID string `json:"client_id"` + PrivateKey string `json:"private_key"` + Peers []struct { + PublicKey string `json:"public_key"` + Endpoint struct { + V4 string `json:"v4"` + V6 string `json:"v6"` + Host string `json:"host"` + Ports []int `json:"ports"` + } `json:"endpoint"` + } `json:"peers"` + Interface struct { + Addresses struct { + V4 string `json:"v4"` + V6 string `json:"v6"` + } `json:"addresses"` + } `json:"interface"` + Services struct { + HTTPProxy string `json:"http_proxy"` + } `json:"services"` + Metrics struct { + Ping int `json:"ping"` + Report int `json:"report"` + } `json:"metrics"` + } `json:"config"` + Token string `json:"token"` + WarpEnabled bool `json:"warp_enabled"` + WaitlistEnabled bool `json:"waitlist_enabled"` + Created time.Time `json:"created"` + Updated time.Time `json:"updated"` + Tos time.Time `json:"tos"` + Place int `json:"place"` + Locale string `json:"locale"` + Enabled bool `json:"enabled"` + InstallID string `json:"install_id"` + FcmToken string `json:"fcm_token"` + Policy struct { + TunnelProtocol string `json:"tunnel_protocol"` + } `json:"policy"` +} diff --git a/constant/proxy.go b/constant/proxy.go index 3197de60..753eb0bf 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -15,6 +15,7 @@ const ( TypeTrojan = "trojan" TypeNaive = "naive" TypeWireGuard = "wireguard" + TypeWARP = "warp" TypeHysteria = "hysteria" TypeTor = "tor" TypeSSH = "ssh" @@ -60,6 +61,8 @@ func ProxyDisplayName(proxyType string) string { return "Naive" case TypeWireGuard: return "WireGuard" + case TypeWARP: + return "WARP" case TypeHysteria: return "Hysteria" case TypeTor: diff --git a/experimental/cachefile/cache.go b/experimental/cachefile/cache.go index 88cffdbe..7ac324e2 100644 --- a/experimental/cachefile/cache.go +++ b/experimental/cachefile/cache.go @@ -316,3 +316,36 @@ func (c *CacheFile) SaveRuleSet(tag string, set *adapter.SavedBinary) error { return bucket.Put([]byte(tag), setBinary) }) } + +func (c *CacheFile) LoadCloudflareProfile(tag string) *adapter.SavedBinary { + var savedProfile adapter.SavedBinary + err := c.DB.View(func(t *bbolt.Tx) error { + bucket := c.bucket(t, bucketRuleSet) + if bucket == nil { + return os.ErrNotExist + } + profileBinary := bucket.Get([]byte(tag)) + if len(profileBinary) == 0 { + return os.ErrInvalid + } + return savedProfile.UnmarshalBinary(profileBinary) + }) + if err != nil { + return nil + } + return &savedProfile +} + +func (c *CacheFile) SaveCloudflareProfile(tag string, set *adapter.SavedBinary) error { + return c.DB.Batch(func(t *bbolt.Tx) error { + bucket, err := c.createBucket(t, bucketRuleSet) + if err != nil { + return err + } + profileBinary, err := set.MarshalBinary() + if err != nil { + return err + } + return bucket.Put([]byte(tag), profileBinary) + }) +} diff --git a/go.mod b/go.mod index 722b27b5..4daa087a 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,11 @@ require ( howett.net/plist v1.0.1 ) +require ( + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect +) + //replace github.com/sagernet/sing => ../sing require ( @@ -91,6 +96,7 @@ require ( github.com/sagernet/nftables v0.3.0-beta.4 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/tevino/abool/v2 v2.1.0 // indirect + github.com/tidwall/gjson v1.18.0 github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/zeebo/blake3 v0.2.3 // indirect diff --git a/go.sum b/go.sum index 346e9a20..09c49c61 100644 --- a/go.sum +++ b/go.sum @@ -166,6 +166,12 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923 h1:tHNk7XK9GkmKUR6Gh8gVBKXc2MVSZ4G/NnWLtzw4gNA= github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= diff --git a/include/wireguard.go b/include/wireguard.go index f2ce9e23..ead47ced 100644 --- a/include/wireguard.go +++ b/include/wireguard.go @@ -14,4 +14,5 @@ func registerWireGuardOutbound(registry *outbound.Registry) { func registerWireGuardEndpoint(registry *endpoint.Registry) { wireguard.RegisterEndpoint(registry) + wireguard.RegisterWarpEndpoint(registry) } diff --git a/option/wireguard.go b/option/wireguard.go index 5d9b1d86..f819ca8f 100644 --- a/option/wireguard.go +++ b/option/wireguard.go @@ -30,6 +30,22 @@ type WireGuardPeer struct { Reserved []uint8 `json:"reserved,omitempty"` } +type WireGuardWarpEndpointOptions struct { + System bool `json:"system,omitempty"` + Name string `json:"name,omitempty"` + ListenPort uint16 `json:"listen_port,omitempty"` + UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"` + Workers int `json:"workers,omitempty"` + Amnezia *WireGuardAmnezia `json:"amnezia,omitempty"` + Profile *WireGuardWarpProfile `json:"profile,omitempty"` + DialerOptions +} + +type WireGuardWarpProfile struct { + Detour string `json:"detour,omitempty"` + Recreate bool `json:"recreate,omitempty"` +} + type LegacyWireGuardOutboundOptions struct { DialerOptions SystemInterface bool `json:"system_interface,omitempty"` diff --git a/protocol/wireguard/endpoint_warp.go b/protocol/wireguard/endpoint_warp.go new file mode 100644 index 00000000..8b4d0281 --- /dev/null +++ b/protocol/wireguard/endpoint_warp.go @@ -0,0 +1,173 @@ +package wireguard + +import ( + "context" + "encoding/json" + "math/rand" + "net" + "net/netip" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/endpoint" + "github.com/sagernet/sing-box/common/cloudflare" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json/badoption" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" +) + +func RegisterWarpEndpoint(registry *endpoint.Registry) { + endpoint.Register[option.WireGuardWarpEndpointOptions](registry, C.TypeWARP, NewWarpEndpoint) +} + +type WarpEndpoint struct { + endpoint.Adapter + endpoint adapter.Endpoint + ctx context.Context + router adapter.Router + logger log.ContextLogger + tag string + options option.WireGuardWarpEndpointOptions +} + +func NewWarpEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardWarpEndpointOptions) (adapter.Endpoint, error) { + var dependencies []string + if options.Detour != "" { + dependencies = append(dependencies, options.Detour) + } + if options.Profile != nil { + if options.Profile.Detour != "" { + dependencies = append(dependencies, options.Profile.Detour) + } + } + return &WarpEndpoint{ + Adapter: endpoint.NewAdapter(C.TypeWARP, tag, []string{N.NetworkTCP, N.NetworkUDP}, dependencies), + ctx: ctx, + router: router, + logger: logger, + tag: tag, + options: options, + }, nil +} + +func (w *WarpEndpoint) Start(stage adapter.StartStage) error { + if stage != adapter.StartStatePostStart { + return nil + } + cacheFile := service.FromContext[adapter.CacheFile](w.ctx) + var profile *cloudflare.CloudflareProfile + var err error + if !w.options.Profile.Recreate { + if cacheFile != nil { + savedProfile := cacheFile.LoadCloudflareProfile(w.tag) + if savedProfile != nil { + err := json.Unmarshal(savedProfile.Content, &profile) + if err != nil { + return err + } + } + } + } + if profile == nil { + opts := make([]cloudflare.CloudflareApiOption, 0, 1) + if w.options.Profile != nil { + if w.options.Profile.Detour != "" { + detour, ok := service.FromContext[adapter.OutboundManager](w.ctx).Outbound(w.options.Profile.Detour) + if !ok { + return E.New("outbound detour not found: ", w.options.Profile.Detour) + } + opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) { + return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) + })) + } + } + api := cloudflare.NewCloudeflareApi(opts...) + profile, err = api.CreateProfile(w.ctx) + if err != nil { + return err + } + if cacheFile != nil { + content, err := json.Marshal(profile) + if err != nil { + return err + } + cacheFile.SaveCloudflareProfile(w.tag, &adapter.SavedBinary{ + LastUpdated: time.Now(), + Content: content, + LastEtag: "", + }) + } + } + peer := profile.Config.Peers[0] + hostParts := strings.Split(peer.Endpoint.Host, ":") + w.endpoint, err = NewEndpoint( + w.ctx, + w.router, + w.logger, + w.tag, + option.WireGuardEndpointOptions{ + System: w.options.System, + Name: w.options.Name, + ListenPort: w.options.ListenPort, + UDPTimeout: w.options.UDPTimeout, + Workers: w.options.Workers, + Amnezia: w.options.Amnezia, + DialerOptions: w.options.DialerOptions, + + Address: badoption.Listable[netip.Prefix]{ + netip.MustParsePrefix(profile.Config.Interface.Addresses.V4 + "/32"), + netip.MustParsePrefix(profile.Config.Interface.Addresses.V6 + "/128"), + }, + PrivateKey: profile.Config.PrivateKey, + Peers: []option.WireGuardPeer{ + { + Address: hostParts[0], + Port: uint16(peer.Endpoint.Ports[rand.Intn(len(peer.Endpoint.Ports))]), + PublicKey: peer.PublicKey, + AllowedIPs: badoption.Listable[netip.Prefix]{ + netip.MustParsePrefix("0.0.0.0/0"), + netip.MustParsePrefix("::/0"), + }, + }, + }, + MTU: 1280, + }, + ) + if err != nil { + return err + } + if err := w.endpoint.Start(adapter.StartStateStart); err != nil { + return err + } + if err := w.endpoint.Start(adapter.StartStatePostStart); err != nil { + return err + } + return nil +} + +func (w *WarpEndpoint) Close() error { + if w.endpoint == nil { + return E.New("endpoint not initialized") + } + return w.endpoint.Close() +} + +func (w *WarpEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if w.endpoint == nil { + return nil, E.New("endpoint not initialized") + } + return w.endpoint.DialContext(ctx, network, destination) +} + +func (w *WarpEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if w.endpoint == nil { + return nil, E.New("endpoint not initialized") + } + return w.endpoint.ListenPacket(ctx, destination) +}