policy/v2: add localpart:*@domain SSH user compilation

Add support for localpart:*@<domain> entries in SSH policy users.
When a user SSHes into a target, their email local-part becomes the
OS username (e.g. alice@example.com → OS user alice).

Type system (types.go):
- SSHUser.IsLocalpart() and ParseLocalpart() for validation
- SSHUsers.LocalpartEntries(), NormalUsers(), ContainsLocalpart()
- Enforces format: localpart:*@<domain> (wildcard-only)
- UserWildcard.Resolve for user:*@domain SSH source aliases
- acceptEnv passthrough for SSH rules

Compilation (filter.go):
- resolveLocalparts: pure function mapping users to local-parts
  by email domain. No node walking, easy to test.
- groupSourcesByUser: single walk producing per-user principals
  with sorted user IDs, and tagged principals separately.
- ipSetToPrincipals: shared helper replacing 6 inline copies.
- selfPrincipalsForNode: self-access using pre-computed byUser.

The approach separates data gathering from rule assembly. Localpart
rules are interleaved per source user to match Tailscale SaaS
first-match-wins ordering.

Updates #3049
This commit is contained in:
Kristoffer Dalby
2026-02-24 19:39:52 +00:00
parent 414d3bbbd8
commit 0acf09bdd2
5 changed files with 1452 additions and 220 deletions

View File

