mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-27 04:39:02 +03:00
Add Snell protocol. Refactor MASQUE HTTP/2, Fair Queue. Update XHTTP, OpenVPN, Sudoku, Fallback. Fixes
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
@@ -23,7 +24,7 @@ const (
|
||||
|
||||
type DataCipher interface {
|
||||
Encrypt(header []byte, packetID uint32, payload []byte) ([]byte, error)
|
||||
Decrypt(packet []byte, headerSize int) ([]byte, error)
|
||||
Decrypt(packet []byte, headerSize int) (plaintext []byte, packetID uint32, err error)
|
||||
}
|
||||
|
||||
type AEADDataCipher struct {
|
||||
@@ -86,9 +87,9 @@ func (g *AEADDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) {
|
||||
func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
|
||||
if len(packet) < headerSize+4+AESGCMTagSize+1 {
|
||||
return nil, errors.New("openvpn gcm data packet too short")
|
||||
return nil, 0, errors.New("openvpn gcm data packet too short")
|
||||
}
|
||||
header := packet[:headerSize]
|
||||
pidBytes := packet[headerSize : headerSize+4]
|
||||
@@ -96,8 +97,13 @@ func (g *AEADDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error)
|
||||
ciphertext := packet[headerSize+4+AESGCMTagSize:]
|
||||
combined := append(ciphertext, tag...)
|
||||
ad := append(header, pidBytes...)
|
||||
nonce := g.nonce(binary.BigEndian.Uint32(pidBytes), g.recvImplicitIV)
|
||||
return g.recv.Open(nil, nonce[:], combined, ad)
|
||||
packetID := binary.BigEndian.Uint32(pidBytes)
|
||||
nonce := g.nonce(packetID, g.recvImplicitIV)
|
||||
plain, err := g.recv.Open(nil, nonce[:], combined, ad)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return plain, packetID, nil
|
||||
}
|
||||
|
||||
func (g *AEADDataCipher) nonce(packetID uint32, implicit [AESGCMIVSize]byte) [AESGCMIVSize]byte {
|
||||
@@ -127,6 +133,9 @@ func NewCBCCipher(keys *KeyMaterial, auth string) (*CBCDataCipher, error) {
|
||||
var newHash func() hash.Hash
|
||||
var hmacSize int
|
||||
switch auth {
|
||||
case AuthMD5:
|
||||
newHash = md5.New
|
||||
hmacSize = md5.Size
|
||||
case AuthSHA256:
|
||||
newHash = sha256.New
|
||||
hmacSize = sha256.Size
|
||||
@@ -176,34 +185,35 @@ func (c *CBCDataCipher) Encrypt(header []byte, packetID uint32, payload []byte)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, error) {
|
||||
func (c *CBCDataCipher) Decrypt(packet []byte, headerSize int) ([]byte, uint32, error) {
|
||||
minSize := headerSize + c.hmacSize + CBCIVSize + aes.BlockSize
|
||||
if len(packet) < minSize {
|
||||
return nil, errors.New("openvpn cbc data packet too short")
|
||||
return nil, 0, errors.New("openvpn cbc data packet too short")
|
||||
}
|
||||
tag := packet[headerSize : headerSize+c.hmacSize]
|
||||
iv := packet[headerSize+c.hmacSize : headerSize+c.hmacSize+CBCIVSize]
|
||||
ct := packet[headerSize+c.hmacSize+CBCIVSize:]
|
||||
if len(ct)%aes.BlockSize != 0 {
|
||||
return nil, errors.New("openvpn cbc ciphertext not block-aligned")
|
||||
return nil, 0, errors.New("openvpn cbc ciphertext not block-aligned")
|
||||
}
|
||||
mac := hmac.New(c.newHash, c.recvHMAC)
|
||||
mac.Write(iv)
|
||||
mac.Write(ct)
|
||||
if !hmac.Equal(tag, mac.Sum(nil)) {
|
||||
return nil, errors.New("openvpn cbc hmac verification failed")
|
||||
return nil, 0, errors.New("openvpn cbc hmac verification failed")
|
||||
}
|
||||
plain := make([]byte, len(ct))
|
||||
cipher.NewCBCDecrypter(c.recvBlock, iv).CryptBlocks(plain, ct)
|
||||
padLen := int(plain[len(plain)-1])
|
||||
if padLen < 1 || padLen > aes.BlockSize {
|
||||
return nil, errors.New("openvpn cbc invalid padding")
|
||||
return nil, 0, errors.New("openvpn cbc invalid padding")
|
||||
}
|
||||
plain = plain[:len(plain)-padLen]
|
||||
if len(plain) < 4 {
|
||||
return nil, errors.New("openvpn cbc payload too short")
|
||||
return nil, 0, errors.New("openvpn cbc payload too short")
|
||||
}
|
||||
return plain[4:], nil
|
||||
packetID := binary.BigEndian.Uint32(plain[:4])
|
||||
return plain[4:], packetID, nil
|
||||
}
|
||||
|
||||
func CipherKeyLength(cipher string) int {
|
||||
|
||||
@@ -8,12 +8,16 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
const defaultHandshakeTimeout = 30 * time.Second
|
||||
const (
|
||||
defaultHandshakeTimeout = 30 * time.Second
|
||||
controlRetransmitDelay = time.Second
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
config *ClientConfig
|
||||
@@ -26,6 +30,8 @@ type Client struct {
|
||||
push *PushReply
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
lastReceiveNano atomic.Int64
|
||||
}
|
||||
|
||||
func NewClient(config *ClientConfig, io PacketIO, tlsConfig tls.Config) (*Client, error) {
|
||||
@@ -154,6 +160,7 @@ func (c *Client) Handshake(ctx context.Context) (*PushReply, error) {
|
||||
return nil, err
|
||||
}
|
||||
c.data = NewDataChannel(cipher, push.PeerID, push.CompLZO)
|
||||
c.markReceive()
|
||||
return push, nil
|
||||
}
|
||||
|
||||
@@ -181,10 +188,21 @@ func (c *Client) ReadIPPacket(ctx context.Context) ([]byte, error) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c.markReceive()
|
||||
return plain, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SinceReceive() time.Duration {
|
||||
return time.Duration(int64(time.Since(clientStart)) - c.lastReceiveNano.Load())
|
||||
}
|
||||
|
||||
func (c *Client) markReceive() {
|
||||
c.lastReceiveNano.Store(int64(time.Since(clientStart)))
|
||||
}
|
||||
|
||||
var clientStart = time.Now().Add(-time.Hour)
|
||||
|
||||
func (c *Client) Close() error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
@@ -199,10 +217,24 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) waitServerReset(ctx context.Context) error {
|
||||
retransmits := 0
|
||||
for {
|
||||
packet, err := c.control.Read(ctx)
|
||||
readCtx := ctx
|
||||
cancel := func() {}
|
||||
if c.config.Proto == ProtoUDP {
|
||||
readCtx, cancel = context.WithTimeout(ctx, controlRetransmitDelay)
|
||||
}
|
||||
packet, err := c.control.Read(readCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read hard reset response: %w", err)
|
||||
if c.config.Proto == ProtoUDP && errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil {
|
||||
if err := c.control.RetransmitPending(ctx); err != nil {
|
||||
return fmt.Errorf("retransmit hard reset: %w", err)
|
||||
}
|
||||
retransmits++
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read hard reset response after %d retransmits: %w", retransmits, err)
|
||||
}
|
||||
switch packet.Opcode {
|
||||
case PControlHardResetServerV2:
|
||||
|
||||
@@ -20,6 +20,7 @@ const (
|
||||
CipherAES256CBC = "AES-256-CBC"
|
||||
CipherCHACHA20POLY = "CHACHA20-POLY1305"
|
||||
|
||||
AuthMD5 = "MD5"
|
||||
AuthSHA1 = "SHA1"
|
||||
AuthSHA256 = "SHA256"
|
||||
AuthSHA384 = "SHA384"
|
||||
@@ -107,7 +108,7 @@ func isValidCipher(cipher string) bool {
|
||||
|
||||
func isValidAuth(auth string) bool {
|
||||
switch auth {
|
||||
case AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512:
|
||||
case AuthMD5, AuthSHA1, AuthSHA256, AuthSHA384, AuthSHA512:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -30,8 +30,10 @@ type ControlChannel struct {
|
||||
mu sync.Mutex
|
||||
sendPacketID uint32
|
||||
sendMessage uint32
|
||||
recvMessage uint32
|
||||
ackPending []uint32
|
||||
pending map[uint32]*ControlPacket
|
||||
recvPending map[uint32]*ControlPacket
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
@@ -40,9 +42,10 @@ func NewControlChannel(io PacketIO, crypt ControlCrypt, local SessionID) *Contro
|
||||
ch := &ControlChannel{
|
||||
io: io,
|
||||
|
||||
clock: time.Now,
|
||||
local: local,
|
||||
pending: make(map[uint32]*ControlPacket),
|
||||
clock: time.Now,
|
||||
local: local,
|
||||
pending: make(map[uint32]*ControlPacket),
|
||||
recvPending: make(map[uint32]*ControlPacket),
|
||||
}
|
||||
if crypt != nil {
|
||||
ch.encode = func(p *ControlPacket, pid uint32, t uint32) ([]byte, error) {
|
||||
@@ -130,10 +133,23 @@ func (c *ControlChannel) SendAck(ctx context.Context) error {
|
||||
|
||||
func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
|
||||
for {
|
||||
c.mu.Lock()
|
||||
if packet, ok := c.recvPending[c.recvMessage]; ok {
|
||||
delete(c.recvPending, c.recvMessage)
|
||||
c.recvMessage++
|
||||
c.mu.Unlock()
|
||||
return packet, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
packet, err := c.readControlPacket(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var deliver *ControlPacket
|
||||
sendAck := false
|
||||
|
||||
c.mu.Lock()
|
||||
if c.remote == (SessionID{}) && packet.LocalSession != c.local {
|
||||
c.remote = packet.LocalSession
|
||||
@@ -144,11 +160,33 @@ func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
|
||||
if packet.Opcode.HasMessageID() {
|
||||
c.ackPending = appendAck(c.ackPending, packet.MessageID)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
if packet.Opcode == PAckV1 {
|
||||
continue
|
||||
|
||||
switch {
|
||||
case packet.Opcode == PAckV1:
|
||||
case !packet.Opcode.HasMessageID():
|
||||
deliver = packet
|
||||
case packet.MessageID < c.recvMessage:
|
||||
sendAck = true
|
||||
case packet.MessageID == c.recvMessage:
|
||||
deliver = packet
|
||||
c.recvMessage++
|
||||
default:
|
||||
if _, exists := c.recvPending[packet.MessageID]; !exists {
|
||||
c.recvPending[packet.MessageID] = packet
|
||||
}
|
||||
sendAck = true
|
||||
}
|
||||
|
||||
c.mu.Unlock()
|
||||
|
||||
if deliver != nil {
|
||||
return deliver, nil
|
||||
}
|
||||
if sendAck {
|
||||
if err := c.SendAck(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return packet, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -349,11 +387,17 @@ func (c *ControlConn) SetWriteDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
type streamPacketIO struct {
|
||||
conn net.Conn
|
||||
conn net.Conn
|
||||
deadlineMu sync.Mutex
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
|
||||
type datagramPacketIO struct {
|
||||
conn net.Conn
|
||||
conn net.Conn
|
||||
deadlineMu sync.Mutex
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
|
||||
func NewDatagramPacketIO(conn net.Conn) PacketIO {
|
||||
@@ -361,40 +405,23 @@ func NewDatagramPacketIO(conn net.Conn) PacketIO {
|
||||
}
|
||||
|
||||
func (d *datagramPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
|
||||
done := make(chan struct{})
|
||||
var (
|
||||
packet []byte
|
||||
err error
|
||||
)
|
||||
go func() {
|
||||
defer close(done)
|
||||
buf := make([]byte, 64*1024)
|
||||
var n int
|
||||
n, err = d.conn.Read(buf)
|
||||
if err == nil {
|
||||
packet = cloneBytes(buf[:n])
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return packet, err
|
||||
if err := setReadDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.readDeadline); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := make([]byte, 64*1024)
|
||||
n, err := d.conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, contextIOError(ctx, err)
|
||||
}
|
||||
return buf[:n], nil
|
||||
}
|
||||
|
||||
func (d *datagramPacketIO) WritePacket(ctx context.Context, packet []byte) error {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := d.conn.Write(packet)
|
||||
done <- err
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
if err := setWriteDeadlineFromContext(d.conn, ctx, &d.deadlineMu, &d.writeDeadline); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := d.conn.Write(packet)
|
||||
return contextIOError(ctx, err)
|
||||
}
|
||||
|
||||
func (d *datagramPacketIO) Close() error {
|
||||
@@ -414,52 +441,37 @@ func NewTCPPacketIO(conn net.Conn) PacketIO {
|
||||
}
|
||||
|
||||
func (s *streamPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
|
||||
done := make(chan struct{})
|
||||
var (
|
||||
packet []byte
|
||||
err error
|
||||
)
|
||||
go func() {
|
||||
defer close(done)
|
||||
var lenBuf [2]byte
|
||||
if _, err = io.ReadFull(s.conn, lenBuf[:]); err != nil {
|
||||
return
|
||||
}
|
||||
size := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if size == 0 {
|
||||
err = errors.New("empty openvpn tcp packet")
|
||||
return
|
||||
}
|
||||
packet = make([]byte, size)
|
||||
_, err = io.ReadFull(s.conn, packet)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
return packet, err
|
||||
if err := setReadDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.readDeadline); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var lenBuf [2]byte
|
||||
if _, err := io.ReadFull(s.conn, lenBuf[:]); err != nil {
|
||||
return nil, contextIOError(ctx, err)
|
||||
}
|
||||
size := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if size == 0 {
|
||||
return nil, errors.New("empty openvpn tcp packet")
|
||||
}
|
||||
packet := make([]byte, size)
|
||||
if _, err := io.ReadFull(s.conn, packet); err != nil {
|
||||
return nil, contextIOError(ctx, err)
|
||||
}
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func (s *streamPacketIO) WritePacket(ctx context.Context, packet []byte) error {
|
||||
if len(packet) > 0xffff {
|
||||
return fmt.Errorf("openvpn tcp packet too large: %d", len(packet))
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
frame := make([]byte, 2+len(packet))
|
||||
frame[0] = byte(len(packet) >> 8)
|
||||
frame[1] = byte(len(packet))
|
||||
copy(frame[2:], packet)
|
||||
_, err := s.conn.Write(frame)
|
||||
done <- err
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
if err := setWriteDeadlineFromContext(s.conn, ctx, &s.deadlineMu, &s.writeDeadline); err != nil {
|
||||
return err
|
||||
}
|
||||
frame := make([]byte, 2+len(packet))
|
||||
frame[0] = byte(len(packet) >> 8)
|
||||
frame[1] = byte(len(packet))
|
||||
copy(frame[2:], packet)
|
||||
_, err := s.conn.Write(frame)
|
||||
return contextIOError(ctx, err)
|
||||
}
|
||||
|
||||
func (s *streamPacketIO) Close() error {
|
||||
@@ -473,3 +485,50 @@ func (s *streamPacketIO) LocalAddr() net.Addr {
|
||||
func (s *streamPacketIO) RemoteAddr() net.Addr {
|
||||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func setReadDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if current.Equal(deadline) {
|
||||
return nil
|
||||
}
|
||||
if hasDeadline {
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return err
|
||||
}
|
||||
*current = deadline
|
||||
return nil
|
||||
}
|
||||
|
||||
func setWriteDeadlineFromContext(conn net.Conn, ctx context.Context, mu *sync.Mutex, current *time.Time) error {
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if current.Equal(deadline) {
|
||||
return nil
|
||||
}
|
||||
if hasDeadline {
|
||||
if err := conn.SetWriteDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := conn.SetWriteDeadline(time.Time{}); err != nil {
|
||||
return err
|
||||
}
|
||||
*current = deadline
|
||||
return nil
|
||||
}
|
||||
|
||||
func contextIOError(ctx context.Context, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() && ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,15 +8,21 @@ import (
|
||||
|
||||
const (
|
||||
PeerIDUnset uint32 = 0xffffff
|
||||
|
||||
dataChannelReplayWindow = 64
|
||||
)
|
||||
|
||||
type DataChannel struct {
|
||||
cipher DataCipher
|
||||
keyID uint8
|
||||
peerID uint32
|
||||
compLZO bool
|
||||
cipher DataCipher
|
||||
keyID uint8
|
||||
peerID uint32
|
||||
compLZO bool
|
||||
|
||||
mu sync.Mutex
|
||||
sendPacketID uint32
|
||||
recvHighest uint32
|
||||
recvWindow uint64
|
||||
recvSeen bool
|
||||
}
|
||||
|
||||
func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel {
|
||||
@@ -29,10 +35,11 @@ func NewDataChannel(cipher DataCipher, peerID uint32, compLZO bool) *DataChannel
|
||||
|
||||
func (d *DataChannel) Encrypt(packet []byte) ([]byte, error) {
|
||||
if d.compLZO {
|
||||
p := make([]byte, 1+len(packet))
|
||||
p[0] = 0xFA
|
||||
copy(p[1:], packet)
|
||||
packet = p
|
||||
compressed, err := lzo1xCompressSafe(packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packet = compressed
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.sendPacketID++
|
||||
@@ -50,18 +57,15 @@ func (d *DataChannel) Decrypt(packet []byte) ([]byte, error) {
|
||||
if opcode == PDataV2 {
|
||||
headerSize = 4
|
||||
}
|
||||
plain, err := d.cipher.Decrypt(packet, headerSize)
|
||||
plain, packetID, err := d.cipher.Decrypt(packet, headerSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.acceptPacketID(packetID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d.compLZO {
|
||||
if len(plain) < 1 {
|
||||
return nil, errors.New("openvpn comp-lzo packet too short")
|
||||
}
|
||||
if plain[0] != 0xFA {
|
||||
return nil, fmt.Errorf("openvpn compressed packet not supported (byte: 0x%02x)", plain[0])
|
||||
}
|
||||
plain = plain[1:]
|
||||
return lzo1xDecompressSafe(plain)
|
||||
}
|
||||
return plain, nil
|
||||
}
|
||||
@@ -78,6 +82,40 @@ func (d *DataChannel) dataHeader() []byte {
|
||||
return []byte{opcodeKeyID(PDataV1, d.keyID)}
|
||||
}
|
||||
|
||||
func (d *DataChannel) acceptPacketID(packetID uint32) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if !d.recvSeen {
|
||||
d.recvHighest = packetID
|
||||
d.recvWindow = 1
|
||||
d.recvSeen = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if packetID > d.recvHighest {
|
||||
shift := packetID - d.recvHighest
|
||||
if shift >= dataChannelReplayWindow {
|
||||
d.recvWindow = 1
|
||||
} else {
|
||||
d.recvWindow = d.recvWindow<<shift | 1
|
||||
}
|
||||
d.recvHighest = packetID
|
||||
return nil
|
||||
}
|
||||
|
||||
diff := d.recvHighest - packetID
|
||||
if diff >= dataChannelReplayWindow {
|
||||
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
|
||||
}
|
||||
mask := uint64(1) << diff
|
||||
if d.recvWindow&mask != 0 {
|
||||
return fmt.Errorf("openvpn replayed data packet id %d", packetID)
|
||||
}
|
||||
d.recvWindow |= mask
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParsePeerID(options string) uint32 {
|
||||
for _, field := range splitPushOptions(options) {
|
||||
if len(field) > len("peer-id ") && field[:len("peer-id ")] == "peer-id " {
|
||||
|
||||
444
transport/openvpn/e2e_test.go
Normal file
444
transport/openvpn/e2e_test.go
Normal file
@@ -0,0 +1,444 @@
|
||||
//go:build with_openvpn && with_gvisor
|
||||
|
||||
// OpenVPN E2E tests. Require a local OpenVPN server setup.
|
||||
//
|
||||
// Setup (run once before testing):
|
||||
//
|
||||
// # Generate PKI
|
||||
// mkdir -p /tmp/ovpn-e2e/pki/{issued,private}
|
||||
// cd /tmp/ovpn-e2e/pki
|
||||
// openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -days 1 -nodes -keyout ca.key -out ca.crt -subj "/CN=E2ETestCA"
|
||||
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/server.key -out server.csr -subj "/CN=server"
|
||||
// openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/server.crt -days 1
|
||||
// openssl req -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 -nodes -keyout private/client.key -out client.csr -subj "/CN=client"
|
||||
// openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out issued/client.crt -days 1
|
||||
// openvpn --genkey secret ta.key
|
||||
// openvpn --genkey secret ta-auth.key
|
||||
//
|
||||
// # Start servers (4 instances: TCP/UDP × tls-crypt/tls-auth)
|
||||
// # TCP + tls-crypt on :11940, subnet 10.99.0.0/24
|
||||
// # UDP + tls-crypt on :11941, subnet 10.99.1.0/24
|
||||
// # TCP + tls-auth on :11942, subnet 10.99.2.0/24
|
||||
// # UDP + tls-auth on :11943, subnet 10.99.3.0/24
|
||||
// #
|
||||
// # Each server config needs: topology subnet, duplicate-cn, persist-tun,
|
||||
// # data-ciphers AES-256-GCM:AES-128-GCM:AES-192-GCM:CHACHA20-POLY1305:AES-256-CBC:AES-128-CBC:AES-192-CBC
|
||||
// # auth SHA256, keepalive 10 60, ca/cert/key from above PKI.
|
||||
// # tls-auth servers use: tls-auth ta-auth.key 0
|
||||
// # tls-crypt servers use: tls-crypt ta.key
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-crypt.conf --daemon
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-crypt.conf --daemon
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-tcp-auth.conf --daemon
|
||||
// sudo openvpn --config /tmp/ovpn-e2e/server-udp-auth.conf --daemon
|
||||
//
|
||||
// # Start HTTP servers on each VPN subnet
|
||||
// for ip in 10.99.0.1 10.99.1.1 10.99.2.1 10.99.3.1; do
|
||||
// mkdir -p /tmp/ovpn-e2e/$ip && echo "hello" > /tmp/ovpn-e2e/$ip/index.html
|
||||
// cd /tmp/ovpn-e2e/$ip && python3 -m http.server 8080 --bind $ip &
|
||||
// done
|
||||
//
|
||||
// Run tests:
|
||||
//
|
||||
// go test -tags "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_masque,with_mtproxy,with_manager,with_admin_panel,with_v2ray_api,with_ccm,with_ocm,with_profiler,with_openvpn,with_sudoku,with_trusttunnel" \
|
||||
// -run TestE2E -v -count=1 ./transport/openvpn/ -timeout 300s
|
||||
//
|
||||
// Tests all 28 combinations: 2 protos (tcp/udp) × 2 TLS modes (tls-crypt/tls-auth) × 7 ciphers.
|
||||
|
||||
package openvpn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box"
|
||||
"github.com/sagernet/sing-box/include"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
// Servers (started externally):
|
||||
// TCP+tls-crypt :11940 subnet 10.99.0.0/24
|
||||
// UDP+tls-crypt :11941 subnet 10.99.1.0/24
|
||||
// TCP+tls-auth :11942 subnet 10.99.2.0/24
|
||||
// UDP+tls-auth :11943 subnet 10.99.3.0/24
|
||||
// TCP+plain :11944 subnet 10.99.4.0/24
|
||||
// UDP+plain :11945 subnet 10.99.5.0/24
|
||||
// TCP+tls-crypt+SHA1 :11946 subnet 10.99.6.0/24 (CBC only)
|
||||
// TCP+tls-crypt+SHA512 :11947 subnet 10.99.7.0/24 (CBC only)
|
||||
// Each has HTTP on .1:8080 serving "hello"
|
||||
|
||||
const pkiDir = "/tmp/ovpn-e2e/pki"
|
||||
|
||||
type serverConfig struct {
|
||||
proto string
|
||||
port uint16
|
||||
tlsMode string // "tls-crypt" or "tls-auth"
|
||||
httpAddr string
|
||||
}
|
||||
|
||||
var servers = []serverConfig{
|
||||
{"tcp", 11940, "tls-crypt", "10.99.0.1:8080"},
|
||||
{"udp", 11941, "tls-crypt", "10.99.1.1:8080"},
|
||||
{"tcp", 11942, "tls-auth", "10.99.2.1:8080"},
|
||||
{"udp", 11943, "tls-auth", "10.99.3.1:8080"},
|
||||
}
|
||||
|
||||
var ciphers = []string{
|
||||
"AES-128-GCM",
|
||||
"AES-192-GCM",
|
||||
"AES-256-GCM",
|
||||
"CHACHA20-POLY1305",
|
||||
"AES-128-CBC",
|
||||
"AES-192-CBC",
|
||||
"AES-256-CBC",
|
||||
}
|
||||
|
||||
var portCounter atomic.Uint32
|
||||
|
||||
func init() { portCounter.Store(18100) }
|
||||
|
||||
func nextPort() uint16 { return uint16(portCounter.Add(1)) }
|
||||
|
||||
func readFile(t *testing.T, path string) string {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Skipf("PKI not found: %v", err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func testCombo(t *testing.T, srv serverConfig, cipher string) {
|
||||
t.Helper()
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
|
||||
ovpnOpts := &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: srv.port}},
|
||||
Proto: srv.proto,
|
||||
Cipher: cipher,
|
||||
Auth: "SHA256",
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
}
|
||||
|
||||
switch srv.tlsMode {
|
||||
case "tls-crypt":
|
||||
ovpnOpts.TLSCrypt = readFile(t, pkiDir+"/ta.key")
|
||||
case "tls-auth":
|
||||
ovpnOpts.TLSAuth = readFile(t, pkiDir+"/ta-auth.key")
|
||||
ovpnOpts.KeyDirection = 1
|
||||
}
|
||||
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer instance.Close()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(srv.httpAddr))
|
||||
if err != nil {
|
||||
t.Fatal("dial:", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
|
||||
if err != nil {
|
||||
t.Fatal("write:", err)
|
||||
}
|
||||
body, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Fatal("read:", err)
|
||||
}
|
||||
if !strings.Contains(string(body), "hello") {
|
||||
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
|
||||
}
|
||||
}
|
||||
|
||||
// 4 servers × 7 ciphers = 28 combinations
|
||||
func TestE2E(t *testing.T) {
|
||||
for _, srv := range servers {
|
||||
for _, cipher := range ciphers {
|
||||
name := fmt.Sprintf("%s/%s/%s", srv.proto, srv.tlsMode, cipher)
|
||||
srv, cipher := srv, cipher
|
||||
t.Run(name, func(t *testing.T) {
|
||||
testCombo(t, srv, cipher)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test CBC ciphers with different auth algorithms (SHA1, SHA512)
|
||||
func TestE2E_Auth(t *testing.T) {
|
||||
type authServer struct {
|
||||
port uint16
|
||||
auth string
|
||||
httpAddr string
|
||||
}
|
||||
authServers := []authServer{
|
||||
{11946, "SHA1", "10.99.6.1:8080"},
|
||||
{11947, "SHA512", "10.99.7.1:8080"},
|
||||
}
|
||||
cbcCiphers := []string{"AES-128-CBC", "AES-256-CBC"}
|
||||
|
||||
for _, as := range authServers {
|
||||
for _, cipher := range cbcCiphers {
|
||||
name := fmt.Sprintf("auth-%s/%s", as.auth, cipher)
|
||||
as, cipher := as, cipher
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{
|
||||
Type: "openvpn", Tag: "vpn",
|
||||
Options: &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: as.port}},
|
||||
Proto: "tcp",
|
||||
Cipher: cipher,
|
||||
Auth: as.auth,
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer instance.Close()
|
||||
time.Sleep(2 * time.Second)
|
||||
doHTTPCheck(t, port, as.httpAddr)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test tunnel stability with multiple sequential requests
|
||||
func TestE2E_BulkData(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{
|
||||
Type: "openvpn", Tag: "vpn",
|
||||
Options: &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11940}},
|
||||
Proto: "tcp",
|
||||
Cipher: "AES-256-GCM",
|
||||
Auth: "SHA256",
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer instance.Close()
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", port), socks.Version5, "", "")
|
||||
for i := 0; i < 10; i++ {
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr("10.99.0.1:8080"))
|
||||
if err != nil {
|
||||
t.Fatalf("request %d dial: %v", i, err)
|
||||
}
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
fmt.Fprintf(conn, "GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n")
|
||||
body, err := io.ReadAll(conn)
|
||||
conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("request %d read: %v", i, err)
|
||||
}
|
||||
if !strings.Contains(string(body), "hello") {
|
||||
t.Fatalf("request %d: no 'hello'", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doHTTPCheck(t *testing.T, socksPort uint16, httpAddr string) {
|
||||
t.Helper()
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", socksPort), socks.Version5, "", "")
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddr(httpAddr))
|
||||
if err != nil {
|
||||
t.Fatal("dial:", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err = conn.Write([]byte("GET /index.html HTTP/1.0\r\nHost: test\r\n\r\n"))
|
||||
if err != nil {
|
||||
t.Fatal("write:", err)
|
||||
}
|
||||
body, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Fatal("read:", err)
|
||||
}
|
||||
if !strings.Contains(string(body), "hello") {
|
||||
t.Fatalf("no 'hello' in response: %s", string(body)[:min(len(body), 200)])
|
||||
}
|
||||
}
|
||||
|
||||
func startInstance(t *testing.T, ovpnOpts *option.OpenVPNOutboundOptions) uint16 {
|
||||
t.Helper()
|
||||
port := nextPort()
|
||||
opts := option.Options{
|
||||
Log: &option.LogOptions{Level: "error"},
|
||||
Inbounds: []option.Inbound{{
|
||||
Type: "socks",
|
||||
Options: &option.SocksInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: (*badoption.Addr)(&badoption.Addr{}),
|
||||
ListenPort: port,
|
||||
},
|
||||
},
|
||||
}},
|
||||
Outbounds: []option.Outbound{{Type: "openvpn", Tag: "vpn", Options: ovpnOpts}},
|
||||
Route: &option.RouteOptions{Final: "vpn"},
|
||||
}
|
||||
ctx := include.Context(context.Background())
|
||||
instance, err := box.New(box.Options{Context: ctx, Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := instance.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { instance.Close() })
|
||||
time.Sleep(2 * time.Second)
|
||||
return port
|
||||
}
|
||||
|
||||
func TestE2E_CompLZO(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
|
||||
for _, cipher := range ciphers {
|
||||
cipher := cipher
|
||||
t.Run(cipher, func(t *testing.T) {
|
||||
port := startInstance(t, &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: 11948}},
|
||||
Proto: "udp",
|
||||
Cipher: cipher,
|
||||
Auth: "SHA256",
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
})
|
||||
doHTTPCheck(t, port, "10.99.8.1:8080")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestE2E_AES192(t *testing.T) {
|
||||
ca := readFile(t, pkiDir+"/ca.crt")
|
||||
cert := readFile(t, pkiDir+"/issued/client.crt")
|
||||
key := readFile(t, pkiDir+"/private/client.key")
|
||||
tlsCrypt := readFile(t, pkiDir+"/ta.key")
|
||||
|
||||
type combo struct {
|
||||
proto string
|
||||
port uint16
|
||||
httpAddr string
|
||||
}
|
||||
for _, c := range []combo{
|
||||
{"tcp", 11940, "10.99.0.1:8080"},
|
||||
{"udp", 11941, "10.99.1.1:8080"},
|
||||
} {
|
||||
for _, cipher := range []string{"AES-192-GCM", "AES-192-CBC"} {
|
||||
c, cipher := c, cipher
|
||||
t.Run(fmt.Sprintf("%s/%s", c.proto, cipher), func(t *testing.T) {
|
||||
port := startInstance(t, &option.OpenVPNOutboundOptions{
|
||||
Servers: []option.ServerOptions{{Server: "127.0.0.1", ServerPort: c.port}},
|
||||
Proto: c.proto,
|
||||
Cipher: cipher,
|
||||
Auth: "SHA256",
|
||||
TLSCrypt: tlsCrypt,
|
||||
OpenVPNOutboundTLSOptionsContainer: option.OpenVPNOutboundTLSOptionsContainer{
|
||||
TLS: &option.OpenVPNTLSOptions{CA: ca, Certificate: cert, Key: key},
|
||||
},
|
||||
})
|
||||
doHTTPCheck(t, port, c.httpAddr)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ net.Conn
|
||||
@@ -114,7 +114,7 @@ func ParseServerKeyMethod2Record(packet []byte) (*KeyMethod2Record, error) {
|
||||
}
|
||||
|
||||
func DeriveClientKeyMaterial(sources KeySource2, clientSession, serverSession SessionID, cipherKeyLen int) (*KeyMaterial, error) {
|
||||
if cipherKeyLen != 16 && cipherKeyLen != 32 {
|
||||
if cipherKeyLen != 16 && cipherKeyLen != 24 && cipherKeyLen != 32 {
|
||||
return nil, fmt.Errorf("unsupported data cipher key length %d", cipherKeyLen)
|
||||
}
|
||||
var master [48]byte
|
||||
|
||||
48
transport/openvpn/lzo.go
Normal file
48
transport/openvpn/lzo.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/rasky/go-lzo"
|
||||
)
|
||||
|
||||
const (
|
||||
lzoCompressNone = 0xFA
|
||||
lzoCompressLZO = 0x66
|
||||
)
|
||||
|
||||
var ErrLZODecompress = errors.New("lzo decompression failed")
|
||||
|
||||
func lzo1xDecompressSafe(src []byte) ([]byte, error) {
|
||||
if len(src) == 0 {
|
||||
return nil, ErrLZODecompress
|
||||
}
|
||||
|
||||
switch src[0] {
|
||||
case lzoCompressNone:
|
||||
if len(src) > 1 {
|
||||
return src[1:], nil
|
||||
}
|
||||
return nil, nil
|
||||
case lzoCompressLZO:
|
||||
if len(src) > 1 {
|
||||
r := bytes.NewReader(src[1:])
|
||||
out, err := lzo.Decompress1X(r, len(src)-1, 0)
|
||||
if err != nil {
|
||||
return nil, ErrLZODecompress
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, ErrLZODecompress
|
||||
}
|
||||
}
|
||||
|
||||
func lzo1xCompressSafe(src []byte) ([]byte, error) {
|
||||
lzoPacket := make([]byte, 1+len(src))
|
||||
lzoPacket[0] = lzoCompressNone
|
||||
copy(lzoPacket[1:], src)
|
||||
return lzoPacket, nil
|
||||
}
|
||||
@@ -10,16 +10,17 @@ import (
|
||||
const PushRequest = "PUSH_REQUEST"
|
||||
|
||||
type PushReply struct {
|
||||
Raw string
|
||||
Prefixes []netip.Prefix
|
||||
DNS []netip.Addr
|
||||
PeerID uint32
|
||||
Cipher string
|
||||
Ping uint32
|
||||
MTU uint32
|
||||
CompLZO bool
|
||||
Redirect bool
|
||||
BlockIPv6 bool
|
||||
Raw string
|
||||
Prefixes []netip.Prefix
|
||||
DNS []netip.Addr
|
||||
PeerID uint32
|
||||
Cipher string
|
||||
Ping uint32
|
||||
PingRestart uint32
|
||||
MTU uint32
|
||||
CompLZO bool
|
||||
Redirect bool
|
||||
BlockIPv6 bool
|
||||
}
|
||||
|
||||
func ParsePushReply(message string) (*PushReply, error) {
|
||||
@@ -81,6 +82,12 @@ func ParsePushReply(message string) (*PushReply, error) {
|
||||
reply.Ping = uint32(v)
|
||||
}
|
||||
}
|
||||
case "ping-restart":
|
||||
if len(fields) >= 2 {
|
||||
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
|
||||
reply.PingRestart = uint32(v)
|
||||
}
|
||||
}
|
||||
case "tun-mtu":
|
||||
if len(fields) >= 2 {
|
||||
if v, err := strconv.ParseUint(fields[1], 10, 32); err == nil {
|
||||
@@ -113,27 +120,44 @@ func splitPushOptions(message string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseIPv4Ifconfig(address, mask string) (netip.Prefix, error) {
|
||||
func parseIPv4Ifconfig(address, maskOrPeer string) (netip.Prefix, error) {
|
||||
addr, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 address %q: %w", address, err)
|
||||
}
|
||||
maskAddr, err := netip.ParseAddr(mask)
|
||||
maskAddr, err := netip.ParseAddr(maskOrPeer)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", mask, err)
|
||||
return netip.Prefix{}, fmt.Errorf("parse pushed ipv4 mask %q: %w", maskOrPeer, err)
|
||||
}
|
||||
if !addr.Is4() || !maskAddr.Is4() {
|
||||
return netip.Prefix{}, fmt.Errorf("openvpn ifconfig requires ipv4 address and mask")
|
||||
}
|
||||
maskBytes := maskAddr.As4()
|
||||
|
||||
if ones, ok := ipv4MaskSize(maskAddr); ok {
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
}
|
||||
|
||||
// Some servers, including SoftEther/VPNGate in net30/p2p mode, push
|
||||
// "ifconfig <local> <remote>" rather than "ifconfig <local> <netmask>".
|
||||
// Use a host prefix for that local tunnel address.
|
||||
return netip.PrefixFrom(addr, 32), nil
|
||||
}
|
||||
|
||||
func ipv4MaskSize(mask netip.Addr) (int, bool) {
|
||||
maskBytes := mask.As4()
|
||||
ones := 0
|
||||
seenZero := false
|
||||
for _, b := range maskBytes {
|
||||
for i := 7; i >= 0; i-- {
|
||||
if b&(1<<i) == 0 {
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
seenZero = true
|
||||
continue
|
||||
}
|
||||
if seenZero {
|
||||
return 0, false
|
||||
}
|
||||
ones++
|
||||
}
|
||||
}
|
||||
return netip.PrefixFrom(addr, ones), nil
|
||||
return ones, true
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package openvpn
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
@@ -35,6 +36,9 @@ func NewTLSAuth(staticKey []byte, keyDirection int, auth string) (*TLSAuth, erro
|
||||
var newHash func() hash.Hash
|
||||
var hmacSize int
|
||||
switch auth {
|
||||
case AuthMD5:
|
||||
newHash = md5.New
|
||||
hmacSize = md5.Size
|
||||
case AuthSHA256:
|
||||
newHash = sha256.New
|
||||
hmacSize = sha256.Size
|
||||
|
||||
@@ -30,16 +30,19 @@ type TunnelOptions struct {
|
||||
UDPTimeout time.Duration
|
||||
ReconnectDelay time.Duration
|
||||
PingInterval time.Duration
|
||||
PingRestart time.Duration
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger logger.ContextLogger
|
||||
options TunnelOptions
|
||||
device Device
|
||||
client *Client
|
||||
mtu uint32
|
||||
serverIndex int
|
||||
wg sync.WaitGroup
|
||||
|
||||
await chan struct{}
|
||||
mu sync.Mutex
|
||||
@@ -49,8 +52,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
|
||||
if options.ReconnectDelay == 0 {
|
||||
options.ReconnectDelay = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Tunnel{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
options: options,
|
||||
await: make(chan struct{}),
|
||||
@@ -59,10 +64,10 @@ func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelO
|
||||
|
||||
func (t *Tunnel) Start() error {
|
||||
go func() {
|
||||
defer close(t.await)
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
t.logger.Error("OpenVPN connect: ", err)
|
||||
close(t.await)
|
||||
return
|
||||
}
|
||||
t.mtu = 1500
|
||||
@@ -84,20 +89,26 @@ func (t *Tunnel) Start() error {
|
||||
if err != nil {
|
||||
client.Close()
|
||||
t.logger.Error("create OpenVPN device: ", err)
|
||||
close(t.await)
|
||||
return
|
||||
}
|
||||
t.device = device
|
||||
if err := device.Start(); err != nil {
|
||||
client.Close()
|
||||
t.logger.Error("start OpenVPN device: ", err)
|
||||
close(t.await)
|
||||
return
|
||||
}
|
||||
close(t.await)
|
||||
t.maintainTunnel()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if err := t.isTunnelInitialized(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
@@ -105,6 +116,9 @@ func (t *Tunnel) DialContext(ctx context.Context, network string, destination M.
|
||||
}
|
||||
|
||||
func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if err := t.isTunnelInitialized(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
@@ -112,15 +126,18 @@ func (t *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
|
||||
}
|
||||
|
||||
func (t *Tunnel) Close() error {
|
||||
t.cancel()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.client != nil {
|
||||
t.client.Close()
|
||||
t.client = nil
|
||||
}
|
||||
if t.device != nil {
|
||||
return t.device.Close()
|
||||
t.device.Close()
|
||||
t.device = nil
|
||||
}
|
||||
t.mu.Unlock()
|
||||
t.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -137,7 +154,9 @@ func (t *Tunnel) isTunnelInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (t *Tunnel) maintainTunnel() {
|
||||
t.wg.Add(2)
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
bufs := make([][]byte, 1)
|
||||
bufs[0] = make([]byte, t.mtu)
|
||||
sizes := make([]int, 1)
|
||||
@@ -161,6 +180,7 @@ func (t *Tunnel) maintainTunnel() {
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
for t.ctx.Err() == nil {
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
@@ -179,10 +199,14 @@ func (t *Tunnel) maintainTunnel() {
|
||||
if bytes.Equal(packet, pingPayload) {
|
||||
continue
|
||||
}
|
||||
if t.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if t.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if _, err := t.device.Write([][]byte{packet}, 0); err != nil {
|
||||
if t.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -208,6 +232,34 @@ func (t *Tunnel) maintainTunnel() {
|
||||
}
|
||||
}()
|
||||
}
|
||||
pingRestart := t.options.PingRestart
|
||||
if pingRestart == 0 && t.client != nil && t.client.push.PingRestart > 0 {
|
||||
pingRestart = time.Duration(t.client.push.PingRestart) * time.Second
|
||||
}
|
||||
if pingRestart > 0 {
|
||||
t.wg.Add(1)
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
ticker := time.NewTicker(pingRestart)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if client.SinceReceive() >= pingRestart {
|
||||
if ok := t.closeClient(client); ok {
|
||||
t.logger.ErrorContext(t.ctx, fmt.Errorf("ping-restart timeout: no packet received for %s", pingRestart))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
<-t.ctx.Done()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user