Add MTProxy, MASQUE, VPN, Link parser. Update AmneziaWG. Remove Tunneling

This commit is contained in:
Sergei Maklagin
2026-04-29 22:11:30 +03:00
parent 09f9f114aa
commit 04908a6a67
158 changed files with 7994 additions and 2277 deletions

View File

@@ -1,6 +1,7 @@
package bond
import (
"bytes"
"encoding/binary"
"errors"
"io"
@@ -13,9 +14,7 @@ type bondedConn struct {
downloadRatios []uint8
uploadRatios []uint8
readBuffer []byte
readOffset int
readSize int
readBuffer *bytes.Buffer
}
func NewBondedConn(conns []net.Conn, downloadRatios, uploadRatios []uint8) *bondedConn {
@@ -23,12 +22,13 @@ func NewBondedConn(conns []net.Conn, downloadRatios, uploadRatios []uint8) *bond
conns: conns,
downloadRatios: downloadRatios,
uploadRatios: uploadRatios,
readBuffer: make([]byte, 65535),
readBuffer: bytes.NewBuffer(make([]byte, 0, 65536)),
}
}
func (c *bondedConn) Read(b []byte) (n int, err error) {
if c.readOffset == c.readSize {
if c.readBuffer.Len() == 0 {
c.readBuffer.Reset()
var header [2]byte
_, err := io.ReadFull(c.conns[0], header[:])
if err != nil {
@@ -41,19 +41,14 @@ func (c *bondedConn) Read(b []byte) (n int, err error) {
if chunkLen == 0 {
continue
}
chunk := c.readBuffer[total : total+chunkLen]
n, err := io.ReadFull(c.conns[i], chunk)
total += n
n, err := io.CopyN(c.readBuffer, c.conns[i], int64(chunkLen))
total += int(n)
if err != nil {
return total, err
}
}
c.readOffset = 0
c.readSize = size
}
n = copy(b, c.readBuffer[c.readOffset:c.readSize])
c.readOffset += n
return n, nil
return c.readBuffer.Read(b)
}
func (c *bondedConn) Write(b []byte) (n int, err error) {

View File

@@ -6,7 +6,8 @@ import (
"net"
"time"
"github.com/patrickmn/go-cache"
"github.com/gofrs/uuid/v5"
"github.com/patrickmn/go-cache/v2"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/kmutex"
@@ -29,9 +30,9 @@ type Inbound struct {
logger logger.ContextLogger
router adapter.ConnectionRouterEx
inbounds []adapter.Inbound
conns *cache.Cache
conns *cache.Cache[uuid.UUID, map[uint8]*ratioConn]
mtx *kmutex.Kmutex[string]
mtx *kmutex.Kmutex[uuid.UUID]
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.BondInboundOptions) (adapter.Inbound, error) {
@@ -42,23 +43,23 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
Adapter: inbound.NewAdapter(C.TypeBond, tag),
logger: logger,
router: uot.NewRouter(router, logger),
conns: cache.New(C.TCPConnectTimeout, time.Second),
mtx: kmutex.New[string](),
conns: cache.New[uuid.UUID, map[uint8]*ratioConn](C.TCPConnectTimeout, time.Second),
mtx: kmutex.New[uuid.UUID](),
}
router = NewRouter(router, logger, inbound.connHandler)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
inbounds := make([]adapter.Inbound, len(options.Inbounds))
for i, inboundOptions := range options.Inbounds {
inbound, err := inboundRegistry.UnsafeCreate(ctx, NewRouter(router, logger, inbound.connHandler), logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options)
inbound, err := inboundRegistry.UnsafeCreate(ctx, router, logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options)
if err != nil {
return nil, err
}
inbounds[i] = inbound
}
inbound.inbounds = inbounds
inbound.conns.OnEvicted(func(s string, i interface{}) {
inbound.conns.OnEvicted(func(s uuid.UUID, ratioConns map[uint8]*ratioConn) {
inbound.mtx.Lock(s)
defer inbound.mtx.Unlock(s)
ratioConns := i.(map[uint8]*ratioConn)
for _, ratioConn := range ratioConns {
if ratioConn != nil {
ratioConn.conn.Close()
@@ -93,21 +94,15 @@ func (h *Inbound) Close() error {
}
func (h *Inbound) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
if metadata.Destination != Destination {
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return nil
}
request, err := ReadRequest(conn)
if err != nil {
return err
}
requestUUID := request.UUID.String()
requestUUID := request.UUID
h.mtx.Lock(requestUUID)
var ratioConns map[uint8]*ratioConn
rawRatioConns, ok := h.conns.Get(requestUUID)
if ok {
ratioConns = rawRatioConns.(map[uint8]*ratioConn)
} else {
ratioConns, ok := h.conns.Get(requestUUID)
if !ok {
ratioConns = make(map[uint8]*ratioConn, request.Count)
h.conns.SetDefault(requestUUID, ratioConns)
}

View File

@@ -21,14 +21,24 @@ func NewRouter(router adapter.Router, logger logger.ContextLogger, handler func(
}
func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
if metadata.Destination != Destination {
return r.Router.RouteConnection(ctx, conn, metadata)
}
return r.handler(ctx, conn, metadata, func(error) {})
}
func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
if metadata.Destination != Destination {
return r.Router.RoutePacketConnection(ctx, conn, metadata)
}
return os.ErrInvalid
}
func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if metadata.Destination != Destination {
r.Router.RouteConnectionEx(ctx, conn, metadata, onClose)
return
}
if err := r.handler(ctx, conn, metadata, onClose); err != nil {
r.logger.ErrorContext(ctx, err)
N.CloseOnHandshakeFailure(conn, onClose, err)
@@ -36,6 +46,10 @@ func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata
}
func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if metadata.Destination != Destination {
r.Router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
return
}
r.logger.ErrorContext(ctx, os.ErrInvalid)
N.CloseOnHandshakeFailure(conn, onClose, os.ErrInvalid)
}

View File

@@ -17,15 +17,15 @@ import (
"github.com/sagernet/sing/service"
)
func RegisterFailover(registry *outbound.Registry) {
outbound.Register[option.FailoverOutboundOptions](registry, C.TypeFailover, NewFailover)
func RegisterFallback(registry *outbound.Registry) {
outbound.Register[option.FallbackOutboundOptions](registry, C.TypeFallback, NewFallback)
}
var (
_ adapter.OutboundGroup = (*Failover)(nil)
_ adapter.OutboundGroup = (*Fallback)(nil)
)
type Failover struct {
type Fallback struct {
outbound.Adapter
ctx context.Context
outbound adapter.OutboundManager
@@ -37,12 +37,12 @@ type Failover struct {
mtx sync.Mutex
}
func NewFailover(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FailoverOutboundOptions) (adapter.Outbound, error) {
func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FallbackOutboundOptions) (adapter.Outbound, error) {
if len(options.Outbounds) == 0 {
return nil, E.New("missing tags")
}
outbound := &Failover{
Adapter: outbound.NewAdapter(C.TypeFailover, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
outbound := &Fallback{
Adapter: outbound.NewAdapter(C.TypeFallback, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
ctx: ctx,
outbound: service.FromContext[adapter.OutboundManager](ctx),
logger: logger,
@@ -53,7 +53,7 @@ func NewFailover(ctx context.Context, router adapter.Router, logger log.ContextL
return outbound, nil
}
func (s *Failover) Start() error {
func (s *Fallback) Start() error {
for i, tag := range s.tags {
outbound, loaded := s.outbound.Outbound(tag)
if !loaded {
@@ -64,17 +64,17 @@ func (s *Failover) Start() error {
return nil
}
func (s *Failover) Now() string {
func (s *Fallback) Now() string {
s.mtx.Lock()
defer s.mtx.Unlock()
return s.lastUsedOutbound
}
func (s *Failover) All() []string {
func (s *Fallback) All() []string {
return s.tags
}
func (s *Failover) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
func (s *Fallback) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
var conn net.Conn
var err error
for _, outbound := range s.outbounds {
@@ -91,7 +91,7 @@ func (s *Failover) DialContext(ctx context.Context, network string, destination
return nil, err
}
func (s *Failover) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
func (s *Fallback) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
var conn net.PacketConn
var err error
for _, outbound := range s.outbounds {

View File

@@ -3,6 +3,7 @@ package group
import (
"context"
"net"
"regexp"
"time"
"github.com/sagernet/sing-box/adapter"
@@ -42,11 +43,20 @@ type Selector struct {
selected common.TypedValue[adapter.Outbound]
interruptGroup *interrupt.Group
interruptExternalConnections bool
provider adapter.ProviderManager
providers map[string]adapter.Provider
outboundsCache map[string][]adapter.Outbound
providerTags []string
exclude *regexp.Regexp
include *regexp.Regexp
useAllProviders bool
}
func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) {
outbound := &Selector{
Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds),
Adapter: outbound.NewAdapter(C.TypeSelector, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
ctx: ctx,
outbound: service.FromContext[adapter.OutboundManager](ctx),
connection: service.FromContext[adapter.ConnectionManager](ctx),
@@ -56,9 +66,15 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL
outbounds: make(map[string]adapter.Outbound),
interruptGroup: interrupt.NewGroup(),
interruptExternalConnections: options.InterruptExistConnections,
}
if len(outbound.tags) == 0 {
return nil, E.New("missing tags")
provider: service.FromContext[adapter.ProviderManager](ctx),
providers: make(map[string]adapter.Provider),
outboundsCache: make(map[string][]adapter.Outbound),
providerTags: options.Providers,
exclude: (*regexp.Regexp)(options.Exclude),
include: (*regexp.Regexp)(options.Include),
useAllProviders: options.UseAllProviders,
}
return outbound, nil
}
@@ -72,6 +88,28 @@ func (s *Selector) Network() []string {
}
func (s *Selector) Start() error {
if s.useAllProviders {
var providerTags []string
for _, provider := range s.provider.Providers() {
providerTags = append(providerTags, provider.Tag())
s.providers[provider.Tag()] = provider
provider.RegisterCallback(s.onProviderUpdated)
}
s.providerTags = providerTags
} else {
for i, tag := range s.providerTags {
provider, loaded := s.provider.Get(tag)
if !loaded {
return E.New("outbound provider ", i, " not found: ", tag)
}
s.providers[tag] = provider
provider.RegisterCallback(s.onProviderUpdated)
}
}
if len(s.tags)+len(s.providerTags) == 0 {
return E.New("missing outbound and provider tags")
}
for i, tag := range s.tags {
detour, loaded := s.outbound.Outbound(tag)
if !loaded {
@@ -79,31 +117,16 @@ func (s *Selector) Start() error {
}
s.outbounds[tag] = detour
}
if s.Tag() != "" {
cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
if cacheFile != nil {
selected := cacheFile.LoadSelected(s.Tag())
if selected != "" {
detour, loaded := s.outbounds[selected]
if loaded {
s.selected.Store(detour)
return nil
}
}
}
if len(s.tags) == 0 {
detour, _ := s.outbound.Outbound("Compatible")
s.tags = append(s.tags, detour.Tag())
s.outbounds[detour.Tag()] = detour
}
if s.defaultTag != "" {
detour, loaded := s.outbounds[s.defaultTag]
if !loaded {
return E.New("default outbound not found: ", s.defaultTag)
}
s.selected.Store(detour)
return nil
outbound, err := s.outboundSelect()
if err != nil {
return err
}
s.selected.Store(s.outbounds[s.tags[0]])
s.selected.Store(outbound)
return nil
}
@@ -145,7 +168,7 @@ func (s *Selector) DialContext(ctx context.Context, network string, destination
if err != nil {
return nil, err
}
return s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
return s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil
}
func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
@@ -153,13 +176,13 @@ func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
if err != nil {
return nil, err
}
return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil
}
func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = interrupt.ContextWithIsExternalConnection(ctx)
selected := s.selected.Load()
conn = s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx))
conn = s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx))
if outboundHandler, isHandler := selected.(adapter.ConnectionHandlerEx); isHandler {
outboundHandler.NewConnectionEx(ctx, conn, metadata, onClose)
} else {
@@ -170,7 +193,7 @@ func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata
func (s *Selector) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = interrupt.ContextWithIsExternalConnection(ctx)
selected := s.selected.Load()
conn = s.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx))
conn = s.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx))
if outboundHandler, isHandler := selected.(adapter.PacketConnectionHandlerEx); isHandler {
outboundHandler.NewPacketConnectionEx(ctx, conn, metadata, onClose)
} else {
@@ -192,3 +215,77 @@ func RealTag(detour adapter.Outbound) string {
}
return detour.Tag()
}
func (s *Selector) onProviderUpdated(tag string) error {
_, loaded := s.providers[tag]
if !loaded {
return E.New(s.Tag(), ": ", "outbound provider not found: ", tag)
}
var (
tags = s.Dependencies()
outboundByTag = make(map[string]adapter.Outbound)
)
for _, tag := range tags {
outboundByTag[tag] = s.outbounds[tag]
}
for _, providerTag := range s.providerTags {
if providerTag != tag && s.outboundsCache[providerTag] != nil {
for _, detour := range s.outboundsCache[providerTag] {
tags = append(tags, detour.Tag())
outboundByTag[detour.Tag()] = detour
}
continue
}
provider := s.providers[providerTag]
var cache []adapter.Outbound
for _, detour := range provider.Outbounds() {
tag := detour.Tag()
if s.exclude != nil && s.exclude.MatchString(tag) {
continue
}
if s.include != nil && !s.include.MatchString(tag) {
continue
}
tags = append(tags, tag)
cache = append(cache, detour)
outboundByTag[tag] = detour
}
s.outboundsCache[providerTag] = cache
}
if len(tags) == 0 {
detour, _ := s.outbound.Outbound("Compatible")
tags = append(tags, detour.Tag())
outboundByTag[detour.Tag()] = detour
}
s.tags, s.outbounds = tags, outboundByTag
detour, _ := s.outboundSelect()
if s.selected.Swap(detour) != detour {
s.interruptGroup.Interrupt(s.interruptExternalConnections)
}
return nil
}
func (s *Selector) outboundSelect() (adapter.Outbound, error) {
if s.Tag() != "" {
cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
if cacheFile != nil {
selected := cacheFile.LoadSelected(s.Tag())
if selected != "" {
detour, loaded := s.outbounds[selected]
if loaded {
return detour, nil
}
}
}
}
if s.defaultTag != "" {
detour, loaded := s.outbounds[s.defaultTag]
if !loaded {
return nil, E.New("default outbound not found: ", s.defaultTag)
}
return detour, nil
}
return s.outbounds[s.tags[0]], nil
}

View File

@@ -3,6 +3,7 @@ package group
import (
"context"
"net"
"regexp"
"sync"
"sync/atomic"
"time"
@@ -45,6 +46,16 @@ type URLTest struct {
idleTimeout time.Duration
group *URLTestGroup
interruptExternalConnections bool
provider adapter.ProviderManager
providers map[string]adapter.Provider
outboundsCache map[string][]adapter.Outbound
cancel context.CancelFunc
providerTags []string
exclude *regexp.Regexp
include *regexp.Regexp
useAllProviders bool
}
func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) {
@@ -61,14 +72,42 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
tolerance: options.Tolerance,
idleTimeout: time.Duration(options.IdleTimeout),
interruptExternalConnections: options.InterruptExistConnections,
}
if len(outbound.tags) == 0 {
return nil, E.New("missing tags")
provider: service.FromContext[adapter.ProviderManager](ctx),
providers: make(map[string]adapter.Provider),
outboundsCache: make(map[string][]adapter.Outbound),
providerTags: options.Providers,
exclude: (*regexp.Regexp)(options.Exclude),
include: (*regexp.Regexp)(options.Include),
useAllProviders: options.UseAllProviders,
}
return outbound, nil
}
func (s *URLTest) Start() error {
if s.useAllProviders {
var providerTags []string
for _, provider := range s.provider.Providers() {
providerTags = append(providerTags, provider.Tag())
s.providers[provider.Tag()] = provider
provider.RegisterCallback(s.onProviderUpdated)
}
s.providerTags = providerTags
} else {
for i, tag := range s.providerTags {
provider, loaded := s.provider.Get(tag)
if !loaded {
return E.New("outbound provider ", i, " not found: ", tag)
}
s.providers[tag] = provider
provider.RegisterCallback(s.onProviderUpdated)
}
}
if len(s.tags)+len(s.providerTags) == 0 {
return E.New("missing outbound and provider tags")
}
outbounds := make([]adapter.Outbound, 0, len(s.tags))
for i, tag := range s.tags {
detour, loaded := s.outbound.Outbound(tag)
@@ -77,6 +116,11 @@ func (s *URLTest) Start() error {
}
outbounds = append(outbounds, detour)
}
if len(s.tags) == 0 {
detour, _ := s.outbound.Outbound("Compatible")
s.tags = append(s.tags, detour.Tag())
outbounds = append(outbounds, detour)
}
group, err := NewURLTestGroup(s.ctx, s.outbound, s.logger, outbounds, s.link, s.interval, s.tolerance, s.idleTimeout, s.interruptExternalConnections)
if err != nil {
return err
@@ -136,7 +180,7 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M
}
conn, err := outbound.DialContext(ctx, network, destination)
if err == nil {
return s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
return s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil
}
s.logger.ErrorContext(ctx, err)
s.group.history.DeleteURLTestHistory(outbound.Tag())
@@ -154,7 +198,7 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne
}
conn, err := outbound.ListenPacket(ctx, destination)
if err == nil {
return s.group.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
return s.group.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil
}
s.logger.ErrorContext(ctx, err)
s.group.history.DeleteURLTestHistory(outbound.Tag())
@@ -163,13 +207,13 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne
func (s *URLTest) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = interrupt.ContextWithIsExternalConnection(ctx)
conn = s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx))
conn = s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx))
s.connection.NewConnection(ctx, s, conn, metadata, onClose)
}
func (s *URLTest) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = interrupt.ContextWithIsExternalConnection(ctx)
conn = s.group.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx))
conn = s.group.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx))
s.connection.NewPacketConnection(ctx, s, conn, metadata, onClose)
}
@@ -188,6 +232,63 @@ func (s *URLTest) NewDirectRouteConnection(metadata adapter.InboundContext, rout
return selected.(adapter.DirectRouteOutbound).NewDirectRouteConnection(metadata, routeContext, timeout)
}
func (s *URLTest) onProviderUpdated(tag string) error {
_, loaded := s.providers[tag]
if !loaded {
return E.New("outbound provider not found: ", tag)
}
var (
tags = s.Dependencies()
outbounds []adapter.Outbound
)
for _, tag := range tags {
detour, _ := s.outbound.Outbound(tag)
outbounds = append(outbounds, detour)
}
for _, providerTag := range s.providerTags {
if providerTag != tag && s.outboundsCache[providerTag] != nil {
for _, detour := range s.outboundsCache[providerTag] {
tags = append(tags, detour.Tag())
outbounds = append(outbounds, detour)
}
continue
}
provider := s.providers[providerTag]
var cache []adapter.Outbound
for _, detour := range provider.Outbounds() {
tag := detour.Tag()
if s.exclude != nil && s.exclude.MatchString(tag) {
continue
}
if s.include != nil && !s.include.MatchString(tag) {
continue
}
tags = append(tags, tag)
cache = append(cache, detour)
}
outbounds = append(outbounds, cache...)
s.outboundsCache[providerTag] = cache
}
if len(tags) == 0 {
detour, _ := s.outbound.Outbound("Compatible")
tags = append(tags, detour.Tag())
outbounds = append(outbounds, detour)
}
s.tags, s.group.outbounds = tags, outbounds
s.group.access.Lock()
if s.group.ticker != nil {
s.group.ticker.Reset(s.group.interval)
}
s.group.access.Unlock()
ctx, cancel := context.WithCancel(s.ctx)
if s.cancel != nil {
s.cancel()
}
s.cancel = cancel
s.URLTest(ctx)
return nil
}
type URLTestGroup struct {
ctx context.Context
router adapter.Router
@@ -407,7 +508,11 @@ func (g *URLTestGroup) urlTest(ctx context.Context, force bool) (map[string]uint
})
}
b.Wait()
g.performUpdateCheck()
select {
case <-ctx.Done():
default:
g.performUpdateCheck()
}
return result, nil
}

View File

@@ -0,0 +1,119 @@
package bandwidth
import (
"context"
"net"
"golang.org/x/time/rate"
)
type connWithDownloadBandwidthLimiter struct {
net.Conn
ctx context.Context
limiter Limiter
}
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithDownloadBandwidthLimiter {
return &connWithDownloadBandwidthLimiter{conn, ctx, limiter}
}
func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error) {
err = conn.limiter.WaitN(conn.ctx, len(p))
if err != nil {
return
}
return conn.Conn.Write(p)
}
type connWithUploadBandwidthLimiter struct {
net.Conn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithUploadBandwidthLimiter {
return &connWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) {
n, err = conn.Conn.Read(p)
if err != nil {
return
}
err = conn.limiter.WaitN(conn.ctx, n)
if err != nil {
return
}
return n, err
}
type connWithCloseHandler struct {
net.Conn
onClose CloseHandlerFunc
}
func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler {
return &connWithCloseHandler{conn, onClose}
}
func (conn *connWithCloseHandler) Close() error {
conn.onClose()
return conn.Conn.Close()
}
type packetConnWithDownloadBandwidthLimiter struct {
net.PacketConn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithDownloadBandwidthLimiter {
return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.Addr) (n int, err error) {
err = conn.limiter.WaitN(conn.ctx, len(p))
if err != nil {
return
}
return conn.PacketConn.WriteTo(p, addr)
}
type packetConnWithUploadBandwidthLimiter struct {
net.PacketConn
ctx context.Context
limiter Limiter
burst int
}
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithUploadBandwidthLimiter {
return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = conn.PacketConn.ReadFrom(p)
if err != nil {
return
}
err = conn.limiter.WaitN(conn.ctx, n)
if err != nil {
return
}
return
}
type packetConnWithCloseHandler struct {
net.PacketConn
onClose CloseHandlerFunc
}
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler {
return &packetConnWithCloseHandler{conn, onClose}
}
func (conn *packetConnWithCloseHandler) Close() error {
conn.onClose()
return conn.PacketConn.Close()
}

View File

@@ -2,157 +2,8 @@ package bandwidth
import (
"context"
"net"
"golang.org/x/time/rate"
)
type connWithDownloadBandwidthLimiter struct {
net.Conn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithDownloadBandwidthLimiter {
return &connWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error) {
var nn int
for {
end := len(p)
if end == 0 {
break
}
if conn.burst < len(p) {
end = conn.burst
}
err = conn.limiter.WaitN(conn.ctx, end)
if err != nil {
return
}
nn, err = conn.Conn.Write(p[:end])
n += nn
if err != nil {
return
}
p = p[end:]
}
return
}
type connWithUploadBandwidthLimiter struct {
net.Conn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithUploadBandwidthLimiter {
return &connWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) {
if conn.burst < len(p) {
p = p[:conn.burst]
}
n, err = conn.Conn.Read(p)
if err != nil {
return
}
err = conn.limiter.WaitN(conn.ctx, n)
if err != nil {
return
}
return
}
type connWithCloseHandler struct {
net.Conn
onClose CloseHandlerFunc
}
func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler {
return &connWithCloseHandler{conn, onClose}
}
func (conn *connWithCloseHandler) Close() error {
conn.onClose()
return conn.Conn.Close()
}
type packetConnWithDownloadBandwidthLimiter struct {
net.PacketConn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithDownloadBandwidthLimiter {
return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.Addr) (n int, err error) {
var nn int
for {
end := len(p)
if end == 0 {
break
}
if conn.burst < len(p) {
end = conn.burst
}
err = conn.limiter.WaitN(conn.ctx, end)
if err != nil {
return
}
nn, err = conn.PacketConn.WriteTo(p[:end], addr)
n += nn
if err != nil {
return
}
p = p[end:]
}
return
}
type packetConnWithUploadBandwidthLimiter struct {
net.PacketConn
ctx context.Context
limiter *rate.Limiter
burst int
}
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithUploadBandwidthLimiter {
return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
}
func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if conn.burst < len(p) {
p = p[:conn.burst]
}
n, addr, err = conn.PacketConn.ReadFrom(p)
if err != nil {
return
}
err = conn.limiter.WaitN(conn.ctx, n)
if err != nil {
return
}
return
}
type packetConnWithCloseHandler struct {
net.PacketConn
onClose CloseHandlerFunc
}
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler {
return &packetConnWithCloseHandler{conn, onClose}
}
func (conn *packetConnWithCloseHandler) Close() error {
conn.onClose()
return conn.PacketConn.Close()
type Limiter interface {
WaitN(ctx context.Context, n int) (err error)
}

View File

@@ -226,7 +226,7 @@ func CreateStrategy(strategy string, mode string, connectionType string, speed u
}
func createSpeedLimiter(speed uint64) *rate.Limiter {
return rate.NewLimiter(rate.Limit(float64(speed)), 10000)
return rate.NewLimiter(rate.Limit(float64(speed)), 65536)
}
func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter *rate.Limiter, reverse bool) net.Conn {

89
protocol/masque/config.go Normal file
View File

@@ -0,0 +1,89 @@
package masque
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"net"
)
type Config struct {
PrivateKey string `json:"private_key"` // Base64-encoded ECDSA private key
EndpointV4 string `json:"endpoint_v4"` // IPv4 address of the endpoint
EndpointV6 string `json:"endpoint_v6"` // IPv6 address of the endpoint
EndpointH2V4 string `json:"endpoint_h2_v4"` // IPv4 address used in HTTP/2 mode
EndpointH2V6 string `json:"endpoint_h2_v6"` // IPv6 address used in HTTP/2 mode
EndpointPubKey string `json:"endpoint_pub_key"` // PEM-encoded ECDSA public key of the endpoint to verify against
License string `json:"license"` // Application license key
ID string `json:"id"` // Device unique identifier
AccessToken string `json:"access_token"` // Authentication token for API access
IPv4 string `json:"ipv4"` // Assigned IPv4 address
IPv6 string `json:"ipv6"` // Assigned IPv6 address
}
func (c *Config) GetEcPrivateKey() (*ecdsa.PrivateKey, error) {
privKeyB64, err := base64.StdEncoding.DecodeString(c.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to decode private key: %v", err)
}
privKey, err := x509.ParseECPrivateKey(privKeyB64)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %v", err)
}
return privKey, nil
}
func (c *Config) GetEcEndpointPublicKey() (*ecdsa.PublicKey, error) {
endpointPubKeyB64, _ := pem.Decode([]byte(c.EndpointPubKey))
if endpointPubKeyB64 == nil {
return nil, fmt.Errorf("failed to decode endpoint public key")
}
pubKey, err := x509.ParsePKIXPublicKey(endpointPubKeyB64.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %v", err)
}
ecPubKey, ok := pubKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("failed to assert public key as ECDSA")
}
return ecPubKey, nil
}
func (c *Config) SelectEndpointFromConfig(useHTTP2 bool, useIPv6 bool, port int) (net.Addr, error) {
if useHTTP2 {
if useIPv6 {
if c.EndpointH2V6 == "" {
return nil, fmt.Errorf("--http2 with --ipv6 requires config endpoint_h2_v6 to be set")
}
ip := net.ParseIP(c.EndpointH2V6)
if ip == nil {
return nil, fmt.Errorf("invalid endpoint_h2_v6 value %q", c.EndpointH2V6)
}
return &net.TCPAddr{IP: ip, Port: port}, nil
}
v4 := c.EndpointH2V4
ip := net.ParseIP(v4)
if ip == nil {
return nil, fmt.Errorf("invalid endpoint_h2_v4 value %q")
}
return &net.TCPAddr{IP: ip, Port: port}, nil
}
if useIPv6 {
ip := net.ParseIP(c.EndpointV6)
if ip == nil {
return nil, fmt.Errorf("invalid endpoint_v6 value %q", c.EndpointV6)
}
return &net.UDPAddr{IP: ip, Port: port}, nil
}
ip := net.ParseIP(c.EndpointV4)
if ip == nil {
return nil, fmt.Errorf("invalid endpoint_v4 value %q", c.EndpointV4)
}
return &net.UDPAddr{IP: ip, Port: port}, nil
}

300
protocol/masque/outbound.go Normal file
View File

@@ -0,0 +1,300 @@
package masque
import (
"context"
"encoding/base64"
"encoding/json"
"net"
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/cloudflare"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/masque"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.MASQUEOutboundOptions](registry, C.TypeMASQUE, NewOutbound)
}
type Outbound struct {
outbound.Adapter
ctx context.Context
dnsRouter adapter.DNSRouter
logger logger.ContextLogger
options option.MASQUEOutboundOptions
tunnel *masque.Tunnel
startHandler func()
await chan struct{}
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MASQUEOutboundOptions) (adapter.Outbound, error) {
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeMASQUE, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, options.DialerOptions),
ctx: ctx,
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
options: options,
await: make(chan struct{}),
}
outbound.startHandler = func() {
defer close(outbound.await)
cacheFile := service.FromContext[adapter.CacheFile](ctx)
var appConfig *Config
var err error
if !options.Profile.Recreate && cacheFile != nil && cacheFile.StoreMASQUEConfig() {
savedProfile := cacheFile.LoadMASQUEConfig(tag)
if savedProfile != nil {
if err = json.Unmarshal(savedProfile.Content, &appConfig); err != nil {
logger.ErrorContext(ctx, err)
return
}
}
}
if appConfig == nil {
appConfig, err = outbound.createConfig()
if err != nil {
logger.ErrorContext(ctx, err)
return
}
if cacheFile != nil && cacheFile.StoreMASQUEConfig() {
content, err := json.Marshal(appConfig)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
cacheFile.SaveMASQUEConfig(tag, &adapter.SavedBinary{
LastUpdated: time.Now(),
Content: content,
LastEtag: "",
})
}
}
privKey, err := appConfig.GetEcPrivateKey()
if err != nil {
logger.ErrorContext(ctx, E.New("failed to get private key: ", err))
return
}
peerPubKey, err := appConfig.GetEcEndpointPublicKey()
if err != nil {
logger.ErrorContext(ctx, E.New("failed to get public key: ", err))
return
}
cert, err := masque.GenerateCert(privKey, &privKey.PublicKey)
if err != nil {
logger.ErrorContext(ctx, E.New("failed to generate cert: ", err))
return
}
tlsConfig, err := tls.NewMASQUEClient(ctx, logger, "consumer-masque.cloudflareclient.com", cert, privKey, peerPubKey, options.MASQUEOutboundTLSOptions)
if err != nil {
logger.ErrorContext(ctx, E.New("failed to prepare TLS config: ", err))
return
}
endpoint, err := appConfig.SelectEndpointFromConfig(options.UseHTTP2, options.UseIPv6, 443)
if err != nil {
logger.ErrorContext(ctx, E.New("failed to select endpoint: ", err))
return
}
var udpTimeout time.Duration
if options.UDPTimeout != 0 {
udpTimeout = time.Duration(options.UDPTimeout)
} else {
udpTimeout = C.UDPTimeout
}
var udpKeepalivePeriod time.Duration
if options.UDPKeepalivePeriod != 0 {
udpKeepalivePeriod = time.Duration(options.UDPKeepalivePeriod)
} else {
udpKeepalivePeriod = time.Second * 30
}
outboundDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
RemoteIsDomain: false,
ResolverOnDetour: true,
})
if err != nil {
logger.ErrorContext(ctx, err)
return
}
tunnel, err := masque.NewTunnel(
ctx,
logger,
masque.TunnelOptions{
Dialer: outboundDialer,
Address: []netip.Prefix{
netip.MustParsePrefix(appConfig.IPv4 + "/32"),
netip.MustParsePrefix(appConfig.IPv6 + "/128"),
},
Endpoint: endpoint,
TLSConfig: tlsConfig,
UseHTTP2: options.UseHTTP2,
UDPTimeout: udpTimeout,
UDPKeepalivePeriod: udpKeepalivePeriod,
UDPInitialPacketSize: options.UDPInitialPacketSize,
ReconnectDelay: options.ReconnectDelay.Build(),
})
if err != nil {
logger.ErrorContext(ctx, err)
return
}
outbound.tunnel = tunnel
if err = outbound.tunnel.Start(false); err != nil {
logger.ErrorContext(ctx, err)
return
}
if err = outbound.tunnel.Start(true); err != nil {
logger.ErrorContext(ctx, err)
return
}
}
return outbound, nil
}
func (w *Outbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStatePostStart {
return nil
}
go w.startHandler()
return nil
}
func (w *Outbound) Close() error {
if err := w.isTunnelInitialized(w.ctx); err != nil {
return err
}
return w.tunnel.Close()
}
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if err := w.isTunnelInitialized(ctx); err != nil {
return nil, err
}
switch network {
case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsDomain() {
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
return N.DialSerial(ctx, w.tunnel, network, destination, destinationAddresses)
} else if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
}
return w.tunnel.DialContext(ctx, network, destination)
}
func (w *Outbound) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) {
if err := w.isTunnelInitialized(ctx); err != nil {
return nil, netip.Addr{}, err
}
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsDomain() {
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, netip.Addr{}, err
}
return N.ListenSerial(ctx, w.tunnel, destination, destinationAddresses)
}
packetConn, err := w.tunnel.ListenPacket(ctx, destination)
if err != nil {
return nil, netip.Addr{}, err
}
if destination.IsIP() {
return packetConn, destination.Addr, nil
}
return packetConn, netip.Addr{}, nil
}
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
packetConn, destinationAddress, err := w.ListenPacketWithDestination(ctx, destination)
if err != nil {
return nil, err
}
if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) {
return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
}
return packetConn, nil
}
func (w *Outbound) isTunnelInitialized(ctx context.Context) error {
select {
case <-w.await:
case <-ctx.Done():
return ctx.Err()
}
if w.tunnel == nil {
return E.New("tunnel not initialized")
}
return nil
}
func (w *Outbound) createConfig() (*Config, error) {
opts := make([]cloudflare.CloudflareApiOption, 0, 1)
if w.options.Profile.Detour != "" {
detour, ok := service.FromContext[adapter.OutboundManager](w.ctx).Outbound(w.options.Profile.Detour)
if !ok {
return nil, E.New("outbound detour not found: ", w.options.Profile.Detour)
}
opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) {
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
}))
}
api := cloudflare.NewCloudflareApi(opts...)
var profile *cloudflare.CloudflareProfile
var err error
if w.options.Profile.AuthToken != "" && w.options.Profile.ID != "" {
profile, err = api.GetProfile(w.ctx, w.options.Profile.AuthToken, w.options.Profile.ID)
if err != nil {
return nil, err
}
} else {
wgPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
profile, err = api.CreateProfile(w.ctx, wgPrivateKey.PublicKey().String())
if err != nil {
return nil, err
}
}
privateKey, publicKey, err := masque.GenerateEcKeyPair()
if err != nil {
return nil, E.New("failed to generate key pair: ", err)
}
updatedProfile, err := api.EnrollKey(w.ctx, profile.Token, profile.ID, cloudflare.KeyTypeMasque, cloudflare.TunTypeMasque, base64.StdEncoding.EncodeToString(publicKey))
if err != nil {
return nil, err
}
return &Config{
PrivateKey: base64.StdEncoding.EncodeToString(privateKey),
EndpointV4: updatedProfile.Config.Peers[0].Endpoint.V4[:len(updatedProfile.Config.Peers[0].Endpoint.V4)-2],
EndpointV6: updatedProfile.Config.Peers[0].Endpoint.V6[1 : len(updatedProfile.Config.Peers[0].Endpoint.V6)-3],
EndpointH2V4: cloudflare.DefaultEndpointH2V4,
EndpointH2V6: cloudflare.DefaultEndpointH2V6,
EndpointPubKey: updatedProfile.Config.Peers[0].PublicKey,
License: updatedProfile.Account.License,
ID: updatedProfile.ID,
AccessToken: profile.Token,
IPv4: updatedProfile.Config.Interface.Addresses.V4,
IPv6: updatedProfile.Config.Interface.Addresses.V6,
}, nil
}

