Compare commits

..

101 Commits

Author SHA1 Message Date
世界
c3cc010880 documentation: Bump version 2025-06-06 23:17:46 +08:00
世界
1920c191be Fix systemd package 2025-06-06 23:17:46 +08:00
世界
e0ac459204 Fix missing home for derp service 2025-06-06 23:17:46 +08:00
Zero Clover
09fb897805 documentation: Fix services 2025-06-06 23:17:46 +08:00
世界
a1b3d891a3 Fix dns.client_subnet ignored 2025-06-06 23:17:46 +08:00
世界
d866a40469 documentation: Minor fixes 2025-06-06 23:17:46 +08:00
世界
45cd04b07e Fix tailscale forward 2025-06-06 23:17:46 +08:00
世界
2cf0528c4d Minor fixes 2025-06-06 23:17:46 +08:00
世界
905a2ded93 Add SSM API service 2025-06-06 23:17:46 +08:00
世界
cb3c0829c5 Add resolved service and DNS server 2025-06-06 23:17:45 +08:00
世界
1a8f6e053d Add DERP service 2025-06-06 23:17:27 +08:00
世界
99a09a6ce5 Add service component type 2025-06-06 23:17:27 +08:00
世界
01b4c7fcdd Fix tproxy tcp control 2025-06-06 23:17:27 +08:00
愚者
fe89f946c1 release: Fix build tags for android
Signed-off-by: 愚者 <11926619+FansChou@users.noreply.github.com>
2025-06-06 23:17:27 +08:00
世界
6c17c7a8f5 prevent creation of bind and mark controls on unsupported platforms 2025-06-06 23:17:27 +08:00
PuerNya
ea067e5478 documentation: Fix description of reject DNS action behavior 2025-06-06 23:17:27 +08:00
Restia-Ashbell
75af9a824e Fix TLS record fragment 2025-06-06 23:17:27 +08:00
世界
a5d4a42119 Add missing accept_routes option for Tailscale 2025-06-06 23:17:26 +08:00
世界
9821fbc3e3 Add TLS record fragment support 2025-06-06 23:17:26 +08:00
世界
c0408ad1de release: Update Go to 1.24.3 2025-06-06 23:17:26 +08:00
世界
6b0e861afa Fix set edns0 client subnet 2025-06-06 23:17:26 +08:00
世界
e32d686d6c Update minor dependencies 2025-06-06 23:17:26 +08:00
世界
844308e128 Update certmagic and providers 2025-06-06 23:17:26 +08:00
世界
93c14db281 Update protobuf and grpc 2025-06-06 23:17:26 +08:00
世界
b893a27dfc Add control options for listeners 2025-06-06 23:17:25 +08:00
世界
d39960fa23 Update quic-go to v0.52.0 2025-06-06 23:17:25 +08:00
世界
ba0badd4bf Update utls to v1.7.2 2025-06-06 23:17:25 +08:00
世界
cfbb5d63d5 Handle EDNS version downgrade 2025-06-06 23:16:21 +08:00
世界
8447a3edfe documentation: Fix anytls padding scheme description 2025-06-06 23:16:21 +08:00
安容
1a9747a531 Report invalid DNS address early 2025-06-06 23:16:20 +08:00
世界
583ecbea3b Fix wireguard listen_port 2025-06-06 23:16:20 +08:00
世界
bb6c8535a5 clash-api: Add more meta api 2025-06-06 23:16:19 +08:00
世界
10d90e4acc Fix DNS lookup 2025-06-06 23:16:19 +08:00
世界
e625012219 Fix fetch ECH configs 2025-06-06 23:16:19 +08:00
reletor
670863fd5b documentation: Minor fixes 2025-06-06 23:16:19 +08:00
caelansar
f7cf87142f Fix callback deletion in UDP transport 2025-06-06 23:16:18 +08:00
世界
2597a68a01 documentation: Try to make the play review happy 2025-06-06 23:16:18 +08:00
世界
7354332daa Fix missing handling of legacy domain_strategy options 2025-06-06 23:16:18 +08:00
世界
a0d382fc4e Improve local DNS server 2025-06-06 23:16:18 +08:00
anytls
a6da8b6654 Update anytls
Co-authored-by: anytls <anytls>
2025-06-06 23:16:17 +08:00
世界
7385616cca Fix DNS dialer 2025-06-06 23:16:17 +08:00
世界
4b6784b446 release: Skip override version for iOS 2025-06-06 23:16:16 +08:00
iikira
68579bb93b Fix UDP DNS server crash
Signed-off-by: iikira <i2@mail.iikira.com>
2025-06-06 23:16:16 +08:00
ReleTor
6aace7b1b7 Fix fetch ECH configs 2025-06-06 23:16:16 +08:00
世界
148234b742 Allow direct outbounds without domain_resolver 2025-06-06 23:16:16 +08:00
世界
97b7a451be Fix Tailscale dialer 2025-06-06 23:16:15 +08:00
dyhkwong
73b67e0b48 Fix DNS over QUIC stream close 2025-06-06 23:16:15 +08:00
anytls
88b4d04d59 Update anytls
Co-authored-by: anytls <anytls>
2025-06-06 23:16:15 +08:00
Rambling2076
d1ec6c6dd2 Fix missing with_tailscale in Dockerfile
Signed-off-by: Rambling2076 <Rambling2076@proton.me>
2025-06-06 23:16:14 +08:00
世界
523825336a Fail when default DNS server not found 2025-06-06 23:16:14 +08:00
世界
032565a026 Update gVisor to 20250319.0 2025-06-06 23:16:14 +08:00
世界
aeea24ae30 Explicitly reject detour to empty direct outbounds 2025-06-06 23:16:14 +08:00
世界
af22549f1a Add netns support 2025-06-06 23:16:14 +08:00
世界
57b17ceb4b Add wildcard name support for predefined records 2025-06-06 23:16:13 +08:00
世界
3dd308e7c3 Remove map usage in options 2025-06-06 23:16:13 +08:00
世界
7f75195d86 Fix unhandled DNS loop 2025-06-06 23:16:13 +08:00
世界
2fe4cad905 Add wildcard-sni support for shadow-tls inbound 2025-06-06 23:16:12 +08:00
k9982874
f55eb75a53 Add ntp protocol sniffing 2025-06-06 23:16:12 +08:00
世界
5ffb5b6ad2 option: Fix marshal legacy DNS options 2025-06-06 23:16:12 +08:00
世界
a1d5931759 Make domain_resolver optional when only one DNS server is configured 2025-06-06 23:16:12 +08:00
世界
9e68e909cb Fix DNS lookup context pollution 2025-06-06 23:16:11 +08:00
世界
117e8b76cc Fix http3 DNS server connecting to wrong address 2025-06-06 23:16:11 +08:00
Restia-Ashbell
d2f83bfd50 documentation: Fix typo 2025-06-06 23:16:11 +08:00
anytls
eaef13febe Update sing-anytls
Co-authored-by: anytls <anytls>
2025-06-06 23:16:11 +08:00
k9982874
0110c69dc9 Fix hosts DNS server 2025-06-06 23:16:10 +08:00
世界
fb2f5af1fb Fix UDP DNS server crash 2025-06-06 23:16:10 +08:00
世界
1553923118 documentation: Fix missing ip_accept_any DNS rule option 2025-06-06 23:16:10 +08:00
世界
0ada49489d Fix anytls dialer usage 2025-06-06 23:16:10 +08:00
世界
95d5ca9393 Move predefined DNS server to rule action 2025-06-06 23:16:10 +08:00
世界
6cebbb4590 Fix domain resolver on direct outbound 2025-06-06 23:16:09 +08:00
Zephyruso
0ef81bb5ef Fix missing AnyTLS display name 2025-06-06 23:16:09 +08:00
anytls
0d30a1df9d Update sing-anytls
Co-authored-by: anytls <anytls>
2025-06-06 23:16:09 +08:00
Estel
563499d2f9 documentation: Fix typo
Signed-off-by: Estel <callmebedrockdigger@gmail.com>
2025-06-06 23:16:08 +08:00
TargetLocked
f10c0c1c8d Fix parsing legacy DNS options 2025-06-06 23:16:08 +08:00
世界
428074d88b Fix DNS fallback 2025-06-06 23:16:07 +08:00
世界
fa18832ad2 documentation: Fix missing hosts DNS server 2025-06-06 23:16:07 +08:00
anytls
87bce2de29 Add MinIdleSession option to AnyTLS outbound
Co-authored-by: anytls <anytls>
2025-06-06 23:16:06 +08:00
ReleTor
f5020554e4 documentation: Minor fixes 2025-06-06 23:16:06 +08:00
libtry486
31f3623b8a documentation: Fix typo
fix typo

Signed-off-by: libtry486 <89328481+libtry486@users.noreply.github.com>
2025-06-06 23:16:05 +08:00
Alireza Ahmadi
bb42657177 Fix Outbound deadlock 2025-06-06 23:16:05 +08:00
世界
f19ff7eca7 documentation: Fix AnyTLS doc 2025-06-06 23:16:05 +08:00
anytls
8e45133f2e Add AnyTLS protocol 2025-06-06 23:16:04 +08:00
世界
63df88675f Migrate to stdlib ECH support 2025-06-06 23:16:04 +08:00
世界
0423244298 Add fallback local DNS server for iOS 2025-06-06 23:16:03 +08:00
世界
a5f1af9587 Get darwin local DNS server from libresolv 2025-06-06 23:16:03 +08:00
世界
112817c1a4 Improve resolve action 2025-06-06 23:16:02 +08:00
世界
6e91de51f1 Add back port hopping to hysteria 1 2025-06-06 23:16:02 +08:00
xchacha20-poly1305
efc5c542fb Remove single quotes of raw Moziila certs 2025-06-06 23:16:02 +08:00
世界
f1b569c7d1 Add Tailscale endpoint 2025-06-06 23:16:02 +08:00
世界
a752197d5e Build legacy binaries with latest Go 2025-06-06 23:16:01 +08:00
世界
65517d4513 documentation: Remove outdated icons 2025-06-06 23:16:01 +08:00
世界
ccf4fa4d3a documentation: Certificate store 2025-06-06 23:16:01 +08:00
世界
18dbb823a1 documentation: TLS fragment 2025-06-06 23:16:01 +08:00
世界
4ec058e91a documentation: Outbound domain resolver 2025-06-06 23:16:01 +08:00
世界
6eed06b2c2 documentation: Refactor DNS 2025-06-06 23:16:00 +08:00
世界
dd209cc9d5 Add certificate store 2025-06-06 23:16:00 +08:00
世界
b0c0a6b07d Add TLS fragment support 2025-06-06 23:15:59 +08:00
世界
951a8fabbf refactor: Outbound domain resolver 2025-06-06 23:15:59 +08:00
世界
928298b528 refactor: DNS 2025-06-06 23:15:59 +08:00
世界
5b84fa0137 Fix default network strategy 2025-06-06 14:50:38 +08:00
世界
2bb85ac8a1 Fix slowOpenConn 2025-06-06 14:39:40 +08:00
424 changed files with 21425 additions and 11546 deletions

View File

@@ -8,11 +8,15 @@
--deb-field "Bug: https://github.com/SagerNet/sing-box/issues"
--no-deb-generate-changes
--config-files /etc/sing-box/config.json
--after-install release/config/sing-box.postinst
release/config/config.json=/etc/sing-box/config.json
release/config/sing-box.service=/usr/lib/systemd/system/sing-box.service
release/config/sing-box@.service=/usr/lib/systemd/system/sing-box@.service
release/config/sing-box.sysusers=/usr/lib/sysusers.d/sing-box.conf
release/config/sing-box.rules=usr/share/polkit-1/rules.d/sing-box.rules
release/config/sing-box-split-dns.xml=/usr/share/dbus-1/system.d/sing-box-split-dns.conf
release/completions/sing-box.bash=/usr/share/bash-completion/completions/sing-box.bash
release/completions/sing-box.fish=/usr/share/fish/vendor_completions.d/sing-box.fish

25
.github/setup_legacy_go.sh vendored Executable file
View File

@@ -0,0 +1,25 @@
#!/usr/bin/env bash
VERSION="1.23.6"
mkdir -p $HOME/go
cd $HOME/go
wget "https://dl.google.com/go/go${VERSION}.linux-amd64.tar.gz"
tar -xzf "go${VERSION}.linux-amd64.tar.gz"
mv go go_legacy
cd go_legacy
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.23.x
# that means after golang1.24 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.23/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
curl https://github.com/MetaCubeX/go/commit/9ac42137ef6730e8b7daca016ece831297a1d75b.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/21290de8a4c91408de7c2b5b68757b1e90af49dd.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/6a31d3fa8e47ddabc10bd97bff10d9a85f4cfb76.diff | patch --verbose -p 1
curl https://github.com/MetaCubeX/go/commit/69e2eed6dd0f6d815ebf15797761c13f31213dd6.diff | patch --verbose -p 1

View File

@@ -46,7 +46,7 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.24
go-version: ^1.24.3
- name: Check input version
if: github.event_name == 'workflow_dispatch'
run: |-
@@ -94,7 +94,6 @@ jobs:
- { os: windows, arch: arm64 }
- { os: darwin, arch: amd64 }
- { os: darwin, arch: amd64, legacy_go: true }
- { os: darwin, arch: arm64 }
- { os: android, arch: arm64, ndk: "aarch64-linux-android21" }
@@ -106,16 +105,28 @@ jobs:
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with:
fetch-depth: 0
- name: Setup Go
if: matrix.legacy_go
uses: actions/setup-go@v5
with:
go-version: ~1.20
- name: Setup Go
if: ${{ ! matrix.legacy_go }}
uses: actions/setup-go@v5
with:
go-version: ^1.24
go-version: ^1.24.3
- name: Cache Legacy Go
if: matrix.require_legacy_go
id: cache-legacy-go
uses: actions/cache@v4
with:
path: |
~/go/go_legacy
key: go_legacy_1236
- name: Setup Legacy Go
if: matrix.legacy_go && steps.cache-legacy-go.outputs.cache-hit != 'true'
run: |-
.github/setup_legacy_go.sh
- name: Setup Legacy Go 2
if: matrix.legacy_go
run: |-
echo "PATH=$HOME/go/go_legacy/bin:$PATH" >> $GITHUB_ENV
echo "GOROOT=$HOME/go/go_legacy" >> $GITHUB_ENV
- name: Setup Android NDK
if: matrix.os == 'android'
uses: nttld/setup-ndk@v1
@@ -129,10 +140,7 @@ jobs:
- name: Set build tags
run: |
set -xeuo pipefail
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_reality_server,with_acme,with_clash_api'
if [ ! '${{ matrix.legacy_go }}' = 'true' ]; then
TAGS="${TAGS},with_ech"
fi
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale'
echo "BUILD_TAGS=${TAGS}" >> "${GITHUB_ENV}"
- name: Build
if: matrix.os != 'android'
@@ -286,7 +294,7 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.24
go-version: ^1.24.3
- name: Setup Android NDK
id: setup-ndk
uses: nttld/setup-ndk@v1
@@ -366,7 +374,7 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.24
go-version: ^1.24.3
- name: Setup Android NDK
id: setup-ndk
uses: nttld/setup-ndk@v1
@@ -464,7 +472,7 @@ jobs:
if: matrix.if
uses: actions/setup-go@v5
with:
go-version: ^1.24
go-version: ^1.24.3
- name: Setup Xcode stable
if: matrix.if && github.ref == 'refs/heads/main-next'
run: |-
@@ -541,10 +549,13 @@ jobs:
MACOS_PROJECT_VERSION=$(go run -v ./cmd/internal/app_store_connect next_macos_project_version)
echo "MACOS_PROJECT_VERSION=$MACOS_PROJECT_VERSION"
echo "MACOS_PROJECT_VERSION=$MACOS_PROJECT_VERSION" >> "$GITHUB_ENV"
- name: Update version
if: matrix.if && matrix.name != 'iOS'
run: |-
go run -v ./cmd/internal/update_apple_version --ci
- name: Build
if: matrix.if
run: |-
go run -v ./cmd/internal/update_apple_version --ci
cd clients/apple
xcodebuild archive \
-scheme "${{ matrix.scheme }}" \

View File

@@ -28,7 +28,7 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.24
go-version: ^1.24.3
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:

View File

@@ -25,7 +25,7 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.24
go-version: ^1.24.3
- name: Check input version
if: github.event_name == 'workflow_dispatch'
run: |-
@@ -66,7 +66,7 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.24
go-version: ^1.24.3
- name: Setup Android NDK
if: matrix.os == 'android'
uses: nttld/setup-ndk@v1
@@ -80,10 +80,7 @@ jobs:
- name: Set build tags
run: |
set -xeuo pipefail
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_reality_server,with_acme,with_clash_api'
if [ ! '${{ matrix.legacy_go }}' = 'true' ]; then
TAGS="${TAGS},with_ech"
fi
TAGS='with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale'
echo "BUILD_TAGS=${TAGS}" >> "${GITHUB_ENV}"
- name: Build
run: |

View File

@@ -27,9 +27,7 @@ run:
- with_quic
- with_dhcp
- with_wireguard
- with_ech
- with_utls
- with_reality_server
- with_acme
- with_clash_api

View File

@@ -6,17 +6,18 @@ builds:
- -v
- -trimpath
ldflags:
- -X github.com/sagernet/sing-box/constant.Version={{ .Version }} -s -w -buildid=
- -X github.com/sagernet/sing-box/constant.Version={{ .Version }}
- -s
- -buildid=
tags:
- with_gvisor
- with_quic
- with_dhcp
- with_wireguard
- with_ech
- with_utls
- with_reality_server
- with_acme
- with_clash_api
- with_tailscale
env:
- CGO_ENABLED=0
targets:
@@ -54,6 +55,12 @@ nfpms:
dst: /usr/lib/systemd/system/sing-box.service
- src: release/config/sing-box@.service
dst: /usr/lib/systemd/system/sing-box@.service
- src: release/config/sing-box.sysusers
dst: /usr/lib/sysusers.d/sing-box.conf
- src: release/config/sing-box.rules
dst: /usr/share/polkit-1/rules.d/sing-box.rules
- src: release/config/sing-box-split-dns.xml
dst: /usr/share/dbus-1/system.d/sing-box-split-dns.conf
- src: release/completions/sing-box.bash
dst: /usr/share/bash-completion/completions/sing-box.bash

View File

@@ -16,13 +16,13 @@ builds:
- with_quic
- with_dhcp
- with_wireguard
- with_ech
- with_utls
- with_reality_server
- with_acme
- with_clash_api
- with_tailscale
env:
- CGO_ENABLED=0
- GOTOOLCHAIN=local
targets:
- linux_386
- linux_amd64_v1
@@ -46,21 +46,21 @@ builds:
- with_dhcp
- with_wireguard
- with_utls
- with_reality_server
- with_acme
- with_clash_api
- with_tailscale
env:
- CGO_ENABLED=0
- GOROOT={{ .Env.GOPATH }}/go1.20.14
tool: "{{ .Env.GOPATH }}/go1.20.14/bin/go"
- GOROOT={{ .Env.GOPATH }}/go_legacy
tool: "{{ .Env.GOPATH }}/go_legacy/bin/go"
targets:
- windows_amd64_v1
- windows_386
- darwin_amd64_v1
- id: android
<<: *template
env:
- CGO_ENABLED=1
- GOTOOLCHAIN=local
overrides:
- goos: android
goarch: arm
@@ -136,6 +136,12 @@ nfpms:
dst: /usr/lib/systemd/system/sing-box.service
- src: release/config/sing-box@.service
dst: /usr/lib/systemd/system/sing-box@.service
- src: release/config/sing-box.sysusers
dst: /usr/lib/sysusers.d/sing-box.conf
- src: release/config/sing-box.rules
dst: /usr/share/polkit-1/rules.d/sing-box.rules
- src: release/config/sing-box-split-dns.xml
dst: /usr/share/dbus-1/system.d/sing-box-split-dns.conf
- src: release/completions/sing-box.bash
dst: /usr/share/bash-completion/completions/sing-box.bash

View File

@@ -13,7 +13,7 @@ RUN set -ex \
&& export COMMIT=$(git rev-parse --short HEAD) \
&& export VERSION=$(go run ./cmd/internal/read_tag) \
&& go build -v -trimpath -tags \
"with_gvisor,with_quic,with_dhcp,with_wireguard,with_ech,with_utls,with_reality_server,with_acme,with_clash_api" \
"with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale" \
-o /go/bin/sing-box \
-ldflags "-X \"github.com/sagernet/sing-box/constant.Version=$VERSION\" -s -w -buildid=" \
./cmd/sing-box

View File

@@ -1,13 +1,10 @@
NAME = sing-box
COMMIT = $(shell git rev-parse --short HEAD)
TAGS_GO120 = with_gvisor,with_dhcp,with_wireguard,with_reality_server,with_clash_api,with_quic,with_utls
TAGS_GO121 = with_ech
TAGS ?= $(TAGS_GO118),$(TAGS_GO120),$(TAGS_GO121)
TAGS_TEST ?= with_gvisor,with_quic,with_wireguard,with_grpc,with_ech,with_utls,with_reality_server
TAGS ?= with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale
GOHOSTOS = $(shell go env GOHOSTOS)
GOHOSTARCH = $(shell go env GOHOSTARCH)
VERSION=$(shell CGO_ENABLED=0 GOOS=$(GOHOSTOS) GOARCH=$(GOHOSTARCH) go run ./cmd/internal/read_tag)
VERSION=$(shell CGO_ENABLED=0 GOOS=$(GOHOSTOS) GOARCH=$(GOHOSTARCH) go run github.com/sagernet/sing-box/cmd/internal/read_tag@latest)
PARAMS = -v -trimpath -ldflags "-X 'github.com/sagernet/sing-box/constant.Version=$(VERSION)' -s -w -buildid="
MAIN_PARAMS = $(PARAMS) -tags "$(TAGS)"
@@ -17,14 +14,12 @@ PREFIX ?= $(shell go env GOPATH)
.PHONY: test release docs build
build:
export GOTOOLCHAIN=local && \
go build $(MAIN_PARAMS) $(MAIN)
ci_build_go120:
go build $(PARAMS) $(MAIN)
go build $(PARAMS) -tags "$(TAGS_GO120)" $(MAIN)
ci_build:
go build $(PARAMS) $(MAIN)
export GOTOOLCHAIN=local && \
go build $(PARAMS) $(MAIN) && \
go build $(MAIN_PARAMS) $(MAIN)
generate_completions:
@@ -61,6 +56,9 @@ proto_install:
go install -v google.golang.org/protobuf/cmd/protoc-gen-go@latest
go install -v google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
update_certificates:
go run ./cmd/internal/update_certificates
release:
go run ./cmd/internal/build goreleaser release --clean --skip publish
mkdir dist/release
@@ -227,8 +225,8 @@ lib:
go run ./cmd/internal/build_libbox -target ios
lib_install:
go install -v github.com/sagernet/gomobile/cmd/gomobile@v0.1.4
go install -v github.com/sagernet/gomobile/cmd/gobind@v0.1.4
go install -v github.com/sagernet/gomobile/cmd/gomobile@v0.1.6
go install -v github.com/sagernet/gomobile/cmd/gobind@v0.1.6
docs:
venv/bin/mkdocs serve

21
adapter/certificate.go Normal file
View File

@@ -0,0 +1,21 @@
package adapter
import (
"context"
"crypto/x509"
"github.com/sagernet/sing/service"
)
type CertificateStore interface {
LifecycleService
Pool() *x509.CertPool
}
func RootPoolFromContext(ctx context.Context) *x509.CertPool {
store := service.FromContext[CertificateStore](ctx)
if store == nil {
return nil
}
return store.Pool()
}

94
adapter/dns.go Normal file
View File

