mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-05-14 00:51:12 +03:00
Add MTProxy, MASQUE, VPN, Link parser. Update AmneziaWG. Remove Tunneling
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package bond
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -13,9 +14,7 @@ type bondedConn struct {
|
||||
downloadRatios []uint8
|
||||
uploadRatios []uint8
|
||||
|
||||
readBuffer []byte
|
||||
readOffset int
|
||||
readSize int
|
||||
readBuffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func NewBondedConn(conns []net.Conn, downloadRatios, uploadRatios []uint8) *bondedConn {
|
||||
@@ -23,12 +22,13 @@ func NewBondedConn(conns []net.Conn, downloadRatios, uploadRatios []uint8) *bond
|
||||
conns: conns,
|
||||
downloadRatios: downloadRatios,
|
||||
uploadRatios: uploadRatios,
|
||||
readBuffer: make([]byte, 65535),
|
||||
readBuffer: bytes.NewBuffer(make([]byte, 0, 65536)),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *bondedConn) Read(b []byte) (n int, err error) {
|
||||
if c.readOffset == c.readSize {
|
||||
if c.readBuffer.Len() == 0 {
|
||||
c.readBuffer.Reset()
|
||||
var header [2]byte
|
||||
_, err := io.ReadFull(c.conns[0], header[:])
|
||||
if err != nil {
|
||||
@@ -41,19 +41,14 @@ func (c *bondedConn) Read(b []byte) (n int, err error) {
|
||||
if chunkLen == 0 {
|
||||
continue
|
||||
}
|
||||
chunk := c.readBuffer[total : total+chunkLen]
|
||||
n, err := io.ReadFull(c.conns[i], chunk)
|
||||
total += n
|
||||
n, err := io.CopyN(c.readBuffer, c.conns[i], int64(chunkLen))
|
||||
total += int(n)
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
c.readOffset = 0
|
||||
c.readSize = size
|
||||
}
|
||||
n = copy(b, c.readBuffer[c.readOffset:c.readSize])
|
||||
c.readOffset += n
|
||||
return n, nil
|
||||
return c.readBuffer.Read(b)
|
||||
}
|
||||
|
||||
func (c *bondedConn) Write(b []byte) (n int, err error) {
|
||||
|
||||
@@ -6,7 +6,8 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/patrickmn/go-cache/v2"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/common/kmutex"
|
||||
@@ -29,9 +30,9 @@ type Inbound struct {
|
||||
logger logger.ContextLogger
|
||||
router adapter.ConnectionRouterEx
|
||||
inbounds []adapter.Inbound
|
||||
conns *cache.Cache
|
||||
conns *cache.Cache[uuid.UUID, map[uint8]*ratioConn]
|
||||
|
||||
mtx *kmutex.Kmutex[string]
|
||||
mtx *kmutex.Kmutex[uuid.UUID]
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.BondInboundOptions) (adapter.Inbound, error) {
|
||||
@@ -42,23 +43,23 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
Adapter: inbound.NewAdapter(C.TypeBond, tag),
|
||||
logger: logger,
|
||||
router: uot.NewRouter(router, logger),
|
||||
conns: cache.New(C.TCPConnectTimeout, time.Second),
|
||||
mtx: kmutex.New[string](),
|
||||
conns: cache.New[uuid.UUID, map[uint8]*ratioConn](C.TCPConnectTimeout, time.Second),
|
||||
mtx: kmutex.New[uuid.UUID](),
|
||||
}
|
||||
router = NewRouter(router, logger, inbound.connHandler)
|
||||
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
||||
inbounds := make([]adapter.Inbound, len(options.Inbounds))
|
||||
for i, inboundOptions := range options.Inbounds {
|
||||
inbound, err := inboundRegistry.UnsafeCreate(ctx, NewRouter(router, logger, inbound.connHandler), logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options)
|
||||
inbound, err := inboundRegistry.UnsafeCreate(ctx, router, logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbounds[i] = inbound
|
||||
}
|
||||
inbound.inbounds = inbounds
|
||||
inbound.conns.OnEvicted(func(s string, i interface{}) {
|
||||
inbound.conns.OnEvicted(func(s uuid.UUID, ratioConns map[uint8]*ratioConn) {
|
||||
inbound.mtx.Lock(s)
|
||||
defer inbound.mtx.Unlock(s)
|
||||
ratioConns := i.(map[uint8]*ratioConn)
|
||||
for _, ratioConn := range ratioConns {
|
||||
if ratioConn != nil {
|
||||
ratioConn.conn.Close()
|
||||
@@ -93,21 +94,15 @@ func (h *Inbound) Close() error {
|
||||
}
|
||||
|
||||
func (h *Inbound) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
|
||||
if metadata.Destination != Destination {
|
||||
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
return nil
|
||||
}
|
||||
request, err := ReadRequest(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
requestUUID := request.UUID.String()
|
||||
requestUUID := request.UUID
|
||||
h.mtx.Lock(requestUUID)
|
||||
var ratioConns map[uint8]*ratioConn
|
||||
rawRatioConns, ok := h.conns.Get(requestUUID)
|
||||
if ok {
|
||||
ratioConns = rawRatioConns.(map[uint8]*ratioConn)
|
||||
} else {
|
||||
ratioConns, ok := h.conns.Get(requestUUID)
|
||||
if !ok {
|
||||
ratioConns = make(map[uint8]*ratioConn, request.Count)
|
||||
h.conns.SetDefault(requestUUID, ratioConns)
|
||||
}
|
||||
|
||||
@@ -21,14 +21,24 @@ func NewRouter(router adapter.Router, logger logger.ContextLogger, handler func(
|
||||
}
|
||||
|
||||
func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
if metadata.Destination != Destination {
|
||||
return r.Router.RouteConnection(ctx, conn, metadata)
|
||||
}
|
||||
return r.handler(ctx, conn, metadata, func(error) {})
|
||||
}
|
||||
|
||||
func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||
if metadata.Destination != Destination {
|
||||
return r.Router.RoutePacketConnection(ctx, conn, metadata)
|
||||
}
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
if metadata.Destination != Destination {
|
||||
r.Router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
return
|
||||
}
|
||||
if err := r.handler(ctx, conn, metadata, onClose); err != nil {
|
||||
r.logger.ErrorContext(ctx, err)
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
@@ -36,6 +46,10 @@ func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata
|
||||
}
|
||||
|
||||
func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
if metadata.Destination != Destination {
|
||||
r.Router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||
return
|
||||
}
|
||||
r.logger.ErrorContext(ctx, os.ErrInvalid)
|
||||
N.CloseOnHandshakeFailure(conn, onClose, os.ErrInvalid)
|
||||
}
|
||||
|
||||
@@ -17,15 +17,15 @@ import (
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterFailover(registry *outbound.Registry) {
|
||||
outbound.Register[option.FailoverOutboundOptions](registry, C.TypeFailover, NewFailover)
|
||||
func RegisterFallback(registry *outbound.Registry) {
|
||||
outbound.Register[option.FallbackOutboundOptions](registry, C.TypeFallback, NewFallback)
|
||||
}
|
||||
|
||||
var (
|
||||
_ adapter.OutboundGroup = (*Failover)(nil)
|
||||
_ adapter.OutboundGroup = (*Fallback)(nil)
|
||||
)
|
||||
|
||||
type Failover struct {
|
||||
type Fallback struct {
|
||||
outbound.Adapter
|
||||
ctx context.Context
|
||||
outbound adapter.OutboundManager
|
||||
@@ -37,12 +37,12 @@ type Failover struct {
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewFailover(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FailoverOutboundOptions) (adapter.Outbound, error) {
|
||||
func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FallbackOutboundOptions) (adapter.Outbound, error) {
|
||||
if len(options.Outbounds) == 0 {
|
||||
return nil, E.New("missing tags")
|
||||
}
|
||||
outbound := &Failover{
|
||||
Adapter: outbound.NewAdapter(C.TypeFailover, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
|
||||
outbound := &Fallback{
|
||||
Adapter: outbound.NewAdapter(C.TypeFallback, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
|
||||
ctx: ctx,
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
logger: logger,
|
||||
@@ -53,7 +53,7 @@ func NewFailover(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||
return outbound, nil
|
||||
}
|
||||
|
||||
func (s *Failover) Start() error {
|
||||
func (s *Fallback) Start() error {
|
||||
for i, tag := range s.tags {
|
||||
outbound, loaded := s.outbound.Outbound(tag)
|
||||
if !loaded {
|
||||
@@ -64,17 +64,17 @@ func (s *Failover) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Failover) Now() string {
|
||||
func (s *Fallback) Now() string {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
return s.lastUsedOutbound
|
||||
}
|
||||
|
||||
func (s *Failover) All() []string {
|
||||
func (s *Fallback) All() []string {
|
||||
return s.tags
|
||||
}
|
||||
|
||||
func (s *Failover) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
func (s *Fallback) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
for _, outbound := range s.outbounds {
|
||||
@@ -91,7 +91,7 @@ func (s *Failover) DialContext(ctx context.Context, network string, destination
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *Failover) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
func (s *Fallback) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
var conn net.PacketConn
|
||||
var err error
|
||||
for _, outbound := range s.outbounds {
|
||||
@@ -3,6 +3,7 @@ package group
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
@@ -42,11 +43,20 @@ type Selector struct {
|
||||
selected common.TypedValue[adapter.Outbound]
|
||||
interruptGroup *interrupt.Group
|
||||
interruptExternalConnections bool
|
||||
|
||||
provider adapter.ProviderManager
|
||||
providers map[string]adapter.Provider
|
||||
outboundsCache map[string][]adapter.Outbound
|
||||
|
||||
providerTags []string
|
||||
exclude *regexp.Regexp
|
||||
include *regexp.Regexp
|
||||
useAllProviders bool
|
||||
}
|
||||
|
||||
func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) {
|
||||
outbound := &Selector{
|
||||
Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds),
|
||||
Adapter: outbound.NewAdapter(C.TypeSelector, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
|
||||
ctx: ctx,
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
connection: service.FromContext[adapter.ConnectionManager](ctx),
|
||||
@@ -56,9 +66,15 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||
outbounds: make(map[string]adapter.Outbound),
|
||||
interruptGroup: interrupt.NewGroup(),
|
||||
interruptExternalConnections: options.InterruptExistConnections,
|
||||
}
|
||||
if len(outbound.tags) == 0 {
|
||||
return nil, E.New("missing tags")
|
||||
|
||||
provider: service.FromContext[adapter.ProviderManager](ctx),
|
||||
providers: make(map[string]adapter.Provider),
|
||||
outboundsCache: make(map[string][]adapter.Outbound),
|
||||
|
||||
providerTags: options.Providers,
|
||||
exclude: (*regexp.Regexp)(options.Exclude),
|
||||
include: (*regexp.Regexp)(options.Include),
|
||||
useAllProviders: options.UseAllProviders,
|
||||
}
|
||||
return outbound, nil
|
||||
}
|
||||
@@ -72,6 +88,28 @@ func (s *Selector) Network() []string {
|
||||
}
|
||||
|
||||
func (s *Selector) Start() error {
|
||||
if s.useAllProviders {
|
||||
var providerTags []string
|
||||
for _, provider := range s.provider.Providers() {
|
||||
providerTags = append(providerTags, provider.Tag())
|
||||
s.providers[provider.Tag()] = provider
|
||||
provider.RegisterCallback(s.onProviderUpdated)
|
||||
}
|
||||
s.providerTags = providerTags
|
||||
} else {
|
||||
for i, tag := range s.providerTags {
|
||||
provider, loaded := s.provider.Get(tag)
|
||||
if !loaded {
|
||||
return E.New("outbound provider ", i, " not found: ", tag)
|
||||
}
|
||||
s.providers[tag] = provider
|
||||
provider.RegisterCallback(s.onProviderUpdated)
|
||||
}
|
||||
}
|
||||
if len(s.tags)+len(s.providerTags) == 0 {
|
||||
return E.New("missing outbound and provider tags")
|
||||
}
|
||||
|
||||
for i, tag := range s.tags {
|
||||
detour, loaded := s.outbound.Outbound(tag)
|
||||
if !loaded {
|
||||
@@ -79,31 +117,16 @@ func (s *Selector) Start() error {
|
||||
}
|
||||
s.outbounds[tag] = detour
|
||||
}
|
||||
|
||||
if s.Tag() != "" {
|
||||
cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
|
||||
if cacheFile != nil {
|
||||
selected := cacheFile.LoadSelected(s.Tag())
|
||||
if selected != "" {
|
||||
detour, loaded := s.outbounds[selected]
|
||||
if loaded {
|
||||
s.selected.Store(detour)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(s.tags) == 0 {
|
||||
detour, _ := s.outbound.Outbound("Compatible")
|
||||
s.tags = append(s.tags, detour.Tag())
|
||||
s.outbounds[detour.Tag()] = detour
|
||||
}
|
||||
|
||||
if s.defaultTag != "" {
|
||||
detour, loaded := s.outbounds[s.defaultTag]
|
||||
if !loaded {
|
||||
return E.New("default outbound not found: ", s.defaultTag)
|
||||
}
|
||||
s.selected.Store(detour)
|
||||
return nil
|
||||
outbound, err := s.outboundSelect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.selected.Store(s.outbounds[s.tags[0]])
|
||||
s.selected.Store(outbound)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -145,7 +168,7 @@ func (s *Selector) DialContext(ctx context.Context, network string, destination
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
|
||||
return s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil
|
||||
}
|
||||
|
||||
func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
@@ -153,13 +176,13 @@ func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
|
||||
return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil
|
||||
}
|
||||
|
||||
func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
ctx = interrupt.ContextWithIsExternalConnection(ctx)
|
||||
selected := s.selected.Load()
|
||||
conn = s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx))
|
||||
conn = s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx))
|
||||
if outboundHandler, isHandler := selected.(adapter.ConnectionHandlerEx); isHandler {
|
||||
outboundHandler.NewConnectionEx(ctx, conn, metadata, onClose)
|
||||
} else {
|
||||
@@ -170,7 +193,7 @@ func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata
|
||||
func (s *Selector) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
ctx = interrupt.ContextWithIsExternalConnection(ctx)
|
||||
selected := s.selected.Load()
|
||||
conn = s.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx))
|
||||
conn = s.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx))
|
||||
if outboundHandler, isHandler := selected.(adapter.PacketConnectionHandlerEx); isHandler {
|
||||
outboundHandler.NewPacketConnectionEx(ctx, conn, metadata, onClose)
|
||||
} else {
|
||||
@@ -192,3 +215,77 @@ func RealTag(detour adapter.Outbound) string {
|
||||
}
|
||||
return detour.Tag()
|
||||
}
|
||||
|
||||
func (s *Selector) onProviderUpdated(tag string) error {
|
||||
_, loaded := s.providers[tag]
|
||||
if !loaded {
|
||||
return E.New(s.Tag(), ": ", "outbound provider not found: ", tag)
|
||||
}
|
||||
var (
|
||||
tags = s.Dependencies()
|
||||
outboundByTag = make(map[string]adapter.Outbound)
|
||||
)
|
||||
for _, tag := range tags {
|
||||
outboundByTag[tag] = s.outbounds[tag]
|
||||
}
|
||||
for _, providerTag := range s.providerTags {
|
||||
if providerTag != tag && s.outboundsCache[providerTag] != nil {
|
||||
for _, detour := range s.outboundsCache[providerTag] {
|
||||
tags = append(tags, detour.Tag())
|
||||
outboundByTag[detour.Tag()] = detour
|
||||
}
|
||||
continue
|
||||
}
|
||||
provider := s.providers[providerTag]
|
||||
var cache []adapter.Outbound
|
||||
for _, detour := range provider.Outbounds() {
|
||||
tag := detour.Tag()
|
||||
if s.exclude != nil && s.exclude.MatchString(tag) {
|
||||
continue
|
||||
}
|
||||
if s.include != nil && !s.include.MatchString(tag) {
|
||||
continue
|
||||
}
|
||||
tags = append(tags, tag)
|
||||
cache = append(cache, detour)
|
||||
outboundByTag[tag] = detour
|
||||
}
|
||||
s.outboundsCache[providerTag] = cache
|
||||
}
|
||||
if len(tags) == 0 {
|
||||
detour, _ := s.outbound.Outbound("Compatible")
|
||||
tags = append(tags, detour.Tag())
|
||||
outboundByTag[detour.Tag()] = detour
|
||||
}
|
||||
s.tags, s.outbounds = tags, outboundByTag
|
||||
detour, _ := s.outboundSelect()
|
||||
if s.selected.Swap(detour) != detour {
|
||||
s.interruptGroup.Interrupt(s.interruptExternalConnections)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Selector) outboundSelect() (adapter.Outbound, error) {
|
||||
if s.Tag() != "" {
|
||||
cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
|
||||
if cacheFile != nil {
|
||||
selected := cacheFile.LoadSelected(s.Tag())
|
||||
if selected != "" {
|
||||
detour, loaded := s.outbounds[selected]
|
||||
if loaded {
|
||||
return detour, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.defaultTag != "" {
|
||||
detour, loaded := s.outbounds[s.defaultTag]
|
||||
if !loaded {
|
||||
return nil, E.New("default outbound not found: ", s.defaultTag)
|
||||
}
|
||||
return detour, nil
|
||||
}
|
||||
|
||||
return s.outbounds[s.tags[0]], nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package group
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"regexp"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -45,6 +46,16 @@ type URLTest struct {
|
||||
idleTimeout time.Duration
|
||||
group *URLTestGroup
|
||||
interruptExternalConnections bool
|
||||
|
||||
provider adapter.ProviderManager
|
||||
providers map[string]adapter.Provider
|
||||
outboundsCache map[string][]adapter.Outbound
|
||||
cancel context.CancelFunc
|
||||
|
||||
providerTags []string
|
||||
exclude *regexp.Regexp
|
||||
include *regexp.Regexp
|
||||
useAllProviders bool
|
||||
}
|
||||
|
||||
func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) {
|
||||
@@ -61,14 +72,42 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
tolerance: options.Tolerance,
|
||||
idleTimeout: time.Duration(options.IdleTimeout),
|
||||
interruptExternalConnections: options.InterruptExistConnections,
|
||||
}
|
||||
if len(outbound.tags) == 0 {
|
||||
return nil, E.New("missing tags")
|
||||
|
||||
provider: service.FromContext[adapter.ProviderManager](ctx),
|
||||
providers: make(map[string]adapter.Provider),
|
||||
outboundsCache: make(map[string][]adapter.Outbound),
|
||||
|
||||
providerTags: options.Providers,
|
||||
exclude: (*regexp.Regexp)(options.Exclude),
|
||||
include: (*regexp.Regexp)(options.Include),
|
||||
useAllProviders: options.UseAllProviders,
|
||||
}
|
||||
return outbound, nil
|
||||
}
|
||||
|
||||
func (s *URLTest) Start() error {
|
||||
if s.useAllProviders {
|
||||
var providerTags []string
|
||||
for _, provider := range s.provider.Providers() {
|
||||
providerTags = append(providerTags, provider.Tag())
|
||||
s.providers[provider.Tag()] = provider
|
||||
provider.RegisterCallback(s.onProviderUpdated)
|
||||
}
|
||||
s.providerTags = providerTags
|
||||
} else {
|
||||
for i, tag := range s.providerTags {
|
||||
provider, loaded := s.provider.Get(tag)
|
||||
if !loaded {
|
||||
return E.New("outbound provider ", i, " not found: ", tag)
|
||||
}
|
||||
s.providers[tag] = provider
|
||||
provider.RegisterCallback(s.onProviderUpdated)
|
||||
}
|
||||
}
|
||||
if len(s.tags)+len(s.providerTags) == 0 {
|
||||
return E.New("missing outbound and provider tags")
|
||||
}
|
||||
|
||||
outbounds := make([]adapter.Outbound, 0, len(s.tags))
|
||||
for i, tag := range s.tags {
|
||||
detour, loaded := s.outbound.Outbound(tag)
|
||||
@@ -77,6 +116,11 @@ func (s *URLTest) Start() error {
|
||||
}
|
||||
outbounds = append(outbounds, detour)
|
||||
}
|
||||
if len(s.tags) == 0 {
|
||||
detour, _ := s.outbound.Outbound("Compatible")
|
||||
s.tags = append(s.tags, detour.Tag())
|
||||
outbounds = append(outbounds, detour)
|
||||
}
|
||||
group, err := NewURLTestGroup(s.ctx, s.outbound, s.logger, outbounds, s.link, s.interval, s.tolerance, s.idleTimeout, s.interruptExternalConnections)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -136,7 +180,7 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M
|
||||
}
|
||||
conn, err := outbound.DialContext(ctx, network, destination)
|
||||
if err == nil {
|
||||
return s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
|
||||
return s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil
|
||||
}
|
||||
s.logger.ErrorContext(ctx, err)
|
||||
s.group.history.DeleteURLTestHistory(outbound.Tag())
|
||||
@@ -154,7 +198,7 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne
|
||||
}
|
||||
conn, err := outbound.ListenPacket(ctx, destination)
|
||||
if err == nil {
|
||||
return s.group.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
|
||||
return s.group.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil
|
||||
}
|
||||
s.logger.ErrorContext(ctx, err)
|
||||
s.group.history.DeleteURLTestHistory(outbound.Tag())
|
||||
@@ -163,13 +207,13 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne
|
||||
|
||||
func (s *URLTest) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
ctx = interrupt.ContextWithIsExternalConnection(ctx)
|
||||
conn = s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx))
|
||||
conn = s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx))
|
||||
s.connection.NewConnection(ctx, s, conn, metadata, onClose)
|
||||
}
|
||||
|
||||
func (s *URLTest) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
ctx = interrupt.ContextWithIsExternalConnection(ctx)
|
||||
conn = s.group.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx))
|
||||
conn = s.group.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx))
|
||||
s.connection.NewPacketConnection(ctx, s, conn, metadata, onClose)
|
||||
}
|
||||
|
||||
@@ -188,6 +232,63 @@ func (s *URLTest) NewDirectRouteConnection(metadata adapter.InboundContext, rout
|
||||
return selected.(adapter.DirectRouteOutbound).NewDirectRouteConnection(metadata, routeContext, timeout)
|
||||
}
|
||||
|
||||
func (s *URLTest) onProviderUpdated(tag string) error {
|
||||
_, loaded := s.providers[tag]
|
||||
if !loaded {
|
||||
return E.New("outbound provider not found: ", tag)
|
||||
}
|
||||
var (
|
||||
tags = s.Dependencies()
|
||||
outbounds []adapter.Outbound
|
||||
)
|
||||
for _, tag := range tags {
|
||||
detour, _ := s.outbound.Outbound(tag)
|
||||
outbounds = append(outbounds, detour)
|
||||
}
|
||||
for _, providerTag := range s.providerTags {
|
||||
if providerTag != tag && s.outboundsCache[providerTag] != nil {
|
||||
for _, detour := range s.outboundsCache[providerTag] {
|
||||
tags = append(tags, detour.Tag())
|
||||
outbounds = append(outbounds, detour)
|
||||
}
|
||||
continue
|
||||
}
|
||||
provider := s.providers[providerTag]
|
||||
var cache []adapter.Outbound
|
||||
for _, detour := range provider.Outbounds() {
|
||||
tag := detour.Tag()
|
||||
if s.exclude != nil && s.exclude.MatchString(tag) {
|
||||
continue
|
||||
}
|
||||
if s.include != nil && !s.include.MatchString(tag) {
|
||||
continue
|
||||
}
|
||||
tags = append(tags, tag)
|
||||
cache = append(cache, detour)
|
||||
}
|
||||
outbounds = append(outbounds, cache...)
|
||||
s.outboundsCache[providerTag] = cache
|
||||
}
|
||||
if len(tags) == 0 {
|
||||
detour, _ := s.outbound.Outbound("Compatible")
|
||||
tags = append(tags, detour.Tag())
|
||||
outbounds = append(outbounds, detour)
|
||||
}
|
||||
s.tags, s.group.outbounds = tags, outbounds
|
||||
s.group.access.Lock()
|
||||
if s.group.ticker != nil {
|
||||
s.group.ticker.Reset(s.group.interval)
|
||||
}
|
||||
s.group.access.Unlock()
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
s.cancel = cancel
|
||||
s.URLTest(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
type URLTestGroup struct {
|
||||
ctx context.Context
|
||||
router adapter.Router
|
||||
@@ -407,7 +508,11 @@ func (g *URLTestGroup) urlTest(ctx context.Context, force bool) (map[string]uint
|
||||
})
|
||||
}
|
||||
b.Wait()
|
||||
g.performUpdateCheck()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
default:
|
||||
g.performUpdateCheck()
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
119
protocol/limiter/bandwidth/conn.go
Normal file
119
protocol/limiter/bandwidth/conn.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package bandwidth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type connWithDownloadBandwidthLimiter struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
}
|
||||
|
||||
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithDownloadBandwidthLimiter {
|
||||
return &connWithDownloadBandwidthLimiter{conn, ctx, limiter}
|
||||
}
|
||||
|
||||
func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error) {
|
||||
err = conn.limiter.WaitN(conn.ctx, len(p))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return conn.Conn.Write(p)
|
||||
}
|
||||
|
||||
type connWithUploadBandwidthLimiter struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
limiter *rate.Limiter
|
||||
burst int
|
||||
}
|
||||
|
||||
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithUploadBandwidthLimiter {
|
||||
return &connWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
|
||||
}
|
||||
|
||||
func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) {
|
||||
n, err = conn.Conn.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = conn.limiter.WaitN(conn.ctx, n)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
type connWithCloseHandler struct {
|
||||
net.Conn
|
||||
onClose CloseHandlerFunc
|
||||
}
|
||||
|
||||
func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler {
|
||||
return &connWithCloseHandler{conn, onClose}
|
||||
}
|
||||
|
||||
func (conn *connWithCloseHandler) Close() error {
|
||||
conn.onClose()
|
||||
return conn.Conn.Close()
|
||||
}
|
||||
|
||||
type packetConnWithDownloadBandwidthLimiter struct {
|
||||
net.PacketConn
|
||||
ctx context.Context
|
||||
limiter *rate.Limiter
|
||||
burst int
|
||||
}
|
||||
|
||||
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithDownloadBandwidthLimiter {
|
||||
return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
|
||||
}
|
||||
|
||||
func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
err = conn.limiter.WaitN(conn.ctx, len(p))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return conn.PacketConn.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
type packetConnWithUploadBandwidthLimiter struct {
|
||||
net.PacketConn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
burst int
|
||||
}
|
||||
|
||||
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithUploadBandwidthLimiter {
|
||||
return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
|
||||
}
|
||||
|
||||
func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = conn.PacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = conn.limiter.WaitN(conn.ctx, n)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type packetConnWithCloseHandler struct {
|
||||
net.PacketConn
|
||||
onClose CloseHandlerFunc
|
||||
}
|
||||
|
||||
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler {
|
||||
return &packetConnWithCloseHandler{conn, onClose}
|
||||
}
|
||||
|
||||
func (conn *packetConnWithCloseHandler) Close() error {
|
||||
conn.onClose()
|
||||
return conn.PacketConn.Close()
|
||||
}
|
||||
@@ -2,157 +2,8 @@ package bandwidth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type connWithDownloadBandwidthLimiter struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
limiter *rate.Limiter
|
||||
burst int
|
||||
}
|
||||
|
||||
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithDownloadBandwidthLimiter {
|
||||
return &connWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
|
||||
}
|
||||
|
||||
func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error) {
|
||||
var nn int
|
||||
for {
|
||||
end := len(p)
|
||||
if end == 0 {
|
||||
break
|
||||
}
|
||||
if conn.burst < len(p) {
|
||||
end = conn.burst
|
||||
}
|
||||
err = conn.limiter.WaitN(conn.ctx, end)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
nn, err = conn.Conn.Write(p[:end])
|
||||
n += nn
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p = p[end:]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type connWithUploadBandwidthLimiter struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
limiter *rate.Limiter
|
||||
burst int
|
||||
}
|
||||
|
||||
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithUploadBandwidthLimiter {
|
||||
return &connWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
|
||||
}
|
||||
|
||||
func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) {
|
||||
if conn.burst < len(p) {
|
||||
p = p[:conn.burst]
|
||||
}
|
||||
n, err = conn.Conn.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = conn.limiter.WaitN(conn.ctx, n)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type connWithCloseHandler struct {
|
||||
net.Conn
|
||||
onClose CloseHandlerFunc
|
||||
}
|
||||
|
||||
func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler {
|
||||
return &connWithCloseHandler{conn, onClose}
|
||||
}
|
||||
|
||||
func (conn *connWithCloseHandler) Close() error {
|
||||
conn.onClose()
|
||||
return conn.Conn.Close()
|
||||
}
|
||||
|
||||
type packetConnWithDownloadBandwidthLimiter struct {
|
||||
net.PacketConn
|
||||
ctx context.Context
|
||||
limiter *rate.Limiter
|
||||
burst int
|
||||
}
|
||||
|
||||
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithDownloadBandwidthLimiter {
|
||||
return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
|
||||
}
|
||||
|
||||
func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
var nn int
|
||||
for {
|
||||
end := len(p)
|
||||
if end == 0 {
|
||||
break
|
||||
}
|
||||
if conn.burst < len(p) {
|
||||
end = conn.burst
|
||||
}
|
||||
err = conn.limiter.WaitN(conn.ctx, end)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
nn, err = conn.PacketConn.WriteTo(p[:end], addr)
|
||||
n += nn
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p = p[end:]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type packetConnWithUploadBandwidthLimiter struct {
|
||||
net.PacketConn
|
||||
ctx context.Context
|
||||
limiter *rate.Limiter
|
||||
burst int
|
||||
}
|
||||
|
||||
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithUploadBandwidthLimiter {
|
||||
return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()}
|
||||
}
|
||||
|
||||
func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
if conn.burst < len(p) {
|
||||
p = p[:conn.burst]
|
||||
}
|
||||
n, addr, err = conn.PacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = conn.limiter.WaitN(conn.ctx, n)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type packetConnWithCloseHandler struct {
|
||||
net.PacketConn
|
||||
onClose CloseHandlerFunc
|
||||
}
|
||||
|
||||
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler {
|
||||
return &packetConnWithCloseHandler{conn, onClose}
|
||||
}
|
||||
|
||||
func (conn *packetConnWithCloseHandler) Close() error {
|
||||
conn.onClose()
|
||||
return conn.PacketConn.Close()
|
||||
type Limiter interface {
|
||||
WaitN(ctx context.Context, n int) (err error)
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ func CreateStrategy(strategy string, mode string, connectionType string, speed u
|
||||
}
|
||||
|
||||
func createSpeedLimiter(speed uint64) *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Limit(float64(speed)), 10000)
|
||||
return rate.NewLimiter(rate.Limit(float64(speed)), 65536)
|
||||
}
|
||||
|
||||
func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter *rate.Limiter, reverse bool) net.Conn {
|
||||
|
||||
89
protocol/masque/config.go
Normal file
89
protocol/masque/config.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package masque
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
PrivateKey string `json:"private_key"` // Base64-encoded ECDSA private key
|
||||
EndpointV4 string `json:"endpoint_v4"` // IPv4 address of the endpoint
|
||||
EndpointV6 string `json:"endpoint_v6"` // IPv6 address of the endpoint
|
||||
EndpointH2V4 string `json:"endpoint_h2_v4"` // IPv4 address used in HTTP/2 mode
|
||||
EndpointH2V6 string `json:"endpoint_h2_v6"` // IPv6 address used in HTTP/2 mode
|
||||
EndpointPubKey string `json:"endpoint_pub_key"` // PEM-encoded ECDSA public key of the endpoint to verify against
|
||||
License string `json:"license"` // Application license key
|
||||
ID string `json:"id"` // Device unique identifier
|
||||
AccessToken string `json:"access_token"` // Authentication token for API access
|
||||
IPv4 string `json:"ipv4"` // Assigned IPv4 address
|
||||
IPv6 string `json:"ipv6"` // Assigned IPv6 address
|
||||
}
|
||||
|
||||
func (c *Config) GetEcPrivateKey() (*ecdsa.PrivateKey, error) {
|
||||
privKeyB64, err := base64.StdEncoding.DecodeString(c.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode private key: %v", err)
|
||||
}
|
||||
privKey, err := x509.ParseECPrivateKey(privKeyB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %v", err)
|
||||
}
|
||||
return privKey, nil
|
||||
}
|
||||
|
||||
func (c *Config) GetEcEndpointPublicKey() (*ecdsa.PublicKey, error) {
|
||||
endpointPubKeyB64, _ := pem.Decode([]byte(c.EndpointPubKey))
|
||||
if endpointPubKeyB64 == nil {
|
||||
return nil, fmt.Errorf("failed to decode endpoint public key")
|
||||
}
|
||||
|
||||
pubKey, err := x509.ParsePKIXPublicKey(endpointPubKeyB64.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
ecPubKey, ok := pubKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to assert public key as ECDSA")
|
||||
}
|
||||
|
||||
return ecPubKey, nil
|
||||
}
|
||||
|
||||
func (c *Config) SelectEndpointFromConfig(useHTTP2 bool, useIPv6 bool, port int) (net.Addr, error) {
|
||||
if useHTTP2 {
|
||||
if useIPv6 {
|
||||
if c.EndpointH2V6 == "" {
|
||||
return nil, fmt.Errorf("--http2 with --ipv6 requires config endpoint_h2_v6 to be set")
|
||||
}
|
||||
ip := net.ParseIP(c.EndpointH2V6)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid endpoint_h2_v6 value %q", c.EndpointH2V6)
|
||||
}
|
||||
|
||||
return &net.TCPAddr{IP: ip, Port: port}, nil
|
||||
}
|
||||
v4 := c.EndpointH2V4
|
||||
ip := net.ParseIP(v4)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid endpoint_h2_v4 value %q")
|
||||
}
|
||||
return &net.TCPAddr{IP: ip, Port: port}, nil
|
||||
}
|
||||
if useIPv6 {
|
||||
ip := net.ParseIP(c.EndpointV6)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid endpoint_v6 value %q", c.EndpointV6)
|
||||
}
|
||||
return &net.UDPAddr{IP: ip, Port: port}, nil
|
||||
}
|
||||
ip := net.ParseIP(c.EndpointV4)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid endpoint_v4 value %q", c.EndpointV4)
|
||||
}
|
||||
return &net.UDPAddr{IP: ip, Port: port}, nil
|
||||
}
|
||||
300
protocol/masque/outbound.go
Normal file
300
protocol/masque/outbound.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package masque
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/cloudflare"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/masque"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func RegisterOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.MASQUEOutboundOptions](registry, C.TypeMASQUE, NewOutbound)
|
||||
}
|
||||
|
||||
type Outbound struct {
|
||||
outbound.Adapter
|
||||
ctx context.Context
|
||||
dnsRouter adapter.DNSRouter
|
||||
logger logger.ContextLogger
|
||||
options option.MASQUEOutboundOptions
|
||||
tunnel *masque.Tunnel
|
||||
startHandler func()
|
||||
|
||||
await chan struct{}
|
||||
}
|
||||
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MASQUEOutboundOptions) (adapter.Outbound, error) {
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeMASQUE, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, options.DialerOptions),
|
||||
ctx: ctx,
|
||||
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
|
||||
logger: logger,
|
||||
options: options,
|
||||
await: make(chan struct{}),
|
||||
}
|
||||
outbound.startHandler = func() {
|
||||
defer close(outbound.await)
|
||||
cacheFile := service.FromContext[adapter.CacheFile](ctx)
|
||||
var appConfig *Config
|
||||
var err error
|
||||
if !options.Profile.Recreate && cacheFile != nil && cacheFile.StoreMASQUEConfig() {
|
||||
savedProfile := cacheFile.LoadMASQUEConfig(tag)
|
||||
if savedProfile != nil {
|
||||
if err = json.Unmarshal(savedProfile.Content, &appConfig); err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if appConfig == nil {
|
||||
appConfig, err = outbound.createConfig()
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
if cacheFile != nil && cacheFile.StoreMASQUEConfig() {
|
||||
content, err := json.Marshal(appConfig)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
cacheFile.SaveMASQUEConfig(tag, &adapter.SavedBinary{
|
||||
LastUpdated: time.Now(),
|
||||
Content: content,
|
||||
LastEtag: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
privKey, err := appConfig.GetEcPrivateKey()
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, E.New("failed to get private key: ", err))
|
||||
return
|
||||
}
|
||||
peerPubKey, err := appConfig.GetEcEndpointPublicKey()
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, E.New("failed to get public key: ", err))
|
||||
return
|
||||
}
|
||||
cert, err := masque.GenerateCert(privKey, &privKey.PublicKey)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, E.New("failed to generate cert: ", err))
|
||||
return
|
||||
}
|
||||
tlsConfig, err := tls.NewMASQUEClient(ctx, logger, "consumer-masque.cloudflareclient.com", cert, privKey, peerPubKey, options.MASQUEOutboundTLSOptions)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, E.New("failed to prepare TLS config: ", err))
|
||||
return
|
||||
}
|
||||
endpoint, err := appConfig.SelectEndpointFromConfig(options.UseHTTP2, options.UseIPv6, 443)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, E.New("failed to select endpoint: ", err))
|
||||
return
|
||||
}
|
||||
var udpTimeout time.Duration
|
||||
if options.UDPTimeout != 0 {
|
||||
udpTimeout = time.Duration(options.UDPTimeout)
|
||||
} else {
|
||||
udpTimeout = C.UDPTimeout
|
||||
}
|
||||
var udpKeepalivePeriod time.Duration
|
||||
if options.UDPKeepalivePeriod != 0 {
|
||||
udpKeepalivePeriod = time.Duration(options.UDPKeepalivePeriod)
|
||||
} else {
|
||||
udpKeepalivePeriod = time.Second * 30
|
||||
}
|
||||
outboundDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: options.DialerOptions,
|
||||
RemoteIsDomain: false,
|
||||
ResolverOnDetour: true,
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
tunnel, err := masque.NewTunnel(
|
||||
ctx,
|
||||
logger,
|
||||
masque.TunnelOptions{
|
||||
Dialer: outboundDialer,
|
||||
Address: []netip.Prefix{
|
||||
netip.MustParsePrefix(appConfig.IPv4 + "/32"),
|
||||
netip.MustParsePrefix(appConfig.IPv6 + "/128"),
|
||||
},
|
||||
Endpoint: endpoint,
|
||||
TLSConfig: tlsConfig,
|
||||
UseHTTP2: options.UseHTTP2,
|
||||
UDPTimeout: udpTimeout,
|
||||
UDPKeepalivePeriod: udpKeepalivePeriod,
|
||||
UDPInitialPacketSize: options.UDPInitialPacketSize,
|
||||
ReconnectDelay: options.ReconnectDelay.Build(),
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
outbound.tunnel = tunnel
|
||||
if err = outbound.tunnel.Start(false); err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
if err = outbound.tunnel.Start(true); err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
return outbound, nil
|
||||
}
|
||||
|
||||
func (w *Outbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStatePostStart {
|
||||
return nil
|
||||
}
|
||||
go w.startHandler()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Outbound) Close() error {
|
||||
if err := w.isTunnelInitialized(w.ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.tunnel.Close()
|
||||
}
|
||||
|
||||
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if err := w.isTunnelInitialized(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch network {
|
||||
case N.NetworkTCP:
|
||||
w.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
case N.NetworkUDP:
|
||||
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
}
|
||||
if destination.IsDomain() {
|
||||
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return N.DialSerial(ctx, w.tunnel, network, destination, destinationAddresses)
|
||||
} else if !destination.Addr.IsValid() {
|
||||
return nil, E.New("invalid destination: ", destination)
|
||||
}
|
||||
return w.tunnel.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (w *Outbound) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) {
|
||||
if err := w.isTunnelInitialized(ctx); err != nil {
|
||||
return nil, netip.Addr{}, err
|
||||
}
|
||||
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
if destination.IsDomain() {
|
||||
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, err
|
||||
}
|
||||
return N.ListenSerial(ctx, w.tunnel, destination, destinationAddresses)
|
||||
}
|
||||
packetConn, err := w.tunnel.ListenPacket(ctx, destination)
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, err
|
||||
}
|
||||
if destination.IsIP() {
|
||||
return packetConn, destination.Addr, nil
|
||||
}
|
||||
return packetConn, netip.Addr{}, nil
|
||||
}
|
||||
|
||||
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
packetConn, destinationAddress, err := w.ListenPacketWithDestination(ctx, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) {
|
||||
return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
|
||||
}
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (w *Outbound) isTunnelInitialized(ctx context.Context) error {
|
||||
select {
|
||||
case <-w.await:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
if w.tunnel == nil {
|
||||
return E.New("tunnel not initialized")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Outbound) createConfig() (*Config, error) {
|
||||
opts := make([]cloudflare.CloudflareApiOption, 0, 1)
|
||||
if w.options.Profile.Detour != "" {
|
||||
detour, ok := service.FromContext[adapter.OutboundManager](w.ctx).Outbound(w.options.Profile.Detour)
|
||||
if !ok {
|
||||
return nil, E.New("outbound detour not found: ", w.options.Profile.Detour)
|
||||
}
|
||||
opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
}))
|
||||
}
|
||||
api := cloudflare.NewCloudflareApi(opts...)
|
||||
var profile *cloudflare.CloudflareProfile
|
||||
var err error
|
||||
if w.options.Profile.AuthToken != "" && w.options.Profile.ID != "" {
|
||||
profile, err = api.GetProfile(w.ctx, w.options.Profile.AuthToken, w.options.Profile.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
wgPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile, err = api.CreateProfile(w.ctx, wgPrivateKey.PublicKey().String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
privateKey, publicKey, err := masque.GenerateEcKeyPair()
|
||||
if err != nil {
|
||||
return nil, E.New("failed to generate key pair: ", err)
|
||||
}
|
||||
updatedProfile, err := api.EnrollKey(w.ctx, profile.Token, profile.ID, cloudflare.KeyTypeMasque, cloudflare.TunTypeMasque, base64.StdEncoding.EncodeToString(publicKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Config{
|
||||
PrivateKey: base64.StdEncoding.EncodeToString(privateKey),
|
||||
EndpointV4: updatedProfile.Config.Peers[0].Endpoint.V4[:len(updatedProfile.Config.Peers[0].Endpoint.V4)-2],
|
||||
EndpointV6: updatedProfile.Config.Peers[0].Endpoint.V6[1 : len(updatedProfile.Config.Peers[0].Endpoint.V6)-3],
|
||||
EndpointH2V4: cloudflare.DefaultEndpointH2V4,
|
||||
EndpointH2V6: cloudflare.DefaultEndpointH2V6,
|
||||
EndpointPubKey: updatedProfile.Config.Peers[0].PublicKey,
|
||||
License: updatedProfile.Account.License,
|
||||
ID: updatedProfile.ID,
|
||||
AccessToken: profile.Token,
|
||||
IPv4: updatedProfile.Config.Interface.Addresses.V4,
|
||||
IPv6: updatedProfile.Config.Interface.Addresses.V6,
|
||||
}, nil
|
||||
}
|
||||
38
protocol/mtproxy/dialer.go
Normal file
38
protocol/mtproxy/dialer.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package mtproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
handler adapter.ConnectionHandlerFuncEx
|
||||
}
|
||||
|
||||
func NewDialer(handler adapter.ConnectionHandlerFuncEx) *Dialer {
|
||||
return &Dialer{handler}
|
||||
}
|
||||
|
||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
inConn, outConn := net.Pipe()
|
||||
var metadata adapter.InboundContext
|
||||
if streamContext, ok := ctx.(streamContext); ok {
|
||||
metadata.Source = M.SocksaddrFromNet(streamContext.ClientAddr())
|
||||
metadata.User = streamContext.SecretName()
|
||||
}
|
||||
metadata.Destination = M.ParseSocksaddr(address)
|
||||
d.handler(ctx, inConn, metadata, func(error) {})
|
||||
return outConn, nil
|
||||
}
|
||||
|
||||
type streamContext interface {
|
||||
ClientAddr() net.Addr
|
||||
SecretName() string
|
||||
}
|
||||
132
protocol/mtproxy/inbound.go
Normal file
132
protocol/mtproxy/inbound.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package mtproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/dolonet/mtg-multi/antireplay"
|
||||
"github.com/dolonet/mtg-multi/events"
|
||||
"github.com/dolonet/mtg-multi/ipblocklist"
|
||||
"github.com/dolonet/mtg-multi/mtglib"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func RegisterInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.MTProxyInboundOptions](registry, C.TypeMTProxy, NewInbound)
|
||||
}
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
ctx context.Context
|
||||
router adapter.ConnectionRouterEx
|
||||
logger logger.ContextLogger
|
||||
listener *listener.Listener
|
||||
proxy *mtglib.Proxy
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MTProxyInboundOptions) (adapter.Inbound, error) {
|
||||
inbound := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeMTProxy, tag),
|
||||
ctx: ctx,
|
||||
router: router,
|
||||
logger: logger,
|
||||
listener: listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
}
|
||||
mtgLogger := NewLoggerAdapter(logger)
|
||||
secrets := make(map[string]mtglib.Secret, len(options.Users))
|
||||
for _, user := range options.Users {
|
||||
secret := mtglib.Secret{}
|
||||
err := secret.Set(user.Secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
secrets[user.Name] = secret
|
||||
}
|
||||
opts := mtglib.ProxyOpts{
|
||||
Logger: mtgLogger,
|
||||
Network: NewNetworkAdapter(ctx, NewDialer(inbound.newConnection)),
|
||||
AntiReplayCache: antireplay.NewNoop(),
|
||||
IPBlocklist: ipblocklist.NewNoop(),
|
||||
IPAllowlist: ipblocklist.NewNoop(),
|
||||
EventStream: events.NewNoopStream(),
|
||||
|
||||
Secrets: secrets,
|
||||
Concurrency: options.GetConcurrency(),
|
||||
DomainFrontingPort: options.GetDomainFrontingPort(),
|
||||
DomainFrontingIP: options.DomainFrontingIP,
|
||||
DomainFrontingProxyProtocol: options.DomainFrontingProxyProtocol,
|
||||
PreferIP: options.GetPreferIP(),
|
||||
AutoUpdate: options.AutoUpdate,
|
||||
|
||||
AllowFallbackOnUnknownDC: options.AllowFallbackOnUnknownDC,
|
||||
TolerateTimeSkewness: options.TolerateTimeSkewness.Build(),
|
||||
IdleTimeout: options.GetIdleTimeout(),
|
||||
HandshakeTimeout: options.GetHandshakeTimeout(),
|
||||
|
||||
DoppelGangerURLs: options.DoppelGangerURLs,
|
||||
DoppelGangerPerRaid: options.GetDoppelGangerPerRaid(),
|
||||
DoppelGangerEach: options.GetDoppelGangerEach(),
|
||||
DoppelGangerDRS: options.DoppelGangerDRS,
|
||||
|
||||
ThrottleMaxConnections: options.ThrottleMaxConnections,
|
||||
ThrottleCheckInterval: options.GetThrottleCheckInterval(),
|
||||
}
|
||||
proxy, err := mtglib.NewProxy(opts)
|
||||
if err != nil {
|
||||
return nil, E.New("cannot create a proxy: ", err)
|
||||
}
|
||||
inbound.proxy = proxy
|
||||
return inbound, nil
|
||||
}
|
||||
|
||||
func (n *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
listener, err := n.listener.ListenTCP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go n.proxy.Serve(listener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Inbound) Close() error {
|
||||
n.proxy.Shutdown()
|
||||
return common.Close(
|
||||
&n.listener,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *Inbound) UpdateUsers(users []option.MTProxyUser) {
|
||||
secrets := make(map[string]mtglib.Secret, len(users))
|
||||
for _, user := range users {
|
||||
secret := mtglib.Secret{}
|
||||
err := secret.Set(user.Secret)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
secrets[user.Name] = secret
|
||||
}
|
||||
h.proxy.UpdateUsers(secrets)
|
||||
}
|
||||
|
||||
func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
metadata.Inbound = h.Tag()
|
||||
metadata.InboundType = h.Type()
|
||||
h.logger.InfoContext(ctx, "[", metadata.User, "] inbound connection to ", metadata.Destination)
|
||||
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
}
|
||||
60
protocol/mtproxy/logger.go
Normal file
60
protocol/mtproxy/logger.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package mtproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/dolonet/mtg-multi/mtglib"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
type LoggerAdapter struct {
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewLoggerAdapter(logger logger.Logger) *LoggerAdapter {
|
||||
return &LoggerAdapter{logger}
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) Named(name string) mtglib.Logger {
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) BindInt(name string, value int) mtglib.Logger {
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) BindStr(name, value string) mtglib.Logger {
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) BindJSON(name, value string) mtglib.Logger {
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) Printf(format string, args ...any) {
|
||||
l.logger.Info(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) Info(msg string) {
|
||||
l.logger.Info(msg)
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) InfoError(msg string, err error) {
|
||||
l.logger.Error(msg, err)
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) Warning(msg string) {
|
||||
l.logger.Warn(msg)
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) WarningError(msg string, err error) {
|
||||
l.logger.Warn(msg, err)
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) Debug(msg string) {
|
||||
l.logger.Debug(msg)
|
||||
}
|
||||
|
||||
func (l *LoggerAdapter) DebugError(msg string, err error) {
|
||||
l.logger.Debug(msg, err)
|
||||
}
|
||||
43
protocol/mtproxy/network.go
Normal file
43
protocol/mtproxy/network.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mtproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/dolonet/mtg-multi/essentials"
|
||||
)
|
||||
|
||||
type NetworkAdapter struct {
|
||||
ctx context.Context
|
||||
dialer essentials.Dialer
|
||||
}
|
||||
|
||||
func NewNetworkAdapter(ctx context.Context, dialer essentials.Dialer) *NetworkAdapter {
|
||||
return &NetworkAdapter{ctx, dialer}
|
||||
}
|
||||
|
||||
func (a *NetworkAdapter) Dial(network, address string) (essentials.Conn, error) {
|
||||
return a.DialContext(a.ctx, network, address)
|
||||
}
|
||||
|
||||
func (a *NetworkAdapter) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
|
||||
conn, err := a.dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return essentials.WrapNetConn(conn), nil
|
||||
}
|
||||
|
||||
func (a *NetworkAdapter) MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error)) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: 10,
|
||||
Transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return a.DialContext(ctx, network, addr)
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *NetworkAdapter) NativeDialer() essentials.Dialer {
|
||||
return a.dialer
|
||||
}
|
||||
38
protocol/parser/outbound.go
Normal file
38
protocol/parser/outbound.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/parser/link"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.ParserOutboundOptions](registry, C.TypeParser, NewOutbound)
|
||||
}
|
||||
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ParserOutboundOptions) (adapter.Outbound, error) {
|
||||
if options.Link == "" {
|
||||
return nil, E.New("missing link")
|
||||
}
|
||||
outboundOptions, err := link.ParseSubscriptionLink(options.Link)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dialerOptions, ok := outboundOptions.Options.(option.DialerOptionsWrapper); ok {
|
||||
dialerOptions.ReplaceDialerOptions(options.DialerOptions)
|
||||
}
|
||||
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
||||
outbound, err := outboundRegistry.UnsafeCreateOutbound(ctx, router, logger, tag, outboundOptions.Type, outboundOptions.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return outbound, nil
|
||||
}
|
||||
0
protocol/relay/outbound.go
Normal file
0
protocol/relay/outbound.go
Normal file
@@ -49,6 +49,7 @@ type DNSTransport struct {
|
||||
dnsRouter adapter.DNSRouter
|
||||
endpointManager adapter.EndpointManager
|
||||
endpoint *Endpoint
|
||||
access sync.RWMutex
|
||||
routePrefixes []netip.Prefix
|
||||
routes map[string][]adapter.DNSTransport
|
||||
hosts map[string][]netip.Addr
|
||||
@@ -91,6 +92,12 @@ func (t *DNSTransport) Start(stage adapter.StartStage) error {
|
||||
}
|
||||
|
||||
func (t *DNSTransport) Reset() {
|
||||
t.access.RLock()
|
||||
transports := t.collectResolversLocked()
|
||||
t.access.RUnlock()
|
||||
for _, transport := range transports {
|
||||
transport.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) {
|
||||
@@ -101,7 +108,7 @@ func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, d
|
||||
}
|
||||
|
||||
func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *nDNS.Config) error {
|
||||
t.routePrefixes = buildRoutePrefixes(routeConfig)
|
||||
routePrefixes := buildRoutePrefixes(routeConfig)
|
||||
directDialerOnce := sync.OnceValue(func() N.Dialer {
|
||||
directDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{}))
|
||||
return &DNSDialer{transport: t, fallbackDialer: directDialer}
|
||||
@@ -130,9 +137,19 @@ func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *n
|
||||
}
|
||||
defaultResolvers = append(defaultResolvers, myResolver)
|
||||
}
|
||||
|
||||
t.access.Lock()
|
||||
oldResolvers := t.collectResolversLocked()
|
||||
t.routePrefixes = routePrefixes
|
||||
t.routes = routes
|
||||
t.hosts = hosts
|
||||
t.defaultResolvers = defaultResolvers
|
||||
t.access.Unlock()
|
||||
|
||||
for _, transport := range oldResolvers {
|
||||
transport.Close()
|
||||
}
|
||||
|
||||
if len(defaultResolvers) > 0 {
|
||||
t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts, default resolvers: ",
|
||||
strings.Join(common.Map(dnsConfig.DefaultResolvers, func(it *dnstype.Resolver) string { return it.Addr }), " "))
|
||||
@@ -207,7 +224,22 @@ func buildRoutePrefixes(routeConfig *router.Config) []netip.Prefix {
|
||||
}
|
||||
|
||||
func (t *DNSTransport) Close() error {
|
||||
return nil
|
||||
t.access.Lock()
|
||||
transports := t.collectResolversLocked()
|
||||
t.routePrefixes = nil
|
||||
t.routes = nil
|
||||
t.hosts = nil
|
||||
t.defaultResolvers = nil
|
||||
t.access.Unlock()
|
||||
|
||||
var err error
|
||||
for _, transport := range transports {
|
||||
name := "resolver/" + transport.Type() + "[" + transport.Tag() + "]"
|
||||
err = E.Append(err, transport.Close(), func(err error) error {
|
||||
return E.Cause(err, "close ", name)
|
||||
})
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *DNSTransport) Raw() bool {
|
||||
@@ -219,7 +251,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
question := message.Question[0]
|
||||
addresses, hostsLoaded := t.hosts[question.Name]
|
||||
|
||||
t.access.RLock()
|
||||
hosts := t.hosts
|
||||
routes := t.routes
|
||||
defaultResolvers := t.defaultResolvers
|
||||
acceptDefaultResolvers := t.acceptDefaultResolvers
|
||||
t.access.RUnlock()
|
||||
|
||||
addresses, hostsLoaded := hosts[question.Name]
|
||||
if hostsLoaded {
|
||||
switch question.Qtype {
|
||||
case mDNS.TypeA:
|
||||
@@ -238,7 +278,7 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
||||
}
|
||||
}
|
||||
}
|
||||
for domainSuffix, transports := range t.routes {
|
||||
for domainSuffix, transports := range routes {
|
||||
if strings.HasSuffix(question.Name, domainSuffix) {
|
||||
if len(transports) == 0 {
|
||||
return &mDNS.Msg{
|
||||
@@ -262,10 +302,10 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
||||
return nil, lastErr
|
||||
}
|
||||
}
|
||||
if t.acceptDefaultResolvers {
|
||||
if len(t.defaultResolvers) > 0 {
|
||||
if acceptDefaultResolvers {
|
||||
if len(defaultResolvers) > 0 {
|
||||
var lastErr error
|
||||
for _, resolver := range t.defaultResolvers {
|
||||
for _, resolver := range defaultResolvers {
|
||||
response, err := resolver.Exchange(ctx, message)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
@@ -281,6 +321,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
||||
return nil, dns.RcodeNameError
|
||||
}
|
||||
|
||||
func (t *DNSTransport) collectResolversLocked() []adapter.DNSTransport {
|
||||
var transports []adapter.DNSTransport
|
||||
for _, resolvers := range t.routes {
|
||||
transports = append(transports, resolvers...)
|
||||
}
|
||||
transports = append(transports, t.defaultResolvers...)
|
||||
return transports
|
||||
}
|
||||
|
||||
type DNSDialer struct {
|
||||
transport *DNSTransport
|
||||
fallbackDialer N.Dialer
|
||||
@@ -290,7 +339,8 @@ func (d *DNSDialer) DialContext(ctx context.Context, network string, destination
|
||||
if destination.IsDomain() {
|
||||
panic("invalid request here")
|
||||
}
|
||||
for _, prefix := range d.transport.routePrefixes {
|
||||
routePrefixes := d.transport.routePrefixesSnapshot()
|
||||
for _, prefix := range routePrefixes {
|
||||
if prefix.Contains(destination.Addr) {
|
||||
return d.transport.endpoint.DialContext(ctx, network, destination)
|
||||
}
|
||||
@@ -302,10 +352,17 @@ func (d *DNSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (
|
||||
if destination.IsDomain() {
|
||||
panic("invalid request here")
|
||||
}
|
||||
for _, prefix := range d.transport.routePrefixes {
|
||||
routePrefixes := d.transport.routePrefixesSnapshot()
|
||||
for _, prefix := range routePrefixes {
|
||||
if prefix.Contains(destination.Addr) {
|
||||
return d.transport.endpoint.ListenPacket(ctx, destination)
|
||||
}
|
||||
}
|
||||
return d.fallbackDialer.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (t *DNSTransport) routePrefixesSnapshot() []netip.Prefix {
|
||||
t.access.RLock()
|
||||
defer t.access.RUnlock()
|
||||
return append([]netip.Prefix(nil), t.routePrefixes...)
|
||||
}
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = 0
|
||||
)
|
||||
|
||||
const (
|
||||
CommandInbound = 1
|
||||
CommandTCP = 2
|
||||
)
|
||||
|
||||
var Destination = M.Socksaddr{
|
||||
Fqdn: "sp.tunnel.sing-box.arpa",
|
||||
Port: 444,
|
||||
}
|
||||
|
||||
var AddressSerializer = M.NewSerializer(
|
||||
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
|
||||
M.AddressFamilyByte(0x03, M.AddressFamilyIPv6),
|
||||
M.AddressFamilyByte(0x02, M.AddressFamilyFqdn),
|
||||
M.PortThenAddress(),
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
UUID uuid.UUID
|
||||
Command byte
|
||||
DestinationUUID uuid.UUID
|
||||
Destination M.Socksaddr
|
||||
}
|
||||
|
||||
func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
var request Request
|
||||
var version uint8
|
||||
err := binary.Read(reader, binary.BigEndian, &version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != Version {
|
||||
return nil, E.New("unknown version: ", version)
|
||||
}
|
||||
_, err = io.ReadFull(reader, request.UUID[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &request.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = io.ReadFull(reader, request.DestinationUUID[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Destination, err = AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &request, nil
|
||||
}
|
||||
|
||||
func WriteRequest(writer io.Writer, request *Request) error {
|
||||
var requestLen int
|
||||
requestLen += 1 // version
|
||||
requestLen += 16 // UUID
|
||||
requestLen += 16 // destinationUUID
|
||||
requestLen += 1 // command
|
||||
requestLen += AddressSerializer.AddrPortLen(request.Destination)
|
||||
buffer := buf.NewSize(requestLen)
|
||||
defer buffer.Release()
|
||||
common.Must(
|
||||
buffer.WriteByte(Version),
|
||||
common.Error(buffer.Write(request.UUID[:])),
|
||||
buffer.WriteByte(request.Command),
|
||||
common.Error(buffer.Write(request.DestinationUUID[:])),
|
||||
)
|
||||
err := AddressSerializer.WriteAddrPort(buffer, request.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Error(writer.Write(buffer.Bytes()))
|
||||
}
|
||||
@@ -1,201 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
sbUot "github.com/sagernet/sing-box/common/uot"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/uot"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterServerEndpoint(registry *endpoint.Registry) {
|
||||
endpoint.Register[option.TunnelServerEndpointOptions](registry, C.TypeTunnelServer, NewServerEndpoint)
|
||||
}
|
||||
|
||||
type ServerEndpoint struct {
|
||||
outbound.Adapter
|
||||
logger logger.ContextLogger
|
||||
inbound adapter.Inbound
|
||||
router adapter.ConnectionRouterEx
|
||||
uuid uuid.UUID
|
||||
users map[uuid.UUID]uuid.UUID
|
||||
keys map[uuid.UUID]uuid.UUID
|
||||
conns map[uuid.UUID]chan net.Conn
|
||||
timeout time.Duration
|
||||
uotClient *uot.Client
|
||||
}
|
||||
|
||||
func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelServerEndpointOptions) (adapter.Endpoint, error) {
|
||||
serverUUID, err := uuid.FromString(options.UUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server := &ServerEndpoint{
|
||||
Adapter: outbound.NewAdapter(C.TypeTunnelServer, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
|
||||
logger: logger,
|
||||
router: sbUot.NewRouter(router, logger),
|
||||
uuid: serverUUID,
|
||||
}
|
||||
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
||||
inbound, err := inboundRegistry.Create(ctx, NewRouter(router, logger, server.connHandler), logger, options.Inbound.Tag, options.Inbound.Type, options.Inbound.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.inbound = inbound
|
||||
server.users = make(map[uuid.UUID]uuid.UUID, len(options.Users))
|
||||
server.keys = make(map[uuid.UUID]uuid.UUID, len(options.Users))
|
||||
server.conns = make(map[uuid.UUID]chan net.Conn)
|
||||
for _, user := range options.Users {
|
||||
key, err := uuid.FromString(user.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
uuid, err := uuid.FromString(user.UUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.users[key] = uuid
|
||||
server.keys[uuid] = key
|
||||
server.conns[uuid] = make(chan net.Conn, 10)
|
||||
}
|
||||
if options.ConnectTimeout != 0 {
|
||||
server.timeout = time.Duration(options.ConnectTimeout)
|
||||
} else {
|
||||
server.timeout = C.TCPConnectTimeout
|
||||
}
|
||||
server.uotClient = &uot.Client{
|
||||
Dialer: server,
|
||||
Version: uot.Version,
|
||||
}
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) Start(stage adapter.StartStage) error {
|
||||
return s.inbound.Start(stage)
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if N.NetworkName(network) == N.NetworkUDP {
|
||||
return s.uotClient.DialContext(ctx, network, destination)
|
||||
}
|
||||
var sourceUUID *uuid.UUID
|
||||
var ch chan net.Conn
|
||||
if metadata := adapter.ContextFrom(ctx); metadata != nil {
|
||||
if metadata.TunnelDestination != "" {
|
||||
tunnelDestination, err := uuid.FromString(metadata.TunnelDestination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ok bool
|
||||
ch, ok = s.conns[tunnelDestination]
|
||||
if !ok {
|
||||
return nil, E.New("user ", metadata.TunnelDestination, " not found")
|
||||
}
|
||||
}
|
||||
if metadata.TunnelSource != "" {
|
||||
tunnelSource, err := uuid.FromString(metadata.TunnelSource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sourceUUID = &tunnelSource
|
||||
}
|
||||
}
|
||||
if ch == nil {
|
||||
return nil, E.New("tunnel destination not set")
|
||||
}
|
||||
if sourceUUID == nil {
|
||||
sourceUUID = &s.uuid
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, s.timeout)
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case conn := <-ch:
|
||||
err := WriteRequest(conn, &Request{UUID: *sourceUUID, Command: CommandTCP, Destination: destination})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, err)
|
||||
continue
|
||||
}
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return s.uotClient.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) Close() error {
|
||||
return common.Close(s.inbound)
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
|
||||
if metadata.Destination != Destination {
|
||||
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
return nil
|
||||
}
|
||||
request, err := ReadRequest(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if request.Command == CommandInbound {
|
||||
uuid, ok := s.users[request.UUID]
|
||||
if !ok {
|
||||
return E.New("key ", request.UUID.String(), " not found")
|
||||
}
|
||||
ch := s.conns[uuid]
|
||||
select {
|
||||
case ch <- conn:
|
||||
default:
|
||||
oldConn := <-ch
|
||||
oldConn.Close()
|
||||
ch <- conn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if request.Command == CommandTCP {
|
||||
sourceUUID, ok := s.users[request.UUID]
|
||||
if !ok {
|
||||
return E.New("key ", request.UUID, " not found")
|
||||
}
|
||||
if sourceUUID == request.DestinationUUID {
|
||||
return E.New("routing loop on ", sourceUUID)
|
||||
}
|
||||
if request.DestinationUUID != s.uuid {
|
||||
_, ok = s.keys[request.DestinationUUID]
|
||||
if !ok {
|
||||
return E.New("user ", request.DestinationUUID, " not found")
|
||||
}
|
||||
}
|
||||
metadata.Inbound = s.Tag()
|
||||
metadata.InboundType = C.TypeTunnelServer
|
||||
metadata.Destination = request.Destination
|
||||
metadata.TunnelSource = sourceUUID.String()
|
||||
metadata.TunnelDestination = request.DestinationUUID.String()
|
||||
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
return nil
|
||||
}
|
||||
return E.New("command ", request.Command, " not found")
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
package tunnel
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
@@ -23,7 +24,7 @@ import (
|
||||
)
|
||||
|
||||
func RegisterClientEndpoint(registry *endpoint.Registry) {
|
||||
endpoint.Register[option.TunnelClientEndpointOptions](registry, C.TypeTunnelClient, NewClientEndpoint)
|
||||
endpoint.Register[option.VPNClientEndpointOptions](registry, C.TypeVPNClient, NewClientEndpoint)
|
||||
}
|
||||
|
||||
type ClientEndpoint struct {
|
||||
@@ -32,27 +33,27 @@ type ClientEndpoint struct {
|
||||
outbound adapter.Outbound
|
||||
router adapter.ConnectionRouterEx
|
||||
logger logger.ContextLogger
|
||||
uuid uuid.UUID
|
||||
address IPv4
|
||||
key uuid.UUID
|
||||
uotClient *uot.Client
|
||||
}
|
||||
|
||||
func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelClientEndpointOptions) (adapter.Endpoint, error) {
|
||||
clientUUID, err := uuid.FromString(options.UUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VPNClientEndpointOptions) (adapter.Endpoint, error) {
|
||||
address := options.Address
|
||||
if !address.Is4() {
|
||||
return nil, E.New("invalid address: ", address)
|
||||
}
|
||||
clientKey, err := uuid.FromString(options.Key)
|
||||
key, err := uuid.FromString(options.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &ClientEndpoint{
|
||||
Adapter: outbound.NewAdapter(C.TypeTunnelClient, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
|
||||
Adapter: outbound.NewAdapter(C.TypeVPNClient, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
|
||||
ctx: ctx,
|
||||
router: sbUot.NewRouter(router, logger),
|
||||
logger: logger,
|
||||
uuid: clientUUID,
|
||||
key: clientKey,
|
||||
address: address.As4(),
|
||||
key: key,
|
||||
}
|
||||
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
||||
outbound, err := outboundRegistry.CreateOutbound(ctx, router, logger, options.Outbound.Tag, options.Outbound.Type, options.Outbound.Options)
|
||||
@@ -94,27 +95,31 @@ func (c *ClientEndpoint) DialContext(ctx context.Context, network string, destin
|
||||
if N.NetworkName(network) == N.NetworkUDP {
|
||||
return c.uotClient.DialContext(ctx, network, destination)
|
||||
}
|
||||
var destinationUUID *uuid.UUID
|
||||
if metadata := adapter.ContextFrom(ctx); metadata != nil {
|
||||
if metadata.TunnelDestination != "" {
|
||||
uuid, err := uuid.FromString(metadata.TunnelDestination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
destinationUUID = &uuid
|
||||
}
|
||||
}
|
||||
if destinationUUID == nil {
|
||||
return nil, E.New("tunnel destination not set")
|
||||
}
|
||||
if *destinationUUID == c.uuid {
|
||||
return nil, E.New("routing loop")
|
||||
if destination.Addr.Is4() && destination.Addr.As4() == c.address {
|
||||
return nil, E.New("routing loop on ", destination.Addr)
|
||||
}
|
||||
conn, err := c.outbound.DialContext(ctx, N.NetworkTCP, Destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = WriteRequest(conn, &Request{UUID: c.key, Command: CommandTCP, DestinationUUID: *destinationUUID, Destination: destination})
|
||||
gateway := Loopback.As4()
|
||||
if metadata := adapter.ContextFrom(ctx); metadata != nil {
|
||||
if metadata.Gateway != nil {
|
||||
gateway = metadata.Gateway.As4()
|
||||
if gateway == c.address {
|
||||
return nil, E.New("routing loop on ", destination.Addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
err = WriteClientRequest(
|
||||
conn,
|
||||
&ClientRequest{
|
||||
Key: c.key,
|
||||
Command: CommandTCP,
|
||||
Gateway: gateway,
|
||||
Destination: destination,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -134,29 +139,22 @@ func (c *ClientEndpoint) startInboundConn() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = WriteRequest(conn, &Request{UUID: c.key, Command: CommandInbound, Destination: Destination})
|
||||
err = WriteClientRequest(conn, &ClientRequest{Key: c.key, Command: CommandInbound, Destination: Destination})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
request, err := ReadRequest(conn)
|
||||
request, err := ReadServerRequest(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go c.connHandler(conn, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientEndpoint) connHandler(conn net.Conn, request *Request) {
|
||||
if request.Source == c.address {
|
||||
return E.New("routing loop")
|
||||
}
|
||||
metadata := adapter.InboundContext{
|
||||
Inbound: c.Tag(),
|
||||
Source: M.ParseSocksaddr(conn.RemoteAddr().String()),
|
||||
Source: M.Socksaddr{Addr: netip.AddrFrom4(request.Source)},
|
||||
Destination: request.Destination,
|
||||
}
|
||||
if request.UUID == c.uuid {
|
||||
c.logger.ErrorContext(c.ctx, "routing loop")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
metadata.TunnelSource = request.UUID.String()
|
||||
c.router.RouteConnectionEx(c.ctx, conn, metadata, func(it error) {})
|
||||
go c.router.RouteConnectionEx(c.ctx, conn, metadata, func(it error) {})
|
||||
return nil
|
||||
}
|
||||
124
protocol/vpn/protocol.go
Normal file
124
protocol/vpn/protocol.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net/netip"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = 0
|
||||
)
|
||||
|
||||
const (
|
||||
CommandInbound = 1
|
||||
CommandTCP = 2
|
||||
)
|
||||
|
||||
type IPv4 [4]byte
|
||||
|
||||
var Destination = M.Socksaddr{
|
||||
Fqdn: "sp.vpn.sing-box.arpa",
|
||||
Port: 444,
|
||||
}
|
||||
|
||||
var Loopback = netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||
|
||||
var AddressSerializer = M.NewSerializer(
|
||||
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
|
||||
M.AddressFamilyByte(0x03, M.AddressFamilyIPv6),
|
||||
M.AddressFamilyByte(0x02, M.AddressFamilyFqdn),
|
||||
M.PortThenAddress(),
|
||||
)
|
||||
|
||||
type ClientRequest struct {
|
||||
Key uuid.UUID
|
||||
Command byte
|
||||
Gateway IPv4
|
||||
Destination M.Socksaddr
|
||||
}
|
||||
|
||||
func ReadClientRequest(reader io.Reader) (*ClientRequest, error) {
|
||||
var request ClientRequest
|
||||
var version uint8
|
||||
err := binary.Read(reader, binary.BigEndian, &version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != Version {
|
||||
return nil, E.New("unknown version: ", version)
|
||||
}
|
||||
_, err = io.ReadFull(reader, request.Key[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &request.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = io.ReadFull(reader, request.Gateway[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Destination, err = AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &request, nil
|
||||
}
|
||||
|
||||
func WriteClientRequest(writer io.Writer, request *ClientRequest) error {
|
||||
var requestLen int
|
||||
requestLen += 1 // version
|
||||
requestLen += 16 // key
|
||||
requestLen += 1 // command
|
||||
requestLen += 4 // gateway
|
||||
requestLen += AddressSerializer.AddrPortLen(request.Destination)
|
||||
buffer := buf.NewSize(requestLen)
|
||||
defer buffer.Release()
|
||||
common.Must(
|
||||
buffer.WriteByte(Version),
|
||||
common.Error(buffer.Write(request.Key[:])),
|
||||
buffer.WriteByte(request.Command),
|
||||
common.Error(buffer.Write(request.Gateway[:])),
|
||||
AddressSerializer.WriteAddrPort(buffer, request.Destination),
|
||||
)
|
||||
return common.Error(writer.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
type ServerRequest struct {
|
||||
Source IPv4
|
||||
Destination M.Socksaddr
|
||||
}
|
||||
|
||||
func ReadServerRequest(reader io.Reader) (*ServerRequest, error) {
|
||||
var request ServerRequest
|
||||
_, err := io.ReadFull(reader, request.Source[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Destination, err = AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &request, nil
|
||||
}
|
||||
|
||||
func WriteServerRequest(writer io.Writer, request *ServerRequest) error {
|
||||
var requestLen int
|
||||
requestLen += 4 // source
|
||||
requestLen += AddressSerializer.AddrPortLen(request.Destination)
|
||||
buffer := buf.NewSize(requestLen)
|
||||
defer buffer.Release()
|
||||
common.Must(
|
||||
common.Error(buffer.Write(request.Source[:])),
|
||||
AddressSerializer.WriteAddrPort(buffer, request.Destination),
|
||||
)
|
||||
return common.Error(writer.Write(buffer.Bytes()))
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package tunnel
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -21,14 +21,24 @@ func NewRouter(router adapter.Router, logger logger.ContextLogger, handler func(
|
||||
}
|
||||
|
||||
func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
if metadata.Destination != Destination {
|
||||
return r.Router.RouteConnection(ctx, conn, metadata)
|
||||
}
|
||||
return r.handler(ctx, conn, metadata, func(error) {})
|
||||
}
|
||||
|
||||
func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||
if metadata.Destination != Destination {
|
||||
return r.Router.RoutePacketConnection(ctx, conn, metadata)
|
||||
}
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
if metadata.Destination != Destination {
|
||||
r.Router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
return
|
||||
}
|
||||
if err := r.handler(ctx, conn, metadata, onClose); err != nil {
|
||||
r.logger.ErrorContext(ctx, err)
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
@@ -36,6 +46,10 @@ func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata
|
||||
}
|
||||
|
||||
func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
if metadata.Destination != Destination {
|
||||
r.Router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||
return
|
||||
}
|
||||
r.logger.ErrorContext(ctx, os.ErrInvalid)
|
||||
N.CloseOnHandshakeFailure(conn, onClose, os.ErrInvalid)
|
||||
}
|
||||
235
protocol/vpn/server.go
Normal file
235
protocol/vpn/server.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
sbUot "github.com/sagernet/sing-box/common/uot"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/uot"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterServerEndpoint(registry *endpoint.Registry) {
|
||||
endpoint.Register[option.VPNServerEndpointOptions](registry, C.TypeVPNServer, NewServerEndpoint)
|
||||
}
|
||||
|
||||
type ServerEndpoint struct {
|
||||
outbound.Adapter
|
||||
logger logger.ContextLogger
|
||||
inbounds []adapter.Inbound
|
||||
router adapter.ConnectionRouterEx
|
||||
address IPv4
|
||||
addresses map[uuid.UUID]IPv4
|
||||
keys map[IPv4]uuid.UUID
|
||||
conns map[IPv4]chan net.Conn
|
||||
timeout time.Duration
|
||||
uotClient *uot.Client
|
||||
}
|
||||
|
||||
func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VPNServerEndpointOptions) (adapter.Endpoint, error) {
|
||||
address := options.Address
|
||||
if !address.Is4() {
|
||||
return nil, E.New("invalid address: ", address)
|
||||
}
|
||||
server := &ServerEndpoint{
|
||||
Adapter: outbound.NewAdapter(C.TypeVPNServer, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
|
||||
logger: logger,
|
||||
router: sbUot.NewRouter(router, logger),
|
||||
address: address.As4(),
|
||||
}
|
||||
router = NewRouter(router, logger, server.connHandler)
|
||||
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
||||
inbounds := make([]adapter.Inbound, len(options.Inbounds))
|
||||
for i, inboundOptions := range options.Inbounds {
|
||||
inbound, err := inboundRegistry.Create(ctx, router, logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbounds[i] = inbound
|
||||
}
|
||||
server.inbounds = inbounds
|
||||
server.addresses = make(map[uuid.UUID]IPv4, len(options.Users))
|
||||
server.keys = make(map[IPv4]uuid.UUID, len(options.Users))
|
||||
server.conns = make(map[IPv4]chan net.Conn)
|
||||
for _, user := range options.Users {
|
||||
key, err := uuid.FromString(user.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !user.Address.Is4() {
|
||||
return nil, E.New("invalid address: ", user.Address)
|
||||
}
|
||||
address := user.Address.As4()
|
||||
server.addresses[key] = address
|
||||
server.keys[address] = key
|
||||
server.conns[address] = make(chan net.Conn, 10)
|
||||
}
|
||||
if options.ConnectTimeout != 0 {
|
||||
server.timeout = time.Duration(options.ConnectTimeout)
|
||||
} else {
|
||||
server.timeout = C.TCPConnectTimeout
|
||||
}
|
||||
server.uotClient = &uot.Client{
|
||||
Dialer: server,
|
||||
Version: uot.Version,
|
||||
}
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) Start(stage adapter.StartStage) error {
|
||||
for _, inbound := range s.inbounds {
|
||||
err := inbound.Start(stage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if N.NetworkName(network) == N.NetworkUDP {
|
||||
return s.uotClient.DialContext(ctx, network, destination)
|
||||
}
|
||||
source := s.address
|
||||
var gateway *netip.Addr
|
||||
if metadata := adapter.ContextFrom(ctx); metadata != nil {
|
||||
if metadata.Source.IsIPv4() {
|
||||
address := metadata.Source.Addr.As4()
|
||||
if _, ok := s.conns[address]; ok {
|
||||
source = address
|
||||
}
|
||||
}
|
||||
if metadata.Gateway != nil {
|
||||
gateway = metadata.Gateway
|
||||
}
|
||||
}
|
||||
if gateway == nil {
|
||||
if destination.IsIPv4() {
|
||||
gateway = &destination.Addr
|
||||
destination = M.Socksaddr{
|
||||
Addr: Loopback,
|
||||
Port: destination.Port,
|
||||
}
|
||||
} else {
|
||||
return nil, E.New("missing gateway")
|
||||
}
|
||||
} else if destination.Addr.Compare(*gateway) == 0 {
|
||||
destination = M.Socksaddr{
|
||||
Addr: Loopback,
|
||||
Port: destination.Port,
|
||||
}
|
||||
}
|
||||
if gateway.Compare(Loopback) == 0 {
|
||||
return nil, E.New("invalid gateway")
|
||||
}
|
||||
ch, ok := s.conns[gateway.As4()]
|
||||
if !ok {
|
||||
return nil, E.New("user with address ", gateway, " not found")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, s.timeout)
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case conn := <-ch:
|
||||
err := WriteServerRequest(conn, &ServerRequest{Source: source, Destination: destination})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.logger.ErrorContext(ctx, err)
|
||||
continue
|
||||
}
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return s.uotClient.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) Close() error {
|
||||
errs := make([]error, 0)
|
||||
for _, inbound := range s.inbounds {
|
||||
err := inbound.Close()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if len(errs) != 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerEndpoint) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
|
||||
if metadata.Destination != Destination {
|
||||
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
return nil
|
||||
}
|
||||
request, err := ReadClientRequest(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if request.Command == CommandInbound {
|
||||
address, ok := s.addresses[request.Key]
|
||||
if !ok {
|
||||
return E.New("key ", request.Key.String(), " not found")
|
||||
}
|
||||
ch := s.conns[address]
|
||||
select {
|
||||
case ch <- conn:
|
||||
default:
|
||||
oldConn := <-ch
|
||||
oldConn.Close()
|
||||
ch <- conn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if request.Command == CommandTCP {
|
||||
source, ok := s.addresses[request.Key]
|
||||
if !ok {
|
||||
return E.New("key ", request.Key, " not found")
|
||||
}
|
||||
if request.Destination.Addr.Is4() && source == request.Destination.Addr.As4() {
|
||||
return E.New("routing loop on ", request.Destination)
|
||||
}
|
||||
metadata.Inbound = s.Tag()
|
||||
metadata.InboundType = C.TypeVPNServer
|
||||
metadata.Source = M.Socksaddr{Addr: netip.AddrFrom4(source)}
|
||||
if request.Destination.Addr.Is4() && request.Destination.Addr.As4() == s.address {
|
||||
metadata.Destination = M.Socksaddr{
|
||||
Addr: Loopback,
|
||||
Port: request.Destination.Port,
|
||||
}
|
||||
} else {
|
||||
metadata.Destination = request.Destination
|
||||
if request.Gateway != s.address && request.Gateway != Loopback.As4() {
|
||||
addr := netip.AddrFrom4(request.Gateway)
|
||||
metadata.Gateway = &addr
|
||||
}
|
||||
}
|
||||
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
return nil
|
||||
}
|
||||
return E.New("command ", request.Command, " not found")
|
||||
}
|
||||
14
protocol/warp/config.go
Normal file
14
protocol/warp/config.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package warp
|
||||
|
||||
import "github.com/sagernet/sing-box/common/cloudflare"
|
||||
|
||||
type Config struct {
|
||||
PrivateKey string `json:"private_key"`
|
||||
Interface struct {
|
||||
Addresses struct {
|
||||
V4 string `json:"v4"`
|
||||
V6 string `json:"v6"`
|
||||
} `json:"addresses"`
|
||||
} `json:"interface"`
|
||||
Peers []cloudflare.Peer
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package wireguard
|
||||
package warp
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/protocol/wireguard"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
@@ -25,19 +25,21 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func RegisterWARPEndpoint(registry *endpoint.Registry) {
|
||||
endpoint.Register[option.WireGuardWARPEndpointOptions](registry, C.TypeWARP, NewWARPEndpoint)
|
||||
func RegisterEndpoint(registry *endpoint.Registry) {
|
||||
endpoint.Register[option.WARPEndpointOptions](registry, C.TypeWARP, NewEndpoint)
|
||||
}
|
||||
|
||||
type WARPEndpoint struct {
|
||||
type Endpoint struct {
|
||||
endpoint.Adapter
|
||||
ctx context.Context
|
||||
options option.WARPEndpointOptions
|
||||
endpoint adapter.Endpoint
|
||||
startHandler func()
|
||||
|
||||
mtx sync.Mutex
|
||||
await chan struct{}
|
||||
}
|
||||
|
||||
func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardWARPEndpointOptions) (adapter.Endpoint, error) {
|
||||
func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WARPEndpointOptions) (adapter.Endpoint, error) {
|
||||
var dependencies []string
|
||||
if options.Detour != "" {
|
||||
dependencies = append(dependencies, options.Detour)
|
||||
@@ -45,14 +47,16 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont
|
||||
if options.Profile.Detour != "" {
|
||||
dependencies = append(dependencies, options.Profile.Detour)
|
||||
}
|
||||
warpEndpoint := &WARPEndpoint{
|
||||
endpoint := &Endpoint{
|
||||
Adapter: endpoint.NewAdapter(C.TypeWARP, tag, []string{N.NetworkTCP, N.NetworkUDP}, dependencies),
|
||||
ctx: ctx,
|
||||
options: options,
|
||||
await: make(chan struct{}),
|
||||
}
|
||||
warpEndpoint.mtx.Lock()
|
||||
warpEndpoint.startHandler = func() {
|
||||
defer warpEndpoint.mtx.Unlock()
|
||||
endpoint.startHandler = func() {
|
||||
defer close(endpoint.await)
|
||||
cacheFile := service.FromContext[adapter.CacheFile](ctx)
|
||||
var config *C.WARPConfig
|
||||
var config *Config
|
||||
var err error
|
||||
if !options.Profile.Recreate && cacheFile != nil && cacheFile.StoreWARPConfig() {
|
||||
savedProfile := cacheFile.LoadWARPConfig(tag)
|
||||
@@ -64,50 +68,10 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont
|
||||
}
|
||||
}
|
||||
if config == nil {
|
||||
var privateKey wgtypes.Key
|
||||
if options.Profile.PrivateKey != "" {
|
||||
privateKey, err = wgtypes.ParseKey(options.Profile.PrivateKey)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
opts := make([]cloudflare.CloudflareApiOption, 0, 1)
|
||||
if options.Profile.Detour != "" {
|
||||
detour, ok := service.FromContext[adapter.OutboundManager](ctx).Outbound(options.Profile.Detour)
|
||||
if !ok {
|
||||
logger.ErrorContext(ctx, E.New("outbound detour not found: ", options.Profile.Detour))
|
||||
return
|
||||
}
|
||||
opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
}))
|
||||
}
|
||||
api := cloudflare.NewCloudflareApi(opts...)
|
||||
var profile *cloudflare.CloudflareProfile
|
||||
if options.Profile.AuthToken != "" && options.Profile.ID != "" {
|
||||
profile, err = api.GetProfile(ctx, options.Profile.AuthToken, options.Profile.ID)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
profile, err = api.CreateProfile(ctx, privateKey.PublicKey().String())
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
config = &C.WARPConfig{
|
||||
PrivateKey: privateKey.String(),
|
||||
Interface: profile.Config.Interface,
|
||||
Peers: profile.Config.Peers,
|
||||
config, err := endpoint.createConfig()
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
if cacheFile != nil && cacheFile.StoreWARPConfig() {
|
||||
content, err := json.Marshal(config)
|
||||
@@ -124,7 +88,7 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont
|
||||
}
|
||||
peer := config.Peers[0]
|
||||
hostParts := strings.Split(peer.Endpoint.Host, ":")
|
||||
warpEndpoint.endpoint, err = NewEndpoint(
|
||||
endpoint.endpoint, err = wireguard.NewEndpoint(
|
||||
ctx,
|
||||
router,
|
||||
logger,
|
||||
@@ -165,19 +129,19 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
if err = warpEndpoint.endpoint.Start(adapter.StartStateStart); err != nil {
|
||||
if err = endpoint.endpoint.Start(adapter.StartStateStart); err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
if err = warpEndpoint.endpoint.Start(adapter.StartStatePostStart); err != nil {
|
||||
if err = endpoint.endpoint.Start(adapter.StartStatePostStart); err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
return warpEndpoint, nil
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
func (w *WARPEndpoint) Start(stage adapter.StartStage) error {
|
||||
func (w *Endpoint) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStatePostStart {
|
||||
return nil
|
||||
}
|
||||
@@ -185,26 +149,79 @@ func (w *WARPEndpoint) Start(stage adapter.StartStage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WARPEndpoint) Close() error {
|
||||
func (w *Endpoint) Close() error {
|
||||
if err := w.isEndpointInitialized(w.ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return common.Close(w.endpoint)
|
||||
}
|
||||
|
||||
func (w *WARPEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if ok := w.isEndpointInitialized(); !ok {
|
||||
return nil, E.New("endpoint not initialized")
|
||||
func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if err := w.isEndpointInitialized(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return w.endpoint.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (w *WARPEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if ok := w.isEndpointInitialized(); !ok {
|
||||
return nil, E.New("endpoint not initialized")
|
||||
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if err := w.isEndpointInitialized(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return w.endpoint.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (w *WARPEndpoint) isEndpointInitialized() bool {
|
||||
w.mtx.Lock()
|
||||
defer w.mtx.Unlock()
|
||||
return w.endpoint != nil
|
||||
func (w *Endpoint) isEndpointInitialized(ctx context.Context) error {
|
||||
select {
|
||||
case <-w.await:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
if w.endpoint == nil {
|
||||
return E.New("endpoint not initialized")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Endpoint) createConfig() (*Config, error) {
|
||||
var privateKey wgtypes.Key
|
||||
var err error
|
||||
if w.options.Profile.PrivateKey != "" {
|
||||
privateKey, err = wgtypes.ParseKey(w.options.Profile.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
opts := make([]cloudflare.CloudflareApiOption, 0, 1)
|
||||
if w.options.Profile.Detour != "" {
|
||||
detour, ok := service.FromContext[adapter.OutboundManager](w.ctx).Outbound(w.options.Profile.Detour)
|
||||
if !ok {
|
||||
return nil, E.New("outbound detour not found: ", w.options.Profile.Detour)
|
||||
}
|
||||
opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
}))
|
||||
}
|
||||
api := cloudflare.NewCloudflareApi(opts...)
|
||||
var profile *cloudflare.CloudflareProfile
|
||||
if w.options.Profile.AuthToken != "" && w.options.Profile.ID != "" {
|
||||
profile, err = api.GetProfile(w.ctx, w.options.Profile.AuthToken, w.options.Profile.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
profile, err = api.CreateProfile(w.ctx, privateKey.PublicKey().String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Config{
|
||||
PrivateKey: privateKey.String(),
|
||||
Interface: profile.Config.Interface,
|
||||
Peers: profile.Config.Peers,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user