View File

@@ -0,0 +1,38 @@
package mtproxy
import (
"context"
"net"
"github.com/sagernet/sing-box/adapter"
M "github.com/sagernet/sing/common/metadata"
)
type Dialer struct {
handler adapter.ConnectionHandlerFuncEx
}
func NewDialer(handler adapter.ConnectionHandlerFuncEx) *Dialer {
return &Dialer{handler}
}
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
inConn, outConn := net.Pipe()
var metadata adapter.InboundContext
if streamContext, ok := ctx.(streamContext); ok {
metadata.Source = M.SocksaddrFromNet(streamContext.ClientAddr())
metadata.User = streamContext.SecretName()
}
metadata.Destination = M.ParseSocksaddr(address)
d.handler(ctx, inConn, metadata, func(error) {})
return outConn, nil
}
type streamContext interface {
ClientAddr() net.Addr
SecretName() string
}

132
protocol/mtproxy/inbound.go Normal file
View File

@@ -0,0 +1,132 @@
package mtproxy
import (
"context"
"net"
"github.com/dolonet/mtg-multi/antireplay"
"github.com/dolonet/mtg-multi/events"
"github.com/dolonet/mtg-multi/ipblocklist"
"github.com/dolonet/mtg-multi/mtglib"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/listener"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.MTProxyInboundOptions](registry, C.TypeMTProxy, NewInbound)
}
type Inbound struct {
inbound.Adapter
ctx context.Context
router adapter.ConnectionRouterEx
logger logger.ContextLogger
listener *listener.Listener
proxy *mtglib.Proxy
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MTProxyInboundOptions) (adapter.Inbound, error) {
inbound := &Inbound{
Adapter: inbound.NewAdapter(C.TypeMTProxy, tag),
ctx: ctx,
router: router,
logger: logger,
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Listen: options.ListenOptions,
}),
}
mtgLogger := NewLoggerAdapter(logger)
secrets := make(map[string]mtglib.Secret, len(options.Users))
for _, user := range options.Users {
secret := mtglib.Secret{}
err := secret.Set(user.Secret)
if err != nil {
return nil, err
}
secrets[user.Name] = secret
}
opts := mtglib.ProxyOpts{
Logger: mtgLogger,
Network: NewNetworkAdapter(ctx, NewDialer(inbound.newConnection)),
AntiReplayCache: antireplay.NewNoop(),
IPBlocklist: ipblocklist.NewNoop(),
IPAllowlist: ipblocklist.NewNoop(),
EventStream: events.NewNoopStream(),
Secrets: secrets,
Concurrency: options.GetConcurrency(),
DomainFrontingPort: options.GetDomainFrontingPort(),
DomainFrontingIP: options.DomainFrontingIP,
DomainFrontingProxyProtocol: options.DomainFrontingProxyProtocol,
PreferIP: options.GetPreferIP(),
AutoUpdate: options.AutoUpdate,
AllowFallbackOnUnknownDC: options.AllowFallbackOnUnknownDC,
TolerateTimeSkewness: options.TolerateTimeSkewness.Build(),
IdleTimeout: options.GetIdleTimeout(),
HandshakeTimeout: options.GetHandshakeTimeout(),
DoppelGangerURLs: options.DoppelGangerURLs,
DoppelGangerPerRaid: options.GetDoppelGangerPerRaid(),
DoppelGangerEach: options.GetDoppelGangerEach(),
DoppelGangerDRS: options.DoppelGangerDRS,
ThrottleMaxConnections: options.ThrottleMaxConnections,
ThrottleCheckInterval: options.GetThrottleCheckInterval(),
}
proxy, err := mtglib.NewProxy(opts)
if err != nil {
return nil, E.New("cannot create a proxy: ", err)
}
inbound.proxy = proxy
return inbound, nil
}
func (n *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
listener, err := n.listener.ListenTCP()
if err != nil {
return err
}
go n.proxy.Serve(listener)
return nil
}
func (n *Inbound) Close() error {
n.proxy.Shutdown()
return common.Close(
&n.listener,
)
}
func (h *Inbound) UpdateUsers(users []option.MTProxyUser) {
secrets := make(map[string]mtglib.Secret, len(users))
for _, user := range users {
secret := mtglib.Secret{}
err := secret.Set(user.Secret)
if err != nil {
return
}
secrets[user.Name] = secret
}
h.proxy.UpdateUsers(secrets)
}
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
h.logger.InfoContext(ctx, "[", metadata.User, "] inbound connection to ", metadata.Destination)
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
}

