Add tunnel

This commit is contained in:
Shtorm
2025-07-06 18:31:06 +03:00
parent 54b373a73e
commit 3f97424224
29 changed files with 1262 additions and 10 deletions

151
protocol/tunnel/client.go Normal file
View File

@@ -0,0 +1,151 @@
package tunnel
import (
"context"
"net"
"os"
"time"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterClientEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.TunnelClientEndpointOptions](registry, C.TypeTunnelClient, NewClientEndpoint)
}
type ClientEndpoint struct {
outbound.Adapter
ctx context.Context
outbound adapter.Outbound
router adapter.ConnectionRouterEx
logger logger.ContextLogger
uuid uuid.UUID
key uuid.UUID
}
func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelClientEndpointOptions) (adapter.Endpoint, error) {
clientUUID, err := uuid.FromString(options.UUID)
if err != nil {
return nil, err
}
clientKey, err := uuid.FromString(options.Key)
if err != nil {
return nil, err
}
client := &ClientEndpoint{
Adapter: outbound.NewAdapter(C.TypeTunnelClient, tag, []string{N.NetworkTCP}, []string{}),
ctx: ctx,
router: router,
logger: logger,
uuid: clientUUID,
key: clientKey,
}
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
outbound, err := outboundRegistry.CreateOutbound(ctx, router, logger, options.Outbound.Tag, options.Outbound.Type, options.Outbound.Options)
if err != nil {
return nil, err
}
client.outbound = outbound
return client, nil
}
func (c *ClientEndpoint) Start(stage adapter.StartStage) error {
if stage != adapter.StartStatePostStart {
return nil
}
for range 5 {
go func() {
for {
select {
case <-c.ctx.Done():
return
default:
err := c.startInboundConn()
if err != nil {
c.logger.ErrorContext(c.ctx, err)
time.Sleep(time.Second * 5)
}
}
}
}()
}
return nil
}
func (c *ClientEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if network != N.NetworkTCP {
return nil, os.ErrInvalid
}
var destinationUUID *uuid.UUID
if metadata := adapter.ContextFrom(ctx); metadata != nil {
if metadata.TunnelDestination != "" {
uuid, err := uuid.FromString(metadata.TunnelDestination)
if err != nil {
return nil, err
}
destinationUUID = &uuid
}
}
if destinationUUID == nil {
return nil, E.New("tunnel destination not set")
}
if *destinationUUID == c.uuid {
return nil, E.New("routing loop")
}
conn, err := c.outbound.DialContext(ctx, N.NetworkTCP, Destination)
if err != nil {
return nil, err
}
err = WriteRequest(conn, &Request{UUID: c.key, Command: CommandTCP, DestinationUUID: *destinationUUID, Destination: destination})
return conn, err
}
func (c *ClientEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return nil, os.ErrInvalid
}
func (c *ClientEndpoint) Close() error {
return nil
}
func (c *ClientEndpoint) startInboundConn() error {
conn, err := c.outbound.DialContext(c.ctx, N.NetworkTCP, Destination)
if err != nil {
return err
}
err = WriteRequest(conn, &Request{UUID: c.key, Command: CommandInbound, Destination: Destination})
if err != nil {
return err
}
request, err := ReadRequest(conn)
if err != nil {
return err
}
go c.connHandler(conn, request)
return nil
}
func (c *ClientEndpoint) connHandler(conn net.Conn, request *Request) {
metadata := adapter.InboundContext{
Source: M.ParseSocksaddr(conn.RemoteAddr().String()),
Destination: request.Destination,
}
if request.UUID == c.uuid {
c.logger.ErrorContext(c.ctx, "routing loop")
conn.Close()
return
}
metadata.TunnelSource = request.UUID.String()
c.router.RouteConnectionEx(c.ctx, conn, metadata, func(it error) {})
}

View File

