policy/v2: align SSH rule validation with Tailscale

Trim whitespace on action, users, src, dst; reject empty/wildcard users; reject empty acceptEnv; reject negative and over-max checkPeriod; reject hosts-table aliases as SSH dst; reject non-ASCII tag names; tolerate tag-owner cycles; match group-nesting wording.
This commit is contained in:
Kristoffer Dalby
2026-05-13 14:09:17 +00:00
parent 4ad200ab73
commit d600090f2c
70 changed files with 1361660 additions and 308 deletions

View File

@@ -1317,7 +1317,7 @@ func TestSSHPolicyRules(t *testing.T) {
]
}`,
expectErr: true,
errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`,
errorMessage: `"invalid" is not a valid action`,
},
{
name: "invalid-check-period",
@@ -1341,10 +1341,15 @@ func TestSSHPolicyRules(t *testing.T) {
]
}`,
expectErr: true,
errorMessage: "not a valid duration string",
errorMessage: `time: invalid duration "invalid"`,
},
// `autogroup:invalid` as an SSH user is no longer rejected:
// SaaS treats every `autogroup:*` user-string as a literal
// label and compiles it into the SSHUsers map. The compat
// suite covers this via ssh-malformed-user-autogroup-* — no
// dedicated case is needed here.
{
name: "unsupported-autogroup",
name: "ssh-user-unknown-autogroup-as-literal",
targetNode: taggedClient,
peers: types.Nodes{&nodeUser2},
policy: `{
@@ -1363,8 +1368,23 @@ func TestSSHPolicyRules(t *testing.T) {
}
]
}`,
expectErr: true,
errorMessage: "autogroup not supported for SSH user",
wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{
{NodeIP: "100.64.0.2"},
},
SSHUsers: map[string]string{
"autogroup:invalid": "autogroup:invalid",
"root": "",
},
Action: &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
},
},
}},
},
{
name: "autogroup-nonroot-should-use-wildcard-with-root-excluded",

View File

@@ -1467,27 +1467,13 @@ func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types.
}
}
// flattenTags flattens the TagOwners by resolving nested tags and detecting cycles.
// It will return a Owners list where all the Tag types have been resolved to their underlying Owners.
// flattenTags resolves nested tag-owner references. Cycles
// (tag:a -> tag:b -> tag:a, or tag:a -> tag:a) drop the cycle-causing
// edge and contribute no addresses; non-cycle owners on the cycled tags
// still resolve. Undefined-tag references remain a hard error.
func flattenTags(tagOwners TagOwners, tag Tag, visiting map[Tag]bool, chain []Tag) (Owners, error) {
if visiting[tag] {
cycleStart := 0
for i, t := range chain {
if t == tag {
cycleStart = i
break
}
}
cycleTags := make([]string, len(chain[cycleStart:]))
for i, t := range chain[cycleStart:] {
cycleTags[i] = string(t)
}
slices.Sort(cycleTags)
return nil, fmt.Errorf("%w: %s", ErrCircularReference, strings.Join(cycleTags, " -> "))
return nil, nil
}
visiting[tag] = true

View File

@@ -1,18 +1,9 @@
// This file implements a data-driven test runner for SSH compatibility tests.
// It loads HuJSON golden files from testdata/ssh_results/ssh-*.hujson, captured
// from a Tailscale-hosted control plane, and compares headscale's SSH policy
// compilation against the captured SSH rules.
//
// Each file is a testcapture.Capture containing:
// - The full policy that was POSTed to Tailscale SaaS (we use tf.Input.FullPolicy
// directly instead of reconstructing it from a sub-section)
// - The expected SSH rules for each of the 8 test nodes (in tf.Captures[name].SSHRules)
//
// Tests known to fail due to unimplemented features or known differences are
// skipped with a TODO comment explaining the root cause.
//
// Test data source: testdata/ssh_results/ssh-*.hujson
// Source format: github.com/juanfont/headscale/hscontrol/types/testcapture
// Replay golden HuJSON captures under testdata/ssh_results/ssh-*.hujson:
// the 200 path compares headscale's compileSSHPolicy output node-by-node
// against the captured SSHRules; the non-200 path requires headscale to
// reject the same input with the captured error body as a substring.
// Divergences are listed in sshSkipReasons (200) and sshRejectSkipReasons
// (non-200) with the engine gap each represents.
package v2
@@ -31,16 +22,10 @@ import (
"tailscale.com/tailcfg"
)
// setupSSHDataCompatUsers returns the 3 test users for SSH data-driven
// compatibility tests. Users get norse-god names; nodes get original-151
// pokémon names — matching the anonymized identifiers the capture
// tool writes into the capture files.
//
// odin and freya live on @example.com; thor lives on @example.org so
// that "localpart:*@example.com" resolves to exactly two users
// (matching SaaS output) and the "user on a different email domain"
// case stays covered by scenarios like ssh-d1 that use
// "localpart:*@example.org".
// setupSSHDataCompatUsers returns three users straddling two email
// domains so that "localpart:*@example.com" resolves to exactly two
// users (odin, freya) and the cross-domain case stays covered through
// thor on @example.org.
func setupSSHDataCompatUsers() types.Users {
return types.Users{
{
@@ -126,39 +111,29 @@ func loadSSHTestFile(t *testing.T, path string) *testcapture.Capture {
return c
}
// sshSkipReasons documents why each skipped test fails and what needs to be
// fixed. Tests are grouped by root cause to identify high-impact changes.
// sshSkipReasons documents captures the upstream control plane accepts
// but headscale cannot yet represent. Each entry names the feature gap.
var sshSkipReasons = map[string]string{
// USER_PASSKEY_WILDCARD (2 tests)
//
// headscale does not support passkey authentication and has no
// equivalent for the user:*@passkey wildcard pattern.
"ssh-b5": "user:*@passkey wildcard not supported in headscale",
"ssh-d10": "user:*@passkey wildcard not supported in headscale",
// DOMAIN_NOT_ASSOCIATED (4 tests)
//
// SaaS validates that email domains in user:*@domain and
// localpart:*@domain expressions are configured tailnet domains.
// headscale has no concept of "associated tailnet domains" — it
// only has users with email addresses. These policies are
// legitimately rejected by SaaS but not by headscale.
"ssh-b4": "domain validation: headscale has no 'associated tailnet domains' concept",
"ssh-d1": "domain validation: headscale has no 'associated tailnet domains' concept",
"ssh-e1": "domain validation: headscale has no 'associated tailnet domains' concept",
"ssh-e2": "domain validation: headscale has no 'associated tailnet domains' concept",
"ssh-b5": "headscale has no passkey authentication; user:*@passkey wildcard unsupported",
"ssh-d10": "headscale has no passkey authentication; user:*@passkey wildcard unsupported",
}
// TestSSHDataCompat is a data-driven test that loads all ssh-*.hujson test
// files captured from Tailscale SaaS and compares headscale's SSH policy
// compilation against the real Tailscale behavior.
//
// Each capture file contains:
// - The full policy that was POSTed to the SaaS API (Input.FullPolicy)
// - Expected SSH rules per node (Captures[name].SSHRules)
//
// The test converts Tailscale user email formats to headscale format and runs
// the captured policy through unmarshalPolicy and compileSSHPolicy.
// sshRejectSkipReasons documents captures the upstream control plane
// rejects for reasons headscale cannot apply. Each entry names the
// feature gap.
var sshRejectSkipReasons = map[string]string{
"ssh-b4": "headscale has no associated-tailnet-domains config; user:*@domain / localpart:*@domain are not domain-validated",
"ssh-d1": "headscale has no associated-tailnet-domains config; user:*@domain / localpart:*@domain are not domain-validated",
"ssh-e1": "headscale has no associated-tailnet-domains config; user:*@domain / localpart:*@domain are not domain-validated",
"ssh-e2": "headscale has no associated-tailnet-domains config; user:*@domain / localpart:*@domain are not domain-validated",
"ssh-malformed-user-localpart-multi-glob": "headscale has no associated-tailnet-domains config; user:*@domain / localpart:*@domain are not domain-validated",
}
// TestSSHDataCompat loads every ssh-*.hujson capture, parses the policy
// it pinned, and compiles the same per-node SSH rules to compare against
// the captured shape. Non-200 captures replay the rejection path: the
// recorded error body must appear as a substring of headscale's
// rejection.
func TestSSHDataCompat(t *testing.T) {
t.Parallel()
@@ -192,39 +167,61 @@ func TestSSHDataCompat(t *testing.T) {
t.Run(tf.TestID, func(t *testing.T) {
t.Parallel()
// Check if this test is in the skip list
if reason, ok := sshSkipReasons[tf.TestID]; ok {
t.Skipf(
"TODO: %s — see sshSkipReasons comments for details",
reason,
)
return
}
// SaaS rejected this policy — verify headscale also rejects it.
if tf.Error {
testSSHError(t, tf)
return
}
// Build nodes per-scenario from this file's topology.
// the capture tool uses clean-slate mode, so each scenario has
// different node IPs.
// Each capture pins its own topology IPs, so nodes are
// rebuilt from the capture rather than a shared fixture.
nodes := buildGrantsNodesFromCapture(users, tf)
// Use the captured full policy as is. Anonymization in
// captures already rewrite SaaS emails to @example.com.
policyJSON := tf.Input.FullPolicy
policyJSON := []byte(tf.Input.FullPolicy)
pol, err := unmarshalPolicy([]byte(policyJSON))
if tf.Input.APIResponseCode != 200 {
if reason, ok := sshRejectSkipReasons[tf.TestID]; ok {
t.Skipf("skipping: %s", reason)
return
}
pm, parseErr := NewPolicyManager(policyJSON, users, nodes.ViewSlice())
var got error
switch {
case parseErr != nil:
got = parseErr
default:
_, setErr := pm.SetPolicy(policyJSON)
got = setErr
}
require.Error(t, got, "tailscale rejected; headscale must reject too")
if tf.Input.APIResponseBody == nil ||
tf.Input.APIResponseBody.Message == "" {
return
}
want := tf.Input.APIResponseBody.Message
if !strings.Contains(got.Error(), want) {
t.Errorf(
"error body mismatch\n tailscale wants: %q\n headscale got: %q",
want,
got.Error(),
)
}
return
}
if reason, ok := sshSkipReasons[tf.TestID]; ok {
t.Skipf("skipping: %s", reason)
return
}
pol, err := unmarshalPolicy(policyJSON)
require.NoError(
t,
err,
"%s: policy should parse successfully\nPolicy:\n%s",
tf.TestID,
policyJSON,
tf.Input.FullPolicy,
)
for nodeName, capture := range tf.Captures {
@@ -309,97 +306,3 @@ func TestSSHDataCompat(t *testing.T) {
})
}
}
// sshErrorMessageMap maps Tailscale SaaS error substrings to headscale
// equivalents where the wording differs but the meaning is the same.
var sshErrorMessageMap = map[string]string{}
// testSSHError verifies that an invalid policy produces the expected error.
func testSSHError(t *testing.T, tf *testcapture.Capture) {
t.Helper()
policyJSON := []byte(tf.Input.FullPolicy)
pol, err := unmarshalPolicy(policyJSON)
if err != nil {
// Parse-time error.
if tf.Input.APIResponseBody != nil {
wantMsg := tf.Input.APIResponseBody.Message
if wantMsg != "" {
assertSSHErrorContains(t, err, wantMsg, tf.TestID)
}
}
return
}
err = pol.validate()
if err != nil {
if tf.Input.APIResponseBody != nil {
wantMsg := tf.Input.APIResponseBody.Message
if wantMsg != "" {
assertSSHErrorContains(t, err, wantMsg, tf.TestID)
}
}
return
}
t.Errorf(
"%s: expected error but policy parsed and validated successfully",
tf.TestID,
)
}
// assertSSHErrorContains checks that an error message matches the
// expected Tailscale SaaS message, using progressive fallbacks:
// 1. Direct substring match
// 2. Mapped equivalent from sshErrorMessageMap
// 3. Key-part extraction (tags, autogroups)
// 4. t.Errorf on no match (strict)
func assertSSHErrorContains(
t *testing.T,
err error,
wantMsg string,
testID string,
) {
t.Helper()
errStr := err.Error()
// 1. Direct substring match.
if strings.Contains(errStr, wantMsg) {
return
}
// 2. Mapped equivalent.
for tsKey, hsKey := range sshErrorMessageMap {
if strings.Contains(wantMsg, tsKey) &&
strings.Contains(errStr, hsKey) {
return
}
}
// 3. Key-part extraction.
for _, part := range []string{
"autogroup:",
"tag:",
"undefined",
"not valid",
} {
if strings.Contains(wantMsg, part) &&
strings.Contains(errStr, part) {
return
}
}
// 4. No match — strict failure.
t.Errorf(
"%s: error message mismatch\n"+
" want (tailscale): %q\n"+
" got (headscale): %q",
testID,
wantMsg,
errStr,
)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,6 @@ import (
"github.com/go-json-experiment/json"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"github.com/tailscale/hujson"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
@@ -35,8 +34,6 @@ const Wildcard = Asterix(0)
var ErrAutogroupSelfRequiresPerNodeResolution = errors.New("autogroup:self requires per-node resolution and cannot be resolved in this context")
var ErrCircularReference = errors.New("circular reference detected")
var ErrUndefinedTagReference = errors.New("references undefined tag")
// SSH validation errors.
@@ -46,17 +43,25 @@ var (
ErrSSHAutogroupSelfRequiresUserSource = errors.New("autogroup:self destination requires source to contain only users or groups, not tags or autogroup:tagged")
ErrSSHTagSourceToAutogroupMember = errors.New("tags in SSH source cannot access autogroup:member (user-owned devices)")
ErrSSHWildcardDestination = errors.New("wildcard (*) is not supported as SSH destination")
ErrSSHCheckPeriodBelowMin = errors.New("checkPeriod below minimum of 1 minute")
ErrSSHCheckPeriodAboveMax = errors.New("checkPeriod above maximum of 168 hours (1 week)")
ErrSSHCheckPeriodAboveMax = errors.New("is above the max (168h)")
ErrSSHCheckPeriodNegative = errors.New("must be a positive duration")
ErrSSHCheckPeriodOnNonCheck = errors.New("checkPeriod is only valid with action \"check\"")
ErrInvalidLocalpart = errors.New("invalid localpart format, must be localpart:*@<domain>")
ErrSSHUsersMustBeSpecified = errors.New("users must be specified")
ErrSSHUserInvalid = errors.New("is not valid")
ErrSSHAcceptEnvEmpty = errors.New("acceptEnv values cannot be empty")
ErrSSHActionMustBeSpecified = errors.New("action must be specified")
ErrSSHActionInvalid = errors.New("is not a valid action")
ErrSSHDestinationHostAlias = errors.New("invalid dst")
ErrTagNameMustStartWithLetter = errors.New("tag names must start with a letter, after 'tag:'")
ErrGroupMembersCannotBeRecursive = errors.New("group members cannot be recursive")
)
// SSH check period constants per Tailscale docs:
// https://tailscale.com/docs/features/tailscale-ssh#checkperiod
// SaaS imposes no minimum (0s is accepted) so headscale matches.
const (
SSHCheckPeriodDefault = 12 * time.Hour
SSHCheckPeriodMin = time.Minute
SSHCheckPeriodMax = 168 * time.Hour
)
@@ -120,7 +125,6 @@ var (
ErrGroupNotDefined = errors.New("group not defined in policy")
ErrInvalidGroupMember = errors.New("invalid group member type")
ErrGroupValueNotArray = errors.New("group value must be an array of users")
ErrNestedGroups = errors.New("nested groups are not allowed")
ErrInvalidHostIP = errors.New("hostname contains invalid IP address")
ErrTagNotDefined = errors.New("tag not found")
ErrAutoApproverNotAlias = errors.New("auto approver is not an alias")
@@ -137,7 +141,6 @@ var (
ErrAutogroupDangerAllDst = errors.New("cannot use autogroup:danger-all as a dst")
ErrAutogroupNotSupportedSSHSrc = errors.New("autogroup not supported for SSH sources")
ErrAutogroupNotSupportedSSHDst = errors.New("autogroup not supported for SSH destinations")
ErrAutogroupNotSupportedSSHUsr = errors.New("autogroup not supported for SSH user")
ErrHostNotDefined = errors.New("host not defined in policy")
ErrSSHSourceAliasNotSupported = errors.New("alias not supported for SSH source")
ErrSSHDestAliasNotSupported = errors.New("alias not supported for SSH destination")
@@ -539,12 +542,27 @@ func (g *Group) resolve(p *Policy, users types.Users, nodes views.Slice[types.No
// Tag is a special string which is always prefixed with `tag:`.
type Tag string
// Validate enforces the `tag:` prefix and the SaaS rule that the
// first character after the prefix is an ASCII letter ([A-Za-z]).
// Subsequent characters may be ASCII letters, digits, hyphens, or
// dots — those are checked by the existing alias parser elsewhere.
func (t *Tag) Validate() error {
if isTag(string(*t)) {
return nil
s := string(*t)
if !isTag(s) {
return fmt.Errorf("%w, got: %q", ErrInvalidTagFormat, *t)
}
return fmt.Errorf("%w, got: %q", ErrInvalidTagFormat, *t)
rest := strings.TrimPrefix(s, "tag:")
if rest == "" {
return ErrTagNameMustStartWithLetter
}
first := rest[0]
if (first < 'a' || first > 'z') && (first < 'A' || first > 'Z') {
return ErrTagNameMustStartWithLetter
}
return nil
}
func (t *Tag) UnmarshalJSON(b []byte) error {
@@ -1056,10 +1074,17 @@ func parseAlias(vs string) (Alias, error) {
// AliasEnc is used to deserialize a Alias.
type AliasEnc struct{ Alias }
// UnmarshalJSON trims surrounding whitespace from each alias string
// before dispatching so that `"tag:server "` or `" odin@example.com"`
// resolves to the same tag or user SaaS would resolve. SaaS trims
// before lookup; a literal-match policy here would drop the affected
// node from every rule referencing it.
func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
ptr, err := unmarshalPointer(
b,
parseAlias,
func(s string) (Alias, error) {
return parseAlias(strings.TrimSpace(s))
},
)
if err != nil {
return err
@@ -1379,6 +1404,25 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
}
}
// Reject group-in-group references. Reverse-sort the keys so the
// reported (parent, child) pair names the deepest non-leaf parent
// first.
keys := make([]string, 0, len(rawGroups))
for k := range rawGroups {
keys = append(keys, k)
}
slices.Sort(keys)
slices.Reverse(keys)
for _, key := range keys {
for _, u := range rawGroups[key] {
if isGroup(u) {
return fmt.Errorf("groups[%q]: %q: %w", key, u, ErrGroupMembersCannotBeRecursive)
}
}
}
*g = make(Groups)
for key, value := range rawGroups {
@@ -1391,10 +1435,6 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
err := username.Validate()
if err != nil {
if isGroup(u) {
return fmt.Errorf("%w: found %q inside %q", ErrNestedGroups, u, group)
}
return err
}
@@ -1646,16 +1686,20 @@ func (a *SSHAction) String() string {
return string(*a)
}
// UnmarshalJSON implements JSON unmarshaling for SSHAction.
// UnmarshalJSON trims surrounding whitespace before matching, lets the
// empty string through (per-rule Validate() surfaces it later), and
// rejects every other unknown value here.
func (a *SSHAction) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
str := strings.TrimSpace(strings.Trim(string(b), `"`))
switch str {
case "":
*a = SSHAction("")
case "accept":
*a = SSHActionAccept
case "check":
*a = SSHActionCheck
default:
return fmt.Errorf("%w: %q, must be one of: accept, check", ErrInvalidSSHAction, str)
return fmt.Errorf("%q %w", str, ErrSSHActionInvalid)
}
return nil
@@ -2052,7 +2096,6 @@ var (
autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged, AutoGroupSelf}
autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
autogroupForSSHDst = []AutoGroup{AutoGroupMember, AutoGroupTagged, AutoGroupSelf}
autogroupForSSHUser = []AutoGroup{AutoGroupNonRoot}
autogroupForNodeAttrs = []AutoGroup{AutoGroupMember, AutoGroupTagged}
autogroupNotSupported = []AutoGroup{}
@@ -2192,18 +2235,6 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error {
return nil
}
func validateAutogroupForSSHUser(user *AutoGroup) error {
if user == nil {
return nil
}
if !slices.Contains(autogroupForSSHUser, *user) {
return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedSSHUsr, *user, autogroupForSSHUser)
}
return nil
}
// validateSSHSrcDstCombination validates that SSH source/destination combinations
// follow Tailscale's security model:
// - Destination can be: tags, autogroup:self (if source is users/groups), or same-user
@@ -2479,23 +2510,29 @@ func (p *Policy) validate() error {
}
for _, ssh := range p.SSHs {
// Empty action and users survive parse; surface them here.
if ssh.Action == "" {
errs = append(errs, ErrSSHActionMustBeSpecified)
}
if len(ssh.Users) == 0 {
errs = append(errs, ErrSSHUsersMustBeSpecified)
}
// "" and "*" are not valid login users; any other string
// (including autogroup, group, tag, malformed localpart) is
// treated as a literal user name.
for _, user := range ssh.Users {
if strings.HasPrefix(string(user), "autogroup:") {
maybeAuto := AutoGroup(user)
err := validateAutogroupForSSHUser(&maybeAuto)
if err != nil {
errs = append(errs, err)
continue
}
switch user {
case "", "*":
errs = append(errs, fmt.Errorf("user %q %w", user, ErrSSHUserInvalid))
}
}
if user.IsLocalpart() {
_, err := user.ParseLocalpart()
if err != nil {
errs = append(errs, err)
continue
}
// acceptEnv entries cannot be empty; "*" and "**" are valid.
for _, env := range ssh.AcceptEnv {
if env == "" {
errs = append(errs, ErrSSHAcceptEnvEmpty)
}
}
@@ -2555,6 +2592,10 @@ func (p *Policy) validate() error {
if err != nil {
errs = append(errs, err)
}
case *Host:
// Hosts-table aliases are valid on ACL dst but
// rejected here for SSH dst.
errs = append(errs, fmt.Errorf("%w %q", ErrSSHDestinationHostAlias, string(*dst)))
}
}
@@ -2877,12 +2918,16 @@ func (p *SSHCheckPeriod) UnmarshalJSON(b []byte) error {
return nil
}
d, err := model.ParseDuration(str)
// time.ParseDuration produces error strings like
// `time: invalid duration "abc"` which match SaaS body wording
// exactly; model.ParseDuration wraps the same parse with custom
// phrasing and would diverge.
d, err := time.ParseDuration(str)
if err != nil {
return fmt.Errorf("parsing checkPeriod %q: %w", str, err)
return err
}
p.Duration = time.Duration(d)
p.Duration = d
return nil
}
@@ -2896,26 +2941,19 @@ func (p SSHCheckPeriod) MarshalJSON() ([]byte, error) {
return fmt.Appendf(nil, "%q", p.Duration.String()), nil
}
// Validate checks that the SSHCheckPeriod is within allowed bounds.
// Validate rejects negative durations and anything above the inclusive
// 168h max.
func (p *SSHCheckPeriod) Validate() error {
if p.Always {
return nil
}
if p.Duration < SSHCheckPeriodMin {
return fmt.Errorf(
"%w: got %s",
ErrSSHCheckPeriodBelowMin,
p.Duration,
)
if p.Duration < 0 {
return fmt.Errorf("checkPeriod %s %w", p.Duration, ErrSSHCheckPeriodNegative)
}
if p.Duration > SSHCheckPeriodMax {
return fmt.Errorf(
"%w: got %s",
ErrSSHCheckPeriodAboveMax,
p.Duration,
)
return fmt.Errorf("checkPeriod %s %w", p.Duration, ErrSSHCheckPeriodAboveMax)
}
return nil
@@ -3093,25 +3131,31 @@ func (u SSHUsers) ContainsNonRoot() bool {
return slices.Contains(u, SSHUser(AutoGroupNonRoot))
}
// ContainsLocalpart returns true if any entry has the localpart: prefix.
// ContainsLocalpart returns true if any entry is a canonical
// `localpart:*@<domain>` form. Non-canonical strings starting with
// `localpart:` are treated as literal usernames.
func (u SSHUsers) ContainsLocalpart() bool {
return slices.ContainsFunc(u, func(user SSHUser) bool {
return user.IsLocalpart()
return user.IsCanonicalLocalpart()
})
}
// NormalUsers returns all SSH users that are not root, autogroup:nonroot,
// or localpart: entries.
// NormalUsers returns users that land in the compiled literal user map
// (everything except root, autogroup:nonroot, and canonical
// `localpart:*@<domain>`). Malformed `localpart:` strings stay here as
// literals.
func (u SSHUsers) NormalUsers() []SSHUser {
return slicesx.Filter(nil, u, func(user SSHUser) bool {
return user != "root" && user != SSHUser(AutoGroupNonRoot) && !user.IsLocalpart()
return user != "root" && user != SSHUser(AutoGroupNonRoot) && !user.IsCanonicalLocalpart()
})
}
// LocalpartEntries returns only the localpart: prefixed entries.
// LocalpartEntries returns only canonical `localpart:*@<domain>` entries.
// Non-canonical localpart strings are excluded so they do not trigger
// the resolution path; they are emitted literally by NormalUsers.
func (u SSHUsers) LocalpartEntries() []SSHUser {
return slicesx.Filter(nil, u, func(user SSHUser) bool {
return user.IsLocalpart()
return user.IsCanonicalLocalpart()
})
}
@@ -3121,11 +3165,25 @@ func (u SSHUser) String() string {
return string(u)
}
// IsLocalpart returns true if the SSHUser has the localpart: prefix.
// IsLocalpart returns true if the SSHUser has the literal `localpart:`
// prefix. It is a syntactic check only — non-canonical shapes still
// pass.
func (u SSHUser) IsLocalpart() bool {
return strings.HasPrefix(string(u), SSHUserLocalpartPrefix)
}
// IsCanonicalLocalpart reports whether the SSHUser parses as the
// canonical `localpart:*@<domain>` form that resolution acts on.
func (u SSHUser) IsCanonicalLocalpart() bool {
if !u.IsLocalpart() {
return false
}
_, err := u.ParseLocalpart()
return err == nil
}
// ParseLocalpart validates and extracts the domain from a localpart: entry.
// The expected format is localpart:*@<domain>.
// Returns the domain part or an error if the format is invalid.
@@ -3161,6 +3219,20 @@ func (u SSHUser) MarshalJSON() ([]byte, error) {
return json.Marshal(string(u))
}
// UnmarshalJSON trims surrounding whitespace per element. A whitespace-
// only entry collapses to `""` and surfaces as `user "" is not valid` in
// the per-rule Validate() pass.
func (u *SSHUser) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil { //nolint:noinlineerr
return err
}
*u = SSHUser(strings.TrimSpace(s))
return nil
}
// unmarshalPolicy takes a byte slice and unmarshals it into a Policy struct.
// In addition to unmarshalling, it will also validate the policy.
// This is the only entrypoint of reading a policy from a file or other source.
@@ -3188,10 +3260,19 @@ func unmarshalPolicy(b []byte) (*Policy, error) {
}
// Non-tag entries in grant.via surface as type errors on
// []Tag; match SaaS wording instead of Go's JSON diagnostic.
// []Tag; rephrase to the wire-compatible body.
if strings.Contains(string(serr.JSONPointer), "/via/") {
return nil, ErrGrantViaNotATag
}
// Non-ASCII tag-name failures surface from Tag.Validate
// at unmarshal time. Reshape to `tagOwners["tag:X"]: …`.
if errors.Is(serr.Err, ErrTagNameMustStartWithLetter) {
ptr := serr.JSONPointer
name := ptr.LastToken()
return nil, fmt.Errorf("tagOwners[%q]: %w", name, ErrTagNameMustStartWithLetter)
}
}
return nil, fmt.Errorf("parsing policy from bytes: %w", err)

View File

@@ -407,8 +407,50 @@ func TestUnmarshalPolicy(t *testing.T) {
},
}
`,
// wantErr: `username must contain @, got: "group:inner"`,
wantErr: `nested groups are not allowed: found "group:inner" inside "group:example"`,
wantErr: `groups["group:example"]: "group:inner": group members cannot be recursive`,
},
{
// SaaS reports the deepest non-leaf parent first: for
// the three-deep chain `a -> b -> c -> user`, the
// reported pair is `b -> c` rather than `a -> b`.
name: "group-nested-three-deep",
input: `
{
"groups": {
"group:a": ["group:b"],
"group:b": ["group:c"],
"group:c": ["thor@example.org"],
},
}
`,
wantErr: `groups["group:b"]: "group:c": group members cannot be recursive`,
},
{
// Cycle `a <-> b`: reported as `b -> a` so the body
// matches SaaS exactly.
name: "group-nested-cycle",
input: `
{
"groups": {
"group:a": ["group:b"],
"group:b": ["group:a"],
},
}
`,
wantErr: `groups["group:b"]: "group:a": group members cannot be recursive`,
},
{
// Self-cycle: the same group appears as its own
// member. Same wording.
name: "group-nested-self-cycle",
input: `
{
"groups": {
"group:a": ["group:a"],
},
}
`,
wantErr: `groups["group:a"]: "group:a": group members cannot be recursive`,
},
{
name: "invalid-addr",
@@ -660,7 +702,7 @@ func TestUnmarshalPolicy(t *testing.T) {
},
},
{
name: "ssh-with-tag-and-user",
name: "ssh-with-tag-and-wildcard-user",
input: `
{
"tagOwners": {
@@ -681,26 +723,7 @@ func TestUnmarshalPolicy(t *testing.T) {
]
}
`,
want: &Policy{
TagOwners: TagOwners{
Tag("tag:web"): Owners{new(Username("admin@example.com"))},
Tag("tag:server"): Owners{new(Username("admin@example.com"))},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{
tp("tag:web"),
},
Destinations: SSHDstAliases{
tp("tag:server"),
},
Users: []SSHUser{
SSHUser("*"),
},
},
},
},
wantErr: `user "*" is not valid`,
},
{
name: "ssh-with-check-period",
@@ -2006,8 +2029,11 @@ func TestUnmarshalPolicy(t *testing.T) {
`,
wantErr: "square brackets are only valid around IPv6 addresses",
},
// Non-canonical `localpart:` strings flow through as literal
// user names per SaaS behaviour — captured in
// ssh-malformed-user-localpart-{no-at,no-glob,no-domain}.
{
name: "ssh-localpart-invalid-no-at-sign",
name: "ssh-localpart-non-canonical-no-at-sign",
input: `
{
"tagOwners": {"tag:prod": ["admin@"]},
@@ -2019,10 +2045,22 @@ func TestUnmarshalPolicy(t *testing.T) {
}]
}
`,
wantErr: "invalid localpart format",
want: &Policy{
TagOwners: TagOwners{
Tag("tag:prod"): Owners{new(Username("admin@"))},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:prod")},
Users: []SSHUser{SSHUser("localpart:foo")},
},
},
},
},
{
name: "ssh-localpart-invalid-non-wildcard",
name: "ssh-localpart-non-canonical-non-wildcard",
input: `
{
"tagOwners": {"tag:prod": ["admin@"]},
@@ -2034,10 +2072,22 @@ func TestUnmarshalPolicy(t *testing.T) {
}]
}
`,
wantErr: "invalid localpart format",
want: &Policy{
TagOwners: TagOwners{
Tag("tag:prod"): Owners{new(Username("admin@"))},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:prod")},
Users: []SSHUser{SSHUser("localpart:alice@example.com")},
},
},
},
},
{
name: "ssh-localpart-invalid-empty-domain",
name: "ssh-localpart-non-canonical-empty-domain",
input: `
{
"tagOwners": {"tag:prod": ["admin@"]},
@@ -2049,7 +2099,19 @@ func TestUnmarshalPolicy(t *testing.T) {
}]
}
`,
wantErr: "invalid localpart format",
want: &Policy{
TagOwners: TagOwners{
Tag("tag:prod"): Owners{new(Username("admin@"))},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:prod")},
Users: []SSHUser{SSHUser("localpart:*@")},
},
},
},
},
// A test entry with neither accept nor deny asserts nothing
// and is silently accepted today. Tailscale rejects the policy.
@@ -4133,13 +4195,45 @@ func TestFlattenTagOwners(t *testing.T) {
wantErr: "",
},
{
name: "circular-reference",
// SaaS tolerates tag:a <-> tag:b cycles by dropping the
// cycle edge; both tags resolve to an empty owner set
// because neither chain reaches a non-tag owner.
name: "circular-reference-resolves-to-empty",
input: TagOwners{
Tag("tag:a"): Owners{new(Tag("tag:b"))},
Tag("tag:b"): Owners{new(Tag("tag:a"))},
},
want: nil,
wantErr: "circular reference detected: tag:a -> tag:b",
want: TagOwners{
Tag("tag:a"): nil,
Tag("tag:b"): nil,
},
wantErr: "",
},
{
// tag:a -> tag:a self-reference: the only owner is the
// cycle edge itself; result is empty.
name: "self-reference-resolves-to-empty",
input: TagOwners{
Tag("tag:a"): Owners{new(Tag("tag:a"))},
},
want: TagOwners{
Tag("tag:a"): nil,
},
wantErr: "",
},
{
// Cycle plus a sibling non-tag owner: the cycle edge
// drops out, the sibling owner survives.
name: "cycle-plus-sibling-keeps-sibling",
input: TagOwners{
Tag("tag:a"): Owners{new(Tag("tag:b")), new(Username("alice@example.com"))},
Tag("tag:b"): Owners{new(Tag("tag:a"))},
},
want: TagOwners{
Tag("tag:a"): Owners{new(Username("alice@example.com"))},
Tag("tag:b"): Owners{new(Username("alice@example.com"))},
},
wantErr: "",
},
{
name: "mixed-owners",
@@ -4198,7 +4292,9 @@ func TestFlattenTagOwners(t *testing.T) {
wantErr: "",
},
{
name: "tag-long-circular-chain",
// Long cycle: every tag eventually points back to itself.
// Each tag resolves to the empty owner set.
name: "tag-long-circular-chain-resolves-to-empty",
input: TagOwners{
Tag("tag:a"): Owners{new(Tag("tag:g"))},
Tag("tag:b"): Owners{new(Tag("tag:a"))},
@@ -4208,7 +4304,16 @@ func TestFlattenTagOwners(t *testing.T) {
Tag("tag:f"): Owners{new(Tag("tag:e"))},
Tag("tag:g"): Owners{new(Tag("tag:f"))},
},
wantErr: "circular reference detected: tag:a -> tag:b -> tag:c -> tag:d -> tag:e -> tag:f -> tag:g",
want: TagOwners{
Tag("tag:a"): nil,
Tag("tag:b"): nil,
Tag("tag:c"): nil,
Tag("tag:d"): nil,
Tag("tag:e"): nil,
Tag("tag:f"): nil,
Tag("tag:g"): nil,
},
wantErr: "",
},
{
name: "undefined-tag-reference",
@@ -4364,7 +4469,15 @@ func TestSSHCheckPeriodValidate(t *testing.T) {
period: SSHCheckPeriod{Always: true},
},
{
name: "1m minimum valid",
name: "zero duration is valid",
period: SSHCheckPeriod{Duration: 0},
},
{
name: "30s below previous minimum is valid (matches SaaS)",
period: SSHCheckPeriod{Duration: 30 * time.Second},
},
{
name: "1m valid",
period: SSHCheckPeriod{Duration: time.Minute},
},
{
@@ -4372,15 +4485,20 @@ func TestSSHCheckPeriodValidate(t *testing.T) {
period: SSHCheckPeriod{Duration: 168 * time.Hour},
},
{
name: "30s below minimum",
period: SSHCheckPeriod{Duration: 30 * time.Second},
wantErr: ErrSSHCheckPeriodBelowMin,
name: "168h0m1s above maximum",
period: SSHCheckPeriod{Duration: 168*time.Hour + time.Second},
wantErr: ErrSSHCheckPeriodAboveMax,
},
{
name: "169h above maximum",
period: SSHCheckPeriod{Duration: 169 * time.Hour},
wantErr: ErrSSHCheckPeriodAboveMax,
},
{
name: "negative duration rejected",
period: SSHCheckPeriod{Duration: -time.Minute},
wantErr: ErrSSHCheckPeriodNegative,
},
}
for _, tt := range tests {
@@ -4444,7 +4562,7 @@ func TestSSHCheckPeriodPolicyValidation(t *testing.T) {
wantErr: ErrSSHCheckPeriodOnNonCheck,
},
{
name: "check with 30s is invalid",
name: "check with 30s is valid (matches SaaS, no minimum)",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
@@ -4452,7 +4570,49 @@ func TestSSHCheckPeriodPolicyValidation(t *testing.T) {
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: 30 * time.Second},
},
wantErr: ErrSSHCheckPeriodBelowMin,
},
{
name: "check with 168h exactly is valid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: 168 * time.Hour},
},
},
{
name: "check with 168h0m1s just above max is invalid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: 168*time.Hour + time.Second},
},
wantErr: ErrSSHCheckPeriodAboveMax,
},
{
name: "check with 200h above max is invalid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: 200 * time.Hour},
},
wantErr: ErrSSHCheckPeriodAboveMax,
},
{
name: "check with negative duration is invalid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: -time.Minute},
},
wantErr: ErrSSHCheckPeriodNegative,
},
}
@@ -4472,6 +4632,406 @@ func TestSSHCheckPeriodPolicyValidation(t *testing.T) {
}
}
// TestSSHRuleSaaSValidation exercises the SaaS-aligned rejections
// added to match the API body strings exactly.
func TestSSHRuleSaaSValidation(t *testing.T) {
baseSSH := func(modify func(*SSH)) SSH {
ssh := SSH{
Action: SSHActionAccept,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
}
if modify != nil {
modify(&ssh)
}
return ssh
}
tests := []struct {
name string
ssh SSH
wantErr error
}{
{
name: "users empty rejected",
ssh: baseSSH(func(s *SSH) { s.Users = nil }),
wantErr: ErrSSHUsersMustBeSpecified,
},
{
name: "users empty array rejected",
ssh: baseSSH(func(s *SSH) { s.Users = SSHUsers{} }),
wantErr: ErrSSHUsersMustBeSpecified,
},
{
name: "user empty string rejected",
ssh: baseSSH(func(s *SSH) { s.Users = SSHUsers{""} }),
wantErr: ErrSSHUserInvalid,
},
{
name: "user wildcard rejected",
ssh: baseSSH(func(s *SSH) { s.Users = SSHUsers{"*"} }),
wantErr: ErrSSHUserInvalid,
},
{
name: "acceptEnv empty entry rejected",
ssh: baseSSH(func(s *SSH) { s.AcceptEnv = []string{"FOO", ""} }),
wantErr: ErrSSHAcceptEnvEmpty,
},
{
name: "action empty rejected",
ssh: baseSSH(func(s *SSH) { s.Action = "" }),
wantErr: ErrSSHActionMustBeSpecified,
},
{
name: "user autogroup non-nonroot accepted (literal)",
ssh: baseSSH(func(s *SSH) {
s.Users = SSHUsers{"autogroup:internet"}
}),
},
{
name: "user malformed localpart accepted (literal)",
ssh: baseSSH(func(s *SSH) {
s.Users = SSHUsers{"localpart:foo"}
}),
},
{
name: "acceptEnv double-glob accepted",
ssh: baseSSH(func(s *SSH) {
s.AcceptEnv = []string{"**"}
}),
},
{
// SaaS rejects hosts-table aliases on SSH dst with
// `invalid dst "srv"`. headscale validates the same
// regardless of whether the alias resolves to a
// single IP or a CIDR.
name: "host alias as SSH dst rejected",
ssh: baseSSH(func(s *SSH) {
s.Destinations = SSHDstAliases{hp("srv")}
}),
wantErr: ErrSSHDestinationHostAlias,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pol := &Policy{
Hosts: Hosts{Host("srv"): Prefix(mp("100.64.0.16/32"))},
SSHs: []SSH{tt.ssh},
}
err := pol.validate()
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
return
}
require.NoError(t, err)
})
}
}
// TestSSHActionInvalidUnmarshal verifies the SaaS-aligned wording for
// non-empty unknown actions surfaces at JSON parse time.
func TestSSHActionInvalidUnmarshal(t *testing.T) {
tests := []struct {
name string
input string
wantValue SSHAction
wantErr error
wantMsg string
}{
{
name: "exact match accept",
input: `"accept"`,
wantValue: SSHActionAccept,
},
{
name: "exact match check",
input: `"check"`,
wantValue: SSHActionCheck,
},
{
name: "whitespace trimmed to accept",
input: `" accept "`,
wantValue: SSHActionAccept,
},
{
name: "uppercase rejected",
input: `"ACCEPT"`,
wantErr: ErrSSHActionInvalid,
wantMsg: `"ACCEPT" is not a valid action`,
},
{
name: "mixedcase rejected",
input: `"Accept"`,
wantErr: ErrSSHActionInvalid,
wantMsg: `"Accept" is not a valid action`,
},
{
name: "whitespace trimmed then mixedcase rejected",
input: `" Accept"`,
wantErr: ErrSSHActionInvalid,
wantMsg: `"Accept" is not a valid action`,
},
{
name: "unknown action rejected",
input: `"deny"`,
wantErr: ErrSSHActionInvalid,
wantMsg: `"deny" is not a valid action`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var a SSHAction
err := json.Unmarshal([]byte(tt.input), &a)
if tt.wantErr != nil {
require.Error(t, err)
require.ErrorIs(t, err, tt.wantErr)
require.Contains(t, err.Error(), tt.wantMsg)
return
}
require.NoError(t, err)
require.Equal(t, tt.wantValue, a)
})
}
}
// TestSSHUserUnmarshalTrim verifies per-element whitespace trimming so
// that the compiled `sshUsers` map matches SaaS exactly. A
// whitespace-only entry collapses to "" and is left for the per-rule
// validate() pass to reject via ErrSSHUserInvalid.
func TestSSHUserUnmarshalTrim(t *testing.T) {
tests := []struct {
name string
input string
want SSHUser
}{
{
name: "leading whitespace trimmed",
input: `" root"`,
want: SSHUser("root"),
},
{
name: "trailing whitespace trimmed",
input: `"root "`,
want: SSHUser("root"),
},
{
name: "whitespace-only collapses to empty",
input: `" "`,
want: SSHUser(""),
},
{
name: "no trim needed",
input: `"root"`,
want: SSHUser("root"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var u SSHUser
err := json.Unmarshal([]byte(tt.input), &u)
require.NoError(t, err)
require.Equal(t, tt.want, u)
})
}
}
// TestSSHUserTrimEndToEnd verifies that a policy with `[" root"]`
// parses cleanly and that the policy validate() pass treats `[" "]`
// as the empty-user case (per-element trim happens at unmarshal time).
func TestSSHUserTrimEndToEnd(t *testing.T) {
t.Run("leading whitespace user accepted and trimmed", func(t *testing.T) {
policy := `
{
"tagOwners": {"tag:server": ["odin@example.com"]},
"ssh": [{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:server"],
"users": [" root"]
}]
}`
pol, err := unmarshalPolicy([]byte(policy))
require.NoError(t, err)
require.Len(t, pol.SSHs, 1)
require.Equal(t, SSHUsers{SSHUser("root")}, pol.SSHs[0].Users)
})
t.Run("whitespace-only user rejected as empty", func(t *testing.T) {
policy := `
{
"tagOwners": {"tag:server": ["odin@example.com"]},
"ssh": [{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:server"],
"users": [" "]
}]
}`
_, err := unmarshalPolicy([]byte(policy))
require.Error(t, err)
require.ErrorIs(t, err, ErrSSHUserInvalid)
require.Contains(t, err.Error(), `user "" is not valid`)
})
}
// TestAliasEncUnmarshalTrim verifies that src/dst entries get
// trimmed before alias dispatch so `"tag:server "` resolves to the
// same Tag alias SaaS uses and `" odin@example.com"` resolves to the
// same Username alias. Covers tag, group, user, and autogroup entries
// on both the leading- and trailing-whitespace edges.
func TestAliasEncUnmarshalTrim(t *testing.T) {
tests := []struct {
name string
input string
want Alias
}{
{
name: "tag trailing whitespace",
input: `"tag:server "`,
want: new(Tag("tag:server")),
},
{
name: "tag leading whitespace",
input: `" tag:server"`,
want: new(Tag("tag:server")),
},
{
name: "group leading whitespace",
input: `" group:admins"`,
want: new(Group("group:admins")),
},
{
name: "user trailing whitespace",
input: `"odin@example.com "`,
want: new(Username("odin@example.com")),
},
{
name: "autogroup trailing whitespace",
input: `"autogroup:member "`,
want: new(AutoGroup("autogroup:member")),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var a AliasEnc
err := json.Unmarshal([]byte(tt.input), &a)
require.NoError(t, err)
require.Equal(t, tt.want, a.Alias)
})
}
}
// TestTagValidateFirstCharLetter exercises the SaaS rule that the
// first character after `tag:` must be an ASCII letter. Digits,
// punctuation, and non-ASCII Unicode letters are rejected with the
// same body SaaS produces. Subsequent characters are unconstrained.
func TestTagValidateFirstCharLetter(t *testing.T) {
tests := []struct {
name string
tag Tag
wantErr error
}{
{
name: "ascii lowercase letter",
tag: Tag("tag:server"),
},
{
name: "ascii uppercase letter",
tag: Tag("tag:Server"),
},
{
name: "ascii letter then digit",
tag: Tag("tag:a1"),
},
{
name: "leading digit rejected",
tag: Tag("tag:1server"),
wantErr: ErrTagNameMustStartWithLetter,
},
{
name: "leading hyphen rejected",
tag: Tag("tag:-server"),
wantErr: ErrTagNameMustStartWithLetter,
},
{
name: "cyrillic letter rejected",
tag: Tag("tag:сервер"),
wantErr: ErrTagNameMustStartWithLetter,
},
{
name: "empty name rejected",
tag: Tag("tag:"),
wantErr: ErrTagNameMustStartWithLetter,
},
{
name: "missing prefix rejected",
tag: Tag("server"),
wantErr: ErrInvalidTagFormat,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.tag.Validate()
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
return
}
require.NoError(t, err)
})
}
}
// TestUnmarshalPolicyCyrillicTagOwner verifies the full SaaS body
// (`tagOwners["tag:сервер"]: …`) surfaces when a non-ASCII tag
// appears as a tagOwners key.
func TestUnmarshalPolicyCyrillicTagOwner(t *testing.T) {
policy := []byte(`{"tagOwners": {"tag:сервер": ["odin@example.com"]}}`)
_, err := unmarshalPolicy(policy)
require.Error(t, err)
require.ErrorIs(t, err, ErrTagNameMustStartWithLetter)
require.Contains(t, err.Error(),
`tagOwners["tag:сервер"]: tag names must start with a letter, after 'tag:'`)
}
// TestSSHCheckPeriodInvalidDuration verifies the SaaS body for the
// malformed-duration case (`time: invalid duration "abc"`).
func TestSSHCheckPeriodInvalidDuration(t *testing.T) {
var p SSHCheckPeriod
err := json.Unmarshal([]byte(`"abc"`), &p)
require.Error(t, err)
require.Contains(t, err.Error(), `time: invalid duration "abc"`)
}
// TestSSHCheckPeriodNegativeMessage verifies the SaaS body for the
// negative-duration case (`checkPeriod -1m0s must be a positive duration`).
func TestSSHCheckPeriodNegativeMessage(t *testing.T) {
p := SSHCheckPeriod{Duration: -time.Minute}
err := p.Validate()
require.Error(t, err)
require.ErrorIs(t, err, ErrSSHCheckPeriodNegative)
require.Contains(t, err.Error(), "checkPeriod -1m0s must be a positive duration")
}
func TestUnmarshalGrants(t *testing.T) {
tests := []struct {
name string