View File

@@ -0,0 +1,60 @@
package mtproxy
import (
"fmt"
"github.com/dolonet/mtg-multi/mtglib"
"github.com/sagernet/sing/common/logger"
)
type LoggerAdapter struct {
logger logger.Logger
}
func NewLoggerAdapter(logger logger.Logger) *LoggerAdapter {
return &LoggerAdapter{logger}
}
func (l *LoggerAdapter) Named(name string) mtglib.Logger {
return l
}
func (l *LoggerAdapter) BindInt(name string, value int) mtglib.Logger {
return l
}
func (l *LoggerAdapter) BindStr(name, value string) mtglib.Logger {
return l
}
func (l *LoggerAdapter) BindJSON(name, value string) mtglib.Logger {
return l
}
func (l *LoggerAdapter) Printf(format string, args ...any) {
l.logger.Info(fmt.Sprintf(format, args...))
}
func (l *LoggerAdapter) Info(msg string) {
l.logger.Info(msg)
}
func (l *LoggerAdapter) InfoError(msg string, err error) {
l.logger.Error(msg, err)
}
func (l *LoggerAdapter) Warning(msg string) {
l.logger.Warn(msg)
}
func (l *LoggerAdapter) WarningError(msg string, err error) {
l.logger.Warn(msg, err)
}
func (l *LoggerAdapter) Debug(msg string) {
l.logger.Debug(msg)
}
func (l *LoggerAdapter) DebugError(msg string, err error) {
l.logger.Debug(msg, err)
}

