Improve rule actions

This commit is contained in:
世界
2024-11-06 17:30:40 +08:00
parent 607f315f91
commit 2156b40f3a
10 changed files with 414 additions and 138 deletions

View File

@@ -87,23 +87,34 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if deadline.NeedAdditionalReadDeadline(conn) {
conn = deadline.NewConn(conn)
}
selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil, -1)
selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil)
if err != nil {
return err
}
var selectedOutbound adapter.Outbound
var selectReturn bool
var (
// selectedOutbound adapter.Outbound
selectedDialer N.Dialer
selectedTag string
selectedDescription string
)
if selectedRule != nil {
switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute:
var loaded bool
selectedOutbound, loaded = r.Outbound(action.Outbound)
selectedOutbound, loaded := r.Outbound(action.Outbound)
if !loaded {
buf.ReleaseMulti(buffers)
return E.New("outbound not found: ", action.Outbound)
}
case *rule.RuleActionReturn:
selectReturn = true
if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) {
buf.ReleaseMulti(buffers)
return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag())
}
selectedDialer = selectedOutbound
selectedTag = selectedOutbound.Tag()
selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
case *rule.RuleActionDirect:
selectedDialer = action.Dialer
selectedDescription = action.String()
case *rule.RuleActionReject:
buf.ReleaseMulti(buffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@@ -116,17 +127,16 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
return nil
}
}
if selectedRule == nil || selectReturn {
if selectedRule == nil {
if r.defaultOutboundForConnection == nil {
buf.ReleaseMulti(buffers)
return E.New("missing default outbound with TCP support")
}
selectedOutbound = r.defaultOutboundForConnection
}
if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) {
buf.ReleaseMulti(buffers)
return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag())
selectedDialer = r.defaultOutboundForConnection
selectedTag = r.defaultOutboundForConnection.Tag()
selectedDescription = F.ToString("outbound/", r.defaultOutboundForConnection.Type(), "[", r.defaultOutboundForConnection.Tag(), "]")
}
for _, buffer := range buffers {
conn = bufio.NewCachedConn(conn, buffer)
}
@@ -137,10 +147,10 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
}
if r.v2rayServer != nil {
if statsService := r.v2rayServer.StatsService(); statsService != nil {
conn = statsService.RoutedConnection(metadata.Inbound, selectedOutbound.Tag(), metadata.User, conn)
conn = statsService.RoutedConnection(metadata.Inbound, selectedTag, metadata.User, conn)
}
}
legacyOutbound, isLegacy := selectedOutbound.(adapter.ConnectionHandler)
legacyOutbound, isLegacy := selectedDialer.(adapter.ConnectionHandler)
if isLegacy {
err = legacyOutbound.NewConnection(ctx, conn, metadata)
if err != nil {
@@ -148,7 +158,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if onClose != nil {
onClose(err)
}
return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
return E.Cause(err, selectedDescription)
} else {
if onClose != nil {
onClose(nil)
@@ -157,13 +167,13 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
return nil
}
// TODO
err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata)
err = outbound.NewConnection(ctx, selectedDialer, conn, metadata)
if err != nil {
conn.Close()
if onClose != nil {
onClose(err)
}
return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
return E.Cause(err, selectedDescription)
} else {
if onClose != nil {
onClose(nil)
@@ -231,24 +241,34 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn))
}*/
selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1)
selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn)
if err != nil {
return err
}
var selectedOutbound adapter.Outbound
var (
selectedDialer N.Dialer
selectedTag string
selectedDescription string
)
var selectReturn bool
if selectedRule != nil {
switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute:
var loaded bool
selectedOutbound, loaded = r.Outbound(action.Outbound)
selectedOutbound, loaded := r.Outbound(action.Outbound)
if !loaded {
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("outbound not found: ", action.Outbound)
}
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
case *rule.RuleActionReturn:
selectReturn = true
if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) {
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
}
selectedDialer = selectedOutbound
selectedTag = selectedOutbound.Tag()
selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
case *rule.RuleActionDirect:
selectedDialer = action.Dialer
selectedDescription = action.String()
case *rule.RuleActionReject:
N.ReleaseMultiPacketBuffer(packetBuffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@@ -263,11 +283,9 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("missing default outbound with UDP support")
}
selectedOutbound = r.defaultOutboundForPacketConnection
}
if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) {
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
selectedDialer = r.defaultOutboundForPacketConnection
selectedTag = r.defaultOutboundForPacketConnection.Tag()
selectedDescription = F.ToString("outbound/", r.defaultOutboundForPacketConnection.Type(), "[", r.defaultOutboundForPacketConnection.Tag(), "]")
}
for _, buffer := range packetBuffers {
conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination)
@@ -280,32 +298,32 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
}
if r.v2rayServer != nil {
if statsService := r.v2rayServer.StatsService(); statsService != nil {
conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedOutbound.Tag(), metadata.User, conn)
conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedTag, metadata.User, conn)
}
}
if metadata.FakeIP {
conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
}
legacyOutbound, isLegacy := selectedOutbound.(adapter.PacketConnectionHandler)
legacyOutbound, isLegacy := selectedDialer.(adapter.PacketConnectionHandler)
if isLegacy {
err = legacyOutbound.NewPacketConnection(ctx, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil {
return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
return E.Cause(err, selectedDescription)
}
return nil
}
// TODO
err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata)
err = outbound.NewPacketConnection(ctx, selectedDialer, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil {
return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
return E.Cause(err, selectedDescription)
}
return nil
}
func (r *Router) PreMatch(metadata adapter.InboundContext) error {
selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1)
selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil)
if err != nil {
return err
}
@@ -321,7 +339,7 @@ func (r *Router) PreMatch(metadata adapter.InboundContext) error {
func (r *Router) matchRule(
ctx context.Context, metadata *adapter.InboundContext, preMatch bool,
inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int,
inputConn net.Conn, inputPacketConn N.PacketConn,
) (
selectedRule adapter.Rule, selectedRuleIndex int,
buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error,
@@ -416,24 +434,10 @@ func (r *Router) matchRule(
}
match:
for ruleIndex < len(r.rules) {
rules := r.rules
if ruleIndex != -1 {
rules = rules[ruleIndex+1:]
}
var (
currentRule adapter.Rule
currentRuleIndex int
matched bool
)
for currentRuleIndex, currentRule = range rules {
if currentRule.Match(metadata) {
matched = true
break
}
}
if !matched {
break
for currentRuleIndex, currentRule := range r.rules {
metadata.ResetRuleCache()
if !currentRule.Match(metadata) {
continue
}
if !preMatch {
ruleDescription := currentRule.String()
@@ -444,7 +448,7 @@ match:
}
} else {
switch currentRule.Action().Type() {
case C.RuleActionTypeReject, C.RuleActionTypeResolve:
case C.RuleActionTypeReject:
ruleDescription := currentRule.String()
if ruleDescription != "" {
r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action())
@@ -454,6 +458,12 @@ match:
}
}
switch action := currentRule.Action().(type) {
case *rule.RuleActionRoute:
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
metadata.UDPConnect = action.UDPConnect
case *rule.RuleActionRouteOptions:
metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping
metadata.UDPConnect = action.UDPConnect
case *rule.RuleActionSniff:
if !preMatch {
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn)
@@ -476,12 +486,16 @@ match:
if fatalErr != nil {
return
}
default:
}
actionType := currentRule.Action().Type()
if actionType == C.RuleActionTypeRoute ||
actionType == C.RuleActionTypeReject ||
actionType == C.RuleActionTypeHijackDNS ||
(actionType == C.RuleActionTypeSniff && preMatch) {
selectedRule = currentRule
selectedRuleIndex = currentRuleIndex
break match
}
ruleIndex = currentRuleIndex
}
if !preMatch && metadata.Destination.Addr.IsUnspecified() {
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn)