mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-05-14 08:52:47 +03:00
327 lines
9.6 KiB
Go
327 lines
9.6 KiB
Go
package masque
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"errors"
|
|
"log"
|
|
"math/big"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/sagernet/quic-go"
|
|
)
|
|
|
|
// PortMapping represents a network port forwarding rule.
|
|
type PortMapping struct {
|
|
BindAddress string // The address to bind the local port.
|
|
LocalPort int // The local port number.
|
|
RemoteIP string // The remote destination IP address.
|
|
RemotePort int // The remote destination port number.
|
|
}
|
|
|
|
// GenerateRandomAndroidSerial generates a random 8-byte Android-like device identifier
|
|
// and returns it as a hexadecimal string.
|
|
//
|
|
// Returns:
|
|
// - string: A randomly generated 16-character hexadecimal serial number.
|
|
// - error: An error if random data generation fails.
|
|
func GenerateRandomAndroidSerial() (string, error) {
|
|
serial := make([]byte, 8)
|
|
if _, err := rand.Read(serial); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(serial), nil
|
|
}
|
|
|
|
// GenerateRandomWgPubkey generates a random 32-byte WireGuard like public key
|
|
// and returns it as a base64-encoded string.
|
|
//
|
|
// Returns:
|
|
// - string: A randomly generated WireGuard like public key in base64 format.
|
|
// - error: An error if random data generation fails.
|
|
func GenerateRandomWgPubkey() (string, error) {
|
|
publicKey := make([]byte, 32)
|
|
if _, err := rand.Read(publicKey); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.StdEncoding.EncodeToString(publicKey), nil
|
|
}
|
|
|
|
// TimeAsCfString formats a given time.Time into a Cloudflare-compatible string format.
|
|
//
|
|
// The format follows the standard: "YYYY-MM-DDTHH:MM:SS.sss-07:00".
|
|
//
|
|
// Parameters:
|
|
// - t: time.Time to format.
|
|
//
|
|
// Returns:
|
|
// - string: The formatted time string.
|
|
func TimeAsCfString(t time.Time) string {
|
|
return t.Format("2006-01-02T15:04:05.000-07:00")
|
|
}
|
|
|
|
// GenerateEcKeyPair generates a new ECDSA key pair using the P-256 curve.
|
|
//
|
|
// Returns:
|
|
// - []byte: The marshalled private key in ASN.1 DER format.
|
|
// - []byte: The marshalled public key in PKIX format.
|
|
// - error: An error if key generation or marshalling fails.
|
|
func GenerateEcKeyPair() ([]byte, []byte, error) {
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
marshalledPrivKey, err := x509.MarshalECPrivateKey(privKey)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
marshalledPubKey, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return marshalledPrivKey, marshalledPubKey, nil
|
|
}
|
|
|
|
// GenerateCert creates a self-signed certificate using the provided ECDSA private and public keys.
|
|
//
|
|
// The certificate is valid for 24 hours.
|
|
//
|
|
// Parameters:
|
|
// - privKey: *ecdsa.PrivateKey - The private key to sign the certificate.
|
|
// - pubKey: *ecdsa.PublicKey - The public key to include in the certificate.
|
|
//
|
|
// Returns:
|
|
// - [][]byte: A slice containing the certificate in DER format.
|
|
// - error: An error if certificate generation fails.
|
|
func GenerateCert(privKey *ecdsa.PrivateKey, pubKey *ecdsa.PublicKey) ([][]byte, error) {
|
|
cert, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
|
|
SerialNumber: big.NewInt(0),
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(1 * 24 * time.Hour),
|
|
}, &x509.Certificate{}, &privKey.PublicKey, privKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return [][]byte{cert}, nil
|
|
}
|
|
|
|
// DefaultQuicConfig returns a MASQUE-compatible default QUIC configuration.
|
|
//
|
|
// When initialPacketSize is 0, Path MTU Discovery remains enabled.
|
|
//
|
|
// Parameters:
|
|
// - keepalivePeriod: time.Duration - The duration for sending QUIC keep-alive packets.
|
|
// - initialPacketSize: uint16 - The custom initial size of QUIC packets (0 = auto with PMTU discovery).
|
|
//
|
|
// Returns:
|
|
// - *quic.Config: A pointer to a configured QUIC configuration object.
|
|
func DefaultQuicConfig(keepalivePeriod time.Duration, initialPacketSize uint16) *quic.Config {
|
|
cfg := &quic.Config{
|
|
EnableDatagrams: true,
|
|
KeepAlivePeriod: keepalivePeriod,
|
|
}
|
|
|
|
if initialPacketSize > 0 {
|
|
cfg.InitialPacketSize = initialPacketSize
|
|
cfg.DisablePathMTUDiscovery = true
|
|
}
|
|
|
|
return cfg
|
|
}
|
|
|
|
// parsePortMapping is an internal helper function that parses a port mapping string into its components.
|
|
//
|
|
// It handles IPv6 addresses enclosed in brackets and various format edge cases.
|
|
//
|
|
// Parameters:
|
|
// - port: string - The port mapping string.
|
|
//
|
|
// Returns:
|
|
// - string: The bind address.
|
|
// - int: The local port.
|
|
// - string: The remote hostname/IP.
|
|
// - int: The remote port.
|
|
// - error: An error if parsing fails.
|
|
func parsePortMapping(port string) (bindAddress string, localPort int, remoteHost string, remotePort int, err error) {
|
|
parts := strings.Split(port, ":")
|
|
|
|
// Handle IPv6 addresses (which are enclosed in brackets)
|
|
if len(parts) >= 4 && strings.HasPrefix(parts[0], "[") && strings.Contains(parts[0], "]") {
|
|
bindAddress = parts[0]
|
|
parts = parts[1:] // Shift parts forward
|
|
} else if len(parts) == 3 {
|
|
bindAddress = "localhost" // Default to localhost
|
|
} else if len(parts) == 4 {
|
|
bindAddress = parts[0]
|
|
parts = parts[1:] // Shift forward
|
|
} else {
|
|
return "", 0, "", 0, errors.New("invalid port mapping format (expected format: [bind_address:]local_port:remote_host:remote_port)")
|
|
}
|
|
|
|
// Parse local port
|
|
localPort, err = strconv.Atoi(parts[0])
|
|
if err != nil || localPort <= 0 || localPort > 65535 {
|
|
return "", 0, "", 0, errors.New("invalid local port")
|
|
}
|
|
|
|
// Validate remote host (allow both hostnames and IPs)
|
|
remoteHost = parts[1]
|
|
if net.ParseIP(remoteHost) == nil && !isValidHostname(remoteHost) {
|
|
return "", 0, "", 0, errors.New("invalid remote hostname/IP")
|
|
}
|
|
|
|
// Parse remote port
|
|
remotePort, err = strconv.Atoi(parts[2])
|
|
if err != nil || remotePort <= 0 || remotePort > 65535 {
|
|
return "", 0, "", 0, errors.New("invalid remote port")
|
|
}
|
|
|
|
// If bindAddress is an IPv6 address, remove brackets for proper binding
|
|
if strings.HasPrefix(bindAddress, "[") && strings.HasSuffix(bindAddress, "]") {
|
|
bindAddress = strings.Trim(bindAddress, "[]")
|
|
}
|
|
|
|
// Convert "localhost" or hostnames to actual addresses
|
|
if bindAddress == "*" {
|
|
bindAddress = "0.0.0.0" // Allow all interfaces
|
|
}
|
|
|
|
// Validate bind address (support both IPs and hostnames)
|
|
bindAddress, err = resolveBindAddress(bindAddress)
|
|
if err != nil {
|
|
return "", 0, "", 0, errors.New("invalid local address: " + err.Error())
|
|
}
|
|
|
|
remoteHost, err = resolveBindAddress(remoteHost)
|
|
if err != nil {
|
|
return "", 0, "", 0, errors.New("invalid remote address: " + err.Error())
|
|
}
|
|
|
|
return bindAddress, localPort, remoteHost, remotePort, nil
|
|
}
|
|
|
|
// ParsePortMapping parses a port mapping string into a structured PortMapping.
|
|
//
|
|
// The expected format is: `[bind_address:]local_port:remote_host:remote_port`.
|
|
//
|
|
// Parameters:
|
|
// - port: string - The port mapping string.
|
|
//
|
|
// Returns:
|
|
// - PortMapping: A structured representation of the parsed port mapping.
|
|
// - error: An error if the parsing fails.
|
|
func ParsePortMapping(port string) (PortMapping, error) {
|
|
bindAddress, localPort, remoteHost, remotePort, err := parsePortMapping(port)
|
|
if err != nil {
|
|
return PortMapping{}, err
|
|
}
|
|
|
|
return PortMapping{
|
|
BindAddress: bindAddress,
|
|
LocalPort: localPort,
|
|
RemoteIP: remoteHost,
|
|
RemotePort: remotePort,
|
|
}, nil
|
|
}
|
|
|
|
// resolveBindAddress resolves a hostname or IP to its string representation.
|
|
//
|
|
// Parameters:
|
|
// - addr: string - The hostname or IP.
|
|
//
|
|
// Returns:
|
|
// - string: The resolved IP address.
|
|
// - error: An error if resolution fails.
|
|
func resolveBindAddress(addr string) (string, error) {
|
|
tcpAddr, err := net.ResolveTCPAddr("tcp", addr+":0") // Resolve the address
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return tcpAddr.IP.String(), nil // Return resolved IP
|
|
}
|
|
|
|
// isValidHostname checks if a given hostname is valid.
|
|
// Pretty ugly for now, needs to be refactored.
|
|
//
|
|
// Parameters:
|
|
// - hostname: string - The hostname to validate.
|
|
//
|
|
// Returns:
|
|
// - bool: True if valid, false otherwise.
|
|
func isValidHostname(hostname string) bool {
|
|
// Must contain at least one dot (.) unless it's "localhost"
|
|
if hostname == "localhost" {
|
|
return true
|
|
}
|
|
return strings.Contains(hostname, ".")
|
|
}
|
|
|
|
// LoginToBase64 encodes a username and password into a base64-encoded string in "username:password" format.
|
|
// This is commonly used for HTTP Basic Authentication.
|
|
//
|
|
// Parameters:
|
|
// - username: string - The username to encode.
|
|
// - password: string - The password to encode.
|
|
//
|
|
// Returns:
|
|
// - string: The base64-encoded "username:password" string.
|
|
func LoginToBase64(username, password string) string {
|
|
return base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
|
}
|
|
|
|
// CheckIfname validates a network interface name according to the following rules:
|
|
// - Must not be empty.
|
|
// - Should not exceed 15 characters (warning if it does).
|
|
// - Should not contain non-ASCII characters (warning if it does).
|
|
// - Should not contain invalid characters: '/', whitespace, or control characters.
|
|
//
|
|
// Parameters:
|
|
// - name: string - The interface name to validate.
|
|
//
|
|
// Returns:
|
|
// - error: An error if the name is invalid, or nil if valid.
|
|
func CheckIfname(name string) error {
|
|
if name == "" {
|
|
return errors.New("interface name cannot be empty")
|
|
}
|
|
|
|
if len(name) >= 16 {
|
|
log.Printf("Warning: interface name '%s' is longer than %d characters", name, 16-1)
|
|
}
|
|
|
|
var invalidChar bool
|
|
var hasWhitespace bool
|
|
|
|
for _, r := range name {
|
|
if r > 127 {
|
|
invalidChar = true
|
|
break
|
|
}
|
|
if r == '/' || r == ' ' || strings.ContainsRune("\t\n\v\f\r", r) {
|
|
hasWhitespace = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if invalidChar {
|
|
log.Printf("Warning: interface name contains non-ASCII character")
|
|
}
|
|
|
|
if hasWhitespace {
|
|
return errors.New("interface name contains invalid character: '/' or whitespace")
|
|
}
|
|
|
|
return nil
|
|
}
|