View File

@@ -0,0 +1,43 @@
package mtproxy
import (
"context"
"net"
"net/http"
"github.com/dolonet/mtg-multi/essentials"
)
type NetworkAdapter struct {
ctx context.Context
dialer essentials.Dialer
}
func NewNetworkAdapter(ctx context.Context, dialer essentials.Dialer) *NetworkAdapter {
return &NetworkAdapter{ctx, dialer}
}
func (a *NetworkAdapter) Dial(network, address string) (essentials.Conn, error) {
return a.DialContext(a.ctx, network, address)
}
func (a *NetworkAdapter) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
conn, err := a.dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
return essentials.WrapNetConn(conn), nil
}
func (a *NetworkAdapter) MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error)) *http.Client {
return &http.Client{
Timeout: 10,
Transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return a.DialContext(ctx, network, addr)
}},
}
}
func (a *NetworkAdapter) NativeDialer() essentials.Dialer {
return a.dialer
}

View File

@@ -0,0 +1,38 @@
package parser
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/parser/link"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.ParserOutboundOptions](registry, C.TypeParser, NewOutbound)
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ParserOutboundOptions) (adapter.Outbound, error) {
if options.Link == "" {
return nil, E.New("missing link")
}
outboundOptions, err := link.ParseSubscriptionLink(options.Link)
if err != nil {
return nil, err
}
if dialerOptions, ok := outboundOptions.Options.(option.DialerOptionsWrapper); ok {
dialerOptions.ReplaceDialerOptions(options.DialerOptions)
}
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
outbound, err := outboundRegistry.UnsafeCreateOutbound(ctx, router, logger, tag, outboundOptions.Type, outboundOptions.Options)
if err != nil {
return nil, err
}
return outbound, nil
}

