mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-05-14 00:51:12 +03:00
Update sing-box core
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/sagernet/sing-vmess"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
@@ -55,7 +56,7 @@ func newV2RayPlugin(ctx context.Context, pluginOpts Args, router adapter.Router,
|
||||
var tlsClient tls.Config
|
||||
var err error
|
||||
if tlsOptions.Enabled {
|
||||
tlsClient, err = tls.NewClient(ctx, serverAddr.AddrString(), tlsOptions)
|
||||
tlsClient, err = tls.NewClient(ctx, logger.NOP(), serverAddr.AddrString(), tlsOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -91,7 +92,7 @@ func newV2RayPlugin(ctx context.Context, pluginOpts Args, router adapter.Router,
|
||||
return nil, E.New("v2ray-plugin: unknown mode: " + mode)
|
||||
}
|
||||
|
||||
transport, err := v2ray.NewClientTransport(context.Background(), dialer, serverAddr, transportOptions, tlsClient)
|
||||
transport, err := v2ray.NewClientTransport(context.Background(), logger.NOP(), dialer, serverAddr, transportOptions, tlsClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ const (
|
||||
|
||||
var CRLF = []byte{'\r', '\n'}
|
||||
|
||||
var _ N.EarlyConn = (*ClientConn)(nil)
|
||||
var _ N.EarlyWriter = (*ClientConn)(nil)
|
||||
|
||||
type ClientConn struct {
|
||||
N.ExtendedConn
|
||||
@@ -43,7 +43,7 @@ func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) NeedHandshake() bool {
|
||||
func (c *ClientConn) NeedHandshakeForWrite() bool {
|
||||
return !c.headerWritten
|
||||
}
|
||||
|
||||
@@ -83,6 +83,14 @@ func (c *ClientConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
func (c *ClientConn) ReaderReplaceable() bool {
|
||||
return c.headerWritten
|
||||
}
|
||||
|
||||
func (c *ClientConn) WriterReplaceable() bool {
|
||||
return c.headerWritten
|
||||
}
|
||||
|
||||
type ClientPacketConn struct {
|
||||
net.Conn
|
||||
access sync.Mutex
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/v2rayhttp"
|
||||
"github.com/sagernet/sing-box/transport/v2rayhttpupgrade"
|
||||
@@ -50,7 +51,7 @@ func NewServerTransport(ctx context.Context, logger logger.ContextLogger, option
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayTransportOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
|
||||
func NewClientTransport(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayTransportOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
|
||||
if options.Type == "" {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -69,7 +70,7 @@ func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socks
|
||||
case C.V2RayTransportTypeHTTPUpgrade:
|
||||
return v2rayhttpupgrade.NewClient(ctx, dialer, serverAddr, options.HTTPUpgradeOptions, tlsConfig)
|
||||
case C.V2RayTransportTypeXHTTP:
|
||||
return xhttp.NewClient(ctx, dialer, serverAddr, options.XHTTPOptions, tlsConfig)
|
||||
return xhttp.NewClient(ctx, logger, dialer, serverAddr, options.XHTTPOptions, tlsConfig)
|
||||
case C.V2RayTransportTypeKCP:
|
||||
return v2raykcp.NewClient(ctx, dialer, serverAddr, options.KCPOptions, tlsConfig)
|
||||
default:
|
||||
|
||||
@@ -106,7 +106,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
cancel(err)
|
||||
return nil, err
|
||||
}
|
||||
return NewGRPCConn(stream), nil
|
||||
return NewGRPCConn(stream, cancel), nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package v2raygrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/baderror"
|
||||
@@ -14,16 +16,19 @@ var _ net.Conn = (*GRPCConn)(nil)
|
||||
|
||||
type GRPCConn struct {
|
||||
GunService
|
||||
cache []byte
|
||||
cache []byte
|
||||
cancel context.CancelCauseFunc
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewGRPCConn(service GunService) *GRPCConn {
|
||||
func NewGRPCConn(service GunService, cancel context.CancelCauseFunc) *GRPCConn {
|
||||
//nolint:staticcheck
|
||||
if client, isClient := service.(GunService_TunClient); isClient {
|
||||
service = &clientConnWrapper{client}
|
||||
}
|
||||
return &GRPCConn{
|
||||
GunService: service,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +59,11 @@ func (c *GRPCConn) Write(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *GRPCConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
if c.cancel != nil {
|
||||
c.cancel(nil)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
|
||||
}
|
||||
|
||||
func (s *Server) Tun(server GunService_TunServer) error {
|
||||
conn := NewGRPCConn(server)
|
||||
conn := NewGRPCConn(server, nil)
|
||||
var source M.Socksaddr
|
||||
if remotePeer, loaded := peer.FromContext(server.Context()); loaded {
|
||||
source = M.SocksaddrFromNet(remotePeer.Addr)
|
||||
|
||||
@@ -61,7 +61,7 @@ type GunServiceServer interface {
|
||||
type UnimplementedGunServiceServer struct{}
|
||||
|
||||
func (UnimplementedGunServiceServer) Tun(grpc.BidiStreamingServer[Hunk, Hunk]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Tun not implemented")
|
||||
return status.Error(codes.Unimplemented, "method Tun not implemented")
|
||||
}
|
||||
func (UnimplementedGunServiceServer) mustEmbedUnimplementedGunServiceServer() {}
|
||||
func (UnimplementedGunServiceServer) testEmbeddedByValue() {}
|
||||
@@ -74,7 +74,7 @@ type UnsafeGunServiceServer interface {
|
||||
}
|
||||
|
||||
func RegisterGunServiceServer(s grpc.ServiceRegistrar, srv GunServiceServer) {
|
||||
// If the following call pancis, it indicates UnimplementedGunServiceServer was
|
||||
// If the following call panics, it indicates UnimplementedGunServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
|
||||
@@ -29,7 +29,6 @@ var defaultClientHeader = http.Header{
|
||||
|
||||
type Client struct {
|
||||
ctx context.Context
|
||||
dialer N.Dialer
|
||||
serverAddr M.Socksaddr
|
||||
transport *http2.Transport
|
||||
options option.V2RayGRPCOptions
|
||||
@@ -46,7 +45,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
}
|
||||
client := &Client{
|
||||
ctx: ctx,
|
||||
dialer: dialer,
|
||||
serverAddr: serverAddr,
|
||||
options: options,
|
||||
transport: &http2.Transport{
|
||||
@@ -62,7 +60,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
},
|
||||
host: host,
|
||||
}
|
||||
|
||||
if tlsConfig == nil {
|
||||
client.transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
@@ -71,12 +68,9 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
if len(tlsConfig.NextProtos()) == 0 {
|
||||
tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
|
||||
}
|
||||
tlsDialer := tls.NewDialer(dialer, tlsConfig)
|
||||
client.transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
|
||||
conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tls.ClientHandshake(ctx, conn, tlsConfig)
|
||||
return tlsDialer.DialTLSContext(ctx, M.ParseSocksaddr(addr))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -47,15 +47,12 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
if len(tlsConfig.NextProtos()) == 0 {
|
||||
tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
|
||||
}
|
||||
tlsDialer := tls.NewDialer(dialer, tlsConfig)
|
||||
transport = &http2.Transport{
|
||||
ReadIdleTimeout: time.Duration(options.IdleTimeout),
|
||||
PingTimeout: time.Duration(options.PingTimeout),
|
||||
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.STDConfig) (net.Conn, error) {
|
||||
conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tls.ClientHandshake(ctx, conn, tlsConfig)
|
||||
return tlsDialer.DialTLSContext(ctx, M.ParseSocksaddr(addr))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,10 +136,12 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
s.handler.NewConnectionEx(HWIDContext(DupContext(request.Context()), request.Header), conn, source, M.Socksaddr{}, nil)
|
||||
} else {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
flusher := writer.(http.Flusher)
|
||||
flusher.Flush()
|
||||
done := make(chan struct{})
|
||||
conn := NewHTTP2Wrapper(&ServerHTTPConn{
|
||||
NewHTTPConn(request.Body, writer),
|
||||
writer.(http.Flusher),
|
||||
flusher,
|
||||
})
|
||||
s.handler.NewConnectionEx(HWIDContext(request.Context(), request.Header), conn, source, M.Socksaddr{}, N.OnceClose(func(it error) {
|
||||
close(done)
|
||||
|
||||
@@ -23,7 +23,6 @@ var _ adapter.V2RayClientTransport = (*Client)(nil)
|
||||
|
||||
type Client struct {
|
||||
dialer N.Dialer
|
||||
tlsConfig tls.Config
|
||||
serverAddr M.Socksaddr
|
||||
requestURL url.URL
|
||||
headers http.Header
|
||||
@@ -35,6 +34,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
if len(tlsConfig.NextProtos()) == 0 {
|
||||
tlsConfig.SetNextProtos([]string{"http/1.1"})
|
||||
}
|
||||
dialer = tls.NewDialer(dialer, tlsConfig)
|
||||
}
|
||||
var host string
|
||||
if options.Host != "" {
|
||||
@@ -65,7 +65,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
}
|
||||
return &Client{
|
||||
dialer: dialer,
|
||||
tlsConfig: tlsConfig,
|
||||
serverAddr: serverAddr,
|
||||
requestURL: requestURL,
|
||||
headers: headers,
|
||||
@@ -78,12 +77,6 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.tlsConfig != nil {
|
||||
conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
request := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &c.requestURL,
|
||||
|
||||
@@ -29,7 +29,7 @@ type Client struct {
|
||||
tlsConfig tls.Config
|
||||
quicConfig *quic.Config
|
||||
connAccess sync.Mutex
|
||||
conn common.TypedValue[quic.Connection]
|
||||
conn common.TypedValue[*quic.Conn]
|
||||
rawConn net.Conn
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) offer() (quic.Connection, error) {
|
||||
func (c *Client) offer() (*quic.Conn, error) {
|
||||
conn := c.conn.Load()
|
||||
if conn != nil && !common.Done(conn.Context()) {
|
||||
return conn, nil
|
||||
@@ -67,7 +67,7 @@ func (c *Client) offer() (quic.Connection, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) offerNew() (quic.Connection, error) {
|
||||
func (c *Client) offerNew() (*quic.Conn, error) {
|
||||
udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -84,11 +84,11 @@ func (s *Server) acceptLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) streamAcceptLoop(conn quic.Connection) error {
|
||||
func (s *Server) streamAcceptLoop(conn *quic.Conn) error {
|
||||
for {
|
||||
stream, err := conn.AcceptStream(s.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
return qtls.WrapError(err)
|
||||
}
|
||||
go s.handler.NewConnectionEx(conn.Context(), &StreamWrapper{Conn: conn, Stream: stream}, M.SocksaddrFromNet(conn.RemoteAddr()), M.Socksaddr{}, nil)
|
||||
}
|
||||
|
||||
@@ -4,24 +4,22 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing/common/baderror"
|
||||
qtls "github.com/sagernet/sing-quic"
|
||||
)
|
||||
|
||||
type StreamWrapper struct {
|
||||
Conn quic.Connection
|
||||
quic.Stream
|
||||
Conn *quic.Conn
|
||||
*quic.Stream
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) Read(p []byte) (n int, err error) {
|
||||
n, err = s.Stream.Read(p)
|
||||
//nolint:staticcheck
|
||||
return n, baderror.WrapQUIC(err)
|
||||
return n, qtls.WrapError(err)
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) Write(p []byte) (n int, err error) {
|
||||
n, err = s.Stream.Write(p)
|
||||
//nolint:staticcheck
|
||||
return n, baderror.WrapQUIC(err)
|
||||
return n, qtls.WrapError(err)
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) LocalAddr() net.Addr {
|
||||
|
||||
@@ -26,7 +26,6 @@ var _ adapter.V2RayClientTransport = (*Client)(nil)
|
||||
|
||||
type Client struct {
|
||||
dialer N.Dialer
|
||||
tlsConfig tls.Config
|
||||
serverAddr M.Socksaddr
|
||||
requestURL url.URL
|
||||
headers http.Header
|
||||
@@ -39,6 +38,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
if len(tlsConfig.NextProtos()) == 0 {
|
||||
tlsConfig.SetNextProtos([]string{"http/1.1"})
|
||||
}
|
||||
dialer = tls.NewDialer(dialer, tlsConfig)
|
||||
}
|
||||
var requestURL url.URL
|
||||
if tlsConfig == nil {
|
||||
@@ -65,7 +65,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
}
|
||||
return &Client{
|
||||
dialer,
|
||||
tlsConfig,
|
||||
serverAddr,
|
||||
requestURL,
|
||||
headers,
|
||||
@@ -79,12 +78,6 @@ func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers h
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.tlsConfig != nil {
|
||||
conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var deadlineConn net.Conn
|
||||
if deadline.NeedAdditionalReadDeadline(conn) {
|
||||
deadlineConn = deadline.NewConn(conn)
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/sagernet/sing-box/common/xray/pipe"
|
||||
"github.com/sagernet/sing-box/common/xray/signal/done"
|
||||
"github.com/sagernet/sing-box/common/xray/uuid"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
qtls "github.com/sagernet/sing-quic"
|
||||
"github.com/sagernet/sing/common"
|
||||
@@ -43,7 +44,7 @@ type Client struct {
|
||||
getHTTPClient2 func() (DialerClient, *XmuxClient)
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayXHTTPOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
|
||||
func NewClient(ctx context.Context, logger log.ContextLogger, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayXHTTPOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
|
||||
if options.Mode == "" {
|
||||
return nil, E.New("mode is not set")
|
||||
}
|
||||
@@ -78,7 +79,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
||||
dest2 := options2.ServerOptions.Build()
|
||||
var tlsConfig2 tls.Config
|
||||
if options2.TLS != nil {
|
||||
tlsConfig2, err = tls.NewClient(ctx, options2.Server, common.PtrValueOrDefault(options2.TLS))
|
||||
tlsConfig2, err = tls.NewClient(ctx, logger, options2.Server, common.PtrValueOrDefault(options2.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -315,7 +316,7 @@ func createHTTPClient(dest M.Socksaddr, dialer N.Dialer, options *option.V2RayXH
|
||||
}
|
||||
transport = &http3.Transport{
|
||||
QUICConfig: quicConfig,
|
||||
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
udpConn, dErr := dialer.DialContext(ctx, N.NetworkUDP, dest)
|
||||
if dErr != nil {
|
||||
return nil, dErr
|
||||
|
||||
@@ -162,7 +162,7 @@ func (c *ClientBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||
func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint, offset int) error {
|
||||
udpConn, err := c.connect()
|
||||
if err != nil {
|
||||
c.pauseManager.WaitActive()
|
||||
@@ -170,15 +170,18 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||
return err
|
||||
}
|
||||
destination := netip.AddrPort(ep.(remoteEndpoint))
|
||||
for _, b := range bufs {
|
||||
if len(b) > 3 {
|
||||
for _, buf := range bufs {
|
||||
if offset > 0 {
|
||||
buf = buf[offset:]
|
||||
}
|
||||
if len(buf) > 3 {
|
||||
reserved, loaded := c.reservedForEndpoint[destination]
|
||||
if !loaded {
|
||||
reserved = c.reserved
|
||||
}
|
||||
copy(b[1:4], reserved[:])
|
||||
copy(buf[1:4], reserved[:])
|
||||
}
|
||||
_, err = udpConn.WriteToUDPAddrPort(b, destination)
|
||||
_, err = udpConn.WriteToUDPAddrPort(buf, destination)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
return err
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
@@ -17,6 +18,8 @@ type Device interface {
|
||||
N.Dialer
|
||||
Start() error
|
||||
SetDevice(device *device.Device)
|
||||
Inet4Address() netip.Addr
|
||||
Inet6Address() netip.Addr
|
||||
}
|
||||
|
||||
type DeviceOptions struct {
|
||||
@@ -35,9 +38,14 @@ type DeviceOptions struct {
|
||||
func NewDevice(options DeviceOptions) (Device, error) {
|
||||
if !options.System {
|
||||
return newStackDevice(options)
|
||||
} else if options.Handler == nil {
|
||||
} else if !tun.WithGVisor {
|
||||
return newSystemDevice(options)
|
||||
} else {
|
||||
return newSystemStackDevice(options)
|
||||
}
|
||||
}
|
||||
|
||||
type NatDevice interface {
|
||||
Device
|
||||
CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error)
|
||||
}
|
||||
|
||||
103
transport/wireguard/device_nat.go
Normal file
103
transport/wireguard/device_nat.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing-tun/ping"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
var _ Device = (*natDeviceWrapper)(nil)
|
||||
|
||||
type natDeviceWrapper struct {
|
||||
Device
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
packetOutbound chan *buf.Buffer
|
||||
rewriter *ping.SourceRewriter
|
||||
buffer [][]byte
|
||||
}
|
||||
|
||||
func NewNATDevice(ctx context.Context, logger logger.ContextLogger, upstream Device) NatDevice {
|
||||
wrapper := &natDeviceWrapper{
|
||||
Device: upstream,
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
packetOutbound: make(chan *buf.Buffer, 256),
|
||||
rewriter: ping.NewSourceRewriter(ctx, logger, upstream.Inet4Address(), upstream.Inet6Address()),
|
||||
}
|
||||
return wrapper
|
||||
}
|
||||
|
||||
func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
select {
|
||||
case packet := <-d.packetOutbound:
|
||||
defer packet.Release()
|
||||
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
|
||||
return 1, nil
|
||||
default:
|
||||
}
|
||||
return d.Device.Read(bufs, sizes, offset)
|
||||
}
|
||||
|
||||
func (d *natDeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
|
||||
for _, buffer := range bufs {
|
||||
handled, err := d.rewriter.WriteBack(buffer[offset:])
|
||||
if handled {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
d.buffer = append(d.buffer, buffer)
|
||||
}
|
||||
}
|
||||
if len(d.buffer) > 0 {
|
||||
_, err := d.Device.Write(d.buffer, offset)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
d.buffer = d.buffer[:0]
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *natDeviceWrapper) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
|
||||
ctx := log.ContextWithNewID(d.ctx)
|
||||
session := tun.DirectRouteSession{
|
||||
Source: metadata.Source.Addr,
|
||||
Destination: metadata.Destination.Addr,
|
||||
}
|
||||
d.rewriter.CreateSession(session, routeContext)
|
||||
d.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
|
||||
return &natDestination{device: d, session: session}, nil
|
||||
}
|
||||
|
||||
var _ tun.DirectRouteDestination = (*natDestination)(nil)
|
||||
|
||||
type natDestination struct {
|
||||
device *natDeviceWrapper
|
||||
session tun.DirectRouteSession
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (d *natDestination) WritePacket(buffer *buf.Buffer) error {
|
||||
d.device.rewriter.RewritePacket(buffer.Bytes())
|
||||
d.device.packetOutbound <- buffer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *natDestination) Close() error {
|
||||
d.closed.Store(true)
|
||||
d.device.rewriter.DeleteSession(d.session)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *natDestination) IsClosed() bool {
|
||||
return d.closed.Load()
|
||||
}
|
||||
@@ -5,7 +5,9 @@ package wireguard
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
@@ -14,9 +16,14 @@ import (
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing-tun/ping"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
@@ -24,30 +31,40 @@ import (
|
||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||
)
|
||||
|
||||
var _ Device = (*stackDevice)(nil)
|
||||
var _ NatDevice = (*stackDevice)(nil)
|
||||
|
||||
type stackDevice struct {
|
||||
stack *stack.Stack
|
||||
mtu uint32
|
||||
events chan wgTun.Event
|
||||
outbound chan *stack.PacketBuffer
|
||||
done chan struct{}
|
||||
dispatcher stack.NetworkDispatcher
|
||||
addr4 tcpip.Address
|
||||
addr6 tcpip.Address
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
stack *stack.Stack
|
||||
mtu uint32
|
||||
events chan wgTun.Event
|
||||
outbound chan *stack.PacketBuffer
|
||||
packetOutbound chan *buf.Buffer
|
||||
done chan struct{}
|
||||
dispatcher stack.NetworkDispatcher
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
}
|
||||
|
||||
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
|
||||
tunDevice := &stackDevice{
|
||||
mtu: options.MTU,
|
||||
events: make(chan wgTun.Event, 1),
|
||||
outbound: make(chan *stack.PacketBuffer, 256),
|
||||
done: make(chan struct{}),
|
||||
ctx: options.Context,
|
||||
logger: options.Logger,
|
||||
mtu: options.MTU,
|
||||
events: make(chan wgTun.Event, 1),
|
||||
outbound: make(chan *stack.PacketBuffer, 256),
|
||||
packetOutbound: make(chan *buf.Buffer, 256),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
|
||||
ipStack, err := tun.NewGVisorStackWithOptions((*wireEndpoint)(tunDevice), stack.NICOptions{}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var (
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
)
|
||||
for _, prefix := range options.Address {
|
||||
addr := tun.AddressFromAddr(prefix.Addr())
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
@@ -57,10 +74,12 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
|
||||
},
|
||||
}
|
||||
if prefix.Addr().Is4() {
|
||||
tunDevice.addr4 = addr
|
||||
inet4Address = prefix.Addr()
|
||||
tunDevice.inet4Address = inet4Address
|
||||
protoAddr.Protocol = ipv4.ProtocolNumber
|
||||
} else {
|
||||
tunDevice.addr6 = addr
|
||||
inet6Address = prefix.Addr()
|
||||
tunDevice.inet6Address = inet6Address
|
||||
protoAddr.Protocol = ipv6.ProtocolNumber
|
||||
}
|
||||
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
|
||||
@@ -72,6 +91,10 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
|
||||
if options.Handler != nil {
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
||||
icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
|
||||
icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
|
||||
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
|
||||
}
|
||||
return tunDevice, nil
|
||||
}
|
||||
@@ -87,11 +110,17 @@ func (w *stackDevice) DialContext(ctx context.Context, network string, destinati
|
||||
}
|
||||
var networkProtocol tcpip.NetworkProtocolNumber
|
||||
if destination.IsIPv4() {
|
||||
if !w.inet4Address.IsValid() {
|
||||
return nil, E.New("missing IPv4 local address")
|
||||
}
|
||||
networkProtocol = header.IPv4ProtocolNumber
|
||||
bind.Addr = w.addr4
|
||||
bind.Addr = tun.AddressFromAddr(w.inet4Address)
|
||||
} else {
|
||||
if !w.inet6Address.IsValid() {
|
||||
return nil, E.New("missing IPv6 local address")
|
||||
}
|
||||
networkProtocol = header.IPv6ProtocolNumber
|
||||
bind.Addr = w.addr6
|
||||
bind.Addr = tun.AddressFromAddr(w.inet6Address)
|
||||
}
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
@@ -118,10 +147,10 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
|
||||
var networkProtocol tcpip.NetworkProtocolNumber
|
||||
if destination.IsIPv4() {
|
||||
networkProtocol = header.IPv4ProtocolNumber
|
||||
bind.Addr = w.addr4
|
||||
bind.Addr = tun.AddressFromAddr(w.inet4Address)
|
||||
} else {
|
||||
networkProtocol = header.IPv6ProtocolNumber
|
||||
bind.Addr = w.addr6
|
||||
bind.Addr = tun.AddressFromAddr(w.inet4Address)
|
||||
}
|
||||
udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
|
||||
if err != nil {
|
||||
@@ -130,6 +159,14 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
func (w *stackDevice) Inet4Address() netip.Addr {
|
||||
return w.inet4Address
|
||||
}
|
||||
|
||||
func (w *stackDevice) Inet6Address() netip.Addr {
|
||||
return w.inet6Address
|
||||
}
|
||||
|
||||
func (w *stackDevice) SetDevice(device *device.Device) {
|
||||
}
|
||||
|
||||
@@ -144,20 +181,24 @@ func (w *stackDevice) File() *os.File {
|
||||
|
||||
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
select {
|
||||
case packetBuffer, ok := <-w.outbound:
|
||||
case packet, ok := <-w.outbound:
|
||||
if !ok {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
defer packetBuffer.DecRef()
|
||||
p := bufs[0]
|
||||
p = p[offset:]
|
||||
n := 0
|
||||
for _, slice := range packetBuffer.AsSlices() {
|
||||
n += copy(p[n:], slice)
|
||||
defer packet.DecRef()
|
||||
var copyN int
|
||||
/*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) {
|
||||
copyN += copy(bufs[0][offset+copyN:], view.AsSlice())
|
||||
})*/
|
||||
for _, view := range packet.AsSlices() {
|
||||
copyN += copy(bufs[0][offset+copyN:], view)
|
||||
}
|
||||
sizes[0] = n
|
||||
count = 1
|
||||
return
|
||||
sizes[0] = copyN
|
||||
return 1, nil
|
||||
case packet := <-w.packetOutbound:
|
||||
defer packet.Release()
|
||||
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
|
||||
return 1, nil
|
||||
case <-w.done:
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
@@ -217,6 +258,23 @@ func (w *stackDevice) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
|
||||
ctx := log.ContextWithNewID(w.ctx)
|
||||
destination, err := ping.ConnectGVisor(
|
||||
ctx, w.logger,
|
||||
metadata.Source.Addr, metadata.Destination.Addr,
|
||||
routeContext,
|
||||
w.stack,
|
||||
w.inet4Address, w.inet6Address,
|
||||
timeout,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
|
||||
return destination, nil
|
||||
}
|
||||
|
||||
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
|
||||
|
||||
type wireEndpoint stackDevice
|
||||
|
||||
@@ -22,22 +22,42 @@ import (
|
||||
var _ Device = (*systemDevice)(nil)
|
||||
|
||||
type systemDevice struct {
|
||||
options DeviceOptions
|
||||
dialer N.Dialer
|
||||
device tun.Tun
|
||||
batchDevice tun.LinuxTUN
|
||||
events chan wgTun.Event
|
||||
closeOnce sync.Once
|
||||
options DeviceOptions
|
||||
dialer N.Dialer
|
||||
device tun.Tun
|
||||
batchDevice tun.LinuxTUN
|
||||
events chan wgTun.Event
|
||||
closeOnce sync.Once
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
}
|
||||
|
||||
func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
|
||||
if options.Name == "" {
|
||||
options.Name = tun.CalculateInterfaceName("wg")
|
||||
}
|
||||
var inet4Address netip.Addr
|
||||
var inet6Address netip.Addr
|
||||
if len(options.Address) > 0 {
|
||||
if prefix := common.Find(options.Address, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is4()
|
||||
}); prefix.IsValid() {
|
||||
inet4Address = prefix.Addr()
|
||||
}
|
||||
}
|
||||
if len(options.Address) > 0 {
|
||||
if prefix := common.Find(options.Address, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is6()
|
||||
}); prefix.IsValid() {
|
||||
inet6Address = prefix.Addr()
|
||||
}
|
||||
}
|
||||
return &systemDevice{
|
||||
options: options,
|
||||
dialer: options.CreateDialer(options.Name),
|
||||
events: make(chan wgTun.Event, 1),
|
||||
options: options,
|
||||
dialer: options.CreateDialer(options.Name),
|
||||
events: make(chan wgTun.Event, 1),
|
||||
inet4Address: inet4Address,
|
||||
inet6Address: inet6Address,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -49,6 +69,14 @@ func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr
|
||||
return w.dialer.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (w *systemDevice) Inet4Address() netip.Addr {
|
||||
return w.inet4Address
|
||||
}
|
||||
|
||||
func (w *systemDevice) Inet6Address() netip.Addr {
|
||||
return w.inet6Address
|
||||
}
|
||||
|
||||
func (w *systemDevice) SetDevice(device *device.Device) {
|
||||
}
|
||||
|
||||
|
||||
@@ -3,16 +3,26 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing-tun/ping"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
)
|
||||
|
||||
@@ -20,6 +30,8 @@ var _ Device = (*systemStackDevice)(nil)
|
||||
|
||||
type systemStackDevice struct {
|
||||
*systemDevice
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
stack *stack.Stack
|
||||
endpoint *deviceEndpoint
|
||||
writeBufs [][]byte
|
||||
@@ -34,13 +46,45 @@ func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
|
||||
mtu: options.MTU,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
ipStack, err := tun.NewGVisorStack(endpoint)
|
||||
ipStack, err := tun.NewGVisorStackWithOptions(endpoint, stack.NICOptions{}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
||||
var (
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
)
|
||||
for _, prefix := range options.Address {
|
||||
addr := tun.AddressFromAddr(prefix.Addr())
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: addr,
|
||||
PrefixLen: prefix.Bits(),
|
||||
},
|
||||
}
|
||||
if prefix.Addr().Is4() {
|
||||
inet4Address = prefix.Addr()
|
||||
protoAddr.Protocol = ipv4.ProtocolNumber
|
||||
} else {
|
||||
inet6Address = prefix.Addr()
|
||||
protoAddr.Protocol = ipv6.ProtocolNumber
|
||||
}
|
||||
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
|
||||
if gErr != nil {
|
||||
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
|
||||
}
|
||||
}
|
||||
if options.Handler != nil {
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
||||
icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
|
||||
icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
|
||||
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
|
||||
}
|
||||
return &systemStackDevice{
|
||||
ctx: options.Context,
|
||||
logger: options.Logger,
|
||||
systemDevice: system,
|
||||
stack: ipStack,
|
||||
endpoint: endpoint,
|
||||
@@ -116,6 +160,23 @@ func (w *systemStackDevice) writeStack(packet []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *systemStackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
|
||||
ctx := log.ContextWithNewID(w.ctx)
|
||||
destination, err := ping.ConnectGVisor(
|
||||
ctx, w.logger,
|
||||
metadata.Source.Addr, metadata.Destination.Addr,
|
||||
routeContext,
|
||||
w.stack,
|
||||
w.inet4Address, w.inet6Address,
|
||||
timeout,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
|
||||
return destination, nil
|
||||
}
|
||||
|
||||
type deviceEndpoint struct {
|
||||
mtu uint32
|
||||
done chan struct{}
|
||||
|
||||
@@ -8,9 +8,15 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
@@ -30,7 +36,9 @@ type Endpoint struct {
|
||||
ipcConf string
|
||||
allowedAddress []netip.Prefix
|
||||
tunDevice Device
|
||||
natDevice NatDevice
|
||||
device *device.Device
|
||||
allowedIPs *device.AllowedIPs
|
||||
pause pause.Manager
|
||||
pauseCallback *list.Element[pause.Callback]
|
||||
}
|
||||
@@ -112,12 +120,17 @@ func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create WireGuard device")
|
||||
}
|
||||
natDevice, isNatDevice := tunDevice.(NatDevice)
|
||||
if !isNatDevice {
|
||||
natDevice = NewNATDevice(options.Context, options.Logger, tunDevice)
|
||||
}
|
||||
return &Endpoint{
|
||||
options: options,
|
||||
peers: peers,
|
||||
ipcConf: ipcConf,
|
||||
allowedAddress: allowedAddresses,
|
||||
tunDevice: tunDevice,
|
||||
natDevice: natDevice,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -142,9 +155,9 @@ func (e *Endpoint) Start(resolve bool) error {
|
||||
return nil
|
||||
}
|
||||
var bind conn.Bind
|
||||
wgListener, isWgListener := common.Cast[conn.Listener](e.options.Dialer)
|
||||
wgListener, isWgListener := common.Cast[dialer.WireGuardListener](e.options.Dialer)
|
||||
if isWgListener {
|
||||
bind = conn.NewStdNetBind(wgListener)
|
||||
bind = conn.NewStdNetBind(wgListener.WireGuardControl())
|
||||
} else {
|
||||
var (
|
||||
isConnect bool
|
||||
@@ -177,7 +190,13 @@ func (e *Endpoint) Start(resolve bool) error {
|
||||
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
||||
},
|
||||
}
|
||||
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers, e.options.PreallocatedBuffersPerPool, e.options.DisablePauses)
|
||||
var deviceInput Device
|
||||
if e.natDevice != nil {
|
||||
deviceInput = e.natDevice
|
||||
} else {
|
||||
deviceInput = e.tunDevice
|
||||
}
|
||||
wgDevice := device.NewDevice(e.options.Context, deviceInput, bind, logger, e.options.Workers, e.options.PreallocatedBuffersPerPool, e.options.DisablePauses)
|
||||
e.tunDevice.SetDevice(wgDevice)
|
||||
ipcConf := e.ipcConf
|
||||
if e.options.Amnezia != nil {
|
||||
@@ -254,6 +273,7 @@ func (e *Endpoint) Start(resolve bool) error {
|
||||
if e.pause != nil {
|
||||
e.pauseCallback = e.pause.RegisterCallback(e.onPauseUpdated)
|
||||
}
|
||||
e.allowedIPs = (*device.AllowedIPs)(unsafe.Pointer(reflect.Indirect(reflect.ValueOf(wgDevice)).FieldByName("allowedips").UnsafeAddr()))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -281,6 +301,20 @@ func (e *Endpoint) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Endpoint) Lookup(address netip.Addr) *device.Peer {
|
||||
if e.allowedIPs == nil {
|
||||
return nil
|
||||
}
|
||||
return e.allowedIPs.Lookup(address.AsSlice())
|
||||
}
|
||||
|
||||
func (e *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
|
||||
if e.natDevice == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
return e.natDevice.CreateDestination(metadata, routeContext, timeout)
|
||||
}
|
||||
|
||||
func (e *Endpoint) onPauseUpdated(event int) {
|
||||
switch event {
|
||||
case pause.EventDevicePaused, pause.EventNetworkPause:
|
||||
|
||||
Reference in New Issue
Block a user