policy/v2: trim whitespace and reject negative checkPeriod

SaaS trims leading whitespace in action and per-user entries before
matching, so headscale does too. Reject negative checkPeriod with
"must be a positive duration" matching SaaS body. The 168h upper
bound is inclusive.
This commit is contained in:
Kristoffer Dalby
2026-05-13 08:17:11 +00:00
parent 2180380fc1
commit f8aa6c46ef
12 changed files with 201130 additions and 7 deletions
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+31 -2
View File
@@ -46,6 +46,7 @@ var (
ErrSSHTagSourceToAutogroupMember = errors.New("tags in SSH source cannot access autogroup:member (user-owned devices)")
ErrSSHWildcardDestination = errors.New("wildcard (*) is not supported as SSH destination")
ErrSSHCheckPeriodAboveMax = errors.New("is above the max (168h)")
ErrSSHCheckPeriodNegative = errors.New("must be a positive duration")
ErrSSHCheckPeriodOnNonCheck = errors.New("checkPeriod is only valid with action \"check\"")
ErrInvalidLocalpart = errors.New("invalid localpart format, must be localpart:*@<domain>")
ErrSSHUsersMustBeSpecified = errors.New("users must be specified")
@@ -1657,8 +1658,12 @@ func (a *SSHAction) String() string {
// Empty strings are accepted at parse time; the per-rule validate()
// pass surfaces them with `action must be specified` to match SaaS.
// Non-empty unknown values fail here with `"foo" is not a valid action`.
//
// SaaS trims surrounding whitespace before comparing, then complains
// about the trimmed content; the resulting error quotes the trimmed
// value (e.g. `" Accept "` → `"Accept" is not a valid action`).
func (a *SSHAction) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
str := strings.TrimSpace(strings.Trim(string(b), `"`))
switch str {
case "":
*a = SSHAction("")
@@ -2919,12 +2924,19 @@ func (p SSHCheckPeriod) MarshalJSON() ([]byte, error) {
}
// Validate checks that the SSHCheckPeriod is within allowed bounds.
// SaaS imposes no minimum; the only ceiling is 168h.
// SaaS rejects negative durations with `must be a positive duration`
// and anything above 168h with `is above the max (168h)`; the 168h
// upper bound is inclusive.
func (p *SSHCheckPeriod) Validate() error {
if p.Always {
return nil
}
if p.Duration < 0 {
// SaaS body: `checkPeriod -1m0s must be a positive duration`.
return fmt.Errorf("checkPeriod %s %w", p.Duration, ErrSSHCheckPeriodNegative)
}
if p.Duration > SSHCheckPeriodMax {
// SaaS body: `checkPeriod 200h0m0s is above the max (168h)`.
return fmt.Errorf("checkPeriod %s %w", p.Duration, ErrSSHCheckPeriodAboveMax)
@@ -3195,6 +3207,23 @@ func (u SSHUser) MarshalJSON() ([]byte, error) {
return json.Marshal(string(u))
}
// UnmarshalJSON trims surrounding whitespace per element so a policy
// like `"users": [" root"]` stores `"root"` and compiles to the same
// `sshUsers: {"root": "root"}` map SaaS produces. A whitespace-only
// entry like `[" "]` collapses to `""` and falls through to the
// per-rule validate() pass, which surfaces the SaaS-aligned
// `user "" is not valid`.
func (u *SSHUser) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil { //nolint:noinlineerr
return err
}
*u = SSHUser(strings.TrimSpace(s))
return nil
}
// unmarshalPolicy takes a byte slice and unmarshals it into a Policy struct.
// In addition to unmarshalling, it will also validate the policy.
// This is the only entrypoint of reading a policy from a file or other source.
+199 -5
View File
@@ -4399,11 +4399,21 @@ func TestSSHCheckPeriodValidate(t *testing.T) {
name: "168h maximum valid",
period: SSHCheckPeriod{Duration: 168 * time.Hour},
},
{
name: "168h0m1s above maximum",
period: SSHCheckPeriod{Duration: 168*time.Hour + time.Second},
wantErr: ErrSSHCheckPeriodAboveMax,
},
{
name: "169h above maximum",
period: SSHCheckPeriod{Duration: 169 * time.Hour},
wantErr: ErrSSHCheckPeriodAboveMax,
},
{
name: "negative duration rejected",
period: SSHCheckPeriod{Duration: -time.Minute},
wantErr: ErrSSHCheckPeriodNegative,
},
}
for _, tt := range tests {
@@ -4476,6 +4486,27 @@ func TestSSHCheckPeriodPolicyValidation(t *testing.T) {
CheckPeriod: &SSHCheckPeriod{Duration: 30 * time.Second},
},
},
{
name: "check with 168h exactly is valid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: 168 * time.Hour},
},
},
{
name: "check with 168h0m1s just above max is invalid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: 168*time.Hour + time.Second},
},
wantErr: ErrSSHCheckPeriodAboveMax,
},
{
name: "check with 200h above max is invalid",
ssh: SSH{
@@ -4487,6 +4518,17 @@ func TestSSHCheckPeriodPolicyValidation(t *testing.T) {
},
wantErr: ErrSSHCheckPeriodAboveMax,
},
{
name: "check with negative duration is invalid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: -time.Minute},
},
wantErr: ErrSSHCheckPeriodNegative,
},
}
for _, tt := range tests {
@@ -4596,12 +4638,153 @@ func TestSSHRuleSaaSValidation(t *testing.T) {
// TestSSHActionInvalidUnmarshal verifies the SaaS-aligned wording for
// non-empty unknown actions surfaces at JSON parse time.
func TestSSHActionInvalidUnmarshal(t *testing.T) {
var a SSHAction
tests := []struct {
name string
input string
wantValue SSHAction
wantErr error
wantMsg string
}{
{
name: "exact match accept",
input: `"accept"`,
wantValue: SSHActionAccept,
},
{
name: "exact match check",
input: `"check"`,
wantValue: SSHActionCheck,
},
{
name: "whitespace trimmed to accept",
input: `" accept "`,
wantValue: SSHActionAccept,
},
{
name: "uppercase rejected",
input: `"ACCEPT"`,
wantErr: ErrSSHActionInvalid,
wantMsg: `"ACCEPT" is not a valid action`,
},
{
name: "mixedcase rejected",
input: `"Accept"`,
wantErr: ErrSSHActionInvalid,
wantMsg: `"Accept" is not a valid action`,
},
{
name: "whitespace trimmed then mixedcase rejected",
input: `" Accept"`,
wantErr: ErrSSHActionInvalid,
wantMsg: `"Accept" is not a valid action`,
},
{
name: "unknown action rejected",
input: `"deny"`,
wantErr: ErrSSHActionInvalid,
wantMsg: `"deny" is not a valid action`,
},
}
err := json.Unmarshal([]byte(`"deny"`), &a)
require.Error(t, err)
require.ErrorIs(t, err, ErrSSHActionInvalid)
require.Contains(t, err.Error(), `"deny" is not a valid action`)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var a SSHAction
err := json.Unmarshal([]byte(tt.input), &a)
if tt.wantErr != nil {
require.Error(t, err)
require.ErrorIs(t, err, tt.wantErr)
require.Contains(t, err.Error(), tt.wantMsg)
return
}
require.NoError(t, err)
require.Equal(t, tt.wantValue, a)
})
}
}
// TestSSHUserUnmarshalTrim verifies per-element whitespace trimming so
// that the compiled `sshUsers` map matches SaaS exactly. A
// whitespace-only entry collapses to "" and is left for the per-rule
// validate() pass to reject via ErrSSHUserInvalid.
func TestSSHUserUnmarshalTrim(t *testing.T) {
tests := []struct {
name string
input string
want SSHUser
}{
{
name: "leading whitespace trimmed",
input: `" root"`,
want: SSHUser("root"),
},
{
name: "trailing whitespace trimmed",
input: `"root "`,
want: SSHUser("root"),
},
{
name: "whitespace-only collapses to empty",
input: `" "`,
want: SSHUser(""),
},
{
name: "no trim needed",
input: `"root"`,
want: SSHUser("root"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var u SSHUser
err := json.Unmarshal([]byte(tt.input), &u)
require.NoError(t, err)
require.Equal(t, tt.want, u)
})
}
}
// TestSSHUserTrimEndToEnd verifies that a policy with `[" root"]`
// parses cleanly and that the policy validate() pass treats `[" "]`
// as the empty-user case (per-element trim happens at unmarshal time).
func TestSSHUserTrimEndToEnd(t *testing.T) {
t.Run("leading whitespace user accepted and trimmed", func(t *testing.T) {
policy := `
{
"tagOwners": {"tag:server": ["odin@example.com"]},
"ssh": [{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:server"],
"users": [" root"]
}]
}`
pol, err := unmarshalPolicy([]byte(policy))
require.NoError(t, err)
require.Len(t, pol.SSHs, 1)
require.Equal(t, SSHUsers{SSHUser("root")}, pol.SSHs[0].Users)
})
t.Run("whitespace-only user rejected as empty", func(t *testing.T) {
policy := `
{
"tagOwners": {"tag:server": ["odin@example.com"]},
"ssh": [{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:server"],
"users": [" "]
}]
}`
_, err := unmarshalPolicy([]byte(policy))
require.Error(t, err)
require.ErrorIs(t, err, ErrSSHUserInvalid)
require.Contains(t, err.Error(), `user "" is not valid`)
})
}
// TestSSHCheckPeriodInvalidDuration verifies the SaaS body for the
@@ -4614,6 +4797,17 @@ func TestSSHCheckPeriodInvalidDuration(t *testing.T) {
require.Contains(t, err.Error(), `time: invalid duration "abc"`)
}
// TestSSHCheckPeriodNegativeMessage verifies the SaaS body for the
// negative-duration case (`checkPeriod -1m0s must be a positive duration`).
func TestSSHCheckPeriodNegativeMessage(t *testing.T) {
p := SSHCheckPeriod{Duration: -time.Minute}
err := p.Validate()
require.Error(t, err)
require.ErrorIs(t, err, ErrSSHCheckPeriodNegative)
require.Contains(t, err.Error(), "checkPeriod -1m0s must be a positive duration")
}
func TestUnmarshalGrants(t *testing.T) {
tests := []struct {
name string