@@ -1077,6 +1077,8 @@ func TestSSHPolicyRules(t *testing.T) {
{Name: "user1", Model: gorm.Model{ID: 1}},
{Name: "user2", Model: gorm.Model{ID: 2}},
{Name: "user3", Model: gorm.Model{ID: 3}},
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 4}},
{Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 5}},
}
// Create standard node setups used across tests
@@ -1110,6 +1112,20 @@ func TestSSHPolicyRules(t *testing.T) {
Tags: []string{"tag:server"},
}
// Nodes for localpart tests (users with email addresses)
nodeAlice := types.Node{
Hostname: "alice-device",
IPv4: ap("100.64.0.6"),
UserID: new(uint(4)),
User: new(users[3]),
}
nodeBob := types.Node{
Hostname: "bob-device",
IPv4: ap("100.64.0.7"),
UserID: new(uint(5)),
User: new(users[4]),
}
tests := []struct {
name string
targetNode types.Node
@@ -1446,6 +1462,7 @@ func TestSSHPolicyRules(t *testing.T) {
},
SSHUsers: map[string]string{
"debian": "debian",
"root": "",
},
Action: &tailcfg.SSHAction{
Accept: true,
@@ -1456,6 +1473,108 @@ func TestSSHPolicyRules(t *testing.T) {
},
}},
},
{
name: "localpart-maps-email-to-os-user",
targetNode: nodeTaggedServer,
peers: types.Nodes{&nodeAlice, &nodeBob},
policy: `{
"tagOwners": {
"tag:server": ["alice@example.com"]
},
"ssh": [
{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:server"],
"users": ["localpart:*@example.com"]
}
]
}`,
// Per-user common+localpart interleaved: each user gets root deny then localpart.
wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}},
SSHUsers: map[string]string{"root": ""},
Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
},
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}},
SSHUsers: map[string]string{"alice": "alice"},
Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
},
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.7"}},
SSHUsers: map[string]string{"root": ""},
Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
},
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.7"}},
SSHUsers: map[string]string{"bob": "bob"},
Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
},
},
}},
},
{
name: "localpart-combined-with-root",
targetNode: nodeTaggedServer,
peers: types.Nodes{&nodeAlice},
policy: `{
"tagOwners": {
"tag:server": ["alice@example.com"]
},
"ssh": [
{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:server"],
"users": ["localpart:*@example.com", "root"]
}
]
}`,
// Common root rule followed by alice's per-user localpart rule (interleaved).
wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}},
SSHUsers: map[string]string{"root": "root"},
Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
},
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}},
SSHUsers: map[string]string{"alice": "alice"},
Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
},
},
}},
},
}
for _, tt := range tests {

View File

@@ -362,7 +362,6 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction {
}
}
//nolint:gocyclo // complex SSH policy compilation logic
func (pol *Policy) compileSSHPolicy(
baseURL string,
users types.Users,
@@ -378,13 +377,7 @@ func (pol *Policy) compileSSHPolicy(
var rules []*tailcfg.SSHRule
for index, rule := range pol.SSHs {
// Separate destinations into autogroup:self and others
// This is needed because autogroup:self requires filtering sources to same-user only,
// while other destinations should use all resolved sources
var (
autogroupSelfDests []Alias
otherDests []Alias
)
var autogroupSelfDests, otherDests []Alias
for _, dst := range rule.Destinations {
if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
@@ -394,12 +387,11 @@ func (pol *Policy) compileSSHPolicy(
}
}
// Note: Tagged nodes can't match autogroup:self destinations, but can still match other destinations
// Resolve sources once - we'll use them differently for each destination type
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Caller().Err(err).Msgf("ssh policy compilation failed resolving source ips for rule %+v", rule)
log.Trace().Caller().Err(err).Msgf(
"ssh policy compilation failed resolving source ips for rule %+v", rule,
)
}
if srcIPs == nil || len(srcIPs.Prefixes()) == 0 {
@@ -414,95 +406,87 @@ func (pol *Policy) compileSSHPolicy(
case SSHActionCheck:
action = sshCheck(baseURL, checkPeriodFromRule(rule))
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
return nil, fmt.Errorf(
"parsing SSH policy, unknown action %q, index: %d: %w",
rule.Action, index, err,
)
}
userMap := make(map[string]string, len(rule.Users))
acceptEnv := rule.AcceptEnv
// Build the common userMap (always has at least a root entry).
const rootUser = "root"
baseUserMap := make(map[string]string, len(rule.Users))
if rule.Users.ContainsNonRoot() {
userMap["*"] = "="
// by default, we do not allow root unless explicitly stated
userMap["root"] = ""
baseUserMap["*"] = "="
}
if rule.Users.ContainsRoot() {
userMap["root"] = "root"
baseUserMap[rootUser] = rootUser
} else {
baseUserMap[rootUser] = ""
}
for _, u := range rule.Users.NormalUsers() {
userMap[u.String()] = u.String()
baseUserMap[u.String()] = u.String()
}
// Handle autogroup:self destinations (if any)
// Note: Tagged nodes can't match autogroup:self, so skip this block for tagged nodes
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
// Build destination set for autogroup:self (same-user untagged devices only)
var dest netipx.IPSetBuilder
hasLocalpart := rule.Users.ContainsLocalpart()
for _, n := range nodes.All() {
if !n.IsTagged() && n.User().ID() == node.User().ID() {
n.AppendToIPSet(&dest)
}
}
var localpartByUser map[uint]string
if hasLocalpart {
localpartByUser = resolveLocalparts(
rule.Users.LocalpartEntries(), users,
)
}
destSet, err := dest.IPSet()
if err != nil {
return nil, err
}
userIDs, principalsByUser, taggedPrincipals := groupSourcesByUser(
nodes, srcIPs,
)
// Only create rule if this node is in the destination set
if node.InIPSet(destSet) {
// Filter sources to only same-user untagged devices
// Pre-filter to same-user untagged devices for efficiency
sameUserNodes := make([]types.NodeView, 0)
// appendRules emits a common rule and, if the user has a
// localpart match, a per-user localpart rule.
appendRules := func(principals []*tailcfg.SSHPrincipal, uid uint, hasUID bool) {
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: baseUserMap,
Action: &action,
AcceptEnv: acceptEnv,
})
for _, n := range nodes.All() {
if !n.IsTagged() && n.User().ID() == node.User().ID() {
sameUserNodes = append(sameUserNodes, n)
}
}
var filteredSrcIPs netipx.IPSetBuilder
for _, n := range sameUserNodes {
// Check if any of this node's IPs are in the source set
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
n.AppendToIPSet(&filteredSrcIPs) // Found this node, move to next
}
}
filteredSrcSet, err := filteredSrcIPs.IPSet()
if err != nil {
return nil, err
}
if filteredSrcSet != nil && len(filteredSrcSet.Prefixes()) > 0 {
var principals []*tailcfg.SSHPrincipal
for addr := range util.IPSetAddrIter(filteredSrcSet) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
}
if len(principals) > 0 {
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}
if hasUID {
if lp, ok := localpartByUser[uid]; ok {
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: map[string]string{lp: lp},
Action: &action,
AcceptEnv: acceptEnv,
})
}
}
}
// Handle other destinations (if any)
// Handle autogroup:self destinations.
// Tagged nodes can't match autogroup:self.
if len(autogroupSelfDests) > 0 &&
!node.IsTagged() && node.User().Valid() {
uid := node.User().ID()
if principals := principalsByUser[uid]; len(principals) > 0 {
appendRules(principals, uid, true)
}
}
// Handle other destinations.
if len(otherDests) > 0 {
// Build destination set for other destinations
var dest netipx.IPSetBuilder
for _, dst := range otherDests {
ips, err := dst.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
log.Trace().Caller().Err(err).
Msgf("resolving destination ips")
}
if ips != nil {
@@ -515,22 +499,48 @@ func (pol *Policy) compileSSHPolicy(
return nil, err
}
// Only create rule if this node is in the destination set
if node.InIPSet(destSet) {
// For non-autogroup:self destinations, use all resolved sources (no filtering)
var principals []*tailcfg.SSHPrincipal
for addr := range util.IPSetAddrIter(srcIPs) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
}
// Node is a destination — emit rules.
// When localpart entries exist, interleave common
// and localpart rules per source user to match
// Tailscale SaaS first-match-wins ordering.
if hasLocalpart {
for _, uid := range userIDs {
appendRules(principalsByUser[uid], uid, true)
}
if len(principals) > 0 {
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
if len(taggedPrincipals) > 0 {
appendRules(taggedPrincipals, 0, false)
}
} else {
if principals := ipSetToPrincipals(srcIPs); len(principals) > 0 {
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: baseUserMap,
Action: &action,
AcceptEnv: acceptEnv,
})
}
}
} else if hasLocalpart && node.InIPSet(srcIPs) {
// Self-access: source node not in destination set
// receives rules scoped to its own user.
if node.IsTagged() {
var builder netipx.IPSetBuilder
node.AppendToIPSet(&builder)
ipSet, err := builder.IPSet()
if err == nil && ipSet != nil {
if principals := ipSetToPrincipals(ipSet); len(principals) > 0 {
appendRules(principals, 0, false)
}
}
} else if node.User().Valid() {
uid := node.User().ID()
if principals := principalsByUser[uid]; len(principals) > 0 {
appendRules(principals, uid, true)
}
}
}
}
@@ -558,6 +568,137 @@ func (pol *Policy) compileSSHPolicy(
}, nil
}
// ipSetToPrincipals converts an IPSet into SSH principals, one per address.
func ipSetToPrincipals(ipSet *netipx.IPSet) []*tailcfg.SSHPrincipal {
if ipSet == nil {
return nil
}
var principals []*tailcfg.SSHPrincipal
for addr := range util.IPSetAddrIter(ipSet) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
}
return principals
}
// resolveLocalparts maps each user whose email matches a localpart:*@<domain>
// entry to their email local-part. Returns userID → localPart (e.g. {1: "alice"}).
// This is a pure data function — no node walking or IP resolution.
func resolveLocalparts(
entries []SSHUser,
users types.Users,
) map[uint]string {
if len(entries) == 0 {
return nil
}
result := make(map[uint]string)
for _, entry := range entries {
domain, err := entry.ParseLocalpart()
if err != nil {
log.Warn().Err(err).Msgf(
"skipping invalid localpart entry %q during SSH compilation",
entry,
)
continue
}
for _, user := range users {
if user.Email == "" {
continue
}
atIdx := strings.LastIndex(user.Email, "@")
if atIdx < 0 {
continue
}
if !strings.EqualFold(user.Email[atIdx+1:], domain) {
continue
}
result[user.ID] = user.Email[:atIdx]
}
}
return result
}
// groupSourcesByUser groups source node IPs by user ownership. Returns sorted
// user IDs for deterministic iteration, per-user principals, and tagged principals.
// Only includes nodes whose IPs are in the srcIPs set.
func groupSourcesByUser(
nodes views.Slice[types.NodeView],
srcIPs *netipx.IPSet,
) ([]uint, map[uint][]*tailcfg.SSHPrincipal, []*tailcfg.SSHPrincipal) {
userIPSets := make(map[uint]*netipx.IPSetBuilder)
var taggedIPSet netipx.IPSetBuilder
hasTagged := false
for _, n := range nodes.All() {
if !slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
continue
}
if n.IsTagged() {
n.AppendToIPSet(&taggedIPSet)
hasTagged = true
continue
}
if !n.User().Valid() {
continue
}
uid := n.User().ID()
if _, ok := userIPSets[uid]; !ok {
userIPSets[uid] = &netipx.IPSetBuilder{}
}
n.AppendToIPSet(userIPSets[uid])
}
var userIDs []uint
principalsByUser := make(map[uint][]*tailcfg.SSHPrincipal, len(userIPSets))
for uid, builder := range userIPSets {
ipSet, err := builder.IPSet()
if err != nil || ipSet == nil {
continue
}
if principals := ipSetToPrincipals(ipSet); len(principals) > 0 {
principalsByUser[uid] = principals
userIDs = append(userIDs, uid)
}
}
slices.Sort(userIDs)
var tagged []*tailcfg.SSHPrincipal
if hasTagged {
taggedSet, err := taggedIPSet.IPSet()
if err == nil && taggedSet != nil {
tagged = ipSetToPrincipals(taggedSet)
}
}
return userIDs, principalsByUser, tagged
}
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
var out []string

