Add NDIS inbound

This commit is contained in:
世界
2025-01-03 18:34:07 +08:00
parent e483c909b4
commit 79d3649a8b
32 changed files with 1339 additions and 572 deletions

110
protocol/ndis/endpoint.go Normal file
View File

@@ -0,0 +1,110 @@
//go:build windows
package ndis
import (
"sync"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/wiresock/ndisapi-go"
"github.com/wiresock/ndisapi-go/driver"
)
var _ stack.LinkEndpoint = (*ndisEndpoint)(nil)
type ndisEndpoint struct {
filter *driver.QueuedPacketFilter
mtu uint32
address tcpip.LinkAddress
dispatcher stack.NetworkDispatcher
}
func (e *ndisEndpoint) MTU() uint32 {
return e.mtu
}
func (e *ndisEndpoint) SetMTU(mtu uint32) {
}
func (e *ndisEndpoint) MaxHeaderLength() uint16 {
return header.EthernetMinimumSize
}
func (e *ndisEndpoint) LinkAddress() tcpip.LinkAddress {
return e.address
}
func (e *ndisEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
}
func (e *ndisEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return 0
}
func (e *ndisEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *ndisEndpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *ndisEndpoint) Wait() {
}
func (e *ndisEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareEther
}
func (e *ndisEndpoint) AddHeader(pkt *stack.PacketBuffer) {
eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
fields := header.EthernetFields{
SrcAddr: pkt.EgressRoute.LocalLinkAddress,
DstAddr: pkt.EgressRoute.RemoteLinkAddress,
Type: pkt.NetworkProtocolNumber,
}
eth.Encode(&fields)
}
func (e *ndisEndpoint) ParseHeader(pkt *stack.PacketBuffer) bool {
_, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
return ok
}
func (e *ndisEndpoint) Close() {
}
func (e *ndisEndpoint) SetOnCloseAction(f func()) {
}
var bufferPool = sync.Pool{
New: func() any {
return new(ndisapi.IntermediateBuffer)
},
}
func (e *ndisEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
ndisBuf := bufferPool.Get().(*ndisapi.IntermediateBuffer)
viewList, offset := packetBuffer.AsViewList()
var view *buffer.View
for view = viewList.Front(); view != nil && offset >= view.Size(); view = view.Next() {
offset -= view.Size()
}
index := copy(ndisBuf.Buffer[:], view.AsSlice()[offset:])
for view = view.Next(); view != nil; view = view.Next() {
index += copy(ndisBuf.Buffer[index:], view.AsSlice())
}
ndisBuf.Length = uint32(index)
err := e.filter.InsertPacketToMstcp(ndisBuf)
bufferPool.Put(ndisBuf)
if err != nil {
return 0, &tcpip.ErrAborted{}
}
}
return list.Len(), nil
}

203
protocol/ndis/inbound.go Normal file
View File

