policy: fix autogroup:self propagation and optimize cache invalidation (#2807)
Some checks failed
Build / build-nix (push) Has been cancelled
Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Has been cancelled
Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Has been cancelled
Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Has been cancelled
Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Has been cancelled
Check Generated Files / check-generated (push) Has been cancelled
Tests / test (push) Has been cancelled
Close inactive issues / close-issues (push) Has been cancelled

This commit is contained in:
Kristoffer Dalby
2025-10-23 17:57:41 +02:00
committed by GitHub
parent 66826232ff
commit 2bf1200483
32 changed files with 3318 additions and 1770 deletions

View File

@@ -28,6 +28,7 @@ jobs:
- TestAPIAuthenticationBypassCurl
- TestGRPCAuthenticationBypass
- TestCLIWithConfigAuthenticationBypass
- TestACLPolicyPropagationOverTime
- TestAuthKeyLogoutAndReloginSameUser
- TestAuthKeyLogoutAndReloginNewUser
- TestAuthKeyLogoutAndReloginSameUserExpiredKey

View File

@@ -81,7 +81,7 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error {
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err)
}
// Create file
outFile, err := os.Create(targetPath)
if err != nil {

View File

@@ -1,6 +1,6 @@
package capver
//Generated DO NOT EDIT
// Generated DO NOT EDIT
import "tailscale.com/tailcfg"
@@ -37,16 +37,15 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.84.2": 116,
}
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
116: "v1.84.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
116: "v1.84.0",
}

View File

@@ -185,7 +185,6 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
}
})
}
}
func TestShuffleDERPMapEdgeCases(t *testing.T) {

View File

@@ -73,7 +73,6 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
// Use the worker pool for controlled concurrency instead of direct generation
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
if err != nil {
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
nodeConn.removeConnectionByChannel(c)

View File

@@ -7,7 +7,6 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
@@ -181,6 +180,9 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
return b
}
// FilterForNode returns rules already reduced to only those relevant for this node.
// For autogroup:self policies, it returns per-node compiled rules.
// For global policies, it returns the global filter reduced for this node.
filter, err := b.mapper.state.FilterForNode(node)
if err != nil {
b.addError(err)
@@ -192,7 +194,7 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node, filter),
"base": filter,
}
return b
@@ -231,18 +233,19 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) (
return nil, errors.New("node not found")
}
// Use per-node filter to handle autogroup:self
filter, err := b.mapper.state.FilterForNode(node)
// Get unreduced matchers for peer relationship determination.
// MatchersForNode returns unreduced matchers that include all rules where the node
// could be either source or destination. This is different from FilterForNode which
// returns reduced rules for packet filtering (only rules where node is destination).
matchers, err := b.mapper.state.MatchersForNode(node)
if err != nil {
return nil, err
}
matchers := matcher.MatchesFromFilterRules(filter)
// If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers.
var changedViews views.Slice[types.NodeView]
if len(filter) > 0 {
if len(matchers) > 0 {
changedViews = policy.ReduceNodes(node, peers, matchers)
} else {
changedViews = peers

View File

@@ -15,6 +15,10 @@ type PolicyManager interface {
Filter() ([]tailcfg.FilterRule, []matcher.Match)
// FilterForNode returns filter rules for a specific node, handling autogroup:self
FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error)
// MatchersForNode returns matchers for peer relationship determination (unreduced)
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
// BuildPeerMap constructs peer relationship maps for the given nodes
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error)
SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error)

View File

