Fix tests

This commit is contained in:
世界
2022-08-22 12:02:16 +08:00
parent 6109040c4c
commit 045e6d1976
7 changed files with 77 additions and 34 deletions

View File

@@ -37,6 +37,7 @@ type Hysteria struct {
xplusKey []byte
sendBPS uint64
recvBPS uint64
udpListener net.PacketConn
listener quic.Listener
udpAccess sync.RWMutex
udpSessionId uint32
@@ -146,6 +147,7 @@ func (h *Hysteria) Start() error {
packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
}
h.udpListener = packetConn
err = h.tlsConfig.Start()
if err != nil {
return err
@@ -314,6 +316,7 @@ func (h *Hysteria) Close() error {
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
h.udpAccess.Unlock()
return common.Close(
h.udpListener,
h.listener,
common.PtrOrNil(h.tlsConfig),
)

View File

@@ -10,7 +10,6 @@ import (
"net/http"
"net/netip"
"os"
"runtime"
"strings"
"time"
@@ -250,8 +249,7 @@ type naivePaddingConn struct {
func (c *naivePaddingConn) Read(p []byte) (n int, err error) {
n, err = c.read(p)
err = wrapHttpError(err)
return
return n, wrapHttpError(err)
}
func (c *naivePaddingConn) read(p []byte) (n int, err error) {
@@ -259,7 +257,7 @@ func (c *naivePaddingConn) read(p []byte) (n int, err error) {
if len(p) > c.readRemaining {
p = p[:c.readRemaining]
}
n, err = c.read(p)
n, err = c.reader.Read(p)
if err != nil {
return
}
@@ -297,35 +295,69 @@ func (c *naivePaddingConn) read(p []byte) (n int, err error) {
}
func (c *naivePaddingConn) Write(p []byte) (n int, err error) {
n, err = c.write(p)
for pLen := len(p); pLen > 0; {
var data []byte
if pLen > 65535 {
data = p[:65535]
p = p[65535:]
pLen -= 65535
} else {
data = p
pLen = 0
}
var writeN int
writeN, err = c.write(data)
n += writeN
if err != nil {
break
}
}
if err == nil {
c.flusher.Flush()
}
err = wrapHttpError(err)
return
return n, wrapHttpError(err)
}
func (c *naivePaddingConn) write(p []byte) (n int, err error) {
if c.writePadding < kFirstPaddings {
paddingSize := rand.Intn(256)
_buffer := buf.Make(3 + len(p) + paddingSize)
defer runtime.KeepAlive(_buffer)
_buffer := buf.StackNewSize(3 + len(p) + paddingSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
binary.BigEndian.PutUint16(buffer, uint16(len(p)))
buffer[2] = byte(paddingSize)
copy(buffer[3:], p)
_, err = c.writer.Write(buffer)
if err != nil {
return
defer buffer.Release()
header := buffer.Extend(3)
binary.BigEndian.PutUint16(header, uint16(len(p)))
header[2] = byte(paddingSize)
common.Must1(buffer.Write(p))
_, err = c.writer.Write(buffer.Bytes())
if err == nil {
n = len(p)
}
c.writePadding++
return
}
return c.writer.Write(p)
}
func (c *naivePaddingConn) FrontHeadroom() int {
if c.writePadding < kFirstPaddings {
return 3 + 255
return 3
}
return 0
}
func (c *naivePaddingConn) RearHeadroom() int {
if c.writePadding < kFirstPaddings {
return 255
}
return 0
}
func (c *naivePaddingConn) WriterMTU() int {
if c.writePadding < kFirstPaddings {
return 65535
}
return 0
}
@@ -334,6 +366,9 @@ func (c *naivePaddingConn) WriteBuffer(buffer *buf.Buffer) error {
defer buffer.Release()
if c.writePadding < kFirstPaddings {
bufferLen := buffer.Len()
if bufferLen > 65535 {
return common.Error(c.Write(buffer.Bytes()))
}
paddingSize := rand.Intn(256)
header := buffer.ExtendHeader(3)
binary.BigEndian.PutUint16(header, uint16(bufferLen))
@@ -350,16 +385,20 @@ func (c *naivePaddingConn) WriteBuffer(buffer *buf.Buffer) error {
func (c *naivePaddingConn) WriteTo(w io.Writer) (n int64, err error) {
if c.readPadding < kFirstPaddings {
return bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
} else {
n, err = bufio.Copy(w, c.reader)
}
return bufio.Copy(w, c.reader)
return n, wrapHttpError(err)
}
func (c *naivePaddingConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writePadding < kFirstPaddings {
return bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
} else {
n, err = bufio.Copy(c.writer, r)
}
return bufio.Copy(c.writer, r)
return n, wrapHttpError(err)
}
func (c *naivePaddingConn) Close() error {
@@ -389,14 +428,14 @@ func (c *naivePaddingConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
var http2errClientDisconnected = "client disconnected"
func wrapHttpError(err error) error {
if err == nil {
return err
}
switch err.Error() {
case http2errClientDisconnected:
if strings.Contains(err.Error(), "client disconnected") {
return net.ErrClosed
}
if strings.Contains(err.Error(), "body closed by handler") {
return net.ErrClosed
}
return err