@@ -0,0 +1,94 @@
package adapter
import (
"context"
"net/netip"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
"github.com/miekg/dns"
)
type DNSRouter interface {
Lifecycle
Exchange(ctx context.Context, message *dns.Msg, options DNSQueryOptions) (*dns.Msg, error)
Lookup(ctx context.Context, domain string, options DNSQueryOptions) ([]netip.Addr, error)
ClearCache()
LookupReverseMapping(ip netip.Addr) (string, bool)
ResetNetwork()
}
type DNSClient interface {
Start()
Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error)
Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error)
LookupCache(domain string, strategy C.DomainStrategy) ([]netip.Addr, bool)
ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool)
ClearCache()
}
type DNSQueryOptions struct {
Transport DNSTransport
Strategy C.DomainStrategy
LookupStrategy C.DomainStrategy
DisableCache bool
RewriteTTL *uint32
ClientSubnet netip.Prefix
}
func DNSQueryOptionsFrom(ctx context.Context, options *option.DomainResolveOptions) (*DNSQueryOptions, error) {
if options == nil {
return &DNSQueryOptions{}, nil
}
transportManager := service.FromContext[DNSTransportManager](ctx)
transport, loaded := transportManager.Transport(options.Server)
if !loaded {
return nil, E.New("domain resolver not found: " + options.Server)
}
return &DNSQueryOptions{
Transport: transport,
Strategy: C.DomainStrategy(options.Strategy),
DisableCache: options.DisableCache,
RewriteTTL: options.RewriteTTL,
ClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
}, nil
}
type RDRCStore interface {
LoadRDRC(transportName string, qName string, qType uint16) (rejected bool)
SaveRDRC(transportName string, qName string, qType uint16) error
SaveRDRCAsync(transportName string, qName string, qType uint16, logger logger.Logger)
}
type DNSTransport interface {
Lifecycle
Type() string
Tag() string
Dependencies() []string
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
}
type LegacyDNSTransport interface {
LegacyStrategy() C.DomainStrategy
LegacyClientSubnet() netip.Prefix
}
type DNSTransportRegistry interface {
option.DNSTransportOptionsRegistry
CreateDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) (DNSTransport, error)
}
type DNSTransportManager interface {
Lifecycle
Transports() []DNSTransport
Transport(tag string) (DNSTransport, bool)
Default() DNSTransport
FakeIP() FakeIPTransport
Remove(tag string) error
Create(ctx context.Context, logger log.ContextLogger, tag string, outboundType string, options any) error
}

View File

@@ -35,6 +35,7 @@ func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Ma
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
@@ -42,12 +43,9 @@ func (m *Manager) Start(stage adapter.StartStage) error {
m.stage = stage
if stage == adapter.StartStateStart {
// started with outbound manager
m.access.Unlock()
return nil
}
endpoints := m.endpoints
m.access.Unlock()
for _, endpoint := range endpoints {
for _, endpoint := range m.endpoints {
err := adapter.LegacyStart(endpoint, stage)
if err != nil {
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")

View File

@@ -6,8 +6,6 @@ import (
"encoding/binary"
"time"
"github.com/sagernet/sing-box/common/urltest"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/varbin"
)
@@ -16,7 +14,20 @@ type ClashServer interface {
ConnectionTracker
Mode() string
ModeList() []string
HistoryStorage() *urltest.HistoryStorage
HistoryStorage() URLTestHistoryStorage
}
type URLTestHistory struct {
Time time.Time `json:"time"`
Delay uint16 `json:"delay"`
}
type URLTestHistoryStorage interface {
SetHook(hook chan<- struct{})
LoadURLTestHistory(tag string) *URLTestHistory
DeleteURLTestHistory(tag string)
StoreURLTestHistory(tag string, history *URLTestHistory)
Close() error
}
type V2RayServer interface {
@@ -31,9 +42,7 @@ type CacheFile interface {
FakeIPStorage
StoreRDRC() bool
dns.RDRCStore
StoreWARPConfig() bool
RDRCStore
LoadMode() string
StoreMode(mode string) error
@@ -43,8 +52,6 @@ type CacheFile interface {
StoreGroupExpand(group string, expand bool) error
LoadRuleSet(tag string) *SavedBinary
SaveRuleSet(tag string, set *SavedBinary) error
LoadWARPConfig(tag string) *SavedBinary
SaveWARPConfig(tag string, set *SavedBinary) error
}
type SavedBinary struct {

View File

@@ -3,12 +3,11 @@ package adapter
import (
"net/netip"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/logger"
)
type FakeIPStore interface {
Service
SimpleLifecycle
Contains(address netip.Addr) bool
Create(domain string, isIPv6 bool) (netip.Addr, error)
Lookup(address netip.Addr) (string, bool)
@@ -27,6 +26,6 @@ type FakeIPStorage interface {
}
type FakeIPTransport interface {
dns.Transport
DNSTransport
Store() FakeIPStore
}

View File

@@ -42,16 +42,14 @@ type InboundManager interface {
}
type InboundContext struct {
Inbound string
InboundType string
IPVersion uint8
Network string
Source M.Socksaddr
Destination M.Socksaddr
TunnelSource string
TunnelDestination string
User string
Outbound string
Inbound string
InboundType string
IPVersion uint8
Network string
Source M.Socksaddr
Destination M.Socksaddr
User string
Outbound string
// sniffer
@@ -74,14 +72,15 @@ type InboundContext struct {
UDPDisableDomainUnmapping bool
UDPConnect bool
UDPTimeout time.Duration
TLSFragment bool
TLSFragmentFallbackDelay time.Duration
TLSRecordFragment bool
NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType
FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration
DNSServer string
DestinationAddresses []netip.Addr
SourceGeoIPCode string
GeoIPCode string

View File

@@ -37,13 +37,14 @@ func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry, endp
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
for _, inbound := range m.inbounds {
inbounds := m.inbounds
m.access.Unlock()
for _, inbound := range inbounds {
err := adapter.LegacyStart(inbound, stage)
if err != nil {
return E.Cause(err, stage, " inbound/", inbound.Type(), "[", inbound.Tag(), "]")

View File

@@ -2,6 +2,11 @@ package adapter
import E "github.com/sagernet/sing/common/exceptions"
type SimpleLifecycle interface {
Start() error
Close() error
}
type StartStage uint8
const (

View File

@@ -28,14 +28,14 @@ func LegacyStart(starter any, stage StartStage) error {
}
type lifecycleServiceWrapper struct {
Service
SimpleLifecycle
name string
}
func NewLifecycleService(service Service, name string) LifecycleService {
func NewLifecycleService(service SimpleLifecycle, name string) LifecycleService {
return &lifecycleServiceWrapper{
Service: service,
name: name,
SimpleLifecycle: service,
name: name,
}
}
@@ -44,9 +44,9 @@ func (l *lifecycleServiceWrapper) Name() string {
}
func (l *lifecycleServiceWrapper) Start(stage StartStage) error {
return LegacyStart(l.Service, stage)
return LegacyStart(l.SimpleLifecycle, stage)
}
func (l *lifecycleServiceWrapper) Close() error {
return l.Service.Close()
return l.SimpleLifecycle.Close()
}

View File

@@ -29,12 +29,14 @@ type NetworkManager interface {
}
type NetworkOptions struct {
NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType
FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration
BindInterface string
RoutingMark uint32
BindInterface string
RoutingMark uint32
DomainResolver string
DomainResolveOptions DNSQueryOptions
NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType
FallbackNetworkType []C.InterfaceType
FallbackDelay time.Duration
}
type InterfaceUpdateListener interface {

View File

@@ -23,7 +23,7 @@ type Manager struct {
registry adapter.OutboundRegistry
endpoint adapter.EndpointManager
defaultTag string
access sync.Mutex
access sync.RWMutex
started bool
stage adapter.StartStage
outbounds []adapter.Outbound
@@ -169,15 +169,15 @@ func (m *Manager) Close() error {
}
func (m *Manager) Outbounds() []adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
m.access.RLock()
defer m.access.RUnlock()
return m.outbounds
}
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
m.access.Lock()
m.access.RLock()
outbound, found := m.outboundByTag[tag]
m.access.Unlock()
m.access.RUnlock()
if found {
return outbound, true
}
@@ -185,8 +185,8 @@ func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
}
func (m *Manager) Default() adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
m.access.RLock()
defer m.access.RUnlock()
if m.defaultOutbound != nil {
return m.defaultOutbound
} else {
@@ -196,9 +196,9 @@ func (m *Manager) Default() adapter.Outbound {
func (m *Manager) Remove(tag string) error {
m.access.Lock()
defer m.access.Unlock()
outbound, found := m.outboundByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.outboundByTag, tag)
@@ -232,7 +232,6 @@ func (m *Manager) Remove(tag string) error {
})
}
}
m.access.Unlock()
if started {
return common.Close(outbound)
}
@@ -247,8 +246,6 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(outbound, stage)
@@ -257,6 +254,8 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.
}
}
}
m.access.Lock()
defer m.access.Unlock()
if existsOutbound, loaded := m.outboundByTag[tag]; loaded {
if m.started {
err = common.Close(existsOutbound)

View File

@@ -2,44 +2,29 @@ package adapter
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/netip"
"sync"
"github.com/sagernet/sing-box/common/geoip"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/x/list"
mdns "github.com/miekg/dns"
"go4.org/netipx"
)
type Router interface {
Lifecycle
FakeIPStore() FakeIPStore
ConnectionRouter
PreMatch(metadata InboundContext) error
ConnectionRouterEx
GeoIPReader() *geoip.Reader
LoadGeosite(code string) (Rule, error)
RuleSet(tag string) (RuleSet, bool)
NeedWIFIState() bool
Exchange(ctx context.Context, message *mdns.Msg) (*mdns.Msg, error)
Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error)
LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error)
ClearDNSCache()
Rules() []Rule
AppendTracker(tracker ConnectionTracker)
ResetNetwork()
}
@@ -83,12 +68,14 @@ type RuleSetMetadata struct {
ContainsIPCIDRRule bool
}
type HTTPStartContext struct {
ctx context.Context
access sync.Mutex
httpClientCache map[string]*http.Client
}
func NewHTTPStartContext() *HTTPStartContext {
func NewHTTPStartContext(ctx context.Context) *HTTPStartContext {
return &HTTPStartContext{
ctx: ctx,
httpClientCache: make(map[string]*http.Client),
}
}
@@ -106,6 +93,10 @@ func (c *HTTPStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Clie
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(c.ctx),
RootCAs: RootPoolFromContext(c.ctx),
},
},
}
c.httpClientCache[detour] = httpClient

View File

@@ -11,9 +11,8 @@ type HeadlessRule interface {
type Rule interface {
HeadlessRule
Service
SimpleLifecycle
Type() string
UpdateGeosite() error
Action() RuleAction
}

View File

@@ -1,6 +1,27 @@
package adapter
import (
"context"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
type Service interface {
Start() error
Close() error
Lifecycle
Type() string
Tag() string
}
type ServiceRegistry interface {
option.ServiceOptionsRegistry
Create(ctx context.Context, logger log.ContextLogger, tag string, serviceType string, options any) (Service, error)
}
type ServiceManager interface {
Lifecycle
Services() []Service
Get(tag string) (Service, bool)
Remove(tag string) error
Create(ctx context.Context, logger log.ContextLogger, tag string, serviceType string, options any) error
}

View File

@@ -0,0 +1,21 @@
package service
type Adapter struct {
serviceType string
serviceTag string
}
func NewAdapter(serviceType string, serviceTag string) Adapter {
return Adapter{
serviceType: serviceType,
serviceTag: serviceTag,
}
}
func (a *Adapter) Type() string {
return a.serviceType
}
func (a *Adapter) Tag() string {
return a.serviceTag
}

144
adapter/service/manager.go Normal file
View File

@@ -0,0 +1,144 @@
package service
import (
"context"
"os"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
var _ adapter.ServiceManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.ServiceRegistry
access sync.Mutex
started bool
stage adapter.StartStage
services []adapter.Service
serviceByTag map[string]adapter.Service
}
func NewManager(logger log.ContextLogger, registry adapter.ServiceRegistry) *Manager {
return &Manager{
logger: logger,
registry: registry,
serviceByTag: make(map[string]adapter.Service),
}
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
services := m.services
m.access.Unlock()
for _, service := range services {
err := adapter.LegacyStart(service, stage)
if err != nil {
return E.Cause(err, stage, " service/", service.Type(), "[", service.Tag(), "]")
}
}
return nil
}
func (m *Manager) Close() error {
m.access.Lock()
defer m.access.Unlock()
if !m.started {
return nil
}
m.started = false
services := m.services
m.services = nil
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, service := range services {
monitor.Start("close service/", service.Type(), "[", service.Tag(), "]")
err = E.Append(err, service.Close(), func(err error) error {
return E.Cause(err, "close service/", service.Type(), "[", service.Tag(), "]")
})
monitor.Finish()
}
return nil
}
func (m *Manager) Services() []adapter.Service {
m.access.Lock()
defer m.access.Unlock()
return m.services
}
func (m *Manager) Get(tag string) (adapter.Service, bool) {
m.access.Lock()
service, found := m.serviceByTag[tag]
m.access.Unlock()
return service, found
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
service, found := m.serviceByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.serviceByTag, tag)
index := common.Index(m.services, func(it adapter.Service) bool {
return it == service
})
if index == -1 {
panic("invalid service index")
}
m.services = append(m.services[:index], m.services[index+1:]...)
started := m.started
m.access.Unlock()
if started {
return service.Close()
}
return nil
}
func (m *Manager) Create(ctx context.Context, logger log.ContextLogger, tag string, serviceType string, options any) error {
service, err := m.registry.Create(ctx, logger, tag, serviceType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(service, stage)
if err != nil {
return E.Cause(err, stage, " service/", service.Type(), "[", service.Tag(), "]")
}
}
}
if existsService, loaded := m.serviceByTag[tag]; loaded {
if m.started {
err = existsService.Close()
if err != nil {
return E.Cause(err, "close service/", existsService.Type(), "[", existsService.Tag(), "]")
}
}
existsIndex := common.Index(m.services, func(it adapter.Service) bool {
return it == existsService
})
if existsIndex == -1 {
panic("invalid service index")
}
m.services = append(m.services[:existsIndex], m.services[existsIndex+1:]...)
}
m.services = append(m.services, service)
m.serviceByTag[tag] = service
return nil
}

View File

@@ -0,0 +1,72 @@
package service
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type ConstructorFunc[T any] func(ctx context.Context, logger log.ContextLogger, tag string, options T) (adapter.Service, error)
func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
registry.register(outboundType, func() any {
return new(Options)
}, func(ctx context.Context, logger log.ContextLogger, tag string, rawOptions any) (adapter.Service, error) {
var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, logger, tag, common.PtrValueOrDefault(options))
})
}
var _ adapter.ServiceRegistry = (*Registry)(nil)
type (
optionsConstructorFunc func() any
constructorFunc func(ctx context.Context, logger log.ContextLogger, tag string, options any) (adapter.Service, error)
)
type Registry struct {
access sync.Mutex
optionsType map[string]optionsConstructorFunc
constructor map[string]constructorFunc
}
func NewRegistry() *Registry {
return &Registry{
optionsType: make(map[string]optionsConstructorFunc),
constructor: make(map[string]constructorFunc),
}
}
func (m *Registry) CreateOptions(outboundType string) (any, bool) {
m.access.Lock()
defer m.access.Unlock()
optionsConstructor, loaded := m.optionsType[outboundType]
if !loaded {
return nil, false
}
return optionsConstructor(), true
}
func (m *Registry) Create(ctx context.Context, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Service, error) {
m.access.Lock()
defer m.access.Unlock()
constructor, loaded := m.constructor[outboundType]
if !loaded {
return nil, E.New("outbound type not found: " + outboundType)
}
return constructor(ctx, logger, tag, options)
}
func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
m.access.Lock()
defer m.access.Unlock()
m.optionsType[outboundType] = optionsConstructor
m.constructor[outboundType] = constructor
}

18
adapter/ssm.go Normal file
View File

@@ -0,0 +1,18 @@
package adapter
import (
"net"
N "github.com/sagernet/sing/common/network"
)
type ManagedSSMServer interface {
Inbound
SetTracker(tracker SSMTracker)
UpdateUsers(users []string, uPSKs []string) error
}
type SSMTracker interface {
TrackConnection(conn net.Conn, metadata InboundContext) net.Conn
TrackPacketConnection(conn N.PacketConn, metadata InboundContext) N.PacketConn
}

View File

@@ -3,6 +3,6 @@ package adapter
import "time"
type TimeService interface {
Service
SimpleLifecycle
TimeFunc() func() time.Time
}

192
box.go
View File

@@ -12,11 +12,14 @@ import (
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/certificate"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/taskmonitor"
"github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/common/urltest"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/dns/transport/local"
"github.com/sagernet/sing-box/experimental"
"github.com/sagernet/sing-box/experimental/cachefile"
"github.com/sagernet/sing-box/experimental/libbox/platform"
@@ -32,20 +35,23 @@ import (
"github.com/sagernet/sing/service/pause"
)
var _ adapter.Service = (*Box)(nil)
var _ adapter.SimpleLifecycle = (*Box)(nil)
type Box struct {
createdAt time.Time
logFactory log.Factory
logger log.ContextLogger
network *route.NetworkManager
endpoint *endpoint.Manager
inbound *inbound.Manager
outbound *outbound.Manager
connection *route.ConnectionManager
router *route.Router
services []adapter.LifecycleService
done chan struct{}
createdAt time.Time
logFactory log.Factory
logger log.ContextLogger
network *route.NetworkManager
endpoint *endpoint.Manager
inbound *inbound.Manager
outbound *outbound.Manager
service *boxService.Manager
dnsTransport *dns.TransportManager
dnsRouter *dns.Router
connection *route.ConnectionManager
router *route.Router
internalService []adapter.LifecycleService
done chan struct{}
}
type Options struct {
@@ -59,6 +65,8 @@ func Context(
inboundRegistry adapter.InboundRegistry,
outboundRegistry adapter.OutboundRegistry,
endpointRegistry adapter.EndpointRegistry,
dnsTransportRegistry adapter.DNSTransportRegistry,
serviceRegistry adapter.ServiceRegistry,
) context.Context {
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.InboundRegistry](ctx) == nil {
@@ -75,6 +83,14 @@ func Context(
ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry)
ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry)
}
if service.FromContext[adapter.DNSTransportRegistry](ctx) == nil {
ctx = service.ContextWith[option.DNSTransportOptionsRegistry](ctx, dnsTransportRegistry)
ctx = service.ContextWith[adapter.DNSTransportRegistry](ctx, dnsTransportRegistry)
}
if service.FromContext[adapter.ServiceRegistry](ctx) == nil {
ctx = service.ContextWith[option.ServiceOptionsRegistry](ctx, serviceRegistry)
ctx = service.ContextWith[adapter.ServiceRegistry](ctx, serviceRegistry)
}
return ctx
}
@@ -89,6 +105,8 @@ func New(options Options) (*Box, error) {
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx)
serviceRegistry := service.FromContext[adapter.ServiceRegistry](ctx)
if endpointRegistry == nil {
return nil, E.New("missing endpoint registry in context")
@@ -99,6 +117,12 @@ func New(options Options) (*Box, error) {
if outboundRegistry == nil {
return nil, E.New("missing outbound registry in context")
}
if dnsTransportRegistry == nil {
return nil, E.New("missing DNS transport registry in context")
}
if serviceRegistry == nil {
return nil, E.New("missing service registry in context")
}
ctx = pause.WithDefaultManager(ctx)
experimentalOptions := common.PtrValueOrDefault(options.Experimental)
@@ -115,9 +139,6 @@ func New(options Options) (*Box, error) {
if experimentalOptions.V2RayAPI != nil && experimentalOptions.V2RayAPI.Listen != "" {
needV2RayAPI = true
}
if experimentalOptions.UnifiedDelay != nil && experimentalOptions.UnifiedDelay.Enabled {
ctx = urltest.ContextWithIsUnifiedDelay(ctx)
}
platformInterface := service.FromContext[platform.Interface](ctx)
var defaultLogWriter io.Writer
if platformInterface != nil {
@@ -135,14 +156,34 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "create log factory")
}
var internalServices []adapter.LifecycleService
certificateOptions := common.PtrValueOrDefault(options.Certificate)
if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem ||
len(certificateOptions.Certificate) > 0 ||
len(certificateOptions.CertificatePath) > 0 ||
len(certificateOptions.CertificateDirectoryPath) > 0 {
certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), certificateOptions)
if err != nil {
return nil, err
}
service.MustRegister[adapter.CertificateStore](ctx, certificateStore)
internalServices = append(internalServices, certificateStore)
}
routeOptions := common.PtrValueOrDefault(options.Route)
dnsOptions := common.PtrValueOrDefault(options.DNS)
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final)
serviceManager := boxService.NewManager(logFactory.NewLogger("service"), serviceRegistry)
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
service.MustRegister[adapter.InboundManager](ctx, inboundManager)
service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager)
service.MustRegister[adapter.ServiceManager](ctx, serviceManager)
dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions)
service.MustRegister[adapter.DNSRouter](ctx, dnsRouter)
networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions)
if err != nil {
return nil, E.Cause(err, "initialize network manager")
@@ -150,18 +191,40 @@ func New(options Options) (*Box, error) {
service.MustRegister[adapter.NetworkManager](ctx, networkManager)
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection"))
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS))
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
service.MustRegister[adapter.Router](ctx, router)
err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet)
if err != nil {
return nil, E.Cause(err, "initialize router")
}
ntpOptions := common.PtrValueOrDefault(options.NTP)
var timeService *tls.TimeServiceWrapper
if ntpOptions.Enabled {
timeService = new(tls.TimeServiceWrapper)
service.MustRegister[ntp.TimeService](ctx, timeService)
}
for i, transportOptions := range dnsOptions.Servers {
var tag string
if transportOptions.Tag != "" {
tag = transportOptions.Tag
} else {
tag = F.ToString(i)
}
err = dnsTransportManager.Create(
ctx,
logFactory.NewLogger(F.ToString("dns/", transportOptions.Type, "[", tag, "]")),
tag,
transportOptions.Type,
transportOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize DNS server[", i, "]")
}
}
err = dnsRouter.Initialize(dnsOptions.Rules)
if err != nil {
return nil, E.Cause(err, "initialize dns router")
}
for i, endpointOptions := range options.Endpoints {
var tag string
if endpointOptions.Tag != "" {
@@ -185,7 +248,7 @@ func New(options Options) (*Box, error) {
endpointOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]")
return nil, E.Cause(err, "initialize endpoint[", i, "]")
}
}
for i, inboundOptions := range options.Inbounds {
@@ -233,6 +296,24 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "initialize outbound[", i, "]")
}
}
for i, serviceOptions := range options.Services {
var tag string
if serviceOptions.Tag != "" {
tag = serviceOptions.Tag
} else {
tag = F.ToString(i)
}
err = serviceManager.Create(
ctx,
logFactory.NewLogger(F.ToString("service/", serviceOptions.Type, "[", tag, "]")),
tag,
serviceOptions.Type,
serviceOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize service[", i, "]")
}
}
outboundManager.Initialize(common.Must1(
direct.NewOutbound(
ctx,
@@ -242,17 +323,23 @@ func New(options Options) (*Box, error) {
option.DirectOutboundOptions{},
),
))
dnsTransportManager.Initialize(common.Must1(
local.NewTransport(
ctx,
logFactory.NewLogger("dns/local"),
"local",
option.LocalDNSServerOptions{},
)))
if platformInterface != nil {
err = platformInterface.Initialize(networkManager)
if err != nil {
return nil, E.Cause(err, "initialize platform interface")
}
}
var services []adapter.LifecycleService
if needCacheFile {
cacheFile := cachefile.New(ctx, common.PtrValueOrDefault(experimentalOptions.CacheFile))
service.MustRegister[adapter.CacheFile](ctx, cacheFile)
services = append(services, cacheFile)
internalServices = append(internalServices, cacheFile)
}
if needClashAPI {
clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI)
@@ -263,7 +350,7 @@ func New(options Options) (*Box, error) {
}
router.AppendTracker(clashServer)
service.MustRegister[adapter.ClashServer](ctx, clashServer)
services = append(services, clashServer)
internalServices = append(internalServices, clashServer)
}
if needV2RayAPI {
v2rayServer, err := experimental.NewV2RayServer(logFactory.NewLogger("v2ray-api"), common.PtrValueOrDefault(experimentalOptions.V2RayAPI))
@@ -272,12 +359,12 @@ func New(options Options) (*Box, error) {
}
if v2rayServer.StatsService() != nil {
router.AppendTracker(v2rayServer.StatsService())
services = append(services, v2rayServer)
internalServices = append(internalServices, v2rayServer)
service.MustRegister[adapter.V2RayServer](ctx, v2rayServer)
}
}
if ntpOptions.Enabled {
ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions)
ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions, ntpOptions.ServerIsDomain())
if err != nil {
return nil, E.Cause(err, "create NTP service")
}
@@ -290,20 +377,23 @@ func New(options Options) (*Box, error) {
WriteToSystem: ntpOptions.WriteToSystem,
})
timeService.TimeService = ntpService
services = append(services, adapter.NewLifecycleService(ntpService, "ntp service"))
internalServices = append(internalServices, adapter.NewLifecycleService(ntpService, "ntp service"))
}
return &Box{
network: networkManager,
endpoint: endpointManager,
inbound: inboundManager,
outbound: outboundManager,
connection: connectionManager,
router: router,
createdAt: createdAt,
logFactory: logFactory,
logger: logFactory.Logger(),
services: services,
done: make(chan struct{}),
network: networkManager,
endpoint: endpointManager,
inbound: inboundManager,
outbound: outboundManager,
dnsTransport: dnsTransportManager,
service: serviceManager,
dnsRouter: dnsRouter,
connection: connectionManager,
router: router,
createdAt: createdAt,
logFactory: logFactory,
logger: logFactory.Logger(),
internalService: internalServices,
done: make(chan struct{}),
}, nil
}
@@ -353,15 +443,15 @@ func (s *Box) preStart() error {
if err != nil {
return E.Cause(err, "start logger")
}
err = adapter.StartNamed(adapter.StartStateInitialize, s.services) // cache-file clash-api v2ray-api
err = adapter.StartNamed(adapter.StartStateInitialize, s.internalService) // cache-file clash-api v2ray-api
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.connection, s.router)
err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router)
if err != nil {
return err
}
@@ -373,31 +463,27 @@ func (s *Box) start() error {
if err != nil {
return err
}
err = adapter.StartNamed(adapter.StartStateStart, s.services)
err = adapter.StartNamed(adapter.StartStateStart, s.internalService)
if err != nil {
return err
}
err = s.inbound.Start(adapter.StartStateStart)
err = adapter.Start(adapter.StartStateStart, s.inbound, s.endpoint, s.service)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStart, s.endpoint)
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint, s.service)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound, s.endpoint)
err = adapter.StartNamed(adapter.StartStatePostStart, s.internalService)
if err != nil {
return err
}
err = adapter.StartNamed(adapter.StartStatePostStart, s.services)
err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
err = adapter.StartNamed(adapter.StartStateStarted, s.services)
err = adapter.StartNamed(adapter.StartStateStarted, s.internalService)
if err != nil {
return err
}
@@ -412,9 +498,9 @@ func (s *Box) Close() error {
close(s.done)
}
err := common.Close(
s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.network,
s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network,
)
for _, lifecycleService := range s.services {
for _, lifecycleService := range s.internalService {
err = E.Append(err, lifecycleService.Close(), func(err error) error {
return E.Cause(err, "close ", lifecycleService.Name())
})

View File

@@ -105,7 +105,7 @@ func publishTestflight(ctx context.Context) error {
return err
}
tag := tagVersion.VersionString()
client := createClient(10 * time.Minute)
client := createClient(20 * time.Minute)
log.Info(tag, " list build IDs")
buildIDsResponse, _, err := client.TestFlight.ListBuildIDsForBetaGroup(ctx, groupID, nil)
@@ -145,7 +145,7 @@ func publishTestflight(ctx context.Context) error {
return err
}
build := builds.Data[0]
if common.Contains(buildIDs, build.ID) || time.Since(build.Attributes.UploadedDate.Time) > 5*time.Minute {
if common.Contains(buildIDs, build.ID) || time.Since(build.Attributes.UploadedDate.Time) > 30*time.Minute {
log.Info(string(platform), " ", tag, " waiting for process")
time.Sleep(15 * time.Second)
continue

View File

@@ -45,6 +45,7 @@ var (
debugFlags []string
sharedTags []string
iosTags []string
memcTags []string
debugTags []string
)
@@ -58,8 +59,9 @@ func init() {
sharedFlags = append(sharedFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag+" -s -w -buildid=")
debugFlags = append(debugFlags, "-ldflags", "-X github.com/sagernet/sing-box/constant.Version="+currentTag)
sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_ech", "with_utls", "with_clash_api")
iosTags = append(iosTags, "with_dhcp", "with_low_memory", "with_conntrack")
sharedTags = append(sharedTags, "with_gvisor", "with_quic", "with_wireguard", "with_utls", "with_clash_api", "with_conntrack")
iosTags = append(iosTags, "with_dhcp", "with_low_memory")
memcTags = append(memcTags, "with_tailscale")
debugTags = append(debugTags, "debug")
}
@@ -99,18 +101,19 @@ func buildAndroid() {
"-javapkg=io.nekohasekai",
"-libname=box",
}
if !debugEnabled {
args = append(args, sharedFlags...)
} else {
args = append(args, debugFlags...)
}
args = append(args, "-tags")
if !debugEnabled {
args = append(args, strings.Join(sharedTags, ","))
} else {
args = append(args, strings.Join(append(sharedTags, debugTags...), ","))
tags := append(sharedTags, memcTags...)
if debugEnabled {
tags = append(tags, debugTags...)
}
args = append(args, "-tags", strings.Join(tags, ","))
args = append(args, "./experimental/libbox")
command := exec.Command(build_shared.GoBinPath+"/gomobile", args...)
@@ -148,7 +151,9 @@ func buildApple() {
"-v",
"-target", bindTarget,
"-libname=box",
"-tags-macos=" + strings.Join(memcTags, ","),
}
if !debugEnabled {
args = append(args, sharedFlags...)
} else {
@@ -156,12 +161,11 @@ func buildApple() {
}
tags := append(sharedTags, iosTags...)
args = append(args, "-tags")
if !debugEnabled {
args = append(args, strings.Join(tags, ","))
} else {
args = append(args, strings.Join(append(tags, debugTags...), ","))
if debugEnabled {
tags = append(tags, debugTags...)
}
args = append(args, "-tags", strings.Join(tags, ","))
args = append(args, "./experimental/libbox")
command := exec.Command(build_shared.GoBinPath+"/gomobile", args...)

View File

@@ -0,0 +1,71 @@
package main
import (
"encoding/csv"
"io"
"net/http"
"os"
"strings"
"github.com/sagernet/sing-box/log"
"golang.org/x/exp/slices"
)
func main() {
err := updateMozillaIncludedRootCAs()
if err != nil {
log.Error(err)
}
}
func updateMozillaIncludedRootCAs() error {
response, err := http.Get("https://ccadb.my.salesforce-sites.com/mozilla/IncludedCACertificateReportPEMCSV")
if err != nil {
return err
}
defer response.Body.Close()
reader := csv.NewReader(response.Body)
header, err := reader.Read()
if err != nil {
return err
}
geoIndex := slices.Index(header, "Geographic Focus")
nameIndex := slices.Index(header, "Common Name or Certificate Name")
certIndex := slices.Index(header, "PEM Info")
generated := strings.Builder{}
generated.WriteString(`// Code generated by 'make update_certificates'. DO NOT EDIT.
package certificate
import "crypto/x509"
var mozillaIncluded *x509.CertPool
func init() {
mozillaIncluded = x509.NewCertPool()
`)
for {
record, err := reader.Read()
if err == io.EOF {
break
} else if err != nil {
return err
}
if record[geoIndex] == "China" {
continue
}
generated.WriteString("\n // ")
generated.WriteString(record[nameIndex])
generated.WriteString("\n")
generated.WriteString(" mozillaIncluded.AppendCertsFromPEM([]byte(`")
cert := record[certIndex]
// Remove single quotes
cert = cert[1 : len(cert)-1]
generated.WriteString(cert)
generated.WriteString("`))\n")
}
generated.WriteString("}\n")
return os.WriteFile("common/certificate/mozilla.go", []byte(generated.String()), 0o644)
}

View File

@@ -7,7 +7,6 @@ import (
"strconv"
"time"
"github.com/sagernet/sing-box"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/include"
"github.com/sagernet/sing-box/log"
@@ -68,6 +67,5 @@ func preRun(cmd *cobra.Command, args []string) {
if len(configPaths) == 0 && len(configDirectories) == 0 {
configPaths = append(configPaths, "config.json")
}
globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
globalCtx = include.Context(service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger())))
}

View File

@@ -9,8 +9,6 @@ import (
"github.com/spf13/cobra"
)
var pqSignatureSchemesEnabled bool
var commandGenerateECHKeyPair = &cobra.Command{
Use: "ech-keypair <plain_server_name>",
Short: "Generate TLS ECH key pair",
@@ -24,12 +22,11 @@ var commandGenerateECHKeyPair = &cobra.Command{
}
func init() {
commandGenerateECHKeyPair.Flags().BoolVar(&pqSignatureSchemesEnabled, "pq-signature-schemes-enabled", false, "Enable PQ signature schemes")
commandGenerate.AddCommand(commandGenerateECHKeyPair)
}
func generateECHKeyPair(serverName string) error {
configPem, keyPem, err := tls.ECHKeygenDefault(serverName, pqSignatureSchemesEnabled)
configPem, keyPem, err := tls.ECHKeygenDefault(serverName)
if err != nil {
return err
}

View File

@@ -1,31 +0,0 @@
//go:build go1.21 && !without_badtls && with_ech
package badtls
import (
"net"
_ "unsafe"
"github.com/sagernet/cloudflare-tls"
"github.com/sagernet/sing/common"
)
func init() {
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
tlsConn, loaded := common.Cast[*tls.Conn](conn)
if !loaded {
return
}
return true, func() error {
return echReadRecord(tlsConn)
}, func() error {
return echHandlePostHandshakeMessage(tlsConn)
}
})
}
//go:linkname echReadRecord github.com/sagernet/cloudflare-tls.(*Conn).readRecord
func echReadRecord(c *tls.Conn) error
//go:linkname echHandlePostHandshakeMessage github.com/sagernet/cloudflare-tls.(*Conn).handlePostHandshakeMessage
func echHandlePostHandshakeMessage(c *tls.Conn) error

View File

@@ -7,7 +7,8 @@ import (
_ "unsafe"
"github.com/sagernet/sing/common"
"github.com/sagernet/utls"
"github.com/metacubex/utls"
)
func init() {
@@ -24,8 +25,8 @@ func init() {
})
}
//go:linkname utlsReadRecord github.com/sagernet/utls.(*Conn).readRecord
//go:linkname utlsReadRecord github.com/metacubex/utls.(*Conn).readRecord
func utlsReadRecord(c *tls.Conn) error
//go:linkname utlsHandlePostHandshakeMessage github.com/sagernet/utls.(*Conn).handlePostHandshakeMessage
//go:linkname utlsHandlePostHandshakeMessage github.com/metacubex/utls.(*Conn).handlePostHandshakeMessage
func utlsHandlePostHandshakeMessage(c *tls.Conn) error

File diff suppressed because it is too large Load Diff

185
common/certificate/store.go Normal file
View File

@@ -0,0 +1,185 @@
package certificate
import (
"context"
"crypto/x509"
"io/fs"
"os"
"path/filepath"
"strings"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
)
var _ adapter.CertificateStore = (*Store)(nil)
type Store struct {
systemPool *x509.CertPool
currentPool *x509.CertPool
certificate string
certificatePaths []string
certificateDirectoryPaths []string
watcher *fswatch.Watcher
}
func NewStore(ctx context.Context, logger logger.Logger, options option.CertificateOptions) (*Store, error) {
var systemPool *x509.CertPool
switch options.Store {
case C.CertificateStoreSystem, "":
systemPool = x509.NewCertPool()
platformInterface := service.FromContext[platform.Interface](ctx)
var systemValid bool
if platformInterface != nil {
for _, cert := range platformInterface.SystemCertificates() {
if systemPool.AppendCertsFromPEM([]byte(cert)) {
systemValid = true
}
}
}
if !systemValid {
certPool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
systemPool = certPool
}
case C.CertificateStoreMozilla:
systemPool = mozillaIncluded
case C.CertificateStoreNone:
systemPool = nil
default:
return nil, E.New("unknown certificate store: ", options.Store)
}
store := &Store{
systemPool: systemPool,
certificate: strings.Join(options.Certificate, "\n"),
certificatePaths: options.CertificatePath,
certificateDirectoryPaths: options.CertificateDirectoryPath,
}
var watchPaths []string
for _, target := range options.CertificatePath {
watchPaths = append(watchPaths, target)
}
for _, target := range options.CertificateDirectoryPath {
watchPaths = append(watchPaths, target)
}
if len(watchPaths) > 0 {
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: watchPaths,
Logger: logger,
Callback: func(_ string) {
err := store.update()
if err != nil {
logger.Error(E.Cause(err, "reload certificates"))
}
},
})
if err != nil {
return nil, E.Cause(err, "fswatch: create fsnotify watcher")
}
store.watcher = watcher
}
err := store.update()
if err != nil {
return nil, E.Cause(err, "initializing certificate store")
}
return store, nil
}
func (s *Store) Name() string {
return "certificate"
}
func (s *Store) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if s.watcher != nil {
return s.watcher.Start()
}
return nil
}
func (s *Store) Close() error {
if s.watcher != nil {
return s.watcher.Close()
}
return nil
}
func (s *Store) Pool() *x509.CertPool {
return s.currentPool
}
func (s *Store) update() error {
var currentPool *x509.CertPool
if s.systemPool == nil {
currentPool = x509.NewCertPool()
} else {
currentPool = s.systemPool.Clone()
}
if s.certificate != "" {
if !currentPool.AppendCertsFromPEM([]byte(s.certificate)) {
return E.New("invalid certificate PEM strings")
}
}
for _, path := range s.certificatePaths {
pemContent, err := os.ReadFile(path)
if err != nil {
return err
}
if !currentPool.AppendCertsFromPEM(pemContent) {
return E.New("invalid certificate PEM file: ", path)
}
}
var firstErr error
for _, directoryPath := range s.certificateDirectoryPaths {
directoryEntries, err := readUniqueDirectoryEntries(directoryPath)
if err != nil {
if firstErr == nil && !os.IsNotExist(err) {
firstErr = E.Cause(err, "invalid certificate directory: ", directoryPath)
}
continue
}
for _, directoryEntry := range directoryEntries {
pemContent, err := os.ReadFile(filepath.Join(directoryPath, directoryEntry.Name()))
if err == nil {
currentPool.AppendCertsFromPEM(pemContent)
}
}
}
if firstErr != nil {
return firstErr
}
s.currentPool = currentPool
return nil
}
func readUniqueDirectoryEntries(dir string) ([]fs.DirEntry, error) {
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
uniq := files[:0]
for _, f := range files {
if !isSameDirSymlink(f, dir) {
uniq = append(uniq, f)
}
}
return uniq, nil
}
func isSameDirSymlink(f fs.DirEntry, dir string) bool {
if f.Type()&fs.ModeSymlink == 0 {
return false
}
target, err := os.Readlink(filepath.Join(dir, f.Name()))
return err == nil && !strings.Contains(target, "/")
}

View File

@@ -1,74 +0,0 @@
package cloudflare
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/tidwall/gjson"
)
type CloudflareApi struct {
client http.Client
}
func NewCloudflareApi(opts ...CloudflareApiOption) *CloudflareApi {
api := &CloudflareApi{http.Client{Timeout: 30 * time.Second}}
for _, opt := range opts {
opt(api)
}
return api
}
func (api *CloudflareApi) CreateProfile(ctx context.Context, publicKey string) (*CloudflareProfile, error) {
request, err := http.NewRequest("POST", "https://api.cloudflareclient.com/v0i1909051800/reg", strings.NewReader(
fmt.Sprintf(
"{\"install_id\":\"\",\"tos\":\"%s\",\"key\":\"%s\",\"fcm_token\":\"\",\"type\":\"ios\",\"locale\":\"en_US\"}",
time.Now().Format("2006-01-02T15:04:05.000Z"),
publicKey,
),
))
if err != nil {
return nil, err
}
response, err := api.client.Do(request.WithContext(ctx))
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != 200 {
return nil, fmt.Errorf("status code is not 200")
}
content, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
profile := new(CloudflareProfile)
return profile, json.NewDecoder(strings.NewReader(gjson.Get(string(content), "result").Raw)).Decode(profile)
}
func (api *CloudflareApi) GetProfile(ctx context.Context, authToken string, id string) (*CloudflareProfile, error) {
request, err := http.NewRequest("GET", "https://api.cloudflareclient.com/v0i1909051800/reg/"+id, nil)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", "Bearer "+authToken)
response, err := api.client.Do(request.WithContext(ctx))
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != 200 {
return nil, fmt.Errorf("status code is not 200")
}
content, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
profile := new(CloudflareProfile)
return profile, json.NewDecoder(strings.NewReader(gjson.Get(string(content), "result").Raw)).Decode(profile)
}

View File

@@ -1,17 +0,0 @@
package cloudflare
import (
"context"
"net"
"net/http"
)
type CloudflareApiOption func(api *CloudflareApi)
func WithDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) CloudflareApiOption {
return func(api *CloudflareApi) {
api.client.Transport = &http.Transport{
DialContext: dialContext,
}
}
}

