Files

238 lines
6.6 KiB
Go

package xhttp
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"strings"
"sync"
common "github.com/sagernet/sing-box/common/xray"
"github.com/sagernet/sing-box/common/xray/buf"
"github.com/sagernet/sing-box/common/xray/signal/done"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
// interface to abstract between use of browser dialer, vs net/http
type DialerClient interface {
IsClosed() bool
// ctx, url, sessionId, body, uploadOnly
OpenStream(context.Context, string, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error)
// ctx, url, sessionId, seqStr, payload
PostPacket(context.Context, string, string, string, buf.MultiBuffer) error
}
// implements xhttp.DialerClient in terms of direct network connections
type DefaultDialerClient struct {
options *option.V2RayXHTTPBaseOptions
client *http.Client
closed bool
httpVersion string
// pool of net.Conn, created using dialUploadConn
uploadRawPool *sync.Pool
dialUploadConn func(ctxInner context.Context) (net.Conn, error)
}
func (c *DefaultDialerClient) IsClosed() bool {
return c.closed
}
func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, sessionId string, body io.Reader, uploadOnly bool) (wrc io.ReadCloser, remoteAddr, localAddr net.Addr, err error) {
// this is done when the TCP/UDP connection to the server was established,
// and we can unblock the Dial function and print correct net addresses in
// logs
gotConn := done.New()
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
remoteAddr = connInfo.Conn.RemoteAddr()
localAddr = connInfo.Conn.LocalAddr()
gotConn.Close()
},
})
method := "GET" // stream-down
if body != nil {
method = c.options.GetNormalizedUplinkHTTPMethod() // stream-up/one
}
req, _ := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, body)
FillStreamRequest(req, sessionId, "", c.options)
wrc = &WaitReadCloser{Wait: make(chan struct{})}
go func() {
resp, err := c.client.Do(req)
if err != nil {
if !uploadOnly { // stream-down is enough
c.closed = true
}
gotConn.Close()
wrc.Close()
return
}
if resp.StatusCode != 200 || uploadOnly { // stream-up
io.Copy(io.Discard, resp.Body)
resp.Body.Close() // if it is called immediately, the upload will be interrupted also
wrc.Close()
return
}
wrc.(*WaitReadCloser).Set(resp.Body)
}()
<-gotConn.Wait()
return
}
func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, payload buf.MultiBuffer) error {
method := c.options.GetNormalizedUplinkHTTPMethod()
req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, nil)
if err != nil {
return err
}
FillPacketRequest(req, sessionId, seqStr, payload, c.options)
if c.httpVersion != "1.1" {
resp, err := c.client.Do(req)
if err != nil {
c.closed = true
return err
}
io.Copy(io.Discard, resp.Body)
defer resp.Body.Close()
if resp.StatusCode != 200 {
return E.New("bad status code: ", resp.Status)
}
} else {
// stringify the entire HTTP/1.1 request so it can be
// safely retried. if instead req.Write is called multiple
// times, the body is already drained after the first
// request
requestBuff := new(bytes.Buffer)
requestBuff.Grow(512 + int(req.ContentLength))
common.Must(req.Write(requestBuff))
var uploadConn any
var h1UploadConn *H1Conn
for {
uploadConn = c.uploadRawPool.Get()
newConnection := uploadConn == nil
if newConnection {
newConn, err := c.dialUploadConn(context.WithoutCancel(ctx))
if err != nil {
return err
}
h1UploadConn = NewH1Conn(newConn)
uploadConn = h1UploadConn
} else {
h1UploadConn = uploadConn.(*H1Conn)
// TODO: Replace 0 here with a config value later
// Or add some other condition for optimization purposes
if h1UploadConn.UnreadedResponsesCount > 0 {
resp, err := http.ReadResponse(h1UploadConn.RespBufReader, req)
if err != nil {
c.closed = true
return fmt.Errorf("error while reading response: %s", err.Error())
}
io.Copy(io.Discard, resp.Body)
defer resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("got non-200 error response code: %d", resp.StatusCode)
}
}
}
_, err := h1UploadConn.Write(requestBuff.Bytes())
// if the write failed, we try another connection from
// the pool, until the write on a new connection fails.
// failed writes to a pooled connection are normal when
// the connection has been closed in the meantime.
if err == nil {
break
} else if newConnection {
return err
}
}
c.uploadRawPool.Put(uploadConn)
}
return nil
}
type WaitReadCloser struct {
Wait chan struct{}
io.ReadCloser
}
func (w *WaitReadCloser) Set(rc io.ReadCloser) {
w.ReadCloser = rc
defer func() {
if recover() != nil {
rc.Close()
}
}()
close(w.Wait)
}
func (w *WaitReadCloser) Read(b []byte) (int, error) {
if w.ReadCloser == nil {
if <-w.Wait; w.ReadCloser == nil {
return 0, io.ErrClosedPipe
}
}
return w.ReadCloser.Read(b)
}
func (w *WaitReadCloser) Close() error {
if w.ReadCloser != nil {
return w.ReadCloser.Close()
}
defer func() {
if recover() != nil && w.ReadCloser != nil {
w.ReadCloser.Close()
}
}()
close(w.Wait)
return nil
}
func ApplyMetaToRequest(options *option.V2RayXHTTPBaseOptions, req *http.Request, sessionId string, seqStr string) {
sessionPlacement := options.GetNormalizedSessionPlacement()
seqPlacement := options.GetNormalizedSeqPlacement()
sessionKey := options.GetNormalizedSessionKey()
seqKey := options.GetNormalizedSeqKey()
if sessionId != "" {
switch sessionPlacement {
case option.PlacementPath:
req.URL.Path = appendToPath(req.URL.Path, sessionId)
case option.PlacementQuery:
q := req.URL.Query()
q.Set(sessionKey, sessionId)
req.URL.RawQuery = q.Encode()
case option.PlacementHeader:
req.Header.Set(sessionKey, sessionId)
case option.PlacementCookie:
req.AddCookie(&http.Cookie{Name: sessionKey, Value: sessionId})
}
}
if seqStr != "" {
switch seqPlacement {
case option.PlacementPath:
req.URL.Path = appendToPath(req.URL.Path, seqStr)
case option.PlacementQuery:
q := req.URL.Query()
q.Set(seqKey, seqStr)
req.URL.RawQuery = q.Encode()
case option.PlacementHeader:
req.Header.Set(seqKey, seqStr)
case option.PlacementCookie:
req.AddCookie(&http.Cookie{Name: seqKey, Value: seqStr})
}
}
}
func appendToPath(path, value string) string {
if strings.HasSuffix(path, "/") {
return path + value
}
return path + "/" + value
}