View File

View File

@@ -49,6 +49,7 @@ type DNSTransport struct {
dnsRouter adapter.DNSRouter
endpointManager adapter.EndpointManager
endpoint *Endpoint
access sync.RWMutex
routePrefixes []netip.Prefix
routes map[string][]adapter.DNSTransport
hosts map[string][]netip.Addr
@@ -91,6 +92,12 @@ func (t *DNSTransport) Start(stage adapter.StartStage) error {
}
func (t *DNSTransport) Reset() {
t.access.RLock()
transports := t.collectResolversLocked()
t.access.RUnlock()
for _, transport := range transports {
transport.Reset()
}
}
func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) {
@@ -101,7 +108,7 @@ func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, d
}
func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *nDNS.Config) error {
t.routePrefixes = buildRoutePrefixes(routeConfig)
routePrefixes := buildRoutePrefixes(routeConfig)
directDialerOnce := sync.OnceValue(func() N.Dialer {
directDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{}))
return &DNSDialer{transport: t, fallbackDialer: directDialer}
@@ -130,9 +137,19 @@ func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *n
}
defaultResolvers = append(defaultResolvers, myResolver)
}
t.access.Lock()
oldResolvers := t.collectResolversLocked()
t.routePrefixes = routePrefixes
t.routes = routes
t.hosts = hosts
t.defaultResolvers = defaultResolvers
t.access.Unlock()
for _, transport := range oldResolvers {
transport.Close()
}
if len(defaultResolvers) > 0 {
t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts, default resolvers: ",
strings.Join(common.Map(dnsConfig.DefaultResolvers, func(it *dnstype.Resolver) string { return it.Addr }), " "))
@@ -207,7 +224,22 @@ func buildRoutePrefixes(routeConfig *router.Config) []netip.Prefix {
}
func (t *DNSTransport) Close() error {
return nil
t.access.Lock()
transports := t.collectResolversLocked()
t.routePrefixes = nil
t.routes = nil
t.hosts = nil
t.defaultResolvers = nil
t.access.Unlock()
var err error
for _, transport := range transports {
name := "resolver/" + transport.Type() + "[" + transport.Tag() + "]"
err = E.Append(err, transport.Close(), func(err error) error {
return E.Cause(err, "close ", name)
})
}
return err
}
func (t *DNSTransport) Raw() bool {
@@ -219,7 +251,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
return nil, os.ErrInvalid
}
question := message.Question[0]
addresses, hostsLoaded := t.hosts[question.Name]
t.access.RLock()
hosts := t.hosts
routes := t.routes
defaultResolvers := t.defaultResolvers
acceptDefaultResolvers := t.acceptDefaultResolvers
t.access.RUnlock()
addresses, hostsLoaded := hosts[question.Name]
if hostsLoaded {
switch question.Qtype {
case mDNS.TypeA:
@@ -238,7 +278,7 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
}
}
}
for domainSuffix, transports := range t.routes {
for domainSuffix, transports := range routes {
if strings.HasSuffix(question.Name, domainSuffix) {
if len(transports) == 0 {
return &mDNS.Msg{
@@ -262,10 +302,10 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
return nil, lastErr
}
}
if t.acceptDefaultResolvers {
if len(t.defaultResolvers) > 0 {
if acceptDefaultResolvers {
if len(defaultResolvers) > 0 {
var lastErr error
for _, resolver := range t.defaultResolvers {
for _, resolver := range defaultResolvers {
response, err := resolver.Exchange(ctx, message)
if err != nil {
lastErr = err
@@ -281,6 +321,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
return nil, dns.RcodeNameError
}
func (t *DNSTransport) collectResolversLocked() []adapter.DNSTransport {
var transports []adapter.DNSTransport
for _, resolvers := range t.routes {
transports = append(transports, resolvers...)
}
transports = append(transports, t.defaultResolvers...)
return transports
}
type DNSDialer struct {
transport *DNSTransport
fallbackDialer N.Dialer
@@ -290,7 +339,8 @@ func (d *DNSDialer) DialContext(ctx context.Context, network string, destination
if destination.IsDomain() {
panic("invalid request here")
}
for _, prefix := range d.transport.routePrefixes {
routePrefixes := d.transport.routePrefixesSnapshot()
for _, prefix := range routePrefixes {
if prefix.Contains(destination.Addr) {
return d.transport.endpoint.DialContext(ctx, network, destination)
}
@@ -302,10 +352,17 @@ func (d *DNSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (
if destination.IsDomain() {
panic("invalid request here")
}
for _, prefix := range d.transport.routePrefixes {
routePrefixes := d.transport.routePrefixesSnapshot()
for _, prefix := range routePrefixes {
if prefix.Contains(destination.Addr) {
return d.transport.endpoint.ListenPacket(ctx, destination)
}
}
return d.fallbackDialer.ListenPacket(ctx, destination)
}
func (t *DNSTransport) routePrefixesSnapshot() []netip.Prefix {
t.access.RLock()
defer t.access.RUnlock()
return append([]netip.Prefix(nil), t.routePrefixes...)
}

View File

@@ -1,91 +0,0 @@
package tunnel
import (
"encoding/binary"
"io"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
const (
Version = 0
)
const (
CommandInbound = 1
CommandTCP = 2
)
var Destination = M.Socksaddr{
Fqdn: "sp.tunnel.sing-box.arpa",
Port: 444,
}
var AddressSerializer = M.NewSerializer(
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
M.AddressFamilyByte(0x03, M.AddressFamilyIPv6),
M.AddressFamilyByte(0x02, M.AddressFamilyFqdn),
M.PortThenAddress(),
)
type Request struct {
UUID uuid.UUID
Command byte
DestinationUUID uuid.UUID
Destination M.Socksaddr
}
func ReadRequest(reader io.Reader) (*Request, error) {
var request Request
var version uint8
err := binary.Read(reader, binary.BigEndian, &version)
if err != nil {
return nil, err
}
if version != Version {
return nil, E.New("unknown version: ", version)
}
_, err = io.ReadFull(reader, request.UUID[:])
if err != nil {
return nil, err
}
err = binary.Read(reader, binary.BigEndian, &request.Command)
if err != nil {
return nil, err
}
_, err = io.ReadFull(reader, request.DestinationUUID[:])
if err != nil {
return nil, err
}
request.Destination, err = AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
return &request, nil
}
func WriteRequest(writer io.Writer, request *Request) error {
var requestLen int
requestLen += 1 // version
requestLen += 16 // UUID
requestLen += 16 // destinationUUID
requestLen += 1 // command
requestLen += AddressSerializer.AddrPortLen(request.Destination)
buffer := buf.NewSize(requestLen)
defer buffer.Release()
common.Must(
buffer.WriteByte(Version),
common.Error(buffer.Write(request.UUID[:])),
buffer.WriteByte(request.Command),
common.Error(buffer.Write(request.DestinationUUID[:])),
)
err := AddressSerializer.WriteAddrPort(buffer, request.Destination)
if err != nil {
return err
}
return common.Error(writer.Write(buffer.Bytes()))
}

View File

@@ -1,201 +0,0 @@
package tunnel
import (
"context"
"net"
"time"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound"
sbUot "github.com/sagernet/sing-box/common/uot"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/service"
)
func RegisterServerEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.TunnelServerEndpointOptions](registry, C.TypeTunnelServer, NewServerEndpoint)
}
type ServerEndpoint struct {
outbound.Adapter
logger logger.ContextLogger
inbound adapter.Inbound
router adapter.ConnectionRouterEx
uuid uuid.UUID
users map[uuid.UUID]uuid.UUID
keys map[uuid.UUID]uuid.UUID
conns map[uuid.UUID]chan net.Conn
timeout time.Duration
uotClient *uot.Client
}
func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelServerEndpointOptions) (adapter.Endpoint, error) {
serverUUID, err := uuid.FromString(options.UUID)
if err != nil {
return nil, err
}
server := &ServerEndpoint{
Adapter: outbound.NewAdapter(C.TypeTunnelServer, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
logger: logger,
router: sbUot.NewRouter(router, logger),
uuid: serverUUID,
}
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
inbound, err := inboundRegistry.Create(ctx, NewRouter(router, logger, server.connHandler), logger, options.Inbound.Tag, options.Inbound.Type, options.Inbound.Options)
if err != nil {
return nil, err
}
server.inbound = inbound
server.users = make(map[uuid.UUID]uuid.UUID, len(options.Users))
server.keys = make(map[uuid.UUID]uuid.UUID, len(options.Users))
server.conns = make(map[uuid.UUID]chan net.Conn)
for _, user := range options.Users {
key, err := uuid.FromString(user.Key)
if err != nil {
return nil, err
}
uuid, err := uuid.FromString(user.UUID)
if err != nil {
return nil, err
}
server.users[key] = uuid
server.keys[uuid] = key
server.conns[uuid] = make(chan net.Conn, 10)
}
if options.ConnectTimeout != 0 {
server.timeout = time.Duration(options.ConnectTimeout)
} else {
server.timeout = C.TCPConnectTimeout
}
server.uotClient = &uot.Client{
Dialer: server,
Version: uot.Version,
}
return server, nil
}
func (s *ServerEndpoint) Start(stage adapter.StartStage) error {
return s.inbound.Start(stage)
}
func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if N.NetworkName(network) == N.NetworkUDP {
return s.uotClient.DialContext(ctx, network, destination)
}
var sourceUUID *uuid.UUID
var ch chan net.Conn
if metadata := adapter.ContextFrom(ctx); metadata != nil {
if metadata.TunnelDestination != "" {
tunnelDestination, err := uuid.FromString(metadata.TunnelDestination)
if err != nil {
return nil, err
}
var ok bool
ch, ok = s.conns[tunnelDestination]
if !ok {
return nil, E.New("user ", metadata.TunnelDestination, " not found")
}
}
if metadata.TunnelSource != "" {
tunnelSource, err := uuid.FromString(metadata.TunnelSource)
if err != nil {
return nil, err
}
sourceUUID = &tunnelSource
}
}
if ch == nil {
return nil, E.New("tunnel destination not set")
}
if sourceUUID == nil {
sourceUUID = &s.uuid
}
ctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
select {
case conn := <-ch:
err := WriteRequest(conn, &Request{UUID: *sourceUUID, Command: CommandTCP, Destination: destination})
if err != nil {
conn.Close()
s.logger.ErrorContext(ctx, err)
continue
}
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
func (s *ServerEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return s.uotClient.ListenPacket(ctx, destination)
}
func (s *ServerEndpoint) Close() error {
return common.Close(s.inbound)
}
func (s *ServerEndpoint) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
if metadata.Destination != Destination {
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return nil
}
request, err := ReadRequest(conn)
if err != nil {
return err
}
if request.Command == CommandInbound {
uuid, ok := s.users[request.UUID]
if !ok {
return E.New("key ", request.UUID.String(), " not found")
}
ch := s.conns[uuid]
select {
case ch <- conn:
default:
oldConn := <-ch
oldConn.Close()
ch <- conn
}
return nil
}
if request.Command == CommandTCP {
sourceUUID, ok := s.users[request.UUID]
if !ok {
return E.New("key ", request.UUID, " not found")
}
if sourceUUID == request.DestinationUUID {
return E.New("routing loop on ", sourceUUID)
}
if request.DestinationUUID != s.uuid {
_, ok = s.keys[request.DestinationUUID]
if !ok {
return E.New("user ", request.DestinationUUID, " not found")
}
}
metadata.Inbound = s.Tag()
metadata.InboundType = C.TypeTunnelServer
metadata.Destination = request.Destination
metadata.TunnelSource = sourceUUID.String()
metadata.TunnelDestination = request.DestinationUUID.String()
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return nil
}
return E.New("command ", request.Command, " not found")
}

View File

@@ -1,8 +1,9 @@
package tunnel
package vpn
import (
"context"
"net"
"net/netip"
"time"
"github.com/gofrs/uuid/v5"
@@ -23,7 +24,7 @@ import (
)
func RegisterClientEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.TunnelClientEndpointOptions](registry, C.TypeTunnelClient, NewClientEndpoint)
endpoint.Register[option.VPNClientEndpointOptions](registry, C.TypeVPNClient, NewClientEndpoint)
}
type ClientEndpoint struct {
@@ -32,27 +33,27 @@ type ClientEndpoint struct {
outbound adapter.Outbound
router adapter.ConnectionRouterEx
logger logger.ContextLogger
uuid uuid.UUID
address IPv4
key uuid.UUID
uotClient *uot.Client
}
func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelClientEndpointOptions) (adapter.Endpoint, error) {
clientUUID, err := uuid.FromString(options.UUID)
if err != nil {
return nil, err
func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VPNClientEndpointOptions) (adapter.Endpoint, error) {
address := options.Address
if !address.Is4() {
return nil, E.New("invalid address: ", address)
}
clientKey, err := uuid.FromString(options.Key)
key, err := uuid.FromString(options.Key)
if err != nil {
return nil, err
}
client := &ClientEndpoint{
Adapter: outbound.NewAdapter(C.TypeTunnelClient, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
Adapter: outbound.NewAdapter(C.TypeVPNClient, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
ctx: ctx,
router: sbUot.NewRouter(router, logger),
logger: logger,
uuid: clientUUID,
key: clientKey,
address: address.As4(),
key: key,
}
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
outbound, err := outboundRegistry.CreateOutbound(ctx, router, logger, options.Outbound.Tag, options.Outbound.Type, options.Outbound.Options)
@@ -94,27 +95,31 @@ func (c *ClientEndpoint) DialContext(ctx context.Context, network string, destin
if N.NetworkName(network) == N.NetworkUDP {
return c.uotClient.DialContext(ctx, network, destination)
}
var destinationUUID *uuid.UUID
if metadata := adapter.ContextFrom(ctx); metadata != nil {
if metadata.TunnelDestination != "" {
uuid, err := uuid.FromString(metadata.TunnelDestination)
if err != nil {
return nil, err
}
destinationUUID = &uuid
}
}
if destinationUUID == nil {
return nil, E.New("tunnel destination not set")
}
if *destinationUUID == c.uuid {
return nil, E.New("routing loop")
if destination.Addr.Is4() && destination.Addr.As4() == c.address {
return nil, E.New("routing loop on ", destination.Addr)
}
conn, err := c.outbound.DialContext(ctx, N.NetworkTCP, Destination)
if err != nil {
return nil, err
}
err = WriteRequest(conn, &Request{UUID: c.key, Command: CommandTCP, DestinationUUID: *destinationUUID, Destination: destination})
gateway := Loopback.As4()
if metadata := adapter.ContextFrom(ctx); metadata != nil {
if metadata.Gateway != nil {
gateway = metadata.Gateway.As4()
if gateway == c.address {
return nil, E.New("routing loop on ", destination.Addr)
}
}
}
err = WriteClientRequest(
conn,
&ClientRequest{
Key: c.key,
Command: CommandTCP,
Gateway: gateway,
Destination: destination,
},
)
if err != nil {
return nil, err
}
@@ -134,29 +139,22 @@ func (c *ClientEndpoint) startInboundConn() error {
if err != nil {
return err
}
err = WriteRequest(conn, &Request{UUID: c.key, Command: CommandInbound, Destination: Destination})
err = WriteClientRequest(conn, &ClientRequest{Key: c.key, Command: CommandInbound, Destination: Destination})
if err != nil {
return err
}
request, err := ReadRequest(conn)
request, err := ReadServerRequest(conn)
if err != nil {
return err
}
go c.connHandler(conn, request)
return nil
}
func (c *ClientEndpoint) connHandler(conn net.Conn, request *Request) {
if request.Source == c.address {
return E.New("routing loop")
}
metadata := adapter.InboundContext{
Inbound: c.Tag(),
Source: M.ParseSocksaddr(conn.RemoteAddr().String()),
Source: M.Socksaddr{Addr: netip.AddrFrom4(request.Source)},
Destination: request.Destination,
}
if request.UUID == c.uuid {
c.logger.ErrorContext(c.ctx, "routing loop")
conn.Close()
return
}
metadata.TunnelSource = request.UUID.String()
c.router.RouteConnectionEx(c.ctx, conn, metadata, func(it error) {})
go c.router.RouteConnectionEx(c.ctx, conn, metadata, func(it error) {})
return nil
}

124
protocol/vpn/protocol.go Normal file
View File

@@ -0,0 +1,124 @@
package vpn
import (
"encoding/binary"
"io"
"net/netip"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
const (
Version = 0
)
const (
CommandInbound = 1
CommandTCP = 2
)
type IPv4 [4]byte
var Destination = M.Socksaddr{
Fqdn: "sp.vpn.sing-box.arpa",
Port: 444,
}
var Loopback = netip.AddrFrom4([4]byte{127, 0, 0, 1})
var AddressSerializer = M.NewSerializer(
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
M.AddressFamilyByte(0x03, M.AddressFamilyIPv6),
M.AddressFamilyByte(0x02, M.AddressFamilyFqdn),
M.PortThenAddress(),
)
type ClientRequest struct {
Key uuid.UUID
Command byte
Gateway IPv4
Destination M.Socksaddr
}
func ReadClientRequest(reader io.Reader) (*ClientRequest, error) {
var request ClientRequest
var version uint8
err := binary.Read(reader, binary.BigEndian, &version)
if err != nil {
return nil, err
}
if version != Version {
return nil, E.New("unknown version: ", version)
}
_, err = io.ReadFull(reader, request.Key[:])
if err != nil {
return nil, err
}
err = binary.Read(reader, binary.BigEndian, &request.Command)
if err != nil {
return nil, err
}
_, err = io.ReadFull(reader, request.Gateway[:])
if err != nil {
return nil, err
}
request.Destination, err = AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
return &request, nil
}
func WriteClientRequest(writer io.Writer, request *ClientRequest) error {
var requestLen int
requestLen += 1 // version
requestLen += 16 // key
requestLen += 1 // command
requestLen += 4 // gateway
requestLen += AddressSerializer.AddrPortLen(request.Destination)
buffer := buf.NewSize(requestLen)
defer buffer.Release()
common.Must(
buffer.WriteByte(Version),
common.Error(buffer.Write(request.Key[:])),
buffer.WriteByte(request.Command),
common.Error(buffer.Write(request.Gateway[:])),
AddressSerializer.WriteAddrPort(buffer, request.Destination),
)
return common.Error(writer.Write(buffer.Bytes()))
}
type ServerRequest struct {
Source IPv4
Destination M.Socksaddr
}
func ReadServerRequest(reader io.Reader) (*ServerRequest, error) {
var request ServerRequest
_, err := io.ReadFull(reader, request.Source[:])
if err != nil {
return nil, err
}
request.Destination, err = AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
return &request, nil
}
func WriteServerRequest(writer io.Writer, request *ServerRequest) error {
var requestLen int
requestLen += 4 // source
requestLen += AddressSerializer.AddrPortLen(request.Destination)
buffer := buf.NewSize(requestLen)
defer buffer.Release()
common.Must(
common.Error(buffer.Write(request.Source[:])),
AddressSerializer.WriteAddrPort(buffer, request.Destination),
)
return common.Error(writer.Write(buffer.Bytes()))
}

View File

@@ -1,4 +1,4 @@
package tunnel
package vpn
import (
"context"
@@ -21,14 +21,24 @@ func NewRouter(router adapter.Router, logger logger.ContextLogger, handler func(
}
func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
if metadata.Destination != Destination {
return r.Router.RouteConnection(ctx, conn, metadata)
}
return r.handler(ctx, conn, metadata, func(error) {})
}
func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
if metadata.Destination != Destination {
return r.Router.RoutePacketConnection(ctx, conn, metadata)
}
return os.ErrInvalid
}
func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if metadata.Destination != Destination {
r.Router.RouteConnectionEx(ctx, conn, metadata, onClose)
return
}
if err := r.handler(ctx, conn, metadata, onClose); err != nil {
r.logger.ErrorContext(ctx, err)
N.CloseOnHandshakeFailure(conn, onClose, err)
@@ -36,6 +46,10 @@ func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata
}
func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if metadata.Destination != Destination {
r.Router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
return
}
r.logger.ErrorContext(ctx, os.ErrInvalid)
N.CloseOnHandshakeFailure(conn, onClose, os.ErrInvalid)
}

235
protocol/vpn/server.go Normal file
View File

@@ -0,0 +1,235 @@
package vpn
import (
"context"
"errors"
"net"
"net/netip"
"time"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound"
sbUot "github.com/sagernet/sing-box/common/uot"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/service"
)
func RegisterServerEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.VPNServerEndpointOptions](registry, C.TypeVPNServer, NewServerEndpoint)
}
type ServerEndpoint struct {
outbound.Adapter
logger logger.ContextLogger
inbounds []adapter.Inbound
router adapter.ConnectionRouterEx
address IPv4
addresses map[uuid.UUID]IPv4
keys map[IPv4]uuid.UUID
conns map[IPv4]chan net.Conn
timeout time.Duration
uotClient *uot.Client
}
func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VPNServerEndpointOptions) (adapter.Endpoint, error) {
address := options.Address
if !address.Is4() {
return nil, E.New("invalid address: ", address)
}
server := &ServerEndpoint{
Adapter: outbound.NewAdapter(C.TypeVPNServer, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
logger: logger,
router: sbUot.NewRouter(router, logger),
address: address.As4(),
}
router = NewRouter(router, logger, server.connHandler)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
inbounds := make([]adapter.Inbound, len(options.Inbounds))
for i, inboundOptions := range options.Inbounds {
inbound, err := inboundRegistry.Create(ctx, router, logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options)
if err != nil {
return nil, err
}
inbounds[i] = inbound
}
server.inbounds = inbounds
server.addresses = make(map[uuid.UUID]IPv4, len(options.Users))
server.keys = make(map[IPv4]uuid.UUID, len(options.Users))
server.conns = make(map[IPv4]chan net.Conn)
for _, user := range options.Users {
key, err := uuid.FromString(user.Key)
if err != nil {
return nil, err
}
if !user.Address.Is4() {
return nil, E.New("invalid address: ", user.Address)
}
address := user.Address.As4()
server.addresses[key] = address
server.keys[address] = key
server.conns[address] = make(chan net.Conn, 10)
}
if options.ConnectTimeout != 0 {
server.timeout = time.Duration(options.ConnectTimeout)
} else {
server.timeout = C.TCPConnectTimeout
}
server.uotClient = &uot.Client{
Dialer: server,
Version: uot.Version,
}
return server, nil
}
func (s *ServerEndpoint) Start(stage adapter.StartStage) error {
for _, inbound := range s.inbounds {
err := inbound.Start(stage)
if err != nil {
return err
}
}
return nil
}
func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if N.NetworkName(network) == N.NetworkUDP {
return s.uotClient.DialContext(ctx, network, destination)
}
source := s.address
var gateway *netip.Addr
if metadata := adapter.ContextFrom(ctx); metadata != nil {
if metadata.Source.IsIPv4() {
address := metadata.Source.Addr.As4()
if _, ok := s.conns[address]; ok {
source = address
}
}
if metadata.Gateway != nil {
gateway = metadata.Gateway
}
}
if gateway == nil {
if destination.IsIPv4() {
gateway = &destination.Addr
destination = M.Socksaddr{
Addr: Loopback,
Port: destination.Port,
}
} else {
return nil, E.New("missing gateway")
}
} else if destination.Addr.Compare(*gateway) == 0 {
destination = M.Socksaddr{
Addr: Loopback,
Port: destination.Port,
}
}
if gateway.Compare(Loopback) == 0 {
return nil, E.New("invalid gateway")
}
ch, ok := s.conns[gateway.As4()]
if !ok {
return nil, E.New("user with address ", gateway, " not found")
}
ctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
select {
case conn := <-ch:
err := WriteServerRequest(conn, &ServerRequest{Source: source, Destination: destination})
if err != nil {
conn.Close()
s.logger.ErrorContext(ctx, err)
continue
}
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
func (s *ServerEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return s.uotClient.ListenPacket(ctx, destination)
}
func (s *ServerEndpoint) Close() error {
errs := make([]error, 0)
for _, inbound := range s.inbounds {
err := inbound.Close()
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func (s *ServerEndpoint) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
if metadata.Destination != Destination {
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return nil
}
request, err := ReadClientRequest(conn)
if err != nil {
return err
}
if request.Command == CommandInbound {
address, ok := s.addresses[request.Key]
if !ok {
return E.New("key ", request.Key.String(), " not found")
}
ch := s.conns[address]
select {
case ch <- conn:
default:
oldConn := <-ch
oldConn.Close()
ch <- conn
}
return nil
}
if request.Command == CommandTCP {
source, ok := s.addresses[request.Key]
if !ok {
return E.New("key ", request.Key, " not found")
}
if request.Destination.Addr.Is4() && source == request.Destination.Addr.As4() {
return E.New("routing loop on ", request.Destination)
}
metadata.Inbound = s.Tag()
metadata.InboundType = C.TypeVPNServer
metadata.Source = M.Socksaddr{Addr: netip.AddrFrom4(source)}
if request.Destination.Addr.Is4() && request.Destination.Addr.As4() == s.address {
metadata.Destination = M.Socksaddr{
Addr: Loopback,
Port: request.Destination.Port,
}
} else {
metadata.Destination = request.Destination
if request.Gateway != s.address && request.Gateway != Loopback.As4() {
addr := netip.AddrFrom4(request.Gateway)
metadata.Gateway = &addr
}
}
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return nil
}
return E.New("command ", request.Command, " not found")
}

14
protocol/warp/config.go Normal file
View File

@@ -0,0 +1,14 @@
package warp
import "github.com/sagernet/sing-box/common/cloudflare"
type Config struct {
PrivateKey string `json:"private_key"`
Interface struct {
Addresses struct {
V4 string `json:"v4"`
V6 string `json:"v6"`
} `json:"addresses"`
} `json:"interface"`
Peers []cloudflare.Peer
}

View File

@@ -1,4 +1,4 @@
package wireguard
package warp
import (
"context"
@@ -7,7 +7,6 @@ import (
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
@@ -16,6 +15,7 @@ import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/protocol/wireguard"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json/badoption"
@@ -25,19 +25,21 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func RegisterWARPEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.WireGuardWARPEndpointOptions](registry, C.TypeWARP, NewWARPEndpoint)
func RegisterEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.WARPEndpointOptions](registry, C.TypeWARP, NewEndpoint)
}
type WARPEndpoint struct {
type Endpoint struct {
endpoint.Adapter
ctx context.Context
options option.WARPEndpointOptions
endpoint adapter.Endpoint
startHandler func()
mtx sync.Mutex
await chan struct{}
}
func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardWARPEndpointOptions) (adapter.Endpoint, error) {
func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WARPEndpointOptions) (adapter.Endpoint, error) {
var dependencies []string
if options.Detour != "" {
dependencies = append(dependencies, options.Detour)
@@ -45,14 +47,16 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont
if options.Profile.Detour != "" {
dependencies = append(dependencies, options.Profile.Detour)
}
warpEndpoint := &WARPEndpoint{
endpoint := &Endpoint{
Adapter: endpoint.NewAdapter(C.TypeWARP, tag, []string{N.NetworkTCP, N.NetworkUDP}, dependencies),
ctx: ctx,
options: options,
await: make(chan struct{}),
}
warpEndpoint.mtx.Lock()
warpEndpoint.startHandler = func() {
defer warpEndpoint.mtx.Unlock()
endpoint.startHandler = func() {
defer close(endpoint.await)
cacheFile := service.FromContext[adapter.CacheFile](ctx)
var config *C.WARPConfig
var config *Config
var err error
if !options.Profile.Recreate && cacheFile != nil && cacheFile.StoreWARPConfig() {
savedProfile := cacheFile.LoadWARPConfig(tag)
@@ -64,50 +68,10 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont
}
}
if config == nil {
var privateKey wgtypes.Key
if options.Profile.PrivateKey != "" {
privateKey, err = wgtypes.ParseKey(options.Profile.PrivateKey)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
} else {
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.ErrorContext(ctx, err)
return
}
}
opts := make([]cloudflare.CloudflareApiOption, 0, 1)
if options.Profile.Detour != "" {
detour, ok := service.FromContext[adapter.OutboundManager](ctx).Outbound(options.Profile.Detour)
if !ok {
logger.ErrorContext(ctx, E.New("outbound detour not found: ", options.Profile.Detour))
return
}
opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) {
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
}))
}
api := cloudflare.NewCloudflareApi(opts...)
var profile *cloudflare.CloudflareProfile
if options.Profile.AuthToken != "" && options.Profile.ID != "" {
profile, err = api.GetProfile(ctx, options.Profile.AuthToken, options.Profile.ID)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
} else {
profile, err = api.CreateProfile(ctx, privateKey.PublicKey().String())
if err != nil {
logger.ErrorContext(ctx, err)
return
}
}
config = &C.WARPConfig{
PrivateKey: privateKey.String(),
Interface: profile.Config.Interface,
Peers: profile.Config.Peers,
config, err := endpoint.createConfig()
if err != nil {
logger.ErrorContext(ctx, err)
return
}
if cacheFile != nil && cacheFile.StoreWARPConfig() {
content, err := json.Marshal(config)
@@ -124,7 +88,7 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont
}
peer := config.Peers[0]
hostParts := strings.Split(peer.Endpoint.Host, ":")
warpEndpoint.endpoint, err = NewEndpoint(
endpoint.endpoint, err = wireguard.NewEndpoint(
ctx,
router,
logger,
@@ -165,19 +129,19 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont
logger.ErrorContext(ctx, err)
return
}
if err = warpEndpoint.endpoint.Start(adapter.StartStateStart); err != nil {
if err = endpoint.endpoint.Start(adapter.StartStateStart); err != nil {
logger.ErrorContext(ctx, err)
return
}
if err = warpEndpoint.endpoint.Start(adapter.StartStatePostStart); err != nil {
if err = endpoint.endpoint.Start(adapter.StartStatePostStart); err != nil {
logger.ErrorContext(ctx, err)
return
}
}
return warpEndpoint, nil
return endpoint, nil
}
func (w *WARPEndpoint) Start(stage adapter.StartStage) error {
func (w *Endpoint) Start(stage adapter.StartStage) error {
if stage != adapter.StartStatePostStart {
return nil
}
@@ -185,26 +149,79 @@ func (w *WARPEndpoint) Start(stage adapter.StartStage) error {
return nil
}
func (w *WARPEndpoint) Close() error {
func (w *Endpoint) Close() error {
if err := w.isEndpointInitialized(w.ctx); err != nil {
return err
}
return common.Close(w.endpoint)
}
func (w *WARPEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if ok := w.isEndpointInitialized(); !ok {
return nil, E.New("endpoint not initialized")
func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if err := w.isEndpointInitialized(ctx); err != nil {
return nil, err
}
return w.endpoint.DialContext(ctx, network, destination)
}
func (w *WARPEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if ok := w.isEndpointInitialized(); !ok {
return nil, E.New("endpoint not initialized")
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if err := w.isEndpointInitialized(ctx); err != nil {
return nil, err
}
return w.endpoint.ListenPacket(ctx, destination)
}
func (w *WARPEndpoint) isEndpointInitialized() bool {
w.mtx.Lock()
defer w.mtx.Unlock()
return w.endpoint != nil
func (w *Endpoint) isEndpointInitialized(ctx context.Context) error {
select {
case <-w.await:
case <-ctx.Done():
return ctx.Err()
}
if w.endpoint == nil {
return E.New("endpoint not initialized")
}
return nil
}
func (w *Endpoint) createConfig() (*Config, error) {
var privateKey wgtypes.Key
var err error
if w.options.Profile.PrivateKey != "" {
privateKey, err = wgtypes.ParseKey(w.options.Profile.PrivateKey)
if err != nil {
return nil, err
}
} else {
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
}
opts := make([]cloudflare.CloudflareApiOption, 0, 1)
if w.options.Profile.Detour != "" {
detour, ok := service.FromContext[adapter.OutboundManager](w.ctx).Outbound(w.options.Profile.Detour)
if !ok {
return nil, E.New("outbound detour not found: ", w.options.Profile.Detour)
}
opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) {
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
}))
}
api := cloudflare.NewCloudflareApi(opts...)
var profile *cloudflare.CloudflareProfile
if w.options.Profile.AuthToken != "" && w.options.Profile.ID != "" {
profile, err = api.GetProfile(w.ctx, w.options.Profile.AuthToken, w.options.Profile.ID)
if err != nil {
return nil, err
}
} else {
profile, err = api.CreateProfile(w.ctx, privateKey.PublicKey().String())
if err != nil {
return nil, err
}
}
return &Config{
PrivateKey: privateKey.String(),
Interface: profile.Config.Interface,
Peers: profile.Config.Peers,
}, nil
}