From 6a0a297c7f8669f291b6e12f8e05ffac9e202e6f Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 13 May 2026 14:11:48 +0000 Subject: [PATCH] policy/v2: validate sshTests at parse Adds SSHPolicyTest plus parse-time validation: empty src/dst, port/CIDR/autogroup-internet destinations, and tag references missing from tagOwners are rejected. Engine evaluation comes in a follow-up. --- hscontrol/policy/v2/test.go | 37 ++++- hscontrol/policy/v2/types.go | 104 ++++++++++++++ hscontrol/policy/v2/types_test.go | 222 ++++++++++++++++++++++++++++++ 3 files changed, 356 insertions(+), 7 deletions(-) diff --git a/hscontrol/policy/v2/test.go b/hscontrol/policy/v2/test.go index a93b7c6b..b342f9f3 100644 --- a/hscontrol/policy/v2/test.go +++ b/hscontrol/policy/v2/test.go @@ -24,14 +24,13 @@ import ( // The tests evaluate against the compiled global filter rules, which fold in // both `acls` and `grants`, so the `tests` block validates the whole policy. -// errPolicyTestsFailed wraps the rendered failure body so callers can -// type-assert when they need to react differently to test failures vs. parse -// errors. The Error() prefix is "test(s) failed", the same string Tailscale -// SaaS returns in the api_response_body.message — see -// hscontrol/policy/v2/testdata/policytest_results/. +// errPolicyTestsFailed and errSSHPolicyTestsFailed share the +// "test(s) failed" prefix but stay distinct so callers can use +// errors.Is to tell ACL-test and SSH-test failures apart. var ( - errPolicyTestsFailed = errors.New("test(s) failed") - errTestDestinationNoIP = errors.New("destination resolved to no IP addresses") + errPolicyTestsFailed = errors.New("test(s) failed") + errSSHPolicyTestsFailed = errors.New("test(s) failed") + errTestDestinationNoIP = errors.New("destination resolved to no IP addresses") ) // PolicyTest is one entry in the policy's `tests` block. @@ -53,6 +52,30 @@ type PolicyTest struct { Deny []string `json:"deny,omitempty"` } +// SSHPolicyTest is one entry in the policy's `sshTests` block. The +// accept/deny/check arrays carry usernames, not destinations — every +// listed user is asserted against every entry in Dst. +type SSHPolicyTest struct { + // Src is a single source alias (user, group, tag, host, or IP). + Src string `json:"src"` + + // Dst lists destinations the test exercises (tag, host, or SSH- + // compatible autogroup). Ports, CIDRs, and autogroup:internet are + // rejected at parse time. + Dst []string `json:"dst"` + + // Accept lists users that must reach every Dst via an accept- or + // check-action rule. + Accept []string `json:"accept,omitempty"` + + // Deny lists users that must NOT reach any Dst. + Deny []string `json:"deny,omitempty"` + + // Check lists users that must reach every Dst via a check-action + // rule specifically; an accept-action rule does not satisfy this. + Check []string `json:"check,omitempty"` +} + // PolicyTestResult is the outcome of a single PolicyTest. type PolicyTestResult struct { Src string `json:"src"` diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 7de62307..04b26015 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -153,6 +153,10 @@ var ( ErrTestDestinationMultiPort = errors.New("test destination port must be a single port") ErrTestDestinationCIDR = errors.New("test destination must be a single host, not a CIDR range") ErrAutogroupInternetTestDst = errors.New("autogroup:internet not valid as a test destination") + ErrSSHTestEmptySrc = errors.New("SSH tests entry must have a non-empty src") + ErrSSHTestEmptyDst = errors.New("SSH tests entry must have at least one dst") + ErrSSHTestDstUnknownTag = errors.New("SSH tests dst contains unknown tag") + ErrSSHTestDstDisallowedElement = errors.New("SSH tests dst contains disallowed element") ) type resolved struct { @@ -2084,6 +2088,7 @@ type Policy struct { AutoApprovers AutoApproverPolicy `json:"autoApprovers"` SSHs []SSH `json:"ssh,omitempty"` Tests []PolicyTest `json:"tests,omitempty"` + SSHTests []SSHPolicyTest `json:"sshTests,omitempty"` RandomizeClientPort bool `json:"randomizeClientPort,omitempty"` } @@ -2891,6 +2896,10 @@ func (p *Policy) validate() error { errs = append(errs, err) } + if err := validateSSHTests(p, p.SSHTests); err != nil { //nolint:noinlineerr + errs = append(errs, err) + } + if len(errs) > 0 { return multierr.New(errs...) } @@ -3389,3 +3398,98 @@ func validateTestDestination(pol *Policy, dst string) error { return nil } + +// validateSSHTests enforces parse-time shape on every sshTests entry: +// non-empty src, at least one dst, and each dst describing a single +// SSH-reachable host. Login-user assertions land with the engine so +// failures surface through the same errSSHPolicyTestsFailed wrapper. +func validateSSHTests(pol *Policy, tests []SSHPolicyTest) error { + var errs []error + + for i, t := range tests { + if t.Src == "" { + errs = append(errs, fmt.Errorf("sshTest %d: %w", i, ErrSSHTestEmptySrc)) + } + + if len(t.Dst) == 0 { + errs = append(errs, fmt.Errorf("sshTest %d: %w", i, ErrSSHTestEmptyDst)) + } + + for _, dst := range t.Dst { + err := validateSSHTestDestination(pol, dst) + if err != nil { + errs = append(errs, fmt.Errorf("sshTest %d: %w", i, err)) + } + } + } + + if len(errs) > 0 { + return fmt.Errorf("%w:\n%w", errSSHPolicyTestsFailed, multierr.New(errs...)) + } + + return nil +} + +// validateSSHTestDestination rejects sshTests dst shapes that cannot +// name a single SSH-reachable host: +// +// - `host:port` suffixes (parsed as an unknown tag), +// - multi-host CIDRs (raw `/N` or a hosts: entry resolving wider), +// - autogroup:internet (valid as ACL dst only). +// +// A bare IP literal (single-host /BitLen prefix) is accepted. Tag +// entries must exist in tagOwners. +func validateSSHTestDestination(pol *Policy, dst string) error { + alias, err := parseAlias(dst) + if err != nil { + return fmt.Errorf("%w %q", ErrSSHTestDstDisallowedElement, dst) + } + + switch a := alias.(type) { + case *AutoGroup: + // autogroup:internet is the only autogroup not valid here; + // member/tagged/self/nonroot pass to engine evaluation. + if *a == AutoGroupInternet { + return fmt.Errorf("%w %q", ErrSSHTestDstDisallowedElement, dst) + } + + case *Prefix: + // Bare IP parses to a *Prefix without slash; reject any + // explicit CIDR. + if strings.Contains(dst, "/") { + return fmt.Errorf("%w %q", ErrSSHTestDstDisallowedElement, dst) + } + + case *Tag: + // A tag must be declared in tagOwners. `tag:server:22` lands + // here too because isTag only checks the prefix, so the lookup + // misses and the colon-port suffix surfaces as unknown-tag. + if pol == nil { + return fmt.Errorf("%w %q", ErrSSHTestDstUnknownTag, string(*a)) + } + + err := pol.TagOwners.Contains(a) + if err != nil { + return fmt.Errorf("%w %q", ErrSSHTestDstUnknownTag, string(*a)) + } + + case *Host: + // A hosts: alias that resolves to multiple addresses is a CIDR + // in disguise. + if pol == nil { + return nil + } + + pref, ok := pol.Hosts[*a] + if !ok { + return nil + } + + p := netip.Prefix(pref) + if p.Bits() < p.Addr().BitLen() { + return fmt.Errorf("%w %q", ErrSSHTestDstDisallowedElement, dst) + } + } + + return nil +} diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 5c4652c5..8dd06464 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -5973,3 +5973,225 @@ func TestUnmarshalPolicyEmptyArrays(t *testing.T) { }) } } + +// TestUnmarshalPolicySSHTests covers the parse-time shape rules for the +// sshTests block. Positive rows confirm the SSHPolicyTest struct fields +// round-trip through JSON. Rejection rows pin each parse-time sentinel +// against a representative malformed input. SaaS evaluation-time failures +// (empty assertions, empty user strings) are deliberately accepted at +// parse — they share the "test(s) failed" body with true failures and +// land with the engine. +func TestUnmarshalPolicySSHTests(t *testing.T) { + cases := []struct { + name string + input string + wantErr error // sentinel for errors.Is; nil means parse must succeed + extraSentinels []error // additional sentinels reachable via errors.Is + check func(t *testing.T, pol *Policy) + }{ + { + name: "valid-minimal-shape", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + {"src": "thor@example.org", "dst": ["tag:server"], "accept": ["root"]} + ] +} +`, + check: func(t *testing.T, pol *Policy) { + t.Helper() + require.Len(t, pol.SSHTests, 1) + got := pol.SSHTests[0] + require.Equal(t, "thor@example.org", got.Src) + require.Equal(t, []string{"tag:server"}, got.Dst) + require.Equal(t, []string{"root"}, got.Accept) + require.Empty(t, got.Deny) + require.Empty(t, got.Check) + }, + }, + { + name: "valid-all-three-action-arrays", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + { + "src": "thor@example.org", + "dst": ["tag:server"], + "accept": ["root"], + "deny": ["nobody"], + "check": ["alice"] + } + ] +} +`, + check: func(t *testing.T, pol *Policy) { + t.Helper() + require.Len(t, pol.SSHTests, 1) + got := pol.SSHTests[0] + require.Equal(t, []string{"root"}, got.Accept) + require.Equal(t, []string{"nobody"}, got.Deny) + require.Equal(t, []string{"alice"}, got.Check) + }, + }, + { + // Empty accept+deny+check is rejected by SaaS at evaluation, + // not at parse — the captured body is the same "test(s) failed" + // that true evaluation failures emit. The parse layer must let + // this through so the engine reports it consistently. + name: "valid-empty-arrays-engine-deferred", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + {"src": "thor@example.org", "dst": ["tag:server"]} + ] +} +`, + check: func(t *testing.T, pol *Policy) { + t.Helper() + require.Len(t, pol.SSHTests, 1) + got := pol.SSHTests[0] + require.Empty(t, got.Accept) + require.Empty(t, got.Deny) + require.Empty(t, got.Check) + }, + }, + { + // `tag:server:22` parses as a Tag because isTag only checks + // the `tag:` prefix; the colon-port suffix is retained in the + // tag string and the tagOwners lookup misses. SaaS reports + // this as an unknown tag with the bad value quoted. + name: "dst-with-port", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + {"src": "thor@example.org", "dst": ["tag:server:22"], "accept": ["root"]} + ] +} +`, + wantErr: ErrSSHTestDstUnknownTag, + }, + { + name: "dst-cidr", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + {"src": "thor@example.org", "dst": ["10.0.0.0/8"], "accept": ["root"]} + ] +} +`, + wantErr: ErrSSHTestDstDisallowedElement, + }, + { + name: "dst-autogroup-internet", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + {"src": "thor@example.org", "dst": ["autogroup:internet"], "accept": ["root"]} + ] +} +`, + wantErr: ErrSSHTestDstDisallowedElement, + }, + { + name: "dst-unknown-tag", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + {"src": "thor@example.org", "dst": ["tag:not-in-tagOwners"], "accept": ["root"]} + ] +} +`, + wantErr: ErrSSHTestDstUnknownTag, + }, + { + name: "empty-src", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + {"src": "", "dst": ["tag:server"], "accept": ["root"]} + ] +} +`, + wantErr: ErrSSHTestEmptySrc, + }, + { + name: "empty-dst", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + {"src": "thor@example.org", "dst": [], "accept": ["root"]} + ] +} +`, + wantErr: ErrSSHTestEmptyDst, + }, + { + // Multiple shape failures in one entry must aggregate through + // multierr.New under errSSHPolicyTestsFailed so the surfaced + // body matches the SaaS body byte-for-byte and every + // individual sentinel remains reachable via errors.Is. + name: "multierr-wrap", + input: ` +{ + "tagOwners": {"tag:server": ["admin@example.org"]}, + "sshTests": [ + {"src": "", "dst": ["10.0.0.0/8", "autogroup:internet"], "accept": ["root"]} + ] +} +`, + wantErr: ErrSSHTestEmptySrc, + extraSentinels: []error{ErrSSHTestDstDisallowedElement}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + pol, err := unmarshalPolicy([]byte(tc.input)) + + if tc.wantErr == nil { + require.NoError(t, err) + require.NotNil(t, pol) + + if tc.check != nil { + tc.check(t, pol) + } + + return + } + + require.Error(t, err) + require.ErrorIs( + t, + err, tc.wantErr, + "want errors.Is(err, %v); got %v", + tc.wantErr, + err, + ) + require.Contains( + t, + err.Error(), "test(s) failed", + `want err to contain "test(s) failed"; got %q`, + err.Error(), + ) + + for _, sentinel := range tc.extraSentinels { + require.ErrorIs( + t, + err, sentinel, + "want errors.Is(err, %v); got %v", + sentinel, + err, + ) + } + }) + } +}