@@ -10,7 +10,6 @@ import (
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
)
@@ -79,66 +78,6 @@ func BuildPeerMap(
return ret
}
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
// that are not relevant to that particular node.
func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
ret := []tailcfg.FilterRule{}
for _, rule := range rules {
// record if the rule is actually relevant for the given node.
var dests []tailcfg.NetPortRange
DEST_LOOP:
for _, dest := range rule.DstPorts {
expanded, err := util.ParseIPSet(dest.IP, nil)
// Fail closed, if we can't parse it, then we should not allow
// access.
if err != nil {
continue DEST_LOOP
}
if node.InIPSet(expanded) {
dests = append(dests, dest)
continue DEST_LOOP
}
// If the node exposes routes, ensure they are note removed
// when the filters are reduced.
if node.Hostinfo().Valid() {
routableIPs := node.Hostinfo().RoutableIPs()
if routableIPs.Len() > 0 {
for _, routableIP := range routableIPs.All() {
if expanded.OverlapsPrefix(routableIP) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
}
// Also check approved subnet routes - nodes should have access
// to subnets they're approved to route traffic for.
subnetRoutes := node.SubnetRoutes()
for _, subnetRoute := range subnetRoutes {
if expanded.OverlapsPrefix(subnetRoute) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
if len(dests) > 0 {
ret = append(ret, tailcfg.FilterRule{
SrcIPs: rule.SrcIPs,
DstPorts: dests,
IPProto: rule.IPProto,
})
}
}
return ret
}
// ApproveRoutesWithPolicy checks if the node can approve the announced routes
// and returns the new list of approved routes.
// The approved routes will include:

View File

@@ -1,7 +1,6 @@
package policy
import (
"encoding/json"
"fmt"
"net/netip"
"testing"
@@ -11,12 +10,9 @@ import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/must"
)
var ap = func(ipStr string) *netip.Addr {
@@ -29,817 +25,6 @@ var p = func(prefStr string) netip.Prefix {
return ip
}
// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when
// we use headscale "autogroup:internet".
var hsExitNodeDestForTest = []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "96.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "100.0.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "100.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "101.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "102.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "104.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "112.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "169.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "169.128.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "169.192.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "169.224.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "169.240.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "169.248.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "169.252.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "169.255.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "170.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "224.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "2000::/3", Ports: tailcfg.PortRangeAny},
}
func TestTheInternet(t *testing.T) {
internetSet := util.TheInternet()
internetPrefs := internetSet.Prefixes()
for i := range internetPrefs {
if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP {
t.Errorf(
"prefix from internet set %q != hsExit list %q",
internetPrefs[i].String(),
hsExitNodeDestForTest[i].IP,
)
}
}
if len(internetPrefs) != len(hsExitNodeDestForTest) {
t.Fatalf(
"expected same length of prefixes, internet: %d, hsExit: %d",
len(internetPrefs),
len(hsExitNodeDestForTest),
)
}
}
func TestReduceFilterRules(t *testing.T) {
users := types.Users{
types.User{Model: gorm.Model{ID: 1}, Name: "mickael"},
types.User{Model: gorm.Model{ID: 2}, Name: "user1"},
types.User{Model: gorm.Model{ID: 3}, Name: "user2"},
types.User{Model: gorm.Model{ID: 4}, Name: "user100"},
types.User{Model: gorm.Model{ID: 5}, Name: "user3"},
}
tests := []struct {
name string
node *types.Node
peers types.Nodes
pol string
want []tailcfg.FilterRule
}{
{
name: "host1-can-reach-host2-no-rules",
pol: `
{
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"100.64.0.1"
],
"dst": [
"100.64.0.2:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: users[0],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: users[0],
},
},
want: []tailcfg.FilterRule{},
},
{
name: "1604-subnet-routers-are-preserved",
pol: `
{
"groups": {
"group:admins": [
"user1@"
]
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"group:admins:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"10.33.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.1/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::1/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "10.33.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-the-client",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
// "internal" exit node
&types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
},
want: []tailcfg.FilterRule{},
},
{
name: "1786-reducing-breaks-exit-nodes-the-exit",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: hsExitNodeDestForTest,
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-the-example-from-issue",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"0.0.0.0/5:*",
"8.0.0.0/7:*",
"11.0.0.0/8:*",
"12.0.0.0/6:*",
"16.0.0.0/4:*",
"32.0.0.0/3:*",
"64.0.0.0/2:*",
"128.0.0.0/3:*",
"160.0.0.0/5:*",
"168.0.0.0/6:*",
"172.0.0.0/12:*",
"172.32.0.0/11:*",
"172.64.0.0/10:*",
"172.128.0.0/9:*",
"173.0.0.0/8:*",
"174.0.0.0/7:*",
"176.0.0.0/4:*",
"192.0.0.0/9:*",
"192.128.0.0/11:*",
"192.160.0.0/13:*",
"192.169.0.0/16:*",
"192.170.0.0/15:*",
"192.172.0.0/14:*",
"192.176.0.0/12:*",
"192.192.0.0/10:*",
"193.0.0.0/8:*",
"194.0.0.0/7:*",
"196.0.0.0/6:*",
"200.0.0.0/5:*",
"208.0.0.0/4:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/8:*",
"16.0.0.0/8:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "8.0.0.0/8",
Ports: tailcfg.PortRangeAny,
},
{
IP: "16.0.0.0/8",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like2",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/16:*",
"16.0.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "8.0.0.0/16",
Ports: tailcfg.PortRangeAny,
},
{
IP: "16.0.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1817-reduce-breaks-32-mask",
pol: `
{
"tagOwners": {
"tag:access-servers": ["user100@"],
},
"groups": {
"group:access": [
"user1@"
]
},
"hosts": {
"dns1": "172.16.0.21/32",
"vlan1": "172.16.0.0/24"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:access"
],
"dst": [
"tag:access-servers:*",
"dns1:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
ForcedTags: []string{"tag:access-servers"},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0::1/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
{
IP: "172.16.0.21/32",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "2365-only-route-policy",
pol: `
{
"hosts": {
"router": "100.64.0.1/32",
"node": "100.64.0.2/32"
},
"acls": [
{
"action": "accept",
"src": [
"*"
],
"dst": [
"router:8000"
]
},
{
"action": "accept",
"src": [
"node"
],
"dst": [
"172.26.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[3],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},
ApprovedRoutes: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},
},
want: []tailcfg.FilterRule{},
},
}
for _, tt := range tests {
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.pol)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager
var err error
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
require.NoError(t, err)
got, _ := pm.Filter()
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
got = ReduceFilterRules(tt.node.View(), got)
if diff := cmp.Diff(tt.want, got); diff != "" {
log.Trace().Interface("got", got).Msg("result")
t.Errorf("TestReduceFilterRules() unexpected result (-want +got):\n%s", diff)
}
})
}
}
}
func TestReduceNodes(t *testing.T) {
type args struct {
nodes types.Nodes

View File

@@ -0,0 +1,71 @@
package policyutil
import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
)
// ReduceFilterRules takes a node and a set of global filter rules and removes all rules
// and destinations that are not relevant to that particular node.
//
// IMPORTANT: This function is designed for global filters only. Per-node filters
// (from autogroup:self policies) are already node-specific and should not be passed
// to this function. Use PolicyManager.FilterForNode() instead, which handles both cases.
func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
ret := []tailcfg.FilterRule{}
for _, rule := range rules {
// record if the rule is actually relevant for the given node.
var dests []tailcfg.NetPortRange
DEST_LOOP:
for _, dest := range rule.DstPorts {
expanded, err := util.ParseIPSet(dest.IP, nil)
// Fail closed, if we can't parse it, then we should not allow
// access.
if err != nil {
continue DEST_LOOP
}
if node.InIPSet(expanded) {
dests = append(dests, dest)
continue DEST_LOOP
}
// If the node exposes routes, ensure they are note removed
// when the filters are reduced.
if node.Hostinfo().Valid() {
routableIPs := node.Hostinfo().RoutableIPs()
if routableIPs.Len() > 0 {
for _, routableIP := range routableIPs.All() {
if expanded.OverlapsPrefix(routableIP) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
}
// Also check approved subnet routes - nodes should have access
// to subnets they're approved to route traffic for.
subnetRoutes := node.SubnetRoutes()
for _, subnetRoute := range subnetRoutes {
if expanded.OverlapsPrefix(subnetRoute) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
if len(dests) > 0 {
ret = append(ret, tailcfg.FilterRule{
SrcIPs: rule.SrcIPs,
DstPorts: dests,
IPProto: rule.IPProto,
})
}
}
return ret
}

View File

@@ -0,0 +1,841 @@
package policyutil_test
import (
"encoding/json"
"fmt"
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policy/policyutil"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/must"
)
var ap = func(ipStr string) *netip.Addr {
ip := netip.MustParseAddr(ipStr)
return &ip
}
var p = func(prefStr string) netip.Prefix {
ip := netip.MustParsePrefix(prefStr)
return ip
}
// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when
// we use headscale "autogroup:internet".
var hsExitNodeDestForTest = []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "96.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "100.0.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "100.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "101.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "102.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "104.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "112.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "169.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "169.128.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "169.192.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "169.224.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "169.240.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "169.248.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "169.252.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "169.255.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "170.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "224.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "2000::/3", Ports: tailcfg.PortRangeAny},
}
func TestTheInternet(t *testing.T) {
internetSet := util.TheInternet()
internetPrefs := internetSet.Prefixes()
for i := range internetPrefs {
if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP {
t.Errorf(
"prefix from internet set %q != hsExit list %q",
internetPrefs[i].String(),
hsExitNodeDestForTest[i].IP,
)
}
}
if len(internetPrefs) != len(hsExitNodeDestForTest) {
t.Fatalf(
"expected same length of prefixes, internet: %d, hsExit: %d",
len(internetPrefs),
len(hsExitNodeDestForTest),
)
}
}
func TestReduceFilterRules(t *testing.T) {
users := types.Users{
types.User{Model: gorm.Model{ID: 1}, Name: "mickael"},
types.User{Model: gorm.Model{ID: 2}, Name: "user1"},
types.User{Model: gorm.Model{ID: 3}, Name: "user2"},
types.User{Model: gorm.Model{ID: 4}, Name: "user100"},
types.User{Model: gorm.Model{ID: 5}, Name: "user3"},
}
tests := []struct {
name string
node *types.Node
peers types.Nodes
pol string
want []tailcfg.FilterRule
}{
{
name: "host1-can-reach-host2-no-rules",
pol: `
{
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"100.64.0.1"
],
"dst": [
"100.64.0.2:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: users[0],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: users[0],
},
},
want: []tailcfg.FilterRule{},
},
{
name: "1604-subnet-routers-are-preserved",
pol: `
{
"groups": {
"group:admins": [
"user1@"
]
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"group:admins:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"10.33.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.1/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::1/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "10.33.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-the-client",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
// "internal" exit node
&types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
},
want: []tailcfg.FilterRule{},
},
{
name: "1786-reducing-breaks-exit-nodes-the-exit",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: hsExitNodeDestForTest,
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-the-example-from-issue",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"0.0.0.0/5:*",
"8.0.0.0/7:*",
"11.0.0.0/8:*",
"12.0.0.0/6:*",
"16.0.0.0/4:*",
"32.0.0.0/3:*",
"64.0.0.0/2:*",
"128.0.0.0/3:*",
"160.0.0.0/5:*",
"168.0.0.0/6:*",
"172.0.0.0/12:*",
"172.32.0.0/11:*",
"172.64.0.0/10:*",
"172.128.0.0/9:*",
"173.0.0.0/8:*",
"174.0.0.0/7:*",
"176.0.0.0/4:*",
"192.0.0.0/9:*",
"192.128.0.0/11:*",
"192.160.0.0/13:*",
"192.169.0.0/16:*",
"192.170.0.0/15:*",
"192.172.0.0/14:*",
"192.176.0.0/12:*",
"192.192.0.0/10:*",
"193.0.0.0/8:*",
"194.0.0.0/7:*",
"196.0.0.0/6:*",
"200.0.0.0/5:*",
"208.0.0.0/4:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/8:*",
"16.0.0.0/8:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "8.0.0.0/8",
Ports: tailcfg.PortRangeAny,
},
{
IP: "16.0.0.0/8",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like2",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/16:*",
"16.0.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "8.0.0.0/16",
Ports: tailcfg.PortRangeAny,
},
{
IP: "16.0.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1817-reduce-breaks-32-mask",
pol: `
{
"tagOwners": {
"tag:access-servers": ["user100@"],
},
"groups": {
"group:access": [
"user1@"
]
},
"hosts": {
"dns1": "172.16.0.21/32",
"vlan1": "172.16.0.0/24"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:access"
],
"dst": [
"tag:access-servers:*",
"dns1:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
ForcedTags: []string{"tag:access-servers"},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0::1/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
{
IP: "172.16.0.21/32",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "2365-only-route-policy",
pol: `
{
"hosts": {
"router": "100.64.0.1/32",
"node": "100.64.0.2/32"
},
"acls": [
{
"action": "accept",
"src": [
"*"
],
"dst": [
"router:8000"
]
},
{
"action": "accept",
"src": [
"node"
],
"dst": [
"172.26.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[3],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},
ApprovedRoutes: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},
},
want: []tailcfg.FilterRule{},
},
}
for _, tt := range tests {
for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm policy.PolicyManager
var err error
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
require.NoError(t, err)
got, _ := pm.Filter()
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
got = policyutil.ReduceFilterRules(tt.node.View(), got)
if diff := cmp.Diff(tt.want, got); diff != "" {
log.Trace().Interface("got", got).Msg("result")
t.Errorf("TestReduceFilterRules() unexpected result (-want +got):\n%s", diff)
}
})
}
}
}

View File

@@ -854,7 +854,6 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
node1 := nodes[0].View()
rules, err := policy2.compileFilterRulesForNode(users, node1, nodes.ViewSlice())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

View File

@@ -9,6 +9,7 @@ import (
"sync"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/policy/policyutil"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"go4.org/netipx"
@@ -39,7 +40,9 @@ type PolicyManager struct {
// Lazy map of SSH policies
sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy
// Lazy map of per-node filter rules (when autogroup:self is used)
// Lazy map of per-node compiled filter rules (unreduced, for autogroup:self)
compiledFilterRulesMap map[types.NodeID][]tailcfg.FilterRule
// Lazy map of per-node filter rules (reduced, for packet filters)
filterRulesMap map[types.NodeID][]tailcfg.FilterRule
usesAutogroupSelf bool
}
@@ -54,12 +57,13 @@ func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.Node
}
pm := PolicyManager{
pol: policy,
users: users,
nodes: nodes,
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()),
filterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()),
usesAutogroupSelf: policy.usesAutogroupSelf(),
pol: policy,
users: users,
nodes: nodes,
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()),
compiledFilterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()),
filterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()),
usesAutogroupSelf: policy.usesAutogroupSelf(),
}
_, err = pm.updateLocked()
@@ -78,6 +82,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
// policies for nodes that have changed. Particularly if the only difference is
// that nodes has been added or removed.
clear(pm.sshPolicyMap)
clear(pm.compiledFilterRulesMap)
clear(pm.filterRulesMap)
// Check if policy uses autogroup:self
@@ -233,9 +238,157 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
return pm.filter, pm.matchers
}
// FilterForNode returns the filter rules for a specific node.
// If the policy uses autogroup:self, this returns node-specific rules for security.
// Otherwise, it returns the global filter rules for efficiency.
// BuildPeerMap constructs peer relationship maps for the given nodes.
// For global filters, it uses the global filter matchers for all nodes.
// For autogroup:self policies (empty global filter), it builds per-node
// peer maps using each node's specific filter rules.
func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView {
if pm == nil {
return nil
}
pm.mu.Lock()
defer pm.mu.Unlock()
// If we have a global filter, use it for all nodes (normal case)
if !pm.usesAutogroupSelf {
ret := make(map[types.NodeID][]types.NodeView, nodes.Len())
// Build the map of all peers according to the matchers.
// Compared to ReduceNodes, which builds the list per node, we end up with doing
// the full work for every node O(n^2), while this will reduce the list as we see
// relationships while building the map, making it O(n^2/2) in the end, but with less work per node.
for i := range nodes.Len() {
for j := i + 1; j < nodes.Len(); j++ {
if nodes.At(i).ID() == nodes.At(j).ID() {
continue
}
if nodes.At(i).CanAccess(pm.matchers, nodes.At(j)) || nodes.At(j).CanAccess(pm.matchers, nodes.At(i)) {
ret[nodes.At(i).ID()] = append(ret[nodes.At(i).ID()], nodes.At(j))
ret[nodes.At(j).ID()] = append(ret[nodes.At(j).ID()], nodes.At(i))
}
}
}
return ret
}
// For autogroup:self (empty global filter), build per-node peer relationships
ret := make(map[types.NodeID][]types.NodeView, nodes.Len())
// Pre-compute per-node matchers using unreduced compiled rules
// We need unreduced rules to determine peer relationships correctly.
// Reduced rules only show destinations where the node is the target,
// but peer relationships require the full bidirectional access rules.
nodeMatchers := make(map[types.NodeID][]matcher.Match, nodes.Len())
for _, node := range nodes.All() {
filter, err := pm.compileFilterRulesForNodeLocked(node)
if err != nil || len(filter) == 0 {
continue
}
nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter)
}
// Check each node pair for peer relationships.
// Start j at i+1 to avoid checking the same pair twice and creating duplicates.
// We check both directions (i->j and j->i) since ACLs can be asymmetric.
for i := range nodes.Len() {
nodeI := nodes.At(i)
matchersI, hasFilterI := nodeMatchers[nodeI.ID()]
for j := i + 1; j < nodes.Len(); j++ {
nodeJ := nodes.At(j)
matchersJ, hasFilterJ := nodeMatchers[nodeJ.ID()]
// Check if nodeI can access nodeJ
if hasFilterI && nodeI.CanAccess(matchersI, nodeJ) {
ret[nodeI.ID()] = append(ret[nodeI.ID()], nodeJ)
}
// Check if nodeJ can access nodeI
if hasFilterJ && nodeJ.CanAccess(matchersJ, nodeI) {
ret[nodeJ.ID()] = append(ret[nodeJ.ID()], nodeI)
}
}
}
return ret
}
// compileFilterRulesForNodeLocked returns the unreduced compiled filter rules for a node
// when using autogroup:self. This is used by BuildPeerMap to determine peer relationships.
// For packet filters sent to nodes, use filterForNodeLocked which returns reduced rules.
func (pm *PolicyManager) compileFilterRulesForNodeLocked(node types.NodeView) ([]tailcfg.FilterRule, error) {
if pm == nil {
return nil, nil
}
// Check if we have cached compiled rules
if rules, ok := pm.compiledFilterRulesMap[node.ID()]; ok {
return rules, nil
}
// Compile per-node rules with autogroup:self expanded
rules, err := pm.pol.compileFilterRulesForNode(pm.users, node, pm.nodes)
if err != nil {
return nil, fmt.Errorf("compiling filter rules for node: %w", err)
}
// Cache the unreduced compiled rules
pm.compiledFilterRulesMap[node.ID()] = rules
return rules, nil
}
// filterForNodeLocked returns the filter rules for a specific node, already reduced
// to only include rules relevant to that node.
// This is a lock-free version of FilterForNode for internal use when the lock is already held.
// BuildPeerMap already holds the lock, so we need a version that doesn't re-acquire it.
func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.FilterRule, error) {
if pm == nil {
return nil, nil
}
if !pm.usesAutogroupSelf {
// For global filters, reduce to only rules relevant to this node.
// Cache the reduced filter per node for efficiency.
if rules, ok := pm.filterRulesMap[node.ID()]; ok {
return rules, nil
}
// Use policyutil.ReduceFilterRules for global filter reduction.
reducedFilter := policyutil.ReduceFilterRules(node, pm.filter)
pm.filterRulesMap[node.ID()] = reducedFilter
return reducedFilter, nil
}
// For autogroup:self, compile per-node rules then reduce them.
// Check if we have cached reduced rules for this node.
if rules, ok := pm.filterRulesMap[node.ID()]; ok {
return rules, nil
}
// Get unreduced compiled rules
compiledRules, err := pm.compileFilterRulesForNodeLocked(node)
if err != nil {
return nil, err
}
// Reduce the compiled rules to only destinations relevant to this node
reducedFilter := policyutil.ReduceFilterRules(node, compiledRules)
// Cache the reduced filter
pm.filterRulesMap[node.ID()] = reducedFilter
return reducedFilter, nil
}
// FilterForNode returns the filter rules for a specific node, already reduced
// to only include rules relevant to that node.
// If the policy uses autogroup:self, this returns node-specific compiled rules.
// Otherwise, it returns the global filter reduced for this node.
func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error) {
if pm == nil {
return nil, nil
@@ -244,22 +397,36 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filterForNodeLocked(node)
}
// MatchersForNode returns the matchers for peer relationship determination for a specific node.
// These are UNREDUCED matchers - they include all rules where the node could be either source or destination.
// This is different from FilterForNode which returns REDUCED rules for packet filtering.
//
// For global policies: returns the global matchers (same for all nodes)
// For autogroup:self: returns node-specific matchers from unreduced compiled rules
func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
if pm == nil {
return nil, nil
}
pm.mu.Lock()
defer pm.mu.Unlock()
// For global policies, return the shared global matchers
if !pm.usesAutogroupSelf {
return pm.filter, nil
return pm.matchers, nil
}
if rules, ok := pm.filterRulesMap[node.ID()]; ok {
return rules, nil
}
rules, err := pm.pol.compileFilterRulesForNode(pm.users, node, pm.nodes)
// For autogroup:self, get unreduced compiled rules and create matchers
compiledRules, err := pm.compileFilterRulesForNodeLocked(node)
if err != nil {
return nil, fmt.Errorf("compiling filter rules for node: %w", err)
return nil, err
}
pm.filterRulesMap[node.ID()] = rules
return rules, nil
// Create matchers from unreduced rules for peer relationship determination
return matcher.MatchesFromFilterRules(compiledRules), nil
}
// SetUsers updates the users in the policy manager and updates the filter rules.
@@ -300,22 +467,40 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
pm.mu.Lock()
defer pm.mu.Unlock()
// Clear cache based on what actually changed
if pm.usesAutogroupSelf {
// For autogroup:self, we need granular invalidation since rules depend on:
// - User ownership (node.User().ID)
// - Tag status (node.IsTagged())
// - IP addresses (node.IPs())
// - Node existence (added/removed)
pm.invalidateAutogroupSelfCache(pm.nodes, nodes)
} else {
// For non-autogroup:self policies, we can clear everything
clear(pm.filterRulesMap)
}
oldNodeCount := pm.nodes.Len()
newNodeCount := nodes.Len()
// Invalidate cache entries for nodes that changed.
// For autogroup:self: invalidate all nodes belonging to affected users (peer changes).
// For global policies: invalidate only nodes whose properties changed (IPs, routes).
pm.invalidateNodeCache(nodes)
pm.nodes = nodes
return pm.updateLocked()
nodesChanged := oldNodeCount != newNodeCount
// When nodes are added/removed, we must recompile filters because:
// 1. User/group aliases (like "user1@") resolve to node IPs
// 2. Filter compilation needs nodes to generate rules
// 3. Without nodes, filters compile to empty (0 rules)
//
// For autogroup:self: return true when nodes change even if the global filter
// hash didn't change. The global filter is empty for autogroup:self (each node
// has its own filter), so the hash never changes. But peer relationships DO
// change when nodes are added/removed, so we must signal this to trigger updates.
// For global policies: the filter must be recompiled to include the new nodes.
if nodesChanged {
// Recompile filter with the new node list
_, err := pm.updateLocked()
if err != nil {
return false, err
}
// Always return true when nodes changed, even if filter hash didn't change
// (can happen with autogroup:self or when nodes are added but don't affect rules)
return true, nil
}
return false, nil
}
func (pm *PolicyManager) NodeCanHaveTag(node types.NodeView, tag string) bool {
@@ -552,10 +737,12 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
// If we found the user and they're affected, clear this cache entry
if found {
if _, affected := affectedUsers[nodeUserID]; affected {
delete(pm.compiledFilterRulesMap, nodeID)
delete(pm.filterRulesMap, nodeID)
}
} else {
// Node not found in either old or new list, clear it
delete(pm.compiledFilterRulesMap, nodeID)
delete(pm.filterRulesMap, nodeID)
}
}
@@ -567,3 +754,50 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
Msg("Selectively cleared autogroup:self cache for affected users")
}
}
// invalidateNodeCache invalidates cache entries based on what changed.
func (pm *PolicyManager) invalidateNodeCache(newNodes views.Slice[types.NodeView]) {
if pm.usesAutogroupSelf {
// For autogroup:self, a node's filter depends on its peers (same user).
// When any node in a user changes, all nodes for that user need invalidation.
pm.invalidateAutogroupSelfCache(pm.nodes, newNodes)
} else {
// For global policies, a node's filter depends only on its own properties.
// Only invalidate nodes whose properties actually changed.
pm.invalidateGlobalPolicyCache(newNodes)
}
}
// invalidateGlobalPolicyCache invalidates only nodes whose properties affecting
// ReduceFilterRules changed. For global policies, each node's filter is independent.
func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types.NodeView]) {
oldNodeMap := make(map[types.NodeID]types.NodeView)
for _, node := range pm.nodes.All() {
oldNodeMap[node.ID()] = node
}
newNodeMap := make(map[types.NodeID]types.NodeView)
for _, node := range newNodes.All() {
newNodeMap[node.ID()] = node
}
// Invalidate nodes whose properties changed
for nodeID, newNode := range newNodeMap {
oldNode, existed := oldNodeMap[nodeID]
if !existed {
// New node - no cache entry yet, will be lazily calculated
continue
}
if newNode.HasNetworkChanges(oldNode) {
delete(pm.filterRulesMap, nodeID)
}
}
// Remove deleted nodes from cache
for nodeID := range pm.filterRulesMap {
if _, exists := newNodeMap[nodeID]; !exists {
delete(pm.filterRulesMap, nodeID)
}
}
}

