diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index 58848883..0410e16a 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -2,6 +2,7 @@ package mapper import ( "net/netip" + "slices" "sort" "time" @@ -78,7 +79,10 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { tailnode, err := nv.TailNode( b.capVer, func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(nv, b.mapper.state.GetNodePrimaryRoutes(id), matchers) + // Self node: include own primaries + exit routes (no via steering for self). + primaries := policy.ReduceRoutes(nv, b.mapper.state.GetNodePrimaryRoutes(id), matchers) + + return slices.Concat(primaries, nv.ExitRoutes()) }, b.mapper.cfg) if err != nil { @@ -251,14 +255,18 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ( changedViews = peers } - tailPeers, err := types.TailNodes( - changedViews, b.capVer, - func(id types.NodeID) []netip.Prefix { - return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers) - }, - b.mapper.cfg) - if err != nil { - return nil, err + // Build tail nodes with per-peer via-aware route function. + tailPeers := make([]*tailcfg.Node, 0, changedViews.Len()) + + for _, peer := range changedViews.All() { + tn, err := peer.TailNode(b.capVer, func(_ types.NodeID) []netip.Prefix { + return b.mapper.state.RoutesForPeer(node, peer, matchers) + }, b.mapper.cfg) + if err != nil { + return nil, err + } + + tailPeers = append(tailPeers, tn) } // Peers is always returned sorted by Node.ID. diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 81c64115..1e8a899e 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -3,6 +3,7 @@ package mapper import ( "encoding/json" "net/netip" + "slices" "testing" "time" @@ -214,10 +215,13 @@ func TestTailNode(t *testing.T) { // This is a hack to avoid having a second node to test the primary route. // This should be baked into the test case proper if it is extended in the future. _ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24")) - got, err := tt.node.View().TailNode( + nv := tt.node.View() + got, err := nv.TailNode( 0, func(id types.NodeID) []netip.Prefix { - return primary.PrimaryRoutes(id) + // Route function returns primaries + exit routes + // (matching the real caller contract). + return slices.Concat(primary.PrimaryRoutes(id), nv.ExitRoutes()) }, cfg, ) diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 0c69160f..956ecb92 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -36,6 +36,12 @@ type PolicyManager interface { // NodeCanApproveRoute reports whether the given node can approve the given route. NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool + // ViaRoutesForPeer computes via grant effects for a viewer-peer pair. + // It returns which routes should be included (peer is via-designated for viewer) + // and excluded (steered to a different peer). When no via grants apply, + // both fields are empty and the caller falls back to existing behavior. + ViaRoutesForPeer(viewer, peer types.NodeView) types.ViaRouteResult + Version() int DebugString() string } diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 3ef72a16..76f5aed5 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -821,6 +821,100 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr return false } +// ViaRoutesForPeer computes via grant effects for a viewer-peer pair. +// For each via grant where the viewer matches the source, it checks whether the +// peer advertises any of the grant's destination prefixes. If the peer has the +// via tag, those prefixes go into Include; otherwise into Exclude. +func (pm *PolicyManager) ViaRoutesForPeer(viewer, peer types.NodeView) types.ViaRouteResult { + var result types.ViaRouteResult + + if pm == nil || pm.pol == nil { + return result + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // Self-steering doesn't apply. + if viewer.ID() == peer.ID() { + return result + } + + grants := pm.pol.Grants + for _, acl := range pm.pol.ACLs { + grants = append(grants, aclToGrants(acl)...) + } + + for _, grant := range grants { + if len(grant.Via) == 0 { + continue + } + + // Check if viewer matches any grant source. + viewerMatches := false + + for _, src := range grant.Sources { + ips, err := src.Resolve(pm.pol, pm.users, pm.nodes) + if err != nil { + continue + } + + if ips != nil && slices.ContainsFunc(viewer.IPs(), ips.Contains) { + viewerMatches = true + + break + } + } + + if !viewerMatches { + continue + } + + // Collect destination prefixes that the peer actually advertises. + peerSubnetRoutes := peer.SubnetRoutes() + peerExitRoutes := peer.ExitRoutes() + + var matchedPrefixes []netip.Prefix + + for _, dst := range grant.Destinations { + switch d := dst.(type) { + case *Prefix: + dstPrefix := netip.Prefix(*d) + if slices.Contains(peerSubnetRoutes, dstPrefix) { + matchedPrefixes = append(matchedPrefixes, dstPrefix) + } + case *AutoGroup: + if d.Is(AutoGroupInternet) && len(peerExitRoutes) > 0 { + matchedPrefixes = append(matchedPrefixes, peerExitRoutes...) + } + } + } + + if len(matchedPrefixes) == 0 { + continue + } + + // Check if peer has any of the via tags. + peerHasVia := false + + for _, viaTag := range grant.Via { + if peer.HasTag(string(viaTag)) { + peerHasVia = true + + break + } + } + + if peerHasVia { + result.Include = append(result.Include, matchedPrefixes...) + } else { + result.Exclude = append(result.Exclude, matchedPrefixes...) + } + } + + return result +} + func (pm *PolicyManager) Version() int { return 2 } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 89b06f9f..7d00b67f 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -1061,6 +1061,44 @@ func (s *State) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix { return s.primaryRoutes.PrimaryRoutes(nodeID) } +// RoutesForPeer computes the routes a peer should advertise to a specific viewer, +// applying via grant steering on top of global primary election and exit routes. +// When no via grants apply, this falls back to existing behavior (global primaries + exit routes). +func (s *State) RoutesForPeer( + viewer, peer types.NodeView, + matchers []matcher.Match, +) []netip.Prefix { + viaResult := s.polMan.ViaRoutesForPeer(viewer, peer) + + globalPrimaries := s.primaryRoutes.PrimaryRoutes(peer.ID()) + exitRoutes := peer.ExitRoutes() + + // Fast path: no via grants affect this pair — existing behavior. + if len(viaResult.Include) == 0 && len(viaResult.Exclude) == 0 { + allRoutes := slices.Concat(globalPrimaries, exitRoutes) + + return policy.ReduceRoutes(viewer, allRoutes, matchers) + } + + // Remove excluded routes (steered to a different peer for this viewer). + var routes []netip.Prefix + + for _, p := range slices.Concat(globalPrimaries, exitRoutes) { + if !slices.Contains(viaResult.Exclude, p) { + routes = append(routes, p) + } + } + + // Add included routes (this peer is via-designated for this viewer). + for _, p := range viaResult.Include { + if !slices.Contains(routes, p) { + routes = append(routes, p) + } + } + + return policy.ReduceRoutes(viewer, routes, matchers) +} + // PrimaryRoutesString returns a string representation of all primary routes. func (s *State) PrimaryRoutesString() string { return s.primaryRoutes.String() diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 9e88fce7..5058450c 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -36,9 +36,17 @@ var ( ) // RouteFunc is a function that takes a node ID and returns a list of -// netip.Prefixes representing the primary routes for that node. +// netip.Prefixes representing the routes for that node. type RouteFunc func(id NodeID) []netip.Prefix +// ViaRouteResult describes via grant effects for a viewer-peer pair. +type ViaRouteResult struct { + // Include contains prefixes this peer should serve to this viewer (via-designated). + Include []netip.Prefix + // Exclude contains prefixes steered to OTHER peers (suppress from global primary). + Exclude []netip.Prefix +} + type ( NodeID uint64 NodeIDs []NodeID @@ -1110,10 +1118,20 @@ func (nv NodeView) TailNode( keyExpiry = nv.Expiry().Get() } - primaryRoutes := primaryRouteFunc(nv.ID()) - allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes()) + // routeFunc returns ALL routes (subnet + exit) for this node. + allRoutes := primaryRouteFunc(nv.ID()) + allowedIPs := slices.Concat(nv.Prefixes(), allRoutes) slices.SortFunc(allowedIPs, netip.Prefix.Compare) + // PrimaryRoutes only includes non-exit subnet routes for HA tracking. + var primaryRoutes []netip.Prefix + + for _, r := range allRoutes { + if !tsaddr.IsExitRoute(r) { + primaryRoutes = append(primaryRoutes, r) + } + } + capMap := tailcfg.NodeCapMap{ tailcfg.CapabilityAdmin: []tailcfg.RawMessage{}, tailcfg.CapabilitySSH: []tailcfg.RawMessage{},