mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-08 04:04:55 +03:00
323 lines
9.2 KiB
Go
323 lines
9.2 KiB
Go
package ssh
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
"github.com/sagernet/sing-box/common/dialer"
|
|
"github.com/sagernet/sing-box/common/onclose"
|
|
C "github.com/sagernet/sing-box/constant"
|
|
"github.com/sagernet/sing-box/option"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
"github.com/sagernet/sing/common/logger"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
var _ Service = (*Fallback)(nil)
|
|
|
|
type Fallback struct {
|
|
Service
|
|
ctx context.Context
|
|
logger logger.ContextLogger
|
|
dialer N.Dialer
|
|
serverAddr M.Socksaddr
|
|
clientVersion string
|
|
mainSigner ssh.Signer
|
|
issueSigner ssh.Signer
|
|
hostKeys []ssh.PublicKey
|
|
keyAlgorithms []string
|
|
pending map[string]*upstreamConn
|
|
mtx sync.Mutex
|
|
}
|
|
|
|
type upstreamConn struct {
|
|
conn net.Conn
|
|
client ssh.Conn
|
|
channels <-chan ssh.NewChannel
|
|
requests <-chan *ssh.Request
|
|
}
|
|
|
|
func NewFallback(ctx context.Context, logger logger.ContextLogger, inner Service, options *option.SSHFallbackServerOptions) (*Fallback, error) {
|
|
serverAddr := options.Build()
|
|
if serverAddr.Port == 0 {
|
|
serverAddr.Port = 22
|
|
}
|
|
if !serverAddr.Addr.IsValid() && serverAddr.Fqdn == "" {
|
|
return nil, E.New("missing upstream server address")
|
|
}
|
|
upstreamDialer, err := dialer.New(ctx, options.DialerOptions, serverAddr.IsFqdn())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fallback := &Fallback{
|
|
Service: inner,
|
|
ctx: ctx,
|
|
logger: logger,
|
|
dialer: upstreamDialer,
|
|
serverAddr: serverAddr,
|
|
clientVersion: options.ClientVersion,
|
|
keyAlgorithms: options.HostKeyAlgorithms,
|
|
pending: make(map[string]*upstreamConn),
|
|
}
|
|
if fallback.clientVersion == "" {
|
|
fallback.clientVersion = "SSH-2.0-OpenSSH_9.6"
|
|
}
|
|
if options.CA != nil {
|
|
signer, err := parseCAKey(options.CA)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "parse CA")
|
|
}
|
|
fallback.mainSigner = signer
|
|
}
|
|
if options.IssueCA != nil {
|
|
signer, err := parseCAKey(options.IssueCA)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "parse issue CA")
|
|
}
|
|
fallback.issueSigner = signer
|
|
}
|
|
if fallback.issueSigner == nil && fallback.mainSigner != nil {
|
|
fallback.issueSigner = fallback.mainSigner
|
|
}
|
|
for _, hostKey := range options.HostKey {
|
|
key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(hostKey))
|
|
if err != nil {
|
|
return nil, E.New("parse upstream host key ", hostKey)
|
|
}
|
|
fallback.hostKeys = append(fallback.hostKeys, key)
|
|
}
|
|
for _, hostKeyPath := range options.HostKeyPath {
|
|
content, err := os.ReadFile(os.ExpandEnv(hostKeyPath))
|
|
if err != nil {
|
|
return nil, E.Cause(err, "read upstream host key ", hostKeyPath)
|
|
}
|
|
key, _, _, _, err := ssh.ParseAuthorizedKey(content)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "parse upstream host key ", hostKeyPath)
|
|
}
|
|
fallback.hostKeys = append(fallback.hostKeys, key)
|
|
}
|
|
return fallback, nil
|
|
}
|
|
|
|
func (f *Fallback) PasswordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
if permissions, err := f.Service.PasswordCallback(conn, password); err == nil {
|
|
return permissions, nil
|
|
}
|
|
if err := f.dial(string(conn.SessionID()), conn.User(), ssh.Password(string(password))); err != nil {
|
|
return nil, E.Cause(err, "upstream authentication failed for user ", conn.User())
|
|
}
|
|
return &ssh.Permissions{Extensions: map[string]string{"user": conn.User(), "fallback": "1"}}, nil
|
|
}
|
|
|
|
func (f *Fallback) PublicKeyCallback(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
if permissions, err := f.Service.PublicKeyCallback(conn, key); err == nil {
|
|
return permissions, nil
|
|
}
|
|
if verifyCertificate(f.mainSigner, conn, key) {
|
|
signer, err := issueCertificate(f.issueSigner, conn.User())
|
|
if err != nil {
|
|
return nil, E.Cause(err, "upstream authentication failed for user ", conn.User())
|
|
}
|
|
if err := f.dial(string(conn.SessionID()), conn.User(), ssh.PublicKeys(signer)); err != nil {
|
|
return nil, E.Cause(err, "upstream authentication failed for user ", conn.User())
|
|
}
|
|
return &ssh.Permissions{Extensions: map[string]string{"user": conn.User(), "fallback": "1"}}, nil
|
|
}
|
|
return nil, E.New("public key authentication failed for user ", conn.User())
|
|
}
|
|
|
|
func (f *Fallback) Handle(ctx context.Context, serverConn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request, metadata adapter.InboundContext, user string) {
|
|
if serverConn.Permissions == nil || serverConn.Permissions.Extensions["fallback"] != "1" {
|
|
f.Service.Handle(ctx, serverConn, channels, requests, metadata, user)
|
|
return
|
|
}
|
|
sessionID := string(serverConn.SessionID())
|
|
f.mtx.Lock()
|
|
upstream := f.pending[sessionID]
|
|
delete(f.pending, sessionID)
|
|
f.mtx.Unlock()
|
|
if upstream == nil {
|
|
serverConn.Close()
|
|
return
|
|
}
|
|
f.logger.InfoContext(ctx, "[", user, "] forwarded SSH connection from ", metadata.Source)
|
|
go proxyDownstreamRequests(requests, upstream.client)
|
|
go proxyGlobalRequests(upstream.requests, serverConn)
|
|
go func() {
|
|
for newChannel := range upstream.channels {
|
|
go proxyChannel(newChannel, serverConn)
|
|
}
|
|
}()
|
|
var wg sync.WaitGroup
|
|
for newChannel := range channels {
|
|
wg.Go(func() {
|
|
proxyChannel(newChannel, upstream.client)
|
|
})
|
|
}
|
|
wg.Wait()
|
|
upstream.client.Close()
|
|
upstream.conn.Close()
|
|
serverConn.Close()
|
|
}
|
|
|
|
func (f *Fallback) Close() error {
|
|
f.mtx.Lock()
|
|
connections := make([]net.Conn, 0, len(f.pending))
|
|
for id, upstream := range f.pending {
|
|
if upstream != nil {
|
|
connections = append(connections, upstream.conn)
|
|
}
|
|
delete(f.pending, id)
|
|
}
|
|
f.mtx.Unlock()
|
|
for _, conn := range connections {
|
|
conn.Close()
|
|
}
|
|
return f.Service.Close()
|
|
}
|
|
|
|
func (f *Fallback) dial(sessionID string, user string, auth ssh.AuthMethod) error {
|
|
f.mtx.Lock()
|
|
if _, attempted := f.pending[sessionID]; attempted {
|
|
f.mtx.Unlock()
|
|
return E.New("fallback already attempted")
|
|
}
|
|
f.pending[sessionID] = nil
|
|
f.mtx.Unlock()
|
|
conn, err := f.dialer.DialContext(f.ctx, N.NetworkTCP, f.serverAddr)
|
|
if err != nil {
|
|
f.mtx.Lock()
|
|
delete(f.pending, sessionID)
|
|
f.mtx.Unlock()
|
|
return err
|
|
}
|
|
conn = onclose.NewConn(conn, func() {
|
|
f.mtx.Lock()
|
|
delete(f.pending, sessionID)
|
|
f.mtx.Unlock()
|
|
})
|
|
config := &ssh.ClientConfig{
|
|
User: user,
|
|
Auth: []ssh.AuthMethod{auth},
|
|
ClientVersion: f.clientVersion,
|
|
HostKeyAlgorithms: f.keyAlgorithms,
|
|
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
if len(f.hostKeys) == 0 {
|
|
return nil
|
|
}
|
|
serverKey := key.Marshal()
|
|
for _, hostKey := range f.hostKeys {
|
|
if bytes.Equal(serverKey, hostKey.Marshal()) {
|
|
return nil
|
|
}
|
|
}
|
|
return E.New("upstream host key mismatch, server sent ", key.Type(), " ", base64.StdEncoding.EncodeToString(serverKey))
|
|
},
|
|
Timeout: C.TCPTimeout,
|
|
}
|
|
client, channels, requests, err := ssh.NewClientConn(conn, f.serverAddr.String(), config)
|
|
if err != nil {
|
|
conn.Close()
|
|
return err
|
|
}
|
|
f.mtx.Lock()
|
|
f.pending[sessionID] = &upstreamConn{conn: conn, client: client, channels: channels, requests: requests}
|
|
f.mtx.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func proxyChannel(newChannel ssh.NewChannel, target ssh.Conn) {
|
|
targetChannel, targetRequests, err := target.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData())
|
|
if err != nil {
|
|
if openErr, ok := err.(*ssh.OpenChannelError); ok {
|
|
newChannel.Reject(openErr.Reason, openErr.Message)
|
|
} else {
|
|
newChannel.Reject(ssh.ConnectionFailed, err.Error())
|
|
}
|
|
return
|
|
}
|
|
sourceChannel, sourceRequests, err := newChannel.Accept()
|
|
if err != nil {
|
|
targetChannel.Close()
|
|
return
|
|
}
|
|
go proxyChannelRequests(sourceRequests, targetChannel)
|
|
go io.Copy(targetChannel.Stderr(), sourceChannel.Stderr())
|
|
go io.Copy(sourceChannel.Stderr(), targetChannel.Stderr())
|
|
go func() {
|
|
io.Copy(targetChannel, sourceChannel)
|
|
targetChannel.CloseWrite()
|
|
}()
|
|
go func() {
|
|
io.Copy(sourceChannel, targetChannel)
|
|
sourceChannel.CloseWrite()
|
|
}()
|
|
proxyChannelRequests(targetRequests, sourceChannel)
|
|
sourceChannel.Close()
|
|
targetChannel.Close()
|
|
}
|
|
|
|
func proxyGlobalRequests(requests <-chan *ssh.Request, target ssh.Conn) {
|
|
for request := range requests {
|
|
if request.Type == "hostkeys-00@openssh.com" {
|
|
if request.WantReply {
|
|
request.Reply(false, nil)
|
|
}
|
|
continue
|
|
}
|
|
ok, payload, err := target.SendRequest(request.Type, request.WantReply, request.Payload)
|
|
if request.WantReply {
|
|
if err != nil {
|
|
request.Reply(false, nil)
|
|
} else {
|
|
request.Reply(ok, payload)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func proxyDownstreamRequests(requests <-chan *ssh.Request, target ssh.Conn) {
|
|
for request := range requests {
|
|
switch request.Type {
|
|
case "no-more-sessions@openssh.com", "hostkeys-prove-00@openssh.com":
|
|
if request.WantReply {
|
|
request.Reply(false, nil)
|
|
}
|
|
continue
|
|
}
|
|
ok, payload, err := target.SendRequest(request.Type, request.WantReply, request.Payload)
|
|
if request.WantReply {
|
|
if err != nil {
|
|
request.Reply(false, nil)
|
|
} else {
|
|
request.Reply(ok, payload)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func proxyChannelRequests(requests <-chan *ssh.Request, target ssh.Channel) {
|
|
for request := range requests {
|
|
ok, err := target.SendRequest(request.Type, request.WantReply, request.Payload)
|
|
if request.WantReply {
|
|
if err != nil {
|
|
request.Reply(false, nil)
|
|
} else {
|
|
request.Reply(ok, nil)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|