View File

@@ -1,6 +1,7 @@
package v2
import (
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
@@ -204,3 +205,237 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
})
}
}
// TestInvalidateGlobalPolicyCache tests the cache invalidation logic for global policies.
func TestInvalidateGlobalPolicyCache(t *testing.T) {
mustIPPtr := func(s string) *netip.Addr {
ip := netip.MustParseAddr(s)
return &ip
}
tests := []struct {
name string
oldNodes types.Nodes
newNodes types.Nodes
initialCache map[types.NodeID][]tailcfg.FilterRule
expectedCacheAfter map[types.NodeID]bool // true = should exist, false = should not exist
}{
{
name: "node property changed - invalidates only that node",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
newNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.99")}, // Changed
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // Unchanged
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: false, // Invalidated
2: true, // Preserved
},
},
{
name: "multiple nodes changed",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
&types.Node{ID: 3, IPv4: mustIPPtr("100.64.0.3")},
},
newNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.99")}, // Changed
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // Unchanged
&types.Node{ID: 3, IPv4: mustIPPtr("100.64.0.88")}, // Changed
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
3: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: false, // Invalidated
2: true, // Preserved
3: false, // Invalidated
},
},
{
name: "node deleted - removes from cache",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
newNodes: types.Nodes{
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: false, // Deleted
2: true, // Preserved
},
},
{
name: "node added - no cache invalidation needed",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
},
newNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // New
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: true, // Preserved
2: false, // Not in cache (new node)
},
},
{
name: "no changes - preserves all cache",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
newNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: true,
2: true,
},
},
{
name: "routes changed - invalidates that node only",
oldNodes: types.Nodes{
&types.Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
newNodes: types.Nodes{
&types.Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, // Changed
},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: false, // Invalidated
2: true, // Preserved
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm := &PolicyManager{
nodes: tt.oldNodes.ViewSlice(),
filterRulesMap: tt.initialCache,
usesAutogroupSelf: false,
}
pm.invalidateGlobalPolicyCache(tt.newNodes.ViewSlice())
// Verify cache state
for nodeID, shouldExist := range tt.expectedCacheAfter {
_, exists := pm.filterRulesMap[nodeID]
require.Equal(t, shouldExist, exists, "node %d cache existence mismatch", nodeID)
}
})
}
}
// TestAutogroupSelfReducedVsUnreducedRules verifies that:
// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships
// 2. FilterForNode returns reduced compiled rules for packet filters
func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}
users := types.Users{user1, user2}
// Create two nodes
node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1, nil)
node1.ID = 1
node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2, nil)
node2.ID = 2
nodes := types.Nodes{node1, node2}
// Policy with autogroup:self - all members can reach their own devices
policyStr := `{
"acls": [
{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["autogroup:self:*"]
}
]
}`
pm, err := NewPolicyManager([]byte(policyStr), users, nodes.ViewSlice())
require.NoError(t, err)
require.True(t, pm.usesAutogroupSelf, "policy should use autogroup:self")
// Test FilterForNode returns reduced rules
// For node1: should have rules where node1 is in destinations (its own IP)
filterNode1, err := pm.FilterForNode(nodes[0].View())
require.NoError(t, err)
// For node2: should have rules where node2 is in destinations (its own IP)
filterNode2, err := pm.FilterForNode(nodes[1].View())
require.NoError(t, err)
// FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations
// For node1, destinations should only be node1's IPs
node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"}
for _, rule := range filterNode1 {
for _, dst := range rule.DstPorts {
require.Contains(t, node1IPs, dst.IP,
"node1 filter should only contain node1's IPs as destinations")
}
}
// For node2, destinations should only be node2's IPs
node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"}
for _, rule := range filterNode2 {
for _, dst := range rule.DstPorts {
require.Contains(t, node2IPs, dst.IP,
"node2 filter should only contain node2's IPs as destinations")
}
}
// Test BuildPeerMap uses unreduced rules
peerMap := pm.BuildPeerMap(nodes.ViewSlice())
// According to the policy, user1 can reach autogroup:self (which expands to node1's own IPs for node1)
// So node1 should be able to reach itself, but since we're looking at peer relationships,
// node1 should NOT have itself in the peer map (nodes don't peer with themselves)
// node2 should also not have any peers since user2 has no rules allowing it to reach anyone
// Verify peer relationships based on unreduced rules
// With unreduced rules, BuildPeerMap can properly determine that:
// - node1 can access autogroup:self (its own IPs)
// - node2 cannot access node1
require.Empty(t, peerMap[node1.ID], "node1 should have no peers (can only reach itself)")
require.Empty(t, peerMap[node2.ID], "node2 should have no peers")
}