View File

@@ -1,64 +0,0 @@
package cloudflare
import "time"
type CloudflareProfile struct {
ID string `json:"id"`
Type string `json:"type"`
Name string `json:"name"`
Key string `json:"key"`
Account struct {
ID string `json:"id"`
AccountType string `json:"account_type"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
PremiumData int `json:"premium_data"`
Quota int `json:"quota"`
Usage int `json:"usage"`
WARPPlus bool `json:"warp_plus"`
ReferralCount int `json:"referral_count"`
ReferralRenewalCountdown int `json:"referral_renewal_countdown"`
Role string `json:"role"`
License string `json:"license"`
TTL time.Time `json:"ttl"`
} `json:"account"`
Config struct {
ClientID string `json:"client_id"`
Interface struct {
Addresses struct {
V4 string `json:"v4"`
V6 string `json:"v6"`
} `json:"addresses"`
} `json:"interface"`
Peers []struct {
PublicKey string `json:"public_key"`
Endpoint struct {
V4 string `json:"v4"`
V6 string `json:"v6"`
Host string `json:"host"`
Ports []int `json:"ports"`
} `json:"endpoint"`
} `json:"peers"`
Services struct {
HTTPProxy string `json:"http_proxy"`
} `json:"services"`
Metrics struct {
Ping int `json:"ping"`
Report int `json:"report"`
} `json:"metrics"`
} `json:"config"`
Token string `json:"token"`
WARPEnabled bool `json:"warp_enabled"`
WaitlistEnabled bool `json:"waitlist_enabled"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
Tos time.Time `json:"tos"`
Place int `json:"place"`
Locale string `json:"locale"`
Enabled bool `json:"enabled"`
InstallID string `json:"install_id"`
FcmToken string `json:"fcm_token"`
Policy struct {
TunnelProtocol string `json:"tunnel_protocol"`
} `json:"policy"`
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/listener"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/option"
@@ -35,7 +36,7 @@ type DefaultDialer struct {
udpListener net.ListenConfig
udpAddr4 string
udpAddr6 string
isWireGuardListener bool
netns string
networkManager adapter.NetworkManager
networkStrategy *C.NetworkStrategy
defaultNetworkStrategy bool
@@ -65,23 +66,19 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
interfaceFinder = control.NewDefaultInterfaceFinder()
}
if options.BindInterface != "" {
if !(C.IsLinux || C.IsDarwin || C.IsWindows) {
return nil, E.New("`bind_interface` is only supported on Linux, macOS and Windows")
}
bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1)
dialer.Control = control.Append(dialer.Control, bindFunc)
listener.Control = control.Append(listener.Control, bindFunc)
}
if options.RoutingMark > 0 {
dialer.Control = control.Append(dialer.Control, control.RoutingMark(uint32(options.RoutingMark)))
listener.Control = control.Append(listener.Control, control.RoutingMark(uint32(options.RoutingMark)))
}
if networkManager != nil {
autoRedirectOutputMark := networkManager.AutoRedirectOutputMark()
if autoRedirectOutputMark > 0 {
if options.RoutingMark > 0 {
return nil, E.New("`routing_mark` is conflict with `tun.auto_redirect` with `tun.route_[_exclude]_address_set")
}
dialer.Control = control.Append(dialer.Control, control.RoutingMark(autoRedirectOutputMark))
listener.Control = control.Append(listener.Control, control.RoutingMark(autoRedirectOutputMark))
if !C.IsLinux {
return nil, E.New("`routing_mark` is only supported on Linux")
}
dialer.Control = control.Append(dialer.Control, setMarkWrapper(networkManager, uint32(options.RoutingMark), false))
listener.Control = control.Append(listener.Control, setMarkWrapper(networkManager, uint32(options.RoutingMark), false))
}
disableDefaultBind := options.BindInterface != "" || options.Inet4BindAddress != nil || options.Inet6BindAddress != nil
if disableDefaultBind || options.TCPFastOpen {
@@ -126,8 +123,8 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
}
}
if options.RoutingMark == 0 && defaultOptions.RoutingMark != 0 {
dialer.Control = control.Append(dialer.Control, control.RoutingMark(defaultOptions.RoutingMark))
listener.Control = control.Append(listener.Control, control.RoutingMark(defaultOptions.RoutingMark))
dialer.Control = control.Append(dialer.Control, setMarkWrapper(networkManager, defaultOptions.RoutingMark, true))
listener.Control = control.Append(listener.Control, setMarkWrapper(networkManager, defaultOptions.RoutingMark, true))
}
}
if options.ReuseAddr {
@@ -183,11 +180,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
}
setMultiPathTCP(&dialer4)
}
if options.IsWireGuardListener {
for _, controlFn := range WgControlFns {
listener.Control = control.Append(listener.Control, controlFn)
}
}
tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
if err != nil {
return nil, err
@@ -204,7 +196,7 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
udpListener: listener,
udpAddr4: udpAddr4,
udpAddr6: udpAddr6,
isWireGuardListener: options.IsWireGuardListener,
netns: options.NetNs,
networkManager: networkManager,
networkStrategy: networkStrategy,
defaultNetworkStrategy: defaultNetworkStrategy,
@@ -214,24 +206,44 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
}, nil
}
func setMarkWrapper(networkManager adapter.NetworkManager, mark uint32, isDefault bool) control.Func {
if networkManager == nil {
return control.RoutingMark(mark)
}
return func(network, address string, conn syscall.RawConn) error {
if networkManager.AutoRedirectOutputMark() != 0 {
if isDefault {
return E.New("`route.default_mark` is conflict with `tun.auto_redirect`")
} else {
return E.New("`routing_mark` is conflict with `tun.auto_redirect`")
}
}
return control.RoutingMark(mark)(network, address, conn)
}
}
func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
if !address.IsValid() {
return nil, E.New("invalid address")
} else if address.IsFqdn() {
return nil, E.New("domain not resolved")
}
if d.networkStrategy == nil {
switch N.NetworkName(network) {
case N.NetworkUDP:
if !address.IsIPv6() {
return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
} else {
return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
return trackConn(listener.ListenNetworkNamespace[net.Conn](d.netns, func() (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkUDP:
if !address.IsIPv6() {
return d.udpDialer4.DialContext(ctx, network, address.String())
} else {
return d.udpDialer6.DialContext(ctx, network, address.String())
}
}
}
if !address.IsIPv6() {
return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
} else {
return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
}
if !address.IsIPv6() {
return DialSlowContext(&d.dialer4, ctx, network, address)
} else {
return DialSlowContext(&d.dialer6, ctx, network, address)
}
}))
} else {
return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
}
@@ -287,13 +299,15 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin
func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if d.networkStrategy == nil {
if destination.IsIPv6() {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
} else {
return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
}
return trackPacketConn(listener.ListenNetworkNamespace[net.PacketConn](d.netns, func() (net.PacketConn, error) {
if destination.IsIPv6() {
return d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6)
} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
return d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4)
} else {
return d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4)
}
}))
} else {
return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
}

