mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-12 05:05:15 +09:00
policy/v2: add localpart:*@domain SSH user compilation
Add support for localpart:*@<domain> entries in SSH policy users. When a user SSHes into a target, their email local-part becomes the OS username (e.g. alice@example.com → OS user alice). Type system (types.go): - SSHUser.IsLocalpart() and ParseLocalpart() for validation - SSHUsers.LocalpartEntries(), NormalUsers(), ContainsLocalpart() - Enforces format: localpart:*@<domain> (wildcard-only) - UserWildcard.Resolve for user:*@domain SSH source aliases - acceptEnv passthrough for SSH rules Compilation (filter.go): - resolveLocalparts: pure function mapping users to local-parts by email domain. No node walking, easy to test. - groupSourcesByUser: single walk producing per-user principals with sorted user IDs, and tagged principals separately. - ipSetToPrincipals: shared helper replacing 6 inline copies. - selfPrincipalsForNode: self-access using pre-computed byUser. The approach separates data gathering from rule assembly. Localpart rules are interleaved per source user to match Tailscale SaaS first-match-wins ordering. Updates #3049
This commit is contained in:
@@ -1077,6 +1077,8 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
{Name: "user1", Model: gorm.Model{ID: 1}},
|
||||
{Name: "user2", Model: gorm.Model{ID: 2}},
|
||||
{Name: "user3", Model: gorm.Model{ID: 3}},
|
||||
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 4}},
|
||||
{Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 5}},
|
||||
}
|
||||
|
||||
// Create standard node setups used across tests
|
||||
@@ -1110,6 +1112,20 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
Tags: []string{"tag:server"},
|
||||
}
|
||||
|
||||
// Nodes for localpart tests (users with email addresses)
|
||||
nodeAlice := types.Node{
|
||||
Hostname: "alice-device",
|
||||
IPv4: ap("100.64.0.6"),
|
||||
UserID: new(uint(4)),
|
||||
User: new(users[3]),
|
||||
}
|
||||
nodeBob := types.Node{
|
||||
Hostname: "bob-device",
|
||||
IPv4: ap("100.64.0.7"),
|
||||
UserID: new(uint(5)),
|
||||
User: new(users[4]),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
targetNode types.Node
|
||||
@@ -1446,6 +1462,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
},
|
||||
SSHUsers: map[string]string{
|
||||
"debian": "debian",
|
||||
"root": "",
|
||||
},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
@@ -1456,6 +1473,108 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "localpart-maps-email-to-os-user",
|
||||
targetNode: nodeTaggedServer,
|
||||
peers: types.Nodes{&nodeAlice, &nodeBob},
|
||||
policy: `{
|
||||
"tagOwners": {
|
||||
"tag:server": ["alice@example.com"]
|
||||
},
|
||||
"ssh": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:server"],
|
||||
"users": ["localpart:*@example.com"]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
// Per-user common+localpart interleaved: each user gets root deny then localpart.
|
||||
wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}},
|
||||
SSHUsers: map[string]string{"alice": "alice"},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.7"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.7"}},
|
||||
SSHUsers: map[string]string{"bob": "bob"},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "localpart-combined-with-root",
|
||||
targetNode: nodeTaggedServer,
|
||||
peers: types.Nodes{&nodeAlice},
|
||||
policy: `{
|
||||
"tagOwners": {
|
||||
"tag:server": ["alice@example.com"]
|
||||
},
|
||||
"ssh": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:server"],
|
||||
"users": ["localpart:*@example.com", "root"]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
// Common root rule followed by alice's per-user localpart rule (interleaved).
|
||||
wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}},
|
||||
SSHUsers: map[string]string{"root": "root"},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.6"}},
|
||||
SSHUsers: map[string]string{"alice": "alice"},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -362,7 +362,6 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocyclo // complex SSH policy compilation logic
|
||||
func (pol *Policy) compileSSHPolicy(
|
||||
baseURL string,
|
||||
users types.Users,
|
||||
@@ -378,13 +377,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
for index, rule := range pol.SSHs {
|
||||
// Separate destinations into autogroup:self and others
|
||||
// This is needed because autogroup:self requires filtering sources to same-user only,
|
||||
// while other destinations should use all resolved sources
|
||||
var (
|
||||
autogroupSelfDests []Alias
|
||||
otherDests []Alias
|
||||
)
|
||||
var autogroupSelfDests, otherDests []Alias
|
||||
|
||||
for _, dst := range rule.Destinations {
|
||||
if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
@@ -394,12 +387,11 @@ func (pol *Policy) compileSSHPolicy(
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Tagged nodes can't match autogroup:self destinations, but can still match other destinations
|
||||
|
||||
// Resolve sources once - we'll use them differently for each destination type
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("ssh policy compilation failed resolving source ips for rule %+v", rule)
|
||||
log.Trace().Caller().Err(err).Msgf(
|
||||
"ssh policy compilation failed resolving source ips for rule %+v", rule,
|
||||
)
|
||||
}
|
||||
|
||||
if srcIPs == nil || len(srcIPs.Prefixes()) == 0 {
|
||||
@@ -414,95 +406,87 @@ func (pol *Policy) compileSSHPolicy(
|
||||
case SSHActionCheck:
|
||||
action = sshCheck(baseURL, checkPeriodFromRule(rule))
|
||||
default:
|
||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
|
||||
return nil, fmt.Errorf(
|
||||
"parsing SSH policy, unknown action %q, index: %d: %w",
|
||||
rule.Action, index, err,
|
||||
)
|
||||
}
|
||||
|
||||
userMap := make(map[string]string, len(rule.Users))
|
||||
acceptEnv := rule.AcceptEnv
|
||||
|
||||
// Build the common userMap (always has at least a root entry).
|
||||
const rootUser = "root"
|
||||
|
||||
baseUserMap := make(map[string]string, len(rule.Users))
|
||||
if rule.Users.ContainsNonRoot() {
|
||||
userMap["*"] = "="
|
||||
// by default, we do not allow root unless explicitly stated
|
||||
userMap["root"] = ""
|
||||
baseUserMap["*"] = "="
|
||||
}
|
||||
|
||||
if rule.Users.ContainsRoot() {
|
||||
userMap["root"] = "root"
|
||||
baseUserMap[rootUser] = rootUser
|
||||
} else {
|
||||
baseUserMap[rootUser] = ""
|
||||
}
|
||||
|
||||
for _, u := range rule.Users.NormalUsers() {
|
||||
userMap[u.String()] = u.String()
|
||||
baseUserMap[u.String()] = u.String()
|
||||
}
|
||||
|
||||
// Handle autogroup:self destinations (if any)
|
||||
// Note: Tagged nodes can't match autogroup:self, so skip this block for tagged nodes
|
||||
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
|
||||
// Build destination set for autogroup:self (same-user untagged devices only)
|
||||
var dest netipx.IPSetBuilder
|
||||
hasLocalpart := rule.Users.ContainsLocalpart()
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !n.IsTagged() && n.User().ID() == node.User().ID() {
|
||||
n.AppendToIPSet(&dest)
|
||||
}
|
||||
}
|
||||
var localpartByUser map[uint]string
|
||||
if hasLocalpart {
|
||||
localpartByUser = resolveLocalparts(
|
||||
rule.Users.LocalpartEntries(), users,
|
||||
)
|
||||
}
|
||||
|
||||
destSet, err := dest.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs, principalsByUser, taggedPrincipals := groupSourcesByUser(
|
||||
nodes, srcIPs,
|
||||
)
|
||||
|
||||
// Only create rule if this node is in the destination set
|
||||
if node.InIPSet(destSet) {
|
||||
// Filter sources to only same-user untagged devices
|
||||
// Pre-filter to same-user untagged devices for efficiency
|
||||
sameUserNodes := make([]types.NodeView, 0)
|
||||
// appendRules emits a common rule and, if the user has a
|
||||
// localpart match, a per-user localpart rule.
|
||||
appendRules := func(principals []*tailcfg.SSHPrincipal, uid uint, hasUID bool) {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: baseUserMap,
|
||||
Action: &action,
|
||||
AcceptEnv: acceptEnv,
|
||||
})
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !n.IsTagged() && n.User().ID() == node.User().ID() {
|
||||
sameUserNodes = append(sameUserNodes, n)
|
||||
}
|
||||
}
|
||||
|
||||
var filteredSrcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, n := range sameUserNodes {
|
||||
// Check if any of this node's IPs are in the source set
|
||||
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
|
||||
n.AppendToIPSet(&filteredSrcIPs) // Found this node, move to next
|
||||
}
|
||||
}
|
||||
|
||||
filteredSrcSet, err := filteredSrcIPs.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if filteredSrcSet != nil && len(filteredSrcSet.Prefixes()) > 0 {
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
for addr := range util.IPSetAddrIter(filteredSrcSet) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
}
|
||||
|
||||
if len(principals) > 0 {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: userMap,
|
||||
Action: &action,
|
||||
})
|
||||
}
|
||||
if hasUID {
|
||||
if lp, ok := localpartByUser[uid]; ok {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: map[string]string{lp: lp},
|
||||
Action: &action,
|
||||
AcceptEnv: acceptEnv,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle other destinations (if any)
|
||||
// Handle autogroup:self destinations.
|
||||
// Tagged nodes can't match autogroup:self.
|
||||
if len(autogroupSelfDests) > 0 &&
|
||||
!node.IsTagged() && node.User().Valid() {
|
||||
uid := node.User().ID()
|
||||
|
||||
if principals := principalsByUser[uid]; len(principals) > 0 {
|
||||
appendRules(principals, uid, true)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle other destinations.
|
||||
if len(otherDests) > 0 {
|
||||
// Build destination set for other destinations
|
||||
var dest netipx.IPSetBuilder
|
||||
|
||||
for _, dst := range otherDests {
|
||||
ips, err := dst.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
log.Trace().Caller().Err(err).
|
||||
Msgf("resolving destination ips")
|
||||
}
|
||||
|
||||
if ips != nil {
|
||||
@@ -515,22 +499,48 @@ func (pol *Policy) compileSSHPolicy(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only create rule if this node is in the destination set
|
||||
if node.InIPSet(destSet) {
|
||||
// For non-autogroup:self destinations, use all resolved sources (no filtering)
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
}
|
||||
// Node is a destination — emit rules.
|
||||
// When localpart entries exist, interleave common
|
||||
// and localpart rules per source user to match
|
||||
// Tailscale SaaS first-match-wins ordering.
|
||||
if hasLocalpart {
|
||||
for _, uid := range userIDs {
|
||||
appendRules(principalsByUser[uid], uid, true)
|
||||
}
|
||||
|
||||
if len(principals) > 0 {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: userMap,
|
||||
Action: &action,
|
||||
})
|
||||
if len(taggedPrincipals) > 0 {
|
||||
appendRules(taggedPrincipals, 0, false)
|
||||
}
|
||||
} else {
|
||||
if principals := ipSetToPrincipals(srcIPs); len(principals) > 0 {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: baseUserMap,
|
||||
Action: &action,
|
||||
AcceptEnv: acceptEnv,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else if hasLocalpart && node.InIPSet(srcIPs) {
|
||||
// Self-access: source node not in destination set
|
||||
// receives rules scoped to its own user.
|
||||
if node.IsTagged() {
|
||||
var builder netipx.IPSetBuilder
|
||||
|
||||
node.AppendToIPSet(&builder)
|
||||
|
||||
ipSet, err := builder.IPSet()
|
||||
if err == nil && ipSet != nil {
|
||||
if principals := ipSetToPrincipals(ipSet); len(principals) > 0 {
|
||||
appendRules(principals, 0, false)
|
||||
}
|
||||
}
|
||||
} else if node.User().Valid() {
|
||||
uid := node.User().ID()
|
||||
if principals := principalsByUser[uid]; len(principals) > 0 {
|
||||
appendRules(principals, uid, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -558,6 +568,137 @@ func (pol *Policy) compileSSHPolicy(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ipSetToPrincipals converts an IPSet into SSH principals, one per address.
|
||||
func ipSetToPrincipals(ipSet *netipx.IPSet) []*tailcfg.SSHPrincipal {
|
||||
if ipSet == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
|
||||
for addr := range util.IPSetAddrIter(ipSet) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return principals
|
||||
}
|
||||
|
||||
// resolveLocalparts maps each user whose email matches a localpart:*@<domain>
|
||||
// entry to their email local-part. Returns userID → localPart (e.g. {1: "alice"}).
|
||||
// This is a pure data function — no node walking or IP resolution.
|
||||
func resolveLocalparts(
|
||||
entries []SSHUser,
|
||||
users types.Users,
|
||||
) map[uint]string {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[uint]string)
|
||||
|
||||
for _, entry := range entries {
|
||||
domain, err := entry.ParseLocalpart()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msgf(
|
||||
"skipping invalid localpart entry %q during SSH compilation",
|
||||
entry,
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if user.Email == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
atIdx := strings.LastIndex(user.Email, "@")
|
||||
if atIdx < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.EqualFold(user.Email[atIdx+1:], domain) {
|
||||
continue
|
||||
}
|
||||
|
||||
result[user.ID] = user.Email[:atIdx]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// groupSourcesByUser groups source node IPs by user ownership. Returns sorted
|
||||
// user IDs for deterministic iteration, per-user principals, and tagged principals.
|
||||
// Only includes nodes whose IPs are in the srcIPs set.
|
||||
func groupSourcesByUser(
|
||||
nodes views.Slice[types.NodeView],
|
||||
srcIPs *netipx.IPSet,
|
||||
) ([]uint, map[uint][]*tailcfg.SSHPrincipal, []*tailcfg.SSHPrincipal) {
|
||||
userIPSets := make(map[uint]*netipx.IPSetBuilder)
|
||||
|
||||
var taggedIPSet netipx.IPSetBuilder
|
||||
|
||||
hasTagged := false
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
|
||||
continue
|
||||
}
|
||||
|
||||
if n.IsTagged() {
|
||||
n.AppendToIPSet(&taggedIPSet)
|
||||
|
||||
hasTagged = true
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !n.User().Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
uid := n.User().ID()
|
||||
|
||||
if _, ok := userIPSets[uid]; !ok {
|
||||
userIPSets[uid] = &netipx.IPSetBuilder{}
|
||||
}
|
||||
|
||||
n.AppendToIPSet(userIPSets[uid])
|
||||
}
|
||||
|
||||
var userIDs []uint
|
||||
|
||||
principalsByUser := make(map[uint][]*tailcfg.SSHPrincipal, len(userIPSets))
|
||||
|
||||
for uid, builder := range userIPSets {
|
||||
ipSet, err := builder.IPSet()
|
||||
if err != nil || ipSet == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if principals := ipSetToPrincipals(ipSet); len(principals) > 0 {
|
||||
principalsByUser[uid] = principals
|
||||
userIDs = append(userIDs, uid)
|
||||
}
|
||||
}
|
||||
|
||||
slices.Sort(userIDs)
|
||||
|
||||
var tagged []*tailcfg.SSHPrincipal
|
||||
|
||||
if hasTagged {
|
||||
taggedSet, err := taggedIPSet.IPSet()
|
||||
if err == nil && taggedSet != nil {
|
||||
tagged = ipSetToPrincipals(taggedSet)
|
||||
}
|
||||
}
|
||||
|
||||
return userIDs, principalsByUser, tagged
|
||||
}
|
||||
|
||||
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
var out []string
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go4.org/netipx"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
@@ -431,12 +432,19 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
|
||||
nodes := types.Nodes{&nodeTaggedServer, &nodeTaggedDB, &nodeUser2Untagged}
|
||||
|
||||
acceptAction := &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
}
|
||||
user2Principal := []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
targetNode types.Node
|
||||
policy *Policy
|
||||
wantSSHUsers map[string]string
|
||||
wantEmpty bool
|
||||
name string
|
||||
targetNode types.Node
|
||||
policy *Policy
|
||||
want *tailcfg.SSHPolicy
|
||||
}{
|
||||
{
|
||||
name: "specific user mapping",
|
||||
@@ -457,9 +465,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantSSHUsers: map[string]string{
|
||||
"ssh-it-user": "ssh-it-user",
|
||||
},
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: user2Principal,
|
||||
SSHUsers: map[string]string{"root": "", "ssh-it-user": "ssh-it-user"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "multiple specific users",
|
||||
@@ -480,11 +492,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantSSHUsers: map[string]string{
|
||||
"ubuntu": "ubuntu",
|
||||
"admin": "admin",
|
||||
"deploy": "deploy",
|
||||
},
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: user2Principal,
|
||||
SSHUsers: map[string]string{"root": "", "ubuntu": "ubuntu", "admin": "admin", "deploy": "deploy"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "autogroup:nonroot only",
|
||||
@@ -505,10 +519,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantSSHUsers: map[string]string{
|
||||
"*": "=",
|
||||
"root": "",
|
||||
},
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: user2Principal,
|
||||
SSHUsers: map[string]string{"*": "=", "root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "root only",
|
||||
@@ -529,9 +546,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantSSHUsers: map[string]string{
|
||||
"root": "root",
|
||||
},
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: user2Principal,
|
||||
SSHUsers: map[string]string{"root": "root"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "autogroup:nonroot plus root",
|
||||
@@ -552,10 +573,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantSSHUsers: map[string]string{
|
||||
"*": "=",
|
||||
"root": "root",
|
||||
},
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: user2Principal,
|
||||
SSHUsers: map[string]string{"*": "=", "root": "root"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "mixed specific users and autogroups",
|
||||
@@ -576,12 +600,13 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantSSHUsers: map[string]string{
|
||||
"*": "=",
|
||||
"root": "root",
|
||||
"ubuntu": "ubuntu",
|
||||
"admin": "admin",
|
||||
},
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: user2Principal,
|
||||
SSHUsers: map[string]string{"*": "=", "root": "root", "ubuntu": "ubuntu", "admin": "admin"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "no matching destination",
|
||||
@@ -603,45 +628,387 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantEmpty: true,
|
||||
want: &tailcfg.SSHPolicy{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate the policy
|
||||
err := tt.policy.validate()
|
||||
require.NoError(t, tt.policy.validate())
|
||||
|
||||
got, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compile SSH policy
|
||||
sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("compileSSHPolicy() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantEmpty {
|
||||
if sshPolicy == nil {
|
||||
return // Expected empty result
|
||||
func TestCompileSSHPolicy_LocalpartMapping(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
|
||||
{Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 2}},
|
||||
{Name: "charlie", Email: "charlie@other.com", Model: gorm.Model{ID: 3}},
|
||||
{Name: "dave", Model: gorm.Model{ID: 4}}, // CLI user, no email
|
||||
}
|
||||
|
||||
nodeTaggedServer := types.Node{
|
||||
Hostname: "tagged-server",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
Tags: []string{"tag:server"},
|
||||
}
|
||||
nodeAlice := types.Node{
|
||||
Hostname: "alice-device",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
}
|
||||
nodeBob := types.Node{
|
||||
Hostname: "bob-device",
|
||||
IPv4: createAddr("100.64.0.3"),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
}
|
||||
nodeCharlie := types.Node{
|
||||
Hostname: "charlie-device",
|
||||
IPv4: createAddr("100.64.0.4"),
|
||||
UserID: new(users[2].ID),
|
||||
User: new(users[2]),
|
||||
}
|
||||
nodeDave := types.Node{
|
||||
Hostname: "dave-device",
|
||||
IPv4: createAddr("100.64.0.5"),
|
||||
UserID: new(users[3].ID),
|
||||
User: new(users[3]),
|
||||
}
|
||||
|
||||
nodes := types.Nodes{&nodeTaggedServer, &nodeAlice, &nodeBob, &nodeCharlie, &nodeDave}
|
||||
|
||||
acceptAction := &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
users types.Users // nil → use default users
|
||||
nodes types.Nodes // nil → use default nodes
|
||||
targetNode types.Node
|
||||
policy *Policy
|
||||
want *tailcfg.SSHPolicy
|
||||
}{
|
||||
{
|
||||
name: "localpart only",
|
||||
targetNode: nodeTaggedServer,
|
||||
policy: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("alice@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Per-user common+localpart rules interleaved, then non-matching users.
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
|
||||
SSHUsers: map[string]string{"alice": "alice"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
|
||||
SSHUsers: map[string]string{"bob": "bob"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "localpart with root",
|
||||
targetNode: nodeTaggedServer,
|
||||
policy: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("alice@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com"), "root"},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Per-user common+localpart rules interleaved, then non-matching users.
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
|
||||
SSHUsers: map[string]string{"root": "root"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
|
||||
SSHUsers: map[string]string{"alice": "alice"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
|
||||
SSHUsers: map[string]string{"root": "root"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
|
||||
SSHUsers: map[string]string{"bob": "bob"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
|
||||
SSHUsers: map[string]string{"root": "root"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
|
||||
SSHUsers: map[string]string{"root": "root"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "localpart no matching users in domain",
|
||||
targetNode: nodeTaggedServer,
|
||||
policy: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("alice@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@nonexistent.com")},
|
||||
},
|
||||
},
|
||||
},
|
||||
// No localpart matches, but per-user common rules still emitted (root deny)
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "localpart with special chars in email",
|
||||
users: types.Users{
|
||||
{Name: "dave+sshuser", Email: "dave+sshuser@example.com", Model: gorm.Model{ID: 10}},
|
||||
},
|
||||
nodes: func() types.Nodes {
|
||||
specialUser := types.User{Name: "dave+sshuser", Email: "dave+sshuser@example.com", Model: gorm.Model{ID: 10}}
|
||||
n := types.Node{
|
||||
Hostname: "special-device",
|
||||
IPv4: createAddr("100.64.0.10"),
|
||||
UserID: new(specialUser.ID),
|
||||
User: &specialUser,
|
||||
}
|
||||
|
||||
assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match")
|
||||
return types.Nodes{&nodeTaggedServer, &n}
|
||||
}(),
|
||||
targetNode: nodeTaggedServer,
|
||||
policy: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("dave+sshuser@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Per-user common rule (root deny), then separate localpart rule.
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.10"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.10"}},
|
||||
SSHUsers: map[string]string{"dave+sshuser": "dave+sshuser"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "localpart excludes CLI users without email",
|
||||
users: types.Users{
|
||||
{Name: "dave", Model: gorm.Model{ID: 4}},
|
||||
},
|
||||
nodes: func() types.Nodes {
|
||||
cliUser := types.User{Name: "dave", Model: gorm.Model{ID: 4}}
|
||||
n := types.Node{
|
||||
Hostname: "dave-cli-device",
|
||||
IPv4: createAddr("100.64.0.5"),
|
||||
UserID: new(cliUser.ID),
|
||||
User: &cliUser,
|
||||
}
|
||||
|
||||
return
|
||||
return types.Nodes{&nodeTaggedServer, &n}
|
||||
}(),
|
||||
targetNode: nodeTaggedServer,
|
||||
policy: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("dave@")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
},
|
||||
},
|
||||
// No localpart matches (CLI user, no email), but implicit root deny emits common rule
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "localpart with multiple domains",
|
||||
targetNode: nodeTaggedServer,
|
||||
policy: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("alice@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{
|
||||
SSHUser("localpart:*@example.com"),
|
||||
SSHUser("localpart:*@other.com"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Per-user common+localpart rules interleaved:
|
||||
// alice/bob match *@example.com, charlie matches *@other.com.
|
||||
want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
|
||||
SSHUsers: map[string]string{"alice": "alice"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.3"}},
|
||||
SSHUsers: map[string]string{"bob": "bob"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.4"}},
|
||||
SSHUsers: map[string]string{"charlie": "charlie"},
|
||||
Action: acceptAction,
|
||||
},
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.5"}},
|
||||
SSHUsers: map[string]string{"root": ""},
|
||||
Action: acceptAction,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testUsers := users
|
||||
if tt.users != nil {
|
||||
testUsers = tt.users
|
||||
}
|
||||
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 1, "Should have exactly one SSH rule")
|
||||
testNodes := nodes
|
||||
if tt.nodes != nil {
|
||||
testNodes = tt.nodes
|
||||
}
|
||||
|
||||
rule := sshPolicy.Rules[0]
|
||||
assert.Equal(t, tt.wantSSHUsers, rule.SSHUsers, "SSH users mapping should match expected")
|
||||
require.NoError(t, tt.policy.validate())
|
||||
|
||||
// Verify principals are set correctly (should contain user2's untagged device IP since that's the source)
|
||||
require.Len(t, rule.Principals, 1)
|
||||
assert.Equal(t, "100.64.0.3", rule.Principals[0].NodeIP)
|
||||
got, err := tt.policy.compileSSHPolicy(
|
||||
"unused-server-url", testUsers, tt.targetNode.View(), testNodes.ViewSlice(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify action is set correctly
|
||||
assert.True(t, rule.Action.Accept)
|
||||
assert.True(t, rule.Action.AllowAgentForwarding)
|
||||
assert.True(t, rule.Action.AllowLocalPortForwarding)
|
||||
assert.True(t, rule.Action.AllowRemotePortForwarding)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("compileSSHPolicy() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -687,8 +1054,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := policy.validate()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, policy.validate())
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
@@ -700,6 +1066,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
||||
// Verify SSH users are correctly mapped
|
||||
expectedUsers := map[string]string{
|
||||
"ssh-it-user": "ssh-it-user",
|
||||
"root": "",
|
||||
}
|
||||
assert.Equal(t, expectedUsers, rule.SSHUsers)
|
||||
|
||||
@@ -833,28 +1200,28 @@ func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Validate policy
|
||||
err := policy.validate()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, policy.validate())
|
||||
|
||||
// Test SSH policy compilation for node2 (owned by user2, who is in the group)
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice())
|
||||
got, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 1)
|
||||
|
||||
rule := sshPolicy.Rules[0]
|
||||
want := &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.2"}},
|
||||
SSHUsers: map[string]string{"root": "", "ssh-it-user": "ssh-it-user"},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
// This was the failing assertion in integration test - sshUsers was empty
|
||||
assert.NotEmpty(t, rule.SSHUsers, "SSH users should not be empty")
|
||||
assert.Contains(t, rule.SSHUsers, "ssh-it-user", "ssh-it-user should be present in SSH users")
|
||||
assert.Equal(t, "ssh-it-user", rule.SSHUsers["ssh-it-user"], "ssh-it-user should map to itself")
|
||||
|
||||
// Verify that ssh-it-user is correctly mapped
|
||||
expectedUsers := map[string]string{
|
||||
"ssh-it-user": "ssh-it-user",
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("compileSSHPolicy() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
assert.Equal(t, expectedUsers, rule.SSHUsers, "ssh-it-user should be mapped to itself")
|
||||
}
|
||||
|
||||
// TestSSHJSONSerialization verifies that the SSH policy can be properly serialized
|
||||
@@ -885,40 +1252,38 @@ func TestSSHJSONSerialization(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := policy.validate()
|
||||
require.NoError(t, policy.validate())
|
||||
|
||||
got, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
want := &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.1"}},
|
||||
SSHUsers: map[string]string{"root": "", "ssh-it-user": "ssh-it-user", "ubuntu": "ubuntu", "admin": "admin"},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
// Serialize to JSON to verify structure
|
||||
jsonData, err := json.MarshalIndent(sshPolicy, "", " ")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse back to verify structure
|
||||
var parsed tailcfg.SSHPolicy
|
||||
|
||||
err = json.Unmarshal(jsonData, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the parsed structure has the expected SSH users
|
||||
require.Len(t, parsed.Rules, 1)
|
||||
rule := parsed.Rules[0]
|
||||
|
||||
expectedUsers := map[string]string{
|
||||
"ssh-it-user": "ssh-it-user",
|
||||
"ubuntu": "ubuntu",
|
||||
"admin": "admin",
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("compileSSHPolicy() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
assert.Equal(t, expectedUsers, rule.SSHUsers, "SSH users should survive JSON round-trip")
|
||||
|
||||
// Verify JSON contains the SSH users (not empty)
|
||||
assert.Contains(t, string(jsonData), `"ssh-it-user"`)
|
||||
assert.Contains(t, string(jsonData), `"ubuntu"`)
|
||||
assert.Contains(t, string(jsonData), `"admin"`)
|
||||
assert.NotContains(t, string(jsonData), `"sshUsers": {}`, "SSH users should not be empty")
|
||||
assert.NotContains(t, string(jsonData), `"sshUsers": null`, "SSH users should not be null")
|
||||
// Verify JSON round-trip preserves the full structure
|
||||
jsonData, err := json.MarshalIndent(got, "", " ")
|
||||
require.NoError(t, err)
|
||||
|
||||
var parsed tailcfg.SSHPolicy
|
||||
require.NoError(t, json.Unmarshal(jsonData, &parsed))
|
||||
|
||||
if diff := cmp.Diff(want, &parsed); diff != "" {
|
||||
t.Errorf("JSON round-trip mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
@@ -2244,6 +2609,89 @@ func TestCompileSSHPolicy_CheckPeriodVariants(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPSetToPrincipals(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ips []string // IPs to add to the set
|
||||
want []*tailcfg.SSHPrincipal
|
||||
}{
|
||||
{
|
||||
name: "nil input",
|
||||
ips: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single IPv4",
|
||||
ips: []string{"100.64.0.1"},
|
||||
want: []*tailcfg.SSHPrincipal{{NodeIP: "100.64.0.1"}},
|
||||
},
|
||||
{
|
||||
name: "multiple IPs",
|
||||
ips: []string{"100.64.0.1", "100.64.0.2"},
|
||||
want: []*tailcfg.SSHPrincipal{
|
||||
{NodeIP: "100.64.0.1"},
|
||||
{NodeIP: "100.64.0.2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "IPv6",
|
||||
ips: []string{"fd7a:115c:a1e0::1"},
|
||||
want: []*tailcfg.SSHPrincipal{{NodeIP: "fd7a:115c:a1e0::1"}},
|
||||
},
|
||||
{
|
||||
name: "mixed IPv4 and IPv6",
|
||||
ips: []string{"100.64.0.1", "fd7a:115c:a1e0::1"},
|
||||
want: []*tailcfg.SSHPrincipal{
|
||||
{NodeIP: "100.64.0.1"},
|
||||
{NodeIP: "fd7a:115c:a1e0::1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var ipSet *netipx.IPSet
|
||||
|
||||
if tt.ips != nil {
|
||||
var builder netipx.IPSetBuilder
|
||||
|
||||
for _, ip := range tt.ips {
|
||||
addr := netip.MustParseAddr(ip)
|
||||
builder.Add(addr)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
ipSet, err = builder.IPSet()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
got := ipSetToPrincipals(ipSet)
|
||||
|
||||
// Sort for deterministic comparison
|
||||
sortPrincipals := func(p []*tailcfg.SSHPrincipal) {
|
||||
slices.SortFunc(p, func(a, b *tailcfg.SSHPrincipal) int {
|
||||
if a.NodeIP < b.NodeIP {
|
||||
return -1
|
||||
}
|
||||
|
||||
if a.NodeIP > b.NodeIP {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
})
|
||||
}
|
||||
sortPrincipals(got)
|
||||
sortPrincipals(tt.want)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("ipSetToPrincipals() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCheckParams(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Name: "user1", Model: gorm.Model{ID: 1}},
|
||||
@@ -2424,3 +2872,219 @@ func TestSSHCheckParams(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLocalparts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
entries []SSHUser
|
||||
users types.Users
|
||||
want map[uint]string
|
||||
}{
|
||||
{
|
||||
name: "no entries",
|
||||
entries: nil,
|
||||
users: types.Users{{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single match",
|
||||
entries: []SSHUser{"localpart:*@example.com"},
|
||||
users: types.Users{
|
||||
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
|
||||
},
|
||||
want: map[uint]string{1: "alice"},
|
||||
},
|
||||
{
|
||||
name: "domain mismatch",
|
||||
entries: []SSHUser{"localpart:*@other.com"},
|
||||
users: types.Users{
|
||||
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
|
||||
},
|
||||
want: map[uint]string{},
|
||||
},
|
||||
{
|
||||
name: "case insensitive domain",
|
||||
entries: []SSHUser{"localpart:*@EXAMPLE.COM"},
|
||||
users: types.Users{
|
||||
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
|
||||
},
|
||||
want: map[uint]string{1: "alice"},
|
||||
},
|
||||
{
|
||||
name: "user without email skipped",
|
||||
entries: []SSHUser{"localpart:*@example.com"},
|
||||
users: types.Users{
|
||||
{Name: "cli-user", Model: gorm.Model{ID: 1}},
|
||||
},
|
||||
want: map[uint]string{},
|
||||
},
|
||||
{
|
||||
name: "multiple domains multiple users",
|
||||
entries: []SSHUser{
|
||||
"localpart:*@example.com",
|
||||
"localpart:*@other.com",
|
||||
},
|
||||
users: types.Users{
|
||||
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
|
||||
{Name: "bob", Email: "bob@other.com", Model: gorm.Model{ID: 2}},
|
||||
{Name: "charlie", Email: "charlie@nope.com", Model: gorm.Model{ID: 3}},
|
||||
},
|
||||
want: map[uint]string{1: "alice", 2: "bob"},
|
||||
},
|
||||
{
|
||||
name: "special chars in local part",
|
||||
entries: []SSHUser{"localpart:*@example.com"},
|
||||
users: types.Users{
|
||||
{Name: "d", Email: "dave+ssh@example.com", Model: gorm.Model{ID: 1}},
|
||||
},
|
||||
want: map[uint]string{1: "dave+ssh"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := resolveLocalparts(tt.entries, tt.users)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("resolveLocalparts() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupSourcesByUser(t *testing.T) {
|
||||
alice := types.User{
|
||||
Name: "alice", Email: "alice@example.com",
|
||||
Model: gorm.Model{ID: 1},
|
||||
}
|
||||
bob := types.User{
|
||||
Name: "bob", Email: "bob@example.com",
|
||||
Model: gorm.Model{ID: 2},
|
||||
}
|
||||
|
||||
nodeAlice := types.Node{
|
||||
Hostname: "alice-dev",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: &alice.ID,
|
||||
User: &alice,
|
||||
}
|
||||
nodeBob := types.Node{
|
||||
Hostname: "bob-dev",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: &bob.ID,
|
||||
User: &bob,
|
||||
}
|
||||
nodeTagged := types.Node{
|
||||
Hostname: "tagged",
|
||||
IPv4: createAddr("100.64.0.3"),
|
||||
UserID: &alice.ID,
|
||||
User: &alice,
|
||||
Tags: []string{"tag:server"},
|
||||
}
|
||||
|
||||
// Build an IPSet that includes all node IPs
|
||||
allIPs := func() *netipx.IPSet {
|
||||
var b netipx.IPSetBuilder
|
||||
b.AddPrefix(netip.MustParsePrefix("100.64.0.0/24"))
|
||||
|
||||
s, _ := b.IPSet()
|
||||
|
||||
return s
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nodes types.Nodes
|
||||
srcIPs *netipx.IPSet
|
||||
wantUIDs []uint
|
||||
wantUserCount int
|
||||
wantHasTagged bool
|
||||
wantTaggedLen int
|
||||
wantAliceIP string
|
||||
wantBobIP string
|
||||
wantTaggedIP string
|
||||
}{
|
||||
{
|
||||
name: "user-owned only",
|
||||
nodes: types.Nodes{&nodeAlice, &nodeBob},
|
||||
srcIPs: allIPs,
|
||||
wantUIDs: []uint{1, 2},
|
||||
wantUserCount: 2,
|
||||
wantAliceIP: "100.64.0.1",
|
||||
wantBobIP: "100.64.0.2",
|
||||
},
|
||||
{
|
||||
name: "mixed user and tagged",
|
||||
nodes: types.Nodes{&nodeAlice, &nodeTagged},
|
||||
srcIPs: allIPs,
|
||||
wantUIDs: []uint{1},
|
||||
wantUserCount: 1,
|
||||
wantHasTagged: true,
|
||||
wantTaggedLen: 1,
|
||||
wantAliceIP: "100.64.0.1",
|
||||
wantTaggedIP: "100.64.0.3",
|
||||
},
|
||||
{
|
||||
name: "tagged only",
|
||||
nodes: types.Nodes{&nodeTagged},
|
||||
srcIPs: allIPs,
|
||||
wantUIDs: nil,
|
||||
wantUserCount: 0,
|
||||
wantHasTagged: true,
|
||||
wantTaggedLen: 1,
|
||||
},
|
||||
{
|
||||
name: "node not in srcIPs excluded",
|
||||
nodes: types.Nodes{&nodeAlice, &nodeBob},
|
||||
srcIPs: func() *netipx.IPSet {
|
||||
var b netipx.IPSetBuilder
|
||||
b.Add(netip.MustParseAddr("100.64.0.1")) // only alice
|
||||
|
||||
s, _ := b.IPSet()
|
||||
|
||||
return s
|
||||
}(),
|
||||
wantUIDs: []uint{1},
|
||||
wantUserCount: 1,
|
||||
wantAliceIP: "100.64.0.1",
|
||||
},
|
||||
{
|
||||
name: "sorted by user ID",
|
||||
nodes: types.Nodes{&nodeBob, &nodeAlice}, // reverse order
|
||||
srcIPs: allIPs,
|
||||
wantUIDs: []uint{1, 2}, // still sorted
|
||||
wantUserCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sortedUIDs, byUser, tagged := groupSourcesByUser(
|
||||
tt.nodes.ViewSlice(), tt.srcIPs,
|
||||
)
|
||||
|
||||
assert.Equal(t, tt.wantUIDs, sortedUIDs, "sortedUIDs")
|
||||
assert.Len(t, byUser, tt.wantUserCount, "byUser count")
|
||||
|
||||
if tt.wantHasTagged {
|
||||
assert.Len(t, tagged, tt.wantTaggedLen, "tagged count")
|
||||
} else {
|
||||
assert.Empty(t, tagged, "tagged should be empty")
|
||||
}
|
||||
|
||||
if tt.wantAliceIP != "" {
|
||||
require.Contains(t, byUser, uint(1))
|
||||
assert.Equal(t, tt.wantAliceIP, byUser[1][0].NodeIP)
|
||||
}
|
||||
|
||||
if tt.wantBobIP != "" {
|
||||
require.Contains(t, byUser, uint(2))
|
||||
assert.Equal(t, tt.wantBobIP, byUser[2][0].NodeIP)
|
||||
}
|
||||
|
||||
if tt.wantTaggedIP != "" {
|
||||
require.NotEmpty(t, tagged)
|
||||
assert.Equal(t, tt.wantTaggedIP, tagged[0].NodeIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ var (
|
||||
ErrSSHCheckPeriodBelowMin = errors.New("checkPeriod below minimum of 1 minute")
|
||||
ErrSSHCheckPeriodAboveMax = errors.New("checkPeriod above maximum of 168 hours (1 week)")
|
||||
ErrSSHCheckPeriodOnNonCheck = errors.New("checkPeriod is only valid with action \"check\"")
|
||||
ErrInvalidLocalpart = errors.New("invalid localpart format, must be localpart:*@<domain>")
|
||||
)
|
||||
|
||||
// SSH check period constants per Tailscale docs:
|
||||
@@ -1965,6 +1966,14 @@ func (p *Policy) validate() error {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if user.IsLocalpart() {
|
||||
_, err := user.ParseLocalpart()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, src := range ssh.Sources {
|
||||
@@ -2191,6 +2200,7 @@ type SSH struct {
|
||||
Destinations SSHDstAliases `json:"dst"`
|
||||
Users SSHUsers `json:"users"`
|
||||
CheckPeriod *SSHCheckPeriod `json:"checkPeriod,omitempty"`
|
||||
AcceptEnv []string `json:"acceptEnv,omitempty"`
|
||||
}
|
||||
|
||||
// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule.
|
||||
@@ -2342,6 +2352,11 @@ type SSHDstAliases []Alias
|
||||
|
||||
type SSHUsers []SSHUser
|
||||
|
||||
// SSHUserLocalpartPrefix is the prefix for localpart SSH user entries.
|
||||
// Format: localpart:*@<domain>
|
||||
// See: https://tailscale.com/docs/features/tailscale-ssh#users
|
||||
const SSHUserLocalpartPrefix = "localpart:"
|
||||
|
||||
func (u SSHUsers) ContainsRoot() bool {
|
||||
return slices.Contains(u, "root")
|
||||
}
|
||||
@@ -2350,9 +2365,25 @@ func (u SSHUsers) ContainsNonRoot() bool {
|
||||
return slices.Contains(u, SSHUser(AutoGroupNonRoot))
|
||||
}
|
||||
|
||||
// ContainsLocalpart returns true if any entry has the localpart: prefix.
|
||||
func (u SSHUsers) ContainsLocalpart() bool {
|
||||
return slices.ContainsFunc(u, func(user SSHUser) bool {
|
||||
return user.IsLocalpart()
|
||||
})
|
||||
}
|
||||
|
||||
// NormalUsers returns all SSH users that are not root, autogroup:nonroot,
|
||||
// or localpart: entries.
|
||||
func (u SSHUsers) NormalUsers() []SSHUser {
|
||||
return slicesx.Filter(nil, u, func(user SSHUser) bool {
|
||||
return user != "root" && user != SSHUser(AutoGroupNonRoot)
|
||||
return user != "root" && user != SSHUser(AutoGroupNonRoot) && !user.IsLocalpart()
|
||||
})
|
||||
}
|
||||
|
||||
// LocalpartEntries returns only the localpart: prefixed entries.
|
||||
func (u SSHUsers) LocalpartEntries() []SSHUser {
|
||||
return slicesx.Filter(nil, u, func(user SSHUser) bool {
|
||||
return user.IsLocalpart()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2362,6 +2393,41 @@ func (u SSHUser) String() string {
|
||||
return string(u)
|
||||
}
|
||||
|
||||
// IsLocalpart returns true if the SSHUser has the localpart: prefix.
|
||||
func (u SSHUser) IsLocalpart() bool {
|
||||
return strings.HasPrefix(string(u), SSHUserLocalpartPrefix)
|
||||
}
|
||||
|
||||
// ParseLocalpart validates and extracts the domain from a localpart: entry.
|
||||
// The expected format is localpart:*@<domain>.
|
||||
// Returns the domain part or an error if the format is invalid.
|
||||
func (u SSHUser) ParseLocalpart() (string, error) {
|
||||
if !u.IsLocalpart() {
|
||||
return "", fmt.Errorf("%w: missing prefix %q in %q", ErrInvalidLocalpart, SSHUserLocalpartPrefix, u)
|
||||
}
|
||||
|
||||
pattern := strings.TrimPrefix(string(u), SSHUserLocalpartPrefix)
|
||||
|
||||
// Must be *@<domain>
|
||||
atIdx := strings.LastIndex(pattern, "@")
|
||||
if atIdx < 0 {
|
||||
return "", fmt.Errorf("%w: missing @ in %q", ErrInvalidLocalpart, u)
|
||||
}
|
||||
|
||||
localPart := pattern[:atIdx]
|
||||
domain := pattern[atIdx+1:]
|
||||
|
||||
if localPart != "*" {
|
||||
return "", fmt.Errorf("%w: local part must be *, got %q in %q", ErrInvalidLocalpart, localPart, u)
|
||||
}
|
||||
|
||||
if domain == "" {
|
||||
return "", fmt.Errorf("%w: empty domain in %q", ErrInvalidLocalpart, u)
|
||||
}
|
||||
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
// MarshalJSON marshals the SSHUser to JSON.
|
||||
func (u SSHUser) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(string(u))
|
||||
|
||||
@@ -1794,6 +1794,33 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-valid",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:*@example.com"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
want: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:prod"): Owners{up("admin@")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:prod")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2754-bracketed-ipv6-multiple-ports",
|
||||
input: `
|
||||
@@ -1825,6 +1852,33 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-with-other-users",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:*@example.com", "root", "autogroup:nonroot"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
want: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:prod"): Owners{up("admin@")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:prod")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com"), "root", SSHUser(AutoGroupNonRoot)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2754-bracketed-ipv6-wildcard-port",
|
||||
input: `
|
||||
@@ -1951,6 +2005,51 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
`,
|
||||
wantErr: "square brackets are only valid around IPv6 addresses",
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-invalid-no-at-sign",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:foo"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
wantErr: "invalid localpart format",
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-invalid-non-wildcard",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:alice@example.com"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
wantErr: "invalid localpart format",
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-invalid-empty-domain",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:*@"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
wantErr: "invalid localpart format",
|
||||
},
|
||||
}
|
||||
|
||||
cmps := append(util.Comparers,
|
||||
@@ -2635,56 +2734,63 @@ func TestResolveAutoApprovers(t *testing.T) {
|
||||
|
||||
func TestSSHUsers_NormalUsers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
users SSHUsers
|
||||
expected []SSHUser
|
||||
name string
|
||||
users SSHUsers
|
||||
want []SSHUser
|
||||
}{
|
||||
{
|
||||
name: "empty users",
|
||||
users: SSHUsers{},
|
||||
expected: []SSHUser{},
|
||||
name: "empty users",
|
||||
users: SSHUsers{},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "only root",
|
||||
users: SSHUsers{"root"},
|
||||
expected: []SSHUser{},
|
||||
name: "only root",
|
||||
users: SSHUsers{"root"},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "only autogroup:nonroot",
|
||||
users: SSHUsers{SSHUser(AutoGroupNonRoot)},
|
||||
expected: []SSHUser{},
|
||||
name: "only autogroup:nonroot",
|
||||
users: SSHUsers{SSHUser(AutoGroupNonRoot)},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "only normal user",
|
||||
users: SSHUsers{"ssh-it-user"},
|
||||
expected: []SSHUser{"ssh-it-user"},
|
||||
name: "only normal user",
|
||||
users: SSHUsers{"ssh-it-user"},
|
||||
want: []SSHUser{"ssh-it-user"},
|
||||
},
|
||||
{
|
||||
name: "multiple normal users",
|
||||
users: SSHUsers{"ubuntu", "admin", "user1"},
|
||||
expected: []SSHUser{"ubuntu", "admin", "user1"},
|
||||
name: "multiple normal users",
|
||||
users: SSHUsers{"ubuntu", "admin", "user1"},
|
||||
want: []SSHUser{"ubuntu", "admin", "user1"},
|
||||
},
|
||||
{
|
||||
name: "mixed users with root",
|
||||
users: SSHUsers{"ubuntu", "root", "admin"},
|
||||
expected: []SSHUser{"ubuntu", "admin"},
|
||||
name: "mixed users with root",
|
||||
users: SSHUsers{"ubuntu", "root", "admin"},
|
||||
want: []SSHUser{"ubuntu", "admin"},
|
||||
},
|
||||
{
|
||||
name: "mixed users with autogroup:nonroot",
|
||||
users: SSHUsers{"ubuntu", SSHUser(AutoGroupNonRoot), "admin"},
|
||||
expected: []SSHUser{"ubuntu", "admin"},
|
||||
name: "mixed users with autogroup:nonroot",
|
||||
users: SSHUsers{"ubuntu", SSHUser(AutoGroupNonRoot), "admin"},
|
||||
want: []SSHUser{"ubuntu", "admin"},
|
||||
},
|
||||
{
|
||||
name: "mixed users with both root and autogroup:nonroot",
|
||||
users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), "admin"},
|
||||
expected: []SSHUser{"ubuntu", "admin"},
|
||||
name: "mixed users with both root and autogroup:nonroot",
|
||||
users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), "admin"},
|
||||
want: []SSHUser{"ubuntu", "admin"},
|
||||
},
|
||||
{
|
||||
name: "excludes localpart entries",
|
||||
users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), SSHUser("localpart:*@example.com"), "admin"},
|
||||
want: []SSHUser{"ubuntu", "admin"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.users.NormalUsers()
|
||||
assert.ElementsMatch(t, tt.expected, result, "NormalUsers() should return expected normal users")
|
||||
got := tt.users.NormalUsers()
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("NormalUsers() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2761,6 +2867,142 @@ func TestSSHUsers_ContainsNonRoot(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHUsers_ContainsLocalpart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
users SSHUsers
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty users",
|
||||
users: SSHUsers{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "contains localpart",
|
||||
users: SSHUsers{SSHUser("localpart:*@example.com")},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "does not contain localpart",
|
||||
users: SSHUsers{"ubuntu", "admin", "root"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "contains localpart among others",
|
||||
users: SSHUsers{"ubuntu", SSHUser("localpart:*@example.com"), "admin"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple localpart entries",
|
||||
users: SSHUsers{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.users.ContainsLocalpart()
|
||||
assert.Equal(t, tt.expected, result, "ContainsLocalpart() should return expected result")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHUsers_LocalpartEntries(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
users SSHUsers
|
||||
want []SSHUser
|
||||
}{
|
||||
{
|
||||
name: "empty users",
|
||||
users: SSHUsers{},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no localpart entries",
|
||||
users: SSHUsers{"root", "ubuntu", SSHUser(AutoGroupNonRoot)},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single localpart entry",
|
||||
users: SSHUsers{"root", SSHUser("localpart:*@example.com"), "ubuntu"},
|
||||
want: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
{
|
||||
name: "multiple localpart entries",
|
||||
users: SSHUsers{SSHUser("localpart:*@a.com"), "root", SSHUser("localpart:*@b.com")},
|
||||
want: []SSHUser{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.users.LocalpartEntries()
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("LocalpartEntries() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHUser_ParseLocalpart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user SSHUser
|
||||
expectedDomain string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid localpart",
|
||||
user: SSHUser("localpart:*@example.com"),
|
||||
expectedDomain: "example.com",
|
||||
},
|
||||
{
|
||||
name: "valid localpart with subdomain",
|
||||
user: SSHUser("localpart:*@corp.example.com"),
|
||||
expectedDomain: "corp.example.com",
|
||||
},
|
||||
{
|
||||
name: "missing prefix",
|
||||
user: SSHUser("ubuntu"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing @ sign",
|
||||
user: SSHUser("localpart:foo"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-wildcard local part",
|
||||
user: SSHUser("localpart:alice@example.com"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty domain",
|
||||
user: SSHUser("localpart:*@"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "just prefix",
|
||||
user: SSHUser("localpart:"),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
domain, err := tt.user.ParseLocalpart()
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedDomain, domain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustIPSet(prefixes ...string) *netipx.IPSet {
|
||||
var builder netipx.IPSetBuilder
|
||||
for _, p := range prefixes {
|
||||
|
||||
Reference in New Issue
Block a user