@@ -0,0 +1,203 @@
//go:build windows
package ndis
import (
"context"
"net"
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/wiresock/ndisapi-go"
"go4.org/netipx"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.NDISInboundOptions](registry, C.TypeNDIS, NewInbound)
}
type Inbound struct {
inbound.Adapter
ctx context.Context
router adapter.Router
logger log.ContextLogger
api *ndisapi.NdisApi
tracker conntrack.Tracker
routeAddress []netip.Prefix
routeExcludeAddress []netip.Prefix
routeRuleSet []adapter.RuleSet
routeRuleSetCallback []*list.Element[adapter.RuleSetUpdateCallback]
routeExcludeRuleSet []adapter.RuleSet
routeExcludeRuleSetCallback []*list.Element[adapter.RuleSetUpdateCallback]
stack *Stack
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NDISInboundOptions) (adapter.Inbound, error) {
api, err := ndisapi.NewNdisApi()
if err != nil {
return nil, E.Cause(err, "create NDIS API")
}
//if !api.IsDriverLoaded() {
// return nil, E.New("missing NDIS driver")
//}
networkManager := service.FromContext[adapter.NetworkManager](ctx)
trackerOut := service.FromContext[conntrack.Tracker](ctx)
var udpTimeout time.Duration
if options.UDPTimeout != 0 {
udpTimeout = time.Duration(options.UDPTimeout)
} else {
udpTimeout = C.UDPTimeout
}
var (
routeRuleSet []adapter.RuleSet
routeExcludeRuleSet []adapter.RuleSet
)
for _, routeAddressSet := range options.RouteAddressSet {
ruleSet, loaded := router.RuleSet(routeAddressSet)
if !loaded {
return nil, E.New("parse route_address_set: rule-set not found: ", routeAddressSet)
}
ruleSet.IncRef()
routeRuleSet = append(routeRuleSet, ruleSet)
}
for _, routeExcludeAddressSet := range options.RouteExcludeAddressSet {
ruleSet, loaded := router.RuleSet(routeExcludeAddressSet)
if !loaded {
return nil, E.New("parse route_exclude_address_set: rule-set not found: ", routeExcludeAddressSet)
}
ruleSet.IncRef()
routeExcludeRuleSet = append(routeExcludeRuleSet, ruleSet)
}
trackerIn := conntrack.NewDefaultTracker(false, 0)
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeNDIS, tag),
ctx: ctx,
router: router,
logger: logger,
api: api,
tracker: trackerIn,
routeRuleSet: routeRuleSet,
routeExcludeRuleSet: routeExcludeRuleSet,
stack: &Stack{
ctx: ctx,
logger: logger,
network: networkManager,
trackerIn: trackerIn,
trackerOut: trackerOut,
api: api,
udpTimeout: udpTimeout,
routeAddress: options.RouteAddress,
routeExcludeAddress: options.RouteExcludeAddress,
},
}, nil
}
func (t *Inbound) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
monitor := taskmonitor.New(t.logger, C.StartTimeout)
var (
routeAddressSet []*netipx.IPSet
routeExcludeAddressSet []*netipx.IPSet
)
for _, routeRuleSet := range t.routeRuleSet {
ipSets := routeRuleSet.ExtractIPSet()
if len(ipSets) == 0 {
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeRuleSet.Name())
}
t.routeRuleSetCallback = append(t.routeRuleSetCallback, routeRuleSet.RegisterCallback(t.updateRouteAddressSet))
routeRuleSet.DecRef()
routeAddressSet = append(routeAddressSet, ipSets...)
}
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
ipSets := routeExcludeRuleSet.ExtractIPSet()
if len(ipSets) == 0 {
t.logger.Warn("route_exclude_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
}
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
routeExcludeRuleSet.DecRef()
routeExcludeAddressSet = append(routeExcludeAddressSet, ipSets...)
}
t.stack.routeAddressSet = routeAddressSet
t.stack.routeExcludeAddressSet = routeExcludeAddressSet
monitor.Start("starting NDIS stack")
t.stack.handler = t
err := t.stack.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "starting NDIS stack")
}
}
return nil
}
func (t *Inbound) Close() error {
if t.api != nil {
t.stack.Close()
t.api.Close()
}
return nil
}
func (t *Inbound) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
return t.router.PreMatch(adapter.InboundContext{
Inbound: t.Tag(),
InboundType: C.TypeNDIS,
Network: network,
Source: source,
Destination: destination,
})
}
func (t *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = t.Tag()
metadata.InboundType = C.TypeNDIS
metadata.Source = source
metadata.Destination = destination
t.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
t.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
done, err := t.tracker.NewConnEx(conn)
if err != nil {
t.logger.ErrorContext(ctx, E.Cause(err, "track inbound connection"))
return
}
t.router.RouteConnectionEx(ctx, conn, metadata, N.AppendClose(onClose, done))
}
func (t *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = t.Tag()
metadata.InboundType = C.TypeNDIS
metadata.Source = source
metadata.Destination = destination
t.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source)
t.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
done, err := t.tracker.NewPacketConnEx(conn)
if err != nil {
t.logger.ErrorContext(ctx, E.Cause(err, "track inbound connection"))
return
}
t.router.RoutePacketConnectionEx(ctx, conn, metadata, N.AppendClose(onClose, done))
}
func (t *Inbound) updateRouteAddressSet(it adapter.RuleSet) {
t.stack.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
t.stack.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
}

267
protocol/ndis/stack.go Normal file
View File