View File

@@ -6,39 +6,63 @@ import (
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type DirectDialer interface {
IsEmpty() bool
}
type DetourDialer struct {
outboundManager adapter.OutboundManager
detour string
legacyDNSDialer bool
dialer N.Dialer
initOnce sync.Once
initErr error
}
func NewDetour(outboundManager adapter.OutboundManager, detour string) N.Dialer {
return &DetourDialer{outboundManager: outboundManager, detour: detour}
func NewDetour(outboundManager adapter.OutboundManager, detour string, legacyDNSDialer bool) N.Dialer {
return &DetourDialer{
outboundManager: outboundManager,
detour: detour,
legacyDNSDialer: legacyDNSDialer,
}
}
func (d *DetourDialer) Start() error {
_, err := d.Dialer()
return err
func InitializeDetour(dialer N.Dialer) error {
detourDialer, isDetour := common.Cast[*DetourDialer](dialer)
if !isDetour {
return nil
}
return common.Error(detourDialer.Dialer())
}
func (d *DetourDialer) Dialer() (N.Dialer, error) {
d.initOnce.Do(func() {
var loaded bool
d.dialer, loaded = d.outboundManager.Outbound(d.detour)
if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour)
}
})
d.initOnce.Do(d.init)
return d.dialer, d.initErr
}
func (d *DetourDialer) init() {
dialer, loaded := d.outboundManager.Outbound(d.detour)
if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour)
return
}
if !d.legacyDNSDialer {
if directDialer, isDirect := dialer.(DirectDialer); isDirect {
if directDialer.IsEmpty() {
d.initErr = E.New("detour to an empty direct outbound makes no sense")
return
}
}
}
d.dialer = dialer
}
func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
dialer, err := d.Dialer()
if err != nil {

View File

@@ -8,68 +8,133 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) {
if options.IsWireGuardListener {
return NewDefault(ctx, options)
}
type Options struct {
Context context.Context
Options option.DialerOptions
RemoteIsDomain bool
DirectResolver bool
ResolverOnDetour bool
NewDialer bool
LegacyDNSDialer bool
DirectOutbound bool
}
// TODO: merge with NewWithOptions
func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) {
return NewWithOptions(Options{
Context: ctx,
Options: options,
RemoteIsDomain: remoteIsDomain,
})
}
func NewWithOptions(options Options) (N.Dialer, error) {
dialOptions := options.Options
var (
dialer N.Dialer
err error
)
if options.Detour == "" {
dialer, err = NewDefault(ctx, options)
if err != nil {
return nil, err
}
} else {
outboundManager := service.FromContext[adapter.OutboundManager](ctx)
if dialOptions.Detour != "" {
outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
if outboundManager == nil {
return nil, E.New("missing outbound manager")
}
dialer = NewDetour(outboundManager, options.Detour)
}
if options.Detour == "" {
router := service.FromContext[adapter.Router](ctx)
if router != nil {
dialer = NewResolveDialer(
router,
dialer,
options.Detour == "" && !options.TCPFastOpen,
dns.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay))
dialer = NewDetour(outboundManager, dialOptions.Detour, options.LegacyDNSDialer)
} else {
dialer, err = NewDefault(options.Context, dialOptions)
if err != nil {
return nil, err
}
}
if options.RemoteIsDomain && (dialOptions.Detour == "" || options.ResolverOnDetour || dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "") {
networkManager := service.FromContext[adapter.NetworkManager](options.Context)
dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context)
var defaultOptions adapter.NetworkOptions
if networkManager != nil {
defaultOptions = networkManager.DefaultOptions()
}
var (
server string
dnsQueryOptions adapter.DNSQueryOptions
resolveFallbackDelay time.Duration
)
if dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "" {
var transport adapter.DNSTransport
if !options.DirectResolver {
var loaded bool
transport, loaded = dnsTransport.Transport(dialOptions.DomainResolver.Server)
if !loaded {
return nil, E.New("domain resolver not found: " + dialOptions.DomainResolver.Server)
}
}
var strategy C.DomainStrategy
if dialOptions.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) {
strategy = C.DomainStrategy(dialOptions.DomainResolver.Strategy)
} else if
//nolint:staticcheck
dialOptions.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
//nolint:staticcheck
strategy = C.DomainStrategy(dialOptions.DomainStrategy)
deprecated.Report(options.Context, deprecated.OptionLegacyDomainStrategyOptions)
}
server = dialOptions.DomainResolver.Server
dnsQueryOptions = adapter.DNSQueryOptions{
Transport: transport,
Strategy: strategy,
DisableCache: dialOptions.DomainResolver.DisableCache,
RewriteTTL: dialOptions.DomainResolver.RewriteTTL,
ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
}
resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
} else if options.DirectResolver {
return nil, E.New("missing domain resolver for domain server address")
} else {
if defaultOptions.DomainResolver != "" {
dnsQueryOptions = defaultOptions.DomainResolveOptions
transport, loaded := dnsTransport.Transport(defaultOptions.DomainResolver)
if !loaded {
return nil, E.New("default domain resolver not found: " + defaultOptions.DomainResolver)
}
dnsQueryOptions.Transport = transport
resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
} else {
transports := dnsTransport.Transports()
if len(transports) < 2 {
dnsQueryOptions.Transport = dnsTransport.Default()
} else if options.NewDialer {
return nil, E.New("missing domain resolver for domain server address")
} else if !options.DirectOutbound {
deprecated.Report(options.Context, deprecated.OptionMissingDomainResolver)
}
}
if
//nolint:staticcheck
dialOptions.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
//nolint:staticcheck
dnsQueryOptions.Strategy = C.DomainStrategy(dialOptions.DomainStrategy)
deprecated.Report(options.Context, deprecated.OptionLegacyDomainStrategyOptions)
}
}
dialer = NewResolveDialer(
options.Context,
dialer,
dialOptions.Detour == "" && !dialOptions.TCPFastOpen,
server,
dnsQueryOptions,
resolveFallbackDelay,
)
}
return dialer, nil
}
func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInterfaceDialer, error) {
if options.Detour != "" {
return nil, E.New("`detour` is not supported in direct context")
}
if options.IsWireGuardListener {
return NewDefault(ctx, options)
}
dialer, err := NewDefault(ctx, options)
if err != nil {
return nil, err
}
return NewResolveParallelInterfaceDialer(
service.FromContext[adapter.Router](ctx),
dialer,
true,
dns.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay),
), nil
}
type ParallelInterfaceDialer interface {
N.Dialer
DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error)

View File

@@ -3,16 +3,17 @@ package dialer
import (
"context"
"net"
"net/netip"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
var (
@@ -20,21 +21,51 @@ var (
_ ParallelInterfaceDialer = (*resolveParallelNetworkDialer)(nil)
)
type ResolveDialer interface {
N.Dialer
QueryOptions() adapter.DNSQueryOptions
}
type ParallelInterfaceResolveDialer interface {
ParallelInterfaceDialer
QueryOptions() adapter.DNSQueryOptions
}
type resolveDialer struct {
transport adapter.DNSTransportManager
router adapter.DNSRouter
dialer N.Dialer
parallel bool
router adapter.Router
strategy dns.DomainStrategy
server string
initOnce sync.Once
initErr error
queryOptions adapter.DNSQueryOptions
fallbackDelay time.Duration
}
func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) N.Dialer {
func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer {
if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel {
return &resolveParallelNetworkDialer{
resolveDialer{
transport: service.FromContext[adapter.DNSTransportManager](ctx),
router: service.FromContext[adapter.DNSRouter](ctx),
dialer: dialer,
parallel: parallel,
server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
},
parallelDialer,
}
}
return &resolveDialer{
dialer,
parallel,
router,
strategy,
fallbackDelay,
transport: service.FromContext[adapter.DNSTransportManager](ctx),
router: service.FromContext[adapter.DNSRouter](ctx),
dialer: dialer,
parallel: parallel,
server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
}
}
@@ -43,59 +74,53 @@ type resolveParallelNetworkDialer struct {
dialer ParallelInterfaceDialer
}
func NewResolveParallelInterfaceDialer(router adapter.Router, dialer ParallelInterfaceDialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer {
return &resolveParallelNetworkDialer{
resolveDialer{
dialer,
parallel,
router,
strategy,
fallbackDelay,
},
dialer,
func (d *resolveDialer) initialize() error {
d.initOnce.Do(d.initServer)
return d.initErr
}
func (d *resolveDialer) initServer() {
if d.server == "" {
return
}
transport, loaded := d.transport.Transport(d.server)
if !loaded {
d.initErr = E.New("domain resolver not found: " + d.server)
return
}
d.queryOptions.Transport = transport
}
func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
if err != nil {
return nil, err
}
if d.parallel {
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, d.fallbackDelay)
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.queryOptions.Strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
} else {
return N.DialSerial(ctx, d.dialer, network, destination, addresses)
}
}
func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
if err != nil {
return nil, err
}
@@ -106,21 +131,24 @@ func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
}
func (d *resolveDialer) QueryOptions() adapter.DNSQueryOptions {
return d.queryOptions
}
func (d *resolveDialer) Upstream() any {
return d.dialer
}
func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
if err != nil {
return nil, err
}
@@ -128,30 +156,28 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
fallbackDelay = d.fallbackDelay
}
if d.parallel {
return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.queryOptions.Strategy == C.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
} else {
return DialSerialNetwork(ctx, d.dialer, network, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
}
}
func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
metadata.Destination = destination
metadata.Domain = ""
var addresses []netip.Addr
var err error
if d.strategy == dns.DomainStrategyAsIS {
addresses, err = d.router.LookupDefault(ctx, destination.Fqdn)
} else {
addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy)
}
addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
if err != nil {
return nil, err
}
if fallbackDelay == 0 {
fallbackDelay = d.fallbackDelay
}
conn, destinationAddress, err := ListenSerialNetworkPacket(ctx, d.dialer, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
if err != nil {
return nil, err
@@ -159,6 +185,10 @@ func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.C
return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
}
func (d *resolveDialer) Upstream() any {
func (d *resolveParallelNetworkDialer) QueryOptions() adapter.DNSQueryOptions {
return d.queryOptions
}
func (d *resolveParallelNetworkDialer) Upstream() any {
return d.dialer
}

View File

@@ -7,24 +7,27 @@ import (
"github.com/sagernet/sing-box/adapter"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
type DefaultOutboundDialer struct {
outboundManager adapter.OutboundManager
outbound adapter.OutboundManager
}
func NewDefaultOutbound(outboundManager adapter.OutboundManager) N.Dialer {
return &DefaultOutboundDialer{outboundManager: outboundManager}
func NewDefaultOutbound(ctx context.Context) N.Dialer {
return &DefaultOutboundDialer{
outbound: service.FromContext[adapter.OutboundManager](ctx),
}
}
func (d *DefaultOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return d.outboundManager.Default().DialContext(ctx, network, destination)
return d.outbound.Default().DialContext(ctx, network, destination)
}
func (d *DefaultOutboundDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return d.outboundManager.Default().ListenPacket(ctx, destination)
return d.outbound.Default().ListenPacket(ctx, destination)
}
func (d *DefaultOutboundDialer) Upstream() any {
return d.outboundManager.Default()
return d.outbound.Default()
}

View File

@@ -10,6 +10,7 @@ import (
"sync"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@@ -24,9 +25,7 @@ type slowOpenConn struct {
destination M.Socksaddr
conn net.Conn
create chan struct{}
done chan struct{}
access sync.Mutex
closeOnce sync.Once
err error
}
@@ -45,7 +44,6 @@ func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, des
network: network,
destination: destination,
create: make(chan struct{}),
done: make(chan struct{}),
}, nil
}
@@ -56,8 +54,8 @@ func (c *slowOpenConn) Read(b []byte) (n int, err error) {
if c.err != nil {
return 0, c.err
}
case <-c.done:
return 0, os.ErrClosed
case <-c.ctx.Done():
return 0, c.ctx.Err()
}
}
return c.conn.Read(b)
@@ -75,8 +73,6 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
return 0, c.err
}
return c.conn.Write(b)
case <-c.done:
return 0, os.ErrClosed
default:
}
conn, err := c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b)
@@ -91,13 +87,7 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) {
}
func (c *slowOpenConn) Close() error {
c.closeOnce.Do(func() {
close(c.done)
if c.conn != nil {
c.conn.Close()
}
})
return nil
return common.Close(c.conn)
}
func (c *slowOpenConn) LocalAddr() net.Addr {
@@ -162,8 +152,8 @@ func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
if c.err != nil {
return 0, c.err
}
case <-c.done:
return 0, c.err
case <-c.ctx.Done():
return 0, c.ctx.Err()
}
}
return bufio.Copy(w, c.conn)

View File

@@ -3,7 +3,6 @@ package interrupt
import (
"net"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)
@@ -74,32 +73,3 @@ func (c *PacketConn) WriterReplaceable() bool {
func (c *PacketConn) Upstream() any {
return c.PacketConn
}
type SingPacketConn struct {
N.PacketConn
group *Group
element *list.Element[*groupConnItem]
}
/*func (c *SingPacketConn) MarkAsInternal() {
c.element.Value.internal = true
}*/
func (c *SingPacketConn) Close() error {
c.group.access.Lock()
defer c.group.access.Unlock()
c.group.connections.Remove(c.element)
return c.PacketConn.Close()
}
func (c *SingPacketConn) ReaderReplaceable() bool {
return true
}
func (c *SingPacketConn) WriterReplaceable() bool {
return true
}
func (c *SingPacketConn) Upstream() any {
return c.PacketConn
}

View File

@@ -5,7 +5,6 @@ import (
"net"
"sync"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)
@@ -37,13 +36,6 @@ func (g *Group) NewPacketConn(conn net.PacketConn, isExternal bool) net.PacketCo
return &PacketConn{PacketConn: conn, group: g, element: item}
}
func (g *Group) NewSingPacketConn(conn N.PacketConn, isExternal bool) N.PacketConn {
g.access.Lock()
defer g.access.Unlock()
item := g.connections.PushBack(&groupConnItem{conn, isExternal})
return &SingPacketConn{PacketConn: conn, group: g, element: item}
}
func (g *Group) Interrupt(interruptExternalConnections bool) {
g.access.Lock()
defer g.access.Unlock()

View File

@@ -4,6 +4,8 @@ import (
"context"
"net"
"net/netip"
"runtime"
"strings"
"sync/atomic"
"github.com/sagernet/sing-box/adapter"
@@ -14,6 +16,8 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/vishvananda/netns"
)
type Listener struct {
@@ -28,6 +32,7 @@ type Listener struct {
disablePacketOutput bool
setSystemProxy bool
systemProxySOCKS bool
tproxy bool
tcpListener net.Listener
systemProxy settings.SystemProxy
@@ -50,6 +55,7 @@ type Options struct {
DisablePacketOutput bool
SetSystemProxy bool
SystemProxySOCKS bool
TProxy bool
}
func New(
@@ -67,6 +73,7 @@ func New(
disablePacketOutput: options.DisablePacketOutput,
setSystemProxy: options.SetSystemProxy,
systemProxySOCKS: options.SystemProxySOCKS,
tproxy: options.TProxy,
}
}
@@ -135,3 +142,30 @@ func (l *Listener) UDPConn() *net.UDPConn {
func (l *Listener) ListenOptions() option.ListenOptions {
return l.listenOptions
}
func ListenNetworkNamespace[T any](nameOrPath string, block func() (T, error)) (T, error) {
if nameOrPath != "" {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
currentNs, err := netns.Get()
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "get current netns")
}
defer netns.Set(currentNs)
var targetNs netns.NsHandle
if strings.HasPrefix(nameOrPath, "/") {
targetNs, err = netns.GetFromPath(nameOrPath)
} else {
targetNs, err = netns.GetFromName(nameOrPath)
}
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "get netns ", nameOrPath)
}
defer targetNs.Close()
err = netns.Set(targetNs)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "set netns to ", nameOrPath)
}
}
return block()
}

View File