@@ -0,0 +1,91 @@
package tunnel
import (
"encoding/binary"
"io"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
const (
Version = 0
)
const (
CommandInbound = 1
CommandTCP = 2
)
var Destination = M.Socksaddr{
Fqdn: "sp.tunnel.sing-box.arpa",
Port: 444,
}
var AddressSerializer = M.NewSerializer(
M.AddressFamilyByte(0x01, M.AddressFamilyIPv4),
M.AddressFamilyByte(0x03, M.AddressFamilyIPv6),
M.AddressFamilyByte(0x02, M.AddressFamilyFqdn),
M.PortThenAddress(),
)
type Request struct {
UUID uuid.UUID
Command byte
DestinationUUID uuid.UUID
Destination M.Socksaddr
}
func ReadRequest(reader io.Reader) (*Request, error) {
var request Request
var version uint8
err := binary.Read(reader, binary.BigEndian, &version)
if err != nil {
return nil, err
}
if version != Version {
return nil, E.New("unknown version: ", version)
}
_, err = io.ReadFull(reader, request.UUID[:])
if err != nil {
return nil, err
}
err = binary.Read(reader, binary.BigEndian, &request.Command)
if err != nil {
return nil, err
}
_, err = io.ReadFull(reader, request.DestinationUUID[:])
if err != nil {
return nil, err
}
request.Destination, err = AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
return &request, nil
}
func WriteRequest(writer io.Writer, request *Request) error {
var requestLen int
requestLen += 1 // version
requestLen += 16 // UUID
requestLen += 16 // destinationUUID
requestLen += 1 // command
requestLen += AddressSerializer.AddrPortLen(request.Destination)
buffer := buf.NewSize(requestLen)
defer buffer.Release()
common.Must(
buffer.WriteByte(Version),
common.Error(buffer.Write(request.UUID[:])),
buffer.WriteByte(request.Command),
common.Error(buffer.Write(request.DestinationUUID[:])),
)
err := AddressSerializer.WriteAddrPort(buffer, request.Destination)
if err != nil {
return err
}
return common.Error(writer.Write(buffer.Bytes()))
}

41
protocol/tunnel/router.go Normal file
View File

@@ -0,0 +1,41 @@
package tunnel
import (
"context"
"net"
"os"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
)
type Router struct {
adapter.Router
logger logger.ContextLogger
handler func(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) error
}
func NewRouter(router adapter.Router, logger logger.ContextLogger, handler func(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) error) *Router {
return &Router{Router: router, logger: logger, handler: handler}
}
func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return r.handler(ctx, conn, metadata, func(error) {})
}
func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return os.ErrInvalid
}
func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if err := r.handler(ctx, conn, metadata, onClose); err != nil {
r.logger.ErrorContext(ctx, err)
N.CloseOnHandshakeFailure(conn, onClose, err)
}
}
func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
r.logger.ErrorContext(ctx, os.ErrInvalid)
N.CloseOnHandshakeFailure(conn, onClose, os.ErrInvalid)
}

203
protocol/tunnel/server.go Normal file
View File

