Files
sing-box-extended/transport/openvpn/control.go

476 lines
10 KiB
Go

package openvpn
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
)
type PacketIO interface {
ReadPacket(ctx context.Context) ([]byte, error)
WritePacket(ctx context.Context, packet []byte) error
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
}
type ControlChannel struct {
io PacketIO
encode func(*ControlPacket, uint32, uint32) ([]byte, error)
decode func([]byte) (*ControlPacket, uint32, uint32, error)
clock func() time.Time
keyID uint8
local SessionID
remote SessionID
mu sync.Mutex
sendPacketID uint32
sendMessage uint32
ackPending []uint32
pending map[uint32]*ControlPacket
readDeadline time.Time
writeDeadline time.Time
}
func NewControlChannel(io PacketIO, crypt ControlCrypt, local SessionID) *ControlChannel {
ch := &ControlChannel{
io: io,
clock: time.Now,
local: local,
pending: make(map[uint32]*ControlPacket),
}
if crypt != nil {
ch.encode = func(p *ControlPacket, pid uint32, t uint32) ([]byte, error) {
return EncodeControlPacketCrypt(*p, crypt, pid, t)
}
ch.decode = func(pkt []byte) (*ControlPacket, uint32, uint32, error) {
return DecodeControlPacketCrypt(crypt, pkt)
}
} else {
ch.encode = func(p *ControlPacket, _ uint32, _ uint32) ([]byte, error) {
return EncodeControlPacket(*p)
}
ch.decode = func(pkt []byte) (*ControlPacket, uint32, uint32, error) {
cp, err := DecodeControlPacket(pkt)
return cp, 0, 0, err
}
}
return ch
}
func (c *ControlChannel) LocalSessionID() SessionID {
return c.local
}
func (c *ControlChannel) RemoteSessionID() SessionID {
c.mu.Lock()
defer c.mu.Unlock()
return c.remote
}
func (c *ControlChannel) SetRemoteSessionID(id SessionID) {
c.mu.Lock()
c.remote = id
c.mu.Unlock()
}
func (c *ControlChannel) SendReset(ctx context.Context) error {
_, err := c.Send(ctx, PControlHardResetClientV2, nil)
return err
}
func (c *ControlChannel) Send(ctx context.Context, opcode Opcode, payload []byte) (uint32, error) {
if !opcode.HasMessageID() {
return 0, fmt.Errorf("opcode %s cannot carry a reliable message", opcode)
}
c.mu.Lock()
messageID := c.sendMessage
c.sendMessage++
packet := &ControlPacket{
Opcode: opcode,
KeyID: c.keyID,
LocalSession: c.local,
AckIDs: append([]uint32(nil), c.ackPending...),
AckRemoteSession: c.remote,
MessageID: messageID,
Payload: cloneBytes(payload),
}
c.ackPending = nil
c.pending[messageID] = packet
c.mu.Unlock()
if err := c.writeControlPacket(ctx, packet); err != nil {
return 0, err
}
return messageID, nil
}
func (c *ControlChannel) SendAck(ctx context.Context) error {
c.mu.Lock()
if len(c.ackPending) == 0 {
c.mu.Unlock()
return nil
}
packet := &ControlPacket{
Opcode: PAckV1,
KeyID: c.keyID,
LocalSession: c.local,
AckIDs: append([]uint32(nil), c.ackPending...),
AckRemoteSession: c.remote,
}
c.ackPending = nil
c.mu.Unlock()
return c.writeControlPacket(ctx, packet)
}
func (c *ControlChannel) Read(ctx context.Context) (*ControlPacket, error) {
for {
packet, err := c.readControlPacket(ctx)
if err != nil {
return nil, err
}
c.mu.Lock()
if c.remote == (SessionID{}) && packet.LocalSession != c.local {
c.remote = packet.LocalSession
}
for _, ackID := range packet.AckIDs {
delete(c.pending, ackID)
}
if packet.Opcode.HasMessageID() {
c.ackPending = appendAck(c.ackPending, packet.MessageID)
}
c.mu.Unlock()
if packet.Opcode == PAckV1 {
continue
}
return packet, nil
}
}
func (c *ControlChannel) PendingMessages() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.pending)
}
func (c *ControlChannel) RetransmitPending(ctx context.Context) error {
c.mu.Lock()
packets := make([]*ControlPacket, 0, len(c.pending))
for _, packet := range c.pending {
cp := *packet
cp.AckIDs = append([]uint32(nil), c.ackPending...)
cp.AckRemoteSession = c.remote
packets = append(packets, &cp)
}
c.ackPending = nil
c.mu.Unlock()
for _, packet := range packets {
if err := c.writeControlPacket(ctx, packet); err != nil {
return err
}
}
return nil
}
func (c *ControlChannel) writeControlPacket(ctx context.Context, packet *ControlPacket) error {
c.mu.Lock()
c.sendPacketID++
packetID := c.sendPacketID
unixTime := uint32(c.clock().Unix())
deadline := c.writeDeadline
c.mu.Unlock()
if !deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, deadline)
defer cancel()
}
encoded, err := c.encode(packet, packetID, unixTime)
if err != nil {
return err
}
return c.io.WritePacket(ctx, encoded)
}
func (c *ControlChannel) encodeAndWrap(ctx context.Context, packet *ControlPacket) ([]byte, error) {
c.mu.Lock()
c.sendPacketID++
packetID := c.sendPacketID
unixTime := uint32(c.clock().Unix())
c.mu.Unlock()
return c.encode(packet, packetID, unixTime)
}
func (c *ControlChannel) readControlPacket(ctx context.Context) (*ControlPacket, error) {
c.mu.Lock()
deadline := c.readDeadline
c.mu.Unlock()
if !deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, deadline)
defer cancel()
}
raw, err := c.io.ReadPacket(ctx)
if err != nil {
return nil, err
}
packet, _, _, err := c.decode(raw)
return packet, err
}
func (c *ControlChannel) SetDeadline(t time.Time) error {
c.mu.Lock()
c.readDeadline = t
c.writeDeadline = t
c.mu.Unlock()
return nil
}
func (c *ControlChannel) SetReadDeadline(t time.Time) error {
c.mu.Lock()
c.readDeadline = t
c.mu.Unlock()
return nil
}
func (c *ControlChannel) SetWriteDeadline(t time.Time) error {
c.mu.Lock()
c.writeDeadline = t
c.mu.Unlock()
return nil
}
func appendAck(acks []uint32, ack uint32) []uint32 {
for _, existing := range acks {
if existing == ack {
return acks
}
}
return append(acks, ack)
}
type ControlConn struct {
channel *ControlChannel
readBuf []byte
closed bool
mu sync.Mutex
}
func NewControlConn(channel *ControlChannel) *ControlConn {
return &ControlConn{channel: channel}
}
func (c *ControlConn) Read(b []byte) (int, error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, net.ErrClosed
}
if len(c.readBuf) > 0 {
n := copy(b, c.readBuf)
c.readBuf = c.readBuf[n:]
c.mu.Unlock()
return n, nil
}
c.mu.Unlock()
for {
packet, err := c.channel.Read(context.Background())
if err != nil {
return 0, err
}
if packet.Opcode != PControlV1 {
if err := c.channel.SendAck(context.Background()); err != nil {
return 0, err
}
continue
}
if err := c.channel.SendAck(context.Background()); err != nil {
return 0, err
}
if len(packet.Payload) == 0 {
continue
}
n := copy(b, packet.Payload)
if n < len(packet.Payload) {
c.mu.Lock()
c.readBuf = append(c.readBuf, packet.Payload[n:]...)
c.mu.Unlock()
}
return n, nil
}
}
func (c *ControlConn) Write(b []byte) (int, error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, net.ErrClosed
}
c.mu.Unlock()
if _, err := c.channel.Send(context.Background(), PControlV1, b); err != nil {
return 0, err
}
return len(b), nil
}
func (c *ControlConn) Close() error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil
}
c.closed = true
c.mu.Unlock()
return c.channel.io.Close()
}
func (c *ControlConn) LocalAddr() net.Addr {
return c.channel.io.LocalAddr()
}
func (c *ControlConn) RemoteAddr() net.Addr {
return c.channel.io.RemoteAddr()
}
func (c *ControlConn) SetDeadline(t time.Time) error {
return c.channel.SetDeadline(t)
}
func (c *ControlConn) SetReadDeadline(t time.Time) error {
return c.channel.SetReadDeadline(t)
}
func (c *ControlConn) SetWriteDeadline(t time.Time) error {
return c.channel.SetWriteDeadline(t)
}
type streamPacketIO struct {
conn net.Conn
}
type datagramPacketIO struct {
conn net.Conn
}
func NewDatagramPacketIO(conn net.Conn) PacketIO {
return &datagramPacketIO{conn: conn}
}
func (d *datagramPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
done := make(chan struct{})
var (
packet []byte
err error
)
go func() {
defer close(done)
buf := make([]byte, 64*1024)
var n int
n, err = d.conn.Read(buf)
if err == nil {
packet = cloneBytes(buf[:n])
}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return packet, err
}
}
func (d *datagramPacketIO) WritePacket(ctx context.Context, packet []byte) error {
done := make(chan error, 1)
go func() {
_, err := d.conn.Write(packet)
done <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
func (d *datagramPacketIO) Close() error {
return d.conn.Close()
}
func (d *datagramPacketIO) LocalAddr() net.Addr {
return d.conn.LocalAddr()
}
func (d *datagramPacketIO) RemoteAddr() net.Addr {
return d.conn.RemoteAddr()
}
func NewTCPPacketIO(conn net.Conn) PacketIO {
return &streamPacketIO{conn: conn}
}
func (s *streamPacketIO) ReadPacket(ctx context.Context) ([]byte, error) {
done := make(chan struct{})
var (
packet []byte
err error
)
go func() {
defer close(done)
var lenBuf [2]byte
if _, err = io.ReadFull(s.conn, lenBuf[:]); err != nil {
return
}
size := int(lenBuf[0])<<8 | int(lenBuf[1])
if size == 0 {
err = errors.New("empty openvpn tcp packet")
return
}
packet = make([]byte, size)
_, err = io.ReadFull(s.conn, packet)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return packet, err
}
}
func (s *streamPacketIO) WritePacket(ctx context.Context, packet []byte) error {
if len(packet) > 0xffff {
return fmt.Errorf("openvpn tcp packet too large: %d", len(packet))
}
done := make(chan error, 1)
go func() {
frame := make([]byte, 2+len(packet))
frame[0] = byte(len(packet) >> 8)
frame[1] = byte(len(packet))
copy(frame[2:], packet)
_, err := s.conn.Write(frame)
done <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
func (s *streamPacketIO) Close() error {
return s.conn.Close()
}
func (s *streamPacketIO) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}
func (s *streamPacketIO) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}