View File

@@ -20,9 +20,10 @@ const (
)
const (
put = 1
del = 2
update = 3
put = 1
del = 2
update = 3
rebuildPeerMaps = 4
)
const prometheusNamespace = "headscale"
@@ -142,6 +143,8 @@ type work struct {
updateFn UpdateNodeFunc
result chan struct{}
nodeResult chan types.NodeView // Channel to return the resulting node after batch application
// For rebuildPeerMaps operation
rebuildResult chan struct{}
}
// PutNode adds or updates a node in the store.
@@ -298,6 +301,9 @@ func (s *NodeStore) applyBatch(batch []work) {
// Track which work items need node results
nodeResultRequests := make(map[types.NodeID][]*work)
// Track rebuildPeerMaps operations
var rebuildOps []*work
for i := range batch {
w := &batch[i]
switch w.op {
@@ -321,6 +327,10 @@ func (s *NodeStore) applyBatch(batch []work) {
if w.nodeResult != nil {
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
}
case rebuildPeerMaps:
// rebuildPeerMaps doesn't modify nodes, it just forces the snapshot rebuild
// below to recalculate peer relationships using the current peersFunc
rebuildOps = append(rebuildOps, w)
}
}
@@ -347,9 +357,16 @@ func (s *NodeStore) applyBatch(batch []work) {
}
}
// Signal completion for all work items
// Signal completion for rebuildPeerMaps operations
for _, w := range rebuildOps {
close(w.rebuildResult)
}
// Signal completion for all other work items
for _, w := range batch {
close(w.result)
if w.op != rebuildPeerMaps {
close(w.result)
}
}
}
@@ -546,6 +563,22 @@ func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
return views.SliceOf(s.data.Load().peersByNode[id])
}
// RebuildPeerMaps rebuilds the peer relationship map using the current peersFunc.
// This must be called after policy changes because peersFunc uses PolicyManager's
// filters to determine which nodes can see each other. Without rebuilding, the
// peer map would use stale filter data until the next node add/delete.
func (s *NodeStore) RebuildPeerMaps() {
result := make(chan struct{})
w := work{
op: rebuildPeerMaps,
rebuildResult: result,
}
s.writeQueue <- w
<-result
}
// ListNodesByUser returns a slice of all nodes for a given user ID.
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))

View File

@@ -132,9 +132,10 @@ func NewState(cfg *types.Config) (*State, error) {
return nil, fmt.Errorf("init policy manager: %w", err)
}
// PolicyManager.BuildPeerMap handles both global and per-node filter complexity.
// This moves the complex peer relationship logic into the policy package where it belongs.
nodeStore := NewNodeStore(nodes, func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
_, matchers := polMan.Filter()
return policy.BuildPeerMap(views.SliceOf(nodes), matchers)
return polMan.BuildPeerMap(views.SliceOf(nodes))
})
nodeStore.Start()
@@ -225,6 +226,12 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) {
return nil, fmt.Errorf("setting policy: %w", err)
}
// Rebuild peer maps after policy changes because the peersFunc in NodeStore
// uses the PolicyManager's filters. Without this, nodes won't see newly allowed
// peers until a node is added/removed, causing autogroup:self policies to not
// propagate correctly when switching between policy types.
s.nodeStore.RebuildPeerMaps()
cs := []change.ChangeSet{change.PolicyChange()}
// Always call autoApproveNodes during policy reload, regardless of whether
@@ -797,6 +804,11 @@ func (s *State) FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error)
return s.polMan.FilterForNode(node)
}
// MatchersForNode returns matchers for peer relationship determination (unreduced).
func (s *State) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
return s.polMan.MatchersForNode(node)
}
// NodeCanHaveTag checks if a node is allowed to have a specific tag.
func (s *State) NodeCanHaveTag(node types.NodeView, tag string) bool {
return s.polMan.NodeCanHaveTag(node, tag)

View File

@@ -340,11 +340,11 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
log.Warn().Msg("No config file found, using defaults")
return nil
}
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
log.Warn().Msg("No config file found, using defaults")
return nil
}
return fmt.Errorf("fatal error reading config file: %w", err)
}

View File

@@ -855,3 +855,22 @@ func (v NodeView) IPsAsString() []string {
}
return v.ж.IPsAsString()
}
// HasNetworkChanges checks if the node has network-related changes.
// Returns true if IPs, announced routes, or approved routes changed.
// This is primarily used for policy cache invalidation.
func (v NodeView) HasNetworkChanges(other NodeView) bool {
if !slices.Equal(v.IPs(), other.IPs()) {
return true
}
if !slices.Equal(v.AnnouncedRoutes(), other.AnnouncedRoutes()) {
return true
}
if !slices.Equal(v.SubnetRoutes(), other.SubnetRoutes()) {
return true
}
return false
}

View File