@@ -3,23 +3,39 @@ package listener
import (
"net"
"net/netip"
"syscall"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/redir"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"github.com/metacubex/tfo-go"
)
func (l *Listener) ListenTCP() (net.Listener, error) {
//nolint:staticcheck
if l.listenOptions.ProxyProtocol || l.listenOptions.ProxyProtocolAcceptNoHeader {
return nil, E.New("Proxy Protocol is deprecated and removed in sing-box 1.6.0")
}
var err error
bindAddr := M.SocksaddrFrom(l.listenOptions.Listen.Build(netip.AddrFrom4([4]byte{127, 0, 0, 1})), l.listenOptions.ListenPort)
var tcpListener net.Listener
var listenConfig net.ListenConfig
if l.listenOptions.BindInterface != "" {
listenConfig.Control = control.Append(listenConfig.Control, control.BindToInterface(service.FromContext[adapter.NetworkManager](l.ctx).InterfaceFinder(), l.listenOptions.BindInterface, -1))
}
if l.listenOptions.RoutingMark != 0 {
listenConfig.Control = control.Append(listenConfig.Control, control.RoutingMark(uint32(l.listenOptions.RoutingMark)))
}
if l.listenOptions.ReuseAddr {
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
}
if l.listenOptions.TCPKeepAlive >= 0 {
keepIdle := time.Duration(l.listenOptions.TCPKeepAlive)
if keepIdle == 0 {
@@ -37,20 +53,26 @@ func (l *Listener) ListenTCP() (net.Listener, error) {
}
setMultiPathTCP(&listenConfig)
}
if l.listenOptions.TCPFastOpen {
var tfoConfig tfo.ListenConfig
tfoConfig.ListenConfig = listenConfig
tcpListener, err = tfoConfig.Listen(l.ctx, M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.String())
} else {
tcpListener, err = listenConfig.Listen(l.ctx, M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.String())
if l.tproxy {
listenConfig.Control = control.Append(listenConfig.Control, func(network, address string, conn syscall.RawConn) error {
return control.Raw(conn, func(fd uintptr) error {
return redir.TProxy(fd, M.ParseSocksaddr(address).IsIPv6(), false)
})
})
}
if err == nil {
l.logger.Info("tcp server started at ", tcpListener.Addr())
}
//nolint:staticcheck
if l.listenOptions.ProxyProtocol || l.listenOptions.ProxyProtocolAcceptNoHeader {
return nil, E.New("Proxy Protocol is deprecated and removed in sing-box 1.6.0")
tcpListener, err := ListenNetworkNamespace[net.Listener](l.listenOptions.NetNs, func() (net.Listener, error) {
if l.listenOptions.TCPFastOpen {
var tfoConfig tfo.ListenConfig
tfoConfig.ListenConfig = listenConfig
return tfoConfig.Listen(l.ctx, M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.String())
} else {
return listenConfig.Listen(l.ctx, M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.String())
}
})
if err != nil {
return nil, err
}
l.logger.Info("tcp server started at ", tcpListener.Addr())
l.tcpListener = tcpListener
return tcpListener, err
}

View File

@@ -1,20 +1,34 @@
package listener
import (
"context"
"net"
"net/netip"
"os"
"syscall"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/redir"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func (l *Listener) ListenUDP() (net.PacketConn, error) {
bindAddr := M.SocksaddrFrom(l.listenOptions.Listen.Build(netip.AddrFrom4([4]byte{127, 0, 0, 1})), l.listenOptions.ListenPort)
var lc net.ListenConfig
var listenConfig net.ListenConfig
if l.listenOptions.BindInterface != "" {
listenConfig.Control = control.Append(listenConfig.Control, control.BindToInterface(service.FromContext[adapter.NetworkManager](l.ctx).InterfaceFinder(), l.listenOptions.BindInterface, -1))
}
if l.listenOptions.RoutingMark != 0 {
listenConfig.Control = control.Append(listenConfig.Control, control.RoutingMark(uint32(l.listenOptions.RoutingMark)))
}
if l.listenOptions.ReuseAddr {
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
}
var udpFragment bool
if l.listenOptions.UDPFragment != nil {
udpFragment = *l.listenOptions.UDPFragment
@@ -22,9 +36,18 @@ func (l *Listener) ListenUDP() (net.PacketConn, error) {
udpFragment = l.listenOptions.UDPFragmentDefault
}
if !udpFragment {
lc.Control = control.Append(lc.Control, control.DisableUDPFragment())
listenConfig.Control = control.Append(listenConfig.Control, control.DisableUDPFragment())
}
udpConn, err := lc.ListenPacket(l.ctx, M.NetworkFromNetAddr(N.NetworkUDP, bindAddr.Addr), bindAddr.String())
if l.tproxy {
listenConfig.Control = control.Append(listenConfig.Control, func(network, address string, conn syscall.RawConn) error {
return control.Raw(conn, func(fd uintptr) error {
return redir.TProxy(fd, M.ParseSocksaddr(address).IsIPv6(), true)
})
})
}
udpConn, err := ListenNetworkNamespace[net.PacketConn](l.listenOptions.NetNs, func() (net.PacketConn, error) {
return listenConfig.ListenPacket(l.ctx, M.NetworkFromNetAddr(N.NetworkUDP, bindAddr.Addr), bindAddr.String())
})
if err != nil {
return nil, err
}
@@ -34,6 +57,36 @@ func (l *Listener) ListenUDP() (net.PacketConn, error) {
return udpConn, err
}
func (l *Listener) DialContext(dialer net.Dialer, ctx context.Context, network string, address string) (net.Conn, error) {
return ListenNetworkNamespace[net.Conn](l.listenOptions.NetNs, func() (net.Conn, error) {
if l.listenOptions.BindInterface != "" {
dialer.Control = control.Append(dialer.Control, control.BindToInterface(service.FromContext[adapter.NetworkManager](l.ctx).InterfaceFinder(), l.listenOptions.BindInterface, -1))
}
if l.listenOptions.RoutingMark != 0 {
dialer.Control = control.Append(dialer.Control, control.RoutingMark(uint32(l.listenOptions.RoutingMark)))
}
if l.listenOptions.ReuseAddr {
dialer.Control = control.Append(dialer.Control, control.ReuseAddr())
}
return dialer.DialContext(ctx, network, address)
})
}
func (l *Listener) ListenPacket(listenConfig net.ListenConfig, ctx context.Context, network string, address string) (net.PacketConn, error) {
return ListenNetworkNamespace[net.PacketConn](l.listenOptions.NetNs, func() (net.PacketConn, error) {
if l.listenOptions.BindInterface != "" {
listenConfig.Control = control.Append(listenConfig.Control, control.BindToInterface(service.FromContext[adapter.NetworkManager](l.ctx).InterfaceFinder(), l.listenOptions.BindInterface, -1))
}
if l.listenOptions.RoutingMark != 0 {
listenConfig.Control = control.Append(listenConfig.Control, control.RoutingMark(uint32(l.listenOptions.RoutingMark)))
}
if l.listenOptions.ReuseAddr {
listenConfig.Control = control.Append(listenConfig.Control, control.ReuseAddr())
}
return listenConfig.ListenPacket(ctx, network, address)
})
}
func (l *Listener) UDPAddr() M.Socksaddr {
return l.udpAddr
}

View File

@@ -23,6 +23,7 @@ type Config struct {
}
type Info struct {
ProcessID uint32
ProcessPath string
PackageName string
User string

View File

@@ -76,8 +76,6 @@ func findProcessName(network string, ip netip.Addr, port int) (string, error) {
// rup8(sizeof(xtcpcb_n))
itemSize += 208
}
var fallbackUDPProcess string
// skip the first xinpgen(24 bytes) block
for i := 24; i+itemSize <= len(buf); i += itemSize {
// offset of xinpcb_n and xsocket_n
@@ -92,12 +90,10 @@ func findProcessName(network string, ip netip.Addr, port int) (string, error) {
flag := buf[inp+44]
var srcIP netip.Addr
srcIsIPv4 := false
switch {
case flag&0x1 > 0 && isIPv4:
// ipv4
srcIP = netip.AddrFrom4(*(*[4]byte)(buf[inp+76 : inp+80]))
srcIsIPv4 = true
case flag&0x2 > 0 && !isIPv4:
// ipv6
srcIP = netip.AddrFrom16(*(*[16]byte)(buf[inp+64 : inp+80]))
@@ -105,21 +101,13 @@ func findProcessName(network string, ip netip.Addr, port int) (string, error) {
continue
}
if ip == srcIP {
// xsocket_n.so_last_pid
pid := readNativeUint32(buf[so+68 : so+72])
return getExecPathFromPID(pid)
if ip != srcIP {
continue
}
// udp packet connection may be not equal with srcIP
if network == N.NetworkUDP && srcIP.IsUnspecified() && isIPv4 == srcIsIPv4 {
pid := readNativeUint32(buf[so+68 : so+72])
fallbackUDPProcess, _ = getExecPathFromPID(pid)
}
}
if network == N.NetworkUDP && len(fallbackUDPProcess) > 0 {
return fallbackUDPProcess, nil
// xsocket_n.so_last_pid
pid := readNativeUint32(buf[so+68 : so+72])
return getExecPathFromPID(pid)
}
return "", ErrNotFound

View File

@@ -2,14 +2,11 @@ package process
import (
"context"
"fmt"
"net/netip"
"os"
"syscall"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/winiphlpapi"
"golang.org/x/sys/windows"
)
@@ -26,201 +23,39 @@ func NewSearcher(_ Config) (Searcher, error) {
return &windowsSearcher{}, nil
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable")
procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable")
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW")
)
func initWin32API() error {
err := modiphlpapi.Load()
if err != nil {
return E.Cause(err, "load iphlpapi.dll")
}
err = procGetExtendedTcpTable.Find()
if err != nil {
return E.Cause(err, "load iphlpapi::GetExtendedTcpTable")
}
err = procGetExtendedUdpTable.Find()
if err != nil {
return E.Cause(err, "load iphlpapi::GetExtendedUdpTable")
}
err = modkernel32.Load()
if err != nil {
return E.Cause(err, "load kernel32.dll")
}
err = procQueryFullProcessImageNameW.Find()
if err != nil {
return E.Cause(err, "load kernel32::QueryFullProcessImageNameW")
}
return nil
return winiphlpapi.LoadExtendedTable()
}
func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) {
processName, err := findProcessName(network, source.Addr(), int(source.Port()))
pid, err := winiphlpapi.FindPid(network, source)
if err != nil {
return nil, err
}
return &Info{ProcessPath: processName, UserId: -1}, nil
}
func findProcessName(network string, ip netip.Addr, srcPort int) (string, error) {
family := windows.AF_INET
if ip.Is6() {
family = windows.AF_INET6
}
const (
tcpTablePidConn = 4
udpTablePid = 1
)
var class int
var fn uintptr
switch network {
case N.NetworkTCP:
fn = procGetExtendedTcpTable.Addr()
class = tcpTablePidConn
case N.NetworkUDP:
fn = procGetExtendedUdpTable.Addr()
class = udpTablePid
default:
return "", os.ErrInvalid
}
buf, err := getTransportTable(fn, family, class)
path, err := getProcessPath(pid)
if err != nil {
return "", err
return &Info{ProcessID: pid, UserId: -1}, err
}
s := newSearcher(family == windows.AF_INET, network == N.NetworkTCP)
pid, err := s.Search(buf, ip, uint16(srcPort))
if err != nil {
return "", err
}
return getExecPathFromPID(pid)
return &Info{ProcessID: pid, ProcessPath: path, UserId: -1}, nil
}
type searcher struct {
itemSize int
port int
ip int
ipSize int
pid int
tcpState int
}
func (s *searcher) Search(b []byte, ip netip.Addr, port uint16) (uint32, error) {
n := int(readNativeUint32(b[:4]))
itemSize := s.itemSize
for i := 0; i < n; i++ {
row := b[4+itemSize*i : 4+itemSize*(i+1)]
// according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian.
// this field can be illustrated as follows depends on different machine endianess:
// little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB)
// big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB)
// so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32
srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4])))
if srcPort != port {
continue
}
srcIP, _ := netip.AddrFromSlice(row[s.ip : s.ip+s.ipSize])
// windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto
if ip != srcIP && (!srcIP.IsUnspecified()) {
continue
}
pid := readNativeUint32(row[s.pid : s.pid+4])
return pid, nil
}
return 0, ErrNotFound
}
func newSearcher(isV4, isTCP bool) *searcher {
var itemSize, port, ip, ipSize, pid int
tcpState := -1
switch {
case isV4 && isTCP:
// struct MIB_TCPROW_OWNER_PID
itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0
case isV4 && !isTCP:
// struct MIB_UDPROW_OWNER_PID
itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8
case !isV4 && isTCP:
// struct MIB_TCP6ROW_OWNER_PID
itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48
case !isV4 && !isTCP:
// struct MIB_UDP6ROW_OWNER_PID
itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24
}
return &searcher{
itemSize: itemSize,
port: port,
ip: ip,
ipSize: ipSize,
pid: pid,
tcpState: tcpState,
}
}
func getTransportTable(fn uintptr, family int, class int) ([]byte, error) {
for size, buf := uint32(8), make([]byte, 8); ; {
ptr := unsafe.Pointer(&buf[0])
err, _, _ := syscall.SyscallN(fn, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0)
switch err {
case 0:
return buf, nil
case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER):
buf = make([]byte, size)
default:
return nil, fmt.Errorf("syscall error: %d", err)
}
}
}
func readNativeUint32(b []byte) uint32 {
return *(*uint32)(unsafe.Pointer(&b[0]))
}
func getExecPathFromPID(pid uint32) (string, error) {
// kernel process starts with a colon in order to distinguish with normal processes
func getProcessPath(pid uint32) (string, error) {
switch pid {
case 0:
// reserved pid for system idle process
return ":System Idle Process", nil
case 4:
// reserved pid for windows kernel image
return ":System", nil
}
h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
handle, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
if err != nil {
return "", err
}
defer windows.CloseHandle(h)
defer windows.CloseHandle(handle)
size := uint32(syscall.MAX_LONG_PATH)
buf := make([]uint16, syscall.MAX_LONG_PATH)
size := uint32(len(buf))
r1, _, err := syscall.SyscallN(
procQueryFullProcessImageNameW.Addr(),
uintptr(h),
uintptr(0),
uintptr(unsafe.Pointer(&buf[0])),
uintptr(unsafe.Pointer(&size)),
)
if r1 == 0 {
err = windows.QueryFullProcessImageName(handle, 0, &buf[0], &size)
if err != nil {
return "", err
}
return syscall.UTF16ToString(buf[:size]), nil
return windows.UTF16ToString(buf[:size]), nil
}

View File

@@ -12,7 +12,7 @@ import (
"golang.org/x/sys/unix"
)
func TProxy(fd uintptr, isIPv6 bool) error {
func TProxy(fd uintptr, isIPv6 bool, isUDP bool) error {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
@@ -20,11 +20,13 @@ func TProxy(fd uintptr, isIPv6 bool) error {
if err == nil && isIPv6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
}
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
}
if err == nil && isIPv6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
if isUDP {
if err == nil {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
}
if err == nil && isIPv6 {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
}
}
return err
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/sagernet/sing/common/control"
)
func TProxy(fd uintptr, isIPv6 bool) error {
func TProxy(fd uintptr, isIPv6 bool, isUDP bool) error {
return os.ErrInvalid
}

58
common/sniff/ntp.go Normal file
View File

@@ -0,0 +1,58 @@
package sniff
import (
"context"
"encoding/binary"
"os"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
)
func NTP(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error {
// NTP packets must be at least 48 bytes long (standard NTP header size).
pLen := len(packet)
if pLen < 48 {
return os.ErrInvalid
}
// Check the LI (Leap Indicator) and Version Number (VN) in the first byte.
// We'll primarily focus on ensuring the version is valid for NTP.
// Many NTP versions are used, but let's check for generally accepted ones (3 & 4 for IPv4, plus potential extensions/customizations)
firstByte := packet[0]
li := (firstByte >> 6) & 0x03 // Extract LI
vn := (firstByte >> 3) & 0x07 // Extract VN
mode := firstByte & 0x07 // Extract Mode
// Leap Indicator should be a valid value (0-3).
if li > 3 {
return os.ErrInvalid
}
// Version Check (common NTP versions are 3 and 4)
if vn != 3 && vn != 4 {
return os.ErrInvalid
}
// Check the Mode field for a client request (Mode 3). This validates it *is* a request.
if mode != 3 {
return os.ErrInvalid
}
// Check Root Delay and Root Dispersion. While not strictly *required* for a request,
// we can check if they appear to be reasonable values (not excessively large).
rootDelay := binary.BigEndian.Uint32(packet[4:8])
rootDispersion := binary.BigEndian.Uint32(packet[8:12])
// Check for unreasonably large root delay and dispersion. NTP RFC specifies max values of approximately 16 seconds.
// Convert to milliseconds for easy comparison. Each unit is 1/2^16 seconds.
if float64(rootDelay)/65536.0 > 16.0 {
return os.ErrInvalid
}
if float64(rootDispersion)/65536.0 > 16.0 {
return os.ErrInvalid
}
metadata.Protocol = C.ProtocolNTP
return nil
}

33
common/sniff/ntp_test.go Normal file
View File

@@ -0,0 +1,33 @@
package sniff_test
import (
"context"
"encoding/hex"
"os"
"testing"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/sniff"
C "github.com/sagernet/sing-box/constant"
"github.com/stretchr/testify/require"
)
func TestSniffNTP(t *testing.T) {
t.Parallel()
packet, err := hex.DecodeString("1b0006000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
require.NoError(t, err)
var metadata adapter.InboundContext
err = sniff.NTP(context.Background(), &metadata, packet)
require.NoError(t, err)
require.Equal(t, metadata.Protocol, C.ProtocolNTP)
}
func TestSniffNTPFailed(t *testing.T) {
t.Parallel()
packet, err := hex.DecodeString("400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
require.NoError(t, err)
var metadata adapter.InboundContext
err = sniff.NTP(context.Background(), &metadata, packet)
require.ErrorIs(t, err, os.ErrInvalid)
}

View File

@@ -16,7 +16,7 @@ import (
"github.com/caddyserver/certmagic"
"github.com/libdns/alidns"
"github.com/libdns/cloudflare"
"github.com/mholt/acmez/acme"
"github.com/mholt/acmez/v3/acme"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
@@ -37,7 +37,7 @@ func (w *acmeWrapper) Close() error {
return nil
}
func startACME(ctx context.Context, options option.InboundACMEOptions) (*tls.Config, adapter.Service, error) {
func startACME(ctx context.Context, options option.InboundACMEOptions) (*tls.Config, adapter.SimpleLifecycle, error) {
var acmeServer string
switch options.Provider {
case "", "letsencrypt":

View File

@@ -11,6 +11,6 @@ import (
E "github.com/sagernet/sing/common/exceptions"
)
func startACME(ctx context.Context, options option.InboundACMEOptions) (*tls.Config, adapter.Service, error) {
func startACME(ctx context.Context, options option.InboundACMEOptions) (*tls.Config, adapter.SimpleLifecycle, error) {
return nil, nil, E.New(`ACME is not included in this build, rebuild with -tags with_acme`)
}

View File

@@ -29,15 +29,12 @@ func NewClient(ctx context.Context, serverAddress string, options option.Outboun
if !options.Enabled {
return nil, nil
}
if options.ECH != nil && options.ECH.Enabled {
return NewECHClient(ctx, serverAddress, options)
} else if options.Reality != nil && options.Reality.Enabled {
if options.Reality != nil && options.Reality.Enabled {
return NewRealityClient(ctx, serverAddress, options)
} else if options.UTLS != nil && options.UTLS.Enabled {
return NewUTLSClient(ctx, serverAddress, options)
} else {
return NewSTDClient(ctx, serverAddress, options)
}
return NewSTDClient(ctx, serverAddress, options)
}
func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) {

View File

@@ -18,6 +18,7 @@ type (
STDConfig = tls.Config
STDConn = tls.Conn
ConnectionState = tls.ConnectionState
CurveID = tls.CurveID
)
func ParseTLSVersion(version string) (uint16, error) {

194
common/tls/ech.go Normal file
View File

@@ -0,0 +1,194 @@
//go:build go1.24
package tls
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/pem"
"net"
"os"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns"
"golang.org/x/crypto/cryptobyte"
)
func parseECHClientConfig(ctx context.Context, options option.OutboundTLSOptions, tlsConfig *tls.Config) (Config, error) {
var echConfig []byte
if len(options.ECH.Config) > 0 {
echConfig = []byte(strings.Join(options.ECH.Config, "\n"))
} else if options.ECH.ConfigPath != "" {
content, err := os.ReadFile(options.ECH.ConfigPath)
if err != nil {
return nil, E.Cause(err, "read ECH config")
}
echConfig = content
}
//nolint:staticcheck
if options.ECH.PQSignatureSchemesEnabled || options.ECH.DynamicRecordSizingDisabled {
deprecated.Report(ctx, deprecated.OptionLegacyECHOptions)
}
if len(echConfig) > 0 {
block, rest := pem.Decode(echConfig)
if block == nil || block.Type != "ECH CONFIGS" || len(rest) > 0 {
return nil, E.New("invalid ECH configs pem")
}
tlsConfig.EncryptedClientHelloConfigList = block.Bytes
return &STDClientConfig{tlsConfig}, nil
} else {
return &STDECHClientConfig{
STDClientConfig: STDClientConfig{tlsConfig},
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
}, nil
}
}
func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions, tlsConfig *tls.Config, echKeyPath *string) error {
var echKey []byte
if len(options.ECH.Key) > 0 {
echKey = []byte(strings.Join(options.ECH.Key, "\n"))
} else if options.ECH.KeyPath != "" {
content, err := os.ReadFile(options.ECH.KeyPath)
if err != nil {
return E.Cause(err, "read ECH keys")
}
echKey = content
*echKeyPath = options.ECH.KeyPath
} else {
return E.New("missing ECH keys")
}
block, rest := pem.Decode(echKey)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return E.New("invalid ECH keys pem")
}
echKeys, err := UnmarshalECHKeys(block.Bytes)
if err != nil {
return E.Cause(err, "parse ECH keys")
}
tlsConfig.EncryptedClientHelloKeys = echKeys
//nolint:staticcheck
if options.ECH.PQSignatureSchemesEnabled || options.ECH.DynamicRecordSizingDisabled {
deprecated.Report(ctx, deprecated.OptionLegacyECHOptions)
}
return nil
}
func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error {
echKey, err := os.ReadFile(echKeyPath)
if err != nil {
return E.Cause(err, "reload ECH keys from ", echKeyPath)
}
block, _ := pem.Decode(echKey)
if block == nil || block.Type != "ECH KEYS" {
return E.New("invalid ECH keys pem")
}
echKeys, err := UnmarshalECHKeys(block.Bytes)
if err != nil {
return E.Cause(err, "parse ECH keys")
}
tlsConfig.EncryptedClientHelloKeys = echKeys
return nil
}
type STDECHClientConfig struct {
STDClientConfig
access sync.Mutex
dnsRouter adapter.DNSRouter
lastTTL time.Duration
lastUpdate time.Time
}
func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
tlsConn, err := s.fetchAndHandshake(ctx, conn)
if err != nil {
return nil, err
}
err = tlsConn.HandshakeContext(ctx)
if err != nil {
return nil, err
}
return tlsConn, nil
}
func (s *STDECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
s.access.Lock()
defer s.access.Unlock()
if len(s.config.EncryptedClientHelloConfigList) == 0 || s.lastTTL == 0 || time.Now().Sub(s.lastUpdate) > s.lastTTL {
message := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
RecursionDesired: true,
},
Question: []mDNS.Question{
{
Name: mDNS.Fqdn(s.config.ServerName),
Qtype: mDNS.TypeHTTPS,
Qclass: mDNS.ClassINET,
},
},
}
response, err := s.dnsRouter.Exchange(ctx, message, adapter.DNSQueryOptions{})
if err != nil {
return nil, E.Cause(err, "fetch ECH config list")
}
if response.Rcode != mDNS.RcodeSuccess {
return nil, E.Cause(dns.RcodeError(response.Rcode), "fetch ECH config list")
}
match:
for _, rr := range response.Answer {
switch resource := rr.(type) {
case *mDNS.HTTPS:
for _, value := range resource.Value {
if value.Key().String() == "ech" {
echConfigList, err := base64.StdEncoding.DecodeString(value.String())
if err != nil {
return nil, E.Cause(err, "decode ECH config")
}
s.lastTTL = time.Duration(rr.Header().Ttl) * time.Second
s.lastUpdate = time.Now()
s.config.EncryptedClientHelloConfigList = echConfigList
break match
}
}
}
}
if len(s.config.EncryptedClientHelloConfigList) == 0 {
return nil, E.New("no ECH config found in DNS records")
}
}
return s.Client(conn)
}
func (s *STDECHClientConfig) Clone() Config {
return &STDECHClientConfig{STDClientConfig: STDClientConfig{s.config.Clone()}, dnsRouter: s.dnsRouter, lastUpdate: s.lastUpdate}
}
func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) {
var keys []tls.EncryptedClientHelloKey
rawString := cryptobyte.String(raw)
for !rawString.Empty() {
var key tls.EncryptedClientHelloKey
if !rawString.ReadUint16LengthPrefixed((*cryptobyte.String)(&key.PrivateKey)) {
return nil, E.New("error parsing private key")
}
if !rawString.ReadUint16LengthPrefixed((*cryptobyte.String)(&key.Config)) {
return nil, E.New("error parsing config")
}
keys = append(keys, key)
}
if len(keys) == 0 {
return nil, E.New("empty ECH keys")
}
return keys, nil
}

View File

@@ -1,243 +0,0 @@
//go:build with_ech
package tls
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"net"
"net/netip"
"os"
"strings"
cftls "github.com/sagernet/cloudflare-tls"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns"
)
type echClientConfig struct {
config *cftls.Config
}
func (c *echClientConfig) ServerName() string {
return c.config.ServerName
}
func (c *echClientConfig) SetServerName(serverName string) {
c.config.ServerName = serverName
}
func (c *echClientConfig) NextProtos() []string {
return c.config.NextProtos
}
func (c *echClientConfig) SetNextProtos(nextProto []string) {
c.config.NextProtos = nextProto
}
func (c *echClientConfig) Config() (*STDConfig, error) {
return nil, E.New("unsupported usage for ECH")
}
func (c *echClientConfig) Client(conn net.Conn) (Conn, error) {
return &echConnWrapper{cftls.Client(conn, c.config)}, nil
}
func (c *echClientConfig) Clone() Config {
return &echClientConfig{
config: c.config.Clone(),
}
}
type echConnWrapper struct {
*cftls.Conn
}
func (c *echConnWrapper) ConnectionState() tls.ConnectionState {
state := c.Conn.ConnectionState()
//nolint:staticcheck
return tls.ConnectionState{
Version: state.Version,
HandshakeComplete: state.HandshakeComplete,
DidResume: state.DidResume,
CipherSuite: state.CipherSuite,
NegotiatedProtocol: state.NegotiatedProtocol,
NegotiatedProtocolIsMutual: state.NegotiatedProtocolIsMutual,
ServerName: state.ServerName,
PeerCertificates: state.PeerCertificates,
VerifiedChains: state.VerifiedChains,
SignedCertificateTimestamps: state.SignedCertificateTimestamps,
OCSPResponse: state.OCSPResponse,
TLSUnique: state.TLSUnique,
}
}
func (c *echConnWrapper) Upstream() any {
return c.Conn
}
func NewECHClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
var serverName string
if options.ServerName != "" {
serverName = options.ServerName
} else if serverAddress != "" {
if _, err := netip.ParseAddr(serverName); err != nil {
serverName = serverAddress
}
}
if serverName == "" && !options.Insecure {
return nil, E.New("missing server_name or insecure=true")
}
var tlsConfig cftls.Config
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {
tlsConfig.ServerName = serverName
}
if options.Insecure {
tlsConfig.InsecureSkipVerify = options.Insecure
} else if options.DisableSNI {
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyConnection = func(state cftls.ConnectionState) error {
verifyOptions := x509.VerifyOptions{
DNSName: serverName,
Intermediates: x509.NewCertPool(),
}
for _, cert := range state.PeerCertificates[1:] {
verifyOptions.Intermediates.AddCert(cert)
}
_, err := state.PeerCertificates[0].Verify(verifyOptions)
return err
}
}
if len(options.ALPN) > 0 {
tlsConfig.NextProtos = options.ALPN
}
if options.MinVersion != "" {
minVersion, err := ParseTLSVersion(options.MinVersion)
if err != nil {
return nil, E.Cause(err, "parse min_version")
}
tlsConfig.MinVersion = minVersion
}
if options.MaxVersion != "" {
maxVersion, err := ParseTLSVersion(options.MaxVersion)
if err != nil {
return nil, E.Cause(err, "parse max_version")
}
tlsConfig.MaxVersion = maxVersion
}
if options.CipherSuites != nil {
find:
for _, cipherSuite := range options.CipherSuites {
for _, tlsCipherSuite := range cftls.CipherSuites() {
if cipherSuite == tlsCipherSuite.Name {
tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
continue find
}
}
return nil, E.New("unknown cipher_suite: ", cipherSuite)
}
}
var certificate []byte
if len(options.Certificate) > 0 {
certificate = []byte(strings.Join(options.Certificate, "\n"))
} else if options.CertificatePath != "" {
content, err := os.ReadFile(options.CertificatePath)
if err != nil {
return nil, E.Cause(err, "read certificate")
}
certificate = content
}
if len(certificate) > 0 {
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(certificate) {
return nil, E.New("failed to parse certificate:\n\n", certificate)
}
tlsConfig.RootCAs = certPool
}
// ECH Config
tlsConfig.ECHEnabled = true
tlsConfig.PQSignatureSchemesEnabled = options.ECH.PQSignatureSchemesEnabled
tlsConfig.DynamicRecordSizingDisabled = options.ECH.DynamicRecordSizingDisabled
var echConfig []byte
if len(options.ECH.Config) > 0 {
echConfig = []byte(strings.Join(options.ECH.Config, "\n"))
} else if options.ECH.ConfigPath != "" {
content, err := os.ReadFile(options.ECH.ConfigPath)
if err != nil {
return nil, E.Cause(err, "read ECH config")
}
echConfig = content
}
if len(echConfig) > 0 {
block, rest := pem.Decode(echConfig)
if block == nil || block.Type != "ECH CONFIGS" || len(rest) > 0 {
return nil, E.New("invalid ECH configs pem")
}
echConfigs, err := cftls.UnmarshalECHConfigs(block.Bytes)
if err != nil {
return nil, E.Cause(err, "parse ECH configs")
}
tlsConfig.ClientECHConfigs = echConfigs
} else {
tlsConfig.GetClientECHConfigs = fetchECHClientConfig(ctx)
}
return &echClientConfig{&tlsConfig}, nil
}
func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverName string) ([]cftls.ECHConfig, error) {
return func(_ context.Context, serverName string) ([]cftls.ECHConfig, error) {
message := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
RecursionDesired: true,
},
Question: []mDNS.Question{
{
Name: serverName + ".",
Qtype: mDNS.TypeHTTPS,
Qclass: mDNS.ClassINET,
},
},
}
response, err := service.FromContext[adapter.Router](ctx).Exchange(ctx, message)
if err != nil {
return nil, err
}
if response.Rcode != mDNS.RcodeSuccess {
return nil, dns.RCodeError(response.Rcode)
}
for _, rr := range response.Answer {
switch resource := rr.(type) {
case *mDNS.HTTPS:
for _, value := range resource.Value {
if value.Key().String() == "ech" {
echConfig, err := base64.StdEncoding.DecodeString(value.String())
if err != nil {
return nil, E.Cause(err, "decode ECH config")
}
return cftls.UnmarshalECHConfigs(echConfig)
}
}
default:
return nil, E.New("unknown resource record type: ", resource.Header().Rrtype)
}
}
return nil, E.New("no ECH config found")
}
}

View File

@@ -1,5 +1,3 @@
//go:build with_ech
package tls
import (
@@ -7,14 +5,13 @@ import (
"encoding/binary"
"encoding/pem"
cftls "github.com/sagernet/cloudflare-tls"
E "github.com/sagernet/sing/common/exceptions"
"github.com/cloudflare/circl/hpke"
"github.com/cloudflare/circl/kem"
)
func ECHKeygenDefault(serverName string, pqSignatureSchemesEnabled bool) (configPem string, keyPem string, err error) {
func ECHKeygenDefault(serverName string) (configPem string, keyPem string, err error) {
cipherSuites := []echCipherSuite{
{
kdf: hpke.KDF_HKDF_SHA256,
@@ -24,13 +21,9 @@ func ECHKeygenDefault(serverName string, pqSignatureSchemesEnabled bool) (config
aead: hpke.AEAD_ChaCha20Poly1305,
},
}
keyConfig := []myECHKeyConfig{
{id: 0, kem: hpke.KEM_X25519_HKDF_SHA256},
}
if pqSignatureSchemesEnabled {
keyConfig = append(keyConfig, myECHKeyConfig{id: 1, kem: hpke.KEM_X25519_KYBER768_DRAFT00})
}
keyPairs, err := echKeygen(0xfe0d, serverName, keyConfig, cipherSuites)
if err != nil {
@@ -59,7 +52,6 @@ func ECHKeygenDefault(serverName string, pqSignatureSchemesEnabled bool) (config
type echKeyConfigPair struct {
id uint8
key cftls.EXP_ECHKey
rawKey []byte
conf myECHKeyConfig
rawConf []byte
@@ -155,15 +147,6 @@ func echKeygen(version uint16, serverName string, conf []myECHKeyConfig, suite [
sk = append(sk, secBuf...)
sk = be.AppendUint16(sk, uint16(len(b)))
sk = append(sk, b...)
cfECHKeys, err := cftls.EXP_UnmarshalECHKeys(sk)
if err != nil {
return nil, E.Cause(err, "bug: can't parse generated ECH server key")
}
if len(cfECHKeys) != 1 {
return nil, E.New("bug: unexpected server key count")
}
pair.key = cfECHKeys[0]
pair.rawKey = sk
pairs = append(pairs, pair)

View File

@@ -1,55 +0,0 @@
//go:build with_quic && with_ech
package tls
import (
"context"
"net"
"net/http"
"github.com/sagernet/cloudflare-tls"
"github.com/sagernet/quic-go/ech"
"github.com/sagernet/quic-go/http3_ech"
"github.com/sagernet/sing-quic"
M "github.com/sagernet/sing/common/metadata"
)
var (
_ qtls.Config = (*echClientConfig)(nil)
_ qtls.ServerConfig = (*echServerConfig)(nil)
)
func (c *echClientConfig) Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.Connection, error) {
return quic.Dial(ctx, conn, addr, c.config, config)
}
func (c *echClientConfig) DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.EarlyConnection, error) {
return quic.DialEarly(ctx, conn, addr, c.config, config)
}
func (c *echClientConfig) CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, quicConfig *quic.Config) http.RoundTripper {
return &http3.Transport{
TLSClientConfig: c.config,
QUICConfig: quicConfig,
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
quicConn, err := quic.DialEarly(ctx, conn, serverAddr.UDPAddr(), tlsCfg, cfg)
if err != nil {
return nil, err
}
*quicConnPtr = quicConn
return quicConn, nil
},
}
}
func (c *echServerConfig) Listen(conn net.PacketConn, config *quic.Config) (qtls.Listener, error) {
return quic.Listen(conn, c.config, config)
}
func (c *echServerConfig) ListenEarly(conn net.PacketConn, config *quic.Config) (qtls.EarlyListener, error) {
return quic.ListenEarly(conn, c.config, config)
}
func (c *echServerConfig) ConfigureHTTP3() {
http3.ConfigureTLSConfig(c.config)
}

View File

@@ -1,278 +0,0 @@
//go:build with_ech
package tls
import (
"context"
"crypto/tls"
"encoding/pem"
"net"
"os"
"strings"
cftls "github.com/sagernet/cloudflare-tls"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
)
type echServerConfig struct {
config *cftls.Config
logger log.Logger
certificate []byte
key []byte
certificatePath string
keyPath string
echKeyPath string
watcher *fswatch.Watcher
}
func (c *echServerConfig) ServerName() string {
return c.config.ServerName
}
func (c *echServerConfig) SetServerName(serverName string) {
c.config.ServerName = serverName
}
func (c *echServerConfig) NextProtos() []string {
return c.config.NextProtos
}
func (c *echServerConfig) SetNextProtos(nextProto []string) {
c.config.NextProtos = nextProto
}
func (c *echServerConfig) Config() (*STDConfig, error) {
return nil, E.New("unsupported usage for ECH")
}
func (c *echServerConfig) Client(conn net.Conn) (Conn, error) {
return &echConnWrapper{cftls.Client(conn, c.config)}, nil
}
func (c *echServerConfig) Server(conn net.Conn) (Conn, error) {
return &echConnWrapper{cftls.Server(conn, c.config)}, nil
}
func (c *echServerConfig) Clone() Config {
return &echServerConfig{
config: c.config.Clone(),
}
}
func (c *echServerConfig) Start() error {
err := c.startWatcher()
if err != nil {
c.logger.Warn("create credentials watcher: ", err)
}
return nil
}
func (c *echServerConfig) startWatcher() error {
var watchPath []string
if c.certificatePath != "" {
watchPath = append(watchPath, c.certificatePath)
}
if c.keyPath != "" {
watchPath = append(watchPath, c.keyPath)
}
if c.echKeyPath != "" {
watchPath = append(watchPath, c.echKeyPath)
}
if len(watchPath) == 0 {
return nil
}
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: watchPath,
Callback: func(path string) {
err := c.credentialsUpdated(path)
if err != nil {
c.logger.Error(E.Cause(err, "reload credentials from ", path))
}
},
})
if err != nil {
return err
}
err = watcher.Start()
if err != nil {
return err
}
c.watcher = watcher
return nil
}
func (c *echServerConfig) credentialsUpdated(path string) error {
if path == c.certificatePath || path == c.keyPath {
if path == c.certificatePath {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
return err
}
c.certificate = certificate
} else {
key, err := os.ReadFile(c.keyPath)
if err != nil {
return err
}
c.key = key
}
keyPair, err := cftls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "parse key pair")
}
c.config.Certificates = []cftls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
} else {
echKeyContent, err := os.ReadFile(c.echKeyPath)
if err != nil {
return err
}
block, rest := pem.Decode(echKeyContent)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return E.New("invalid ECH keys pem")
}
echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
if err != nil {
return E.Cause(err, "parse ECH keys")
}
echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
if err != nil {
return E.Cause(err, "create ECH key set")
}
c.config.ServerECHProvider = echKeySet
c.logger.Info("reloaded ECH keys")
}
return nil
}
func (c *echServerConfig) Close() error {
var err error
if c.watcher != nil {
err = E.Append(err, c.watcher.Close(), func(err error) error {
return E.Cause(err, "close credentials watcher")
})
}
return err
}
func NewECHServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
if !options.Enabled {
return nil, nil
}
var tlsConfig cftls.Config
if options.ACME != nil && len(options.ACME.Domain) > 0 {
return nil, E.New("acme is unavailable in ech")
}
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
if options.ServerName != "" {
tlsConfig.ServerName = options.ServerName
}
if len(options.ALPN) > 0 {
tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...)
}
if options.MinVersion != "" {
minVersion, err := ParseTLSVersion(options.MinVersion)
if err != nil {
return nil, E.Cause(err, "parse min_version")
}
tlsConfig.MinVersion = minVersion
}
if options.MaxVersion != "" {
maxVersion, err := ParseTLSVersion(options.MaxVersion)
if err != nil {
return nil, E.Cause(err, "parse max_version")
}
tlsConfig.MaxVersion = maxVersion
}
if options.CipherSuites != nil {
find:
for _, cipherSuite := range options.CipherSuites {
for _, tlsCipherSuite := range tls.CipherSuites() {
if cipherSuite == tlsCipherSuite.Name {
tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
continue find
}
}
return nil, E.New("unknown cipher_suite: ", cipherSuite)
}
}
var certificate []byte
var key []byte
if len(options.Certificate) > 0 {
certificate = []byte(strings.Join(options.Certificate, "\n"))
} else if options.CertificatePath != "" {
content, err := os.ReadFile(options.CertificatePath)
if err != nil {
return nil, E.Cause(err, "read certificate")
}
certificate = content
}
if len(options.Key) > 0 {
key = []byte(strings.Join(options.Key, "\n"))
} else if options.KeyPath != "" {
content, err := os.ReadFile(options.KeyPath)
if err != nil {
return nil, E.Cause(err, "read key")
}
key = content
}
if certificate == nil {
return nil, E.New("missing certificate")
} else if key == nil {
return nil, E.New("missing key")
}
keyPair, err := cftls.X509KeyPair(certificate, key)
if err != nil {
return nil, E.Cause(err, "parse x509 key pair")
}
tlsConfig.Certificates = []cftls.Certificate{keyPair}
var echKey []byte
if len(options.ECH.Key) > 0 {
echKey = []byte(strings.Join(options.ECH.Key, "\n"))
} else if options.ECH.KeyPath != "" {
content, err := os.ReadFile(options.ECH.KeyPath)
if err != nil {
return nil, E.Cause(err, "read ECH key")
}
echKey = content
} else {
return nil, E.New("missing ECH key")
}
block, rest := pem.Decode(echKey)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return nil, E.New("invalid ECH keys pem")
}
echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
if err != nil {
return nil, E.Cause(err, "parse ECH keys")
}
echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
if err != nil {
return nil, E.Cause(err, "create ECH key set")
}
tlsConfig.ECHEnabled = true
tlsConfig.PQSignatureSchemesEnabled = options.ECH.PQSignatureSchemesEnabled
tlsConfig.DynamicRecordSizingDisabled = options.ECH.DynamicRecordSizingDisabled
tlsConfig.ServerECHProvider = echKeySet
return &echServerConfig{
config: &tlsConfig,
logger: logger,
certificate: certificate,
key: key,
certificatePath: options.CertificatePath,
keyPath: options.KeyPath,
echKeyPath: options.ECH.KeyPath,
}, nil
}

