mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-26 20:29:03 +03:00
Add Snell protocol. Refactor MASQUE HTTP/2, Fair Queue. Update XHTTP, OpenVPN, Sudoku, Fallback. Fixes
This commit is contained in:
331
transport/masque/client_h2.go
Normal file
331
transport/masque/client_h2.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package masque
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/sagernet/quic-go/quicvarint"
|
||||
"github.com/yosida95/uritemplate/v3"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const h2DatagramCapsuleType uint64 = 0
|
||||
|
||||
const (
|
||||
ipv4HeaderLen = 20
|
||||
ipv6HeaderLen = 40
|
||||
)
|
||||
|
||||
func ConnectTunnelH2(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, endpoint *net.TCPAddr, connectUri string) (io.Closer, IpConn, *http.Response, error) {
|
||||
if endpoint == nil {
|
||||
return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
|
||||
}
|
||||
|
||||
tlsConfig.SetNextProtos([]string{"h2"})
|
||||
|
||||
conn, err := dialer.DialContext(ctx, N.NetworkTCP, M.SocksaddrFromNetIP(endpoint.AddrPort()))
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
tlsConn, err := tlsConfig.Client(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if err = tlsConn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
tr := &http2.Transport{
|
||||
ReadIdleTimeout: 30 * time.Second,
|
||||
}
|
||||
cc, err := tr.NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
_ = tlsConn.Close()
|
||||
return nil, nil, nil, fmt.Errorf("connect-ip: failed to create client connection: %w", err)
|
||||
}
|
||||
|
||||
additionalHeaders := http.Header{
|
||||
"User-Agent": []string{""},
|
||||
}
|
||||
template := uritemplate.MustNew(connectUri)
|
||||
|
||||
h2Headers := additionalHeaders.Clone()
|
||||
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
|
||||
h2Headers.Set("pq-enabled", "false")
|
||||
|
||||
ipConn, rsp, err := dialH2(ctx, cc, template, h2Headers)
|
||||
if err != nil {
|
||||
_ = cc.Close()
|
||||
if strings.Contains(err.Error(), "tls: access denied") {
|
||||
return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
}
|
||||
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
|
||||
}
|
||||
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
_ = ipConn.Close()
|
||||
_ = cc.Close()
|
||||
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %v", rsp.Status)
|
||||
}
|
||||
|
||||
return cc, ipConn, rsp, nil
|
||||
}
|
||||
|
||||
func dialH2(ctx context.Context, rt http.RoundTripper, template *uritemplate.Template, additionalHeaders http.Header) (*h2IpConn, *http.Response, error) {
|
||||
if len(template.Varnames()) > 0 {
|
||||
return nil, nil, errors.New("connect-ip: IP flow forwarding not supported")
|
||||
}
|
||||
|
||||
u, err := url.Parse(template.Raw())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("connect-ip: failed to parse URI: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodConnect, u.String(), pr)
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
return nil, nil, fmt.Errorf("connect-ip: failed to create request: %w", err)
|
||||
}
|
||||
req.Host = authorityFromURL(u)
|
||||
req.ContentLength = -1
|
||||
req.Header = make(http.Header)
|
||||
for k, v := range additionalHeaders {
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
stop := context.AfterFunc(ctx, cancel)
|
||||
rsp, err := rt.RoundTrip(req)
|
||||
stop()
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
return nil, nil, fmt.Errorf("connect-ip: failed to send request: %w", err)
|
||||
}
|
||||
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
|
||||
cancel()
|
||||
_ = pr.Close()
|
||||
_ = pw.Close()
|
||||
_ = rsp.Body.Close()
|
||||
return nil, rsp, fmt.Errorf("connect-ip: server responded with %d", rsp.StatusCode)
|
||||
}
|
||||
|
||||
stream := &h2DatagramStream{
|
||||
requestBody: pw,
|
||||
responseBody: rsp.Body,
|
||||
cancel: cancel,
|
||||
}
|
||||
return &h2IpConn{
|
||||
str: stream,
|
||||
closeChan: make(chan struct{}),
|
||||
}, rsp, nil
|
||||
}
|
||||
|
||||
func authorityFromURL(u *url.URL) string {
|
||||
if u.Port() != "" {
|
||||
return u.Host
|
||||
}
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return u.Host
|
||||
}
|
||||
return host + ":443"
|
||||
}
|
||||
|
||||
type h2IpConn struct {
|
||||
str *h2DatagramStream
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
closeChan chan struct{}
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (c *h2IpConn) ReadPacket() (b []byte, err error) {
|
||||
start:
|
||||
data, err := c.str.ReceiveDatagram(context.Background())
|
||||
if err != nil {
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
select {
|
||||
case <-c.closeChan:
|
||||
return nil, c.closeErr
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := c.handleIncomingProxiedPacket(data); err != nil {
|
||||
goto start
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) handleIncomingProxiedPacket(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return errors.New("connect-ip: empty packet")
|
||||
}
|
||||
switch v := ipVersion(data); v {
|
||||
default:
|
||||
return fmt.Errorf("connect-ip: unknown IP versions: %d", v)
|
||||
case 4:
|
||||
if len(data) < ipv4HeaderLen {
|
||||
return fmt.Errorf("connect-ip: malformed datagram: too short")
|
||||
}
|
||||
case 6:
|
||||
if len(data) < ipv6HeaderLen {
|
||||
return fmt.Errorf("connect-ip: malformed datagram: too short")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) WritePacket(b []byte) (icmp []byte, err error) {
|
||||
data, err := c.composeDatagram(b)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err := c.str.SendDatagram(data); err != nil {
|
||||
select {
|
||||
case <-c.closeChan:
|
||||
return nil, c.closeErr
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) composeDatagram(b []byte) ([]byte, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
switch v := ipVersion(b); v {
|
||||
default:
|
||||
return nil, fmt.Errorf("connect-ip: unknown IP versions: %d", v)
|
||||
case 4:
|
||||
if len(b) < ipv4HeaderLen {
|
||||
return nil, fmt.Errorf("connect-ip: IPv4 packet too short")
|
||||
}
|
||||
ttl := b[8]
|
||||
if ttl <= 1 {
|
||||
return nil, fmt.Errorf("connect-ip: datagram TTL too small: %d", ttl)
|
||||
}
|
||||
b[8]--
|
||||
binary.BigEndian.PutUint16(b[10:12], calculateIPv4Checksum(([ipv4HeaderLen]byte)(b[:ipv4HeaderLen])))
|
||||
case 6:
|
||||
if len(b) < ipv6HeaderLen {
|
||||
return nil, fmt.Errorf("connect-ip: IPv6 packet too short")
|
||||
}
|
||||
hopLimit := b[7]
|
||||
if hopLimit <= 1 {
|
||||
return nil, fmt.Errorf("connect-ip: datagram Hop Limit too small: %d", hopLimit)
|
||||
}
|
||||
b[7]--
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (c *h2IpConn) Close() error {
|
||||
c.mu.Lock()
|
||||
if c.closeErr == nil {
|
||||
c.closeErr = net.ErrClosed
|
||||
close(c.closeChan)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
err := c.str.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func ipVersion(b []byte) uint8 { return b[0] >> 4 }
|
||||
|
||||
func calculateIPv4Checksum(header [ipv4HeaderLen]byte) uint16 {
|
||||
var sum uint32
|
||||
for i := 0; i < len(header); i += 2 {
|
||||
if i == 10 {
|
||||
continue
|
||||
}
|
||||
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||
}
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
type h2DatagramStream struct {
|
||||
requestBody *io.PipeWriter
|
||||
responseBody io.ReadCloser
|
||||
cancel context.CancelFunc
|
||||
|
||||
readMu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *h2DatagramStream) ReceiveDatagram(_ context.Context) ([]byte, error) {
|
||||
s.readMu.Lock()
|
||||
defer s.readMu.Unlock()
|
||||
|
||||
reader := quicvarint.NewReader(s.responseBody)
|
||||
for {
|
||||
capsuleType, err := quicvarint.Read(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen, err := quicvarint.Read(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := make([]byte, payloadLen)
|
||||
_, err = io.ReadFull(reader, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if capsuleType != h2DatagramCapsuleType {
|
||||
continue
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *h2DatagramStream) SendDatagram(data []byte) error {
|
||||
frame := make([]byte, 0, quicvarint.Len(h2DatagramCapsuleType)+quicvarint.Len(uint64(len(data)))+len(data))
|
||||
frame = quicvarint.Append(frame, h2DatagramCapsuleType)
|
||||
frame = quicvarint.Append(frame, uint64(len(data)))
|
||||
frame = append(frame, data...)
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
_, err := s.requestBody.Write(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect-ip: failed to send datagram capsule: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *h2DatagramStream) Close() error {
|
||||
_ = s.requestBody.Close()
|
||||
err := s.responseBody.Close()
|
||||
s.cancel()
|
||||
return err
|
||||
}
|
||||
@@ -2,9 +2,9 @@ package masque
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -12,13 +12,13 @@ import (
|
||||
|
||||
connectip "github.com/Diniboy1123/connect-ip-go"
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
qtls "github.com/sagernet/sing-quic"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
"github.com/yosida95/uritemplate/v3"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -26,39 +26,60 @@ type (
|
||||
ListenPacket func(network string, address string) (net.PacketConn, error)
|
||||
)
|
||||
|
||||
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool) (net.PacketConn, *http3.Transport, *connectip.Conn, *http.Response, error) {
|
||||
template := uritemplate.MustNew(connectUri)
|
||||
additionalHeaders := http.Header{
|
||||
"User-Agent": []string{""},
|
||||
type IpConn interface {
|
||||
ReadPacket() (b []byte, err error)
|
||||
WritePacket(b []byte) (icmp []byte, err error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type closerFunc func() error
|
||||
|
||||
func (f closerFunc) Close() error { return f() }
|
||||
|
||||
type quicIpConn struct {
|
||||
conn *connectip.Conn
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newQuicIpConn(conn *connectip.Conn) *quicIpConn {
|
||||
return &quicIpConn{
|
||||
conn: conn,
|
||||
buf: make([]byte, 0xFFFF),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *quicIpConn) ReadPacket() ([]byte, error) {
|
||||
n, err := c.conn.ReadPacket(c.buf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.buf[:n], nil
|
||||
}
|
||||
|
||||
func (c *quicIpConn) WritePacket(b []byte) (icmp []byte, err error) {
|
||||
return c.conn.WritePacket(b)
|
||||
}
|
||||
|
||||
func (c *quicIpConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool, congestionControl func(conn *quic.Conn) congestion.CongestionControl) (io.Closer, IpConn, *http.Response, error) {
|
||||
if useHTTP2 {
|
||||
h2Endpoint, ok := endpoint.(*net.TCPAddr)
|
||||
if !ok || h2Endpoint == nil {
|
||||
return nil, nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
|
||||
return nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint")
|
||||
}
|
||||
h2Headers := additionalHeaders.Clone()
|
||||
h2Headers.Set("cf-connect-proto", "cf-connect-ip")
|
||||
h2Headers.Set("pq-enabled", "false")
|
||||
h2Client, err := newHTTP2Client(dialer, tlsConfig, h2Endpoint, connectUri)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to create HTTP/2 client: %w", err)
|
||||
}
|
||||
ipConn, rsp, err := connectip.DialH2(ctx, h2Client, template, h2Headers)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "tls: access denied") {
|
||||
return nil, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
}
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err)
|
||||
}
|
||||
return nil, nil, ipConn, rsp, nil
|
||||
return ConnectTunnelH2(ctx, dialer, tlsConfig, h2Endpoint, connectUri)
|
||||
}
|
||||
|
||||
quicEndpoint, ok := endpoint.(*net.UDPAddr)
|
||||
if !ok || quicEndpoint == nil {
|
||||
return nil, nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint")
|
||||
return nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint")
|
||||
}
|
||||
udpConn, err := dialer.ListenPacket(ctx, M.SocksaddrFromNetIP(quicEndpoint.AddrPort()))
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
conn, err := qtls.Dial(
|
||||
ctx,
|
||||
@@ -68,28 +89,34 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
|
||||
quicConfig,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
_ = udpConn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if congestionControl != nil {
|
||||
conn.SetCongestionControl(congestionControl(conn))
|
||||
}
|
||||
tr := &http3.Transport{
|
||||
EnableDatagrams: true,
|
||||
AdditionalSettings: map[uint64]uint64{
|
||||
// official client still sends this out as well, even though
|
||||
// it's deprecated, see https://datatracker.ietf.org/doc/draft-ietf-masque-h3-datagram/00/
|
||||
// SETTINGS_H3_DATAGRAM_00 = 0x0000000000000276
|
||||
// https://github.com/cloudflare/quiche/blob/7c66757dbc55b8d0c3653d4b345c6785a181f0b7/quiche/src/h3/frame.rs#L46
|
||||
0x276: 1,
|
||||
},
|
||||
DisableCompression: true,
|
||||
}
|
||||
hconn := tr.NewClientConn(conn)
|
||||
|
||||
template := uritemplate.MustNew(connectUri)
|
||||
additionalHeaders := http.Header{
|
||||
"User-Agent": []string{""},
|
||||
}
|
||||
ipConn, rsp, err := connectip.Dial(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true)
|
||||
if err != nil {
|
||||
_ = tr.Close()
|
||||
_ = conn.CloseWithError(0, "connect-ip dial failed")
|
||||
_ = udpConn.Close()
|
||||
if strings.Contains(err.Error(), "tls: access denied") {
|
||||
return udpConn, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
return nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service")
|
||||
}
|
||||
return udpConn, nil, nil, nil, fmt.Errorf("failed to dial connect-ip: %w", err)
|
||||
return nil, nil, rsp, fmt.Errorf("failed to dial connect-ip: %w", err)
|
||||
}
|
||||
err = ipConn.AdvertiseRoute(ctx, []connectip.IPRoute{
|
||||
{
|
||||
@@ -109,34 +136,16 @@ func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return udpConn, nil, nil, nil, err
|
||||
_ = ipConn.Close()
|
||||
_ = tr.Close()
|
||||
_ = udpConn.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return udpConn, tr, ipConn, rsp, nil
|
||||
}
|
||||
|
||||
func newHTTP2Client(dialer N.Dialer, baseTLSConfig aTLS.Config, endpoint *net.TCPAddr, connectURI string) (*http.Client, error) {
|
||||
if endpoint == nil {
|
||||
return nil, errors.New("missing HTTP/2 endpoint")
|
||||
}
|
||||
tlsConfig := baseTLSConfig.Clone()
|
||||
tlsConfig.SetNextProtos([]string{"h2"})
|
||||
return &http.Client{
|
||||
Transport: &http2.Transport{
|
||||
DialTLSContext: func(ctx context.Context, network, _ string, _ *tls.Config) (net.Conn, error) {
|
||||
conn, err := dialer.DialContext(ctx, network, M.SocksaddrFromNetIP(endpoint.AddrPort()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConn, err := tlsConfig.Client(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return tlsConn, nil
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
closer := closerFunc(func() error {
|
||||
_ = tr.Close()
|
||||
_ = udpConn.Close()
|
||||
return nil
|
||||
})
|
||||
return closer, newQuicIpConn(ipConn), rsp, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
@@ -23,4 +25,5 @@ type TunnelOptions struct {
|
||||
UDPKeepalivePeriod time.Duration
|
||||
UDPInitialPacketSize uint16
|
||||
ReconnectDelay time.Duration
|
||||
CongestionControl func(conn *quic.Conn) congestion.CongestionControl
|
||||
}
|
||||
|
||||
@@ -4,13 +4,12 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
connectip "github.com/Diniboy1123/connect-ip-go"
|
||||
"github.com/sagernet/quic-go/http3"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
@@ -22,9 +21,8 @@ type Tunnel struct {
|
||||
options TunnelOptions
|
||||
device Device
|
||||
|
||||
udpConn net.PacketConn
|
||||
tr *http3.Transport
|
||||
ipConn *connectip.Conn
|
||||
closer io.Closer
|
||||
ipConn IpConn
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
@@ -83,13 +81,11 @@ func (e *Tunnel) Close() error {
|
||||
defer e.mtx.Unlock()
|
||||
if e.ipConn != nil {
|
||||
e.ipConn.Close()
|
||||
if e.udpConn != nil {
|
||||
e.udpConn.Close()
|
||||
}
|
||||
if e.tr != nil {
|
||||
e.tr.Close()
|
||||
if e.closer != nil {
|
||||
e.closer.Close()
|
||||
}
|
||||
e.ipConn = nil
|
||||
e.closer = nil
|
||||
}
|
||||
return e.device.Close()
|
||||
}
|
||||
@@ -124,7 +120,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
}
|
||||
icmp, err := ipConn.WritePacket(packet)
|
||||
if err != nil {
|
||||
if errors.As(err, new(*connectip.CloseError)) {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
if ok := e.closeIpConn(ipConn); ok {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing to IP connection: %w", err))
|
||||
}
|
||||
@@ -135,7 +131,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
}
|
||||
if len(icmp) > 0 {
|
||||
if _, err := e.device.Write([][]byte{icmp}, 0); err != nil {
|
||||
if errors.As(err, new(*connectip.CloseError)) {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while writing ICMP to TUN device: %v", err))
|
||||
continue
|
||||
}
|
||||
@@ -145,15 +141,14 @@ func (e *Tunnel) maintainTunnel() {
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
buf := make([]byte, 1280)
|
||||
for e.ctx.Err() == nil {
|
||||
ipConn, err := e.getIpConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n, err := ipConn.ReadPacket(buf, true)
|
||||
packet, err := ipConn.ReadPacket()
|
||||
if err != nil {
|
||||
if e.options.UseHTTP2 || errors.As(err, new(*connectip.CloseError)) {
|
||||
if e.options.UseHTTP2 || errors.Is(err, net.ErrClosed) {
|
||||
if ok := e.closeIpConn(ipConn); ok {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("connection closed while reading from IP connection: %v", err))
|
||||
}
|
||||
@@ -162,7 +157,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("Error reading from IP connection: %v, continuine...", err))
|
||||
continue
|
||||
}
|
||||
if _, err := e.device.Write([][]byte{buf[:n]}, 0); err != nil {
|
||||
if _, err := e.device.Write([][]byte{packet}, 0); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -170,7 +165,7 @@ func (e *Tunnel) maintainTunnel() {
|
||||
<-e.ctx.Done()
|
||||
}
|
||||
|
||||
func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
func (e *Tunnel) getIpConn() (IpConn, error) {
|
||||
e.mtx.Lock()
|
||||
defer e.mtx.Unlock()
|
||||
if e.ctx.Err() != nil {
|
||||
@@ -184,7 +179,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
defer timer.Stop()
|
||||
for {
|
||||
e.logger.NoticeContext(e.ctx, fmt.Errorf("Establishing MASQUE connection to %s", e.options.Endpoint))
|
||||
udpConn, tr, ipConn, rsp, err := ConnectTunnel(
|
||||
closer, ipConn, rsp, err := ConnectTunnel(
|
||||
e.ctx,
|
||||
e.options.Dialer,
|
||||
e.options.TLSConfig,
|
||||
@@ -192,6 +187,7 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
"https://cloudflareaccess.com",
|
||||
e.options.Endpoint,
|
||||
e.options.UseHTTP2,
|
||||
e.options.CongestionControl,
|
||||
)
|
||||
if err != nil {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("Failed to connect tunnel: %v", err))
|
||||
@@ -206,11 +202,8 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
if rsp.StatusCode != 200 {
|
||||
e.logger.ErrorContext(e.ctx, fmt.Errorf("Tunnel connection failed: %s", rsp.Status))
|
||||
ipConn.Close()
|
||||
if udpConn != nil {
|
||||
udpConn.Close()
|
||||
}
|
||||
if tr != nil {
|
||||
tr.Close()
|
||||
if closer != nil {
|
||||
closer.Close()
|
||||
}
|
||||
timer.Reset(e.options.ReconnectDelay)
|
||||
select {
|
||||
@@ -220,26 +213,23 @@ func (e *Tunnel) getIpConn() (*connectip.Conn, error) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
e.udpConn = udpConn
|
||||
e.tr = tr
|
||||
e.closer = closer
|
||||
e.ipConn = ipConn
|
||||
e.logger.NoticeContext(e.ctx, "Connected to MASQUE server ", e.options.Endpoint)
|
||||
return ipConn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Tunnel) closeIpConn(ipConn *connectip.Conn) bool {
|
||||
func (e *Tunnel) closeIpConn(ipConn IpConn) bool {
|
||||
e.mtx.Lock()
|
||||
defer e.mtx.Unlock()
|
||||
if ipConn == e.ipConn {
|
||||
e.ipConn.Close()
|
||||
if e.udpConn != nil {
|
||||
e.udpConn.Close()
|
||||
}
|
||||
if e.tr != nil {
|
||||
e.tr.Close()
|
||||
if e.closer != nil {
|
||||
e.closer.Close()
|
||||
}
|
||||
e.ipConn = nil
|
||||
e.closer = nil
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
Reference in New Issue
Block a user