@@ -793,3 +793,179 @@ func TestNodeRegisterMethodToV1Enum(t *testing.T) {
})
}
}
// TestHasNetworkChanges tests the NodeView method for detecting
// when a node's network properties have changed.
func TestHasNetworkChanges(t *testing.T) {
mustIPPtr := func(s string) *netip.Addr {
ip := netip.MustParseAddr(s)
return &ip
}
tests := []struct {
name string
old *Node
new *Node
changed bool
}{
{
name: "no changes",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
changed: false,
},
{
name: "IPv4 changed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.2"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
},
changed: true,
},
{
name: "IPv6 changed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::2"),
},
changed: true,
},
{
name: "RoutableIPs added",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
},
changed: true,
},
{
name: "RoutableIPs removed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{},
},
changed: true,
},
{
name: "RoutableIPs changed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
},
changed: true,
},
{
name: "SubnetRoutes added",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
changed: true,
},
{
name: "SubnetRoutes removed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{},
},
changed: true,
},
{
name: "SubnetRoutes changed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
changed: true,
},
{
name: "irrelevant property changed (Hostname)",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostname: "old-name",
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostname: "new-name",
},
changed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.new.View().HasNetworkChanges(tt.old.View())
if got != tt.changed {
t.Errorf("HasNetworkChanges() = %v, want %v", got, tt.changed)
}
})
}
}

View File

@@ -3,12 +3,14 @@ package integration
import (
"fmt"
"net/netip"
"strconv"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/integration/hsic"
@@ -319,12 +321,14 @@ func TestACLHostsInNetMapTable(t *testing.T) {
require.NoError(t, err)
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
user := status.User[status.Self.UserID].LoginName
user := status.User[status.Self.UserID].LoginName
assert.Len(t, status.Peer, (testCase.want[user]))
assert.Len(c, status.Peer, (testCase.want[user]))
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer visibility")
}
})
}
@@ -782,75 +786,87 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn)
// test1 can query test3
result, err := test1.Curl(test3ip4URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip4URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test3ip4URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip4URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv4")
result, err = test1.Curl(test3ip6URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip6URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test3ip6URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip6URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv6")
result, err = test1.Curl(test3fqdnURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3fqdnURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test3fqdnURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3fqdnURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via FQDN")
// test2 can query test3
result, err = test2.Curl(test3ip4URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip4URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test3ip4URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip4URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv4")
result, err = test2.Curl(test3ip6URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip6URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test3ip6URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip6URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv6")
result, err = test2.Curl(test3fqdnURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3fqdnURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test3fqdnURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3fqdnURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via FQDN")
// test3 cannot query test1
result, err = test3.Curl(test1ip4URL)
result, err := test3.Curl(test1ip4URL)
assert.Empty(t, result)
require.Error(t, err)
@@ -876,38 +892,44 @@ func TestACLNamedHostsCanReach(t *testing.T) {
require.Error(t, err)
// test1 can query test2
result, err = test1.Curl(test2ip4URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2ip4URL,
result,
)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2ip4URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2ip4URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4")
require.NoError(t, err)
result, err = test1.Curl(test2ip6URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2ip6URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2ip6URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2ip6URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6")
result, err = test1.Curl(test2fqdnURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2fqdnURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2fqdnURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2fqdnURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN")
// test2 cannot query test1
result, err = test2.Curl(test1ip4URL)
@@ -1050,50 +1072,63 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
// test1 can query test2
result, err := test1.Curl(test2ipURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2ipURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2ipURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2ipURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4")
result, err = test1.Curl(test2ip6URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2ip6URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2ip6URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2ip6URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6")
result, err = test1.Curl(test2fqdnURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2fqdnURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2fqdnURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2fqdnURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN")
result, err = test2.Curl(test1ipURL)
assert.Empty(t, result)
require.Error(t, err)
// test2 cannot query test1 (negative test case)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test1ipURL)
assert.Error(c, err)
assert.Empty(c, result)
}, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv4")
result, err = test2.Curl(test1ip6URL)
assert.Empty(t, result)
require.Error(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test1ip6URL)
assert.Error(c, err)
assert.Empty(c, result)
}, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv6")
result, err = test2.Curl(test1fqdnURL)
assert.Empty(t, result)
require.Error(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test1fqdnURL)
assert.Error(c, err)
assert.Empty(c, result)
}, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via FQDN")
})
}
}
@@ -1266,9 +1301,15 @@ func TestACLAutogroupMember(t *testing.T) {
// Test that untagged nodes can access each other
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
if status.Self.Tags != nil && status.Self.Tags.Len() > 0 {
var clientIsUntagged bool
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
clientIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0
assert.True(c, clientIsUntagged, "Expected client %s to be untagged for autogroup:member test", client.Hostname())
}, 10*time.Second, 200*time.Millisecond, "Waiting for client %s to be untagged", client.Hostname())
if !clientIsUntagged {
continue
}
@@ -1277,9 +1318,15 @@ func TestACLAutogroupMember(t *testing.T) {
continue
}
status, err := peer.Status()
require.NoError(t, err)
if status.Self.Tags != nil && status.Self.Tags.Len() > 0 {
var peerIsUntagged bool
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := peer.Status()
assert.NoError(c, err)
peerIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0
assert.True(c, peerIsUntagged, "Expected peer %s to be untagged for autogroup:member test", peer.Hostname())
}, 10*time.Second, 200*time.Millisecond, "Waiting for peer %s to be untagged", peer.Hostname())
if !peerIsUntagged {
continue
}
@@ -1468,21 +1515,23 @@ func TestACLAutogroupTagged(t *testing.T) {
// Explicitly verify tags on tagged nodes
for _, client := range taggedClients {
status, err := client.Status()
require.NoError(t, err)
require.NotNil(t, status.Self.Tags, "tagged node %s should have tags", client.Hostname())
require.Positive(t, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname())
t.Logf("Tagged node %s has tags: %v", client.Hostname(), status.Self.Tags)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
assert.NotNil(c, status.Self.Tags, "tagged node %s should have tags", client.Hostname())
assert.Positive(c, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname())
}, 10*time.Second, 200*time.Millisecond, "Waiting for tags to be applied to tagged nodes")
}
// Verify untagged nodes have no tags
for _, client := range untaggedClients {
status, err := client.Status()
require.NoError(t, err)
if status.Self.Tags != nil {
require.Equal(t, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname())
}
t.Logf("Untagged node %s has no tags", client.Hostname())
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
if status.Self.Tags != nil {
assert.Equal(c, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname())
}
}, 10*time.Second, 200*time.Millisecond, "Waiting to verify untagged nodes have no tags")
}
// Test that tagged nodes can communicate with each other
@@ -1603,9 +1652,11 @@ func TestACLAutogroupSelf(t *testing.T) {
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s (user1) to %s (user1)", client.Hostname(), fqdn)
result, err := client.Curl(url)
assert.Len(t, result, 13)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := client.Curl(url)
assert.NoError(c, err)
assert.Len(c, result, 13)
}, 10*time.Second, 200*time.Millisecond, "user1 device should reach other user1 device")
}
}
@@ -1622,9 +1673,11 @@ func TestACLAutogroupSelf(t *testing.T) {
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s (user2) to %s (user2)", client.Hostname(), fqdn)
result, err := client.Curl(url)
assert.Len(t, result, 13)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := client.Curl(url)
assert.NoError(c, err)
assert.Len(c, result, 13)
}, 10*time.Second, 200*time.Millisecond, "user2 device should reach other user2 device")
}
}
@@ -1657,3 +1710,388 @@ func TestACLAutogroupSelf(t *testing.T) {
}
}
}
func TestACLPolicyPropagationOverTime(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: 2,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{
// Install iptables to enable packet filtering for ACL tests.
// Packet filters are essential for testing autogroup:self and other ACL policies.
tsic.WithDockerEntrypoint([]string{
"/bin/sh",
"-c",
"/bin/sleep 3 ; apk add python3 curl iptables ip6tables ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev",
}),
tsic.WithDockerWorkdir("/"),
},
hsic.WithTestName("aclpropagation"),
hsic.WithPolicyMode(types.PolicyModeDB),
)
require.NoError(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
require.NoError(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
require.NoError(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2")
require.NoError(t, err)
allClients := append(user1Clients, user2Clients...)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Define the four policies we'll cycle through
allowAllPolicy := &policyv2.Policy{
ACLs: []policyv2.ACL{
{
Action: "accept",
Sources: []policyv2.Alias{wildcard()},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
},
},
},
}
autogroupSelfPolicy := &policyv2.Policy{
ACLs: []policyv2.ACL{
{
Action: "accept",
Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny),
},
},
},
}
user1ToUser2Policy := &policyv2.Policy{
ACLs: []policyv2.ACL{
{
Action: "accept",
Sources: []policyv2.Alias{usernamep("user1@")},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny),
},
},
},
}
// Run through the policy cycle 5 times
for i := range 5 {
iteration := i + 1 // range 5 gives 0-4, we want 1-5 for logging
t.Logf("=== Iteration %d/5 ===", iteration)
// Phase 1: Allow all policy
t.Logf("Iteration %d: Setting allow-all policy", iteration)
err = headscale.SetPolicy(allowAllPolicy)
require.NoError(t, err)
// Wait for peer lists to sync with allow-all policy
t.Logf("Iteration %d: Phase 1 - Waiting for peer lists to sync with allow-all policy", iteration)
err = scenario.WaitForTailscaleSync()
require.NoError(t, err, "iteration %d: Phase 1 - failed to sync after allow-all policy", iteration)
// Test all-to-all connectivity after state is settled
t.Logf("Iteration %d: Phase 1 - Testing all-to-all connectivity", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
for _, peer := range allClients {
if client.ContainerID() == peer.ContainerID() {
continue
}
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: %s should reach %s with allow-all policy", iteration, client.Hostname(), fqdn)
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn)
}
}
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 1 - all connectivity tests with allow-all policy", iteration)
// Phase 2: Autogroup:self policy (only same user can access)
t.Logf("Iteration %d: Phase 2 - Setting autogroup:self policy", iteration)
err = headscale.SetPolicy(autogroupSelfPolicy)
require.NoError(t, err)
// Wait for peer lists to sync with autogroup:self - ensures cross-user peers are removed
t.Logf("Iteration %d: Phase 2 - Waiting for peer lists to sync with autogroup:self", iteration)
err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond)
require.NoError(t, err, "iteration %d: Phase 2 - failed to sync after autogroup:self policy", iteration)
// Test ALL connectivity (positive and negative) in one block after state is settled
t.Logf("Iteration %d: Phase 2 - Testing all connectivity with autogroup:self", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Positive: user1 can access user1's nodes
for _, client := range user1Clients {
for _, peer := range user1Clients {
if client.ContainerID() == peer.ContainerID() {
continue
}
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname())
}
}
// Positive: user2 can access user2's nodes
for _, client := range user2Clients {
for _, peer := range user2Clients {
if client.ContainerID() == peer.ContainerID() {
continue
}
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: user2 %s should reach user2's node %s", iteration, client.Hostname(), fqdn)
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn)
}
}
// Negative: user1 cannot access user2's nodes
for _, client := range user1Clients {
for _, peer := range user2Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.Error(ct, err, "iteration %d: user1 %s should NOT reach user2's node %s with autogroup:self", iteration, client.Hostname(), fqdn)
assert.Empty(ct, result, "iteration %d: user1 %s->user2 %s should fail", iteration, client.Hostname(), fqdn)
}
}
// Negative: user2 cannot access user1's nodes
for _, client := range user2Clients {
for _, peer := range user1Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Empty(ct, result, "iteration %d: user2->user1 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname())
}
}
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2 - all connectivity tests with autogroup:self", iteration)
// Phase 2b: Add a new node to user1 and validate policy propagation
t.Logf("Iteration %d: Phase 2b - Adding new node to user1 during autogroup:self policy", iteration)
// Add a new node with the same options as the initial setup
// Get the network to use (scenario uses first network in list)
networks := scenario.Networks()
require.NotEmpty(t, networks, "scenario should have at least one network")
newClient := scenario.MustAddAndLoginClient(t, "user1", "all", headscale,
tsic.WithNetfilter("off"),
tsic.WithDockerEntrypoint([]string{
"/bin/sh",
"-c",
"/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev",
}),
tsic.WithDockerWorkdir("/"),
tsic.WithNetwork(networks[0]),
)
t.Logf("Iteration %d: Phase 2b - Added and logged in new node %s", iteration, newClient.Hostname())
// Wait for peer lists to sync after new node addition (now 3 user1 nodes, still autogroup:self)
t.Logf("Iteration %d: Phase 2b - Waiting for peer lists to sync after new node addition", iteration)
err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond)
require.NoError(t, err, "iteration %d: Phase 2b - failed to sync after new node addition", iteration)
// Test ALL connectivity (positive and negative) in one block after state is settled
t.Logf("Iteration %d: Phase 2b - Testing all connectivity after new node addition", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Re-fetch client list to ensure latest state
user1ClientsWithNew, err := scenario.ListTailscaleClients("user1")
assert.NoError(ct, err, "iteration %d: failed to list user1 clients", iteration)
assert.Len(ct, user1ClientsWithNew, 3, "iteration %d: user1 should have 3 nodes", iteration)
// Positive: all user1 nodes can access each other
for _, client := range user1ClientsWithNew {
for _, peer := range user1ClientsWithNew {
if client.ContainerID() == peer.ContainerID() {
continue
}
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname())
}
}
// Negative: user1 nodes cannot access user2's nodes
for _, client := range user1ClientsWithNew {
for _, peer := range user2Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.Error(ct, err, "iteration %d: user1 node %s should NOT reach user2 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Empty(ct, result, "iteration %d: user1->user2 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname())
}
}
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - all connectivity tests after new node addition", iteration)
// Delete the newly added node before Phase 3
t.Logf("Iteration %d: Phase 2b - Deleting the newly added node from user1", iteration)
// Get the node list and find the newest node (highest ID)
var nodeList []*v1.Node
var nodeToDeleteID uint64
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
nodeList, err = headscale.ListNodes("user1")
assert.NoError(ct, err)
assert.Len(ct, nodeList, 3, "should have 3 user1 nodes before deletion")
// Find the node with the highest ID (the newest one)
for _, node := range nodeList {
if node.GetId() > nodeToDeleteID {
nodeToDeleteID = node.GetId()
}
}
}, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - listing nodes before deletion", iteration)
// Delete the node via headscale helper
t.Logf("Iteration %d: Phase 2b - Deleting node ID %d from headscale", iteration, nodeToDeleteID)
err = headscale.DeleteNode(nodeToDeleteID)
require.NoError(t, err, "iteration %d: failed to delete node %d", iteration, nodeToDeleteID)
// Remove the deleted client from the scenario's user.Clients map
// This is necessary for WaitForTailscaleSyncPerUser to calculate correct peer counts
t.Logf("Iteration %d: Phase 2b - Removing deleted client from scenario", iteration)
for clientName, client := range scenario.users["user1"].Clients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
if err != nil {
continue
}
if nodeID == nodeToDeleteID {
delete(scenario.users["user1"].Clients, clientName)
t.Logf("Iteration %d: Phase 2b - Removed client %s (node ID %d) from scenario", iteration, clientName, nodeToDeleteID)
break
}
}
// Verify the node has been deleted
t.Logf("Iteration %d: Phase 2b - Verifying node deletion (expecting 2 user1 nodes)", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
nodeListAfter, err := headscale.ListNodes("user1")
assert.NoError(ct, err, "failed to list nodes after deletion")
assert.Len(ct, nodeListAfter, 2, "iteration %d: should have 2 user1 nodes after deletion, got %d", iteration, len(nodeListAfter))
}, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - node should be deleted", iteration)
// Wait for sync after deletion to ensure peer counts are correct
// Use WaitForTailscaleSyncPerUser because autogroup:self is still active,
// so nodes only see same-user peers, not all nodes
t.Logf("Iteration %d: Phase 2b - Waiting for sync after node deletion (with autogroup:self)", iteration)
err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond)
require.NoError(t, err, "iteration %d: failed to sync after node deletion", iteration)
// Refresh client lists after deletion to ensure we don't reference the deleted node
user1Clients, err = scenario.ListTailscaleClients("user1")
require.NoError(t, err, "iteration %d: failed to refresh user1 client list after deletion", iteration)
user2Clients, err = scenario.ListTailscaleClients("user2")
require.NoError(t, err, "iteration %d: failed to refresh user2 client list after deletion", iteration)
// Create NEW slice instead of appending to old allClients which still has deleted client
allClients = make([]TailscaleClient, 0, len(user1Clients)+len(user2Clients))
allClients = append(allClients, user1Clients...)
allClients = append(allClients, user2Clients...)
t.Logf("Iteration %d: Phase 2b completed - New node added, validated, and removed successfully", iteration)
// Phase 3: User1 can access user2 but not reverse
t.Logf("Iteration %d: Phase 3 - Setting user1->user2 directional policy", iteration)
err = headscale.SetPolicy(user1ToUser2Policy)
require.NoError(t, err)
// Note: Cannot use WaitForTailscaleSync() here because directional policy means
// user2 nodes don't see user1 nodes in their peer list (asymmetric visibility).
// The EventuallyWithT block below will handle waiting for policy propagation.
// Test ALL connectivity (positive and negative) in one block after policy settles
t.Logf("Iteration %d: Phase 3 - Testing all connectivity with directional policy", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Positive: user1 can access user2's nodes
for _, client := range user1Clients {
for _, peer := range user2Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: user1 node %s should reach user2 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname())
}
}
// Negative: user2 cannot access user1's nodes
for _, client := range user2Clients {
for _, peer := range user1Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Empty(ct, result, "iteration %d: user2->user1 from %s to %s should fail", iteration, client.Hostname(), peer.Hostname())
}
}
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 3 - all connectivity tests with directional policy", iteration)
t.Logf("=== Iteration %d/5 completed successfully - All 3 phases passed ===", iteration)
}
t.Log("All 5 iterations completed successfully - ACL propagation is working correctly")
}