View File

@@ -1,25 +1,23 @@
//go:build !with_ech
//go:build !go1.24
package tls
import (
"context"
"crypto/tls"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
var errECHNotIncluded = E.New(`ECH is not included in this build, rebuild with -tags with_ech`)
func NewECHServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
return nil, errECHNotIncluded
func parseECHClientConfig(ctx context.Context, options option.OutboundTLSOptions, tlsConfig *tls.Config) (Config, error) {
return nil, E.New("ECH requires go1.24, please recompile your binary.")
}
func NewECHClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return nil, errECHNotIncluded
func parseECHServerConfig(ctx context.Context, options option.InboundTLSOptions, tlsConfig *tls.Config, echKeyPath *string) error {
return E.New("ECH requires go1.24, please recompile your binary.")
}
func ECHKeygenDefault(host string, pqSignatureSchemesEnabled bool) (configPem string, keyPem string, err error) {
return "", "", errECHNotIncluded
func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error {
return E.New("ECH requires go1.24, please recompile your binary.")
}

View File

@@ -0,0 +1,5 @@
//go:build with_ech
package tls
var _ int = "Due to the migration to stdlib, the separate `with_ech` build tag has been deprecated and is no longer needed, please update your build configuration."

View File

@@ -12,6 +12,9 @@ import (
)
func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string) (*tls.Certificate, error) {
if timeFunc == nil {
timeFunc = time.Now
}
privateKeyPem, publicKeyPem, err := GenerateCertificate(parent, parentKey, timeFunc, serverName, timeFunc().Add(time.Hour))
if err != nil {
return nil, err
@@ -24,9 +27,6 @@ func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() ti
}
func GenerateCertificate(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string, expire time.Time) (privateKeyPem []byte, publicKeyPem []byte, err error) {
if timeFunc == nil {
timeFunc = time.Now
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return

View File

@@ -27,12 +27,15 @@ import (
"time"
"unsafe"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
aTLS "github.com/sagernet/sing/common/tls"
utls "github.com/sagernet/utls"
utls "github.com/metacubex/utls"
"golang.org/x/crypto/hkdf"
"golang.org/x/net/http2"
)
@@ -40,6 +43,7 @@ import (
var _ ConfigCompat = (*RealityClientConfig)(nil)
type RealityClientConfig struct {
ctx context.Context
uClient *UTLSClientConfig
publicKey []byte
shortID [8]byte
@@ -70,7 +74,7 @@ func NewRealityClient(ctx context.Context, serverAddress string, options option.
if decodedLen > 8 {
return nil, E.New("invalid short_id")
}
return &RealityClientConfig{uClient, publicKey, shortID}, nil
return &RealityClientConfig{ctx, uClient, publicKey, shortID}, nil
}
func (e *RealityClientConfig) ServerName() string {
@@ -111,6 +115,22 @@ func (e *RealityClientConfig) ClientHandshake(ctx context.Context, conn net.Conn
if err != nil {
return nil, err
}
for _, extension := range uConn.Extensions {
if ce, ok := extension.(*utls.SupportedCurvesExtension); ok {
ce.Curves = common.Filter(ce.Curves, func(curveID utls.CurveID) bool {
return curveID != utls.X25519MLKEM768
})
}
if ks, ok := extension.(*utls.KeyShareExtension); ok {
ks.KeyShares = common.Filter(ks.KeyShares, func(share utls.KeyShare) bool {
return share.Group != utls.X25519MLKEM768
})
}
}
err = uConn.BuildHandshakeState()
if err != nil {
return nil, err
}
if len(uConfig.NextProtos) > 0 {
for _, extension := range uConn.Extensions {
@@ -145,9 +165,13 @@ func (e *RealityClientConfig) ClientHandshake(ctx context.Context, conn net.Conn
if err != nil {
return nil, err
}
ecdheKey := uConn.HandshakeState.State13.EcdheKey
keyShareKeys := uConn.HandshakeState.State13.KeyShareKeys
if keyShareKeys == nil {
return nil, E.New("nil KeyShareKeys")
}
ecdheKey := keyShareKeys.Ecdhe
if ecdheKey == nil {
return nil, E.New("nil ecdhe_key")
return nil, E.New("nil ecdheKey")
}
authKey, err := ecdheKey.ECDH(publicKey)
if err != nil {
@@ -180,20 +204,24 @@ func (e *RealityClientConfig) ClientHandshake(ctx context.Context, conn net.Conn
}
if !verifier.verified {
go realityClientFallback(uConn, e.uClient.ServerName(), e.uClient.id)
go realityClientFallback(e.ctx, uConn, e.uClient.ServerName(), e.uClient.id)
return nil, E.New("reality verification failed")
}
return &realityClientConnWrapper{uConn}, nil
}
func realityClientFallback(uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) {
func realityClientFallback(ctx context.Context, uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) {
defer uConn.Close()
client := &http.Client{
Transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) {
return uConn, nil
},
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(ctx),
RootCAs: adapter.RootPoolFromContext(ctx),
},
},
}
request, _ := http.NewRequest("GET", "https://"+serverName, nil)
@@ -207,12 +235,9 @@ func realityClientFallback(uConn net.Conn, serverName string, fingerprint utls.C
response.Body.Close()
}
func (e *RealityClientConfig) SetSessionIDGenerator(generator func(clientHello []byte, sessionID []byte) error) {
e.uClient.config.SessionIDGenerator = generator
}
func (e *RealityClientConfig) Clone() Config {
return &RealityClientConfig{
e.ctx,
e.uClient.Clone().(*UTLSClientConfig),
e.publicKey,
e.shortID,

View File

@@ -1,4 +1,4 @@
//go:build with_reality_server
//go:build with_utls
package tls
@@ -7,28 +7,29 @@ import (
"crypto/tls"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"time"
"github.com/sagernet/reality"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
utls "github.com/metacubex/utls"
)
var _ ServerConfigCompat = (*RealityServerConfig)(nil)
type RealityServerConfig struct {
config *reality.Config
config *utls.RealityConfig
}
func NewRealityServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (*RealityServerConfig, error) {
var tlsConfig reality.Config
var tlsConfig utls.RealityConfig
if options.ACME != nil && len(options.ACME.Domain) > 0 {
return nil, E.New("acme is unavailable in reality")
@@ -74,6 +75,11 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb
}
tlsConfig.SessionTicketsDisabled = true
tlsConfig.Log = func(format string, v ...any) {
if logger != nil {
logger.Trace(fmt.Sprintf(format, v...))
}
}
tlsConfig.Type = N.NetworkTCP
tlsConfig.Dest = options.Reality.Handshake.ServerOptions.Build().String()
@@ -105,7 +111,7 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb
}
}
handshakeDialer, err := dialer.New(ctx, options.Reality.Handshake.DialerOptions)
handshakeDialer, err := dialer.New(ctx, options.Reality.Handshake.DialerOptions, options.Reality.Handshake.ServerIsDomain())
if err != nil {
return nil, err
}
@@ -113,10 +119,6 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb
return handshakeDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
}
if debug.Enabled {
tlsConfig.Show = true
}
return &RealityServerConfig{&tlsConfig}, nil
}
@@ -157,7 +159,7 @@ func (c *RealityServerConfig) Server(conn net.Conn) (Conn, error) {
}
func (c *RealityServerConfig) ServerHandshake(ctx context.Context, conn net.Conn) (Conn, error) {
tlsConn, err := reality.Server(ctx, conn, c.config)
tlsConn, err := utls.RealityServer(ctx, conn, c.config)
if err != nil {
return nil, err
}
@@ -173,7 +175,7 @@ func (c *RealityServerConfig) Clone() Config {
var _ Conn = (*realityConnWrapper)(nil)
type realityConnWrapper struct {
*reality.Conn
*utls.Conn
}
func (c *realityConnWrapper) ConnectionState() ConnectionState {

View File

@@ -1,15 +1,5 @@
//go:build !with_reality_server
//go:build with_reality_server
package tls
import (
"context"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func NewRealityServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
return nil, E.New(`reality server is not included in this build, rebuild with -tags with_reality_server`)
}
var _ int = "The separate `with_reality_server` build tag has been merged into `with_utls` and is no longer needed, please update your build configuration."

View File

@@ -16,13 +16,10 @@ func NewServer(ctx context.Context, logger log.Logger, options option.InboundTLS
if !options.Enabled {
return nil, nil
}
if options.ECH != nil && options.ECH.Enabled {
return NewECHServer(ctx, logger, options)
} else if options.Reality != nil && options.Reality.Enabled {
if options.Reality != nil && options.Reality.Enabled {
return NewRealityServer(ctx, logger, options)
} else {
return NewSTDServer(ctx, logger, options)
}
return NewSTDServer(ctx, logger, options)
}
func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) {

View File

@@ -5,10 +5,10 @@ import (
"crypto/tls"
"crypto/x509"
"net"
"net/netip"
"os"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
@@ -51,9 +51,7 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
if options.ServerName != "" {
serverName = options.ServerName
} else if serverAddress != "" {
if _, err := netip.ParseAddr(serverName); err != nil {
serverName = serverAddress
}
serverName = serverAddress
}
if serverName == "" && !options.Insecure {
return nil, E.New("missing server_name or insecure=true")
@@ -61,6 +59,7 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
var tlsConfig tls.Config
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {
@@ -128,5 +127,8 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb
}
tlsConfig.RootCAs = certPool
}
if options.ECH != nil && options.ECH.Enabled {
return parseECHClientConfig(ctx, options, &tlsConfig)
}
return &STDClientConfig{&tlsConfig}, nil
}

View File

@@ -22,11 +22,12 @@ var errInsecureUnused = E.New("tls: insecure unused")
type STDServerConfig struct {
config *tls.Config
logger log.Logger
acmeService adapter.Service
acmeService adapter.SimpleLifecycle
certificate []byte
key []byte
certificatePath string
keyPath string
echKeyPath string
watcher *fswatch.Watcher
}
@@ -95,12 +96,15 @@ func (c *STDServerConfig) startWatcher() error {
if c.keyPath != "" {
watchPath = append(watchPath, c.keyPath)
}
if c.echKeyPath != "" {
watchPath = append(watchPath, c.echKeyPath)
}
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: watchPath,
Callback: func(path string) {
err := c.certificateUpdated(path)
if err != nil {
c.logger.Error(err)
c.logger.Error(E.Cause(err, "reload certificate"))
}
},
})
@@ -116,25 +120,33 @@ func (c *STDServerConfig) startWatcher() error {
}
func (c *STDServerConfig) certificateUpdated(path string) error {
if path == c.certificatePath {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath)
if path == c.certificatePath || path == c.keyPath {
if path == c.certificatePath {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath)
}
c.certificate = certificate
} else if path == c.keyPath {
key, err := os.ReadFile(c.keyPath)
if err != nil {
return E.Cause(err, "reload key from ", c.keyPath)
}
c.key = key
}
c.certificate = certificate
} else if path == c.keyPath {
key, err := os.ReadFile(c.keyPath)
keyPair, err := tls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "reload key from ", c.keyPath)
return E.Cause(err, "reload key pair")
}
c.key = key
c.config.Certificates = []tls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
} else if path == c.echKeyPath {
err := reloadECHKeys(c.echKeyPath, c.config)
if err != nil {
return err
}
c.logger.Info("reloaded ECH keys")
}
keyPair, err := tls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "reload key pair")
}
c.config.Certificates = []tls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
return nil
}
@@ -153,7 +165,7 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
return nil, nil
}
var tlsConfig *tls.Config
var acmeService adapter.Service
var acmeService adapter.SimpleLifecycle
var err error
if options.ACME != nil && len(options.ACME.Domain) > 0 {
//nolint:staticcheck
@@ -243,6 +255,13 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
tlsConfig.Certificates = []tls.Certificate{keyPair}
}
}
var echKeyPath string
if options.ECH != nil && options.ECH.Enabled {
err = parseECHServerConfig(ctx, options, tlsConfig, &echKeyPath)
if err != nil {
return nil, err
}
}
return &STDServerConfig{
config: tlsConfig,
logger: logger,
@@ -251,5 +270,6 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
key: key,
certificatePath: options.CertificatePath,
keyPath: options.KeyPath,
echKeyPath: echKeyPath,
}, nil
}

View File

@@ -12,11 +12,12 @@ import (
"os"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
utls "github.com/sagernet/utls"
utls "github.com/metacubex/utls"
"golang.org/x/net/http2"
)
@@ -130,6 +131,7 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out
var tlsConfig utls.Config
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx)
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {

View File

@@ -5,6 +5,7 @@ package tls
import (
"context"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
@@ -14,5 +15,9 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out
}
func NewRealityClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return nil, E.New(`uTLS, which is required by reality client is not included in this build, rebuild with -tags with_utls`)
return nil, E.New(`uTLS, which is required by reality is not included in this build, rebuild with -tags with_utls`)
}
func NewRealityServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
return nil, E.New(`uTLS, which is required by reality is not included in this build, rebuild with -tags with_utls`)
}

129
common/tlsfragment/conn.go Normal file
View File

@@ -0,0 +1,129 @@
package tf
import (
"bytes"
"context"
"encoding/binary"
"math/rand"
"net"
"strings"
"time"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/publicsuffix"
)
type Conn struct {
net.Conn
tcpConn *net.TCPConn
ctx context.Context
firstPacketWritten bool
splitRecord bool
fallbackDelay time.Duration
}
func NewConn(conn net.Conn, ctx context.Context, splitRecord bool, fallbackDelay time.Duration) *Conn {
tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn)
return &Conn{
Conn: conn,
tcpConn: tcpConn,
ctx: ctx,
splitRecord: splitRecord,
fallbackDelay: fallbackDelay,
}
}
func (c *Conn) Write(b []byte) (n int, err error) {
if !c.firstPacketWritten {
defer func() {
c.firstPacketWritten = true
}()
serverName := indexTLSServerName(b)
if serverName != nil {
if !c.splitRecord {
if c.tcpConn != nil {
err = c.tcpConn.SetNoDelay(true)
if err != nil {
return
}
}
}
splits := strings.Split(serverName.ServerName, ".")
currentIndex := serverName.Index
if publicSuffix := publicsuffix.List.PublicSuffix(serverName.ServerName); publicSuffix != "" {
splits = splits[:len(splits)-strings.Count(serverName.ServerName, ".")]
}
if len(splits) > 1 && splits[0] == "..." {
currentIndex += len(splits[0]) + 1
splits = splits[1:]
}
var splitIndexes []int
for i, split := range splits {
splitAt := rand.Intn(len(split))
splitIndexes = append(splitIndexes, currentIndex+splitAt)
currentIndex += len(split)
if i != len(splits)-1 {
currentIndex++
}
}
var buffer bytes.Buffer
for i := 0; i <= len(splitIndexes); i++ {
var payload []byte
if i == 0 {
payload = b[:splitIndexes[i]]
if c.splitRecord {
payload = payload[recordLayerHeaderLen:]
}
} else if i == len(splitIndexes) {
payload = b[splitIndexes[i-1]:]
} else {
payload = b[splitIndexes[i-1]:splitIndexes[i]]
}
if c.splitRecord {
payloadLen := uint16(len(payload))
buffer.Write(b[:3])
binary.Write(&buffer, binary.BigEndian, payloadLen)
buffer.Write(payload)
} else if c.tcpConn != nil && i != len(splitIndexes) {
err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay)
if err != nil {
return
}
} else {
_, err = c.Conn.Write(payload)
if err != nil {
return
}
}
}
if c.splitRecord {
_, err = c.Conn.Write(buffer.Bytes())
if err != nil {
return
}
} else {
if c.tcpConn != nil {
err = c.tcpConn.SetNoDelay(false)
if err != nil {
return
}
}
}
return len(b), nil
}
}
return c.Conn.Write(b)
}
func (c *Conn) ReaderReplaceable() bool {
return true
}
func (c *Conn) WriterReplaceable() bool {
return c.firstPacketWritten
}
func (c *Conn) Upstream() any {
return c.Conn
}