@@ -0,0 +1,203 @@
package tunnel
import (
"context"
"net"
"os"
"sync"
"time"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterServerEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.TunnelServerEndpointOptions](registry, C.TypeTunnelServer, NewServerEndpoint)
}
type ServerEndpoint struct {
outbound.Adapter
logger logger.ContextLogger
inbound adapter.Inbound
router adapter.Router
uuid uuid.UUID
users map[uuid.UUID]uuid.UUID
keys map[uuid.UUID]uuid.UUID
conns map[uuid.UUID]chan net.Conn
timeout time.Duration
mtx sync.Mutex
}
func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelServerEndpointOptions) (adapter.Endpoint, error) {
serverUUID, err := uuid.FromString(options.UUID)
if err != nil {
return nil, err
}
server := &ServerEndpoint{
Adapter: outbound.NewAdapter(C.TypeTunnelServer, tag, []string{N.NetworkTCP}, []string{}),
logger: logger,
router: router,
uuid: serverUUID,
}
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
inbound, err := inboundRegistry.Create(ctx, NewRouter(router, logger, server.connHandler), logger, options.Inbound.Tag, options.Inbound.Type, options.Inbound.Options)
if err != nil {
return nil, err
}
server.inbound = inbound
server.users = make(map[uuid.UUID]uuid.UUID, len(options.Users))
server.keys = make(map[uuid.UUID]uuid.UUID, len(options.Users))
server.conns = make(map[uuid.UUID]chan net.Conn)
for _, user := range options.Users {
key, err := uuid.FromString(user.Key)
if err != nil {
return nil, err
}
uuid, err := uuid.FromString(user.UUID)
if err != nil {
return nil, err
}
server.users[key] = uuid
server.keys[uuid] = key
server.conns[uuid] = make(chan net.Conn, 10)
}
if options.ConnectTimeout != 0 {
server.timeout = time.Duration(options.ConnectTimeout)
} else {
server.timeout = C.TCPConnectTimeout
}
return server, nil
}
func (s *ServerEndpoint) Start(stage adapter.StartStage) error {
return s.inbound.Start(stage)
}
func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if network != N.NetworkTCP {
return nil, os.ErrInvalid
}
var sourceUUID *uuid.UUID
var ch chan net.Conn
if metadata := adapter.ContextFrom(ctx); metadata != nil {
if metadata.TunnelDestination != "" {
tunnelDestination, err := uuid.FromString(metadata.TunnelDestination)
if err != nil {
return nil, err
}
s.mtx.Lock()
var ok bool
ch, ok = s.conns[tunnelDestination]
if !ok {
return nil, E.New("user ", metadata.TunnelDestination, " not found")
}
s.mtx.Unlock()
}
if metadata.TunnelSource != "" {
tunnelSource, err := uuid.FromString(metadata.TunnelSource)
if err != nil {
return nil, err
}
sourceUUID = &tunnelSource
}
}
if ch == nil {
return nil, E.New("tunnel destination not set")
}
if sourceUUID == nil {
sourceUUID = &s.uuid
}
ctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
select {
case conn := <-ch:
err := WriteRequest(conn, &Request{UUID: *sourceUUID, Command: CommandTCP, Destination: destination})
if err != nil {
s.logger.ErrorContext(ctx, err)
continue
}
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
func (s *ServerEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return nil, os.ErrInvalid
}
func (s *ServerEndpoint) Close() error {
return common.Close(s.inbound)
}
func (s *ServerEndpoint) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error {
if metadata.Destination != Destination {
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return nil
}
request, err := ReadRequest(conn)
if err != nil {
return err
}
if request.Command == CommandInbound {
s.mtx.Lock()
defer s.mtx.Unlock()
uuid, ok := s.users[request.UUID]
if !ok {
return E.New("key ", request.UUID.String(), " not found")
}
ch := s.conns[uuid]
select {
case ch <- conn:
default:
oldConn := <-ch
oldConn.Close()
ch <- conn
}
return nil
}
if request.Command == CommandTCP {
sourceUUID, ok := s.users[request.UUID]
if !ok {
return E.New("key ", request.UUID, " not found")
}
if sourceUUID == request.DestinationUUID {
return E.New("routing loop on ", sourceUUID)
}
s.mtx.Lock()
if request.DestinationUUID != s.uuid {
_, ok = s.keys[request.DestinationUUID]
if !ok {
return E.New("user ", sourceUUID, " not found")
}
}
s.mtx.Unlock()
metadata.Inbound = s.Tag()
metadata.InboundType = C.TypeTunnelServer
metadata.Destination = request.Destination
metadata.TunnelSource = sourceUUID.String()
metadata.TunnelDestination = request.DestinationUUID.String()
s.router.RouteConnectionEx(ctx, conn, metadata, onClose)
return nil
}
return E.New("command ", request.Command, " not found")
}