View File

@@ -74,14 +74,21 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
clientIPs[client] = ips
}
listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
var listNodes []*v1.Node
var nodeCountBeforeLogout int
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, len(allClients))
for _, node := range listNodes {
assertLastSeenSet(t, node)
}
for _, node := range listNodes {
assertLastSeenSetWithCollect(c, node)
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout")
nodeCountBeforeLogout = len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
for _, client := range allClients {
err := client.Logout()
@@ -188,11 +195,16 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
}
listNodes, err = headscale.ListNodes()
require.Len(t, listNodes, nodeCountBeforeLogout)
for _, node := range listNodes {
assertLastSeenSet(t, node)
}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, nodeCountBeforeLogout)
for _, node := range listNodes {
assertLastSeenSetWithCollect(c, node)
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for node list after relogin")
})
}
}
@@ -238,9 +250,16 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
var listNodes []*v1.Node
var nodeCountBeforeLogout int
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, len(allClients))
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout")
nodeCountBeforeLogout = len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
for _, client := range allClients {
@@ -371,9 +390,16 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
var listNodes []*v1.Node
var nodeCountBeforeLogout int
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, len(allClients))
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout")
nodeCountBeforeLogout = len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
for _, client := range allClients {

View File

@@ -901,15 +901,18 @@ func TestOIDCFollowUpUrl(t *testing.T) {
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
time.Sleep(2 * time.Minute)
st, err := ts.Status()
require.NoError(t, err)
assert.Equal(t, "NeedsLogin", st.BackendState)
var newUrl *url.URL
assert.EventuallyWithT(t, func(c *assert.CollectT) {
st, err := ts.Status()
assert.NoError(c, err)
assert.Equal(c, "NeedsLogin", st.BackendState)
// get new AuthURL from daemon
newUrl, err := url.Parse(st.AuthURL)
require.NoError(t, err)
// get new AuthURL from daemon
newUrl, err = url.Parse(st.AuthURL)
assert.NoError(c, err)
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
assert.NotEqual(c, u.String(), st.AuthURL, "AuthURL should change")
}, 10*time.Second, 200*time.Millisecond, "Waiting for registration cache to expire and status to reflect NeedsLogin")
_, err = doLoginURL(ts.Hostname(), newUrl)
require.NoError(t, err)
@@ -943,9 +946,11 @@ func TestOIDCFollowUpUrl(t *testing.T) {
t.Fatalf("unexpected users: %s", diff)
}
listNodes, err := headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, listNodes, 1)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
listNodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, 1)
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login")
}
// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client

