From 0acf09bdd26879867e934e980b737e59e06a489f Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 24 Feb 2026 19:39:52 +0000 Subject: [PATCH] policy/v2: add localpart:*@domain SSH user compilation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for localpart:*@ entries in SSH policy users. When a user SSHes into a target, their email local-part becomes the OS username (e.g. alice@example.com → OS user alice). Type system (types.go): - SSHUser.IsLocalpart() and ParseLocalpart() for validation - SSHUsers.LocalpartEntries(), NormalUsers(), ContainsLocalpart() - Enforces format: localpart:*@ (wildcard-only) - UserWildcard.Resolve for user:*@domain SSH source aliases - acceptEnv passthrough for SSH rules Compilation (filter.go): - resolveLocalparts: pure function mapping users to local-parts by email domain. No node walking, easy to test. - groupSourcesByUser: single walk producing per-user principals with sorted user IDs, and tagged principals separately. - ipSetToPrincipals: shared helper replacing 6 inline copies. - selfPrincipalsForNode: self-access using pre-computed byUser. The approach separates data gathering from rule assembly. Localpart rules are interleaved per source user to match Tailscale SaaS first-match-wins ordering. Updates #3049 --- hscontrol/policy/policy_test.go | 119 ++++ hscontrol/policy/v2/filter.go | 321 ++++++++--- hscontrol/policy/v2/filter_test.go | 864 +++++++++++++++++++++++++---- hscontrol/policy/v2/types.go | 68 ++- hscontrol/policy/v2/types_test.go | 300 +++++++++- 5 files changed, 1452 insertions(+), 220 deletions(-) diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 536c86f3..ebce8de5 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1077,6 +1077,8 @@ func TestSSHPolicyRules(t *testing.T) { {Name: "user1", Model: gorm.Model{ID: 1}}, {Name: "user2", Model: gorm.Model{ID: 2}}, {Name: "user3", Model: gorm.Model{ID: 3}}, + {Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 4}}, + {Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 5}}, } // Create standard node setups used across tests @@ -1110,6 +1112,20 @@ func TestSSHPolicyRules(t *testing.T) { Tags: []string{"tag:server"}, } + // Nodes for localpart tests (users with email addresses) + nodeAlice := types.Node{ + Hostname: "alice-device", + IPv4: ap("100.64.0.6"), + UserID: new(uint(4)), + User: new(users[3]), + } + nodeBob := types.Node{ + Hostname: "bob-device", + IPv4: ap("100.64.0.7"), + UserID: new(uint(5)), + User: new(users[4]), + } + tests := []struct { name string targetNode types.Node @@ -1446,6 +1462,7 @@ func TestSSHPolicyRules(t *testing.T) { }, SSHUsers: map[string]string{ "debian": "debian", + "root": "", }, Action: &tailcfg.SSHAction{ Accept: true, @@ -1456,6 +1473,108 @@ func TestSSHPolicyRules(t *testing.T) { }, }}, }, + { + name: "localpart-maps-email-to-os-user", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeAlice, &nodeBob}, + policy: `{ + "tagOwners": { + "tag:server": ["alice@example.com"] + }, + "ssh": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:server"], + "users": ["localpart:*@example.com"] + } + ] + }`, + // Per-user common+localpart interleaved: each user gets root deny then localpart. + wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}}, + SSHUsers: map[string]string{"root": ""}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}}, + SSHUsers: map[string]string{"alice": "alice"}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.7"}}, + SSHUsers: map[string]string{"root": ""}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.7"}}, + SSHUsers: map[string]string{"bob": "bob"}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }}, + }, + { + name: "localpart-combined-with-root", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeAlice}, + policy: `{ + "tagOwners": { + "tag:server": ["alice@example.com"] + }, + "ssh": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:server"], + "users": ["localpart:*@example.com", "root"] + } + ] + }`, + // Common root rule followed by alice's per-user localpart rule (interleaved). + wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}}, + SSHUsers: map[string]string{"root": "root"}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}}, + SSHUsers: map[string]string{"alice": "alice"}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }}, + }, } for _, tt := range tests { diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 9df62525..a0888836 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -362,7 +362,6 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { } } -//nolint:gocyclo // complex SSH policy compilation logic func (pol *Policy) compileSSHPolicy( baseURL string, users types.Users, @@ -378,13 +377,7 @@ func (pol *Policy) compileSSHPolicy( var rules []*tailcfg.SSHRule for index, rule := range pol.SSHs { - // Separate destinations into autogroup:self and others - // This is needed because autogroup:self requires filtering sources to same-user only, - // while other destinations should use all resolved sources - var ( - autogroupSelfDests []Alias - otherDests []Alias - ) + var autogroupSelfDests, otherDests []Alias for _, dst := range rule.Destinations { if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -394,12 +387,11 @@ func (pol *Policy) compileSSHPolicy( } } - // Note: Tagged nodes can't match autogroup:self destinations, but can still match other destinations - - // Resolve sources once - we'll use them differently for each destination type srcIPs, err := rule.Sources.Resolve(pol, users, nodes) if err != nil { - log.Trace().Caller().Err(err).Msgf("ssh policy compilation failed resolving source ips for rule %+v", rule) + log.Trace().Caller().Err(err).Msgf( + "ssh policy compilation failed resolving source ips for rule %+v", rule, + ) } if srcIPs == nil || len(srcIPs.Prefixes()) == 0 { @@ -414,95 +406,87 @@ func (pol *Policy) compileSSHPolicy( case SSHActionCheck: action = sshCheck(baseURL, checkPeriodFromRule(rule)) default: - return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) + return nil, fmt.Errorf( + "parsing SSH policy, unknown action %q, index: %d: %w", + rule.Action, index, err, + ) } - userMap := make(map[string]string, len(rule.Users)) + acceptEnv := rule.AcceptEnv + + // Build the common userMap (always has at least a root entry). + const rootUser = "root" + + baseUserMap := make(map[string]string, len(rule.Users)) if rule.Users.ContainsNonRoot() { - userMap["*"] = "=" - // by default, we do not allow root unless explicitly stated - userMap["root"] = "" + baseUserMap["*"] = "=" } if rule.Users.ContainsRoot() { - userMap["root"] = "root" + baseUserMap[rootUser] = rootUser + } else { + baseUserMap[rootUser] = "" } for _, u := range rule.Users.NormalUsers() { - userMap[u.String()] = u.String() + baseUserMap[u.String()] = u.String() } - // Handle autogroup:self destinations (if any) - // Note: Tagged nodes can't match autogroup:self, so skip this block for tagged nodes - if len(autogroupSelfDests) > 0 && !node.IsTagged() { - // Build destination set for autogroup:self (same-user untagged devices only) - var dest netipx.IPSetBuilder + hasLocalpart := rule.Users.ContainsLocalpart() - for _, n := range nodes.All() { - if !n.IsTagged() && n.User().ID() == node.User().ID() { - n.AppendToIPSet(&dest) - } - } + var localpartByUser map[uint]string + if hasLocalpart { + localpartByUser = resolveLocalparts( + rule.Users.LocalpartEntries(), users, + ) + } - destSet, err := dest.IPSet() - if err != nil { - return nil, err - } + userIDs, principalsByUser, taggedPrincipals := groupSourcesByUser( + nodes, srcIPs, + ) - // Only create rule if this node is in the destination set - if node.InIPSet(destSet) { - // Filter sources to only same-user untagged devices - // Pre-filter to same-user untagged devices for efficiency - sameUserNodes := make([]types.NodeView, 0) + // appendRules emits a common rule and, if the user has a + // localpart match, a per-user localpart rule. + appendRules := func(principals []*tailcfg.SSHPrincipal, uid uint, hasUID bool) { + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: baseUserMap, + Action: &action, + AcceptEnv: acceptEnv, + }) - for _, n := range nodes.All() { - if !n.IsTagged() && n.User().ID() == node.User().ID() { - sameUserNodes = append(sameUserNodes, n) - } - } - - var filteredSrcIPs netipx.IPSetBuilder - - for _, n := range sameUserNodes { - // Check if any of this node's IPs are in the source set - if slices.ContainsFunc(n.IPs(), srcIPs.Contains) { - n.AppendToIPSet(&filteredSrcIPs) // Found this node, move to next - } - } - - filteredSrcSet, err := filteredSrcIPs.IPSet() - if err != nil { - return nil, err - } - - if filteredSrcSet != nil && len(filteredSrcSet.Prefixes()) > 0 { - var principals []*tailcfg.SSHPrincipal - for addr := range util.IPSetAddrIter(filteredSrcSet) { - principals = append(principals, &tailcfg.SSHPrincipal{ - NodeIP: addr.String(), - }) - } - - if len(principals) > 0 { - rules = append(rules, &tailcfg.SSHRule{ - Principals: principals, - SSHUsers: userMap, - Action: &action, - }) - } + if hasUID { + if lp, ok := localpartByUser[uid]; ok { + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: map[string]string{lp: lp}, + Action: &action, + AcceptEnv: acceptEnv, + }) } } } - // Handle other destinations (if any) + // Handle autogroup:self destinations. + // Tagged nodes can't match autogroup:self. + if len(autogroupSelfDests) > 0 && + !node.IsTagged() && node.User().Valid() { + uid := node.User().ID() + + if principals := principalsByUser[uid]; len(principals) > 0 { + appendRules(principals, uid, true) + } + } + + // Handle other destinations. if len(otherDests) > 0 { - // Build destination set for other destinations var dest netipx.IPSetBuilder for _, dst := range otherDests { ips, err := dst.Resolve(pol, users, nodes) if err != nil { - log.Trace().Caller().Err(err).Msgf("resolving destination ips") + log.Trace().Caller().Err(err). + Msgf("resolving destination ips") } if ips != nil { @@ -515,22 +499,48 @@ func (pol *Policy) compileSSHPolicy( return nil, err } - // Only create rule if this node is in the destination set if node.InIPSet(destSet) { - // For non-autogroup:self destinations, use all resolved sources (no filtering) - var principals []*tailcfg.SSHPrincipal - for addr := range util.IPSetAddrIter(srcIPs) { - principals = append(principals, &tailcfg.SSHPrincipal{ - NodeIP: addr.String(), - }) - } + // Node is a destination — emit rules. + // When localpart entries exist, interleave common + // and localpart rules per source user to match + // Tailscale SaaS first-match-wins ordering. + if hasLocalpart { + for _, uid := range userIDs { + appendRules(principalsByUser[uid], uid, true) + } - if len(principals) > 0 { - rules = append(rules, &tailcfg.SSHRule{ - Principals: principals, - SSHUsers: userMap, - Action: &action, - }) + if len(taggedPrincipals) > 0 { + appendRules(taggedPrincipals, 0, false) + } + } else { + if principals := ipSetToPrincipals(srcIPs); len(principals) > 0 { + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: baseUserMap, + Action: &action, + AcceptEnv: acceptEnv, + }) + } + } + } else if hasLocalpart && node.InIPSet(srcIPs) { + // Self-access: source node not in destination set + // receives rules scoped to its own user. + if node.IsTagged() { + var builder netipx.IPSetBuilder + + node.AppendToIPSet(&builder) + + ipSet, err := builder.IPSet() + if err == nil && ipSet != nil { + if principals := ipSetToPrincipals(ipSet); len(principals) > 0 { + appendRules(principals, 0, false) + } + } + } else if node.User().Valid() { + uid := node.User().ID() + if principals := principalsByUser[uid]; len(principals) > 0 { + appendRules(principals, uid, true) + } } } } @@ -558,6 +568,137 @@ func (pol *Policy) compileSSHPolicy( }, nil } +// ipSetToPrincipals converts an IPSet into SSH principals, one per address. +func ipSetToPrincipals(ipSet *netipx.IPSet) []*tailcfg.SSHPrincipal { + if ipSet == nil { + return nil + } + + var principals []*tailcfg.SSHPrincipal + + for addr := range util.IPSetAddrIter(ipSet) { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + return principals +} + +// resolveLocalparts maps each user whose email matches a localpart:*@ +// entry to their email local-part. Returns userID → localPart (e.g. {1: "alice"}). +// This is a pure data function — no node walking or IP resolution. +func resolveLocalparts( + entries []SSHUser, + users types.Users, +) map[uint]string { + if len(entries) == 0 { + return nil + } + + result := make(map[uint]string) + + for _, entry := range entries { + domain, err := entry.ParseLocalpart() + if err != nil { + log.Warn().Err(err).Msgf( + "skipping invalid localpart entry %q during SSH compilation", + entry, + ) + + continue + } + + for _, user := range users { + if user.Email == "" { + continue + } + + atIdx := strings.LastIndex(user.Email, "@") + if atIdx < 0 { + continue + } + + if !strings.EqualFold(user.Email[atIdx+1:], domain) { + continue + } + + result[user.ID] = user.Email[:atIdx] + } + } + + return result +} + +// groupSourcesByUser groups source node IPs by user ownership. Returns sorted +// user IDs for deterministic iteration, per-user principals, and tagged principals. +// Only includes nodes whose IPs are in the srcIPs set. +func groupSourcesByUser( + nodes views.Slice[types.NodeView], + srcIPs *netipx.IPSet, +) ([]uint, map[uint][]*tailcfg.SSHPrincipal, []*tailcfg.SSHPrincipal) { + userIPSets := make(map[uint]*netipx.IPSetBuilder) + + var taggedIPSet netipx.IPSetBuilder + + hasTagged := false + + for _, n := range nodes.All() { + if !slices.ContainsFunc(n.IPs(), srcIPs.Contains) { + continue + } + + if n.IsTagged() { + n.AppendToIPSet(&taggedIPSet) + + hasTagged = true + + continue + } + + if !n.User().Valid() { + continue + } + + uid := n.User().ID() + + if _, ok := userIPSets[uid]; !ok { + userIPSets[uid] = &netipx.IPSetBuilder{} + } + + n.AppendToIPSet(userIPSets[uid]) + } + + var userIDs []uint + + principalsByUser := make(map[uint][]*tailcfg.SSHPrincipal, len(userIPSets)) + + for uid, builder := range userIPSets { + ipSet, err := builder.IPSet() + if err != nil || ipSet == nil { + continue + } + + if principals := ipSetToPrincipals(ipSet); len(principals) > 0 { + principalsByUser[uid] = principals + userIDs = append(userIDs, uid) + } + } + + slices.Sort(userIDs) + + var tagged []*tailcfg.SSHPrincipal + + if hasTagged { + taggedSet, err := taggedIPSet.IPSet() + if err == nil && taggedSet != nil { + tagged = ipSetToPrincipals(taggedSet) + } + } + + return userIDs, principalsByUser, tagged +} + func ipSetToPrefixStringList(ips *netipx.IPSet) []string { var out []string diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index da76e0f8..b6080eae 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -12,6 +12,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go4.org/netipx" "gorm.io/gorm" "tailscale.com/tailcfg" ) @@ -431,12 +432,19 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { nodes := types.Nodes{&nodeTaggedServer, &nodeTaggedDB, &nodeUser2Untagged} + acceptAction := &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + } + user2Principal := []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}} + tests := []struct { - name string - targetNode types.Node - policy *Policy - wantSSHUsers map[string]string - wantEmpty bool + name string + targetNode types.Node + policy *Policy + want *tailcfg.SSHPolicy }{ { name: "specific user mapping", @@ -457,9 +465,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { }, }, }, - wantSSHUsers: map[string]string{ - "ssh-it-user": "ssh-it-user", - }, + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: user2Principal, + SSHUsers: map[string]string{"root": "", "ssh-it-user": "ssh-it-user"}, + Action: acceptAction, + }, + }}, }, { name: "multiple specific users", @@ -480,11 +492,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { }, }, }, - wantSSHUsers: map[string]string{ - "ubuntu": "ubuntu", - "admin": "admin", - "deploy": "deploy", - }, + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: user2Principal, + SSHUsers: map[string]string{"root": "", "ubuntu": "ubuntu", "admin": "admin", "deploy": "deploy"}, + Action: acceptAction, + }, + }}, }, { name: "autogroup:nonroot only", @@ -505,10 +519,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { }, }, }, - wantSSHUsers: map[string]string{ - "*": "=", - "root": "", - }, + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: user2Principal, + SSHUsers: map[string]string{"*": "=", "root": ""}, + Action: acceptAction, + }, + }}, }, { name: "root only", @@ -529,9 +546,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { }, }, }, - wantSSHUsers: map[string]string{ - "root": "root", - }, + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: user2Principal, + SSHUsers: map[string]string{"root": "root"}, + Action: acceptAction, + }, + }}, }, { name: "autogroup:nonroot plus root", @@ -552,10 +573,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { }, }, }, - wantSSHUsers: map[string]string{ - "*": "=", - "root": "root", - }, + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: user2Principal, + SSHUsers: map[string]string{"*": "=", "root": "root"}, + Action: acceptAction, + }, + }}, }, { name: "mixed specific users and autogroups", @@ -576,12 +600,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { }, }, }, - wantSSHUsers: map[string]string{ - "*": "=", - "root": "root", - "ubuntu": "ubuntu", - "admin": "admin", - }, + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: user2Principal, + SSHUsers: map[string]string{"*": "=", "root": "root", "ubuntu": "ubuntu", "admin": "admin"}, + Action: acceptAction, + }, + }}, }, { name: "no matching destination", @@ -603,45 +628,387 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { }, }, }, - wantEmpty: true, + want: &tailcfg.SSHPolicy{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Validate the policy - err := tt.policy.validate() + require.NoError(t, tt.policy.validate()) + + got, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice()) require.NoError(t, err) - // Compile SSH policy - sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice()) - require.NoError(t, err) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("compileSSHPolicy() mismatch (-want +got):\n%s", diff) + } + }) + } +} - if tt.wantEmpty { - if sshPolicy == nil { - return // Expected empty result +func TestCompileSSHPolicy_LocalpartMapping(t *testing.T) { + users := types.Users{ + {Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}, + {Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 2}}, + {Name: "charlie", Email: "charlie@other.com", Model: gorm.Model{ID: 3}}, + {Name: "dave", Model: gorm.Model{ID: 4}}, // CLI user, no email + } + + nodeTaggedServer := types.Node{ + Hostname: "tagged-server", + IPv4: createAddr("100.64.0.1"), + UserID: new(users[0].ID), + User: new(users[0]), + Tags: []string{"tag:server"}, + } + nodeAlice := types.Node{ + Hostname: "alice-device", + IPv4: createAddr("100.64.0.2"), + UserID: new(users[0].ID), + User: new(users[0]), + } + nodeBob := types.Node{ + Hostname: "bob-device", + IPv4: createAddr("100.64.0.3"), + UserID: new(users[1].ID), + User: new(users[1]), + } + nodeCharlie := types.Node{ + Hostname: "charlie-device", + IPv4: createAddr("100.64.0.4"), + UserID: new(users[2].ID), + User: new(users[2]), + } + nodeDave := types.Node{ + Hostname: "dave-device", + IPv4: createAddr("100.64.0.5"), + UserID: new(users[3].ID), + User: new(users[3]), + } + + nodes := types.Nodes{&nodeTaggedServer, &nodeAlice, &nodeBob, &nodeCharlie, &nodeDave} + + acceptAction := &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + } + + tests := []struct { + name string + users types.Users // nil → use default users + nodes types.Nodes // nil → use default nodes + targetNode types.Node + policy *Policy + want *tailcfg.SSHPolicy + }{ + { + name: "localpart only", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("alice@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + }, + }, + // Per-user common+localpart rules interleaved, then non-matching users. + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}}, + SSHUsers: map[string]string{"alice": "alice"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}}, + SSHUsers: map[string]string{"bob": "bob"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + }}, + }, + { + name: "localpart with root", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("alice@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@example.com"), "root"}, + }, + }, + }, + // Per-user common+localpart rules interleaved, then non-matching users. + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}}, + SSHUsers: map[string]string{"root": "root"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}}, + SSHUsers: map[string]string{"alice": "alice"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}}, + SSHUsers: map[string]string{"root": "root"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}}, + SSHUsers: map[string]string{"bob": "bob"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}}, + SSHUsers: map[string]string{"root": "root"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}}, + SSHUsers: map[string]string{"root": "root"}, + Action: acceptAction, + }, + }}, + }, + { + name: "localpart no matching users in domain", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("alice@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@nonexistent.com")}, + }, + }, + }, + // No localpart matches, but per-user common rules still emitted (root deny) + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + }}, + }, + { + name: "localpart with special chars in email", + users: types.Users{ + {Name: "dave+sshuser", Email: "dave+sshuser@example.com", Model: gorm.Model{ID: 10}}, + }, + nodes: func() types.Nodes { + specialUser := types.User{Name: "dave+sshuser", Email: "dave+sshuser@example.com", Model: gorm.Model{ID: 10}} + n := types.Node{ + Hostname: "special-device", + IPv4: createAddr("100.64.0.10"), + UserID: new(specialUser.ID), + User: &specialUser, } - assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match") + return types.Nodes{&nodeTaggedServer, &n} + }(), + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("dave+sshuser@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + }, + }, + // Per-user common rule (root deny), then separate localpart rule. + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.10"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.10"}}, + SSHUsers: map[string]string{"dave+sshuser": "dave+sshuser"}, + Action: acceptAction, + }, + }}, + }, + { + name: "localpart excludes CLI users without email", + users: types.Users{ + {Name: "dave", Model: gorm.Model{ID: 4}}, + }, + nodes: func() types.Nodes { + cliUser := types.User{Name: "dave", Model: gorm.Model{ID: 4}} + n := types.Node{ + Hostname: "dave-cli-device", + IPv4: createAddr("100.64.0.5"), + UserID: new(cliUser.ID), + User: &cliUser, + } - return + return types.Nodes{&nodeTaggedServer, &n} + }(), + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("dave@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + }, + }, + // No localpart matches (CLI user, no email), but implicit root deny emits common rule + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + }}, + }, + { + name: "localpart with multiple domains", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("alice@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{ + SSHUser("localpart:*@example.com"), + SSHUser("localpart:*@other.com"), + }, + }, + }, + }, + // Per-user common+localpart rules interleaved: + // alice/bob match *@example.com, charlie matches *@other.com. + want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}}, + SSHUsers: map[string]string{"alice": "alice"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}}, + SSHUsers: map[string]string{"bob": "bob"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}}, + SSHUsers: map[string]string{"charlie": "charlie"}, + Action: acceptAction, + }, + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}}, + SSHUsers: map[string]string{"root": ""}, + Action: acceptAction, + }, + }}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testUsers := users + if tt.users != nil { + testUsers = tt.users } - require.NotNil(t, sshPolicy) - require.Len(t, sshPolicy.Rules, 1, "Should have exactly one SSH rule") + testNodes := nodes + if tt.nodes != nil { + testNodes = tt.nodes + } - rule := sshPolicy.Rules[0] - assert.Equal(t, tt.wantSSHUsers, rule.SSHUsers, "SSH users mapping should match expected") + require.NoError(t, tt.policy.validate()) - // Verify principals are set correctly (should contain user2's untagged device IP since that's the source) - require.Len(t, rule.Principals, 1) - assert.Equal(t, "100.64.0.3", rule.Principals[0].NodeIP) + got, err := tt.policy.compileSSHPolicy( + "unused-server-url", testUsers, tt.targetNode.View(), testNodes.ViewSlice(), + ) + require.NoError(t, err) - // Verify action is set correctly - assert.True(t, rule.Action.Accept) - assert.True(t, rule.Action.AllowAgentForwarding) - assert.True(t, rule.Action.AllowLocalPortForwarding) - assert.True(t, rule.Action.AllowRemotePortForwarding) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("compileSSHPolicy() unexpected result (-want +got):\n%s", diff) + } }) } } @@ -687,8 +1054,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { }, } - err := policy.validate() - require.NoError(t, err) + require.NoError(t, policy.validate()) sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice()) require.NoError(t, err) @@ -700,6 +1066,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { // Verify SSH users are correctly mapped expectedUsers := map[string]string{ "ssh-it-user": "ssh-it-user", + "root": "", } assert.Equal(t, expectedUsers, rule.SSHUsers) @@ -833,28 +1200,28 @@ func TestSSHIntegrationReproduction(t *testing.T) { }, } - // Validate policy - err := policy.validate() - require.NoError(t, err) + require.NoError(t, policy.validate()) // Test SSH policy compilation for node2 (owned by user2, who is in the group) - sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice()) + got, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice()) require.NoError(t, err) - require.NotNil(t, sshPolicy) - require.Len(t, sshPolicy.Rules, 1) - rule := sshPolicy.Rules[0] + want := &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}}, + SSHUsers: map[string]string{"root": "", "ssh-it-user": "ssh-it-user"}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }} - // This was the failing assertion in integration test - sshUsers was empty - assert.NotEmpty(t, rule.SSHUsers, "SSH users should not be empty") - assert.Contains(t, rule.SSHUsers, "ssh-it-user", "ssh-it-user should be present in SSH users") - assert.Equal(t, "ssh-it-user", rule.SSHUsers["ssh-it-user"], "ssh-it-user should map to itself") - - // Verify that ssh-it-user is correctly mapped - expectedUsers := map[string]string{ - "ssh-it-user": "ssh-it-user", + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("compileSSHPolicy() mismatch (-want +got):\n%s", diff) } - assert.Equal(t, expectedUsers, rule.SSHUsers, "ssh-it-user should be mapped to itself") } // TestSSHJSONSerialization verifies that the SSH policy can be properly serialized @@ -885,40 +1252,38 @@ func TestSSHJSONSerialization(t *testing.T) { }, } - err := policy.validate() + require.NoError(t, policy.validate()) + + got, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice()) require.NoError(t, err) - sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice()) - require.NoError(t, err) - require.NotNil(t, sshPolicy) + want := &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.1"}}, + SSHUsers: map[string]string{"root": "", "ssh-it-user": "ssh-it-user", "ubuntu": "ubuntu", "admin": "admin"}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }} - // Serialize to JSON to verify structure - jsonData, err := json.MarshalIndent(sshPolicy, "", " ") - require.NoError(t, err) - - // Parse back to verify structure - var parsed tailcfg.SSHPolicy - - err = json.Unmarshal(jsonData, &parsed) - require.NoError(t, err) - - // Verify the parsed structure has the expected SSH users - require.Len(t, parsed.Rules, 1) - rule := parsed.Rules[0] - - expectedUsers := map[string]string{ - "ssh-it-user": "ssh-it-user", - "ubuntu": "ubuntu", - "admin": "admin", + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("compileSSHPolicy() mismatch (-want +got):\n%s", diff) } - assert.Equal(t, expectedUsers, rule.SSHUsers, "SSH users should survive JSON round-trip") - // Verify JSON contains the SSH users (not empty) - assert.Contains(t, string(jsonData), `"ssh-it-user"`) - assert.Contains(t, string(jsonData), `"ubuntu"`) - assert.Contains(t, string(jsonData), `"admin"`) - assert.NotContains(t, string(jsonData), `"sshUsers": {}`, "SSH users should not be empty") - assert.NotContains(t, string(jsonData), `"sshUsers": null`, "SSH users should not be null") + // Verify JSON round-trip preserves the full structure + jsonData, err := json.MarshalIndent(got, "", " ") + require.NoError(t, err) + + var parsed tailcfg.SSHPolicy + require.NoError(t, json.Unmarshal(jsonData, &parsed)) + + if diff := cmp.Diff(want, &parsed); diff != "" { + t.Errorf("JSON round-trip mismatch (-want +got):\n%s", diff) + } } func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { @@ -2244,6 +2609,89 @@ func TestCompileSSHPolicy_CheckPeriodVariants(t *testing.T) { } } +func TestIPSetToPrincipals(t *testing.T) { + tests := []struct { + name string + ips []string // IPs to add to the set + want []*tailcfg.SSHPrincipal + }{ + { + name: "nil input", + ips: nil, + want: nil, + }, + { + name: "single IPv4", + ips: []string{"100.64.0.1"}, + want: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.1"}}, + }, + { + name: "multiple IPs", + ips: []string{"100.64.0.1", "100.64.0.2"}, + want: []*tailcfg.SSHPrincipal{ + {NodeIP: "100.64.0.1"}, + {NodeIP: "100.64.0.2"}, + }, + }, + { + name: "IPv6", + ips: []string{"fd7a:115c:a1e0::1"}, + want: []*tailcfg.SSHPrincipal{{NodeIP: "fd7a:115c:a1e0::1"}}, + }, + { + name: "mixed IPv4 and IPv6", + ips: []string{"100.64.0.1", "fd7a:115c:a1e0::1"}, + want: []*tailcfg.SSHPrincipal{ + {NodeIP: "100.64.0.1"}, + {NodeIP: "fd7a:115c:a1e0::1"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ipSet *netipx.IPSet + + if tt.ips != nil { + var builder netipx.IPSetBuilder + + for _, ip := range tt.ips { + addr := netip.MustParseAddr(ip) + builder.Add(addr) + } + + var err error + + ipSet, err = builder.IPSet() + require.NoError(t, err) + } + + got := ipSetToPrincipals(ipSet) + + // Sort for deterministic comparison + sortPrincipals := func(p []*tailcfg.SSHPrincipal) { + slices.SortFunc(p, func(a, b *tailcfg.SSHPrincipal) int { + if a.NodeIP < b.NodeIP { + return -1 + } + + if a.NodeIP > b.NodeIP { + return 1 + } + + return 0 + }) + } + sortPrincipals(got) + sortPrincipals(tt.want) + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("ipSetToPrincipals() mismatch (-want +got):\n%s", diff) + } + }) + } +} + func TestSSHCheckParams(t *testing.T) { users := types.Users{ {Name: "user1", Model: gorm.Model{ID: 1}}, @@ -2424,3 +2872,219 @@ func TestSSHCheckParams(t *testing.T) { }) } } + +func TestResolveLocalparts(t *testing.T) { + tests := []struct { + name string + entries []SSHUser + users types.Users + want map[uint]string + }{ + { + name: "no entries", + entries: nil, + users: types.Users{{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}}, + want: nil, + }, + { + name: "single match", + entries: []SSHUser{"localpart:*@example.com"}, + users: types.Users{ + {Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}, + }, + want: map[uint]string{1: "alice"}, + }, + { + name: "domain mismatch", + entries: []SSHUser{"localpart:*@other.com"}, + users: types.Users{ + {Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}, + }, + want: map[uint]string{}, + }, + { + name: "case insensitive domain", + entries: []SSHUser{"localpart:*@EXAMPLE.COM"}, + users: types.Users{ + {Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}, + }, + want: map[uint]string{1: "alice"}, + }, + { + name: "user without email skipped", + entries: []SSHUser{"localpart:*@example.com"}, + users: types.Users{ + {Name: "cli-user", Model: gorm.Model{ID: 1}}, + }, + want: map[uint]string{}, + }, + { + name: "multiple domains multiple users", + entries: []SSHUser{ + "localpart:*@example.com", + "localpart:*@other.com", + }, + users: types.Users{ + {Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}, + {Name: "bob", Email: "bob@other.com", Model: gorm.Model{ID: 2}}, + {Name: "charlie", Email: "charlie@nope.com", Model: gorm.Model{ID: 3}}, + }, + want: map[uint]string{1: "alice", 2: "bob"}, + }, + { + name: "special chars in local part", + entries: []SSHUser{"localpart:*@example.com"}, + users: types.Users{ + {Name: "d", Email: "dave+ssh@example.com", Model: gorm.Model{ID: 1}}, + }, + want: map[uint]string{1: "dave+ssh"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resolveLocalparts(tt.entries, tt.users) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("resolveLocalparts() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestGroupSourcesByUser(t *testing.T) { + alice := types.User{ + Name: "alice", Email: "alice@example.com", + Model: gorm.Model{ID: 1}, + } + bob := types.User{ + Name: "bob", Email: "bob@example.com", + Model: gorm.Model{ID: 2}, + } + + nodeAlice := types.Node{ + Hostname: "alice-dev", + IPv4: createAddr("100.64.0.1"), + UserID: &alice.ID, + User: &alice, + } + nodeBob := types.Node{ + Hostname: "bob-dev", + IPv4: createAddr("100.64.0.2"), + UserID: &bob.ID, + User: &bob, + } + nodeTagged := types.Node{ + Hostname: "tagged", + IPv4: createAddr("100.64.0.3"), + UserID: &alice.ID, + User: &alice, + Tags: []string{"tag:server"}, + } + + // Build an IPSet that includes all node IPs + allIPs := func() *netipx.IPSet { + var b netipx.IPSetBuilder + b.AddPrefix(netip.MustParsePrefix("100.64.0.0/24")) + + s, _ := b.IPSet() + + return s + }() + + tests := []struct { + name string + nodes types.Nodes + srcIPs *netipx.IPSet + wantUIDs []uint + wantUserCount int + wantHasTagged bool + wantTaggedLen int + wantAliceIP string + wantBobIP string + wantTaggedIP string + }{ + { + name: "user-owned only", + nodes: types.Nodes{&nodeAlice, &nodeBob}, + srcIPs: allIPs, + wantUIDs: []uint{1, 2}, + wantUserCount: 2, + wantAliceIP: "100.64.0.1", + wantBobIP: "100.64.0.2", + }, + { + name: "mixed user and tagged", + nodes: types.Nodes{&nodeAlice, &nodeTagged}, + srcIPs: allIPs, + wantUIDs: []uint{1}, + wantUserCount: 1, + wantHasTagged: true, + wantTaggedLen: 1, + wantAliceIP: "100.64.0.1", + wantTaggedIP: "100.64.0.3", + }, + { + name: "tagged only", + nodes: types.Nodes{&nodeTagged}, + srcIPs: allIPs, + wantUIDs: nil, + wantUserCount: 0, + wantHasTagged: true, + wantTaggedLen: 1, + }, + { + name: "node not in srcIPs excluded", + nodes: types.Nodes{&nodeAlice, &nodeBob}, + srcIPs: func() *netipx.IPSet { + var b netipx.IPSetBuilder + b.Add(netip.MustParseAddr("100.64.0.1")) // only alice + + s, _ := b.IPSet() + + return s + }(), + wantUIDs: []uint{1}, + wantUserCount: 1, + wantAliceIP: "100.64.0.1", + }, + { + name: "sorted by user ID", + nodes: types.Nodes{&nodeBob, &nodeAlice}, // reverse order + srcIPs: allIPs, + wantUIDs: []uint{1, 2}, // still sorted + wantUserCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sortedUIDs, byUser, tagged := groupSourcesByUser( + tt.nodes.ViewSlice(), tt.srcIPs, + ) + + assert.Equal(t, tt.wantUIDs, sortedUIDs, "sortedUIDs") + assert.Len(t, byUser, tt.wantUserCount, "byUser count") + + if tt.wantHasTagged { + assert.Len(t, tagged, tt.wantTaggedLen, "tagged count") + } else { + assert.Empty(t, tagged, "tagged should be empty") + } + + if tt.wantAliceIP != "" { + require.Contains(t, byUser, uint(1)) + assert.Equal(t, tt.wantAliceIP, byUser[1][0].NodeIP) + } + + if tt.wantBobIP != "" { + require.Contains(t, byUser, uint(2)) + assert.Equal(t, tt.wantBobIP, byUser[2][0].NodeIP) + } + + if tt.wantTaggedIP != "" { + require.NotEmpty(t, tagged) + assert.Equal(t, tt.wantTaggedIP, tagged[0].NodeIP) + } + }) + } +} diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 1596e09c..8d7df81f 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -47,6 +47,7 @@ var ( ErrSSHCheckPeriodBelowMin = errors.New("checkPeriod below minimum of 1 minute") ErrSSHCheckPeriodAboveMax = errors.New("checkPeriod above maximum of 168 hours (1 week)") ErrSSHCheckPeriodOnNonCheck = errors.New("checkPeriod is only valid with action \"check\"") + ErrInvalidLocalpart = errors.New("invalid localpart format, must be localpart:*@") ) // SSH check period constants per Tailscale docs: @@ -1965,6 +1966,14 @@ func (p *Policy) validate() error { continue } } + + if user.IsLocalpart() { + _, err := user.ParseLocalpart() + if err != nil { + errs = append(errs, err) + continue + } + } } for _, src := range ssh.Sources { @@ -2191,6 +2200,7 @@ type SSH struct { Destinations SSHDstAliases `json:"dst"` Users SSHUsers `json:"users"` CheckPeriod *SSHCheckPeriod `json:"checkPeriod,omitempty"` + AcceptEnv []string `json:"acceptEnv,omitempty"` } // SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule. @@ -2342,6 +2352,11 @@ type SSHDstAliases []Alias type SSHUsers []SSHUser +// SSHUserLocalpartPrefix is the prefix for localpart SSH user entries. +// Format: localpart:*@ +// See: https://tailscale.com/docs/features/tailscale-ssh#users +const SSHUserLocalpartPrefix = "localpart:" + func (u SSHUsers) ContainsRoot() bool { return slices.Contains(u, "root") } @@ -2350,9 +2365,25 @@ func (u SSHUsers) ContainsNonRoot() bool { return slices.Contains(u, SSHUser(AutoGroupNonRoot)) } +// ContainsLocalpart returns true if any entry has the localpart: prefix. +func (u SSHUsers) ContainsLocalpart() bool { + return slices.ContainsFunc(u, func(user SSHUser) bool { + return user.IsLocalpart() + }) +} + +// NormalUsers returns all SSH users that are not root, autogroup:nonroot, +// or localpart: entries. func (u SSHUsers) NormalUsers() []SSHUser { return slicesx.Filter(nil, u, func(user SSHUser) bool { - return user != "root" && user != SSHUser(AutoGroupNonRoot) + return user != "root" && user != SSHUser(AutoGroupNonRoot) && !user.IsLocalpart() + }) +} + +// LocalpartEntries returns only the localpart: prefixed entries. +func (u SSHUsers) LocalpartEntries() []SSHUser { + return slicesx.Filter(nil, u, func(user SSHUser) bool { + return user.IsLocalpart() }) } @@ -2362,6 +2393,41 @@ func (u SSHUser) String() string { return string(u) } +// IsLocalpart returns true if the SSHUser has the localpart: prefix. +func (u SSHUser) IsLocalpart() bool { + return strings.HasPrefix(string(u), SSHUserLocalpartPrefix) +} + +// ParseLocalpart validates and extracts the domain from a localpart: entry. +// The expected format is localpart:*@. +// Returns the domain part or an error if the format is invalid. +func (u SSHUser) ParseLocalpart() (string, error) { + if !u.IsLocalpart() { + return "", fmt.Errorf("%w: missing prefix %q in %q", ErrInvalidLocalpart, SSHUserLocalpartPrefix, u) + } + + pattern := strings.TrimPrefix(string(u), SSHUserLocalpartPrefix) + + // Must be *@ + atIdx := strings.LastIndex(pattern, "@") + if atIdx < 0 { + return "", fmt.Errorf("%w: missing @ in %q", ErrInvalidLocalpart, u) + } + + localPart := pattern[:atIdx] + domain := pattern[atIdx+1:] + + if localPart != "*" { + return "", fmt.Errorf("%w: local part must be *, got %q in %q", ErrInvalidLocalpart, localPart, u) + } + + if domain == "" { + return "", fmt.Errorf("%w: empty domain in %q", ErrInvalidLocalpart, u) + } + + return domain, nil +} + // MarshalJSON marshals the SSHUser to JSON. func (u SSHUser) MarshalJSON() ([]byte, error) { return json.Marshal(string(u)) diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index a68259a3..f0b9c9a1 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1794,6 +1794,33 @@ func TestUnmarshalPolicy(t *testing.T) { }, }, }, + { + name: "ssh-localpart-valid", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:*@example.com"] + }] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{up("admin@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:prod")}, + Users: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + }, + }, + }, { name: "2754-bracketed-ipv6-multiple-ports", input: ` @@ -1825,6 +1852,33 @@ func TestUnmarshalPolicy(t *testing.T) { }, }, }, + { + name: "ssh-localpart-with-other-users", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:*@example.com", "root", "autogroup:nonroot"] + }] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{up("admin@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:prod")}, + Users: []SSHUser{SSHUser("localpart:*@example.com"), "root", SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + }, { name: "2754-bracketed-ipv6-wildcard-port", input: ` @@ -1951,6 +2005,51 @@ func TestUnmarshalPolicy(t *testing.T) { `, wantErr: "square brackets are only valid around IPv6 addresses", }, + { + name: "ssh-localpart-invalid-no-at-sign", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:foo"] + }] +} +`, + wantErr: "invalid localpart format", + }, + { + name: "ssh-localpart-invalid-non-wildcard", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:alice@example.com"] + }] +} +`, + wantErr: "invalid localpart format", + }, + { + name: "ssh-localpart-invalid-empty-domain", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:*@"] + }] +} +`, + wantErr: "invalid localpart format", + }, } cmps := append(util.Comparers, @@ -2635,56 +2734,63 @@ func TestResolveAutoApprovers(t *testing.T) { func TestSSHUsers_NormalUsers(t *testing.T) { tests := []struct { - name string - users SSHUsers - expected []SSHUser + name string + users SSHUsers + want []SSHUser }{ { - name: "empty users", - users: SSHUsers{}, - expected: []SSHUser{}, + name: "empty users", + users: SSHUsers{}, + want: nil, }, { - name: "only root", - users: SSHUsers{"root"}, - expected: []SSHUser{}, + name: "only root", + users: SSHUsers{"root"}, + want: nil, }, { - name: "only autogroup:nonroot", - users: SSHUsers{SSHUser(AutoGroupNonRoot)}, - expected: []SSHUser{}, + name: "only autogroup:nonroot", + users: SSHUsers{SSHUser(AutoGroupNonRoot)}, + want: nil, }, { - name: "only normal user", - users: SSHUsers{"ssh-it-user"}, - expected: []SSHUser{"ssh-it-user"}, + name: "only normal user", + users: SSHUsers{"ssh-it-user"}, + want: []SSHUser{"ssh-it-user"}, }, { - name: "multiple normal users", - users: SSHUsers{"ubuntu", "admin", "user1"}, - expected: []SSHUser{"ubuntu", "admin", "user1"}, + name: "multiple normal users", + users: SSHUsers{"ubuntu", "admin", "user1"}, + want: []SSHUser{"ubuntu", "admin", "user1"}, }, { - name: "mixed users with root", - users: SSHUsers{"ubuntu", "root", "admin"}, - expected: []SSHUser{"ubuntu", "admin"}, + name: "mixed users with root", + users: SSHUsers{"ubuntu", "root", "admin"}, + want: []SSHUser{"ubuntu", "admin"}, }, { - name: "mixed users with autogroup:nonroot", - users: SSHUsers{"ubuntu", SSHUser(AutoGroupNonRoot), "admin"}, - expected: []SSHUser{"ubuntu", "admin"}, + name: "mixed users with autogroup:nonroot", + users: SSHUsers{"ubuntu", SSHUser(AutoGroupNonRoot), "admin"}, + want: []SSHUser{"ubuntu", "admin"}, }, { - name: "mixed users with both root and autogroup:nonroot", - users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), "admin"}, - expected: []SSHUser{"ubuntu", "admin"}, + name: "mixed users with both root and autogroup:nonroot", + users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), "admin"}, + want: []SSHUser{"ubuntu", "admin"}, + }, + { + name: "excludes localpart entries", + users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), SSHUser("localpart:*@example.com"), "admin"}, + want: []SSHUser{"ubuntu", "admin"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := tt.users.NormalUsers() - assert.ElementsMatch(t, tt.expected, result, "NormalUsers() should return expected normal users") + got := tt.users.NormalUsers() + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("NormalUsers() unexpected result (-want +got):\n%s", diff) + } }) } } @@ -2761,6 +2867,142 @@ func TestSSHUsers_ContainsNonRoot(t *testing.T) { } } +func TestSSHUsers_ContainsLocalpart(t *testing.T) { + tests := []struct { + name string + users SSHUsers + expected bool + }{ + { + name: "empty users", + users: SSHUsers{}, + expected: false, + }, + { + name: "contains localpart", + users: SSHUsers{SSHUser("localpart:*@example.com")}, + expected: true, + }, + { + name: "does not contain localpart", + users: SSHUsers{"ubuntu", "admin", "root"}, + expected: false, + }, + { + name: "contains localpart among others", + users: SSHUsers{"ubuntu", SSHUser("localpart:*@example.com"), "admin"}, + expected: true, + }, + { + name: "multiple localpart entries", + users: SSHUsers{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.users.ContainsLocalpart() + assert.Equal(t, tt.expected, result, "ContainsLocalpart() should return expected result") + }) + } +} + +func TestSSHUsers_LocalpartEntries(t *testing.T) { + tests := []struct { + name string + users SSHUsers + want []SSHUser + }{ + { + name: "empty users", + users: SSHUsers{}, + want: nil, + }, + { + name: "no localpart entries", + users: SSHUsers{"root", "ubuntu", SSHUser(AutoGroupNonRoot)}, + want: nil, + }, + { + name: "single localpart entry", + users: SSHUsers{"root", SSHUser("localpart:*@example.com"), "ubuntu"}, + want: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + { + name: "multiple localpart entries", + users: SSHUsers{SSHUser("localpart:*@a.com"), "root", SSHUser("localpart:*@b.com")}, + want: []SSHUser{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.users.LocalpartEntries() + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("LocalpartEntries() unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +func TestSSHUser_ParseLocalpart(t *testing.T) { + tests := []struct { + name string + user SSHUser + expectedDomain string + expectErr bool + }{ + { + name: "valid localpart", + user: SSHUser("localpart:*@example.com"), + expectedDomain: "example.com", + }, + { + name: "valid localpart with subdomain", + user: SSHUser("localpart:*@corp.example.com"), + expectedDomain: "corp.example.com", + }, + { + name: "missing prefix", + user: SSHUser("ubuntu"), + expectErr: true, + }, + { + name: "missing @ sign", + user: SSHUser("localpart:foo"), + expectErr: true, + }, + { + name: "non-wildcard local part", + user: SSHUser("localpart:alice@example.com"), + expectErr: true, + }, + { + name: "empty domain", + user: SSHUser("localpart:*@"), + expectErr: true, + }, + { + name: "just prefix", + user: SSHUser("localpart:"), + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + domain, err := tt.user.ParseLocalpart() + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedDomain, domain) + } + }) + } +} + func mustIPSet(prefixes ...string) *netipx.IPSet { var builder netipx.IPSetBuilder for _, p := range prefixes {