@@ -0,0 +1,267 @@
//go:build windows
package ndis
import (
"context"
"net/netip"
"time"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/conntrack"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/wiresock/ndisapi-go"
"github.com/wiresock/ndisapi-go/driver"
"go4.org/netipx"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type Stack struct {
ctx context.Context
logger logger.ContextLogger
network adapter.NetworkManager
trackerIn conntrack.Tracker
trackerOut conntrack.Tracker
api *ndisapi.NdisApi
handler tun.Handler
udpTimeout time.Duration
filter *driver.QueuedPacketFilter
stack *stack.Stack
endpoint *ndisEndpoint
routeAddress []netip.Prefix
routeExcludeAddress []netip.Prefix
routeAddressSet []*netipx.IPSet
routeExcludeAddressSet []*netipx.IPSet
currentInterface *control.Interface
}
func (s *Stack) Start() error {
err := s.start(s.network.InterfaceMonitor().DefaultInterface())
if err != nil {
return err
}
s.network.InterfaceMonitor().RegisterCallback(s.updateDefaultInterface)
return nil
}
func (s *Stack) updateDefaultInterface(defaultInterface *control.Interface, flags int) {
if s.currentInterface.Equals(*defaultInterface) {
return
}
err := s.start(defaultInterface)
if err != nil {
s.logger.Error(E.Cause(err, "reconfigure NDIS at: ", defaultInterface.Name))
}
}
func (s *Stack) start(defaultInterface *control.Interface) error {
_ = s.Close()
adapters, err := s.api.GetTcpipBoundAdaptersInfo()
if err != nil {
return err
}
if defaultInterface != nil {
for index := 0; index < int(adapters.AdapterCount); index++ {
name := s.api.ConvertWindows2000AdapterName(string(adapters.AdapterNameList[index][:]))
if name != defaultInterface.Name {
continue
}
s.filter, err = driver.NewQueuedPacketFilter(s.api, adapters, nil, s.processOut)
if err != nil {
return err
}
address := tcpip.LinkAddress(adapters.CurrentAddress[index][:])
mtu := uint32(adapters.MTU[index])
endpoint := &ndisEndpoint{
filter: s.filter,
mtu: mtu,
address: address,
}
s.stack, err = tun.NewGVisorStack(endpoint)
if err != nil {
s.filter = nil
return err
}
s.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(s.ctx, s.stack, s.handler).HandlePacket)
s.stack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(s.ctx, s.stack, s.handler, s.udpTimeout).HandlePacket)
err = s.filter.StartFilter(index)
if err != nil {
s.filter = nil
s.stack.Close()
s.stack = nil
return err
}
s.endpoint = endpoint
s.logger.Info("started at ", defaultInterface.Name)
break
}
}
s.currentInterface = defaultInterface
return nil
}
func (s *Stack) Close() error {
if s.filter != nil {
s.filter.StopFilter()
s.filter.Close()
s.filter = nil
}
if s.stack != nil {
s.stack.Close()
for _, endpoint := range s.stack.CleanupEndpoints() {
endpoint.Abort()
}
s.stack = nil
}
return nil
}
func (s *Stack) processOut(handle ndisapi.Handle, packet *ndisapi.IntermediateBuffer) ndisapi.FilterAction {
if packet.Length < header.EthernetMinimumSize {
return ndisapi.FilterActionPass
}
if s.endpoint.dispatcher == nil || s.filterPacket(packet.Buffer[:packet.Length]) {
return ndisapi.FilterActionPass
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet.Buffer[:packet.Length]),
})
_, ok := packetBuffer.LinkHeader().Consume(header.EthernetMinimumSize)
if !ok {
packetBuffer.DecRef()
return ndisapi.FilterActionPass
}
ethHdr := header.Ethernet(packetBuffer.LinkHeader().Slice())
destinationAddress := ethHdr.DestinationAddress()
if destinationAddress == header.EthernetBroadcastAddress {
packetBuffer.PktType = tcpip.PacketBroadcast
} else if header.IsMulticastEthernetAddress(destinationAddress) {
packetBuffer.PktType = tcpip.PacketMulticast
} else if destinationAddress == s.endpoint.address {
packetBuffer.PktType = tcpip.PacketHost
} else {
packetBuffer.PktType = tcpip.PacketOtherHost
}
s.endpoint.dispatcher.DeliverNetworkPacket(ethHdr.Type(), packetBuffer)
packetBuffer.DecRef()
return ndisapi.FilterActionDrop
}
func (s *Stack) filterPacket(packet []byte) bool {
var ipHdr header.Network
switch header.IPVersion(packet[header.EthernetMinimumSize:]) {
case ipv4.Version:
ipHdr = header.IPv4(packet[header.EthernetMinimumSize:])
case ipv6.Version:
ipHdr = header.IPv6(packet[header.EthernetMinimumSize:])
default:
return true
}
sourceAddr := tun.AddrFromAddress(ipHdr.SourceAddress())
destinationAddr := tun.AddrFromAddress(ipHdr.DestinationAddress())
if !destinationAddr.IsGlobalUnicast() {
return true
}
var (
transportProtocol tcpip.TransportProtocolNumber
transportHdr header.Transport
)
switch ipHdr.TransportProtocol() {
case tcp.ProtocolNumber:
transportProtocol = header.TCPProtocolNumber
transportHdr = header.TCP(ipHdr.Payload())
case udp.ProtocolNumber:
transportProtocol = header.UDPProtocolNumber
transportHdr = header.UDP(ipHdr.Payload())
default:
return false
}
source := netip.AddrPortFrom(sourceAddr, transportHdr.SourcePort())
destination := netip.AddrPortFrom(destinationAddr, transportHdr.DestinationPort())
if transportProtocol == header.TCPProtocolNumber {
if s.trackerIn.CheckConn(source, destination) {
if debug.Enabled {
s.logger.Trace("fall exists TCP ", source, " ", destination)
}
return false
}
} else {
if s.trackerIn.CheckPacketConn(source) {
if debug.Enabled {
s.logger.Trace("fall exists UDP ", source, " ", destination)
}
}
}
if len(s.routeAddress) > 0 {
var match bool
for _, route := range s.routeAddress {
if route.Contains(destinationAddr) {
match = true
}
}
if !match {
return true
}
}
if len(s.routeAddressSet) > 0 {
var match bool
for _, ipSet := range s.routeAddressSet {
if ipSet.Contains(destinationAddr) {
match = true
}
}
if !match {
return true
}
}
if len(s.routeExcludeAddress) > 0 {
for _, address := range s.routeExcludeAddress {
if address.Contains(destinationAddr) {
return true
}
}
}
if len(s.routeExcludeAddressSet) > 0 {
for _, ipSet := range s.routeAddressSet {
if ipSet.Contains(destinationAddr) {
return true
}
}
}
if s.trackerOut.CheckDestination(destination) {
if debug.Enabled {
s.logger.Trace("passing pending ", source, " ", destination)
}
return true
}
if transportProtocol == header.TCPProtocolNumber {
if s.trackerOut.CheckConn(source, destination) {
if debug.Enabled {
s.logger.Trace("passing TCP ", source, " ", destination)
}
return true
}
} else {
if s.trackerOut.CheckPacketConn(source) {
if debug.Enabled {
s.logger.Trace("passing UDP ", source, " ", destination)
}
}
}
if debug.Enabled {
s.logger.Trace("fall ", source, " ", destination)
}
return false
}

View File

@@ -306,7 +306,6 @@ func (t *Inbound) Start(stage adapter.StartStage) error {
t.tunOptions.Name = tun.CalculateInterfaceName("")
}
if t.platformInterface == nil || runtime.GOOS != "android" {
t.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
for _, routeRuleSet := range t.routeRuleSet {
ipSets := routeRuleSet.ExtractIPSet()
if len(ipSets) == 0 {
@@ -316,11 +315,10 @@ func (t *Inbound) Start(stage adapter.StartStage) error {
routeRuleSet.DecRef()
t.routeAddressSet = append(t.routeAddressSet, ipSets...)
}
t.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
ipSets := routeExcludeRuleSet.ExtractIPSet()
if len(ipSets) == 0 {
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
t.logger.Warn("route_exclude_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
}
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
routeExcludeRuleSet.DecRef()