File diff suppressed because it is too large Load Diff

View File

@@ -25,6 +25,7 @@ type ControlServer interface {
CreateUser(user string) (*v1.User, error)
CreateAuthKey(user uint64, reusable bool, ephemeral bool) (*v1.PreAuthKey, error)
ListNodes(users ...string) ([]*v1.Node, error)
DeleteNode(nodeID uint64) error
NodesByUser() (map[string][]*v1.Node, error)
NodesByName() (map[string]*v1.Node, error)
ListUsers() ([]*v1.User, error)
@@ -38,4 +39,5 @@ type ControlServer interface {
PrimaryRoutes() (*routes.DebugRoutes, error)
DebugBatcher() (*hscontrol.DebugBatcherInfo, error)
DebugNodeStore() (map[types.NodeID]types.Node, error)
DebugFilter() ([]tailcfg.FilterRule, error)
}

View File

@@ -541,8 +541,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
// update hostnames using the up command
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
status := client.MustStatus()
command := []string{
"tailscale",
@@ -642,8 +641,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
}, 60*time.Second, 2*time.Second)
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
status := client.MustStatus()
command := []string{
"tailscale",
@@ -773,26 +771,25 @@ func TestExpireNode(t *testing.T) {
// Verify that the expired node has been marked in all peers list.
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
if client.Hostname() == node.GetName() {
continue
}
if client.Hostname() != node.GetName() {
t.Logf("available peers of %s: %v", client.Hostname(), status.Peers())
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
// Ensures that the node is present, and that it is expired.
if peerStatus, ok := status.Peer[expiredNodeKey]; ok {
requireNotNil(t, peerStatus.Expired)
assert.NotNil(t, peerStatus.KeyExpiry)
peerStatus, ok := status.Peer[expiredNodeKey]
assert.True(c, ok, "expired node key should be present in peer list")
if ok {
assert.NotNil(c, peerStatus.Expired)
assert.NotNil(c, peerStatus.KeyExpiry)
t.Logf(
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
now.String(),
peerStatus.KeyExpiry,
)
if peerStatus.KeyExpiry != nil {
assert.Truef(
t,
c,
peerStatus.KeyExpiry.Before(now),
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
@@ -802,7 +799,7 @@ func TestExpireNode(t *testing.T) {
}
assert.Truef(
t,
c,
peerStatus.Expired,
"node %q should be expired, expired is %v",
peerStatus.HostName,
@@ -811,24 +808,14 @@ func TestExpireNode(t *testing.T) {
_, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()})
if !strings.Contains(stderr, "node key has expired") {
t.Errorf(
c.Errorf(
"expected to be unable to ping expired host %q from %q",
node.GetName(),
client.Hostname(),
)
}
} else {
t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey)
}
} else {
if status.Self.KeyExpiry != nil {
assert.Truef(t, status.Self.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", status.Self.HostName, now.String(), status.Self.KeyExpiry)
}
// NeedsLogin means that the node has understood that it is no longer
// valid.
assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName)
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for expired node status to propagate")
}
}
@@ -866,11 +853,13 @@ func TestNodeOnlineStatus(t *testing.T) {
t.Logf("before expire: %d successful pings out of %d", success, len(allClients)*len(allIps))
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
// Assert that we have the original count - self
assert.Len(t, status.Peers(), len(MustTestVersions)-1)
// Assert that we have the original count - self
assert.Len(c, status.Peers(), len(MustTestVersions)-1)
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer count")
}
headscale, err := scenario.Headscale()

View File

@@ -507,6 +507,11 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) {
assert.NotNil(t, node.GetLastSeen())
}
func assertLastSeenSetWithCollect(c *assert.CollectT, node *v1.Node) {
assert.NotNil(c, node)
assert.NotNil(c, node.GetLastSeen())
}
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
// are in the logged-out state (NeedsLogin).
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
@@ -633,50 +638,50 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
t.Logf("Checking netmap of %q", client.Hostname())
netmap, err := client.Netmap()
if err != nil {
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
netmap, err := client.Netmap()
assert.NoError(c, err, "getting netmap for %q", client.Hostname())
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
}
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
for _, peer := range netmap.Peers {
assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
if hi := peer.Hostinfo(); hi.Valid() {
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
// Netinfo is not always set
// assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
if ni := hi.NetInfo(); ni.Valid() {
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
}
assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
}
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(c, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
assert.NotEmptyf(c, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
assert.Truef(c, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
}
assert.Falsef(c, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
assert.Falsef(c, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
assert.Falsef(c, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
for _, peer := range netmap.Peers {
assert.NotEqualf(c, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
if hi := peer.Hostinfo(); hi.Valid() {
assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
// Netinfo is not always set
// assert.Truef(c, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
if ni := hi.NetInfo(); ni.Valid() {
assert.NotEqualf(c, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
}
}
assert.NotEmptyf(c, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(c, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(c, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
assert.Truef(c, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
assert.Falsef(c, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
assert.Falsef(c, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
assert.Falsef(c, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for valid netmap for %q", client.Hostname())
}
// assertValidStatus validates that a client's status has all required fields for proper operation.
@@ -920,3 +925,125 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
EmailVerified: emailVerified,
}
}
// GetUserByName retrieves a user by name from the headscale server.
// This is a common pattern used when creating preauth keys or managing users.
func GetUserByName(headscale ControlServer, username string) (*v1.User, error) {
users, err := headscale.ListUsers()
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
for _, u := range users {
if u.GetName() == username {
return u, nil
}
}
return nil, fmt.Errorf("user %s not found", username)
}
// FindNewClient finds a client that is in the new list but not in the original list.
// This is useful when dynamically adding nodes during tests and needing to identify
// which client was just added.
func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) {
for _, client := range updated {
isOriginal := false
for _, origClient := range original {
if client.Hostname() == origClient.Hostname() {
isOriginal = true
break
}
}
if !isOriginal {
return client, nil
}
}
return nil, fmt.Errorf("no new client found")
}
// AddAndLoginClient adds a new tailscale client to a user and logs it in.
// This combines the common pattern of:
// 1. Creating a new node
// 2. Finding the new node in the client list
// 3. Getting the user to create a preauth key
// 4. Logging in the new node
func (s *Scenario) AddAndLoginClient(
t *testing.T,
username string,
version string,
headscale ControlServer,
tsOpts ...tsic.Option,
) (TailscaleClient, error) {
t.Helper()
// Get the original client list
originalClients, err := s.ListTailscaleClients(username)
if err != nil {
return nil, fmt.Errorf("failed to list original clients: %w", err)
}
// Create the new node
err = s.CreateTailscaleNodesInUser(username, version, 1, tsOpts...)
if err != nil {
return nil, fmt.Errorf("failed to create tailscale node: %w", err)
}
// Wait for the new node to appear in the client list
var newClient TailscaleClient
_, err = backoff.Retry(t.Context(), func() (struct{}, error) {
updatedClients, err := s.ListTailscaleClients(username)
if err != nil {
return struct{}{}, fmt.Errorf("failed to list updated clients: %w", err)
}
if len(updatedClients) != len(originalClients)+1 {
return struct{}{}, fmt.Errorf("expected %d clients, got %d", len(originalClients)+1, len(updatedClients))
}
newClient, err = FindNewClient(originalClients, updatedClients)
if err != nil {
return struct{}{}, fmt.Errorf("failed to find new client: %w", err)
}
return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewConstantBackOff(500*time.Millisecond)), backoff.WithMaxElapsedTime(10*time.Second))
if err != nil {
return nil, fmt.Errorf("timeout waiting for new client: %w", err)
}
// Get the user and create preauth key
user, err := GetUserByName(headscale, username)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
authKey, err := s.CreatePreAuthKey(user.GetId(), true, false)
if err != nil {
return nil, fmt.Errorf("failed to create preauth key: %w", err)
}
// Login the new client
err = newClient.Login(headscale.GetEndpoint(), authKey.GetKey())
if err != nil {
return nil, fmt.Errorf("failed to login new client: %w", err)
}
return newClient, nil
}
// MustAddAndLoginClient is like AddAndLoginClient but fails the test on error.
func (s *Scenario) MustAddAndLoginClient(
t *testing.T,
username string,
version string,
headscale ControlServer,
tsOpts ...tsic.Option,
) TailscaleClient {
t.Helper()
client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...)
require.NoError(t, err)
return client
}

View File

@@ -1082,6 +1082,30 @@ func (t *HeadscaleInContainer) ListNodes(
return ret, nil
}
func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error {
command := []string{
"headscale",
"nodes",
"delete",
"--identifier",
fmt.Sprintf("%d", nodeID),
"--output",
"json",
"--force",
}
_, _, err := dockertestutil.ExecuteCommand(
t.container,
command,
[]string{},
)
if err != nil {
return fmt.Errorf("failed to execute delete node command: %w", err)
}
return nil
}
func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) {
nodes, err := t.ListNodes()
if err != nil {
@@ -1397,3 +1421,38 @@ func (t *HeadscaleInContainer) DebugNodeStore() (map[types.NodeID]types.Node, er
return nodeStore, nil
}
// DebugFilter fetches the current filter rules from the debug endpoint.
func (t *HeadscaleInContainer) DebugFilter() ([]tailcfg.FilterRule, error) {
// Execute curl inside the container to access the debug endpoint locally
command := []string{
"curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/filter",
}
result, err := t.Execute(command)
if err != nil {
return nil, fmt.Errorf("fetching filter from debug endpoint: %w", err)
}
var filterRules []tailcfg.FilterRule
if err := json.Unmarshal([]byte(result), &filterRules); err != nil {
return nil, fmt.Errorf("decoding filter response: %w", err)
}
return filterRules, nil
}
// DebugPolicy fetches the current policy from the debug endpoint.
func (t *HeadscaleInContainer) DebugPolicy() (string, error) {
// Execute curl inside the container to access the debug endpoint locally
command := []string{
"curl", "-s", "http://localhost:9090/debug/policy",
}
result, err := t.Execute(command)
if err != nil {
return "", fmt.Errorf("fetching policy from debug endpoint: %w", err)
}
return result, nil
}

View File

@@ -1358,16 +1358,8 @@ func TestSubnetRouteACL(t *testing.T) {
// Sort nodes by ID
sort.SliceStable(allClients, func(i, j int) bool {
statusI, err := allClients[i].Status()
if err != nil {
return false
}
statusJ, err := allClients[j].Status()
if err != nil {
return false
}
statusI := allClients[i].MustStatus()
statusJ := allClients[j].MustStatus()
return statusI.Self.ID < statusJ.Self.ID
})
@@ -1475,9 +1467,7 @@ func TestSubnetRouteACL(t *testing.T) {
requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes["1"])})
}, 5*time.Second, 200*time.Millisecond, "Verifying client can see subnet routes from router")
clientNm, err := client.Netmap()
require.NoError(t, err)
// Wait for packet filter updates to propagate to client netmap
wantClientFilter := []filter.Match{
{
IPProto: views.SliceOf([]ipproto.Proto{
@@ -1503,13 +1493,16 @@ func TestSubnetRouteACL(t *testing.T) {
},
}
if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
clientNm, err := client.Netmap()
assert.NoError(c, err)
subnetNm, err := subRouter1.Netmap()
require.NoError(t, err)
if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
assert.Fail(c, fmt.Sprintf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff))
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for client packet filter to update")
// Wait for packet filter updates to propagate to subnet router netmap
wantSubnetFilter := []filter.Match{
{
IPProto: views.SliceOf([]ipproto.Proto{
@@ -1553,9 +1546,14 @@ func TestSubnetRouteACL(t *testing.T) {
},
}
if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
subnetNm, err := subRouter1.Netmap()
assert.NoError(c, err)
if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
assert.Fail(c, fmt.Sprintf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff))
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for subnet router packet filter to update")
}
// TestEnablingExitRoutes tests enabling exit routes for clients.
@@ -1592,12 +1590,16 @@ func TestEnablingExitRoutes(t *testing.T) {
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
nodes, err := headscale.ListNodes()
require.NoError(t, err)
require.Len(t, nodes, 2)
var nodes []*v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
requireNodeRouteCount(t, nodes[0], 2, 0, 0)
requireNodeRouteCount(t, nodes[1], 2, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[0], 2, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[1], 2, 0, 0)
}, 10*time.Second, 200*time.Millisecond, "Waiting for route advertisements to propagate")
// Verify that no routes has been sent to the client,
// they are not yet enabled.

View File

@@ -693,6 +693,35 @@ func (s *Scenario) WaitForTailscaleSync() error {
return err
}
// WaitForTailscaleSyncPerUser blocks execution until each TailscaleClient has the expected
// number of peers for its user. This is useful for policies like autogroup:self where nodes
// only see same-user peers, not all nodes in the network.
func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Duration) error {
var allErrors []error
for _, user := range s.users {
// Calculate expected peer count: number of nodes in this user minus 1 (self)
expectedPeers := len(user.Clients) - 1
for _, client := range user.Clients {
c := client
expectedCount := expectedPeers
user.syncWaitGroup.Go(func() error {
return c.WaitForPeers(expectedCount, timeout, retryInterval)
})
}
if err := user.syncWaitGroup.Wait(); err != nil {
allErrors = append(allErrors, err)
}
}
if len(allErrors) > 0 {
return multierr.New(allErrors...)
}
return nil
}
// WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports
// to have all other TailscaleClients present in their netmap.NetworkMap.
func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int, timeout, retryInterval time.Duration) error {

View File

@@ -14,6 +14,7 @@ import (
"tailscale.com/net/netcheck"
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/wgengine/filter"
)
// nolint
@@ -36,6 +37,7 @@ type TailscaleClient interface {
MustIPv4() netip.Addr
MustIPv6() netip.Addr
FQDN() (string, error)
MustFQDN() string
Status(...bool) (*ipnstate.Status, error)
MustStatus() *ipnstate.Status
Netmap() (*netmap.NetworkMap, error)
@@ -52,6 +54,7 @@ type TailscaleClient interface {
ContainerID() string
MustID() types.NodeID
ReadFile(path string) ([]byte, error)
PacketFilter() ([]filter.Match, error)
// FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client
// and a bool indicating if the clients online count and peer count is equal.

View File

@@ -18,6 +18,7 @@ import (
"strings"
"time"
"github.com/cenkalti/backoff/v5"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
@@ -32,6 +33,7 @@ import (
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/util/multierr"
"tailscale.com/wgengine/filter"
)
const (
@@ -597,28 +599,39 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
return t.ips, nil
}
ips := make([]netip.Addr, 0)
command := []string{
"tailscale",
"ip",
}
result, _, err := t.Execute(command)
if err != nil {
return []netip.Addr{}, fmt.Errorf("%s failed to join tailscale client: %w", t.hostname, err)
}
for address := range strings.SplitSeq(result, "\n") {
address = strings.TrimSuffix(address, "\n")
if len(address) < 1 {
continue
// Retry with exponential backoff to handle eventual consistency
ips, err := backoff.Retry(context.Background(), func() ([]netip.Addr, error) {
command := []string{
"tailscale",
"ip",
}
ip, err := netip.ParseAddr(address)
result, _, err := t.Execute(command)
if err != nil {
return nil, err
return nil, fmt.Errorf("%s failed to get IPs: %w", t.hostname, err)
}
ips = append(ips, ip)
ips := make([]netip.Addr, 0)
for address := range strings.SplitSeq(result, "\n") {
address = strings.TrimSuffix(address, "\n")
if len(address) < 1 {
continue
}
ip, err := netip.ParseAddr(address)
if err != nil {
return nil, fmt.Errorf("failed to parse IP %s: %w", address, err)
}
ips = append(ips, ip)
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IPs returned yet for %s", t.hostname)
}
return ips, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
if err != nil {
return nil, fmt.Errorf("failed to get IPs for %s after retries: %w", t.hostname, err)
}
return ips, nil
@@ -629,7 +642,6 @@ func (t *TailscaleInContainer) MustIPs() []netip.Addr {
if err != nil {
panic(err)
}
return ips
}
@@ -646,16 +658,15 @@ func (t *TailscaleInContainer) IPv4() (netip.Addr, error) {
}
}
return netip.Addr{}, errors.New("no IPv4 address found")
return netip.Addr{}, fmt.Errorf("no IPv4 address found for %s", t.hostname)
}
func (t *TailscaleInContainer) MustIPv4() netip.Addr {
for _, ip := range t.MustIPs() {
if ip.Is4() {
return ip
}
ip, err := t.IPv4()
if err != nil {
panic(err)
}
panic("no ipv4 found")
return ip
}
func (t *TailscaleInContainer) MustIPv6() netip.Addr {
@@ -900,12 +911,33 @@ func (t *TailscaleInContainer) FQDN() (string, error) {
return t.fqdn, nil
}
status, err := t.Status()
// Retry with exponential backoff to handle eventual consistency
fqdn, err := backoff.Retry(context.Background(), func() (string, error) {
status, err := t.Status()
if err != nil {
return "", fmt.Errorf("failed to get status: %w", err)
}
if status.Self.DNSName == "" {
return "", fmt.Errorf("FQDN not yet available")
}
return status.Self.DNSName, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
if err != nil {
return "", fmt.Errorf("failed to get FQDN: %w", err)
return "", fmt.Errorf("failed to get FQDN for %s after retries: %w", t.hostname, err)
}
return status.Self.DNSName, nil
return fqdn, nil
}
// MustFQDN returns the FQDN as a string of the Tailscale instance, panicking on error.
func (t *TailscaleInContainer) MustFQDN() string {
fqdn, err := t.FQDN()
if err != nil {
panic(err)
}
return fqdn
}
// FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client
@@ -1353,3 +1385,18 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
return &p.Persist.PrivateNodeKey, nil
}
// PacketFilter returns the current packet filter rules from the client's network map.
// This is useful for verifying that policy changes have propagated to the client.
func (t *TailscaleInContainer) PacketFilter() ([]filter.Match, error) {
if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
return nil, fmt.Errorf("tsic.PacketFilter() requires Tailscale 1.56+, current version: %s", t.version)
}
nm, err := t.Netmap()
if err != nil {
return nil, fmt.Errorf("failed to get netmap: %w", err)
}
return nm.PacketFilter, nil
}

View File

@@ -136,7 +136,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
}
// Write to file
err = os.WriteFile(outputFile, formatted, 0644)
err = os.WriteFile(outputFile, formatted, 0o644)
if err != nil {
return fmt.Errorf("error writing file: %w", err)
}