View File

@@ -12,6 +12,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go4.org/netipx"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
@@ -431,12 +432,19 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
nodes := types.Nodes{&nodeTaggedServer, &nodeTaggedDB, &nodeUser2Untagged}
acceptAction := &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
}
user2Principal := []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}}
tests := []struct {
name string
targetNode types.Node
policy *Policy
wantSSHUsers map[string]string
wantEmpty bool
name string
targetNode types.Node
policy *Policy
want *tailcfg.SSHPolicy
}{
{
name: "specific user mapping",
@@ -457,9 +465,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
},
},
},
wantSSHUsers: map[string]string{
"ssh-it-user": "ssh-it-user",
},
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: user2Principal,
SSHUsers: map[string]string{"root": "", "ssh-it-user": "ssh-it-user"},
Action: acceptAction,
},
}},
},
{
name: "multiple specific users",
@@ -480,11 +492,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
},
},
},
wantSSHUsers: map[string]string{
"ubuntu": "ubuntu",
"admin": "admin",
"deploy": "deploy",
},
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: user2Principal,
SSHUsers: map[string]string{"root": "", "ubuntu": "ubuntu", "admin": "admin", "deploy": "deploy"},
Action: acceptAction,
},
}},
},
{
name: "autogroup:nonroot only",
@@ -505,10 +519,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
},
},
},
wantSSHUsers: map[string]string{
"*": "=",
"root": "",
},
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: user2Principal,
SSHUsers: map[string]string{"*": "=", "root": ""},
Action: acceptAction,
},
}},
},
{
name: "root only",
@@ -529,9 +546,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
},
},
},
wantSSHUsers: map[string]string{
"root": "root",
},
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: user2Principal,
SSHUsers: map[string]string{"root": "root"},
Action: acceptAction,
},
}},
},
{
name: "autogroup:nonroot plus root",
@@ -552,10 +573,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
},
},
},
wantSSHUsers: map[string]string{
"*": "=",
"root": "root",
},
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: user2Principal,
SSHUsers: map[string]string{"*": "=", "root": "root"},
Action: acceptAction,
},
}},
},
{
name: "mixed specific users and autogroups",
@@ -576,12 +600,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
},
},
},
wantSSHUsers: map[string]string{
"*": "=",
"root": "root",
"ubuntu": "ubuntu",
"admin": "admin",
},
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: user2Principal,
SSHUsers: map[string]string{"*": "=", "root": "root", "ubuntu": "ubuntu", "admin": "admin"},
Action: acceptAction,
},
}},
},
{
name: "no matching destination",
@@ -603,45 +628,387 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
},
},
},
wantEmpty: true,
want: &tailcfg.SSHPolicy{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate the policy
err := tt.policy.validate()
require.NoError(t, tt.policy.validate())
got, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice())
require.NoError(t, err)
// Compile SSH policy
sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice())
require.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("compileSSHPolicy() mismatch (-want +got):\n%s", diff)
}
})
}
}
if tt.wantEmpty {
if sshPolicy == nil {
return // Expected empty result
func TestCompileSSHPolicy_LocalpartMapping(t *testing.T) {
users := types.Users{
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
{Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 2}},
{Name: "charlie", Email: "charlie@other.com", Model: gorm.Model{ID: 3}},
{Name: "dave", Model: gorm.Model{ID: 4}}, // CLI user, no email
}
nodeTaggedServer := types.Node{
Hostname: "tagged-server",
IPv4: createAddr("100.64.0.1"),
UserID: new(users[0].ID),
User: new(users[0]),
Tags: []string{"tag:server"},
}
nodeAlice := types.Node{
Hostname: "alice-device",
IPv4: createAddr("100.64.0.2"),
UserID: new(users[0].ID),
User: new(users[0]),
}
nodeBob := types.Node{
Hostname: "bob-device",
IPv4: createAddr("100.64.0.3"),
UserID: new(users[1].ID),
User: new(users[1]),
}
nodeCharlie := types.Node{
Hostname: "charlie-device",
IPv4: createAddr("100.64.0.4"),
UserID: new(users[2].ID),
User: new(users[2]),
}
nodeDave := types.Node{
Hostname: "dave-device",
IPv4: createAddr("100.64.0.5"),
UserID: new(users[3].ID),
User: new(users[3]),
}
nodes := types.Nodes{&nodeTaggedServer, &nodeAlice, &nodeBob, &nodeCharlie, &nodeDave}
acceptAction := &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
}
tests := []struct {
name string
users types.Users // nil → use default users
nodes types.Nodes // nil → use default nodes
targetNode types.Node
policy *Policy
want *tailcfg.SSHPolicy
}{
{
name: "localpart only",
targetNode: nodeTaggedServer,
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("alice@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@example.com")},
},
},
},
// Per-user common+localpart rules interleaved, then non-matching users.
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
SSHUsers: map[string]string{"alice": "alice"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
SSHUsers: map[string]string{"bob": "bob"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
}},
},
{
name: "localpart with root",
targetNode: nodeTaggedServer,
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("alice@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@example.com"), "root"},
},
},
},
// Per-user common+localpart rules interleaved, then non-matching users.
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
SSHUsers: map[string]string{"root": "root"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
SSHUsers: map[string]string{"alice": "alice"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
SSHUsers: map[string]string{"root": "root"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
SSHUsers: map[string]string{"bob": "bob"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
SSHUsers: map[string]string{"root": "root"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
SSHUsers: map[string]string{"root": "root"},
Action: acceptAction,
},
}},
},
{
name: "localpart no matching users in domain",
targetNode: nodeTaggedServer,
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("alice@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@nonexistent.com")},
},
},
},
// No localpart matches, but per-user common rules still emitted (root deny)
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
}},
},
{
name: "localpart with special chars in email",
users: types.Users{
{Name: "dave+sshuser", Email: "dave+sshuser@example.com", Model: gorm.Model{ID: 10}},
},
nodes: func() types.Nodes {
specialUser := types.User{Name: "dave+sshuser", Email: "dave+sshuser@example.com", Model: gorm.Model{ID: 10}}
n := types.Node{
Hostname: "special-device",
IPv4: createAddr("100.64.0.10"),
UserID: new(specialUser.ID),
User: &specialUser,
}
assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match")
return types.Nodes{&nodeTaggedServer, &n}
}(),
targetNode: nodeTaggedServer,
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("dave+sshuser@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@example.com")},
},
},
},
// Per-user common rule (root deny), then separate localpart rule.
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.10"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.10"}},
SSHUsers: map[string]string{"dave+sshuser": "dave+sshuser"},
Action: acceptAction,
},
}},
},
{
name: "localpart excludes CLI users without email",
users: types.Users{
{Name: "dave", Model: gorm.Model{ID: 4}},
},
nodes: func() types.Nodes {
cliUser := types.User{Name: "dave", Model: gorm.Model{ID: 4}}
n := types.Node{
Hostname: "dave-cli-device",
IPv4: createAddr("100.64.0.5"),
UserID: new(cliUser.ID),
User: &cliUser,
}
return
return types.Nodes{&nodeTaggedServer, &n}
}(),
targetNode: nodeTaggedServer,
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("dave@")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@example.com")},
},
},
},
// No localpart matches (CLI user, no email), but implicit root deny emits common rule
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
}},
},
{
name: "localpart with multiple domains",
targetNode: nodeTaggedServer,
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("alice@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{
SSHUser("localpart:*@example.com"),
SSHUser("localpart:*@other.com"),
},
},
},
},
// Per-user common+localpart rules interleaved:
// alice/bob match *@example.com, charlie matches *@other.com.
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
SSHUsers: map[string]string{"alice": "alice"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
SSHUsers: map[string]string{"bob": "bob"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
SSHUsers: map[string]string{"charlie": "charlie"},
Action: acceptAction,
},
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
SSHUsers: map[string]string{"root": ""},
Action: acceptAction,
},
}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testUsers := users
if tt.users != nil {
testUsers = tt.users
}
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1, "Should have exactly one SSH rule")
testNodes := nodes
if tt.nodes != nil {
testNodes = tt.nodes
}
rule := sshPolicy.Rules[0]
assert.Equal(t, tt.wantSSHUsers, rule.SSHUsers, "SSH users mapping should match expected")
require.NoError(t, tt.policy.validate())
// Verify principals are set correctly (should contain user2's untagged device IP since that's the source)
require.Len(t, rule.Principals, 1)
assert.Equal(t, "100.64.0.3", rule.Principals[0].NodeIP)
got, err := tt.policy.compileSSHPolicy(
"unused-server-url", testUsers, tt.targetNode.View(), testNodes.ViewSlice(),
)
require.NoError(t, err)
// Verify action is set correctly
assert.True(t, rule.Action.Accept)
assert.True(t, rule.Action.AllowAgentForwarding)
assert.True(t, rule.Action.AllowLocalPortForwarding)
assert.True(t, rule.Action.AllowRemotePortForwarding)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("compileSSHPolicy() unexpected result (-want +got):\n%s", diff)
}
})
}
}
@@ -687,8 +1054,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
},
}
err := policy.validate()
require.NoError(t, err)
require.NoError(t, policy.validate())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice())
require.NoError(t, err)
@@ -700,6 +1066,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
// Verify SSH users are correctly mapped
expectedUsers := map[string]string{
"ssh-it-user": "ssh-it-user",
"root": "",
}
assert.Equal(t, expectedUsers, rule.SSHUsers)
@@ -833,28 +1200,28 @@ func TestSSHIntegrationReproduction(t *testing.T) {
},
}
// Validate policy
err := policy.validate()
require.NoError(t, err)
require.NoError(t, policy.validate())
// Test SSH policy compilation for node2 (owned by user2, who is in the group)
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice())
got, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
rule := sshPolicy.Rules[0]
want := &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
SSHUsers: map[string]string{"root": "", "ssh-it-user": "ssh-it-user"},
Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
},
},
}}
// This was the failing assertion in integration test - sshUsers was empty
assert.NotEmpty(t, rule.SSHUsers, "SSH users should not be empty")
assert.Contains(t, rule.SSHUsers, "ssh-it-user", "ssh-it-user should be present in SSH users")
assert.Equal(t, "ssh-it-user", rule.SSHUsers["ssh-it-user"], "ssh-it-user should map to itself")
// Verify that ssh-it-user is correctly mapped
expectedUsers := map[string]string{
"ssh-it-user": "ssh-it-user",
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("compileSSHPolicy() mismatch (-want +got):\n%s", diff)
}
assert.Equal(t, expectedUsers, rule.SSHUsers, "ssh-it-user should be mapped to itself")
}
// TestSSHJSONSerialization verifies that the SSH policy can be properly serialized
@@ -885,40 +1252,38 @@ func TestSSHJSONSerialization(t *testing.T) {
},
}
err := policy.validate()
require.NoError(t, policy.validate())
got, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice())
require.NoError(t, err)
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
want := &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.1"}},
SSHUsers: map[string]string{"root": "", "ssh-it-user": "ssh-it-user", "ubuntu": "ubuntu", "admin": "admin"},
Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
},
},
}}
// Serialize to JSON to verify structure
jsonData, err := json.MarshalIndent(sshPolicy, "", " ")
require.NoError(t, err)
// Parse back to verify structure
var parsed tailcfg.SSHPolicy
err = json.Unmarshal(jsonData, &parsed)
require.NoError(t, err)
// Verify the parsed structure has the expected SSH users
require.Len(t, parsed.Rules, 1)
rule := parsed.Rules[0]
expectedUsers := map[string]string{
"ssh-it-user": "ssh-it-user",
"ubuntu": "ubuntu",
"admin": "admin",
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("compileSSHPolicy() mismatch (-want +got):\n%s", diff)
}
assert.Equal(t, expectedUsers, rule.SSHUsers, "SSH users should survive JSON round-trip")
// Verify JSON contains the SSH users (not empty)
assert.Contains(t, string(jsonData), `"ssh-it-user"`)
assert.Contains(t, string(jsonData), `"ubuntu"`)
assert.Contains(t, string(jsonData), `"admin"`)
assert.NotContains(t, string(jsonData), `"sshUsers": {}`, "SSH users should not be empty")
assert.NotContains(t, string(jsonData), `"sshUsers": null`, "SSH users should not be null")
// Verify JSON round-trip preserves the full structure
jsonData, err := json.MarshalIndent(got, "", " ")
require.NoError(t, err)
var parsed tailcfg.SSHPolicy
require.NoError(t, json.Unmarshal(jsonData, &parsed))
if diff := cmp.Diff(want, &parsed); diff != "" {
t.Errorf("JSON round-trip mismatch (-want +got):\n%s", diff)
}
}
func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
@@ -2244,6 +2609,89 @@ func TestCompileSSHPolicy_CheckPeriodVariants(t *testing.T) {
}
}
func TestIPSetToPrincipals(t *testing.T) {
tests := []struct {
name string
ips []string // IPs to add to the set
want []*tailcfg.SSHPrincipal
}{
{
name: "nil input",
ips: nil,
want: nil,
},
{
name: "single IPv4",
ips: []string{"100.64.0.1"},
want: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.1"}},
},
{
name: "multiple IPs",
ips: []string{"100.64.0.1", "100.64.0.2"},
want: []*tailcfg.SSHPrincipal{
{NodeIP: "100.64.0.1"},
{NodeIP: "100.64.0.2"},
},
},
{
name: "IPv6",
ips: []string{"fd7a:115c:a1e0::1"},
want: []*tailcfg.SSHPrincipal{{NodeIP: "fd7a:115c:a1e0::1"}},
},
{
name: "mixed IPv4 and IPv6",
ips: []string{"100.64.0.1", "fd7a:115c:a1e0::1"},
want: []*tailcfg.SSHPrincipal{
{NodeIP: "100.64.0.1"},
{NodeIP: "fd7a:115c:a1e0::1"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ipSet *netipx.IPSet
if tt.ips != nil {
var builder netipx.IPSetBuilder
for _, ip := range tt.ips {
addr := netip.MustParseAddr(ip)
builder.Add(addr)
}
var err error
ipSet, err = builder.IPSet()
require.NoError(t, err)
}
got := ipSetToPrincipals(ipSet)
// Sort for deterministic comparison
sortPrincipals := func(p []*tailcfg.SSHPrincipal) {
slices.SortFunc(p, func(a, b *tailcfg.SSHPrincipal) int {
if a.NodeIP < b.NodeIP {
return -1
}
if a.NodeIP > b.NodeIP {
return 1
}
return 0
})
}
sortPrincipals(got)
sortPrincipals(tt.want)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("ipSetToPrincipals() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestSSHCheckParams(t *testing.T) {
users := types.Users{
{Name: "user1", Model: gorm.Model{ID: 1}},
@@ -2424,3 +2872,219 @@ func TestSSHCheckParams(t *testing.T) {
})
}
}
func TestResolveLocalparts(t *testing.T) {
tests := []struct {
name string
entries []SSHUser
users types.Users
want map[uint]string
}{
{
name: "no entries",
entries: nil,
users: types.Users{{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}},
want: nil,
},
{
name: "single match",
entries: []SSHUser{"localpart:*@example.com"},
users: types.Users{
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
},
want: map[uint]string{1: "alice"},
},
{
name: "domain mismatch",
entries: []SSHUser{"localpart:*@other.com"},
users: types.Users{
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
},
want: map[uint]string{},
},
{
name: "case insensitive domain",
entries: []SSHUser{"localpart:*@EXAMPLE.COM"},
users: types.Users{
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
},
want: map[uint]string{1: "alice"},
},
{
name: "user without email skipped",
entries: []SSHUser{"localpart:*@example.com"},
users: types.Users{
{Name: "cli-user", Model: gorm.Model{ID: 1}},
},
want: map[uint]string{},
},
{
name: "multiple domains multiple users",
entries: []SSHUser{
"localpart:*@example.com",
"localpart:*@other.com",
},
users: types.Users{
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
{Name: "bob", Email: "bob@other.com", Model: gorm.Model{ID: 2}},
{Name: "charlie", Email: "charlie@nope.com", Model: gorm.Model{ID: 3}},
},
want: map[uint]string{1: "alice", 2: "bob"},
},
{
name: "special chars in local part",
entries: []SSHUser{"localpart:*@example.com"},
users: types.Users{
{Name: "d", Email: "dave+ssh@example.com", Model: gorm.Model{ID: 1}},
},
want: map[uint]string{1: "dave+ssh"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := resolveLocalparts(tt.entries, tt.users)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("resolveLocalparts() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestGroupSourcesByUser(t *testing.T) {
alice := types.User{
Name: "alice", Email: "alice@example.com",
Model: gorm.Model{ID: 1},
}
bob := types.User{
Name: "bob", Email: "bob@example.com",
Model: gorm.Model{ID: 2},
}
nodeAlice := types.Node{
Hostname: "alice-dev",
IPv4: createAddr("100.64.0.1"),
UserID: &alice.ID,
User: &alice,
}
nodeBob := types.Node{
Hostname: "bob-dev",
IPv4: createAddr("100.64.0.2"),
UserID: &bob.ID,
User: &bob,
}
nodeTagged := types.Node{
Hostname: "tagged",
IPv4: createAddr("100.64.0.3"),
UserID: &alice.ID,
User: &alice,
Tags: []string{"tag:server"},
}
// Build an IPSet that includes all node IPs
allIPs := func() *netipx.IPSet {
var b netipx.IPSetBuilder
b.AddPrefix(netip.MustParsePrefix("100.64.0.0/24"))
s, _ := b.IPSet()
return s
}()
tests := []struct {
name string
nodes types.Nodes
srcIPs *netipx.IPSet
wantUIDs []uint
wantUserCount int
wantHasTagged bool
wantTaggedLen int
wantAliceIP string
wantBobIP string
wantTaggedIP string
}{
{
name: "user-owned only",
nodes: types.Nodes{&nodeAlice, &nodeBob},
srcIPs: allIPs,
wantUIDs: []uint{1, 2},
wantUserCount: 2,
wantAliceIP: "100.64.0.1",
wantBobIP: "100.64.0.2",
},
{
name: "mixed user and tagged",
nodes: types.Nodes{&nodeAlice, &nodeTagged},
srcIPs: allIPs,
wantUIDs: []uint{1},
wantUserCount: 1,
wantHasTagged: true,
wantTaggedLen: 1,
wantAliceIP: "100.64.0.1",
wantTaggedIP: "100.64.0.3",
},
{
name: "tagged only",
nodes: types.Nodes{&nodeTagged},
srcIPs: allIPs,
wantUIDs: nil,
wantUserCount: 0,
wantHasTagged: true,
wantTaggedLen: 1,
},
{
name: "node not in srcIPs excluded",
nodes: types.Nodes{&nodeAlice, &nodeBob},
srcIPs: func() *netipx.IPSet {
var b netipx.IPSetBuilder
b.Add(netip.MustParseAddr("100.64.0.1")) // only alice
s, _ := b.IPSet()
return s
}(),
wantUIDs: []uint{1},
wantUserCount: 1,
wantAliceIP: "100.64.0.1",
},
{
name: "sorted by user ID",
nodes: types.Nodes{&nodeBob, &nodeAlice}, // reverse order
srcIPs: allIPs,
wantUIDs: []uint{1, 2}, // still sorted
wantUserCount: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sortedUIDs, byUser, tagged := groupSourcesByUser(
tt.nodes.ViewSlice(), tt.srcIPs,
)
assert.Equal(t, tt.wantUIDs, sortedUIDs, "sortedUIDs")
assert.Len(t, byUser, tt.wantUserCount, "byUser count")
if tt.wantHasTagged {
assert.Len(t, tagged, tt.wantTaggedLen, "tagged count")
} else {
assert.Empty(t, tagged, "tagged should be empty")
}
if tt.wantAliceIP != "" {
require.Contains(t, byUser, uint(1))
assert.Equal(t, tt.wantAliceIP, byUser[1][0].NodeIP)
}
if tt.wantBobIP != "" {
require.Contains(t, byUser, uint(2))
assert.Equal(t, tt.wantBobIP, byUser[2][0].NodeIP)
}
if tt.wantTaggedIP != "" {
require.NotEmpty(t, tagged)
assert.Equal(t, tt.wantTaggedIP, tagged[0].NodeIP)
}
})
}
}

View File

@@ -47,6 +47,7 @@ var (
ErrSSHCheckPeriodBelowMin = errors.New("checkPeriod below minimum of 1 minute")
ErrSSHCheckPeriodAboveMax = errors.New("checkPeriod above maximum of 168 hours (1 week)")
ErrSSHCheckPeriodOnNonCheck = errors.New("checkPeriod is only valid with action \"check\"")
ErrInvalidLocalpart = errors.New("invalid localpart format, must be localpart:*@<domain>")
)
// SSH check period constants per Tailscale docs:
@@ -1965,6 +1966,14 @@ func (p *Policy) validate() error {
continue
}
}
if user.IsLocalpart() {
_, err := user.ParseLocalpart()
if err != nil {
errs = append(errs, err)
continue
}
}
}
for _, src := range ssh.Sources {
@@ -2191,6 +2200,7 @@ type SSH struct {
Destinations SSHDstAliases `json:"dst"`
Users SSHUsers `json:"users"`
CheckPeriod *SSHCheckPeriod `json:"checkPeriod,omitempty"`
AcceptEnv []string `json:"acceptEnv,omitempty"`
}
// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule.
@@ -2342,6 +2352,11 @@ type SSHDstAliases []Alias
type SSHUsers []SSHUser
// SSHUserLocalpartPrefix is the prefix for localpart SSH user entries.
// Format: localpart:*@<domain>
// See: https://tailscale.com/docs/features/tailscale-ssh#users
const SSHUserLocalpartPrefix = "localpart:"
func (u SSHUsers) ContainsRoot() bool {
return slices.Contains(u, "root")
}
@@ -2350,9 +2365,25 @@ func (u SSHUsers) ContainsNonRoot() bool {
return slices.Contains(u, SSHUser(AutoGroupNonRoot))
}
// ContainsLocalpart returns true if any entry has the localpart: prefix.
func (u SSHUsers) ContainsLocalpart() bool {
return slices.ContainsFunc(u, func(user SSHUser) bool {
return user.IsLocalpart()
})
}
// NormalUsers returns all SSH users that are not root, autogroup:nonroot,
// or localpart: entries.
func (u SSHUsers) NormalUsers() []SSHUser {
return slicesx.Filter(nil, u, func(user SSHUser) bool {
return user != "root" && user != SSHUser(AutoGroupNonRoot)
return user != "root" && user != SSHUser(AutoGroupNonRoot) && !user.IsLocalpart()
})
}
// LocalpartEntries returns only the localpart: prefixed entries.
func (u SSHUsers) LocalpartEntries() []SSHUser {
return slicesx.Filter(nil, u, func(user SSHUser) bool {
return user.IsLocalpart()
})
}
@@ -2362,6 +2393,41 @@ func (u SSHUser) String() string {
return string(u)
}
// IsLocalpart returns true if the SSHUser has the localpart: prefix.
func (u SSHUser) IsLocalpart() bool {
return strings.HasPrefix(string(u), SSHUserLocalpartPrefix)
}
// ParseLocalpart validates and extracts the domain from a localpart: entry.
// The expected format is localpart:*@<domain>.
// Returns the domain part or an error if the format is invalid.
func (u SSHUser) ParseLocalpart() (string, error) {
if !u.IsLocalpart() {
return "", fmt.Errorf("%w: missing prefix %q in %q", ErrInvalidLocalpart, SSHUserLocalpartPrefix, u)
}
pattern := strings.TrimPrefix(string(u), SSHUserLocalpartPrefix)
// Must be *@<domain>
atIdx := strings.LastIndex(pattern, "@")
if atIdx < 0 {
return "", fmt.Errorf("%w: missing @ in %q", ErrInvalidLocalpart, u)
}
localPart := pattern[:atIdx]
domain := pattern[atIdx+1:]
if localPart != "*" {
return "", fmt.Errorf("%w: local part must be *, got %q in %q", ErrInvalidLocalpart, localPart, u)
}
if domain == "" {
return "", fmt.Errorf("%w: empty domain in %q", ErrInvalidLocalpart, u)
}
return domain, nil
}
// MarshalJSON marshals the SSHUser to JSON.
func (u SSHUser) MarshalJSON() ([]byte, error) {
return json.Marshal(string(u))

View File

@@ -1794,6 +1794,33 @@ func TestUnmarshalPolicy(t *testing.T) {
},
},
},
{
name: "ssh-localpart-valid",
input: `
{
"tagOwners": {"tag:prod": ["admin@"]},
"ssh": [{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:prod"],
"users": ["localpart:*@example.com"]
}]
}
`,
want: &Policy{
TagOwners: TagOwners{
Tag("tag:prod"): Owners{up("admin@")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:prod")},
Users: []SSHUser{SSHUser("localpart:*@example.com")},
},
},
},
},
{
name: "2754-bracketed-ipv6-multiple-ports",
input: `
@@ -1825,6 +1852,33 @@ func TestUnmarshalPolicy(t *testing.T) {
},
},
},
{
name: "ssh-localpart-with-other-users",
input: `
{
"tagOwners": {"tag:prod": ["admin@"]},
"ssh": [{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:prod"],
"users": ["localpart:*@example.com", "root", "autogroup:nonroot"]
}]
}
`,
want: &Policy{
TagOwners: TagOwners{
Tag("tag:prod"): Owners{up("admin@")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:prod")},
Users: []SSHUser{SSHUser("localpart:*@example.com"), "root", SSHUser(AutoGroupNonRoot)},
},
},
},
},
{
name: "2754-bracketed-ipv6-wildcard-port",
input: `
@@ -1951,6 +2005,51 @@ func TestUnmarshalPolicy(t *testing.T) {
`,
wantErr: "square brackets are only valid around IPv6 addresses",
},
{
name: "ssh-localpart-invalid-no-at-sign",
input: `
{
"tagOwners": {"tag:prod": ["admin@"]},
"ssh": [{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:prod"],
"users": ["localpart:foo"]
}]
}
`,
wantErr: "invalid localpart format",
},
{
name: "ssh-localpart-invalid-non-wildcard",
input: `
{
"tagOwners": {"tag:prod": ["admin@"]},
"ssh": [{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:prod"],
"users": ["localpart:alice@example.com"]
}]
}
`,
wantErr: "invalid localpart format",
},
{
name: "ssh-localpart-invalid-empty-domain",
input: `
{
"tagOwners": {"tag:prod": ["admin@"]},
"ssh": [{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:prod"],
"users": ["localpart:*@"]
}]
}
`,
wantErr: "invalid localpart format",
},
}
cmps := append(util.Comparers,
@@ -2635,56 +2734,63 @@ func TestResolveAutoApprovers(t *testing.T) {
func TestSSHUsers_NormalUsers(t *testing.T) {
tests := []struct {
name string
users SSHUsers
expected []SSHUser
name string
users SSHUsers
want []SSHUser
}{
{
name: "empty users",
users: SSHUsers{},
expected: []SSHUser{},
name: "empty users",
users: SSHUsers{},
want: nil,
},
{
name: "only root",
users: SSHUsers{"root"},
expected: []SSHUser{},
name: "only root",
users: SSHUsers{"root"},
want: nil,
},
{
name: "only autogroup:nonroot",
users: SSHUsers{SSHUser(AutoGroupNonRoot)},
expected: []SSHUser{},
name: "only autogroup:nonroot",
users: SSHUsers{SSHUser(AutoGroupNonRoot)},
want: nil,
},
{
name: "only normal user",
users: SSHUsers{"ssh-it-user"},
expected: []SSHUser{"ssh-it-user"},
name: "only normal user",
users: SSHUsers{"ssh-it-user"},
want: []SSHUser{"ssh-it-user"},
},
{
name: "multiple normal users",
users: SSHUsers{"ubuntu", "admin", "user1"},
expected: []SSHUser{"ubuntu", "admin", "user1"},
name: "multiple normal users",
users: SSHUsers{"ubuntu", "admin", "user1"},
want: []SSHUser{"ubuntu", "admin", "user1"},
},
{
name: "mixed users with root",
users: SSHUsers{"ubuntu", "root", "admin"},
expected: []SSHUser{"ubuntu", "admin"},
name: "mixed users with root",
users: SSHUsers{"ubuntu", "root", "admin"},
want: []SSHUser{"ubuntu", "admin"},
},
{
name: "mixed users with autogroup:nonroot",
users: SSHUsers{"ubuntu", SSHUser(AutoGroupNonRoot), "admin"},
expected: []SSHUser{"ubuntu", "admin"},
name: "mixed users with autogroup:nonroot",
users: SSHUsers{"ubuntu", SSHUser(AutoGroupNonRoot), "admin"},
want: []SSHUser{"ubuntu", "admin"},
},
{
name: "mixed users with both root and autogroup:nonroot",
users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), "admin"},
expected: []SSHUser{"ubuntu", "admin"},
name: "mixed users with both root and autogroup:nonroot",
users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), "admin"},
want: []SSHUser{"ubuntu", "admin"},
},
{
name: "excludes localpart entries",
users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), SSHUser("localpart:*@example.com"), "admin"},
want: []SSHUser{"ubuntu", "admin"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.users.NormalUsers()
assert.ElementsMatch(t, tt.expected, result, "NormalUsers() should return expected normal users")
got := tt.users.NormalUsers()
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("NormalUsers() unexpected result (-want +got):\n%s", diff)
}
})
}
}
@@ -2761,6 +2867,142 @@ func TestSSHUsers_ContainsNonRoot(t *testing.T) {
}
}
func TestSSHUsers_ContainsLocalpart(t *testing.T) {
tests := []struct {
name string
users SSHUsers
expected bool
}{
{
name: "empty users",
users: SSHUsers{},
expected: false,
},
{
name: "contains localpart",
users: SSHUsers{SSHUser("localpart:*@example.com")},
expected: true,
},
{
name: "does not contain localpart",
users: SSHUsers{"ubuntu", "admin", "root"},
expected: false,
},
{
name: "contains localpart among others",
users: SSHUsers{"ubuntu", SSHUser("localpart:*@example.com"), "admin"},
expected: true,
},
{
name: "multiple localpart entries",
users: SSHUsers{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.users.ContainsLocalpart()
assert.Equal(t, tt.expected, result, "ContainsLocalpart() should return expected result")
})
}
}
func TestSSHUsers_LocalpartEntries(t *testing.T) {
tests := []struct {
name string
users SSHUsers
want []SSHUser
}{
{
name: "empty users",
users: SSHUsers{},
want: nil,
},
{
name: "no localpart entries",
users: SSHUsers{"root", "ubuntu", SSHUser(AutoGroupNonRoot)},
want: nil,
},
{
name: "single localpart entry",
users: SSHUsers{"root", SSHUser("localpart:*@example.com"), "ubuntu"},
want: []SSHUser{SSHUser("localpart:*@example.com")},
},
{
name: "multiple localpart entries",
users: SSHUsers{SSHUser("localpart:*@a.com"), "root", SSHUser("localpart:*@b.com")},
want: []SSHUser{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.users.LocalpartEntries()
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("LocalpartEntries() unexpected result (-want +got):\n%s", diff)
}
})
}
}
func TestSSHUser_ParseLocalpart(t *testing.T) {
tests := []struct {
name string
user SSHUser
expectedDomain string
expectErr bool
}{
{
name: "valid localpart",
user: SSHUser("localpart:*@example.com"),
expectedDomain: "example.com",
},
{
name: "valid localpart with subdomain",
user: SSHUser("localpart:*@corp.example.com"),
expectedDomain: "corp.example.com",
},
{
name: "missing prefix",
user: SSHUser("ubuntu"),
expectErr: true,
},
{
name: "missing @ sign",
user: SSHUser("localpart:foo"),
expectErr: true,
},
{
name: "non-wildcard local part",
user: SSHUser("localpart:alice@example.com"),
expectErr: true,
},
{
name: "empty domain",
user: SSHUser("localpart:*@"),
expectErr: true,
},
{
name: "just prefix",
user: SSHUser("localpart:"),
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
domain, err := tt.user.ParseLocalpart()
if tt.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expectedDomain, domain)
}
})
}
}
func mustIPSet(prefixes ...string) *netipx.IPSet {
var builder netipx.IPSetBuilder
for _, p := range prefixes {