diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 04ea2fda..b321ebad 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -28,6 +28,7 @@ jobs: - TestAPIAuthenticationBypassCurl - TestGRPCAuthenticationBypass - TestCLIWithConfigAuthenticationBypass + - TestACLPolicyPropagationOverTime - TestAuthKeyLogoutAndReloginSameUser - TestAuthKeyLogoutAndReloginNewUser - TestAuthKeyLogoutAndReloginSameUserExpiredKey diff --git a/cmd/hi/tar_utils.go b/cmd/hi/tar_utils.go index f0e1e86b..cfeeef5e 100644 --- a/cmd/hi/tar_utils.go +++ b/cmd/hi/tar_utils.go @@ -81,7 +81,7 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error { if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err) } - + // Create file outFile, err := os.Create(targetPath) if err != nil { diff --git a/hscontrol/capver/capver_generated.go b/hscontrol/capver/capver_generated.go index 79590000..534ead02 100644 --- a/hscontrol/capver/capver_generated.go +++ b/hscontrol/capver/capver_generated.go @@ -1,6 +1,6 @@ package capver -//Generated DO NOT EDIT +// Generated DO NOT EDIT import "tailscale.com/tailcfg" @@ -37,16 +37,15 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{ "v1.84.2": 116, } - var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{ - 90: "v1.64.0", - 95: "v1.66.0", - 97: "v1.68.0", - 102: "v1.70.0", - 104: "v1.72.0", - 106: "v1.74.0", - 109: "v1.78.0", - 113: "v1.80.0", - 115: "v1.82.0", - 116: "v1.84.0", + 90: "v1.64.0", + 95: "v1.66.0", + 97: "v1.68.0", + 102: "v1.70.0", + 104: "v1.72.0", + 106: "v1.74.0", + 109: "v1.78.0", + 113: "v1.80.0", + 115: "v1.82.0", + 116: "v1.84.0", } diff --git a/hscontrol/derp/derp_test.go b/hscontrol/derp/derp_test.go index c8a5e74c..9334de05 100644 --- a/hscontrol/derp/derp_test.go +++ b/hscontrol/derp/derp_test.go @@ -185,7 +185,6 @@ func TestShuffleDERPMapDeterministic(t *testing.T) { } }) } - } func TestShuffleDERPMapEdgeCases(t *testing.T) { diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 20daee6b..d40b36b0 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -73,7 +73,6 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse // Use the worker pool for controlled concurrency instead of direct generation initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id)) - if err != nil { log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") nodeConn.removeConnectionByChannel(c) diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index 981806e7..b85eb908 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -7,7 +7,6 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" "tailscale.com/types/views" @@ -181,6 +180,9 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { return b } + // FilterForNode returns rules already reduced to only those relevant for this node. + // For autogroup:self policies, it returns per-node compiled rules. + // For global policies, it returns the global filter reduced for this node. filter, err := b.mapper.state.FilterForNode(node) if err != nil { b.addError(err) @@ -192,7 +194,7 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { // new PacketFilters field and "base" allows us to send a full update when we // have to send an empty list, avoiding the hack in the else block. b.resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node, filter), + "base": filter, } return b @@ -231,18 +233,19 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ( return nil, errors.New("node not found") } - // Use per-node filter to handle autogroup:self - filter, err := b.mapper.state.FilterForNode(node) + // Get unreduced matchers for peer relationship determination. + // MatchersForNode returns unreduced matchers that include all rules where the node + // could be either source or destination. This is different from FilterForNode which + // returns reduced rules for packet filtering (only rules where node is destination). + matchers, err := b.mapper.state.MatchersForNode(node) if err != nil { return nil, err } - matchers := matcher.MatchesFromFilterRules(filter) - // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. var changedViews views.Slice[types.NodeView] - if len(filter) > 0 { + if len(matchers) > 0 { changedViews = policy.ReduceNodes(node, peers, matchers) } else { changedViews = peers diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 79b4f845..910eb4a2 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -15,6 +15,10 @@ type PolicyManager interface { Filter() ([]tailcfg.FilterRule, []matcher.Match) // FilterForNode returns filter rules for a specific node, handling autogroup:self FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error) + // MatchersForNode returns matchers for peer relationship determination (unreduced) + MatchersForNode(node types.NodeView) ([]matcher.Match, error) + // BuildPeerMap constructs peer relationship maps for the given nodes + BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error) SetPolicy([]byte) (bool, error) SetUsers(users []types.User) (bool, error) diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 6a74e59f..677cb854 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -10,7 +10,6 @@ import ( "github.com/rs/zerolog/log" "github.com/samber/lo" "tailscale.com/net/tsaddr" - "tailscale.com/tailcfg" "tailscale.com/types/views" ) @@ -79,66 +78,6 @@ func BuildPeerMap( return ret } -// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations -// that are not relevant to that particular node. -func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule { - ret := []tailcfg.FilterRule{} - - for _, rule := range rules { - // record if the rule is actually relevant for the given node. - var dests []tailcfg.NetPortRange - DEST_LOOP: - for _, dest := range rule.DstPorts { - expanded, err := util.ParseIPSet(dest.IP, nil) - // Fail closed, if we can't parse it, then we should not allow - // access. - if err != nil { - continue DEST_LOOP - } - - if node.InIPSet(expanded) { - dests = append(dests, dest) - continue DEST_LOOP - } - - // If the node exposes routes, ensure they are note removed - // when the filters are reduced. - if node.Hostinfo().Valid() { - routableIPs := node.Hostinfo().RoutableIPs() - if routableIPs.Len() > 0 { - for _, routableIP := range routableIPs.All() { - if expanded.OverlapsPrefix(routableIP) { - dests = append(dests, dest) - continue DEST_LOOP - } - } - } - } - - // Also check approved subnet routes - nodes should have access - // to subnets they're approved to route traffic for. - subnetRoutes := node.SubnetRoutes() - - for _, subnetRoute := range subnetRoutes { - if expanded.OverlapsPrefix(subnetRoute) { - dests = append(dests, dest) - continue DEST_LOOP - } - } - } - - if len(dests) > 0 { - ret = append(ret, tailcfg.FilterRule{ - SrcIPs: rule.SrcIPs, - DstPorts: dests, - IPProto: rule.IPProto, - }) - } - } - - return ret -} - // ApproveRoutesWithPolicy checks if the node can approve the announced routes // and returns the new list of approved routes. // The approved routes will include: diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index c7cd3bcf..b849d470 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1,7 +1,6 @@ package policy import ( - "encoding/json" "fmt" "net/netip" "testing" @@ -11,12 +10,9 @@ import ( "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" "gorm.io/gorm" - "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" - "tailscale.com/util/must" ) var ap = func(ipStr string) *netip.Addr { @@ -29,817 +25,6 @@ var p = func(prefStr string) netip.Prefix { return ip } -// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when -// we use headscale "autogroup:internet". -var hsExitNodeDestForTest = []tailcfg.NetPortRange{ - {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, - {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, - {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, - {IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny}, - {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, - {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, - {IP: "64.0.0.0/3", Ports: tailcfg.PortRangeAny}, - {IP: "96.0.0.0/6", Ports: tailcfg.PortRangeAny}, - {IP: "100.0.0.0/10", Ports: tailcfg.PortRangeAny}, - {IP: "100.128.0.0/9", Ports: tailcfg.PortRangeAny}, - {IP: "101.0.0.0/8", Ports: tailcfg.PortRangeAny}, - {IP: "102.0.0.0/7", Ports: tailcfg.PortRangeAny}, - {IP: "104.0.0.0/5", Ports: tailcfg.PortRangeAny}, - {IP: "112.0.0.0/4", Ports: tailcfg.PortRangeAny}, - {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, - {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, - {IP: "168.0.0.0/8", Ports: tailcfg.PortRangeAny}, - {IP: "169.0.0.0/9", Ports: tailcfg.PortRangeAny}, - {IP: "169.128.0.0/10", Ports: tailcfg.PortRangeAny}, - {IP: "169.192.0.0/11", Ports: tailcfg.PortRangeAny}, - {IP: "169.224.0.0/12", Ports: tailcfg.PortRangeAny}, - {IP: "169.240.0.0/13", Ports: tailcfg.PortRangeAny}, - {IP: "169.248.0.0/14", Ports: tailcfg.PortRangeAny}, - {IP: "169.252.0.0/15", Ports: tailcfg.PortRangeAny}, - {IP: "169.255.0.0/16", Ports: tailcfg.PortRangeAny}, - {IP: "170.0.0.0/7", Ports: tailcfg.PortRangeAny}, - {IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny}, - {IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny}, - {IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny}, - {IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny}, - {IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny}, - {IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny}, - {IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny}, - {IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny}, - {IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny}, - {IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny}, - {IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny}, - {IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny}, - {IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny}, - {IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny}, - {IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny}, - {IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny}, - {IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny}, - {IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny}, - {IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny}, - {IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny}, - {IP: "224.0.0.0/3", Ports: tailcfg.PortRangeAny}, - {IP: "2000::/3", Ports: tailcfg.PortRangeAny}, -} - -func TestTheInternet(t *testing.T) { - internetSet := util.TheInternet() - - internetPrefs := internetSet.Prefixes() - - for i := range internetPrefs { - if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP { - t.Errorf( - "prefix from internet set %q != hsExit list %q", - internetPrefs[i].String(), - hsExitNodeDestForTest[i].IP, - ) - } - } - - if len(internetPrefs) != len(hsExitNodeDestForTest) { - t.Fatalf( - "expected same length of prefixes, internet: %d, hsExit: %d", - len(internetPrefs), - len(hsExitNodeDestForTest), - ) - } -} - -func TestReduceFilterRules(t *testing.T) { - users := types.Users{ - types.User{Model: gorm.Model{ID: 1}, Name: "mickael"}, - types.User{Model: gorm.Model{ID: 2}, Name: "user1"}, - types.User{Model: gorm.Model{ID: 3}, Name: "user2"}, - types.User{Model: gorm.Model{ID: 4}, Name: "user100"}, - types.User{Model: gorm.Model{ID: 5}, Name: "user3"}, - } - - tests := []struct { - name string - node *types.Node - peers types.Nodes - pol string - want []tailcfg.FilterRule - }{ - { - name: "host1-can-reach-host2-no-rules", - pol: ` -{ - "acls": [ - { - "action": "accept", - "proto": "", - "src": [ - "100.64.0.1" - ], - "dst": [ - "100.64.0.2:*" - ] - } - ], -} -`, - node: &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: users[0], - }, - peers: types.Nodes{ - &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: users[0], - }, - }, - want: []tailcfg.FilterRule{}, - }, - { - name: "1604-subnet-routers-are-preserved", - pol: ` -{ - "groups": { - "group:admins": [ - "user1@" - ] - }, - "acls": [ - { - "action": "accept", - "proto": "", - "src": [ - "group:admins" - ], - "dst": [ - "group:admins:*" - ] - }, - { - "action": "accept", - "proto": "", - "src": [ - "group:admins" - ], - "dst": [ - "10.33.0.0/16:*" - ] - } - ], -} -`, - node: &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], - Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{ - netip.MustParsePrefix("10.33.0.0/16"), - }, - }, - }, - peers: types.Nodes{ - &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[1], - }, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{ - "100.64.0.1/32", - "100.64.0.2/32", - "fd7a:115c:a1e0::1/128", - "fd7a:115c:a1e0::2/128", - }, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.64.0.1/32", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "fd7a:115c:a1e0::1/128", - Ports: tailcfg.PortRangeAny, - }, - }, - IPProto: []int{6, 17}, - }, - { - SrcIPs: []string{ - "100.64.0.1/32", - "100.64.0.2/32", - "fd7a:115c:a1e0::1/128", - "fd7a:115c:a1e0::2/128", - }, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "10.33.0.0/16", - Ports: tailcfg.PortRangeAny, - }, - }, - IPProto: []int{6, 17}, - }, - }, - }, - { - name: "1786-reducing-breaks-exit-nodes-the-client", - pol: ` -{ - "groups": { - "group:team": [ - "user3@", - "user2@", - "user1@" - ] - }, - "hosts": { - "internal": "100.64.0.100/32" - }, - "acls": [ - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "internal:*" - ] - }, - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "autogroup:internet:*" - ] - } - ], -} -`, - node: &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], - }, - peers: types.Nodes{ - &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], - }, - // "internal" exit node - &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], - Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: tsaddr.ExitRoutes(), - }, - }, - }, - want: []tailcfg.FilterRule{}, - }, - { - name: "1786-reducing-breaks-exit-nodes-the-exit", - pol: ` -{ - "groups": { - "group:team": [ - "user3@", - "user2@", - "user1@" - ] - }, - "hosts": { - "internal": "100.64.0.100/32" - }, - "acls": [ - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "internal:*" - ] - }, - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "autogroup:internet:*" - ] - } - ], -} -`, - node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], - Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: tsaddr.ExitRoutes(), - }, - }, - peers: types.Nodes{ - &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], - }, - &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], - }, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.64.0.100/32", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "fd7a:115c:a1e0::100/128", - Ports: tailcfg.PortRangeAny, - }, - }, - IPProto: []int{6, 17}, - }, - { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, - DstPorts: hsExitNodeDestForTest, - IPProto: []int{6, 17}, - }, - }, - }, - { - name: "1786-reducing-breaks-exit-nodes-the-example-from-issue", - pol: ` -{ - "groups": { - "group:team": [ - "user3@", - "user2@", - "user1@" - ] - }, - "hosts": { - "internal": "100.64.0.100/32" - }, - "acls": [ - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "internal:*" - ] - }, - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "0.0.0.0/5:*", - "8.0.0.0/7:*", - "11.0.0.0/8:*", - "12.0.0.0/6:*", - "16.0.0.0/4:*", - "32.0.0.0/3:*", - "64.0.0.0/2:*", - "128.0.0.0/3:*", - "160.0.0.0/5:*", - "168.0.0.0/6:*", - "172.0.0.0/12:*", - "172.32.0.0/11:*", - "172.64.0.0/10:*", - "172.128.0.0/9:*", - "173.0.0.0/8:*", - "174.0.0.0/7:*", - "176.0.0.0/4:*", - "192.0.0.0/9:*", - "192.128.0.0/11:*", - "192.160.0.0/13:*", - "192.169.0.0/16:*", - "192.170.0.0/15:*", - "192.172.0.0/14:*", - "192.176.0.0/12:*", - "192.192.0.0/10:*", - "193.0.0.0/8:*", - "194.0.0.0/7:*", - "196.0.0.0/6:*", - "200.0.0.0/5:*", - "208.0.0.0/4:*" - ] - } - ], -} -`, - node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], - Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: tsaddr.ExitRoutes(), - }, - }, - peers: types.Nodes{ - &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], - }, - &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], - }, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.64.0.100/32", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "fd7a:115c:a1e0::100/128", - Ports: tailcfg.PortRangeAny, - }, - }, - IPProto: []int{6, 17}, - }, - { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, - {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, - {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, - {IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny}, - {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, - {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, - {IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny}, - {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, - {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, - {IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny}, - {IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny}, - {IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny}, - {IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny}, - {IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny}, - {IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny}, - {IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny}, - {IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny}, - {IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny}, - {IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny}, - {IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny}, - {IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny}, - {IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny}, - {IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny}, - {IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny}, - {IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny}, - {IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny}, - {IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny}, - {IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny}, - {IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny}, - {IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny}, - }, - IPProto: []int{6, 17}, - }, - }, - }, - { - name: "1786-reducing-breaks-exit-nodes-app-connector-like", - pol: ` -{ - "groups": { - "group:team": [ - "user3@", - "user2@", - "user1@" - ] - }, - "hosts": { - "internal": "100.64.0.100/32" - }, - "acls": [ - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "internal:*" - ] - }, - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "8.0.0.0/8:*", - "16.0.0.0/8:*" - ] - } - ], -} -`, - node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], - Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, - }, - }, - peers: types.Nodes{ - &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], - }, - &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], - }, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.64.0.100/32", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "fd7a:115c:a1e0::100/128", - Ports: tailcfg.PortRangeAny, - }, - }, - IPProto: []int{6, 17}, - }, - { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "8.0.0.0/8", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "16.0.0.0/8", - Ports: tailcfg.PortRangeAny, - }, - }, - IPProto: []int{6, 17}, - }, - }, - }, - { - name: "1786-reducing-breaks-exit-nodes-app-connector-like2", - pol: ` -{ - "groups": { - "group:team": [ - "user3@", - "user2@", - "user1@" - ] - }, - "hosts": { - "internal": "100.64.0.100/32" - }, - "acls": [ - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "internal:*" - ] - }, - { - "action": "accept", - "proto": "", - "src": [ - "group:team" - ], - "dst": [ - "8.0.0.0/16:*", - "16.0.0.0/16:*" - ] - } - ], -} -`, - node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], - Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, - }, - }, - peers: types.Nodes{ - &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], - }, - &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], - }, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.64.0.100/32", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "fd7a:115c:a1e0::100/128", - Ports: tailcfg.PortRangeAny, - }, - }, - IPProto: []int{6, 17}, - }, - { - SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "8.0.0.0/16", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "16.0.0.0/16", - Ports: tailcfg.PortRangeAny, - }, - }, - IPProto: []int{6, 17}, - }, - }, - }, - { - name: "1817-reduce-breaks-32-mask", - pol: ` -{ - "tagOwners": { - "tag:access-servers": ["user100@"], - }, - "groups": { - "group:access": [ - "user1@" - ] - }, - "hosts": { - "dns1": "172.16.0.21/32", - "vlan1": "172.16.0.0/24" - }, - "acls": [ - { - "action": "accept", - "proto": "", - "src": [ - "group:access" - ], - "dst": [ - "tag:access-servers:*", - "dns1:*" - ] - } - ], -} -`, - node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], - Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, - }, - ForcedTags: []string{"tag:access-servers"}, - }, - peers: types.Nodes{ - &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], - }, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0::1/128"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.64.0.100/32", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "fd7a:115c:a1e0::100/128", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "172.16.0.21/32", - Ports: tailcfg.PortRangeAny, - }, - }, - IPProto: []int{6, 17}, - }, - }, - }, - { - name: "2365-only-route-policy", - pol: ` -{ - "hosts": { - "router": "100.64.0.1/32", - "node": "100.64.0.2/32" - }, - "acls": [ - { - "action": "accept", - "src": [ - "*" - ], - "dst": [ - "router:8000" - ] - }, - { - "action": "accept", - "src": [ - "node" - ], - "dst": [ - "172.26.0.0/16:*" - ] - } - ], -} -`, - node: &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[3], - }, - peers: types.Nodes{ - &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], - Hostinfo: &tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, - }, - ApprovedRoutes: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, - }, - }, - want: []tailcfg.FilterRule{}, - }, - } - - for _, tt := range tests { - for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.pol)) { - t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm PolicyManager - var err error - pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice()) - require.NoError(t, err) - got, _ := pm.Filter() - t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) - got = ReduceFilterRules(tt.node.View(), got) - - if diff := cmp.Diff(tt.want, got); diff != "" { - log.Trace().Interface("got", got).Msg("result") - t.Errorf("TestReduceFilterRules() unexpected result (-want +got):\n%s", diff) - } - }) - } - } -} - func TestReduceNodes(t *testing.T) { type args struct { nodes types.Nodes diff --git a/hscontrol/policy/policyutil/reduce.go b/hscontrol/policy/policyutil/reduce.go new file mode 100644 index 00000000..e4549c10 --- /dev/null +++ b/hscontrol/policy/policyutil/reduce.go @@ -0,0 +1,71 @@ +package policyutil + +import ( + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "tailscale.com/tailcfg" +) + +// ReduceFilterRules takes a node and a set of global filter rules and removes all rules +// and destinations that are not relevant to that particular node. +// +// IMPORTANT: This function is designed for global filters only. Per-node filters +// (from autogroup:self policies) are already node-specific and should not be passed +// to this function. Use PolicyManager.FilterForNode() instead, which handles both cases. +func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule { + ret := []tailcfg.FilterRule{} + + for _, rule := range rules { + // record if the rule is actually relevant for the given node. + var dests []tailcfg.NetPortRange + DEST_LOOP: + for _, dest := range rule.DstPorts { + expanded, err := util.ParseIPSet(dest.IP, nil) + // Fail closed, if we can't parse it, then we should not allow + // access. + if err != nil { + continue DEST_LOOP + } + + if node.InIPSet(expanded) { + dests = append(dests, dest) + continue DEST_LOOP + } + + // If the node exposes routes, ensure they are note removed + // when the filters are reduced. + if node.Hostinfo().Valid() { + routableIPs := node.Hostinfo().RoutableIPs() + if routableIPs.Len() > 0 { + for _, routableIP := range routableIPs.All() { + if expanded.OverlapsPrefix(routableIP) { + dests = append(dests, dest) + continue DEST_LOOP + } + } + } + } + + // Also check approved subnet routes - nodes should have access + // to subnets they're approved to route traffic for. + subnetRoutes := node.SubnetRoutes() + + for _, subnetRoute := range subnetRoutes { + if expanded.OverlapsPrefix(subnetRoute) { + dests = append(dests, dest) + continue DEST_LOOP + } + } + } + + if len(dests) > 0 { + ret = append(ret, tailcfg.FilterRule{ + SrcIPs: rule.SrcIPs, + DstPorts: dests, + IPProto: rule.IPProto, + }) + } + } + + return ret +} diff --git a/hscontrol/policy/policyutil/reduce_test.go b/hscontrol/policy/policyutil/reduce_test.go new file mode 100644 index 00000000..973d149c --- /dev/null +++ b/hscontrol/policy/policyutil/reduce_test.go @@ -0,0 +1,841 @@ +package policyutil_test + +import ( + "encoding/json" + "fmt" + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/policy/policyutil" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/util/must" +) + +var ap = func(ipStr string) *netip.Addr { + ip := netip.MustParseAddr(ipStr) + return &ip +} + +var p = func(prefStr string) netip.Prefix { + ip := netip.MustParsePrefix(prefStr) + return ip +} + +// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when +// we use headscale "autogroup:internet". +var hsExitNodeDestForTest = []tailcfg.NetPortRange{ + {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "64.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "96.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "100.0.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "100.128.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "101.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "102.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "104.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "112.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "168.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "169.0.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "169.128.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "169.192.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "169.224.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "169.240.0.0/13", Ports: tailcfg.PortRangeAny}, + {IP: "169.248.0.0/14", Ports: tailcfg.PortRangeAny}, + {IP: "169.252.0.0/15", Ports: tailcfg.PortRangeAny}, + {IP: "169.255.0.0/16", Ports: tailcfg.PortRangeAny}, + {IP: "170.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny}, + {IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny}, + {IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny}, + {IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny}, + {IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "224.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "2000::/3", Ports: tailcfg.PortRangeAny}, +} + +func TestTheInternet(t *testing.T) { + internetSet := util.TheInternet() + + internetPrefs := internetSet.Prefixes() + + for i := range internetPrefs { + if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP { + t.Errorf( + "prefix from internet set %q != hsExit list %q", + internetPrefs[i].String(), + hsExitNodeDestForTest[i].IP, + ) + } + } + + if len(internetPrefs) != len(hsExitNodeDestForTest) { + t.Fatalf( + "expected same length of prefixes, internet: %d, hsExit: %d", + len(internetPrefs), + len(hsExitNodeDestForTest), + ) + } +} + +func TestReduceFilterRules(t *testing.T) { + users := types.Users{ + types.User{Model: gorm.Model{ID: 1}, Name: "mickael"}, + types.User{Model: gorm.Model{ID: 2}, Name: "user1"}, + types.User{Model: gorm.Model{ID: 3}, Name: "user2"}, + types.User{Model: gorm.Model{ID: 4}, Name: "user100"}, + types.User{Model: gorm.Model{ID: 5}, Name: "user3"}, + } + + tests := []struct { + name string + node *types.Node + peers types.Nodes + pol string + want []tailcfg.FilterRule + }{ + { + name: "host1-can-reach-host2-no-rules", + pol: ` +{ + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "100.64.0.1" + ], + "dst": [ + "100.64.0.2:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), + User: users[0], + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), + User: users[0], + }, + }, + want: []tailcfg.FilterRule{}, + }, + { + name: "1604-subnet-routers-are-preserved", + pol: ` +{ + "groups": { + "group:admins": [ + "user1@" + ] + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:admins" + ], + "dst": [ + "group:admins:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:admins" + ], + "dst": [ + "10.33.0.0/16:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[1], + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{ + netip.MustParsePrefix("10.33.0.0/16"), + }, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: users[1], + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.1/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::1/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "10.33.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-the-client", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[1], + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: users[2], + }, + // "internal" exit node + &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: users[3], + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tsaddr.ExitRoutes(), + }, + }, + }, + want: []tailcfg.FilterRule{}, + }, + { + name: "1786-reducing-breaks-exit-nodes-the-exit", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: users[3], + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tsaddr.ExitRoutes(), + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: users[2], + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[1], + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: hsExitNodeDestForTest, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-the-example-from-issue", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "0.0.0.0/5:*", + "8.0.0.0/7:*", + "11.0.0.0/8:*", + "12.0.0.0/6:*", + "16.0.0.0/4:*", + "32.0.0.0/3:*", + "64.0.0.0/2:*", + "128.0.0.0/3:*", + "160.0.0.0/5:*", + "168.0.0.0/6:*", + "172.0.0.0/12:*", + "172.32.0.0/11:*", + "172.64.0.0/10:*", + "172.128.0.0/9:*", + "173.0.0.0/8:*", + "174.0.0.0/7:*", + "176.0.0.0/4:*", + "192.0.0.0/9:*", + "192.128.0.0/11:*", + "192.160.0.0/13:*", + "192.169.0.0/16:*", + "192.170.0.0/15:*", + "192.172.0.0/14:*", + "192.176.0.0/12:*", + "192.192.0.0/10:*", + "193.0.0.0/8:*", + "194.0.0.0/7:*", + "196.0.0.0/6:*", + "200.0.0.0/5:*", + "208.0.0.0/4:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: users[3], + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tsaddr.ExitRoutes(), + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: users[2], + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[1], + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny}, + {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny}, + {IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny}, + {IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny}, + {IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny}, + {IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-app-connector-like", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "8.0.0.0/8:*", + "16.0.0.0/8:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: users[3], + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: users[2], + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[1], + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "8.0.0.0/8", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "16.0.0.0/8", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-app-connector-like2", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "8.0.0.0/16:*", + "16.0.0.0/16:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: users[3], + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: users[2], + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[1], + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "8.0.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "16.0.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1817-reduce-breaks-32-mask", + pol: ` +{ + "tagOwners": { + "tag:access-servers": ["user100@"], + }, + "groups": { + "group:access": [ + "user1@" + ] + }, + "hosts": { + "dns1": "172.16.0.21/32", + "vlan1": "172.16.0.0/24" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:access" + ], + "dst": [ + "tag:access-servers:*", + "dns1:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: users[3], + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, + }, + ForcedTags: []string{"tag:access-servers"}, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[1], + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0::1/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "172.16.0.21/32", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "2365-only-route-policy", + pol: ` +{ + "hosts": { + "router": "100.64.0.1/32", + "node": "100.64.0.2/32" + }, + "acls": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "router:8000" + ] + }, + { + "action": "accept", + "src": [ + "node" + ], + "dst": [ + "172.26.0.0/16:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: users[3], + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[1], + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, + }, + ApprovedRoutes: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, + }, + }, + want: []tailcfg.FilterRule{}, + }, + } + + for _, tt := range tests { + for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) { + t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { + var pm policy.PolicyManager + var err error + pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice()) + require.NoError(t, err) + got, _ := pm.Filter() + t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) + got = policyutil.ReduceFilterRules(tt.node.View(), got) + + if diff := cmp.Diff(tt.want, got); diff != "" { + log.Trace().Interface("got", got).Msg("result") + t.Errorf("TestReduceFilterRules() unexpected result (-want +got):\n%s", diff) + } + }) + } + } +} diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index b904e14d..9f2845ac 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -854,7 +854,6 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { node1 := nodes[0].View() rules, err := policy2.compileFilterRulesForNode(users, node1, nodes.ViewSlice()) - if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 0a37d5c2..27cf70b4 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/policy/policyutil" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "go4.org/netipx" @@ -39,7 +40,9 @@ type PolicyManager struct { // Lazy map of SSH policies sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy - // Lazy map of per-node filter rules (when autogroup:self is used) + // Lazy map of per-node compiled filter rules (unreduced, for autogroup:self) + compiledFilterRulesMap map[types.NodeID][]tailcfg.FilterRule + // Lazy map of per-node filter rules (reduced, for packet filters) filterRulesMap map[types.NodeID][]tailcfg.FilterRule usesAutogroupSelf bool } @@ -54,12 +57,13 @@ func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.Node } pm := PolicyManager{ - pol: policy, - users: users, - nodes: nodes, - sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()), - filterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()), - usesAutogroupSelf: policy.usesAutogroupSelf(), + pol: policy, + users: users, + nodes: nodes, + sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()), + compiledFilterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()), + filterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()), + usesAutogroupSelf: policy.usesAutogroupSelf(), } _, err = pm.updateLocked() @@ -78,6 +82,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { // policies for nodes that have changed. Particularly if the only difference is // that nodes has been added or removed. clear(pm.sshPolicyMap) + clear(pm.compiledFilterRulesMap) clear(pm.filterRulesMap) // Check if policy uses autogroup:self @@ -233,9 +238,157 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { return pm.filter, pm.matchers } -// FilterForNode returns the filter rules for a specific node. -// If the policy uses autogroup:self, this returns node-specific rules for security. -// Otherwise, it returns the global filter rules for efficiency. +// BuildPeerMap constructs peer relationship maps for the given nodes. +// For global filters, it uses the global filter matchers for all nodes. +// For autogroup:self policies (empty global filter), it builds per-node +// peer maps using each node's specific filter rules. +func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView { + if pm == nil { + return nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // If we have a global filter, use it for all nodes (normal case) + if !pm.usesAutogroupSelf { + ret := make(map[types.NodeID][]types.NodeView, nodes.Len()) + + // Build the map of all peers according to the matchers. + // Compared to ReduceNodes, which builds the list per node, we end up with doing + // the full work for every node O(n^2), while this will reduce the list as we see + // relationships while building the map, making it O(n^2/2) in the end, but with less work per node. + for i := range nodes.Len() { + for j := i + 1; j < nodes.Len(); j++ { + if nodes.At(i).ID() == nodes.At(j).ID() { + continue + } + + if nodes.At(i).CanAccess(pm.matchers, nodes.At(j)) || nodes.At(j).CanAccess(pm.matchers, nodes.At(i)) { + ret[nodes.At(i).ID()] = append(ret[nodes.At(i).ID()], nodes.At(j)) + ret[nodes.At(j).ID()] = append(ret[nodes.At(j).ID()], nodes.At(i)) + } + } + } + + return ret + } + + // For autogroup:self (empty global filter), build per-node peer relationships + ret := make(map[types.NodeID][]types.NodeView, nodes.Len()) + + // Pre-compute per-node matchers using unreduced compiled rules + // We need unreduced rules to determine peer relationships correctly. + // Reduced rules only show destinations where the node is the target, + // but peer relationships require the full bidirectional access rules. + nodeMatchers := make(map[types.NodeID][]matcher.Match, nodes.Len()) + for _, node := range nodes.All() { + filter, err := pm.compileFilterRulesForNodeLocked(node) + if err != nil || len(filter) == 0 { + continue + } + nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter) + } + + // Check each node pair for peer relationships. + // Start j at i+1 to avoid checking the same pair twice and creating duplicates. + // We check both directions (i->j and j->i) since ACLs can be asymmetric. + for i := range nodes.Len() { + nodeI := nodes.At(i) + matchersI, hasFilterI := nodeMatchers[nodeI.ID()] + + for j := i + 1; j < nodes.Len(); j++ { + nodeJ := nodes.At(j) + matchersJ, hasFilterJ := nodeMatchers[nodeJ.ID()] + + // Check if nodeI can access nodeJ + if hasFilterI && nodeI.CanAccess(matchersI, nodeJ) { + ret[nodeI.ID()] = append(ret[nodeI.ID()], nodeJ) + } + + // Check if nodeJ can access nodeI + if hasFilterJ && nodeJ.CanAccess(matchersJ, nodeI) { + ret[nodeJ.ID()] = append(ret[nodeJ.ID()], nodeI) + } + } + } + + return ret +} + +// compileFilterRulesForNodeLocked returns the unreduced compiled filter rules for a node +// when using autogroup:self. This is used by BuildPeerMap to determine peer relationships. +// For packet filters sent to nodes, use filterForNodeLocked which returns reduced rules. +func (pm *PolicyManager) compileFilterRulesForNodeLocked(node types.NodeView) ([]tailcfg.FilterRule, error) { + if pm == nil { + return nil, nil + } + + // Check if we have cached compiled rules + if rules, ok := pm.compiledFilterRulesMap[node.ID()]; ok { + return rules, nil + } + + // Compile per-node rules with autogroup:self expanded + rules, err := pm.pol.compileFilterRulesForNode(pm.users, node, pm.nodes) + if err != nil { + return nil, fmt.Errorf("compiling filter rules for node: %w", err) + } + + // Cache the unreduced compiled rules + pm.compiledFilterRulesMap[node.ID()] = rules + + return rules, nil +} + +// filterForNodeLocked returns the filter rules for a specific node, already reduced +// to only include rules relevant to that node. +// This is a lock-free version of FilterForNode for internal use when the lock is already held. +// BuildPeerMap already holds the lock, so we need a version that doesn't re-acquire it. +func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.FilterRule, error) { + if pm == nil { + return nil, nil + } + + if !pm.usesAutogroupSelf { + // For global filters, reduce to only rules relevant to this node. + // Cache the reduced filter per node for efficiency. + if rules, ok := pm.filterRulesMap[node.ID()]; ok { + return rules, nil + } + + // Use policyutil.ReduceFilterRules for global filter reduction. + reducedFilter := policyutil.ReduceFilterRules(node, pm.filter) + + pm.filterRulesMap[node.ID()] = reducedFilter + return reducedFilter, nil + } + + // For autogroup:self, compile per-node rules then reduce them. + // Check if we have cached reduced rules for this node. + if rules, ok := pm.filterRulesMap[node.ID()]; ok { + return rules, nil + } + + // Get unreduced compiled rules + compiledRules, err := pm.compileFilterRulesForNodeLocked(node) + if err != nil { + return nil, err + } + + // Reduce the compiled rules to only destinations relevant to this node + reducedFilter := policyutil.ReduceFilterRules(node, compiledRules) + + // Cache the reduced filter + pm.filterRulesMap[node.ID()] = reducedFilter + + return reducedFilter, nil +} + +// FilterForNode returns the filter rules for a specific node, already reduced +// to only include rules relevant to that node. +// If the policy uses autogroup:self, this returns node-specific compiled rules. +// Otherwise, it returns the global filter reduced for this node. func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error) { if pm == nil { return nil, nil @@ -244,22 +397,36 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul pm.mu.Lock() defer pm.mu.Unlock() + return pm.filterForNodeLocked(node) +} + +// MatchersForNode returns the matchers for peer relationship determination for a specific node. +// These are UNREDUCED matchers - they include all rules where the node could be either source or destination. +// This is different from FilterForNode which returns REDUCED rules for packet filtering. +// +// For global policies: returns the global matchers (same for all nodes) +// For autogroup:self: returns node-specific matchers from unreduced compiled rules +func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) { + if pm == nil { + return nil, nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // For global policies, return the shared global matchers if !pm.usesAutogroupSelf { - return pm.filter, nil + return pm.matchers, nil } - if rules, ok := pm.filterRulesMap[node.ID()]; ok { - return rules, nil - } - - rules, err := pm.pol.compileFilterRulesForNode(pm.users, node, pm.nodes) + // For autogroup:self, get unreduced compiled rules and create matchers + compiledRules, err := pm.compileFilterRulesForNodeLocked(node) if err != nil { - return nil, fmt.Errorf("compiling filter rules for node: %w", err) + return nil, err } - pm.filterRulesMap[node.ID()] = rules - - return rules, nil + // Create matchers from unreduced rules for peer relationship determination + return matcher.MatchesFromFilterRules(compiledRules), nil } // SetUsers updates the users in the policy manager and updates the filter rules. @@ -300,22 +467,40 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro pm.mu.Lock() defer pm.mu.Unlock() - // Clear cache based on what actually changed - if pm.usesAutogroupSelf { - // For autogroup:self, we need granular invalidation since rules depend on: - // - User ownership (node.User().ID) - // - Tag status (node.IsTagged()) - // - IP addresses (node.IPs()) - // - Node existence (added/removed) - pm.invalidateAutogroupSelfCache(pm.nodes, nodes) - } else { - // For non-autogroup:self policies, we can clear everything - clear(pm.filterRulesMap) - } + oldNodeCount := pm.nodes.Len() + newNodeCount := nodes.Len() + + // Invalidate cache entries for nodes that changed. + // For autogroup:self: invalidate all nodes belonging to affected users (peer changes). + // For global policies: invalidate only nodes whose properties changed (IPs, routes). + pm.invalidateNodeCache(nodes) pm.nodes = nodes - return pm.updateLocked() + nodesChanged := oldNodeCount != newNodeCount + + // When nodes are added/removed, we must recompile filters because: + // 1. User/group aliases (like "user1@") resolve to node IPs + // 2. Filter compilation needs nodes to generate rules + // 3. Without nodes, filters compile to empty (0 rules) + // + // For autogroup:self: return true when nodes change even if the global filter + // hash didn't change. The global filter is empty for autogroup:self (each node + // has its own filter), so the hash never changes. But peer relationships DO + // change when nodes are added/removed, so we must signal this to trigger updates. + // For global policies: the filter must be recompiled to include the new nodes. + if nodesChanged { + // Recompile filter with the new node list + _, err := pm.updateLocked() + if err != nil { + return false, err + } + // Always return true when nodes changed, even if filter hash didn't change + // (can happen with autogroup:self or when nodes are added but don't affect rules) + return true, nil + } + + return false, nil } func (pm *PolicyManager) NodeCanHaveTag(node types.NodeView, tag string) bool { @@ -552,10 +737,12 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S // If we found the user and they're affected, clear this cache entry if found { if _, affected := affectedUsers[nodeUserID]; affected { + delete(pm.compiledFilterRulesMap, nodeID) delete(pm.filterRulesMap, nodeID) } } else { // Node not found in either old or new list, clear it + delete(pm.compiledFilterRulesMap, nodeID) delete(pm.filterRulesMap, nodeID) } } @@ -567,3 +754,50 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S Msg("Selectively cleared autogroup:self cache for affected users") } } + +// invalidateNodeCache invalidates cache entries based on what changed. +func (pm *PolicyManager) invalidateNodeCache(newNodes views.Slice[types.NodeView]) { + if pm.usesAutogroupSelf { + // For autogroup:self, a node's filter depends on its peers (same user). + // When any node in a user changes, all nodes for that user need invalidation. + pm.invalidateAutogroupSelfCache(pm.nodes, newNodes) + } else { + // For global policies, a node's filter depends only on its own properties. + // Only invalidate nodes whose properties actually changed. + pm.invalidateGlobalPolicyCache(newNodes) + } +} + +// invalidateGlobalPolicyCache invalidates only nodes whose properties affecting +// ReduceFilterRules changed. For global policies, each node's filter is independent. +func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types.NodeView]) { + oldNodeMap := make(map[types.NodeID]types.NodeView) + for _, node := range pm.nodes.All() { + oldNodeMap[node.ID()] = node + } + + newNodeMap := make(map[types.NodeID]types.NodeView) + for _, node := range newNodes.All() { + newNodeMap[node.ID()] = node + } + + // Invalidate nodes whose properties changed + for nodeID, newNode := range newNodeMap { + oldNode, existed := oldNodeMap[nodeID] + if !existed { + // New node - no cache entry yet, will be lazily calculated + continue + } + + if newNode.HasNetworkChanges(oldNode) { + delete(pm.filterRulesMap, nodeID) + } + } + + // Remove deleted nodes from cache + for nodeID := range pm.filterRulesMap { + if _, exists := newNodeMap[nodeID]; !exists { + delete(pm.filterRulesMap, nodeID) + } + } +} diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 90e6b506..5191368a 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -1,6 +1,7 @@ package v2 import ( + "net/netip" "testing" "github.com/google/go-cmp/cmp" @@ -204,3 +205,237 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { }) } } + +// TestInvalidateGlobalPolicyCache tests the cache invalidation logic for global policies. +func TestInvalidateGlobalPolicyCache(t *testing.T) { + mustIPPtr := func(s string) *netip.Addr { + ip := netip.MustParseAddr(s) + return &ip + } + + tests := []struct { + name string + oldNodes types.Nodes + newNodes types.Nodes + initialCache map[types.NodeID][]tailcfg.FilterRule + expectedCacheAfter map[types.NodeID]bool // true = should exist, false = should not exist + }{ + { + name: "node property changed - invalidates only that node", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.99")}, // Changed + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // Unchanged + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: false, // Invalidated + 2: true, // Preserved + }, + }, + { + name: "multiple nodes changed", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + &types.Node{ID: 3, IPv4: mustIPPtr("100.64.0.3")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.99")}, // Changed + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // Unchanged + &types.Node{ID: 3, IPv4: mustIPPtr("100.64.0.88")}, // Changed + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + 3: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: false, // Invalidated + 2: true, // Preserved + 3: false, // Invalidated + }, + }, + { + name: "node deleted - removes from cache", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: false, // Deleted + 2: true, // Preserved + }, + }, + { + name: "node added - no cache invalidation needed", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // New + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: true, // Preserved + 2: false, // Not in cache (new node) + }, + }, + { + name: "no changes - preserves all cache", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: true, + 2: true, + }, + }, + { + name: "routes changed - invalidates that node only", + oldNodes: types.Nodes{ + &types.Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + }, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + newNodes: types.Nodes{ + &types.Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, // Changed + }, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: false, // Invalidated + 2: true, // Preserved + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm := &PolicyManager{ + nodes: tt.oldNodes.ViewSlice(), + filterRulesMap: tt.initialCache, + usesAutogroupSelf: false, + } + + pm.invalidateGlobalPolicyCache(tt.newNodes.ViewSlice()) + + // Verify cache state + for nodeID, shouldExist := range tt.expectedCacheAfter { + _, exists := pm.filterRulesMap[nodeID] + require.Equal(t, shouldExist, exists, "node %d cache existence mismatch", nodeID) + } + }) + } +} + +// TestAutogroupSelfReducedVsUnreducedRules verifies that: +// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships +// 2. FilterForNode returns reduced compiled rules for packet filters +func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { + user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"} + user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"} + users := types.Users{user1, user2} + + // Create two nodes + node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1, nil) + node1.ID = 1 + node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2, nil) + node2.ID = 2 + nodes := types.Nodes{node1, node2} + + // Policy with autogroup:self - all members can reach their own devices + policyStr := `{ + "acls": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policyStr), users, nodes.ViewSlice()) + require.NoError(t, err) + require.True(t, pm.usesAutogroupSelf, "policy should use autogroup:self") + + // Test FilterForNode returns reduced rules + // For node1: should have rules where node1 is in destinations (its own IP) + filterNode1, err := pm.FilterForNode(nodes[0].View()) + require.NoError(t, err) + + // For node2: should have rules where node2 is in destinations (its own IP) + filterNode2, err := pm.FilterForNode(nodes[1].View()) + require.NoError(t, err) + + // FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations + // For node1, destinations should only be node1's IPs + node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"} + for _, rule := range filterNode1 { + for _, dst := range rule.DstPorts { + require.Contains(t, node1IPs, dst.IP, + "node1 filter should only contain node1's IPs as destinations") + } + } + + // For node2, destinations should only be node2's IPs + node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"} + for _, rule := range filterNode2 { + for _, dst := range rule.DstPorts { + require.Contains(t, node2IPs, dst.IP, + "node2 filter should only contain node2's IPs as destinations") + } + } + + // Test BuildPeerMap uses unreduced rules + peerMap := pm.BuildPeerMap(nodes.ViewSlice()) + + // According to the policy, user1 can reach autogroup:self (which expands to node1's own IPs for node1) + // So node1 should be able to reach itself, but since we're looking at peer relationships, + // node1 should NOT have itself in the peer map (nodes don't peer with themselves) + // node2 should also not have any peers since user2 has no rules allowing it to reach anyone + + // Verify peer relationships based on unreduced rules + // With unreduced rules, BuildPeerMap can properly determine that: + // - node1 can access autogroup:self (its own IPs) + // - node2 cannot access node1 + require.Empty(t, peerMap[node1.ID], "node1 should have no peers (can only reach itself)") + require.Empty(t, peerMap[node2.ID], "node2 should have no peers") +} diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index 34bbb24f..a06151a5 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -20,9 +20,10 @@ const ( ) const ( - put = 1 - del = 2 - update = 3 + put = 1 + del = 2 + update = 3 + rebuildPeerMaps = 4 ) const prometheusNamespace = "headscale" @@ -142,6 +143,8 @@ type work struct { updateFn UpdateNodeFunc result chan struct{} nodeResult chan types.NodeView // Channel to return the resulting node after batch application + // For rebuildPeerMaps operation + rebuildResult chan struct{} } // PutNode adds or updates a node in the store. @@ -298,6 +301,9 @@ func (s *NodeStore) applyBatch(batch []work) { // Track which work items need node results nodeResultRequests := make(map[types.NodeID][]*work) + // Track rebuildPeerMaps operations + var rebuildOps []*work + for i := range batch { w := &batch[i] switch w.op { @@ -321,6 +327,10 @@ func (s *NodeStore) applyBatch(batch []work) { if w.nodeResult != nil { nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) } + case rebuildPeerMaps: + // rebuildPeerMaps doesn't modify nodes, it just forces the snapshot rebuild + // below to recalculate peer relationships using the current peersFunc + rebuildOps = append(rebuildOps, w) } } @@ -347,9 +357,16 @@ func (s *NodeStore) applyBatch(batch []work) { } } - // Signal completion for all work items + // Signal completion for rebuildPeerMaps operations + for _, w := range rebuildOps { + close(w.rebuildResult) + } + + // Signal completion for all other work items for _, w := range batch { - close(w.result) + if w.op != rebuildPeerMaps { + close(w.result) + } } } @@ -546,6 +563,22 @@ func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] { return views.SliceOf(s.data.Load().peersByNode[id]) } +// RebuildPeerMaps rebuilds the peer relationship map using the current peersFunc. +// This must be called after policy changes because peersFunc uses PolicyManager's +// filters to determine which nodes can see each other. Without rebuilding, the +// peer map would use stale filter data until the next node add/delete. +func (s *NodeStore) RebuildPeerMaps() { + result := make(chan struct{}) + + w := work{ + op: rebuildPeerMaps, + rebuildResult: result, + } + + s.writeQueue <- w + <-result +} + // ListNodesByUser returns a slice of all nodes for a given user ID. func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] { timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user")) diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 7585c4e3..1d450cb6 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -132,9 +132,10 @@ func NewState(cfg *types.Config) (*State, error) { return nil, fmt.Errorf("init policy manager: %w", err) } + // PolicyManager.BuildPeerMap handles both global and per-node filter complexity. + // This moves the complex peer relationship logic into the policy package where it belongs. nodeStore := NewNodeStore(nodes, func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { - _, matchers := polMan.Filter() - return policy.BuildPeerMap(views.SliceOf(nodes), matchers) + return polMan.BuildPeerMap(views.SliceOf(nodes)) }) nodeStore.Start() @@ -225,6 +226,12 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) { return nil, fmt.Errorf("setting policy: %w", err) } + // Rebuild peer maps after policy changes because the peersFunc in NodeStore + // uses the PolicyManager's filters. Without this, nodes won't see newly allowed + // peers until a node is added/removed, causing autogroup:self policies to not + // propagate correctly when switching between policy types. + s.nodeStore.RebuildPeerMaps() + cs := []change.ChangeSet{change.PolicyChange()} // Always call autoApproveNodes during policy reload, regardless of whether @@ -797,6 +804,11 @@ func (s *State) FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error) return s.polMan.FilterForNode(node) } +// MatchersForNode returns matchers for peer relationship determination (unreduced). +func (s *State) MatchersForNode(node types.NodeView) ([]matcher.Match, error) { + return s.polMan.MatchersForNode(node) +} + // NodeCanHaveTag checks if a node is allowed to have a specific tag. func (s *State) NodeCanHaveTag(node types.NodeView, tag string) bool { return s.polMan.NodeCanHaveTag(node, tag) diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 010e3410..732b4d5a 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -340,11 +340,11 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential)) if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - log.Warn().Msg("No config file found, using defaults") - return nil - } - + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + log.Warn().Msg("No config file found, using defaults") + return nil + } + return fmt.Errorf("fatal error reading config file: %w", err) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index a70861ac..8cf40ced 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -855,3 +855,22 @@ func (v NodeView) IPsAsString() []string { } return v.ж.IPsAsString() } + +// HasNetworkChanges checks if the node has network-related changes. +// Returns true if IPs, announced routes, or approved routes changed. +// This is primarily used for policy cache invalidation. +func (v NodeView) HasNetworkChanges(other NodeView) bool { + if !slices.Equal(v.IPs(), other.IPs()) { + return true + } + + if !slices.Equal(v.AnnouncedRoutes(), other.AnnouncedRoutes()) { + return true + } + + if !slices.Equal(v.SubnetRoutes(), other.SubnetRoutes()) { + return true + } + + return false +} diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 41af5d13..c992219e 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -793,3 +793,179 @@ func TestNodeRegisterMethodToV1Enum(t *testing.T) { }) } } + +// TestHasNetworkChanges tests the NodeView method for detecting +// when a node's network properties have changed. +func TestHasNetworkChanges(t *testing.T) { + mustIPPtr := func(s string) *netip.Addr { + ip := netip.MustParseAddr(s) + return &ip + } + + tests := []struct { + name string + old *Node + new *Node + changed bool + }{ + { + name: "no changes", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + changed: false, + }, + { + name: "IPv4 changed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.2"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + }, + changed: true, + }, + { + name: "IPv6 changed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::2"), + }, + changed: true, + }, + { + name: "RoutableIPs added", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + }, + changed: true, + }, + { + name: "RoutableIPs removed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{}, + }, + changed: true, + }, + { + name: "RoutableIPs changed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + }, + changed: true, + }, + { + name: "SubnetRoutes added", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + changed: true, + }, + { + name: "SubnetRoutes removed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{}, + }, + changed: true, + }, + { + name: "SubnetRoutes changed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + changed: true, + }, + { + name: "irrelevant property changed (Hostname)", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostname: "old-name", + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostname: "new-name", + }, + changed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.new.View().HasNetworkChanges(tt.old.View()) + if got != tt.changed { + t.Errorf("HasNetworkChanges() = %v, want %v", got, tt.changed) + } + }) + } +} diff --git a/integration/acl_test.go b/integration/acl_test.go index fd5d22a0..122eeea7 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -3,12 +3,14 @@ package integration import ( "fmt" "net/netip" + "strconv" "strings" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" @@ -319,12 +321,14 @@ func TestACLHostsInNetMapTable(t *testing.T) { require.NoError(t, err) for _, client := range allClients { - status, err := client.Status() - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) - user := status.User[status.Self.UserID].LoginName + user := status.User[status.Self.UserID].LoginName - assert.Len(t, status.Peer, (testCase.want[user])) + assert.Len(c, status.Peer, (testCase.want[user])) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer visibility") } }) } @@ -782,75 +786,87 @@ func TestACLNamedHostsCanReach(t *testing.T) { test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn) // test1 can query test3 - result, err := test1.Curl(test3ip4URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3ip4URL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test3ip4URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3ip4URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv4") - result, err = test1.Curl(test3ip6URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3ip6URL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test3ip6URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3ip6URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv6") - result, err = test1.Curl(test3fqdnURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3fqdnURL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test3fqdnURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3fqdnURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via FQDN") // test2 can query test3 - result, err = test2.Curl(test3ip4URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3ip4URL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test3ip4URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3ip4URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv4") - result, err = test2.Curl(test3ip6URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3ip6URL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test3ip6URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3ip6URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv6") - result, err = test2.Curl(test3fqdnURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3fqdnURL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test3fqdnURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3fqdnURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via FQDN") // test3 cannot query test1 - result, err = test3.Curl(test1ip4URL) + result, err := test3.Curl(test1ip4URL) assert.Empty(t, result) require.Error(t, err) @@ -876,38 +892,44 @@ func TestACLNamedHostsCanReach(t *testing.T) { require.Error(t, err) // test1 can query test2 - result, err = test1.Curl(test2ip4URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", - test2ip4URL, - result, - ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2ip4URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", + test2ip4URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4") - require.NoError(t, err) - result, err = test1.Curl(test2ip6URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", - test2ip6URL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2ip6URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", + test2ip6URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6") - result, err = test1.Curl(test2fqdnURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", - test2fqdnURL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2fqdnURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", + test2fqdnURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN") // test2 cannot query test1 result, err = test2.Curl(test1ip4URL) @@ -1050,50 +1072,63 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) // test1 can query test2 - result, err := test1.Curl(test2ipURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", - test2ipURL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2ipURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", + test2ipURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4") - result, err = test1.Curl(test2ip6URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", - test2ip6URL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2ip6URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", + test2ip6URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6") - result, err = test1.Curl(test2fqdnURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", - test2fqdnURL, - result, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2fqdnURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", + test2fqdnURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN") - result, err = test2.Curl(test1ipURL) - assert.Empty(t, result) - require.Error(t, err) + // test2 cannot query test1 (negative test case) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test1ipURL) + assert.Error(c, err) + assert.Empty(c, result) + }, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv4") - result, err = test2.Curl(test1ip6URL) - assert.Empty(t, result) - require.Error(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test1ip6URL) + assert.Error(c, err) + assert.Empty(c, result) + }, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv6") - result, err = test2.Curl(test1fqdnURL) - assert.Empty(t, result) - require.Error(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test1fqdnURL) + assert.Error(c, err) + assert.Empty(c, result) + }, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via FQDN") }) } } @@ -1266,9 +1301,15 @@ func TestACLAutogroupMember(t *testing.T) { // Test that untagged nodes can access each other for _, client := range allClients { - status, err := client.Status() - require.NoError(t, err) - if status.Self.Tags != nil && status.Self.Tags.Len() > 0 { + var clientIsUntagged bool + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + clientIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0 + assert.True(c, clientIsUntagged, "Expected client %s to be untagged for autogroup:member test", client.Hostname()) + }, 10*time.Second, 200*time.Millisecond, "Waiting for client %s to be untagged", client.Hostname()) + + if !clientIsUntagged { continue } @@ -1277,9 +1318,15 @@ func TestACLAutogroupMember(t *testing.T) { continue } - status, err := peer.Status() - require.NoError(t, err) - if status.Self.Tags != nil && status.Self.Tags.Len() > 0 { + var peerIsUntagged bool + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := peer.Status() + assert.NoError(c, err) + peerIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0 + assert.True(c, peerIsUntagged, "Expected peer %s to be untagged for autogroup:member test", peer.Hostname()) + }, 10*time.Second, 200*time.Millisecond, "Waiting for peer %s to be untagged", peer.Hostname()) + + if !peerIsUntagged { continue } @@ -1468,21 +1515,23 @@ func TestACLAutogroupTagged(t *testing.T) { // Explicitly verify tags on tagged nodes for _, client := range taggedClients { - status, err := client.Status() - require.NoError(t, err) - require.NotNil(t, status.Self.Tags, "tagged node %s should have tags", client.Hostname()) - require.Positive(t, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname()) - t.Logf("Tagged node %s has tags: %v", client.Hostname(), status.Self.Tags) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + assert.NotNil(c, status.Self.Tags, "tagged node %s should have tags", client.Hostname()) + assert.Positive(c, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname()) + }, 10*time.Second, 200*time.Millisecond, "Waiting for tags to be applied to tagged nodes") } // Verify untagged nodes have no tags for _, client := range untaggedClients { - status, err := client.Status() - require.NoError(t, err) - if status.Self.Tags != nil { - require.Equal(t, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname()) - } - t.Logf("Untagged node %s has no tags", client.Hostname()) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + if status.Self.Tags != nil { + assert.Equal(c, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname()) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting to verify untagged nodes have no tags") } // Test that tagged nodes can communicate with each other @@ -1603,9 +1652,11 @@ func TestACLAutogroupSelf(t *testing.T) { url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s (user1) to %s (user1)", client.Hostname(), fqdn) - result, err := client.Curl(url) - assert.Len(t, result, 13) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 10*time.Second, 200*time.Millisecond, "user1 device should reach other user1 device") } } @@ -1622,9 +1673,11 @@ func TestACLAutogroupSelf(t *testing.T) { url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s (user2) to %s (user2)", client.Hostname(), fqdn) - result, err := client.Curl(url) - assert.Len(t, result, 13) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 10*time.Second, 200*time.Millisecond, "user2 device should reach other user2 device") } } @@ -1657,3 +1710,388 @@ func TestACLAutogroupSelf(t *testing.T) { } } } + +func TestACLPolicyPropagationOverTime(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + // Install iptables to enable packet filtering for ACL tests. + // Packet filters are essential for testing autogroup:self and other ACL policies. + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", + "-c", + "/bin/sleep 3 ; apk add python3 curl iptables ip6tables ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithTestName("aclpropagation"), + hsic.WithPolicyMode(types.PolicyModeDB), + ) + require.NoError(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + require.NoError(t, err) + + allClients := append(user1Clients, user2Clients...) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Define the four policies we'll cycle through + allowAllPolicy := &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + } + + autogroupSelfPolicy := &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), + }, + }, + }, + } + + user1ToUser2Policy := &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + } + + // Run through the policy cycle 5 times + for i := range 5 { + iteration := i + 1 // range 5 gives 0-4, we want 1-5 for logging + t.Logf("=== Iteration %d/5 ===", iteration) + + // Phase 1: Allow all policy + t.Logf("Iteration %d: Setting allow-all policy", iteration) + err = headscale.SetPolicy(allowAllPolicy) + require.NoError(t, err) + + // Wait for peer lists to sync with allow-all policy + t.Logf("Iteration %d: Phase 1 - Waiting for peer lists to sync with allow-all policy", iteration) + err = scenario.WaitForTailscaleSync() + require.NoError(t, err, "iteration %d: Phase 1 - failed to sync after allow-all policy", iteration) + + // Test all-to-all connectivity after state is settled + t.Logf("Iteration %d: Phase 1 - Testing all-to-all connectivity", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + for _, peer := range allClients { + if client.ContainerID() == peer.ContainerID() { + continue + } + + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: %s should reach %s with allow-all policy", iteration, client.Hostname(), fqdn) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn) + } + } + }, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 1 - all connectivity tests with allow-all policy", iteration) + + // Phase 2: Autogroup:self policy (only same user can access) + t.Logf("Iteration %d: Phase 2 - Setting autogroup:self policy", iteration) + err = headscale.SetPolicy(autogroupSelfPolicy) + require.NoError(t, err) + + // Wait for peer lists to sync with autogroup:self - ensures cross-user peers are removed + t.Logf("Iteration %d: Phase 2 - Waiting for peer lists to sync with autogroup:self", iteration) + err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond) + require.NoError(t, err, "iteration %d: Phase 2 - failed to sync after autogroup:self policy", iteration) + + // Test ALL connectivity (positive and negative) in one block after state is settled + t.Logf("Iteration %d: Phase 2 - Testing all connectivity with autogroup:self", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Positive: user1 can access user1's nodes + for _, client := range user1Clients { + for _, peer := range user1Clients { + if client.ContainerID() == peer.ContainerID() { + continue + } + + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname()) + } + } + + // Positive: user2 can access user2's nodes + for _, client := range user2Clients { + for _, peer := range user2Clients { + if client.ContainerID() == peer.ContainerID() { + continue + } + + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: user2 %s should reach user2's node %s", iteration, client.Hostname(), fqdn) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn) + } + } + + // Negative: user1 cannot access user2's nodes + for _, client := range user1Clients { + for _, peer := range user2Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.Error(ct, err, "iteration %d: user1 %s should NOT reach user2's node %s with autogroup:self", iteration, client.Hostname(), fqdn) + assert.Empty(ct, result, "iteration %d: user1 %s->user2 %s should fail", iteration, client.Hostname(), fqdn) + } + } + + // Negative: user2 cannot access user1's nodes + for _, client := range user2Clients { + for _, peer := range user1Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Empty(ct, result, "iteration %d: user2->user1 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname()) + } + } + }, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2 - all connectivity tests with autogroup:self", iteration) + + // Phase 2b: Add a new node to user1 and validate policy propagation + t.Logf("Iteration %d: Phase 2b - Adding new node to user1 during autogroup:self policy", iteration) + + // Add a new node with the same options as the initial setup + // Get the network to use (scenario uses first network in list) + networks := scenario.Networks() + require.NotEmpty(t, networks, "scenario should have at least one network") + + newClient := scenario.MustAddAndLoginClient(t, "user1", "all", headscale, + tsic.WithNetfilter("off"), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", + "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + tsic.WithNetwork(networks[0]), + ) + t.Logf("Iteration %d: Phase 2b - Added and logged in new node %s", iteration, newClient.Hostname()) + + // Wait for peer lists to sync after new node addition (now 3 user1 nodes, still autogroup:self) + t.Logf("Iteration %d: Phase 2b - Waiting for peer lists to sync after new node addition", iteration) + err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond) + require.NoError(t, err, "iteration %d: Phase 2b - failed to sync after new node addition", iteration) + + // Test ALL connectivity (positive and negative) in one block after state is settled + t.Logf("Iteration %d: Phase 2b - Testing all connectivity after new node addition", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Re-fetch client list to ensure latest state + user1ClientsWithNew, err := scenario.ListTailscaleClients("user1") + assert.NoError(ct, err, "iteration %d: failed to list user1 clients", iteration) + assert.Len(ct, user1ClientsWithNew, 3, "iteration %d: user1 should have 3 nodes", iteration) + + // Positive: all user1 nodes can access each other + for _, client := range user1ClientsWithNew { + for _, peer := range user1ClientsWithNew { + if client.ContainerID() == peer.ContainerID() { + continue + } + + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname()) + } + } + + // Negative: user1 nodes cannot access user2's nodes + for _, client := range user1ClientsWithNew { + for _, peer := range user2Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.Error(ct, err, "iteration %d: user1 node %s should NOT reach user2 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Empty(ct, result, "iteration %d: user1->user2 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname()) + } + } + }, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - all connectivity tests after new node addition", iteration) + + // Delete the newly added node before Phase 3 + t.Logf("Iteration %d: Phase 2b - Deleting the newly added node from user1", iteration) + + // Get the node list and find the newest node (highest ID) + var nodeList []*v1.Node + var nodeToDeleteID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodeList, err = headscale.ListNodes("user1") + assert.NoError(ct, err) + assert.Len(ct, nodeList, 3, "should have 3 user1 nodes before deletion") + + // Find the node with the highest ID (the newest one) + for _, node := range nodeList { + if node.GetId() > nodeToDeleteID { + nodeToDeleteID = node.GetId() + } + } + }, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - listing nodes before deletion", iteration) + + // Delete the node via headscale helper + t.Logf("Iteration %d: Phase 2b - Deleting node ID %d from headscale", iteration, nodeToDeleteID) + err = headscale.DeleteNode(nodeToDeleteID) + require.NoError(t, err, "iteration %d: failed to delete node %d", iteration, nodeToDeleteID) + + // Remove the deleted client from the scenario's user.Clients map + // This is necessary for WaitForTailscaleSyncPerUser to calculate correct peer counts + t.Logf("Iteration %d: Phase 2b - Removing deleted client from scenario", iteration) + for clientName, client := range scenario.users["user1"].Clients { + status := client.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + if err != nil { + continue + } + if nodeID == nodeToDeleteID { + delete(scenario.users["user1"].Clients, clientName) + t.Logf("Iteration %d: Phase 2b - Removed client %s (node ID %d) from scenario", iteration, clientName, nodeToDeleteID) + break + } + } + + // Verify the node has been deleted + t.Logf("Iteration %d: Phase 2b - Verifying node deletion (expecting 2 user1 nodes)", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodeListAfter, err := headscale.ListNodes("user1") + assert.NoError(ct, err, "failed to list nodes after deletion") + assert.Len(ct, nodeListAfter, 2, "iteration %d: should have 2 user1 nodes after deletion, got %d", iteration, len(nodeListAfter)) + }, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - node should be deleted", iteration) + + // Wait for sync after deletion to ensure peer counts are correct + // Use WaitForTailscaleSyncPerUser because autogroup:self is still active, + // so nodes only see same-user peers, not all nodes + t.Logf("Iteration %d: Phase 2b - Waiting for sync after node deletion (with autogroup:self)", iteration) + err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond) + require.NoError(t, err, "iteration %d: failed to sync after node deletion", iteration) + + // Refresh client lists after deletion to ensure we don't reference the deleted node + user1Clients, err = scenario.ListTailscaleClients("user1") + require.NoError(t, err, "iteration %d: failed to refresh user1 client list after deletion", iteration) + user2Clients, err = scenario.ListTailscaleClients("user2") + require.NoError(t, err, "iteration %d: failed to refresh user2 client list after deletion", iteration) + // Create NEW slice instead of appending to old allClients which still has deleted client + allClients = make([]TailscaleClient, 0, len(user1Clients)+len(user2Clients)) + allClients = append(allClients, user1Clients...) + allClients = append(allClients, user2Clients...) + + t.Logf("Iteration %d: Phase 2b completed - New node added, validated, and removed successfully", iteration) + + // Phase 3: User1 can access user2 but not reverse + t.Logf("Iteration %d: Phase 3 - Setting user1->user2 directional policy", iteration) + err = headscale.SetPolicy(user1ToUser2Policy) + require.NoError(t, err) + + // Note: Cannot use WaitForTailscaleSync() here because directional policy means + // user2 nodes don't see user1 nodes in their peer list (asymmetric visibility). + // The EventuallyWithT block below will handle waiting for policy propagation. + + // Test ALL connectivity (positive and negative) in one block after policy settles + t.Logf("Iteration %d: Phase 3 - Testing all connectivity with directional policy", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Positive: user1 can access user2's nodes + for _, client := range user1Clients { + for _, peer := range user2Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: user1 node %s should reach user2 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname()) + } + } + + // Negative: user2 cannot access user1's nodes + for _, client := range user2Clients { + for _, peer := range user1Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Empty(ct, result, "iteration %d: user2->user1 from %s to %s should fail", iteration, client.Hostname(), peer.Hostname()) + } + } + }, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 3 - all connectivity tests with directional policy", iteration) + + t.Logf("=== Iteration %d/5 completed successfully - All 3 phases passed ===", iteration) + } + + t.Log("All 5 iterations completed successfully - ACL propagation is working correctly") +} diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 7f8a9e8f..c6a4f4cf 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -74,14 +74,21 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { clientIPs[client] = ips } - listNodes, err := headscale.ListNodes() - assert.Len(t, allClients, len(listNodes)) - nodeCountBeforeLogout := len(listNodes) - t.Logf("node count before logout: %d", nodeCountBeforeLogout) + var listNodes []*v1.Node + var nodeCountBeforeLogout int + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, len(allClients)) - for _, node := range listNodes { - assertLastSeenSet(t, node) - } + for _, node := range listNodes { + assertLastSeenSetWithCollect(c, node) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout") + + nodeCountBeforeLogout = len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) for _, client := range allClients { err := client.Logout() @@ -188,11 +195,16 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } } - listNodes, err = headscale.ListNodes() - require.Len(t, listNodes, nodeCountBeforeLogout) - for _, node := range listNodes { - assertLastSeenSet(t, node) - } + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, nodeCountBeforeLogout) + + for _, node := range listNodes { + assertLastSeenSetWithCollect(c, node) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for node list after relogin") }) } } @@ -238,9 +250,16 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) - listNodes, err := headscale.ListNodes() - assert.Len(t, allClients, len(listNodes)) - nodeCountBeforeLogout := len(listNodes) + var listNodes []*v1.Node + var nodeCountBeforeLogout int + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, len(allClients)) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout") + + nodeCountBeforeLogout = len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) for _, client := range allClients { @@ -371,9 +390,16 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) - listNodes, err := headscale.ListNodes() - assert.Len(t, allClients, len(listNodes)) - nodeCountBeforeLogout := len(listNodes) + var listNodes []*v1.Node + var nodeCountBeforeLogout int + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, len(allClients)) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout") + + nodeCountBeforeLogout = len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) for _, client := range allClients { diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index c08a5efd..0a0b5b95 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -901,15 +901,18 @@ func TestOIDCFollowUpUrl(t *testing.T) { // a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION time.Sleep(2 * time.Minute) - st, err := ts.Status() - require.NoError(t, err) - assert.Equal(t, "NeedsLogin", st.BackendState) + var newUrl *url.URL + assert.EventuallyWithT(t, func(c *assert.CollectT) { + st, err := ts.Status() + assert.NoError(c, err) + assert.Equal(c, "NeedsLogin", st.BackendState) - // get new AuthURL from daemon - newUrl, err := url.Parse(st.AuthURL) - require.NoError(t, err) + // get new AuthURL from daemon + newUrl, err = url.Parse(st.AuthURL) + assert.NoError(c, err) - assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change") + assert.NotEqual(c, u.String(), st.AuthURL, "AuthURL should change") + }, 10*time.Second, 200*time.Millisecond, "Waiting for registration cache to expire and status to reflect NeedsLogin") _, err = doLoginURL(ts.Hostname(), newUrl) require.NoError(t, err) @@ -943,9 +946,11 @@ func TestOIDCFollowUpUrl(t *testing.T) { t.Fatalf("unexpected users: %s", diff) } - listNodes, err := headscale.ListNodes() - require.NoError(t, err) - assert.Len(t, listNodes, 1) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + listNodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, 1) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login") } // TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client diff --git a/integration/cli_test.go b/integration/cli_test.go index d6616d62..37e3c33d 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -127,18 +127,20 @@ func TestUserCommand(t *testing.T) { }, 20*time.Second, 1*time.Second) var listByUsername []*v1.User - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "users", - "list", - "--output", - "json", - "--name=user1", - }, - &listByUsername, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + "--name=user1", + }, + &listByUsername, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for user list by username") slices.SortFunc(listByUsername, sortWithID) want := []*v1.User{ @@ -154,18 +156,20 @@ func TestUserCommand(t *testing.T) { } var listByID []*v1.User - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "users", - "list", - "--output", - "json", - "--identifier=1", - }, - &listByID, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + "--identifier=1", + }, + &listByID, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for user list by ID") slices.SortFunc(listByID, sortWithID) want = []*v1.User{ @@ -234,19 +238,20 @@ func TestUserCommand(t *testing.T) { assert.Contains(t, deleteResult, "User destroyed") var listAfterNameDelete []v1.User - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "users", - "list", - "--output", - "json", - }, - &listAfterNameDelete, - ) - require.NoError(t, err) - - require.Empty(t, listAfterNameDelete) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listAfterNameDelete, + ) + assert.NoError(c, err) + assert.Empty(c, listAfterNameDelete) + }, 10*time.Second, 200*time.Millisecond, "Waiting for user list after name delete") } func TestPreAuthKeyCommand(t *testing.T) { @@ -274,25 +279,27 @@ func TestPreAuthKeyCommand(t *testing.T) { for index := range count { var preAuthKey v1.PreAuthKey - err := executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - "1", - "create", - "--reusable", - "--expiration", - "24h", - "--output", - "json", - "--tags", - "tag:test1,tag:test2", - }, - &preAuthKey, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err := executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "create", + "--reusable", + "--expiration", + "24h", + "--output", + "json", + "--tags", + "tag:test1,tag:test2", + }, + &preAuthKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth key creation") keys[index] = &preAuthKey } @@ -300,20 +307,22 @@ func TestPreAuthKeyCommand(t *testing.T) { assert.Len(t, keys, 3) var listedPreAuthKeys []v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - "1", - "list", - "--output", - "json", - }, - &listedPreAuthKeys, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "list", + "--output", + "json", + }, + &listedPreAuthKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth keys list") // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 4) @@ -375,20 +384,22 @@ func TestPreAuthKeyCommand(t *testing.T) { require.NoError(t, err) var listedPreAuthKeysAfterExpire []v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - "1", - "list", - "--output", - "json", - }, - &listedPreAuthKeysAfterExpire, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "list", + "--output", + "json", + }, + &listedPreAuthKeysAfterExpire, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth keys list after expire") assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now())) @@ -414,37 +425,41 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { require.NoError(t, err) var preAuthKey v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - "1", - "create", - "--reusable", - "--output", - "json", - }, - &preAuthKey, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "create", + "--reusable", + "--output", + "json", + }, + &preAuthKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth key creation without expiry") var listedPreAuthKeys []v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - "1", - "list", - "--output", - "json", - }, - &listedPreAuthKeys, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "list", + "--output", + "json", + }, + &listedPreAuthKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth keys list without expiry") // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 2) @@ -475,57 +490,63 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { require.NoError(t, err) var preAuthReusableKey v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - "1", - "create", - "--reusable=true", - "--output", - "json", - }, - &preAuthReusableKey, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "create", + "--reusable=true", + "--output", + "json", + }, + &preAuthReusableKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for reusable preauth key creation") var preAuthEphemeralKey v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - "1", - "create", - "--ephemeral=true", - "--output", - "json", - }, - &preAuthEphemeralKey, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "create", + "--ephemeral=true", + "--output", + "json", + }, + &preAuthEphemeralKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for ephemeral preauth key creation") assert.True(t, preAuthEphemeralKey.GetEphemeral()) assert.False(t, preAuthEphemeralKey.GetReusable()) var listedPreAuthKeys []v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - "1", - "list", - "--output", - "json", - }, - &listedPreAuthKeys, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "list", + "--output", + "json", + }, + &listedPreAuthKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth keys list after reusable/ephemeral creation") // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 3) @@ -562,25 +583,27 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { var user2Key v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - strconv.FormatUint(u2.GetId(), 10), - "create", - "--reusable", - "--expiration", - "24h", - "--output", - "json", - "--tags", - "tag:test1,tag:test2", - }, - &user2Key, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + strconv.FormatUint(u2.GetId(), 10), + "create", + "--reusable", + "--expiration", + "24h", + "--output", + "json", + "--tags", + "tag:test1,tag:test2", + }, + &user2Key, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for user2 preauth key creation") var listNodes []*v1.Node assert.EventuallyWithT(t, func(ct *assert.CollectT) { @@ -674,17 +697,19 @@ func TestApiKeyCommand(t *testing.T) { assert.Len(t, keys, 5) var listedAPIKeys []v1.ApiKey - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "apikeys", - "list", - "--output", - "json", - }, - &listedAPIKeys, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "apikeys", + "list", + "--output", + "json", + }, + &listedAPIKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for API keys list") assert.Len(t, listedAPIKeys, 5) @@ -746,17 +771,19 @@ func TestApiKeyCommand(t *testing.T) { } var listedAfterExpireAPIKeys []v1.ApiKey - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "apikeys", - "list", - "--output", - "json", - }, - &listedAfterExpireAPIKeys, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "apikeys", + "list", + "--output", + "json", + }, + &listedAfterExpireAPIKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for API keys list after expire") for index := range listedAfterExpireAPIKeys { if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok { @@ -785,17 +812,19 @@ func TestApiKeyCommand(t *testing.T) { assert.NoError(t, err) var listedAPIKeysAfterDelete []v1.ApiKey - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "apikeys", - "list", - "--output", - "json", - }, - &listedAPIKeysAfterDelete, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "apikeys", + "list", + "--output", + "json", + }, + &listedAPIKeysAfterDelete, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for API keys list after delete") assert.Len(t, listedAPIKeysAfterDelete, 4) } @@ -843,22 +872,24 @@ func TestNodeTagCommand(t *testing.T) { assert.NoError(t, err) var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "user1", - "register", - "--key", - regID, - "--output", - "json", - }, - &node, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "user1", + "register", + "--key", + regID, + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node registration") nodes[index] = &node } @@ -867,19 +898,21 @@ func TestNodeTagCommand(t *testing.T) { }, 15*time.Second, 1*time.Second) var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "tag", - "-i", "1", - "-t", "tag:test", - "--output", "json", - }, - &node, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "tag", + "-i", "1", + "-t", "tag:test", + "--output", "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node tag command") assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) @@ -897,17 +930,19 @@ func TestNodeTagCommand(t *testing.T) { // Test list all nodes after added seconds resultMachines := make([]*v1.Node, len(regIDs)) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", "json", - }, - &resultMachines, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", "json", + }, + &resultMachines, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after tagging") found := false for _, node := range resultMachines { if node.GetForcedTags() != nil { @@ -1021,31 +1056,34 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { require.NoError(t, err) // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec.NodesPerUser) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--tags", - "--output", "json", - }, - &resultMachines, - ) - assert.NoError(t, err) - found := false - for _, node := range resultMachines { - if tags := node.GetValidTags(); tags != nil { - found = slices.Contains(tags, "tag:test") + var resultMachines []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { + resultMachines = make([]*v1.Node, spec.NodesPerUser) + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--tags", + "--output", "json", + }, + &resultMachines, + ) + assert.NoError(c, err) + found := false + for _, node := range resultMachines { + if tags := node.GetValidTags(); tags != nil { + found = slices.Contains(tags, "tag:test") + } } - } - assert.Equalf( - t, - tt.wantTag, - found, - "'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag, - ) + assert.Equalf( + c, + tt.wantTag, + found, + "'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag, + ) + }, 10*time.Second, 200*time.Millisecond, "Waiting for tag propagation to nodes") }) } } @@ -1096,22 +1134,24 @@ func TestNodeCommand(t *testing.T) { assert.NoError(t, err) var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "node-user", - "register", - "--key", - regID, - "--output", - "json", - }, - &node, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "node-user", + "register", + "--key", + regID, + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node registration") nodes[index] = &node } @@ -1176,22 +1216,24 @@ func TestNodeCommand(t *testing.T) { assert.NoError(t, err) var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "other-user", - "register", - "--key", - regID, - "--output", - "json", - }, - &node, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "other-user", + "register", + "--key", + regID, + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for other-user node registration") otherUserMachines[index] = &node } @@ -1202,18 +1244,20 @@ func TestNodeCommand(t *testing.T) { // Test list all nodes after added otherUser var listAllWithotherUser []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAllWithotherUser, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAllWithotherUser, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after adding other-user nodes") // All nodes, nodes + otherUser assert.Len(t, listAllWithotherUser, 7) @@ -1226,20 +1270,22 @@ func TestNodeCommand(t *testing.T) { // Test list all nodes after added otherUser var listOnlyotherUserMachineUser []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--user", - "other-user", - "--output", - "json", - }, - &listOnlyotherUserMachineUser, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--user", + "other-user", + "--output", + "json", + }, + &listOnlyotherUserMachineUser, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list filtered by other-user") assert.Len(t, listOnlyotherUserMachineUser, 2) @@ -1339,22 +1385,24 @@ func TestNodeExpireCommand(t *testing.T) { assert.NoError(t, err) var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "node-expire-user", - "register", - "--key", - regID, - "--output", - "json", - }, - &node, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "node-expire-user", + "register", + "--key", + regID, + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node-expire-user node registration") nodes[index] = &node } @@ -1362,18 +1410,20 @@ func TestNodeExpireCommand(t *testing.T) { assert.Len(t, nodes, len(regIDs)) var listAll []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAll, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAll, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list in expire test") assert.Len(t, listAll, 5) @@ -1397,18 +1447,20 @@ func TestNodeExpireCommand(t *testing.T) { } var listAllAfterExpiry []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAllAfterExpiry, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAllAfterExpiry, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after expiry") assert.Len(t, listAllAfterExpiry, 5) @@ -1465,22 +1517,24 @@ func TestNodeRenameCommand(t *testing.T) { require.NoError(t, err) var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "node-rename-command", - "register", - "--key", - regID, - "--output", - "json", - }, - &node, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "node-rename-command", + "register", + "--key", + regID, + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node-rename-command node registration") nodes[index] = &node } @@ -1488,18 +1542,20 @@ func TestNodeRenameCommand(t *testing.T) { assert.Len(t, nodes, len(regIDs)) var listAll []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAll, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAll, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list in rename test") assert.Len(t, listAll, 5) @@ -1526,18 +1582,20 @@ func TestNodeRenameCommand(t *testing.T) { } var listAllAfterRename []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAllAfterRename, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAllAfterRename, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after rename") assert.Len(t, listAllAfterRename, 5) @@ -1561,18 +1619,20 @@ func TestNodeRenameCommand(t *testing.T) { assert.ErrorContains(t, err, "must not exceed 63 characters") var listAllAfterRenameAttempt []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAllAfterRenameAttempt, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAllAfterRenameAttempt, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after failed rename attempt") assert.Len(t, listAllAfterRenameAttempt, 5) @@ -1624,22 +1684,24 @@ func TestNodeMoveCommand(t *testing.T) { assert.NoError(t, err) var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "old-user", - "register", - "--key", - regID.String(), - "--output", - "json", - }, - &node, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "old-user", + "register", + "--key", + regID.String(), + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for old-user node registration") assert.Equal(t, uint64(1), node.GetId()) assert.Equal(t, "nomad-node", node.GetName()) @@ -1647,38 +1709,42 @@ func TestNodeMoveCommand(t *testing.T) { nodeID := strconv.FormatUint(node.GetId(), 10) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "move", - "--identifier", - strconv.FormatUint(node.GetId(), 10), - "--user", - strconv.FormatUint(userMap["new-user"].GetId(), 10), - "--output", - "json", - }, - &node, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "move", + "--identifier", + strconv.FormatUint(node.GetId(), 10), + "--user", + strconv.FormatUint(userMap["new-user"].GetId(), 10), + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node move to new-user") assert.Equal(t, "new-user", node.GetUser().GetName()) var allNodes []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &allNodes, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &allNodes, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after move") assert.Len(t, allNodes, 1) @@ -1706,41 +1772,45 @@ func TestNodeMoveCommand(t *testing.T) { ) assert.Equal(t, "new-user", node.GetUser().GetName()) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "move", - "--identifier", - nodeID, - "--user", - strconv.FormatUint(userMap["old-user"].GetId(), 10), - "--output", - "json", - }, - &node, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "move", + "--identifier", + nodeID, + "--user", + strconv.FormatUint(userMap["old-user"].GetId(), 10), + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node move back to old-user") assert.Equal(t, "old-user", node.GetUser().GetName()) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "move", - "--identifier", - nodeID, - "--user", - strconv.FormatUint(userMap["old-user"].GetId(), 10), - "--output", - "json", - }, - &node, - ) - assert.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "move", + "--identifier", + nodeID, + "--user", + strconv.FormatUint(userMap["old-user"].GetId(), 10), + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node move to same user") assert.Equal(t, "old-user", node.GetUser().GetName()) } @@ -1808,18 +1878,20 @@ func TestPolicyCommand(t *testing.T) { // Get the current policy and check // if it is the same as the one we set. var output *policyv2.Policy - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "policy", - "get", - "--output", - "json", - }, - &output, - ) - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "policy", + "get", + "--output", + "json", + }, + &output, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for policy get command") assert.Len(t, output.TagOwners, 1) assert.Len(t, output.ACLs, 1) diff --git a/integration/control.go b/integration/control.go index 773ddeb8..e0e67e09 100644 --- a/integration/control.go +++ b/integration/control.go @@ -25,6 +25,7 @@ type ControlServer interface { CreateUser(user string) (*v1.User, error) CreateAuthKey(user uint64, reusable bool, ephemeral bool) (*v1.PreAuthKey, error) ListNodes(users ...string) ([]*v1.Node, error) + DeleteNode(nodeID uint64) error NodesByUser() (map[string][]*v1.Node, error) NodesByName() (map[string]*v1.Node, error) ListUsers() ([]*v1.User, error) @@ -38,4 +39,5 @@ type ControlServer interface { PrimaryRoutes() (*routes.DebugRoutes, error) DebugBatcher() (*hscontrol.DebugBatcherInfo, error) DebugNodeStore() (map[types.NodeID]types.Node, error) + DebugFilter() ([]tailcfg.FilterRule, error) } diff --git a/integration/general_test.go b/integration/general_test.go index 83160e9b..2432db9c 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -541,8 +541,7 @@ func TestUpdateHostnameFromClient(t *testing.T) { // update hostnames using the up command for _, client := range allClients { - status, err := client.Status() - require.NoError(t, err) + status := client.MustStatus() command := []string{ "tailscale", @@ -642,8 +641,7 @@ func TestUpdateHostnameFromClient(t *testing.T) { }, 60*time.Second, 2*time.Second) for _, client := range allClients { - status, err := client.Status() - require.NoError(t, err) + status := client.MustStatus() command := []string{ "tailscale", @@ -773,26 +771,25 @@ func TestExpireNode(t *testing.T) { // Verify that the expired node has been marked in all peers list. for _, client := range allClients { - status, err := client.Status() - require.NoError(t, err) + if client.Hostname() == node.GetName() { + continue + } - if client.Hostname() != node.GetName() { - t.Logf("available peers of %s: %v", client.Hostname(), status.Peers()) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) // Ensures that the node is present, and that it is expired. - if peerStatus, ok := status.Peer[expiredNodeKey]; ok { - requireNotNil(t, peerStatus.Expired) - assert.NotNil(t, peerStatus.KeyExpiry) + peerStatus, ok := status.Peer[expiredNodeKey] + assert.True(c, ok, "expired node key should be present in peer list") + + if ok { + assert.NotNil(c, peerStatus.Expired) + assert.NotNil(c, peerStatus.KeyExpiry) - t.Logf( - "node %q should have a key expire before %s, was %s", - peerStatus.HostName, - now.String(), - peerStatus.KeyExpiry, - ) if peerStatus.KeyExpiry != nil { assert.Truef( - t, + c, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, @@ -802,7 +799,7 @@ func TestExpireNode(t *testing.T) { } assert.Truef( - t, + c, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, @@ -811,24 +808,14 @@ func TestExpireNode(t *testing.T) { _, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()}) if !strings.Contains(stderr, "node key has expired") { - t.Errorf( + c.Errorf( "expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname(), ) } - } else { - t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey) } - } else { - if status.Self.KeyExpiry != nil { - assert.Truef(t, status.Self.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", status.Self.HostName, now.String(), status.Self.KeyExpiry) - } - - // NeedsLogin means that the node has understood that it is no longer - // valid. - assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName) - } + }, 10*time.Second, 200*time.Millisecond, "Waiting for expired node status to propagate") } } @@ -866,11 +853,13 @@ func TestNodeOnlineStatus(t *testing.T) { t.Logf("before expire: %d successful pings out of %d", success, len(allClients)*len(allIps)) for _, client := range allClients { - status, err := client.Status() - require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) - // Assert that we have the original count - self - assert.Len(t, status.Peers(), len(MustTestVersions)-1) + // Assert that we have the original count - self + assert.Len(c, status.Peers(), len(MustTestVersions)-1) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer count") } headscale, err := scenario.Headscale() diff --git a/integration/helpers.go b/integration/helpers.go index 8e81fa9b..133a175b 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -507,6 +507,11 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) { assert.NotNil(t, node.GetLastSeen()) } +func assertLastSeenSetWithCollect(c *assert.CollectT, node *v1.Node) { + assert.NotNil(c, node) + assert.NotNil(c, node.GetLastSeen()) +} + // assertTailscaleNodesLogout verifies that all provided Tailscale clients // are in the logged-out state (NeedsLogin). func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { @@ -633,50 +638,50 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { t.Logf("Checking netmap of %q", client.Hostname()) - netmap, err := client.Netmap() - if err != nil { - t.Fatalf("getting netmap for %q: %s", client.Hostname(), err) - } + assert.EventuallyWithT(t, func(c *assert.CollectT) { + netmap, err := client.Netmap() + assert.NoError(c, err, "getting netmap for %q", client.Hostname()) - assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) - if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { - assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) - } - - assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) - assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) - - assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname()) - - assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) - assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) - assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) - - for _, peer := range netmap.Peers { - assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) - assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) - - assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) - if hi := peer.Hostinfo(); hi.Valid() { - assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) - - // Netinfo is not always set - // assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) - if ni := hi.NetInfo(); ni.Valid() { - assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) - } + assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) + if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { + assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) } - assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(c, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) + assert.NotEmptyf(c, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) - assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) + assert.Truef(c, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname()) - assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) - } + assert.Falsef(c, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) + assert.Falsef(c, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) + assert.Falsef(c, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) + + for _, peer := range netmap.Peers { + assert.NotEqualf(c, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) + assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) + + assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) + if hi := peer.Hostinfo(); hi.Valid() { + assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) + + // Netinfo is not always set + // assert.Truef(c, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) + if ni := hi.NetInfo(); ni.Valid() { + assert.NotEqualf(c, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) + } + } + + assert.NotEmptyf(c, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(c, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(c, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) + + assert.Truef(c, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) + + assert.Falsef(c, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(c, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(c, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for valid netmap for %q", client.Hostname()) } // assertValidStatus validates that a client's status has all required fields for proper operation. @@ -920,3 +925,125 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { EmailVerified: emailVerified, } } + +// GetUserByName retrieves a user by name from the headscale server. +// This is a common pattern used when creating preauth keys or managing users. +func GetUserByName(headscale ControlServer, username string) (*v1.User, error) { + users, err := headscale.ListUsers() + if err != nil { + return nil, fmt.Errorf("failed to list users: %w", err) + } + + for _, u := range users { + if u.GetName() == username { + return u, nil + } + } + + return nil, fmt.Errorf("user %s not found", username) +} + +// FindNewClient finds a client that is in the new list but not in the original list. +// This is useful when dynamically adding nodes during tests and needing to identify +// which client was just added. +func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) { + for _, client := range updated { + isOriginal := false + for _, origClient := range original { + if client.Hostname() == origClient.Hostname() { + isOriginal = true + break + } + } + if !isOriginal { + return client, nil + } + } + return nil, fmt.Errorf("no new client found") +} + +// AddAndLoginClient adds a new tailscale client to a user and logs it in. +// This combines the common pattern of: +// 1. Creating a new node +// 2. Finding the new node in the client list +// 3. Getting the user to create a preauth key +// 4. Logging in the new node +func (s *Scenario) AddAndLoginClient( + t *testing.T, + username string, + version string, + headscale ControlServer, + tsOpts ...tsic.Option, +) (TailscaleClient, error) { + t.Helper() + + // Get the original client list + originalClients, err := s.ListTailscaleClients(username) + if err != nil { + return nil, fmt.Errorf("failed to list original clients: %w", err) + } + + // Create the new node + err = s.CreateTailscaleNodesInUser(username, version, 1, tsOpts...) + if err != nil { + return nil, fmt.Errorf("failed to create tailscale node: %w", err) + } + + // Wait for the new node to appear in the client list + var newClient TailscaleClient + + _, err = backoff.Retry(t.Context(), func() (struct{}, error) { + updatedClients, err := s.ListTailscaleClients(username) + if err != nil { + return struct{}{}, fmt.Errorf("failed to list updated clients: %w", err) + } + + if len(updatedClients) != len(originalClients)+1 { + return struct{}{}, fmt.Errorf("expected %d clients, got %d", len(originalClients)+1, len(updatedClients)) + } + + newClient, err = FindNewClient(originalClients, updatedClients) + if err != nil { + return struct{}{}, fmt.Errorf("failed to find new client: %w", err) + } + + return struct{}{}, nil + }, backoff.WithBackOff(backoff.NewConstantBackOff(500*time.Millisecond)), backoff.WithMaxElapsedTime(10*time.Second)) + if err != nil { + return nil, fmt.Errorf("timeout waiting for new client: %w", err) + } + + // Get the user and create preauth key + user, err := GetUserByName(headscale, username) + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + + authKey, err := s.CreatePreAuthKey(user.GetId(), true, false) + if err != nil { + return nil, fmt.Errorf("failed to create preauth key: %w", err) + } + + // Login the new client + err = newClient.Login(headscale.GetEndpoint(), authKey.GetKey()) + if err != nil { + return nil, fmt.Errorf("failed to login new client: %w", err) + } + + return newClient, nil +} + +// MustAddAndLoginClient is like AddAndLoginClient but fails the test on error. +func (s *Scenario) MustAddAndLoginClient( + t *testing.T, + username string, + version string, + headscale ControlServer, + tsOpts ...tsic.Option, +) TailscaleClient { + t.Helper() + + client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...) + require.NoError(t, err) + return client +} diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 553b8b1c..88fc4da2 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -1082,6 +1082,30 @@ func (t *HeadscaleInContainer) ListNodes( return ret, nil } +func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error { + command := []string{ + "headscale", + "nodes", + "delete", + "--identifier", + fmt.Sprintf("%d", nodeID), + "--output", + "json", + "--force", + } + + _, _, err := dockertestutil.ExecuteCommand( + t.container, + command, + []string{}, + ) + if err != nil { + return fmt.Errorf("failed to execute delete node command: %w", err) + } + + return nil +} + func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) { nodes, err := t.ListNodes() if err != nil { @@ -1397,3 +1421,38 @@ func (t *HeadscaleInContainer) DebugNodeStore() (map[types.NodeID]types.Node, er return nodeStore, nil } + +// DebugFilter fetches the current filter rules from the debug endpoint. +func (t *HeadscaleInContainer) DebugFilter() ([]tailcfg.FilterRule, error) { + // Execute curl inside the container to access the debug endpoint locally + command := []string{ + "curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/filter", + } + + result, err := t.Execute(command) + if err != nil { + return nil, fmt.Errorf("fetching filter from debug endpoint: %w", err) + } + + var filterRules []tailcfg.FilterRule + if err := json.Unmarshal([]byte(result), &filterRules); err != nil { + return nil, fmt.Errorf("decoding filter response: %w", err) + } + + return filterRules, nil +} + +// DebugPolicy fetches the current policy from the debug endpoint. +func (t *HeadscaleInContainer) DebugPolicy() (string, error) { + // Execute curl inside the container to access the debug endpoint locally + command := []string{ + "curl", "-s", "http://localhost:9090/debug/policy", + } + + result, err := t.Execute(command) + if err != nil { + return "", fmt.Errorf("fetching policy from debug endpoint: %w", err) + } + + return result, nil +} diff --git a/integration/route_test.go b/integration/route_test.go index e1d30750..15b66d6b 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -1358,16 +1358,8 @@ func TestSubnetRouteACL(t *testing.T) { // Sort nodes by ID sort.SliceStable(allClients, func(i, j int) bool { - statusI, err := allClients[i].Status() - if err != nil { - return false - } - - statusJ, err := allClients[j].Status() - if err != nil { - return false - } - + statusI := allClients[i].MustStatus() + statusJ := allClients[j].MustStatus() return statusI.Self.ID < statusJ.Self.ID }) @@ -1475,9 +1467,7 @@ func TestSubnetRouteACL(t *testing.T) { requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes["1"])}) }, 5*time.Second, 200*time.Millisecond, "Verifying client can see subnet routes from router") - clientNm, err := client.Netmap() - require.NoError(t, err) - + // Wait for packet filter updates to propagate to client netmap wantClientFilter := []filter.Match{ { IPProto: views.SliceOf([]ipproto.Proto{ @@ -1503,13 +1493,16 @@ func TestSubnetRouteACL(t *testing.T) { }, } - if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { - t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff) - } + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientNm, err := client.Netmap() + assert.NoError(c, err) - subnetNm, err := subRouter1.Netmap() - require.NoError(t, err) + if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { + assert.Fail(c, fmt.Sprintf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for client packet filter to update") + // Wait for packet filter updates to propagate to subnet router netmap wantSubnetFilter := []filter.Match{ { IPProto: views.SliceOf([]ipproto.Proto{ @@ -1553,9 +1546,14 @@ func TestSubnetRouteACL(t *testing.T) { }, } - if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { - t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff) - } + assert.EventuallyWithT(t, func(c *assert.CollectT) { + subnetNm, err := subRouter1.Netmap() + assert.NoError(c, err) + + if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { + assert.Fail(c, fmt.Sprintf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for subnet router packet filter to update") } // TestEnablingExitRoutes tests enabling exit routes for clients. @@ -1592,12 +1590,16 @@ func TestEnablingExitRoutes(t *testing.T) { err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) - nodes, err := headscale.ListNodes() - require.NoError(t, err) - require.Len(t, nodes, 2) + var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) - requireNodeRouteCount(t, nodes[0], 2, 0, 0) - requireNodeRouteCount(t, nodes[1], 2, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[0], 2, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[1], 2, 0, 0) + }, 10*time.Second, 200*time.Millisecond, "Waiting for route advertisements to propagate") // Verify that no routes has been sent to the client, // they are not yet enabled. diff --git a/integration/scenario.go b/integration/scenario.go index b48e3265..aa844a7e 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -693,6 +693,35 @@ func (s *Scenario) WaitForTailscaleSync() error { return err } +// WaitForTailscaleSyncPerUser blocks execution until each TailscaleClient has the expected +// number of peers for its user. This is useful for policies like autogroup:self where nodes +// only see same-user peers, not all nodes in the network. +func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Duration) error { + var allErrors []error + + for _, user := range s.users { + // Calculate expected peer count: number of nodes in this user minus 1 (self) + expectedPeers := len(user.Clients) - 1 + + for _, client := range user.Clients { + c := client + expectedCount := expectedPeers + user.syncWaitGroup.Go(func() error { + return c.WaitForPeers(expectedCount, timeout, retryInterval) + }) + } + if err := user.syncWaitGroup.Wait(); err != nil { + allErrors = append(allErrors, err) + } + } + + if len(allErrors) > 0 { + return multierr.New(allErrors...) + } + + return nil +} + // WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports // to have all other TailscaleClients present in their netmap.NetworkMap. func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int, timeout, retryInterval time.Duration) error { diff --git a/integration/tailscale.go b/integration/tailscale.go index 07573e6f..414d08bc 100644 --- a/integration/tailscale.go +++ b/integration/tailscale.go @@ -14,6 +14,7 @@ import ( "tailscale.com/net/netcheck" "tailscale.com/types/key" "tailscale.com/types/netmap" + "tailscale.com/wgengine/filter" ) // nolint @@ -36,6 +37,7 @@ type TailscaleClient interface { MustIPv4() netip.Addr MustIPv6() netip.Addr FQDN() (string, error) + MustFQDN() string Status(...bool) (*ipnstate.Status, error) MustStatus() *ipnstate.Status Netmap() (*netmap.NetworkMap, error) @@ -52,6 +54,7 @@ type TailscaleClient interface { ContainerID() string MustID() types.NodeID ReadFile(path string) ([]byte, error) + PacketFilter() ([]filter.Match, error) // FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client // and a bool indicating if the clients online count and peer count is equal. diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index ddd5027f..f6d8baef 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "github.com/cenkalti/backoff/v5" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" @@ -32,6 +33,7 @@ import ( "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/util/multierr" + "tailscale.com/wgengine/filter" ) const ( @@ -597,28 +599,39 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { return t.ips, nil } - ips := make([]netip.Addr, 0) - - command := []string{ - "tailscale", - "ip", - } - - result, _, err := t.Execute(command) - if err != nil { - return []netip.Addr{}, fmt.Errorf("%s failed to join tailscale client: %w", t.hostname, err) - } - - for address := range strings.SplitSeq(result, "\n") { - address = strings.TrimSuffix(address, "\n") - if len(address) < 1 { - continue + // Retry with exponential backoff to handle eventual consistency + ips, err := backoff.Retry(context.Background(), func() ([]netip.Addr, error) { + command := []string{ + "tailscale", + "ip", } - ip, err := netip.ParseAddr(address) + + result, _, err := t.Execute(command) if err != nil { - return nil, err + return nil, fmt.Errorf("%s failed to get IPs: %w", t.hostname, err) } - ips = append(ips, ip) + + ips := make([]netip.Addr, 0) + for address := range strings.SplitSeq(result, "\n") { + address = strings.TrimSuffix(address, "\n") + if len(address) < 1 { + continue + } + ip, err := netip.ParseAddr(address) + if err != nil { + return nil, fmt.Errorf("failed to parse IP %s: %w", address, err) + } + ips = append(ips, ip) + } + + if len(ips) == 0 { + return nil, fmt.Errorf("no IPs returned yet for %s", t.hostname) + } + + return ips, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + if err != nil { + return nil, fmt.Errorf("failed to get IPs for %s after retries: %w", t.hostname, err) } return ips, nil @@ -629,7 +642,6 @@ func (t *TailscaleInContainer) MustIPs() []netip.Addr { if err != nil { panic(err) } - return ips } @@ -646,16 +658,15 @@ func (t *TailscaleInContainer) IPv4() (netip.Addr, error) { } } - return netip.Addr{}, errors.New("no IPv4 address found") + return netip.Addr{}, fmt.Errorf("no IPv4 address found for %s", t.hostname) } func (t *TailscaleInContainer) MustIPv4() netip.Addr { - for _, ip := range t.MustIPs() { - if ip.Is4() { - return ip - } + ip, err := t.IPv4() + if err != nil { + panic(err) } - panic("no ipv4 found") + return ip } func (t *TailscaleInContainer) MustIPv6() netip.Addr { @@ -900,12 +911,33 @@ func (t *TailscaleInContainer) FQDN() (string, error) { return t.fqdn, nil } - status, err := t.Status() + // Retry with exponential backoff to handle eventual consistency + fqdn, err := backoff.Retry(context.Background(), func() (string, error) { + status, err := t.Status() + if err != nil { + return "", fmt.Errorf("failed to get status: %w", err) + } + + if status.Self.DNSName == "" { + return "", fmt.Errorf("FQDN not yet available") + } + + return status.Self.DNSName, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) if err != nil { - return "", fmt.Errorf("failed to get FQDN: %w", err) + return "", fmt.Errorf("failed to get FQDN for %s after retries: %w", t.hostname, err) } - return status.Self.DNSName, nil + return fqdn, nil +} + +// MustFQDN returns the FQDN as a string of the Tailscale instance, panicking on error. +func (t *TailscaleInContainer) MustFQDN() string { + fqdn, err := t.FQDN() + if err != nil { + panic(err) + } + return fqdn } // FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client @@ -1353,3 +1385,18 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { return &p.Persist.PrivateNodeKey, nil } + +// PacketFilter returns the current packet filter rules from the client's network map. +// This is useful for verifying that policy changes have propagated to the client. +func (t *TailscaleInContainer) PacketFilter() ([]filter.Match, error) { + if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { + return nil, fmt.Errorf("tsic.PacketFilter() requires Tailscale 1.56+, current version: %s", t.version) + } + + nm, err := t.Netmap() + if err != nil { + return nil, fmt.Errorf("failed to get netmap: %w", err) + } + + return nm.PacketFilter, nil +} diff --git a/tools/capver/main.go b/tools/capver/main.go index 1e4512c1..cbb5435c 100644 --- a/tools/capver/main.go +++ b/tools/capver/main.go @@ -136,7 +136,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion } // Write to file - err = os.WriteFile(outputFile, formatted, 0644) + err = os.WriteFile(outputFile, formatted, 0o644) if err != nil { return fmt.Errorf("error writing file: %w", err) }