Refactor geo resources

This commit is contained in:
世界
2022-07-05 13:23:47 +08:00
parent 8392567962
commit 2d9203ee74
12 changed files with 273 additions and 198 deletions

View File

@@ -9,8 +9,8 @@ import (
"path/filepath"
"time"
"github.com/oschwald/geoip2-golang"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/geoip"
"github.com/sagernet/sing-box/common/geosite"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
@@ -36,12 +36,11 @@ type Router struct {
defaultOutboundForConnection adapter.Outbound
defaultOutboundForPacketConnection adapter.Outbound
needGeoIPDatabase bool
geoIPOptions option.GeoIPOptions
geoIPReader *geoip2.Reader
needGeoIPDatabase bool
needGeositeDatabase bool
geoIPOptions option.GeoIPOptions
geositeOptions option.GeositeOptions
geoIPReader *geoip.Reader
geositeReader *geosite.Reader
}
@@ -154,17 +153,28 @@ func (r *Router) Start() error {
return err
}
}
if r.needGeositeDatabase {
for _, rule := range r.rules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
err := common.Close(r.geositeReader)
if err != nil {
return err
}
}
return nil
}
func (r *Router) Close() error {
return common.Close(
common.PtrOrNil(r.geoIPReader),
common.PtrOrNil(r.geositeReader),
)
}
func (r *Router) GeoIPReader() *geoip2.Reader {
func (r *Router) GeoIPReader() *geoip.Reader {
return r.geoIPReader
}
@@ -199,7 +209,7 @@ func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, def
for i, rule := range r.rules {
if rule.Match(&metadata) {
detour := rule.Outbound()
r.logger.WithContext(ctx).Info("match [", i, "]", rule.String(), " => ", detour)
r.logger.WithContext(ctx).Info("match[", i, "] ", rule.String(), " => ", detour)
if outbound, loaded := r.Outbound(detour); loaded {
return outbound
}
@@ -245,7 +255,7 @@ func (r *Router) prepareGeoIPDatabase() error {
if r.geoIPOptions.Path != "" {
geoPath = r.geoIPOptions.Path
} else {
geoPath = "Country.mmdb"
geoPath = "geoip.db"
if foundPath, loaded := C.Find(geoPath); loaded {
geoPath = foundPath
}
@@ -266,13 +276,12 @@ func (r *Router) prepareGeoIPDatabase() error {
return err
}
}
geoReader, err := geoip2.Open(geoPath)
if err == nil {
r.logger.Info("loaded geoip database")
r.geoIPReader = geoReader
} else {
geoReader, codes, err := geoip.Open(geoPath)
if err != nil {
return E.Cause(err, "open geoip database")
}
r.logger.Info("loaded geoip database: ", len(codes), " codes")
r.geoIPReader = geoReader
return nil
}
@@ -302,9 +311,9 @@ func (r *Router) prepareGeositeDatabase() error {
return err
}
}
geoReader, err := geosite.Open(geoPath)
geoReader, codes, err := geosite.Open(geoPath)
if err == nil {
r.logger.Info("loaded geosite database")
r.logger.Info("loaded geosite database: ", len(codes), " codes")
r.geositeReader = geoReader
} else {
return E.Cause(err, "open geosite database")
@@ -317,7 +326,7 @@ func (r *Router) downloadGeoIPDatabase(savePath string) error {
if r.geoIPOptions.DownloadURL != "" {
downloadURL = r.geoIPOptions.DownloadURL
} else {
downloadURL = "https://cdn.jsdelivr.net/gh/Dreamacro/maxmind-geoip@release/Country.mmdb"
downloadURL = "https://github.com/SagerNet/sing-geoip/releases/latest/download/geoip.db"
}
r.logger.Info("downloading geoip database")
var detour adapter.Outbound
@@ -342,7 +351,6 @@ func (r *Router) downloadGeoIPDatabase(savePath string) error {
defer saveFile.Close()
httpClient := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second,
@@ -390,7 +398,6 @@ func (r *Router) downloadGeositeDatabase(savePath string) error {
defer saveFile.Close()
httpClient := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second,

View File

@@ -44,6 +44,7 @@ type DefaultRule struct {
items []RuleItem
sourceAddressItems []RuleItem
destinationAddressItems []RuleItem
allItems []RuleItem
outbound string
}
@@ -57,12 +58,16 @@ func NewDefaultRule(router adapter.Router, logger log.Logger, options option.Def
outbound: options.Outbound,
}
if len(options.Inbound) > 0 {
rule.items = append(rule.items, NewInboundRule(options.Inbound))
item := NewInboundRule(options.Inbound)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if options.IPVersion > 0 {
switch options.IPVersion {
case 4, 6:
rule.items = append(rule.items, NewIPVersionItem(options.IPVersion == 6))
item := NewIPVersionItem(options.IPVersion == 6)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
default:
return nil, E.New("invalid ip version: ", options.IPVersion)
}
@@ -70,19 +75,27 @@ func NewDefaultRule(router adapter.Router, logger log.Logger, options option.Def
if options.Network != "" {
switch options.Network {
case C.NetworkTCP, C.NetworkUDP:
rule.items = append(rule.items, NewNetworkItem(options.Network))
item := NewNetworkItem(options.Network)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
default:
return nil, E.New("invalid network: ", options.Network)
}
}
if len(options.Protocol) > 0 {
rule.items = append(rule.items, NewProtocolItem(options.Protocol))
item := NewProtocolItem(options.Protocol)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 {
rule.destinationAddressItems = append(rule.destinationAddressItems, NewDomainItem(options.Domain, options.DomainSuffix))
item := NewDomainItem(options.Domain, options.DomainSuffix)
rule.destinationAddressItems = append(rule.destinationAddressItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.DomainKeyword) > 0 {
rule.destinationAddressItems = append(rule.destinationAddressItems, NewDomainKeywordItem(options.DomainKeyword))
item := NewDomainKeywordItem(options.DomainKeyword)
rule.destinationAddressItems = append(rule.destinationAddressItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.DomainRegex) > 0 {
item, err := NewDomainRegexItem(options.DomainRegex)
@@ -90,15 +103,22 @@ func NewDefaultRule(router adapter.Router, logger log.Logger, options option.Def
return nil, E.Cause(err, "domain_regex")
}
rule.destinationAddressItems = append(rule.destinationAddressItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.Geosite) > 0 {
rule.destinationAddressItems = append(rule.destinationAddressItems, NewGeositeItem(router, logger, options.Geosite))
item := NewGeositeItem(router, logger, options.Geosite)
rule.destinationAddressItems = append(rule.destinationAddressItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.SourceGeoIP) > 0 {
rule.sourceAddressItems = append(rule.sourceAddressItems, NewGeoIPItem(router, logger, true, options.SourceGeoIP))
item := NewGeoIPItem(router, logger, true, options.SourceGeoIP)
rule.sourceAddressItems = append(rule.sourceAddressItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.GeoIP) > 0 {
rule.destinationAddressItems = append(rule.destinationAddressItems, NewGeoIPItem(router, logger, false, options.GeoIP))
item := NewGeoIPItem(router, logger, false, options.GeoIP)
rule.destinationAddressItems = append(rule.destinationAddressItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.SourceIPCIDR) > 0 {
item, err := NewIPCIDRItem(true, options.SourceIPCIDR)
@@ -106,6 +126,7 @@ func NewDefaultRule(router adapter.Router, logger log.Logger, options option.Def
return nil, E.Cause(err, "source_ipcidr")
}
rule.sourceAddressItems = append(rule.sourceAddressItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.IPCIDR) > 0 {
item, err := NewIPCIDRItem(false, options.IPCIDR)
@@ -113,30 +134,23 @@ func NewDefaultRule(router adapter.Router, logger log.Logger, options option.Def
return nil, E.Cause(err, "ipcidr")
}
rule.destinationAddressItems = append(rule.destinationAddressItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.SourcePort) > 0 {
rule.items = append(rule.items, NewPortItem(true, options.SourcePort))
item := NewPortItem(true, options.SourcePort)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.Port) > 0 {
rule.items = append(rule.items, NewPortItem(false, options.Port))
item := NewPortItem(false, options.Port)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
return rule, nil
}
func (r *DefaultRule) Start() error {
for _, item := range r.items {
err := common.Start(item)
if err != nil {
return err
}
}
for _, item := range r.sourceAddressItems {
err := common.Start(item)
if err != nil {
return err
}
}
for _, item := range r.destinationAddressItems {
for _, item := range r.allItems {
err := common.Start(item)
if err != nil {
return err
@@ -146,22 +160,22 @@ func (r *DefaultRule) Start() error {
}
func (r *DefaultRule) Close() error {
for _, item := range r.items {
for _, item := range r.allItems {
err := common.Close(item)
if err != nil {
return err
}
}
for _, item := range r.sourceAddressItems {
err := common.Close(item)
if err != nil {
return err
}
}
for _, item := range r.destinationAddressItems {
err := common.Close(item)
if err != nil {
return err
return nil
}
func (r *DefaultRule) UpdateGeosite() error {
for _, item := range r.allItems {
if geositeItem, isSite := item.(*GeositeItem); isSite {
err := geositeItem.Update()
if err != nil {
return err
}
}
}
return nil
@@ -208,5 +222,5 @@ func (r *DefaultRule) Outbound() string {
}
func (r *DefaultRule) String() string {
return strings.Join(common.Map(r.items, F.ToString0[RuleItem]), " ")
return strings.Join(common.Map(r.allItems, F.ToString0[RuleItem]), " ")
}

View File

@@ -28,11 +28,11 @@ func NewDomainRegexItem(expressions []string) (*DomainRegexItem, error) {
description := "domain_regex="
eLen := len(expressions)
if eLen == 1 {
description = expressions[0]
description += expressions[0]
} else if eLen > 3 {
description = F.ToString("[", strings.Join(expressions[:3], " "), "]")
description += F.ToString("[", strings.Join(expressions[:3], " "), "]")
} else {
description = F.ToString("[", strings.Join(expressions, " "), "]")
description += F.ToString("[", strings.Join(expressions, " "), "]")
}
return &DomainRegexItem{matchers, description}, nil
}

View File

@@ -39,24 +39,14 @@ func (r *GeoIPItem) Match(metadata *adapter.InboundContext) bool {
}
if r.isSource {
if metadata.SourceGeoIPCode == "" {
country, err := geoReader.Country(metadata.Source.Addr.AsSlice())
if err != nil {
r.logger.Error("query geoip for ", metadata.Source.Addr, ": ", err)
return false
}
metadata.SourceGeoIPCode = strings.ToLower(country.Country.IsoCode)
metadata.SourceGeoIPCode = geoReader.Lookup(metadata.Source.Addr)
}
} else {
if metadata.Destination.IsFqdn() {
return false
}
if metadata.GeoIPCode == "" {
country, err := geoReader.Country(metadata.Destination.Addr.AsSlice())
if err != nil {
r.logger.Error("query geoip for ", metadata.Destination.Addr, ": ", err)
return false
}
metadata.GeoIPCode = strings.ToLower(country.Country.IsoCode)
metadata.GeoIPCode = geoReader.Lookup(metadata.Destination.Addr)
}
}
return r.match(metadata)

View File

@@ -27,7 +27,7 @@ func NewGeositeItem(router adapter.Router, logger log.Logger, codes []string) *G
}
}
func (r *GeositeItem) Start() error {
func (r *GeositeItem) Update() error {
geositeReader := r.router.GeositeReader()
if geositeReader == nil {
return E.New("geosite reader is not initialized")
@@ -50,6 +50,9 @@ func (r *GeositeItem) Start() error {
}
func (r *GeositeItem) Match(metadata *adapter.InboundContext) bool {
if r.matcher == nil {
return false
}
return r.matcher.Match(metadata)
}
@@ -57,11 +60,11 @@ func (r *GeositeItem) String() string {
description := "geosite="
cLen := len(r.codes)
if cLen == 1 {
description = r.codes[0]
description += r.codes[0]
} else if cLen > 3 {
description = "[" + strings.Join(r.codes[:3], " ") + "...]"
description += "[" + strings.Join(r.codes[:3], " ") + "...]"
} else {
description = "[" + strings.Join(r.codes, " ") + "]"
description += "[" + strings.Join(r.codes, " ") + "]"
}
return description
}

View File

@@ -20,6 +20,16 @@ type LogicalRule struct {
outbound string
}
func (r *LogicalRule) UpdateGeosite() error {
for _, rule := range r.rules {
err := rule.UpdateGeosite()
if err != nil {
return err
}
}
return nil
}
func (r *LogicalRule) Start() error {
for _, rule := range r.rules {
err := rule.Start()