View File

@@ -0,0 +1,32 @@
package tf_test
import (
"context"
"crypto/tls"
"net"
"testing"
tf "github.com/sagernet/sing-box/common/tlsfragment"
"github.com/stretchr/testify/require"
)
func TestTLSFragment(t *testing.T) {
t.Parallel()
tcpConn, err := net.Dial("tcp", "1.1.1.1:443")
require.NoError(t, err)
tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), false, 0), &tls.Config{
ServerName: "www.cloudflare.com",
})
require.NoError(t, tlsConn.Handshake())
}
func TestTLSRecordFragment(t *testing.T) {
t.Parallel()
tcpConn, err := net.Dial("tcp", "1.1.1.1:443")
require.NoError(t, err)
tlsConn := tls.Client(tf.NewConn(tcpConn, context.Background(), true, 0), &tls.Config{
ServerName: "www.cloudflare.com",
})
require.NoError(t, tlsConn.Handshake())
}

131
common/tlsfragment/index.go Normal file
View File

@@ -0,0 +1,131 @@
package tf
import (
"encoding/binary"
)
const (
recordLayerHeaderLen int = 5
handshakeHeaderLen int = 6
randomDataLen int = 32
sessionIDHeaderLen int = 1
cipherSuiteHeaderLen int = 2
compressMethodHeaderLen int = 1
extensionsHeaderLen int = 2
extensionHeaderLen int = 4
sniExtensionHeaderLen int = 5
contentType uint8 = 22
handshakeType uint8 = 1
sniExtensionType uint16 = 0
sniNameDNSHostnameType uint8 = 0
tlsVersionBitmask uint16 = 0xFFFC
tls13 uint16 = 0x0304
)
type myServerName struct {
Index int
Length int
ServerName string
}
func indexTLSServerName(payload []byte) *myServerName {
if len(payload) < recordLayerHeaderLen || payload[0] != contentType {
return nil
}
segmentLen := binary.BigEndian.Uint16(payload[3:5])
if len(payload) < recordLayerHeaderLen+int(segmentLen) {
return nil
}
serverName := indexTLSServerNameFromHandshake(payload[recordLayerHeaderLen : recordLayerHeaderLen+int(segmentLen)])
if serverName == nil {
return nil
}
serverName.Length += recordLayerHeaderLen
return serverName
}
func indexTLSServerNameFromHandshake(hs []byte) *myServerName {
if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen {
return nil
}
if hs[0] != handshakeType {
return nil
}
handshakeLen := uint32(hs[1])<<16 | uint32(hs[2])<<8 | uint32(hs[3])
if len(hs[4:]) != int(handshakeLen) {
return nil
}
tlsVersion := uint16(hs[4])<<8 | uint16(hs[5])
if tlsVersion&tlsVersionBitmask != 0x0300 && tlsVersion != tls13 {
return nil
}
sessionIDLen := hs[38]
if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen) {
return nil
}
cs := hs[handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen):]
if len(cs) < cipherSuiteHeaderLen {
return nil
}
csLen := uint16(cs[0])<<8 | uint16(cs[1])
if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen {
return nil
}
compressMethodLen := uint16(cs[cipherSuiteHeaderLen+int(csLen)])
if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen+int(compressMethodLen) {
return nil
}
currentIndex := cipherSuiteHeaderLen + int(csLen) + compressMethodHeaderLen + int(compressMethodLen)
serverName := indexTLSServerNameFromExtensions(cs[currentIndex:])
if serverName == nil {
return nil
}
serverName.Index += currentIndex
return serverName
}
func indexTLSServerNameFromExtensions(exs []byte) *myServerName {
if len(exs) == 0 {
return nil
}
if len(exs) < extensionsHeaderLen {
return nil
}
exsLen := uint16(exs[0])<<8 | uint16(exs[1])
exs = exs[extensionsHeaderLen:]
if len(exs) < int(exsLen) {
return nil
}
for currentIndex := extensionsHeaderLen; len(exs) > 0; {
if len(exs) < extensionHeaderLen {
return nil
}
exType := uint16(exs[0])<<8 | uint16(exs[1])
exLen := uint16(exs[2])<<8 | uint16(exs[3])
if len(exs) < extensionHeaderLen+int(exLen) {
return nil
}
sex := exs[extensionHeaderLen : extensionHeaderLen+int(exLen)]
switch exType {
case sniExtensionType:
if len(sex) < sniExtensionHeaderLen {
return nil
}
sniType := sex[2]
if sniType != sniNameDNSHostnameType {
return nil
}
sniLen := uint16(sex[3])<<8 | uint16(sex[4])
sex = sex[sniExtensionHeaderLen:]
return &myServerName{
Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen,
Length: int(sniLen),
ServerName: string(sex),
}
}
exs = exs[4+exLen:]
currentIndex += 4 + int(exLen)
}
return nil
}

View File

@@ -0,0 +1,93 @@
package tf
import (
"context"
"net"
"time"
"github.com/sagernet/sing/common/control"
"golang.org/x/sys/unix"
)
/*
const tcpMaxNotifyAck = 10
type tcpNotifyAckID uint32
type tcpNotifyAckComplete struct {
NotifyPending uint32
NotifyCompleteCount uint32
NotifyCompleteID [tcpMaxNotifyAck]tcpNotifyAckID
}
var sizeOfTCPNotifyAckComplete = int(unsafe.Sizeof(tcpNotifyAckComplete{}))
func getsockoptTCPNotifyAckComplete(fd, level, opt int) (*tcpNotifyAckComplete, error) {
var value tcpNotifyAckComplete
vallen := uint32(sizeOfTCPNotifyAckComplete)
err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
return &value, err
}
//go:linkname getsockopt golang.org/x/sys/unix.getsockopt
func getsockopt(s int, level int, name int, val unsafe.Pointer, vallen *uint32) error
func waitAck(ctx context.Context, conn *net.TCPConn, _ time.Duration) error {
const TCP_NOTIFY_ACKNOWLEDGEMENT = 0x212
return control.Conn(conn, func(fd uintptr) error {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, TCP_NOTIFY_ACKNOWLEDGEMENT, 1)
if err != nil {
if errors.Is(err, unix.EINVAL) {
return waitAckFallback(ctx, conn, 0)
}
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
var ackComplete *tcpNotifyAckComplete
ackComplete, err = getsockoptTCPNotifyAckComplete(int(fd), unix.IPPROTO_TCP, TCP_NOTIFY_ACKNOWLEDGEMENT)
if err != nil {
return err
}
if ackComplete.NotifyPending == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
})
}
*/
func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
_, err := conn.Write(payload)
if err != nil {
return err
}
return control.Conn(conn, func(fd uintptr) error {
start := time.Now()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
unacked, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_NWRITE)
if err != nil {
return err
}
if unacked == 0 {
if time.Since(start) <= 20*time.Millisecond {
// under transparent proxy
time.Sleep(fallbackDelay)
}
return nil
}
time.Sleep(10 * time.Millisecond)
}
})
}

View File

@@ -0,0 +1,40 @@
package tf
import (
"context"
"net"
"time"
"github.com/sagernet/sing/common/control"
"golang.org/x/sys/unix"
)
func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
_, err := conn.Write(payload)
if err != nil {
return err
}
return control.Conn(conn, func(fd uintptr) error {
start := time.Now()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
tcpInfo, err := unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO)
if err != nil {
return err
}
if tcpInfo.Unacked == 0 {
if time.Since(start) <= 20*time.Millisecond {
// under transparent proxy
time.Sleep(fallbackDelay)
}
return nil
}
time.Sleep(10 * time.Millisecond)
}
})
}

View File

@@ -0,0 +1,14 @@
//go:build !(linux || darwin || windows)
package tf
import (
"context"
"net"
"time"
)
func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
time.Sleep(fallbackDelay)
return nil
}

View File

@@ -0,0 +1,28 @@
package tf
import (
"context"
"errors"
"net"
"time"
"github.com/sagernet/sing/common/winiphlpapi"
"golang.org/x/sys/windows"
)
func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error {
start := time.Now()
err := winiphlpapi.WriteAndWaitAck(ctx, conn, payload)
if err != nil {
if errors.Is(err, windows.ERROR_ACCESS_DENIED) {
time.Sleep(fallbackDelay)
return nil
}
return err
}
if time.Since(start) <= 20*time.Millisecond {
time.Sleep(fallbackDelay)
}
return nil
}

View File

@@ -1,13 +0,0 @@
package urltest
import "context"
type contextKeyIsUnifiedDelay struct{}
func ContextWithIsUnifiedDelay(ctx context.Context) context.Context {
return context.WithValue(ctx, contextKeyIsUnifiedDelay{}, true)
}
func IsUnifiedDelayFromContext(ctx context.Context) bool {
return ctx.Value(contextKeyIsUnifiedDelay{}) != nil
}

View File

@@ -2,32 +2,32 @@ package urltest
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
)
type History struct {
Time time.Time `json:"time"`
Delay uint16 `json:"delay"`
}
var _ adapter.URLTestHistoryStorage = (*HistoryStorage)(nil)
type HistoryStorage struct {
access sync.RWMutex
delayHistory map[string]*History
delayHistory map[string]*adapter.URLTestHistory
updateHook chan<- struct{}
}
func NewHistoryStorage() *HistoryStorage {
return &HistoryStorage{
delayHistory: make(map[string]*History),
delayHistory: make(map[string]*adapter.URLTestHistory),
}
}
@@ -35,7 +35,7 @@ func (s *HistoryStorage) SetHook(hook chan<- struct{}) {
s.updateHook = hook
}
func (s *HistoryStorage) LoadURLTestHistory(tag string) *History {
func (s *HistoryStorage) LoadURLTestHistory(tag string) *adapter.URLTestHistory {
if s == nil {
return nil
}
@@ -51,7 +51,7 @@ func (s *HistoryStorage) DeleteURLTestHistory(tag string) {
s.notifyUpdated()
}
func (s *HistoryStorage) StoreURLTestHistory(tag string, history *History) {
func (s *HistoryStorage) StoreURLTestHistory(tag string, history *adapter.URLTestHistory) {
s.access.Lock()
s.delayHistory[tag] = history
s.access.Unlock()
@@ -110,6 +110,10 @@ func URLTest(ctx context.Context, link string, detour N.Dialer) (t uint16, err e
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return instance, nil
},
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(ctx),
RootCAs: adapter.RootPoolFromContext(ctx),
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
@@ -122,15 +126,6 @@ func URLTest(ctx context.Context, link string, detour N.Dialer) (t uint16, err e
return
}
resp.Body.Close()
if IsUnifiedDelayFromContext(ctx) {
second := time.Now()
resp, err = client.Do(req)
if err != nil {
return
}
resp.Body.Close()
start = second
}
t = uint16(time.Since(start) / time.Millisecond)
return
}

View File

@@ -1,337 +0,0 @@
package buf
import (
"io"
"github.com/sagernet/sing-box/common/xray/bytespool"
"github.com/sagernet/sing-box/common/xray/net"
E "github.com/sagernet/sing/common/exceptions"
)
const (
// Size of a regular buffer.
Size = 8192
)
var zero = [Size * 10]byte{0}
var pool = bytespool.GetPool(Size)
// ownership represents the data owner of the buffer.
type ownership uint8
const (
managed ownership = iota
unmanaged
bytespools
)
// Buffer is a recyclable allocation of a byte array. Buffer.Release() recycles
// the buffer into an internal buffer pool, in order to recreate a buffer more
// quickly.
type Buffer struct {
v []byte
start int32
end int32
ownership ownership
UDP *net.Destination
}
// New creates a Buffer with 0 length and 8K capacity, managed.
func New() *Buffer {
buf := pool.Get().([]byte)
if cap(buf) >= Size {
buf = buf[:Size]
} else {
buf = make([]byte, Size)
}
return &Buffer{
v: buf,
}
}
// NewExisted creates a standard size Buffer with an existed bytearray, managed.
func NewExisted(b []byte) *Buffer {
if cap(b) < Size {
panic("Invalid buffer")
}
oLen := len(b)
if oLen < Size {
b = b[:Size]
}
return &Buffer{
v: b,
end: int32(oLen),
}
}
// FromBytes creates a Buffer with an existed bytearray, unmanaged.
func FromBytes(b []byte) *Buffer {
return &Buffer{
v: b,
end: int32(len(b)),
ownership: unmanaged,
}
}
// StackNew creates a new Buffer object on stack, managed.
// This method is for buffers that is released in the same function.
func StackNew() Buffer {
buf := pool.Get().([]byte)
if cap(buf) >= Size {
buf = buf[:Size]
} else {
buf = make([]byte, Size)
}
return Buffer{
v: buf,
}
}
// NewWithSize creates a Buffer with 0 length and capacity with at least the given size, bytespool's.
func NewWithSize(size int32) *Buffer {
return &Buffer{
v: bytespool.Alloc(size),
ownership: bytespools,
}
}
// Release recycles the buffer into an internal buffer pool.
func (b *Buffer) Release() {
if b == nil || b.v == nil || b.ownership == unmanaged {
return
}
p := b.v
b.v = nil
b.Clear()
switch b.ownership {
case managed:
if cap(p) == Size {
pool.Put(p)
}
case bytespools:
bytespool.Free(p)
}
b.UDP = nil
}
// Clear clears the content of the buffer, results an empty buffer with
// Len() = 0.
func (b *Buffer) Clear() {
b.start = 0
b.end = 0
}
// Byte returns the bytes at index.
func (b *Buffer) Byte(index int32) byte {
return b.v[b.start+index]
}
// SetByte sets the byte value at index.
func (b *Buffer) SetByte(index int32, value byte) {
b.v[b.start+index] = value
}
// Bytes returns the content bytes of this Buffer.
func (b *Buffer) Bytes() []byte {
return b.v[b.start:b.end]
}
// Extend increases the buffer size by n bytes, and returns the extended part.
// It panics if result size is larger than buf.Size.
func (b *Buffer) Extend(n int32) []byte {
end := b.end + n
if end > int32(len(b.v)) {
panic("extending out of bound")
}
ext := b.v[b.end:end]
b.end = end
copy(ext, zero[:])
return ext
}
// BytesRange returns a slice of this buffer with given from and to boundary.
func (b *Buffer) BytesRange(from, to int32) []byte {
if from < 0 {
from += b.Len()
}
if to < 0 {
to += b.Len()
}
return b.v[b.start+from : b.start+to]
}
// BytesFrom returns a slice of this Buffer starting from the given position.
func (b *Buffer) BytesFrom(from int32) []byte {
if from < 0 {
from += b.Len()
}
return b.v[b.start+from : b.end]
}
// BytesTo returns a slice of this Buffer from start to the given position.
func (b *Buffer) BytesTo(to int32) []byte {
if to < 0 {
to += b.Len()
}
if to < 0 {
to = 0
}
return b.v[b.start : b.start+to]
}
// Check makes sure that 0 <= b.start <= b.end.
func (b *Buffer) Check() {
if b.start < 0 {
b.start = 0
}
if b.end < 0 {
b.end = 0
}
if b.start > b.end {
b.start = b.end
}
}
// Resize cuts the buffer at the given position.
func (b *Buffer) Resize(from, to int32) {
oldEnd := b.end
if from < 0 {
from += b.Len()
}
if to < 0 {
to += b.Len()
}
if to < from {
panic("Invalid slice")
}
b.end = b.start + to
b.start += from
b.Check()
if b.end > oldEnd {
copy(b.v[oldEnd:b.end], zero[:])
}
}
// Advance cuts the buffer at the given position.
func (b *Buffer) Advance(from int32) {
if from < 0 {
from += b.Len()
}
b.start += from
b.Check()
}
// Len returns the length of the buffer content.
func (b *Buffer) Len() int32 {
if b == nil {
return 0
}
return b.end - b.start
}
// Cap returns the capacity of the buffer content.
func (b *Buffer) Cap() int32 {
if b == nil {
return 0
}
return int32(len(b.v))
}
// IsEmpty returns true if the buffer is empty.
func (b *Buffer) IsEmpty() bool {
return b.Len() == 0
}
// IsFull returns true if the buffer has no more room to grow.
func (b *Buffer) IsFull() bool {
return b != nil && b.end == int32(len(b.v))
}
// Write implements Write method in io.Writer.
func (b *Buffer) Write(data []byte) (int, error) {
nBytes := copy(b.v[b.end:], data)
b.end += int32(nBytes)
return nBytes, nil
}
// WriteByte writes a single byte into the buffer.
func (b *Buffer) WriteByte(v byte) error {
if b.IsFull() {
return E.New("buffer full")
}
b.v[b.end] = v
b.end++
return nil
}
// WriteString implements io.StringWriter.
func (b *Buffer) WriteString(s string) (int, error) {
return b.Write([]byte(s))
}
// ReadByte implements io.ByteReader
func (b *Buffer) ReadByte() (byte, error) {
if b.start == b.end {
return 0, io.EOF
}
nb := b.v[b.start]
b.start++
return nb, nil
}
// ReadBytes implements bufio.Reader.ReadBytes
func (b *Buffer) ReadBytes(length int32) ([]byte, error) {
if b.end-b.start < length {
return nil, io.EOF
}
nb := b.v[b.start : b.start+length]
b.start += length
return nb, nil
}
// Read implements io.Reader.Read().
func (b *Buffer) Read(data []byte) (int, error) {
if b.Len() == 0 {
return 0, io.EOF
}
nBytes := copy(data, b.v[b.start:b.end])
if int32(nBytes) == b.Len() {
b.Clear()
} else {
b.start += int32(nBytes)
}
return nBytes, nil
}
// ReadFrom implements io.ReaderFrom.
func (b *Buffer) ReadFrom(reader io.Reader) (int64, error) {
n, err := reader.Read(b.v[b.end:])
b.end += int32(n)
return int64(n), err
}
// ReadFullFrom reads exact size of bytes from given reader, or until error occurs.
func (b *Buffer) ReadFullFrom(reader io.Reader, size int32) (int64, error) {
end := b.end + size
if end > int32(len(b.v)) {
v := end
return 0, E.New("out of bound: ", v)
}
n, err := io.ReadFull(reader, b.v[b.end:end])
b.end += int32(n)
return int64(n), err
}
// String returns the string form of this Buffer.
func (b *Buffer) String() string {
return string(b.Bytes())
}

View File

@@ -1,124 +0,0 @@
package buf
import (
"io"
"time"
"github.com/sagernet/sing-box/common/xray/errors"
"github.com/sagernet/sing-box/common/xray/signal"
E "github.com/sagernet/sing/common/exceptions"
)
type dataHandler func(MultiBuffer)
type copyHandler struct {
onData []dataHandler
}
// SizeCounter is for counting bytes copied by Copy().
type SizeCounter struct {
Size int64
}
// CopyOption is an option for copying data.
type CopyOption func(*copyHandler)
// UpdateActivity is a CopyOption to update activity on each data copy operation.
func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(MultiBuffer) {
timer.Update()
})
}
}
// CountSize is a CopyOption that sums the total size of data copied into the given SizeCounter.
func CountSize(sc *SizeCounter) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(b MultiBuffer) {
sc.Size += int64(b.Len())
})
}
}
type readError struct {
error
}
func (e readError) Error() string {
return e.error.Error()
}
func (e readError) Unwrap() error {
return e.error
}
// IsReadError returns true if the error in Copy() comes from reading.
func IsReadError(err error) bool {
_, ok := err.(readError)
return ok
}
type writeError struct {
error
}
func (e writeError) Error() string {
return e.error.Error()
}
func (e writeError) Unwrap() error {
return e.error
}
// IsWriteError returns true if the error in Copy() comes from writing.
func IsWriteError(err error) bool {
_, ok := err.(writeError)
return ok
}
func copyInternal(reader Reader, writer Writer, handler *copyHandler) error {
for {
buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() {
for _, handler := range handler.onData {
handler(buffer)
}
if werr := writer.WriteMultiBuffer(buffer); werr != nil {
return writeError{werr}
}
}
if err != nil {
return readError{err}
}
}
}
// Copy dumps all payload from reader to writer or stops when an error occurs. It returns nil when EOF.
func Copy(reader Reader, writer Writer, options ...CopyOption) error {
var handler copyHandler
for _, option := range options {
option(&handler)
}
err := copyInternal(reader, writer, &handler)
if err != nil && errors.Cause(err) != io.EOF {
return err
}
return nil
}
var ErrNotTimeoutReader = E.New("not a TimeoutReader")
func CopyOnceTimeout(reader Reader, writer Writer, timeout time.Duration) error {
timeoutReader, ok := reader.(TimeoutReader)
if !ok {
return ErrNotTimeoutReader
}
mb, err := timeoutReader.ReadMultiBufferTimeout(timeout)
if err != nil {
return err
}
return writer.WriteMultiBuffer(mb)
}

