diff --git a/Makefile b/Makefile index e83dc992..f78f39ed 100644 --- a/Makefile +++ b/Makefile @@ -46,7 +46,7 @@ lint: GOOS=android golangci-lint run ./... GOOS=windows golangci-lint run ./... GOOS=darwin golangci-lint run ./... - GOOS=freebsd golangci-lint run ./... +# GOOS=freebsd golangci-lint run ./... lint_install: go install -v github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest diff --git a/cmd/internal/protogen/main.go b/cmd/internal/protogen/main.go index 4d5023f7..1a5d59b0 100644 --- a/cmd/internal/protogen/main.go +++ b/cmd/internal/protogen/main.go @@ -48,8 +48,8 @@ func GetRuntimeEnv(key string) (string, error) { if readErr != nil { return "", readErr } - envStrings := strings.Split(string(data), "\n") - for _, envItem := range envStrings { + envStrings := strings.SplitSeq(string(data), "\n") + for envItem := range envStrings { envItem = strings.TrimSuffix(envItem, "\r") envKeyValue := strings.Split(envItem, "=") if strings.EqualFold(strings.TrimSpace(envKeyValue[0]), key) { diff --git a/cmd/internal/update_android_version/main.go b/cmd/internal/update_android_version/main.go index 4850fce0..2278eeac 100644 --- a/cmd/internal/update_android_version/main.go +++ b/cmd/internal/update_android_version/main.go @@ -39,7 +39,7 @@ func main() { common.Must(os.Chdir(androidPath)) localProps := common.Must1(os.ReadFile("version.properties")) var propsList [][]string - for _, propLine := range strings.Split(string(localProps), "\n") { + for propLine := range strings.SplitSeq(string(localProps), "\n") { propsList = append(propsList, strings.Split(propLine, "=")) } var ( diff --git a/cmd/sing-box/cmd_geoip_export.go b/cmd/sing-box/cmd_geoip_export.go index b80e5cd3..6f59b4d5 100644 --- a/cmd/sing-box/cmd_geoip_export.go +++ b/cmd/sing-box/cmd_geoip_export.go @@ -61,16 +61,17 @@ func geoipExport(countryCode string) error { outputFile *os.File outputWriter io.Writer ) - if flagGeoipExportOutput == "stdout" { + switch flagGeoipExportOutput { + case "stdout": outputWriter = os.Stdout - } else if flagGeoipExportOutput == flagGeoipExportDefaultOutput { + case flagGeoipExportDefaultOutput: outputFile, err = os.Create("geoip-" + countryCode + ".json") if err != nil { return err } defer outputFile.Close() outputWriter = outputFile - } else { + default: outputFile, err = os.Create(flagGeoipExportOutput) if err != nil { return err diff --git a/cmd/sing-box/cmd_geosite_export.go b/cmd/sing-box/cmd_geosite_export.go index 90a7955b..573cc1df 100644 --- a/cmd/sing-box/cmd_geosite_export.go +++ b/cmd/sing-box/cmd_geosite_export.go @@ -43,16 +43,17 @@ func geositeExport(category string) error { outputFile *os.File outputWriter io.Writer ) - if commandGeositeExportOutput == "stdout" { + switch commandGeositeExportOutput { + case "stdout": outputWriter = os.Stdout - } else if commandGeositeExportOutput == commandGeositeExportDefaultOutput { + case commandGeositeExportDefaultOutput: outputFile, err = os.Create("geosite-" + category + ".json") if err != nil { return err } defer outputFile.Close() outputWriter = outputFile - } else { + default: outputFile, err = os.Create(commandGeositeExportOutput) if err != nil { return err diff --git a/common/badversion/version.go b/common/badversion/version.go index a8404297..3da6766c 100644 --- a/common/badversion/version.go +++ b/common/badversion/version.go @@ -112,9 +112,7 @@ func IsValid(versionName string) bool { } func Parse(versionName string) (version Version) { - if strings.HasPrefix(versionName, "v") { - versionName = versionName[1:] - } + versionName = strings.TrimPrefix(versionName, "v") if strings.Contains(versionName, "-") { parts := strings.Split(versionName, "-") versionName = parts[0] diff --git a/common/convertor/adguard/convertor.go b/common/convertor/adguard/convertor.go index 3e6d0254..187c4f4d 100644 --- a/common/convertor/adguard/convertor.go +++ b/common/convertor/adguard/convertor.go @@ -63,9 +63,7 @@ parseLine: } continue } - if strings.HasSuffix(ruleLine, "|") { - ruleLine = ruleLine[:len(ruleLine)-1] - } + ruleLine = strings.TrimSuffix(ruleLine, "|") var ( isExclude bool isSuffix bool @@ -76,7 +74,7 @@ parseLine: ) if !strings.HasPrefix(ruleLine, "/") && strings.Contains(ruleLine, "$") { params := common.SubstringAfter(ruleLine, "$") - for _, param := range strings.Split(params, ",") { + for param := range strings.SplitSeq(params, ",") { paramParts := strings.Split(param, "=") var ignored bool if len(paramParts) > 0 && len(paramParts) <= 2 { @@ -106,9 +104,7 @@ parseLine: ruleLine = ruleLine[2:] isExclude = true } - if strings.HasSuffix(ruleLine, "|") { - ruleLine = ruleLine[:len(ruleLine)-1] - } + ruleLine = strings.TrimSuffix(ruleLine, "|") if strings.HasPrefix(ruleLine, "||") { ruleLine = ruleLine[2:] isSuffix = true @@ -414,18 +410,18 @@ func ignoreIPCIDRRegexp(ruleLine string) bool { } func parseAdGuardHostLine(ruleLine string) (string, error) { - idx := strings.Index(ruleLine, " ") - if idx == -1 { + before, after, ok := strings.Cut(ruleLine, " ") + if !ok { return "", os.ErrInvalid } - address, err := netip.ParseAddr(ruleLine[:idx]) + address, err := netip.ParseAddr(before) if err != nil { return "", err } if !address.IsUnspecified() { return "", nil } - domain := ruleLine[idx+1:] + domain := after if !M.IsDomainName(domain) { return "", E.New("invalid domain name: ", domain) } diff --git a/common/dialer/default_parallel_interface.go b/common/dialer/default_parallel_interface.go index eafab75a..e91abc28 100644 --- a/common/dialer/default_parallel_interface.go +++ b/common/dialer/default_parallel_interface.go @@ -136,18 +136,16 @@ func (d *DefaultDialer) dialParallelInterfaceFastFallback(ctx context.Context, d go startRacer(fallbackCtx, false, iif) } var errors []error - for { - select { - case res := <-results: - if res.error == nil { - return res.Conn, res.primary, nil - } - errors = append(errors, res.error) - if len(errors) == len(primaryInterfaces)+len(fallbackInterfaces) { - return nil, false, E.Errors(errors...) - } + for res := range results { + if res.error == nil { + return res.Conn, res.primary, nil + } + errors = append(errors, res.error) + if len(errors) == len(primaryInterfaces)+len(fallbackInterfaces) { + return nil, false, E.Errors(errors...) } } + return nil, false, E.Errors(errors...) } func (d *DefaultDialer) listenSerialInterfacePacket(ctx context.Context, listener net.ListenConfig, network string, addr string, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { diff --git a/common/geosite/compat_test.go b/common/geosite/compat_test.go index 1a55c644..9c66aea3 100644 --- a/common/geosite/compat_test.go +++ b/common/geosite/compat_test.go @@ -19,11 +19,6 @@ func oldWriteString(writer varbin.Writer, value string) error { return varbin.Write(writer, binary.BigEndian, value) } -func oldWriteItem(writer varbin.Writer, item Item) error { - //nolint:staticcheck - return varbin.Write(writer, binary.BigEndian, item) -} - func oldReadString(reader varbin.Reader) (string, error) { //nolint:staticcheck return varbin.ReadValue[string](reader, binary.BigEndian) @@ -224,7 +219,7 @@ func TestGeositeWriteReadCompat(t *testing.T) { func generateLargeItems(count int) map[string][]Item { items := make([]Item, count) - for i := 0; i < count; i++ { + for i := range count { items[i] = Item{ Type: ItemType(i % 4), Value: strings.Repeat("x", i%200) + ".com", diff --git a/common/geosite/reader.go b/common/geosite/reader.go index ef99837d..ecd63a7e 100644 --- a/common/geosite/reader.go +++ b/common/geosite/reader.go @@ -48,12 +48,6 @@ func NewReader(readSeeker io.ReadSeeker) (*Reader, []string, error) { return reader, codes, nil } -type geositeMetadata struct { - Code string - Index uint64 - Length uint64 -} - func (r *Reader) readMetadata() error { counter := &readCounter{Reader: r.reader} reader := bufio.NewReader(counter) @@ -101,6 +95,9 @@ func (r *Reader) readMetadata() error { } func (r *Reader) Read(code string) ([]Item, error) { + r.access.Lock() + defer r.access.Unlock() + index, exists := r.domainIndex[code] if !exists { return nil, E.New("code ", code, " not exists!") diff --git a/common/ja3/parser.go b/common/ja3/parser.go index f9cca603..a4bd7123 100644 --- a/common/ja3/parser.go +++ b/common/ja3/parser.go @@ -131,7 +131,7 @@ func (j *ClientHello) parseHandshake(hs []byte) error { return &ParseError{LengthErr, 7} } - for i := 0; i < numCiphers; i++ { + for i := range numCiphers { cipherSuite := uint16(cs[2+i<<1])<<8 | uint16(cs[3+i<<1]) cipherSuites = append(cipherSuites, cipherSuite) } @@ -234,7 +234,7 @@ func (j *ClientHello) parseExtensions(exs []byte) error { return &ParseError{LengthErr, 16} } - for i := 0; i < numCurves; i++ { + for i := range numCurves { ecType := uint16(sex[i*2])<<8 | uint16(sex[1+i*2]) ellipticCurves = append(ellipticCurves, ecType) } @@ -256,7 +256,7 @@ func (j *ClientHello) parseExtensions(exs []byte) error { return &ParseError{LengthErr, 18} } - for i := 0; i < numPF; i++ { + for i := range numPF { ellipticCurvePF[i] = uint8(sex[i]) } case versionExtensionType: diff --git a/common/ktls/ktls_handshake_messages.go b/common/ktls/ktls_handshake_messages.go index f44958c0..adf023bb 100644 --- a/common/ktls/ktls_handshake_messages.go +++ b/common/ktls/ktls_handshake_messages.go @@ -6,48 +6,7 @@ package ktls -import ( - "fmt" - - "golang.org/x/crypto/cryptobyte" -) - -// The marshalingFunction type is an adapter to allow the use of ordinary -// functions as cryptobyte.MarshalingValue. -type marshalingFunction func(b *cryptobyte.Builder) error - -func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { - return f(b) -} - -// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If -// the length of the sequence is not the value specified, it produces an error. -func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { - b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { - if len(v) != n { - return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) - } - b.AddBytes(v) - return nil - })) -} - -// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. -func addUint64(b *cryptobyte.Builder, v uint64) { - b.AddUint32(uint32(v >> 32)) - b.AddUint32(uint32(v)) -} - -// readUint64 decodes a big-endian, 64-bit value into out and advances over it. -// It reports whether the read was successful. -func readUint64(s *cryptobyte.String, out *uint64) bool { - var hi, lo uint32 - if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { - return false - } - *out = uint64(hi)<<32 | uint64(lo) - return true -} +import "golang.org/x/crypto/cryptobyte" // readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a // []byte instead of a cryptobyte.String. @@ -61,12 +20,6 @@ func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) } -// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) -} - type keyUpdateMsg struct { updateRequested bool } @@ -125,11 +78,6 @@ const ( typeMessageHash uint8 = 254 // synthetic message ) -// TLS compression types. -const ( - compressionNone uint8 = 0 -) - // TLS extension numbers const ( extensionServerName uint16 = 0 diff --git a/common/ktls/ktls_write.go b/common/ktls/ktls_write.go index 76533b4a..f4e0f65d 100644 --- a/common/ktls/ktls_write.go +++ b/common/ktls/ktls_write.go @@ -77,78 +77,5 @@ func (c *Conn) writeRecordLocked(typ uint16, data []byte) (n int, err error) { if !c.kernelTx { return c.rawConn.WriteRecordLocked(typ, data) } - /*for len(data) > 0 { - m := len(data) - if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { - m = maxPayload - } - _, err = c.writeKernelRecord(typ, data[:m]) - if err != nil { - return - } - n += m - data = data[m:] - }*/ return c.writeKernelRecord(typ, data) } - -const ( - // tcpMSSEstimate is a conservative estimate of the TCP maximum segment - // size (MSS). A constant is used, rather than querying the kernel for - // the actual MSS, to avoid complexity. The value here is the IPv6 - // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 - // bytes) and a TCP header with timestamps (32 bytes). - tcpMSSEstimate = 1208 - - // recordSizeBoostThreshold is the number of bytes of application data - // sent after which the TLS record size will be increased to the - // maximum. - recordSizeBoostThreshold = 128 * 1024 -) - -func (c *Conn) maxPayloadSizeForWrite(typ uint16) int { - if /*c.config.DynamicRecordSizingDisabled ||*/ typ != recordTypeApplicationData { - return maxPlaintext - } - - if *c.rawConn.PacketsSent >= recordSizeBoostThreshold { - return maxPlaintext - } - - // Subtract TLS overheads to get the maximum payload size. - payloadBytes := tcpMSSEstimate - recordHeaderLen - c.rawConn.Out.ExplicitNonceLen() - if rawCipher := *c.rawConn.Out.Cipher; rawCipher != nil { - switch ciph := rawCipher.(type) { - case cipher.Stream: - payloadBytes -= (*c.rawConn.Out.Mac).Size() - case cipher.AEAD: - payloadBytes -= ciph.Overhead() - /*case cbcMode: - blockSize := ciph.BlockSize() - // The payload must fit in a multiple of blockSize, with - // room for at least one padding byte. - payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 - // The RawMac is appended before padding so affects the - // payload size directly. - payloadBytes -= c.out.mac.Size()*/ - default: - panic("unknown cipher type") - } - } - if *c.rawConn.Vers == tls.VersionTLS13 { - payloadBytes-- // encrypted ContentType - } - - // Allow packet growth in arithmetic progression up to max. - pkt := *c.rawConn.PacketsSent - *c.rawConn.PacketsSent++ - if pkt > 1000 { - return maxPlaintext // avoid overflow in multiply below - } - - n := payloadBytes * int(pkt+1) - if n > maxPlaintext { - n = maxPlaintext - } - return n -} diff --git a/common/process/searcher_darwin_shared.go b/common/process/searcher_darwin_shared.go index 05925530..0557ae67 100644 --- a/common/process/searcher_darwin_shared.go +++ b/common/process/searcher_darwin_shared.go @@ -81,7 +81,7 @@ func (f *darwinConnectionFinder) find(network string, source netip.AddrPort, des source = normalizeDarwinAddrPort(source) destination = normalizeDarwinAddrPort(destination) var lastOwner *adapter.ConnectionOwner - for attempt := 0; attempt < 2; attempt++ { + for attempt := range 2 { snapshot, fromCache, err := f.loadSnapshot(networkName, attempt > 0) if err != nil { return nil, err diff --git a/common/process/searcher_linux_shared.go b/common/process/searcher_linux_shared.go index cd0601bc..9e868f36 100644 --- a/common/process/searcher_linux_shared.go +++ b/common/process/searcher_linux_shared.go @@ -1,5 +1,6 @@ //go:build linux +//nolint:unused package process import ( @@ -117,7 +118,7 @@ func (c *socketDiagConn) query(source netip.AddrPort, destination netip.AddrPort c.access.Lock() defer c.access.Unlock() request := packSocketDiagRequest(c.family, c.protocol, source, destination, false) - for attempt := 0; attempt < 2; attempt++ { + for range 2 { err = c.ensureOpenLocked() if err != nil { return 0, 0, E.Cause(err, "dial netlink") diff --git a/common/settings/proxy_darwin.go b/common/settings/proxy_darwin.go index 53ed0fe0..baaf6ced 100644 --- a/common/settings/proxy_darwin.go +++ b/common/settings/proxy_darwin.go @@ -109,7 +109,7 @@ func getInterfaceDisplayName(name string) (string, error) { if err != nil { return "", err } - for _, deviceSpan := range strings.Split(string(content), "Ethernet Address") { + for deviceSpan := range strings.SplitSeq(string(content), "Ethernet Address") { if strings.Contains(deviceSpan, "Device: "+name) { substr := "Hardware Port: " deviceSpan = deviceSpan[strings.Index(deviceSpan, substr)+len(substr):] diff --git a/common/settings/wifi_linux_connman.go b/common/settings/wifi_linux_connman.go index 74706a7b..46f6ea17 100644 --- a/common/settings/wifi_linux_connman.go +++ b/common/settings/wifi_linux_connman.go @@ -40,14 +40,14 @@ func (m *connmanMonitor) ReadWIFIState() adapter.WIFIState { defer cancel() cmObj := m.conn.Object("net.connman", "/") - var services []interface{} + var services []any err := cmObj.CallWithContext(ctx, "net.connman.Manager.GetServices", 0).Store(&services) if err != nil { return adapter.WIFIState{} } for _, service := range services { - servicePair, ok := service.([]interface{}) + servicePair, ok := service.([]any) if !ok || len(servicePair) != 2 { continue } diff --git a/common/settings/wifi_linux_wpa.go b/common/settings/wifi_linux_wpa.go index 51e76c1c..192c2f01 100644 --- a/common/settings/wifi_linux_wpa.go +++ b/common/settings/wifi_linux_wpa.go @@ -1,3 +1,4 @@ +//nolint:unused package settings import ( @@ -73,13 +74,13 @@ func (m *wpaSupplicantMonitor) ReadWIFIState() adapter.WIFIState { scanner := bufio.NewScanner(strings.NewReader(status)) for scanner.Scan() { line := scanner.Text() - if strings.HasPrefix(line, "wpa_state=") { - state := strings.TrimPrefix(line, "wpa_state=") + if after, ok := strings.CutPrefix(line, "wpa_state="); ok { + state := after connected = state == "COMPLETED" - } else if strings.HasPrefix(line, "ssid=") { - ssid = strings.TrimPrefix(line, "ssid=") - } else if strings.HasPrefix(line, "bssid=") { - bssid = strings.TrimPrefix(line, "bssid=") + } else if after, ok := strings.CutPrefix(line, "ssid="); ok { + ssid = after + } else if after, ok := strings.CutPrefix(line, "bssid="); ok { + bssid = after } } diff --git a/common/settings/wifi_stub.go b/common/settings/wifi_stub.go index fd39af9e..499212e4 100644 --- a/common/settings/wifi_stub.go +++ b/common/settings/wifi_stub.go @@ -1,5 +1,6 @@ //go:build !linux && !windows +//nolint:unused package settings import ( diff --git a/common/sniff/internal/qtls/qtls.go b/common/sniff/internal/qtls/qtls.go index 9742de1e..72414c61 100644 --- a/common/sniff/internal/qtls/qtls.go +++ b/common/sniff/internal/qtls/qtls.go @@ -54,9 +54,8 @@ type xorNonceAEAD struct { aead cipher.AEAD } -func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number -func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { for i, b := range nonce { diff --git a/common/sniff/quic_blacklist.go b/common/sniff/quic_blacklist.go index 56a15152..bdf9cdb1 100644 --- a/common/sniff/quic_blacklist.go +++ b/common/sniff/quic_blacklist.go @@ -1,6 +1,8 @@ package sniff import ( + "slices" + "github.com/sagernet/sing-box/common/ja3" ) @@ -15,15 +17,8 @@ const ( // Note: uQUIC with Chromium mimicry cannot be reliably distinguished from real Chromium // since it uses the same TLS fingerprint, so it will be identified as Chromium. func isQUICGo(fingerprint *ja3.ClientHello) bool { - for _, curve := range fingerprint.EllipticCurves { - if curve == x25519Kyber768Draft00 { - return true - } + if slices.Contains(fingerprint.EllipticCurves, x25519Kyber768Draft00) { + return true } - for _, ext := range fingerprint.Extensions { - if ext == extensionRenegotiationInfo { - return true - } - } - return false + return slices.Contains(fingerprint.Extensions, extensionRenegotiationInfo) } diff --git a/common/sniff/quic_capture_test.go b/common/sniff/quic_capture_test.go index 4c9eb838..7af3b6a2 100644 --- a/common/sniff/quic_capture_test.go +++ b/common/sniff/quic_capture_test.go @@ -30,7 +30,7 @@ func TestSniffQUICQuicGoFingerprint(t *testing.T) { go func() { var packets [][]byte udpConn.SetReadDeadline(time.Now().Add(3 * time.Second)) - for i := 0; i < 10; i++ { + for range 10 { buf := make([]byte, 2048) n, _, err := udpConn.ReadFromUDP(buf) if err != nil { @@ -104,7 +104,7 @@ func TestSniffQUICInitialFromQuicGo(t *testing.T) { go func() { var packets [][]byte udpConn.SetReadDeadline(time.Now().Add(3 * time.Second)) - for i := 0; i < 5; i++ { // Capture up to 5 packets + for range 5 { // Capture up to 5 packets buf := make([]byte, 2048) n, _, err := udpConn.ReadFromUDP(buf) if err != nil { diff --git a/common/srs/binary.go b/common/srs/binary.go index ca12fff0..d2c865e1 100644 --- a/common/srs/binary.go +++ b/common/srs/binary.go @@ -78,7 +78,7 @@ func Read(reader io.Reader, recover bool) (ruleSetCompat option.PlainRuleSetComp } ruleSetCompat.Version = version ruleSetCompat.Options.Rules = make([]option.HeadlessRule, length) - for i := uint64(0); i < length; i++ { + for i := range length { ruleSetCompat.Options.Rules[i], err = readRule(bReader, recover) if err != nil { err = E.Cause(err, "read rule[", i, "]") @@ -644,7 +644,7 @@ func readLogicalRule(reader varbin.Reader, recovery bool) (logicalRule option.Lo return } logicalRule.Rules = make([]option.HeadlessRule, length) - for i := uint64(0); i < length; i++ { + for i := range length { logicalRule.Rules[i], err = readRule(reader, recovery) if err != nil { err = E.Cause(err, "read logical rule [", i, "]") diff --git a/common/srs/compat_test.go b/common/srs/compat_test.go index 98552b32..46f3c114 100644 --- a/common/srs/compat_test.go +++ b/common/srs/compat_test.go @@ -450,7 +450,7 @@ func buildIPSet(cidrs ...string) *netipx.IPSet { func buildLargeIPSet(count int) *netipx.IPSet { var builder netipx.IPSetBuilder - for i := 0; i < count; i++ { + for i := range count { prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{10, byte(i / 256), byte(i % 256), 0}), 24) builder.AddPrefix(prefix) } diff --git a/common/tls/reality_client.go b/common/tls/reality_client.go index 9362d2f8..d8328770 100644 --- a/common/tls/reality_client.go +++ b/common/tls/reality_client.go @@ -267,8 +267,8 @@ type realityVerifier struct { } func (c *realityVerifier) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - p, _ := reflect.TypeOf(c.Conn).Elem().FieldByName("peerCertificates") - certs := *(*([]*x509.Certificate))(unsafe.Pointer(uintptr(unsafe.Pointer(c.Conn)) + p.Offset)) + p, _ := reflect.TypeFor[utls.Conn]().FieldByName("peerCertificates") + certs := *(*([]*x509.Certificate))(unsafe.Add(unsafe.Pointer(c.Conn), p.Offset)) if pub, ok := certs[0].PublicKey.(ed25519.PublicKey); ok { h := hmac.New(sha512.New, c.authKey) h.Write(pub) diff --git a/common/tls/std_server.go b/common/tls/std_server.go index 760c4b3a..a1a2a611 100644 --- a/common/tls/std_server.go +++ b/common/tls/std_server.go @@ -141,13 +141,14 @@ func (c *STDServerConfig) startWatcher() error { func (c *STDServerConfig) certificateUpdated(path string) error { if path == c.certificatePath || path == c.keyPath { - if path == c.certificatePath { + switch path { + case c.certificatePath: certificate, err := os.ReadFile(c.certificatePath) if err != nil { return E.Cause(err, "reload certificate from ", c.certificatePath) } c.certificate = certificate - } else if path == c.keyPath { + case c.keyPath: key, err := os.ReadFile(c.keyPath) if err != nil { return E.Cause(err, "reload key from ", c.keyPath) @@ -338,9 +339,10 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option. } tlsConfig.ClientCAs = clientCertificateCA } else if len(options.ClientCertificatePublicKeySHA256) > 0 { - if tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert { + switch tlsConfig.ClientAuth { + case tls.RequireAndVerifyClientCert: tlsConfig.ClientAuth = tls.RequireAnyClientCert - } else if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven { + case tls.VerifyClientCertIfGiven: tlsConfig.ClientAuth = tls.RequestClientCert } tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { diff --git a/daemon/started_service.go b/daemon/started_service.go index c260e8cb..3af0ea5a 100644 --- a/daemon/started_service.go +++ b/daemon/started_service.go @@ -603,10 +603,7 @@ func (s *StartedService) URLTest(ctx context.Context, request *URLTestRequest) ( return false } _, isGroup := it.(adapter.OutboundGroup) - if isGroup { - return false - } - return true + return !isGroup }) b, _ := batch.New(boxService.ctx, batch.WithConcurrencyNum[any](10)) for _, detour := range outbounds { diff --git a/dns/client.go b/dns/client.go index 1a2ee8f8..89b6170c 100644 --- a/dns/client.go +++ b/dns/client.go @@ -70,10 +70,7 @@ func NewClient(options ClientOptions) *Client { if client.timeout == 0 { client.timeout = C.DNSTimeout } - cacheCapacity := options.CacheCapacity - if cacheCapacity < 1024 { - cacheCapacity = 1024 - } + cacheCapacity := max(options.CacheCapacity, 1024) if !client.disableCache { if !client.independentCache { client.cache = common.Must1(freelru.NewSharded[dns.Question, *dns.Msg](cacheCapacity, maphash.NewHasher[dns.Question]().Hash32)) @@ -334,9 +331,10 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom if options.LookupStrategy != C.DomainStrategyAsIS { lookupOptions.Strategy = strategy } - if strategy == C.DomainStrategyIPv4Only { + switch strategy { + case C.DomainStrategyIPv4Only: return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, lookupOptions, responseChecker) - } else if strategy == C.DomainStrategyIPv6Only { + case C.DomainStrategyIPv6Only: return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, lookupOptions, responseChecker) } var response4 []netip.Addr @@ -500,10 +498,7 @@ func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransp } } } - nowTTL := int(expireAt.Sub(timeNow).Seconds()) - if nowTTL < 0 { - nowTTL = 0 - } + nowTTL := max(int(expireAt.Sub(timeNow).Seconds()), 0) response = response.Copy() if originTTL > 0 { duration := uint32(originTTL - nowTTL) @@ -551,18 +546,6 @@ func MessageToAddresses(response *dns.Msg) []netip.Addr { return addresses } -func wrapError(err error) error { - switch dnsErr := err.(type) { - case *net.DNSError: - if dnsErr.IsNotFound { - return RcodeNameError - } - case *net.AddrError: - return RcodeNameError - } - return err -} - type transportKey struct{} func contextWithTransportTag(ctx context.Context, transportTag string) context.Context { diff --git a/dns/transport/dhcp/dhcp.go b/dns/transport/dhcp/dhcp.go index 3f4eb721..8dc22c49 100644 --- a/dns/transport/dhcp/dhcp.go +++ b/dns/transport/dhcp/dhcp.go @@ -222,7 +222,7 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *control.Interface) packetConn net.PacketConn err error ) - for i := 0; i < 5; i++ { + for range 5 { packetConn, err = listener.ListenPacket(t.ctx, "udp4", listenAddr) if err == nil || !errors.Is(err, syscall.EADDRINUSE) { break diff --git a/dns/transport/dhcp/dhcp_shared.go b/dns/transport/dhcp/dhcp_shared.go index 20cd50c5..16e319ba 100644 --- a/dns/transport/dhcp/dhcp_shared.go +++ b/dns/transport/dhcp/dhcp_shared.go @@ -72,7 +72,7 @@ func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn sLen := len(servers) var lastErr error for i := 0; i < t.attempts; i++ { - for j := 0; j < sLen; j++ { + for j := range sLen { server := servers[j] question := message.Question[0] question.Name = fqdn diff --git a/dns/transport/local/local_resolved_stub.go b/dns/transport/local/local_resolved_stub.go index 2e011851..e3bf8432 100644 --- a/dns/transport/local/local_resolved_stub.go +++ b/dns/transport/local/local_resolved_stub.go @@ -1,5 +1,6 @@ //go:build !linux +//nolint:unused package local import ( diff --git a/dns/transport/local/local_shared.go b/dns/transport/local/local_shared.go index 77635458..07040911 100644 --- a/dns/transport/local/local_shared.go +++ b/dns/transport/local/local_shared.go @@ -82,7 +82,7 @@ func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn stri sLen := uint32(len(config.servers)) var lastErr error for i := 0; i < config.attempts; i++ { - for j := uint32(0); j < sLen; j++ { + for j := range sLen { server := config.servers[(serverOffset+j)%sLen] question := message.Question[0] question.Name = fqdn diff --git a/dns/transport/local/resolv.go b/dns/transport/local/resolv.go index 3586cbbf..4aa10a64 100644 --- a/dns/transport/local/resolv.go +++ b/dns/transport/local/resolv.go @@ -1,3 +1,4 @@ +//nolint:unused package local import ( diff --git a/dns/transport/local/resolv_default.go b/dns/transport/local/resolv_default.go index 0a7d8810..9c5e8fa2 100644 --- a/dns/transport/local/resolv_default.go +++ b/dns/transport/local/resolv_default.go @@ -1,3 +1,4 @@ +//nolint:unused package local import ( diff --git a/dns/transport/quic/quic.go b/dns/transport/quic/quic.go index 3a7b6163..3bb93e41 100644 --- a/dns/transport/quic/quic.go +++ b/dns/transport/quic/quic.go @@ -100,7 +100,7 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, err error response *mDNS.Msg ) - for i := 0; i < 2; i++ { + for range 2 { conn, _, err = t.connection.Acquire(ctx, func(ctx context.Context) (*quic.Conn, error) { rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) if err != nil { diff --git a/dns/transport/tls.go b/dns/transport/tls.go index 8ce41514..fdb48563 100644 --- a/dns/transport/tls.go +++ b/dns/transport/tls.go @@ -102,7 +102,7 @@ func (t *TLSTransport) Reset() { func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { var lastErr error - for attempt := 0; attempt < 2; attempt++ { + for range 2 { conn, created, err := t.connections.Acquire(ctx, func(ctx context.Context) (*tlsDNSConn, error) { tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr) if err != nil { diff --git a/experimental/cachefile/cache.go b/experimental/cachefile/cache.go index ac2d7002..e52dd3ed 100644 --- a/experimental/cachefile/cache.go +++ b/experimental/cachefile/cache.go @@ -112,7 +112,7 @@ func (c *CacheFile) Start(stage adapter.StartStage) error { db *bbolt.DB err error ) - for i := 0; i < 10; i++ { + for range 10 { db, err = bbolt.Open(c.path, fileMode, &options) if err == nil { break diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index ec40a95f..950dca04 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -164,7 +164,7 @@ func (s *Server) Start(stage adapter.StartStage) error { listener net.Listener err error ) - for i := 0; i < 3; i++ { + for range 3 { listener, err = net.Listen("tcp", s.httpServer.Addr) if runtime.GOOS == "android" && errors.Is(err, syscall.EADDRINUSE) { time.Sleep(100 * time.Millisecond) diff --git a/experimental/libbox/command_client.go b/experimental/libbox/command_client.go index a5077bea..2f347bdd 100644 --- a/experimental/libbox/command_client.go +++ b/experimental/libbox/command_client.go @@ -147,7 +147,7 @@ func (c *CommandClient) dialWithRetry(target string, contextDialer func(context. var client daemon.StartedServiceClient var lastError error - for attempt := 0; attempt < commandClientDialAttempts; attempt++ { + for attempt := range commandClientDialAttempts { if connection == nil { options := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), diff --git a/experimental/libbox/command_server.go b/experimental/libbox/command_server.go index 1c2412b6..7eca1194 100644 --- a/experimental/libbox/command_server.go +++ b/experimental/libbox/command_server.go @@ -114,7 +114,7 @@ func (s *CommandServer) Start() error { if sCommandServerListenPort == 0 { sockPath := filepath.Join(sBasePath, "command.sock") os.Remove(sockPath) - for i := 0; i < 30; i++ { + for range 30 { listener, err = net.ListenUnix("unix", &net.UnixAddr{ Name: sockPath, Net: "unix", diff --git a/experimental/libbox/command_types.go b/experimental/libbox/command_types.go index c330dd4b..b811aaf4 100644 --- a/experimental/libbox/command_types.go +++ b/experimental/libbox/command_types.go @@ -418,13 +418,3 @@ func systemProxyStatusFromGRPC(status *daemon.SystemProxyStatus) *SystemProxySta Enabled: status.Enabled, } } - -func systemProxyStatusToGRPC(status *SystemProxyStatus) *daemon.SystemProxyStatus { - if status == nil { - return nil - } - return &daemon.SystemProxyStatus{ - Available: status.Available, - Enabled: status.Enabled, - } -} diff --git a/experimental/libbox/log.go b/experimental/libbox/log.go index ff33f081..aa12f8f2 100644 --- a/experimental/libbox/log.go +++ b/experimental/libbox/log.go @@ -8,8 +8,6 @@ import ( "runtime/debug" ) -var crashOutputFile *os.File - func RedirectStderr(path string) error { if stats, err := os.Stat(path); err == nil && stats.Size() > 0 { _ = os.Rename(path, path+".old") @@ -32,6 +30,5 @@ func RedirectStderr(path string) error { os.Remove(outputFile.Name()) return err } - crashOutputFile = outputFile - return nil + return outputFile.Close() } diff --git a/experimental/libbox/monitor.go b/experimental/libbox/monitor.go index 2deedb2e..62f91613 100644 --- a/experimental/libbox/monitor.go +++ b/experimental/libbox/monitor.go @@ -16,7 +16,6 @@ var ( type platformDefaultInterfaceMonitor struct { *platformInterfaceWrapper logger logger.Logger - element *list.Element[tun.NetworkUpdateCallback] callbacks list.List[tun.DefaultInterfaceUpdateCallback] myInterface string } diff --git a/experimental/libbox/platform.go b/experimental/libbox/platform.go index 4db32a22..b82121b7 100644 --- a/experimental/libbox/platform.go +++ b/experimental/libbox/platform.go @@ -1,9 +1,6 @@ package libbox -import ( - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/option" -) +import C "github.com/sagernet/sing-box/constant" type PlatformInterface interface { LocalDNSTransport() LocalDNSTransport @@ -98,37 +95,3 @@ type OnDemandRuleIterator interface { Next() OnDemandRule HasNext() bool } - -type onDemandRule struct { - option.OnDemandRule -} - -func (r *onDemandRule) Target() int32 { - if r.OnDemandRule.Action == nil { - return -1 - } - return int32(*r.OnDemandRule.Action) -} - -func (r *onDemandRule) DNSSearchDomainMatch() StringIterator { - return newIterator(r.OnDemandRule.DNSSearchDomainMatch) -} - -func (r *onDemandRule) DNSServerAddressMatch() StringIterator { - return newIterator(r.OnDemandRule.DNSServerAddressMatch) -} - -func (r *onDemandRule) InterfaceTypeMatch() int32 { - if r.OnDemandRule.InterfaceTypeMatch == nil { - return -1 - } - return int32(*r.OnDemandRule.InterfaceTypeMatch) -} - -func (r *onDemandRule) SSIDMatch() StringIterator { - return newIterator(r.OnDemandRule.SSIDMatch) -} - -func (r *onDemandRule) ProbeURL() string { - return r.OnDemandRule.ProbeURL -} diff --git a/experimental/libbox/tun_darwin.go b/experimental/libbox/tun_darwin.go index e312cb91..b6c6d56c 100644 --- a/experimental/libbox/tun_darwin.go +++ b/experimental/libbox/tun_darwin.go @@ -11,7 +11,7 @@ const utunControlName = "com.apple.net.utun_control" func GetTunnelFileDescriptor() int32 { ctlInfo := &unix.CtlInfo{} copy(ctlInfo.Name[:], utunControlName) - for fd := 0; fd < 1024; fd++ { + for fd := range 1024 { addr, err := unix.Getpeername(fd) if err != nil { continue diff --git a/log/id.go b/log/id.go index 7cac29d2..e23ed3d5 100644 --- a/log/id.go +++ b/log/id.go @@ -4,14 +4,8 @@ import ( "context" "math/rand" "time" - - "github.com/sagernet/sing/common/random" ) -func init() { - random.InitializeSeed() -} - type idKey struct{} type ID struct { diff --git a/option/types.go b/option/types.go index fe7d4b3d..87cf382c 100644 --- a/option/types.go +++ b/option/types.go @@ -28,7 +28,6 @@ func (v *NetworkList) UnmarshalJSON(content []byte) error { for _, networkName := range networkList { switch networkName { case N.NetworkTCP, N.NetworkUDP: - break default: return E.New("unknown network: " + networkName) } diff --git a/protocol/direct/loopback_detect.go b/protocol/direct/loopback_detect.go deleted file mode 100644 index 7a62164e..00000000 --- a/protocol/direct/loopback_detect.go +++ /dev/null @@ -1,186 +0,0 @@ -package direct - -import ( - "net" - "net/netip" - "sync" - - "github.com/sagernet/sing-box/adapter" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -type loopBackDetector struct { - networkManager adapter.NetworkManager - connAccess sync.RWMutex - packetConnAccess sync.RWMutex - connMap map[netip.AddrPort]netip.AddrPort - packetConnMap map[uint16]uint16 -} - -func newLoopBackDetector(networkManager adapter.NetworkManager) *loopBackDetector { - return &loopBackDetector{ - networkManager: networkManager, - connMap: make(map[netip.AddrPort]netip.AddrPort), - packetConnMap: make(map[uint16]uint16), - } -} - -func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn { - source := M.AddrPortFromNet(conn.LocalAddr()) - if !source.IsValid() { - return conn - } - if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn { - if !source.Addr().IsLoopback() { - _, err := l.networkManager.InterfaceFinder().ByAddr(source.Addr()) - if err != nil { - return conn - } - } - if !N.IsPublicAddr(source.Addr()) { - return conn - } - l.packetConnAccess.Lock() - l.packetConnMap[source.Port()] = M.AddrPortFromNet(conn.RemoteAddr()).Port() - l.packetConnAccess.Unlock() - return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connPort: source.Port()} - } else { - l.connAccess.Lock() - l.connMap[source] = M.AddrPortFromNet(conn.RemoteAddr()) - l.connAccess.Unlock() - return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: source} - } -} - -func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn, destination M.Socksaddr) N.NetPacketConn { - source := M.AddrPortFromNet(conn.LocalAddr()) - if !source.IsValid() { - return conn - } - if !source.Addr().IsLoopback() { - _, err := l.networkManager.InterfaceFinder().ByAddr(source.Addr()) - if err != nil { - return conn - } - } - l.packetConnAccess.Lock() - l.packetConnMap[source.Port()] = destination.AddrPort().Port() - l.packetConnAccess.Unlock() - return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connPort: source.Port()} -} - -func (l *loopBackDetector) CheckConn(source netip.AddrPort, local netip.AddrPort) bool { - l.connAccess.RLock() - defer l.connAccess.RUnlock() - destination, loaded := l.connMap[source] - return loaded && destination != local -} - -func (l *loopBackDetector) CheckPacketConn(source netip.AddrPort, local netip.AddrPort) bool { - if !source.IsValid() { - return false - } - if !source.Addr().IsLoopback() { - _, err := l.networkManager.InterfaceFinder().ByAddr(source.Addr()) - if err != nil { - return false - } - } - if N.IsPublicAddr(source.Addr()) { - return false - } - l.packetConnAccess.RLock() - defer l.packetConnAccess.RUnlock() - destinationPort, loaded := l.packetConnMap[source.Port()] - return loaded && destinationPort != local.Port() -} - -type loopBackDetectWrapper struct { - net.Conn - detector *loopBackDetector - connAddr netip.AddrPort - closeOnce sync.Once -} - -func (w *loopBackDetectWrapper) Close() error { - w.closeOnce.Do(func() { - w.detector.connAccess.Lock() - delete(w.detector.connMap, w.connAddr) - w.detector.connAccess.Unlock() - }) - return w.Conn.Close() -} - -func (w *loopBackDetectWrapper) ReaderReplaceable() bool { - return true -} - -func (w *loopBackDetectWrapper) WriterReplaceable() bool { - return true -} - -func (w *loopBackDetectWrapper) Upstream() any { - return w.Conn -} - -type loopBackDetectPacketWrapper struct { - N.NetPacketConn - detector *loopBackDetector - connPort uint16 - closeOnce sync.Once -} - -func (w *loopBackDetectPacketWrapper) Close() error { - w.closeOnce.Do(func() { - w.detector.packetConnAccess.Lock() - delete(w.detector.packetConnMap, w.connPort) - w.detector.packetConnAccess.Unlock() - }) - return w.NetPacketConn.Close() -} - -func (w *loopBackDetectPacketWrapper) ReaderReplaceable() bool { - return true -} - -func (w *loopBackDetectPacketWrapper) WriterReplaceable() bool { - return true -} - -func (w *loopBackDetectPacketWrapper) Upstream() any { - return w.NetPacketConn -} - -type abstractUDPConn interface { - net.Conn - net.PacketConn -} - -type loopBackDetectUDPWrapper struct { - abstractUDPConn - detector *loopBackDetector - connPort uint16 - closeOnce sync.Once -} - -func (w *loopBackDetectUDPWrapper) Close() error { - w.closeOnce.Do(func() { - w.detector.packetConnAccess.Lock() - delete(w.detector.packetConnMap, w.connPort) - w.detector.packetConnAccess.Unlock() - }) - return w.abstractUDPConn.Close() -} - -func (w *loopBackDetectUDPWrapper) ReaderReplaceable() bool { - return true -} - -func (w *loopBackDetectUDPWrapper) WriterReplaceable() bool { - return true -} - -func (w *loopBackDetectUDPWrapper) Upstream() any { - return w.abstractUDPConn -} diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index 9d24f31a..630a6755 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -41,7 +41,6 @@ type Outbound struct { domainStrategy C.DomainStrategy fallbackDelay time.Duration isEmpty bool - // loopBack *loopBackDetector } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) { @@ -67,7 +66,6 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL fallbackDelay: time.Duration(options.FallbackDelay), dialer: outboundDialer.(dialer.ParallelInterfaceDialer), isEmpty: reflect.DeepEqual(options.DialerOptions, option.DialerOptions{UDPFragmentDefault: true}), - // loopBack: newLoopBackDetector(router), } //nolint:staticcheck if options.ProxyProtocol != 0 { @@ -87,11 +85,6 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination case N.NetworkUDP: h.logger.InfoContext(ctx, "outbound packet connection to ", destination) } - /*conn, err := h.dialer.DialContext(ctx, network, destination) - if err != nil { - return nil, err - } - return h.loopBack.NewConn(conn), nil*/ return h.dialer.DialContext(ctx, network, destination) } @@ -104,7 +97,6 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n if err != nil { return nil, err } - // conn = h.loopBack.NewPacketConn(bufio.NewPacketConn(conn), destination) return conn, nil } @@ -161,18 +153,3 @@ func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M. func (h *Outbound) IsEmpty() bool { return h.isEmpty } - -/*func (h *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) { - return E.New("reject loopback connection to ", metadata.Destination) - } - return NewConnection(ctx, h, conn, metadata) -} - -func (h *Outbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - if h.loopBack.CheckPacketConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) { - return E.New("reject loopback packet connection to ", metadata.Destination) - } - return NewPacketConnection(ctx, h, conn, metadata) -} -*/ diff --git a/protocol/dns/handle.go b/protocol/dns/handle.go index e1323509..d7d89ca8 100644 --- a/protocol/dns/handle.go +++ b/protocol/dns/handle.go @@ -82,7 +82,7 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn } break } - fastClose, cancel := common.ContextWithCancelCause(ctx) + fastClose, cancel := context.WithCancelCause(ctx) timeout := canceler.New(fastClose, cancel, C.DNSTimeout) var group task.Group group.Append0(func(_ context.Context) error { @@ -150,7 +150,7 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn } func newDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error { - fastClose, cancel := common.ContextWithCancelCause(ctx) + fastClose, cancel := context.WithCancelCause(ctx) timeout := canceler.New(fastClose, cancel, C.DNSTimeout) var group task.Group group.Append0(func(_ context.Context) error { diff --git a/protocol/group/urltest.go b/protocol/group/urltest.go index 26967279..730040f7 100644 --- a/protocol/group/urltest.go +++ b/protocol/group/urltest.go @@ -34,7 +34,6 @@ var _ adapter.OutboundGroup = (*URLTest)(nil) type URLTest struct { outbound.Adapter ctx context.Context - router adapter.Router outbound adapter.OutboundManager connection adapter.ConnectionManager logger log.ContextLogger @@ -51,7 +50,6 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo outbound := &URLTest{ Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds), ctx: ctx, - router: router, outbound: service.FromContext[adapter.OutboundManager](ctx), connection: service.FromContext[adapter.ConnectionManager](ctx), logger: logger, @@ -188,7 +186,6 @@ func (s *URLTest) NewDirectRouteConnection(metadata adapter.InboundContext, rout type URLTestGroup struct { ctx context.Context - router adapter.Router outbound adapter.OutboundManager pause pause.Manager pauseCallback *list.Element[pause.Callback] @@ -267,9 +264,10 @@ func (g *URLTestGroup) Touch() { g.lastActive.Store(time.Now()) return } - g.ticker = time.NewTicker(g.interval) - go g.loopCheck() - g.pauseCallback = pause.RegisterTicker(g.pause, g.ticker, g.interval, nil) + ticker := time.NewTicker(g.interval) + g.ticker = ticker + g.pauseCallback = pause.RegisterTicker(g.pause, ticker, g.interval, nil) + go g.loopCheck(ticker, g.close) } func (g *URLTestGroup) Close() error { @@ -279,7 +277,9 @@ func (g *URLTestGroup) Close() error { return nil } g.ticker.Stop() + g.ticker = nil g.pause.UnregisterCallback(g.pauseCallback) + g.pauseCallback = nil close(g.close) return nil } @@ -328,23 +328,25 @@ func (g *URLTestGroup) Select(network string) (adapter.Outbound, bool) { return minOutbound, true } -func (g *URLTestGroup) loopCheck() { +func (g *URLTestGroup) loopCheck(ticker *time.Ticker, closeChan <-chan struct{}) { if time.Since(g.lastActive.Load()) > g.interval { g.lastActive.Store(time.Now()) g.CheckOutbounds(false) } for { select { - case <-g.close: + case <-closeChan: return - case <-g.ticker.C: + case <-ticker.C: } if time.Since(g.lastActive.Load()) > g.idleTimeout { g.access.Lock() - g.ticker.Stop() - g.ticker = nil - g.pause.UnregisterCallback(g.pauseCallback) - g.pauseCallback = nil + if g.ticker == ticker { + g.ticker.Stop() + g.ticker = nil + g.pause.UnregisterCallback(g.pauseCallback) + g.pauseCallback = nil + } g.access.Unlock() return } diff --git a/protocol/naive/inbound_conn.go b/protocol/naive/inbound_conn.go index 8cc3ded2..77500435 100644 --- a/protocol/naive/inbound_conn.go +++ b/protocol/naive/inbound_conn.go @@ -22,7 +22,7 @@ func generatePaddingHeader() string { paddingLen := rand.Intn(32) + 30 padding := make([]byte, paddingLen) bits := rand.Uint64() - for i := 0; i < 16; i++ { + for i := range 16 { padding[i] = "!#$()+<>?@[]^`{}"[bits&15] bits >>= 4 } diff --git a/protocol/tailscale/tun_device_unix.go b/protocol/tailscale/tun_device_unix.go index a8d237ab..d4bc7ced 100644 --- a/protocol/tailscale/tun_device_unix.go +++ b/protocol/tailscale/tun_device_unix.go @@ -11,7 +11,6 @@ import ( "sync/atomic" singTun "github.com/sagernet/sing-tun" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/logger" wgTun "github.com/sagernet/wireguard-go/tun" ) @@ -57,7 +56,7 @@ func (a *tunDeviceAdapter) Read(bufs [][]byte, sizes []int, offset int) (count i if a.linuxTUN != nil { n, err := a.linuxTUN.BatchRead(bufs, offset-singTun.PacketOffset, sizes) if err == nil { - for i := 0; i < n; i++ { + for i := range n { a.debugPacket("read", bufs[i][offset:offset+sizes[i]]) } } @@ -92,7 +91,7 @@ func (a *tunDeviceAdapter) Write(bufs [][]byte, offset int) (count int, err erro for _, packet := range bufs { a.debugPacket("write", packet[offset:]) if singTun.PacketOffset > 0 { - common.ClearArray(packet[offset-singTun.PacketOffset : offset]) + clear(packet[offset-singTun.PacketOffset : offset]) singTun.PacketFillHeader(packet[offset-singTun.PacketOffset:], singTun.PacketIPVersion(packet[offset:])) } _, err = a.tun.Write(packet[offset-singTun.PacketOffset:]) diff --git a/route/process_cache.go b/route/process_cache.go index 44ee3fcf..f99cebad 100644 --- a/route/process_cache.go +++ b/route/process_cache.go @@ -3,6 +3,7 @@ package route import ( "context" "net/netip" + "slices" "strings" "github.com/sagernet/sing-box/adapter" @@ -78,10 +79,8 @@ func (r *Router) isLocalSource(source netip.Addr) bool { return true } if r.platformInterface != nil { - for _, addr := range r.platformInterface.MyInterfaceAddress() { - if addr == source { - return true - } + if slices.Contains(r.platformInterface.MyInterfaceAddress(), source) { + return true } } for _, netInterface := range r.network.InterfaceFinder().Interfaces() { diff --git a/route/route.go b/route/route.go index 7c24219e..aec5281c 100644 --- a/route/route.go +++ b/route/route.go @@ -31,7 +31,7 @@ import ( // Deprecated: use RouteConnectionEx instead. func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - done := make(chan interface{}) + done := make(chan any) err := r.routeConnection(ctx, conn, metadata, N.OnceClose(func(it error) { close(done) })) @@ -161,7 +161,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad } func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - done := make(chan interface{}) + done := make(chan any) err := r.routePacketConnection(ctx, conn, metadata, N.OnceClose(func(it error) { close(done) })) diff --git a/route/rule/match_state.go b/route/rule/match_state.go index feac8418..0d2e4b0b 100644 --- a/route/rule/match_state.go +++ b/route/rule/match_state.go @@ -42,11 +42,11 @@ func (s ruleMatchStateSet) combine(other ruleMatchStateSet) ruleMatchStateSet { return 0 } var combined ruleMatchStateSet - for left := ruleMatchState(0); left < 16; left++ { + for left := range ruleMatchState(16) { if !s.contains(left) { continue } - for right := ruleMatchState(0); right < 16; right++ { + for right := range ruleMatchState(16) { if !other.contains(right) { continue } @@ -61,7 +61,7 @@ func (s ruleMatchStateSet) withBase(base ruleMatchState) ruleMatchStateSet { return 0 } var withBase ruleMatchStateSet - for state := ruleMatchState(0); state < 16; state++ { + for state := range ruleMatchState(16) { if !s.contains(state) { continue } @@ -72,7 +72,7 @@ func (s ruleMatchStateSet) withBase(base ruleMatchState) ruleMatchStateSet { func (s ruleMatchStateSet) filter(allowed func(ruleMatchState) bool) ruleMatchStateSet { var filtered ruleMatchStateSet - for state := ruleMatchState(0); state < 16; state++ { + for state := range ruleMatchState(16) { if !s.contains(state) { continue } @@ -91,10 +91,6 @@ type ruleStateMatcherWithBase interface { matchStatesWithBase(metadata *adapter.InboundContext, base ruleMatchState) ruleMatchStateSet } -func matchHeadlessRuleStates(rule adapter.HeadlessRule, metadata *adapter.InboundContext) ruleMatchStateSet { - return matchHeadlessRuleStatesWithBase(rule, metadata, 0) -} - func matchHeadlessRuleStatesWithBase(rule adapter.HeadlessRule, metadata *adapter.InboundContext, base ruleMatchState) ruleMatchStateSet { if matcher, isStateMatcher := rule.(ruleStateMatcherWithBase); isStateMatcher { return matcher.matchStatesWithBase(metadata, base) @@ -108,10 +104,6 @@ func matchHeadlessRuleStatesWithBase(rule adapter.HeadlessRule, metadata *adapte return 0 } -func matchRuleItemStates(item RuleItem, metadata *adapter.InboundContext) ruleMatchStateSet { - return matchRuleItemStatesWithBase(item, metadata, 0) -} - func matchRuleItemStatesWithBase(item RuleItem, metadata *adapter.InboundContext, base ruleMatchState) ruleMatchStateSet { if matcher, isStateMatcher := item.(ruleStateMatcherWithBase); isStateMatcher { return matcher.matchStatesWithBase(metadata, base) diff --git a/route/rule/rule_abstract_test.go b/route/rule/rule_abstract_test.go index ace3dec6..a5bbc353 100644 --- a/route/rule/rule_abstract_test.go +++ b/route/rule/rule_abstract_test.go @@ -141,7 +141,6 @@ func TestAbstractLogicalRule_And_WithRuleSetInvert(t *testing.T) { }, } for _, testCase := range testCases { - testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() logicalRule := &abstractLogicalRule{ diff --git a/route/rule/rule_item_cidr.go b/route/rule/rule_item_cidr.go index c823dcf3..1b0105ea 100644 --- a/route/rule/rule_item_cidr.go +++ b/route/rule/rule_item_cidr.go @@ -2,6 +2,7 @@ package rule import ( "net/netip" + "slices" "strings" "github.com/sagernet/sing-box/adapter" @@ -80,12 +81,7 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool { return r.ipSet.Contains(metadata.Destination.Addr) } if len(metadata.DestinationAddresses) > 0 { - for _, address := range metadata.DestinationAddresses { - if r.ipSet.Contains(address) { - return true - } - } - return false + return slices.ContainsFunc(metadata.DestinationAddresses, r.ipSet.Contains) } return metadata.IPCIDRAcceptEmpty } diff --git a/route/rule/rule_item_domain.go b/route/rule/rule_item_domain.go index af790aa3..7e6484ea 100644 --- a/route/rule/rule_item_domain.go +++ b/route/rule/rule_item_domain.go @@ -1,6 +1,7 @@ package rule import ( + "slices" "strings" "github.com/sagernet/sing-box/adapter" @@ -16,15 +17,11 @@ type DomainItem struct { } func NewDomainItem(domains []string, domainSuffixes []string) (*DomainItem, error) { - for _, domainItem := range domains { - if domainItem == "" { - return nil, E.New("domain: empty item is not allowed") - } + if slices.Contains(domains, "") { + return nil, E.New("domain: empty item is not allowed") } - for _, domainSuffixItem := range domainSuffixes { - if domainSuffixItem == "" { - return nil, E.New("domain_suffix: empty item is not allowed") - } + if slices.Contains(domainSuffixes, "") { + return nil, E.New("domain_suffix: empty item is not allowed") } var description string if dLen := len(domains); dLen > 0 { diff --git a/route/rule/rule_set_semantics_test.go b/route/rule/rule_set_semantics_test.go index a01defe6..f1985015 100644 --- a/route/rule/rule_set_semantics_test.go +++ b/route/rule/rule_set_semantics_test.go @@ -57,7 +57,6 @@ func TestRouteRuleSetMergeDestinationAddressGroup(t *testing.T) { }, } for _, testCase := range testCases { - testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() ruleSet := newLocalRuleSetForTest("merge-destination", testCase.inner) @@ -223,7 +222,6 @@ func TestRouteRuleSetOuterGroupedStateMergesIntoSameGroup(t *testing.T) { }, } for _, testCase := range testCases { - testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() ruleSet := newLocalRuleSetForTest("outer-merge-"+testCase.name, headlessDefaultRule(t, func(rule *abstractDefaultRule) { @@ -652,7 +650,6 @@ func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) { }, } for _, testCase := range testCases { - testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() rule := dnsRuleForTest(func(rule *abstractDefaultRule) { diff --git a/service/ccm/credential_other.go b/service/ccm/credential_other.go index 11888b50..828c78c0 100644 --- a/service/ccm/credential_other.go +++ b/service/ccm/credential_other.go @@ -1,4 +1,4 @@ -//go:build !darwin +//go:build !darwin || !cgo package ccm diff --git a/service/ccm/service.go b/service/ccm/service.go index 34c38824..3aca535d 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -124,8 +124,6 @@ type Service struct { userManager *UserManager accessMutex sync.RWMutex usageTracker *AggregatedUsage - trackingGroup sync.WaitGroup - shuttingDown bool } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) { @@ -283,8 +281,8 @@ func (s *Service) getAccessToken() (string, error) { func detectContextWindow(betaHeader string, totalInputTokens int64) int { if totalInputTokens > premiumContextThreshold { - features := strings.Split(betaHeader, ",") - for _, feature := range features { + features := strings.SplitSeq(betaHeader, ",") + for feature := range features { if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") { return contextWindowPremium } @@ -507,8 +505,8 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons continue } - if bytes.HasPrefix(line, []byte("data: ")) { - eventData := bytes.TrimPrefix(line, []byte("data: ")) + if after, ok0 := bytes.CutPrefix(line, []byte("data: ")); ok0 { + eventData := after if bytes.Equal(eventData, []byte("[DONE]")) { continue } diff --git a/service/ocm/service.go b/service/ocm/service.go index 8b66964a..18bae457 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -556,8 +556,8 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons continue } - if bytes.HasPrefix(line, []byte("data: ")) { - eventData := bytes.TrimPrefix(line, []byte("data: ")) + if after, ok0 := bytes.CutPrefix(line, []byte("data: ")); ok0 { + eventData := after if bytes.Equal(eventData, []byte("[DONE]")) { continue } diff --git a/service/ocm/service_usage.go b/service/ocm/service_usage.go index 589fd093..18696f3b 100644 --- a/service/ocm/service_usage.go +++ b/service/ocm/service_usage.go @@ -851,10 +851,7 @@ func normalizeGPT5Model(model string) string { func calculateCost(stats UsageStats, model string, serviceTier string, contextWindow int) float64 { pricing := getPricing(model, serviceTier, contextWindow) - regularInputTokens := stats.InputTokens - stats.CachedTokens - if regularInputTokens < 0 { - regularInputTokens = 0 - } + regularInputTokens := max(stats.InputTokens-stats.CachedTokens, 0) cost := (float64(regularInputTokens)*pricing.InputPrice + float64(stats.OutputTokens)*pricing.OutputPrice + diff --git a/service/oomkiller/service.go b/service/oomkiller/service.go index c3612d92..ff90f6e4 100644 --- a/service/oomkiller/service.go +++ b/service/oomkiller/service.go @@ -96,6 +96,7 @@ func (s *Service) Start(stage adapter.StartStage) error { if s.hasTimerMode { s.adaptiveTimer = newAdaptiveTimer(s.logger, s.router, s.timerConfig) + s.adaptiveTimer.start(false) if s.memoryLimit > 0 { s.logger.Info("started memory monitor with limit: ", s.memoryLimit/(1024*1024), " MiB") } else { @@ -164,7 +165,7 @@ func goMemoryPressureCallback(status C.ulong) { if isCritical { s.logger.Warn("memory pressure: ", level, ", usage: ", usage/(1024*1024), " MiB") if s.adaptiveTimer != nil { - s.adaptiveTimer.startNow() + s.adaptiveTimer.start(true) } } else if isWarning { s.logger.Warn("memory pressure: ", level, ", usage: ", usage/(1024*1024), " MiB") diff --git a/service/oomkiller/service_stub.go b/service/oomkiller/service_stub.go index 13348bac..7c1b84e8 100644 --- a/service/oomkiller/service_stub.go +++ b/service/oomkiller/service_stub.go @@ -64,7 +64,7 @@ func (s *Service) Start(stage adapter.StartStage) error { return E.New("memory pressure monitoring is not available on this platform without memory_limit") } s.adaptiveTimer = newAdaptiveTimer(s.logger, s.router, s.timerConfig) - s.adaptiveTimer.start(0) + s.adaptiveTimer.start(false) if s.useAvailable { s.logger.Info("started memory monitor with available memory detection") } else { diff --git a/service/oomkiller/service_timer.go b/service/oomkiller/service_timer.go index 315e1715..9f6a06c7 100644 --- a/service/oomkiller/service_timer.go +++ b/service/oomkiller/service_timer.go @@ -55,17 +55,13 @@ func newAdaptiveTimer(logger log.ContextLogger, router adapter.Router, config ti } } -func (t *adaptiveTimer) start(_ uint64) { - t.access.Lock() - defer t.access.Unlock() - t.startLocked() -} - -func (t *adaptiveTimer) startNow() { +func (t *adaptiveTimer) start(immediate bool) { t.access.Lock() t.startLocked() t.access.Unlock() - t.poll() + if immediate { + t.poll() + } } func (t *adaptiveTimer) startLocked() { @@ -90,12 +86,6 @@ func (t *adaptiveTimer) stopLocked() { } } -func (t *adaptiveTimer) running() bool { - t.access.Lock() - defer t.access.Unlock() - return t.timer != nil -} - func (t *adaptiveTimer) poll() { t.access.Lock() defer t.access.Unlock() @@ -144,13 +134,8 @@ func (t *adaptiveTimer) poll() { interval = t.maxInterval } else { timeToLimit := time.Duration(float64(remaining) / float64(delta) * float64(t.lastInterval)) - interval = timeToLimit / time.Duration(t.checksBeforeLimit) - if interval < t.minInterval { - interval = t.minInterval - } - if interval > t.maxInterval { - interval = t.maxInterval - } + interval = max(timeToLimit/time.Duration(t.checksBeforeLimit), t.minInterval) + interval = min(interval, t.maxInterval) } t.lastInterval = interval diff --git a/service/resolved/resolve1.go b/service/resolved/resolve1.go index ed1ee41a..6b347060 100644 --- a/service/resolved/resolve1.go +++ b/service/resolved/resolve1.go @@ -10,6 +10,7 @@ import ( "os" "os/user" "path/filepath" + "slices" "strconv" "strings" "syscall" @@ -127,7 +128,7 @@ func (t *resolve1Manager) createMetadata(sender dbus.Sender) adapter.InboundCont var uidFound bool statusContent, err := os.ReadFile(F.ToString("/proc/", senderPid, "/status")) if err == nil { - for _, line := range strings.Split(string(statusContent), "\n") { + for line := range strings.SplitSeq(string(statusContent), "\n") { line = strings.TrimSpace(line) if strings.HasPrefix(line, "Uid:") { fields := strings.Fields(line) @@ -255,8 +256,8 @@ func (t *resolve1Manager) ResolveAddress(sender dbus.Sender, ifIndex int32, fami return } var nibbles []string - for i := len(address) - 1; i >= 0; i-- { - b := address[i] + for _, v := range slices.Backward(address) { + b := v nibbles = append(nibbles, fmt.Sprintf("%x", b&0x0F)) nibbles = append(nibbles, fmt.Sprintf("%x", b>>4)) } diff --git a/service/resolved/transport.go b/service/resolved/transport.go index ac20663a..bdc35551 100644 --- a/service/resolved/transport.go +++ b/service/resolved/transport.go @@ -248,7 +248,7 @@ func (t *Transport) tryOneName(ctx context.Context, servers *LinkServers, messag sLen := uint32(len(servers.Servers)) var lastErr error for i := 0; i < t.attempts; i++ { - for j := uint32(0); j < sLen; j++ { + for j := range sLen { server := servers.Servers[(serverOffset+j)%sLen] question := message.Question[0] question.Name = fqdn diff --git a/transport/sip003/args.go b/transport/sip003/args.go index b9fae3da..de6113f7 100644 --- a/transport/sip003/args.go +++ b/transport/sip003/args.go @@ -105,15 +105,3 @@ func ParsePluginOptions(s string) (opts Args, err error) { } return opts, nil } - -// Escape backslashes and all the bytes that are in set. -func backslashEscape(s string, set []byte) string { - var buf bytes.Buffer - for _, b := range []byte(s) { - if b == '\\' || bytes.IndexByte(set, b) != -1 { - buf.WriteByte('\\') - } - buf.WriteByte(b) - } - return buf.String() -} diff --git a/transport/v2raygrpc/client.go b/transport/v2raygrpc/client.go index 5af53856..a915e1eb 100644 --- a/transport/v2raygrpc/client.go +++ b/transport/v2raygrpc/client.go @@ -10,7 +10,6 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -100,7 +99,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { return nil, err } client := NewGunServiceClient(clientConn).(GunServiceCustomNameClient) - ctx, cancel := common.ContextWithCancelCause(ctx) + ctx, cancel := context.WithCancelCause(ctx) stream, err := client.TunCustomName(ctx, c.serviceName) if err != nil { cancel(err) diff --git a/transport/v2raygrpc/credentials/credentials.go b/transport/v2raygrpc/credentials/credentials.go index 32c9b590..9deee7f6 100644 --- a/transport/v2raygrpc/credentials/credentials.go +++ b/transport/v2raygrpc/credentials/credentials.go @@ -25,12 +25,12 @@ import ( type requestInfoKey struct{} // NewRequestInfoContext creates a context with ri. -func NewRequestInfoContext(ctx context.Context, ri interface{}) context.Context { +func NewRequestInfoContext(ctx context.Context, ri any) context.Context { return context.WithValue(ctx, requestInfoKey{}, ri) } // RequestInfoFromContext extracts the RequestInfo from ctx. -func RequestInfoFromContext(ctx context.Context) interface{} { +func RequestInfoFromContext(ctx context.Context) any { return ctx.Value(requestInfoKey{}) } @@ -39,11 +39,11 @@ func RequestInfoFromContext(ctx context.Context) interface{} { type clientHandshakeInfoKey struct{} // ClientHandshakeInfoFromContext extracts the ClientHandshakeInfo from ctx. -func ClientHandshakeInfoFromContext(ctx context.Context) interface{} { +func ClientHandshakeInfoFromContext(ctx context.Context) any { return ctx.Value(clientHandshakeInfoKey{}) } // NewClientHandshakeInfoContext creates a context with chi. -func NewClientHandshakeInfoContext(ctx context.Context, chi interface{}) context.Context { +func NewClientHandshakeInfoContext(ctx context.Context, chi any) context.Context { return context.WithValue(ctx, clientHandshakeInfoKey{}, chi) } diff --git a/transport/v2raygrpc/credentials/util.go b/transport/v2raygrpc/credentials/util.go index f792fd22..ab864977 100644 --- a/transport/v2raygrpc/credentials/util.go +++ b/transport/v2raygrpc/credentials/util.go @@ -20,16 +20,15 @@ package credentials import ( "crypto/tls" + "slices" ) const alpnProtoStrH2 = "h2" // AppendH2ToNextProtos appends h2 to next protos. func AppendH2ToNextProtos(ps []string) []string { - for _, p := range ps { - if p == alpnProtoStrH2 { - return ps - } + if slices.Contains(ps, alpnProtoStrH2) { + return ps } ret := make([]string, 0, len(ps)+1) ret = append(ret, ps...) diff --git a/transport/v2raygrpc/server.go b/transport/v2raygrpc/server.go index 4d426aa1..6160c2f3 100644 --- a/transport/v2raygrpc/server.go +++ b/transport/v2raygrpc/server.go @@ -60,7 +60,7 @@ func (s *Server) Tun(server GunService_TunServer) error { if grpcMetadata, loaded := gM.FromIncomingContext(server.Context()); loaded { forwardFrom := strings.Join(grpcMetadata.Get("X-Forwarded-For"), ",") if forwardFrom != "" { - for _, from := range strings.Split(forwardFrom, ",") { + for from := range strings.SplitSeq(forwardFrom, ",") { originAddr := M.ParseSocksaddr(from) if originAddr.IsValid() { source = originAddr.Unwrap() diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index 54b7be86..e0e8b645 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -136,7 +136,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) sizes[0] = n if n > 3 { b := packets[0] - common.ClearArray(b[1:4]) + clear(b[1:4]) } eps[0] = remoteEndpoint(M.SocksaddrFromNet(addr).Unwrap().AddrPort()) count = 1 diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index a190baba..373a050d 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -7,6 +7,7 @@ import ( "net" "net/netip" "os" + "sync" "time" "github.com/sagernet/gvisor/pkg/buffer" @@ -42,6 +43,7 @@ type stackDevice struct { outbound chan *stack.PacketBuffer packetOutbound chan *buf.Buffer done chan struct{} + closeOnce sync.Once dispatcher stack.NetworkDispatcher inet4Address netip.Addr inet6Address netip.Addr @@ -146,11 +148,17 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) } var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() { + if !w.inet4Address.IsValid() { + return nil, E.New("missing IPv4 local address") + } networkProtocol = header.IPv4ProtocolNumber bind.Addr = tun.AddressFromAddr(w.inet4Address) } else { + if !w.inet6Address.IsValid() { + return nil, E.New("missing IPv6 local address") + } networkProtocol = header.IPv6ProtocolNumber - bind.Addr = tun.AddressFromAddr(w.inet4Address) + bind.Addr = tun.AddressFromAddr(w.inet6Address) } udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol) if err != nil { @@ -244,13 +252,15 @@ func (w *stackDevice) Events() <-chan wgTun.Event { } func (w *stackDevice) Close() error { - close(w.done) - close(w.events) - w.stack.Close() - for _, endpoint := range w.stack.CleanupEndpoints() { - endpoint.Abort() - } - w.stack.Wait() + w.closeOnce.Do(func() { + close(w.done) + close(w.events) + w.stack.Close() + for _, endpoint := range w.stack.CleanupEndpoints() { + endpoint.Abort() + } + w.stack.Wait() + }) return nil } diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index dcf2959b..1c0b8b6c 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -111,6 +111,7 @@ func (w *systemDevice) Start() error { } err = tunInterface.Start() if err != nil { + tunInterface.Close() return err } w.options.Logger.Info("started at ", w.options.Name) @@ -147,7 +148,7 @@ func (w *systemDevice) Write(bufs [][]byte, offset int) (count int, err error) { } else { for _, packet := range bufs { if tun.PacketOffset > 0 { - common.ClearArray(packet[offset-tun.PacketOffset : offset]) + clear(packet[offset-tun.PacketOffset : offset]) tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:])) } _, err = w.device.Write(packet[offset-tun.PacketOffset:]) @@ -177,8 +178,14 @@ func (w *systemDevice) Events() <-chan wgTun.Event { } func (w *systemDevice) Close() error { - close(w.events) - return w.device.Close() + var err error + w.closeOnce.Do(func() { + close(w.events) + if w.device != nil { + err = w.device.Close() + } + }) + return err } func (w *systemDevice) BatchSize() int { diff --git a/transport/wireguard/device_system_stack.go b/transport/wireguard/device_system_stack.go index 94fd6f4f..59c5f4ab 100644 --- a/transport/wireguard/device_system_stack.go +++ b/transport/wireguard/device_system_stack.go @@ -5,6 +5,7 @@ package wireguard import ( "context" "net/netip" + "sync" "time" "github.com/sagernet/gvisor/pkg/buffer" @@ -20,7 +21,6 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun/ping" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" "github.com/sagernet/wireguard-go/device" @@ -35,6 +35,7 @@ type systemStackDevice struct { stack *stack.Stack endpoint *deviceEndpoint writeBufs [][]byte + closeOnce sync.Once } func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) { @@ -104,13 +105,13 @@ func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err err } } if len(w.writeBufs) > 0 { - return w.batchDevice.BatchWrite(bufs, offset) + return w.batchDevice.BatchWrite(w.writeBufs, offset) } } else { for _, packet := range bufs { if !w.writeStack(packet[offset:]) { if tun.PacketOffset > 0 { - common.ClearArray(packet[offset-tun.PacketOffset : offset]) + clear(packet[offset-tun.PacketOffset : offset]) tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:])) } _, err = w.device.Write(packet[offset-tun.PacketOffset:]) @@ -125,13 +126,17 @@ func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err err } func (w *systemStackDevice) Close() error { - close(w.endpoint.done) - w.stack.Close() - for _, endpoint := range w.stack.CleanupEndpoints() { - endpoint.Abort() - } - w.stack.Wait() - return w.systemDevice.Close() + var err error + w.closeOnce.Do(func() { + close(w.endpoint.done) + w.stack.Close() + for _, endpoint := range w.stack.CleanupEndpoints() { + endpoint.Abort() + } + w.stack.Wait() + err = w.systemDevice.Close() + }) + return err } func (w *systemStackDevice) writeStack(packet []byte) bool { diff --git a/transport/wireguard/endpoint.go b/transport/wireguard/endpoint.go index 3a02e17a..84d9fe72 100644 --- a/transport/wireguard/endpoint.go +++ b/transport/wireguard/endpoint.go @@ -182,10 +182,10 @@ func (e *Endpoint) Start(resolve bool) error { return err } logger := &device.Logger{ - Verbosef: func(format string, args ...interface{}) { + Verbosef: func(format string, args ...any) { e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...)) }, - Errorf: func(format string, args ...interface{}) { + Errorf: func(format string, args ...any) { e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...)) }, } @@ -197,13 +197,15 @@ func (e *Endpoint) Start(resolve bool) error { } wgDevice := device.NewDevice(e.options.Context, deviceInput, bind, logger, e.options.Workers) e.tunDevice.SetDevice(wgDevice) - ipcConf := e.ipcConf + var ipcConf strings.Builder + ipcConf.WriteString(e.ipcConf) for _, peer := range e.peers { - ipcConf += peer.GenerateIpcLines() + ipcConf.WriteString(peer.GenerateIpcLines()) } - err = wgDevice.IpcSet(ipcConf) + err = wgDevice.IpcSet(ipcConf.String()) if err != nil { - return E.Cause(err, "setup wireguard: \n", ipcConf) + wgDevice.Close() + return E.Cause(err, "setup wireguard: \n", ipcConf.String()) } e.device = wgDevice e.pause = service.FromContext[pause.Manager](e.options.Context) @@ -231,10 +233,12 @@ func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n func (e *Endpoint) Close() error { if e.pauseCallback != nil { e.pause.UnregisterCallback(e.pauseCallback) + e.pauseCallback = nil } if e.device != nil { e.device.Down() e.device.Close() + e.device = nil } return nil } @@ -273,18 +277,19 @@ type peerConfig struct { } func (c peerConfig) GenerateIpcLines() string { - ipcLines := "\npublic_key=" + c.publicKeyHex + var ipcLines strings.Builder + ipcLines.WriteString("\npublic_key=" + c.publicKeyHex) if c.endpoint.IsValid() { - ipcLines += "\nendpoint=" + c.endpoint.String() + ipcLines.WriteString("\nendpoint=" + c.endpoint.String()) } if c.preSharedKeyHex != "" { - ipcLines += "\npreshared_key=" + c.preSharedKeyHex + ipcLines.WriteString("\npreshared_key=" + c.preSharedKeyHex) } for _, allowedIP := range c.allowedIPs { - ipcLines += "\nallowed_ip=" + allowedIP.String() + ipcLines.WriteString("\nallowed_ip=" + allowedIP.String()) } if c.keepalive > 0 { - ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive) + ipcLines.WriteString("\npersistent_keepalive_interval=" + F.ToString(c.keepalive)) } - return ipcLines + return ipcLines.String() }