diff --git a/route/process_cache.go b/route/process_cache.go index 691a4e8e..01b477c4 100644 --- a/route/process_cache.go +++ b/route/process_cache.go @@ -3,6 +3,7 @@ package route import ( "context" "net/netip" + "strings" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/process" @@ -32,3 +33,60 @@ func (r *Router) findProcessInfoCached(ctx context.Context, network string, sour r.processCache.Add(key, processCacheEntry{result: result, err: err}) return result, err } + +func (r *Router) searchProcessInfo(ctx context.Context, metadata *adapter.InboundContext) { + if r.processSearcher == nil || metadata.ProcessInfo != nil || !r.isLocalSource(metadata.Source.Addr) { + return + } + var originDestination netip.AddrPort + if metadata.OriginDestination.IsValid() { + originDestination = metadata.OriginDestination.AddrPort() + } else if metadata.Destination.IsIP() { + originDestination = metadata.Destination.AddrPort() + } + processInfo, err := r.findProcessInfoCached(ctx, metadata.Network, metadata.Source.AddrPort(), originDestination) + if err != nil { + r.logger.InfoContext(ctx, "failed to search process: ", err) + return + } + metadata.ProcessInfo = processInfo + if processInfo.ProcessPath != "" { + if processInfo.UserName != "" { + r.logger.InfoContext(ctx, "found process path: ", processInfo.ProcessPath, ", user: ", processInfo.UserName) + } else if processInfo.UserId != -1 { + r.logger.InfoContext(ctx, "found process path: ", processInfo.ProcessPath, ", user id: ", processInfo.UserId) + } else { + r.logger.InfoContext(ctx, "found process path: ", processInfo.ProcessPath) + } + return + } + if len(processInfo.AndroidPackageNames) > 0 { + r.logger.InfoContext(ctx, "found package name: ", strings.Join(processInfo.AndroidPackageNames, ", ")) + return + } + if processInfo.UserId != -1 { + if processInfo.UserName != "" { + r.logger.InfoContext(ctx, "found user: ", processInfo.UserName) + } else { + r.logger.InfoContext(ctx, "found user id: ", processInfo.UserId) + } + } +} + +func (r *Router) isLocalSource(source netip.Addr) bool { + if !source.IsValid() { + return false + } + source = source.Unmap() + if source.IsLoopback() { + return true + } + for _, netInterface := range r.network.InterfaceFinder().Interfaces() { + for _, prefix := range netInterface.Addresses { + if prefix.Addr().Unmap() == source { + return true + } + } + } + return false +} diff --git a/route/route.go b/route/route.go index 77b66ea4..7c24219e 100644 --- a/route/route.go +++ b/route/route.go @@ -407,37 +407,7 @@ func (r *Router) matchRule( selectedRule adapter.Rule, selectedRuleIndex int, buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error, ) { - if r.processSearcher != nil && metadata.ProcessInfo == nil { - var originDestination netip.AddrPort - if metadata.OriginDestination.IsValid() { - originDestination = metadata.OriginDestination.AddrPort() - } else if metadata.Destination.IsIP() { - originDestination = metadata.Destination.AddrPort() - } - processInfo, fErr := r.findProcessInfoCached(ctx, metadata.Network, metadata.Source.AddrPort(), originDestination) - if fErr != nil { - r.logger.InfoContext(ctx, "failed to search process: ", fErr) - } else { - if processInfo.ProcessPath != "" { - if processInfo.UserName != "" { - r.logger.InfoContext(ctx, "found process path: ", processInfo.ProcessPath, ", user: ", processInfo.UserName) - } else if processInfo.UserId != -1 { - r.logger.InfoContext(ctx, "found process path: ", processInfo.ProcessPath, ", user id: ", processInfo.UserId) - } else { - r.logger.InfoContext(ctx, "found process path: ", processInfo.ProcessPath) - } - } else if len(processInfo.AndroidPackageNames) > 0 { - r.logger.InfoContext(ctx, "found package name: ", strings.Join(processInfo.AndroidPackageNames, ", ")) - } else if processInfo.UserId != -1 { - if processInfo.UserName != "" { - r.logger.InfoContext(ctx, "found user: ", processInfo.UserName) - } else { - r.logger.InfoContext(ctx, "found user id: ", processInfo.UserId) - } - } - metadata.ProcessInfo = processInfo - } - } + r.searchProcessInfo(ctx, metadata) if metadata.Destination.Addr.IsValid() && r.dnsTransport.FakeIP() != nil && r.dnsTransport.FakeIP().Store().Contains(metadata.Destination.Addr) { domain, loaded := r.dnsTransport.FakeIP().Store().Lookup(metadata.Destination.Addr) if !loaded {