View File

@@ -1,126 +0,0 @@
package buf
import (
"io"
"net"
"syscall"
"time"
"github.com/sagernet/sing-box/common/xray/stat"
"github.com/sagernet/sing-box/common/xray/stats"
E "github.com/sagernet/sing/common/exceptions"
)
// Reader extends io.Reader with MultiBuffer.
type Reader interface {
// ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
ReadMultiBuffer() (MultiBuffer, error)
}
// ErrReadTimeout is an error that happens with IO timeout.
var ErrReadTimeout = E.New("IO timeout")
// TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
type TimeoutReader interface {
ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
}
// Writer extends io.Writer with MultiBuffer.
type Writer interface {
// WriteMultiBuffer writes a MultiBuffer into underlying writer.
WriteMultiBuffer(MultiBuffer) error
}
// WriteAllBytes ensures all bytes are written into the given writer.
func WriteAllBytes(writer io.Writer, payload []byte, c stats.Counter) error {
wc := 0
defer func() {
if c != nil {
c.Add(int64(wc))
}
}()
for len(payload) > 0 {
n, err := writer.Write(payload)
wc += n
if err != nil {
return err
}
payload = payload[n:]
}
return nil
}
func isPacketReader(reader io.Reader) bool {
_, ok := reader.(net.PacketConn)
return ok
}
// NewReader creates a new Reader.
// The Reader instance doesn't take the ownership of reader.
func NewReader(reader io.Reader) Reader {
if mr, ok := reader.(Reader); ok {
return mr
}
if isPacketReader(reader) {
return &PacketReader{
Reader: reader,
}
}
return &SingleReader{
Reader: reader,
}
}
// NewPacketReader creates a new PacketReader based on the given reader.
func NewPacketReader(reader io.Reader) Reader {
if mr, ok := reader.(Reader); ok {
return mr
}
return &PacketReader{
Reader: reader,
}
}
func isPacketWriter(writer io.Writer) bool {
if _, ok := writer.(net.PacketConn); ok {
return true
}
// If the writer doesn't implement syscall.Conn, it is probably not a TCP connection.
if _, ok := writer.(syscall.Conn); !ok {
return true
}
return false
}
// NewWriter creates a new Writer.
func NewWriter(writer io.Writer) Writer {
if mw, ok := writer.(Writer); ok {
return mw
}
iConn := writer
if statConn, ok := writer.(*stat.CounterConnection); ok {
iConn = statConn.Connection
}
if isPacketWriter(iConn) {
return &SequentialWriter{
Writer: writer,
}
}
var counter stats.Counter
if statConn, ok := writer.(*stat.CounterConnection); ok {
counter = statConn.WriteCounter
}
return &BufferToBytesWriter{
Writer: iConn,
counter: counter,
}
}

View File

@@ -1,310 +0,0 @@
package buf
import (
"io"
"github.com/sagernet/sing-box/common/xray"
"github.com/sagernet/sing-box/common/xray/errors"
"github.com/sagernet/sing-box/common/xray/serial"
)
// ReadAllToBytes reads all content from the reader into a byte array, until EOF.
func ReadAllToBytes(reader io.Reader) ([]byte, error) {
mb, err := ReadFrom(reader)
if err != nil {
return nil, err
}
if mb.Len() == 0 {
return nil, nil
}
b := make([]byte, mb.Len())
mb, _ = SplitBytes(mb, b)
ReleaseMulti(mb)
return b, nil
}
// MultiBuffer is a list of Buffers. The order of Buffer matters.
type MultiBuffer []*Buffer
// MergeMulti merges content from src to dest, and returns the new address of dest and src
func MergeMulti(dest MultiBuffer, src MultiBuffer) (MultiBuffer, MultiBuffer) {
dest = append(dest, src...)
for idx := range src {
src[idx] = nil
}
return dest, src[:0]
}
// MergeBytes merges the given bytes into MultiBuffer and return the new address of the merged MultiBuffer.
func MergeBytes(dest MultiBuffer, src []byte) MultiBuffer {
n := len(dest)
if n > 0 && !(dest)[n-1].IsFull() {
nBytes, _ := (dest)[n-1].Write(src)
src = src[nBytes:]
}
for len(src) > 0 {
b := New()
nBytes, _ := b.Write(src)
src = src[nBytes:]
dest = append(dest, b)
}
return dest
}
// ReleaseMulti releases all content of the MultiBuffer, and returns an empty MultiBuffer.
func ReleaseMulti(mb MultiBuffer) MultiBuffer {
for i := range mb {
mb[i].Release()
mb[i] = nil
}
return mb[:0]
}
// Copy copied the beginning part of the MultiBuffer into the given byte array.
func (mb MultiBuffer) Copy(b []byte) int {
total := 0
for _, bb := range mb {
nBytes := copy(b[total:], bb.Bytes())
total += nBytes
if int32(nBytes) < bb.Len() {
break
}
}
return total
}
// ReadFrom reads all content from reader until EOF.
func ReadFrom(reader io.Reader) (MultiBuffer, error) {
mb := make(MultiBuffer, 0, 16)
for {
b := New()
_, err := b.ReadFullFrom(reader, Size)
if b.IsEmpty() {
b.Release()
} else {
mb = append(mb, b)
}
if err != nil {
if errors.Cause(err) == io.EOF || errors.Cause(err) == io.ErrUnexpectedEOF {
return mb, nil
}
return mb, err
}
}
}
// SplitBytes splits the given amount of bytes from the beginning of the MultiBuffer.
// It returns the new address of MultiBuffer leftover, and number of bytes written into the input byte slice.
func SplitBytes(mb MultiBuffer, b []byte) (MultiBuffer, int) {
totalBytes := 0
endIndex := -1
for i := range mb {
pBuffer := mb[i]
nBytes, _ := pBuffer.Read(b)
totalBytes += nBytes
b = b[nBytes:]
if !pBuffer.IsEmpty() {
endIndex = i
break
}
pBuffer.Release()
mb[i] = nil
}
if endIndex == -1 {
mb = mb[:0]
} else {
mb = mb[endIndex:]
}
return mb, totalBytes
}
// SplitFirstBytes splits the first buffer from MultiBuffer, and then copy its content into the given slice.
func SplitFirstBytes(mb MultiBuffer, p []byte) (MultiBuffer, int) {
mb, b := SplitFirst(mb)
if b == nil {
return mb, 0
}
n := copy(p, b.Bytes())
b.Release()
return mb, n
}
// Compact returns another MultiBuffer by merging all content of the given one together.
func Compact(mb MultiBuffer) MultiBuffer {
if len(mb) == 0 {
return mb
}
mb2 := make(MultiBuffer, 0, len(mb))
last := mb[0]
for i := 1; i < len(mb); i++ {
curr := mb[i]
if last.Len()+curr.Len() > Size {
mb2 = append(mb2, last)
last = curr
} else {
common.Must2(last.ReadFrom(curr))
curr.Release()
}
}
mb2 = append(mb2, last)
return mb2
}
// SplitFirst splits the first Buffer from the beginning of the MultiBuffer.
func SplitFirst(mb MultiBuffer) (MultiBuffer, *Buffer) {
if len(mb) == 0 {
return mb, nil
}
b := mb[0]
mb[0] = nil
mb = mb[1:]
return mb, b
}
// SplitSize splits the beginning of the MultiBuffer into another one, for at most size bytes.
func SplitSize(mb MultiBuffer, size int32) (MultiBuffer, MultiBuffer) {
if len(mb) == 0 {
return mb, nil
}
if mb[0].Len() > size {
b := New()
copy(b.Extend(size), mb[0].BytesTo(size))
mb[0].Advance(size)
return mb, MultiBuffer{b}
}
totalBytes := int32(0)
var r MultiBuffer
endIndex := -1
for i := range mb {
if totalBytes+mb[i].Len() > size {
endIndex = i
break
}
totalBytes += mb[i].Len()
r = append(r, mb[i])
mb[i] = nil
}
if endIndex == -1 {
// To reuse mb array
mb = mb[:0]
} else {
mb = mb[endIndex:]
}
return mb, r
}
// SplitMulti splits the beginning of the MultiBuffer into first one, the index i and after into second one
func SplitMulti(mb MultiBuffer, i int) (MultiBuffer, MultiBuffer) {
mb2 := make(MultiBuffer, 0, len(mb))
if i < len(mb) && i >= 0 {
mb2 = append(mb2, mb[i:]...)
for j := i; j < len(mb); j++ {
mb[j] = nil
}
mb = mb[:i]
}
return mb, mb2
}
// WriteMultiBuffer writes all buffers from the MultiBuffer to the Writer one by one, and return error if any, with leftover MultiBuffer.
func WriteMultiBuffer(writer io.Writer, mb MultiBuffer) (MultiBuffer, error) {
for {
mb2, b := SplitFirst(mb)
mb = mb2
if b == nil {
break
}
_, err := writer.Write(b.Bytes())
b.Release()
if err != nil {
return mb, err
}
}
return nil, nil
}
// Len returns the total number of bytes in the MultiBuffer.
func (mb MultiBuffer) Len() int32 {
if mb == nil {
return 0
}
size := int32(0)
for _, b := range mb {
size += b.Len()
}
return size
}
// IsEmpty returns true if the MultiBuffer has no content.
func (mb MultiBuffer) IsEmpty() bool {
for _, b := range mb {
if !b.IsEmpty() {
return false
}
}
return true
}
// String returns the content of the MultiBuffer in string.
func (mb MultiBuffer) String() string {
v := make([]interface{}, len(mb))
for i, b := range mb {
v[i] = b
}
return serial.Concat(v...)
}
// MultiBufferContainer is a ReadWriteCloser wrapper over MultiBuffer.
type MultiBufferContainer struct {
MultiBuffer
}
// Read implements io.Reader.
func (c *MultiBufferContainer) Read(b []byte) (int, error) {
if c.MultiBuffer.IsEmpty() {
return 0, io.EOF
}
mb, nBytes := SplitBytes(c.MultiBuffer, b)
c.MultiBuffer = mb
return nBytes, nil
}
// ReadMultiBuffer implements Reader.
func (c *MultiBufferContainer) ReadMultiBuffer() (MultiBuffer, error) {
mb := c.MultiBuffer
c.MultiBuffer = nil
return mb, nil
}
// Write implements io.Writer.
func (c *MultiBufferContainer) Write(b []byte) (int, error) {
c.MultiBuffer = MergeBytes(c.MultiBuffer, b)
return len(b), nil
}
// WriteMultiBuffer implements Writer.
func (c *MultiBufferContainer) WriteMultiBuffer(b MultiBuffer) error {
mb, _ := MergeMulti(c.MultiBuffer, b)
c.MultiBuffer = mb
return nil
}
// Close implements io.Closer.
func (c *MultiBufferContainer) Close() error {
c.MultiBuffer = ReleaseMulti(c.MultiBuffer)
return nil
}

View File

@@ -1,38 +0,0 @@
package buf
import (
"github.com/sagernet/sing-box/common/xray/net"
)
type EndpointOverrideReader struct {
Reader
Dest net.Address
OriginalDest net.Address
}
func (r *EndpointOverrideReader) ReadMultiBuffer() (MultiBuffer, error) {
mb, err := r.Reader.ReadMultiBuffer()
if err == nil {
for _, b := range mb {
if b.UDP != nil && b.UDP.Address == r.OriginalDest {
b.UDP.Address = r.Dest
}
}
}
return mb, err
}
type EndpointOverrideWriter struct {
Writer
Dest net.Address
OriginalDest net.Address
}
func (w *EndpointOverrideWriter) WriteMultiBuffer(mb MultiBuffer) error {
for _, b := range mb {
if b.UDP != nil && b.UDP.Address == w.Dest {
b.UDP.Address = w.OriginalDest
}
}
return w.Writer.WriteMultiBuffer(mb)
}

View File

@@ -1,175 +0,0 @@
package buf
import (
"io"
"github.com/sagernet/sing-box/common/xray"
"github.com/sagernet/sing-box/common/xray/errors"
E "github.com/sagernet/sing/common/exceptions"
)
func readOneUDP(r io.Reader) (*Buffer, error) {
b := New()
for i := 0; i < 64; i++ {
_, err := b.ReadFrom(r)
if !b.IsEmpty() {
return b, nil
}
if err != nil {
b.Release()
return nil, err
}
}
b.Release()
return nil, E.New("Reader returns too many empty payloads.")
}
// ReadBuffer reads a Buffer from the given reader.
func ReadBuffer(r io.Reader) (*Buffer, error) {
b := New()
n, err := b.ReadFrom(r)
if n > 0 {
return b, err
}
b.Release()
return nil, err
}
// BufferedReader is a Reader that keeps its internal buffer.
type BufferedReader struct {
// Reader is the underlying reader to be read from
Reader Reader
// Buffer is the internal buffer to be read from first
Buffer MultiBuffer
// Splitter is a function to read bytes from MultiBuffer
Splitter func(MultiBuffer, []byte) (MultiBuffer, int)
}
// BufferedBytes returns the number of bytes that is cached in this reader.
func (r *BufferedReader) BufferedBytes() int32 {
return r.Buffer.Len()
}
// ReadByte implements io.ByteReader.
func (r *BufferedReader) ReadByte() (byte, error) {
var b [1]byte
_, err := r.Read(b[:])
return b[0], err
}
// Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader.
func (r *BufferedReader) Read(b []byte) (int, error) {
spliter := r.Splitter
if spliter == nil {
spliter = SplitBytes
}
if !r.Buffer.IsEmpty() {
buffer, nBytes := spliter(r.Buffer, b)
r.Buffer = buffer
if r.Buffer.IsEmpty() {
r.Buffer = nil
}
return nBytes, nil
}
mb, err := r.Reader.ReadMultiBuffer()
if err != nil {
return 0, err
}
mb, nBytes := spliter(mb, b)
if !mb.IsEmpty() {
r.Buffer = mb
}
return nBytes, nil
}
// ReadMultiBuffer implements Reader.
func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
if !r.Buffer.IsEmpty() {
mb := r.Buffer
r.Buffer = nil
return mb, nil
}
return r.Reader.ReadMultiBuffer()
}
// ReadAtMost returns a MultiBuffer with at most size.
func (r *BufferedReader) ReadAtMost(size int32) (MultiBuffer, error) {
if r.Buffer.IsEmpty() {
mb, err := r.Reader.ReadMultiBuffer()
if mb.IsEmpty() && err != nil {
return nil, err
}
r.Buffer = mb
}
rb, mb := SplitSize(r.Buffer, size)
r.Buffer = rb
if r.Buffer.IsEmpty() {
r.Buffer = nil
}
return mb, nil
}
func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
mbWriter := NewWriter(writer)
var sc SizeCounter
if r.Buffer != nil {
sc.Size = int64(r.Buffer.Len())
if err := mbWriter.WriteMultiBuffer(r.Buffer); err != nil {
return 0, err
}
r.Buffer = nil
}
err := Copy(r.Reader, mbWriter, CountSize(&sc))
return sc.Size, err
}
// WriteTo implements io.WriterTo.
func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
nBytes, err := r.writeToInternal(writer)
if errors.Cause(err) == io.EOF {
return nBytes, nil
}
return nBytes, err
}
// Interrupt implements common.Interruptible.
func (r *BufferedReader) Interrupt() {
common.Interrupt(r.Reader)
}
// Close implements io.Closer.
func (r *BufferedReader) Close() error {
return common.Close(r.Reader)
}
// SingleReader is a Reader that read one Buffer every time.
type SingleReader struct {
io.Reader
}
// ReadMultiBuffer implements Reader.
func (r *SingleReader) ReadMultiBuffer() (MultiBuffer, error) {
b, err := ReadBuffer(r.Reader)
return MultiBuffer{b}, err
}
// PacketReader is a Reader that read one Buffer every time.
type PacketReader struct {
io.Reader
}
// ReadMultiBuffer implements Reader.
func (r *PacketReader) ReadMultiBuffer() (MultiBuffer, error) {
b, err := readOneUDP(r.Reader)
if err != nil {
return nil, err
}
return MultiBuffer{b}, nil
}

View File

@@ -1,270 +0,0 @@
package buf
import (
"io"
"net"
"sync"
"github.com/sagernet/sing-box/common/xray"
"github.com/sagernet/sing-box/common/xray/errors"
"github.com/sagernet/sing-box/common/xray/stats"
)
// BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
type BufferToBytesWriter struct {
io.Writer
counter stats.Counter
cache [][]byte
}
// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer ReleaseMulti(mb)
size := mb.Len()
if size == 0 {
return nil
}
if len(mb) == 1 {
return WriteAllBytes(w.Writer, mb[0].Bytes(), w.counter)
}
if cap(w.cache) < len(mb) {
w.cache = make([][]byte, 0, len(mb))
}
bs := w.cache
for _, b := range mb {
bs = append(bs, b.Bytes())
}
defer func() {
for idx := range bs {
bs[idx] = nil
}
}()
nb := net.Buffers(bs)
wc := int64(0)
defer func() {
if w.counter != nil {
w.counter.Add(wc)
}
}()
for size > 0 {
n, err := nb.WriteTo(w.Writer)
wc += n
if err != nil {
return err
}
size -= int32(n)
}
return nil
}
// ReadFrom implements io.ReaderFrom.
func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
// BufferedWriter is a Writer with internal buffer.
type BufferedWriter struct {
sync.Mutex
writer Writer
buffer *Buffer
buffered bool
}
// NewBufferedWriter creates a new BufferedWriter.
func NewBufferedWriter(writer Writer) *BufferedWriter {
return &BufferedWriter{
writer: writer,
buffer: New(),
buffered: true,
}
}
// WriteByte implements io.ByteWriter.
func (w *BufferedWriter) WriteByte(c byte) error {
return common.Error2(w.Write([]byte{c}))
}
// Write implements io.Writer.
func (w *BufferedWriter) Write(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
w.Lock()
defer w.Unlock()
if !w.buffered {
if writer, ok := w.writer.(io.Writer); ok {
return writer.Write(b)
}
}
totalBytes := 0
for len(b) > 0 {
if w.buffer == nil {
w.buffer = New()
}
nBytes, err := w.buffer.Write(b)
totalBytes += nBytes
if err != nil {
return totalBytes, err
}
if !w.buffered || w.buffer.IsFull() {
if err := w.flushInternal(); err != nil {
return totalBytes, err
}
}
b = b[nBytes:]
}
return totalBytes, nil
}
// WriteMultiBuffer implements Writer. It takes ownership of the given MultiBuffer.
func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
if b.IsEmpty() {
return nil
}
w.Lock()
defer w.Unlock()
if !w.buffered {
return w.writer.WriteMultiBuffer(b)
}
reader := MultiBufferContainer{
MultiBuffer: b,
}
defer reader.Close()
for !reader.MultiBuffer.IsEmpty() {
if w.buffer == nil {
w.buffer = New()
}
common.Must2(w.buffer.ReadFrom(&reader))
if w.buffer.IsFull() {
if err := w.flushInternal(); err != nil {
return err
}
}
}
return nil
}
// Flush flushes buffered content into underlying writer.
func (w *BufferedWriter) Flush() error {
w.Lock()
defer w.Unlock()
return w.flushInternal()
}
func (w *BufferedWriter) flushInternal() error {
if w.buffer.IsEmpty() {
return nil
}
b := w.buffer
w.buffer = nil
if writer, ok := w.writer.(io.Writer); ok {
err := WriteAllBytes(writer, b.Bytes(), nil)
b.Release()
return err
}
return w.writer.WriteMultiBuffer(MultiBuffer{b})
}
// SetBuffered sets whether the internal buffer is used. If set to false, Flush() will be called to clear the buffer.
func (w *BufferedWriter) SetBuffered(f bool) error {
w.Lock()
defer w.Unlock()
w.buffered = f
if !f {
return w.flushInternal()
}
return nil
}
// ReadFrom implements io.ReaderFrom.
func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
if err := w.SetBuffered(false); err != nil {
return 0, err
}
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
// Close implements io.Closable.
func (w *BufferedWriter) Close() error {
if err := w.Flush(); err != nil {
return err
}
return common.Close(w.writer)
}
// SequentialWriter is a Writer that writes MultiBuffer sequentially into the underlying io.Writer.
type SequentialWriter struct {
io.Writer
}
// WriteMultiBuffer implements Writer.
func (w *SequentialWriter) WriteMultiBuffer(mb MultiBuffer) error {
mb, err := WriteMultiBuffer(w.Writer, mb)
ReleaseMulti(mb)
return err
}
type noOpWriter byte
func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error {
ReleaseMulti(b)
return nil
}
func (noOpWriter) Write(b []byte) (int, error) {
return len(b), nil
}
func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
b := New()
defer b.Release()
totalBytes := int64(0)
for {
b.Clear()
_, err := b.ReadFrom(reader)
totalBytes += int64(b.Len())
if err != nil {
if errors.Cause(err) == io.EOF {
return totalBytes, nil
}
return totalBytes, err
}
}
}
var (
// Discard is a Writer that swallows all contents written in.
Discard Writer = noOpWriter(0)
// DiscardBytes is an io.Writer that swallows all contents written in.
DiscardBytes io.Writer = noOpWriter(0)
)

View File

@@ -1,72 +0,0 @@
package bytespool
import "sync"
func createAllocFunc(size int32) func() interface{} {
return func() interface{} {
return make([]byte, size)
}
}
// The following parameters controls the size of buffer pools.
// There are numPools pools. Starting from 2k size, the size of each pool is sizeMulti of the previous one.
// Package buf is guaranteed to not use buffers larger than the largest pool.
// Other packets may use larger buffers.
const (
numPools = 4
sizeMulti = 4
)
var (
pool [numPools]sync.Pool
poolSize [numPools]int32
)
func init() {
size := int32(2048)
for i := 0; i < numPools; i++ {
pool[i] = sync.Pool{
New: createAllocFunc(size),
}
poolSize[i] = size
size *= sizeMulti
}
}
// GetPool returns a sync.Pool that generates bytes array with at least the given size.
// It may return nil if no such pool exists.
//
// xray:api:stable
func GetPool(size int32) *sync.Pool {
for idx, ps := range poolSize {
if size <= ps {
return &pool[idx]
}
}
return nil
}
// Alloc returns a byte slice with at least the given size. Minimum size of returned slice is 2048.
//
// xray:api:stable
func Alloc(size int32) []byte {
pool := GetPool(size)
if pool != nil {
return pool.Get().([]byte)
}
return make([]byte, size)
}
// Free puts a byte slice into the internal pool.
//
// xray:api:stable
func Free(b []byte) {
size := int32(cap(b))
b = b[0:cap(b)]
for i := numPools - 1; i >= 0; i-- {
if size >= poolSize[i] {
pool[i].Put(b)
return
}
}
}

View File

@@ -1,19 +0,0 @@
package common
// Must panics if err is not nil.
func Must(err error) {
if err != nil {
panic(err)
}
}
// Must2 panics if the second parameter is not nil, otherwise returns the first parameter.
func Must2(v interface{}, err error) interface{} {
Must(err)
return v
}
// Error2 returns the err from the 2nd parameter.
func Error2(v interface{}, err error) error {
return err
}

View File

@@ -1,14 +0,0 @@
package crypto
import (
"crypto/rand"
"math/big"
)
func RandBetween(from int64, to int64) int64 {
if from == to {
return from
}
bigInt, _ := rand.Int(rand.Reader, big.NewInt(to-from))
return from + bigInt.Int64()
}

Some files were not shown because too many files have changed in this diff Show More