policy/v2: evaluate sshTests at write boundary

SetPolicy and policy check now compile per-dst SSH rules and replay each sshTests entry. The accept assertion treats check-action rules as reachable; the check assertion requires HoldAndDelegate on the matching rule. Boot reload warns and continues.
This commit is contained in:
Kristoffer Dalby
2026-05-13 14:17:04 +00:00
parent 6a0a297c7f
commit 013dea4f40
9 changed files with 39304 additions and 82 deletions

View File

@@ -208,6 +208,10 @@ func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.Node
log.Warn().Err(testErr).Msg("policy tests failed at boot; server starting anyway, fix the policy and reload")
}
if testErr := pm.RunSSHTests(); testErr != nil { //nolint:noinlineerr // boot path: warn-and-continue, not return
log.Warn().Err(testErr).Msg("policy sshTests failed at boot; server starting anyway, fix the policy and reload")
}
return &pm, nil
}
@@ -482,9 +486,16 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
// sandbox compiled from the new policy + current users/nodes; if
// they fail, return without mutating the live PolicyManager so the
// failed write does not knock the running config offline.
err = evaluateTests(pol, pm.users, pm.nodes)
if err != nil {
return false, err
//
// Aggregate ACL and SSH test failures via multierr so operators
// see both classes in a single response instead of having to
// fix-and-retry to discover the second one.
testErr := multierr.New(
evaluateTests(pol, pm.users, pm.nodes),
evaluateSSHTests(pol, pm.users, pm.nodes),
)
if testErr != nil {
return false, testErr
}
// Log policy metadata for debugging

View File

@@ -0,0 +1,664 @@
package v2
import (
"fmt"
"net/netip"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
)
// sshTests assertions evaluate on user-initiated writes; boot reload
// skips them so a stale reference does not block startup. Each entry
// names a src and one or more dst, and uses:
//
// - accept: every listed user reaches every dst via an accept- or
// check-action rule.
// - deny: no listed user reaches any dst.
// - check: every listed user reaches every dst via a check-action
// rule specifically (accept-only matches fail the assertion).
// SSHPolicyTestResult is the outcome of a single SSHPolicyTest.
type SSHPolicyTestResult struct {
Src string `json:"src"`
Passed bool `json:"passed"`
Errors []string `json:"errors,omitempty"`
AcceptOK map[string][]string `json:"accept_ok,omitempty"`
AcceptFail map[string][]string `json:"accept_fail,omitempty"`
DenyOK map[string][]string `json:"deny_ok,omitempty"`
DenyFail map[string][]string `json:"deny_fail,omitempty"`
CheckOK map[string][]string `json:"check_ok,omitempty"`
CheckFail map[string][]string `json:"check_fail,omitempty"`
}
// SSHPolicyTestResults aggregates one evaluation run.
type SSHPolicyTestResults struct {
AllPassed bool `json:"all_passed"`
Results []SSHPolicyTestResult `json:"results"`
}
// Errors renders the per-test failure breakdown joined by newlines.
func (r SSHPolicyTestResults) Errors() string {
if r.AllPassed {
return ""
}
var lines []string
for _, res := range r.Results {
if res.Passed {
continue
}
for _, e := range res.Errors {
lines = append(lines, fmt.Sprintf("%s: %s", res.Src, e))
}
for _, user := range sortedUsers(res.AcceptFail) {
for _, dst := range res.AcceptFail[user] {
lines = append(lines, fmt.Sprintf(
"%s/%s -> %s: expected ALLOWED, got DENIED",
res.Src, displayUser(user), dst,
))
}
}
for _, user := range sortedUsers(res.DenyFail) {
for _, dst := range res.DenyFail[user] {
lines = append(lines, fmt.Sprintf(
"%s/%s -> %s: expected DENIED, got ALLOWED",
res.Src, displayUser(user), dst,
))
}
}
for _, user := range sortedUsers(res.CheckFail) {
for _, dst := range res.CheckFail[user] {
lines = append(lines, fmt.Sprintf(
"%s/%s -> %s: expected ALLOWED via check, got %s",
res.Src, displayUser(user), dst,
checkFailReason(res, user, dst),
))
}
}
}
return strings.Join(lines, "\n")
}
func sortedUsers(m map[string][]string) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
slices.Sort(keys)
return keys
}
// displayUser shows an empty username as `""` rather than blank.
func displayUser(u string) string {
if u == "" {
return `""`
}
return u
}
// checkFailReason annotates a check-fail with whether the user reached
// the dst via an accept rule or did not reach at all.
func checkFailReason(res SSHPolicyTestResult, user, dst string) string {
if slices.Contains(res.AcceptOK[user], dst) {
return "ALLOWED via accept"
}
return "DENIED"
}
// RunSSHTests evaluates the live policy's sshTests block and wraps any
// failure in errSSHPolicyTestsFailed.
func (pm *PolicyManager) RunSSHTests() error {
if pm == nil || pm.pol == nil || len(pm.pol.SSHTests) == 0 {
return nil
}
pm.mu.Lock()
defer pm.mu.Unlock()
cache := make(map[types.NodeID]*tailcfg.SSHPolicy)
results := runSSHPolicyTests(pm.pol, pm.users, pm.nodes, cache)
if results.AllPassed {
return nil
}
return fmt.Errorf("%w:\n%s", errSSHPolicyTestsFailed, results.Errors())
}
// evaluateSSHTests runs the block against pol without mutating live state.
func evaluateSSHTests(
pol *Policy,
users []types.User,
nodes views.Slice[types.NodeView],
) error {
if pol == nil || len(pol.SSHTests) == 0 {
return nil
}
cache := make(map[types.NodeID]*tailcfg.SSHPolicy)
results := runSSHPolicyTests(pol, users, nodes, cache)
if results.AllPassed {
return nil
}
return fmt.Errorf("%w:\n%s", errSSHPolicyTestsFailed, results.Errors())
}
// runSSHPolicyTests evaluates every sshTests entry. The cache is keyed
// by dst NodeID so repeat destinations only compile once per pass.
func runSSHPolicyTests(
pol *Policy,
users []types.User,
nodes views.Slice[types.NodeView],
cache map[types.NodeID]*tailcfg.SSHPolicy,
) SSHPolicyTestResults {
results := SSHPolicyTestResults{
AllPassed: true,
Results: make([]SSHPolicyTestResult, 0, len(pol.SSHTests)),
}
for _, test := range pol.SSHTests {
res := runSSHPolicyTest(test, pol, users, nodes, cache)
if !res.Passed {
results.AllPassed = false
}
results.Results = append(results.Results, res)
}
return results
}
// runSSHPolicyTest evaluates one entry: resolve src → resolve dst →
// walk accept/deny/check arrays against each dst's compiled SSH policy.
func runSSHPolicyTest(
test SSHPolicyTest,
pol *Policy,
users []types.User,
nodes views.Slice[types.NodeView],
cache map[types.NodeID]*tailcfg.SSHPolicy,
) SSHPolicyTestResult {
srcLabel := ""
if test.Src != nil {
srcLabel = test.Src.String()
}
res := SSHPolicyTestResult{
Src: srcLabel,
Passed: true,
}
srcAddrs, srcUserID, err := resolveSSHTestSource(test.Src, pol, users, nodes)
if err != nil {
res.Passed = false
res.Errors = append(res.Errors,
fmt.Sprintf("failed to resolve source %q: %v", srcLabel, err))
return res
}
if len(srcAddrs) == 0 {
res.Passed = false
res.Errors = append(res.Errors,
fmt.Sprintf("source %q resolved to no IP addresses", srcLabel))
return res
}
// An entry with no assertion arrays would silently pass.
if len(test.Accept) == 0 && len(test.Deny) == 0 && len(test.Check) == 0 {
res.Passed = false
res.Errors = append(res.Errors,
"no accept, deny, or check assertions specified")
return res
}
dstNodes, emptyDsts, err := resolveSSHTestDestNodes(test.Dst, pol, users, nodes, srcUserID)
if err != nil {
res.Passed = false
res.Errors = append(res.Errors,
fmt.Sprintf("failed to resolve destinations: %v", err))
return res
}
// A dst resolving to zero nodes would silently pass.
for _, dst := range emptyDsts {
res.Passed = false
res.Errors = append(res.Errors,
fmt.Sprintf("dst alias %q resolved to no nodes", dst))
}
if len(dstNodes) == 0 {
return res
}
for _, user := range test.Accept {
evaluateAssertion(
pol, users, nodes, cache,
srcAddrs, dstNodes, user.String(),
assertAccept, &res,
)
}
for _, user := range test.Deny {
evaluateAssertion(
pol, users, nodes, cache,
srcAddrs, dstNodes, user.String(),
assertDeny, &res,
)
}
for _, user := range test.Check {
evaluateAssertion(
pol, users, nodes, cache,
srcAddrs, dstNodes, user.String(),
assertCheck, &res,
)
}
return res
}
type sshAssertion int
const (
assertAccept sshAssertion = iota
assertDeny
assertCheck
)
// evaluateAssertion walks every (srcAddr, dstNode) pair for one user
// and records the outcome. Empty username fails — SSH login users
// cannot be empty even when parse accepted it.
func evaluateAssertion(
pol *Policy,
users []types.User,
nodes views.Slice[types.NodeView],
cache map[types.NodeID]*tailcfg.SSHPolicy,
srcAddrs []netip.Addr,
dstNodes []types.NodeView,
user string,
kind sshAssertion,
res *SSHPolicyTestResult,
) {
dstLoop:
for _, dst := range dstNodes {
dstPol, err := compiledSSHPolicy(pol, users, nodes, cache, dst)
if err != nil {
res.Passed = false
res.Errors = append(res.Errors,
fmt.Sprintf("compiling SSH policy for %s: %v",
dst.Hostname(), err))
continue
}
dstLabel := dst.Hostname()
acceptHit := false
checkHit := false
for _, srcAddr := range srcAddrs {
a, c := reachability(dstPol, srcAddr, user)
if a {
acceptHit = true
}
if c {
checkHit = true
}
// All src IPs must agree; one counter-example fails
// the whole (user, dst) pair.
switch kind {
case assertAccept:
if !a {
res.Passed = false
res.AcceptFail = appendUserDst(res.AcceptFail, user, dstLabel)
continue dstLoop
}
case assertDeny:
if a {
res.Passed = false
res.DenyFail = appendUserDst(res.DenyFail, user, dstLabel)
continue dstLoop
}
case assertCheck:
if !c {
res.Passed = false
res.CheckFail = appendUserDst(res.CheckFail, user, dstLabel)
// Record whether the accept side passed so
// the rendered error can say "ALLOWED via
// accept" instead of "DENIED".
if a {
res.AcceptOK = appendUserDst(res.AcceptOK, user, dstLabel)
}
continue dstLoop
}
}
}
switch kind {
case assertAccept:
if acceptHit {
res.AcceptOK = appendUserDst(res.AcceptOK, user, dstLabel)
}
case assertDeny:
res.DenyOK = appendUserDst(res.DenyOK, user, dstLabel)
case assertCheck:
if checkHit {
res.CheckOK = appendUserDst(res.CheckOK, user, dstLabel)
}
}
}
}
// appendUserDst appends dst to m[user], allocating m on first use.
func appendUserDst(m map[string][]string, user, dst string) map[string][]string {
if m == nil {
m = make(map[string][]string)
}
m[user] = append(m[user], dst)
return m
}
// resolveSSHTestSource returns the src's principal addresses and, for
// user-shaped sources, the user ID (so autogroup:self can scope to it).
// Tag, host, and IP sources return userID 0.
func resolveSSHTestSource(
src Alias,
pol *Policy,
users []types.User,
nodes views.Slice[types.NodeView],
) ([]netip.Addr, uint, error) {
if src == nil {
return nil, 0, nil
}
addrs, err := src.Resolve(pol, users, nodes)
if err != nil {
return nil, 0, fmt.Errorf("resolving: %w", err)
}
if addrs == nil || addrs.Empty() {
return nil, 0, nil
}
out := make([]netip.Addr, 0)
for a := range addrs.Iter() {
out = append(out, a)
}
var userID uint
u, ok := src.(*Username)
if ok {
resolved, rErr := u.resolveUser(users)
if rErr == nil {
userID = resolved.ID
}
}
return out, userID, nil
}
// resolveSSHTestDestNodes maps each dst alias to its destination
// NodeViews. autogroup:self needs special handling: it cannot resolve
// without per-node context, so it walks the node set keyed on src's
// owning user. Other aliases resolve to an IPSet and match via InIPSet.
func resolveSSHTestDestNodes(
dsts SSHTestDestinations,
pol *Policy,
users []types.User,
nodes views.Slice[types.NodeView],
srcUserID uint,
) ([]types.NodeView, []string, error) {
seen := make(map[types.NodeID]struct{})
var (
out []types.NodeView
emptyDsts []string
)
for _, alias := range dsts {
dstLabel := alias.String()
matched := false
if ag, ok := alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
// autogroup:self resolves to non-tagged nodes owned by
// the same user as src; tagged/IP sources have no user.
if srcUserID == 0 {
emptyDsts = append(emptyDsts, dstLabel)
continue
}
for _, n := range nodes.All() {
if n.IsTagged() {
continue
}
if !n.User().Valid() {
continue
}
if n.User().ID() != srcUserID {
continue
}
matched = true
if _, dup := seen[n.ID()]; dup {
continue
}
seen[n.ID()] = struct{}{}
out = append(out, n)
}
if !matched {
emptyDsts = append(emptyDsts, dstLabel)
}
continue
}
ips, err := alias.Resolve(pol, users, nodes)
if err != nil {
return nil, nil, fmt.Errorf("resolving destination %q: %w", dstLabel, err)
}
if ips == nil || ips.Empty() {
emptyDsts = append(emptyDsts, dstLabel)
continue
}
set, err := prefixesToIPSet(ips.Prefixes())
if err != nil {
return nil, nil, fmt.Errorf("building IPSet for %q: %w", dstLabel, err)
}
for _, n := range nodes.All() {
if !n.InIPSet(set) {
continue
}
matched = true
if _, dup := seen[n.ID()]; dup {
continue
}
seen[n.ID()] = struct{}{}
out = append(out, n)
}
if !matched {
emptyDsts = append(emptyDsts, dstLabel)
}
}
return out, emptyDsts, nil
}
// prefixesToIPSet builds the IPSet that InIPSet expects on the node
// side.
func prefixesToIPSet(prefixes []netip.Prefix) (*netipx.IPSet, error) {
var b netipx.IPSetBuilder
for _, p := range prefixes {
b.AddPrefix(p)
}
return b.IPSet()
}
// compiledSSHPolicy returns the per-node compiled SSH policy, caching
// on miss. baseURL is empty because reachability only checks for the
// presence of HoldAndDelegate, not its value.
func compiledSSHPolicy(
pol *Policy,
users []types.User,
nodes views.Slice[types.NodeView],
cache map[types.NodeID]*tailcfg.SSHPolicy,
node types.NodeView,
) (*tailcfg.SSHPolicy, error) {
if sshPol, ok := cache[node.ID()]; ok {
return sshPol, nil
}
sshPol, err := pol.compileSSHPolicy("", users, node, nodes)
if err != nil {
return nil, err
}
cache[node.ID()] = sshPol
return sshPol, nil
}
// reachability reports whether srcAddr can log in as user via:
//
// - any matching rule (acceptHit, satisfies accept assertions)
// - a check-action rule (checkHit, satisfies check assertions)
func reachability(
dstPolicy *tailcfg.SSHPolicy,
srcAddr netip.Addr,
user string,
) (bool, bool) {
if dstPolicy == nil {
return false, false
}
var acceptHit, checkHit bool
for _, rule := range dstPolicy.Rules {
if !principalContainsAddr(rule.Principals, srcAddr) {
continue
}
if !sshUserMapAllows(rule.SSHUsers, user) {
continue
}
if rule.Action == nil {
continue
}
acceptHit = true
if rule.Action.HoldAndDelegate != "" {
checkHit = true
}
// Early-out only when both bits are set: a rule satisfying
// accept does not always satisfy check.
if acceptHit && checkHit {
return acceptHit, checkHit
}
}
return acceptHit, checkHit
}
// principalContainsAddr reports whether any principal's NodeIP matches
// srcAddr exactly (the SSH compiler emits one principal per source IP).
func principalContainsAddr(
principals []*tailcfg.SSHPrincipal,
srcAddr netip.Addr,
) bool {
for _, p := range principals {
if p == nil {
continue
}
if p.NodeIP == "" {
continue
}
addr, err := netip.ParseAddr(p.NodeIP)
if err != nil {
continue
}
if addr == srcAddr {
return true
}
}
return false
}
// sshUserMapAllows reports whether SSHUsers permits user. The SSHUsers
// wire shape (see filter.go compileSSHPolicy):
//
// - SSHUsers["root"] == "root" allows root; == "" disallows it.
// - SSHUsers["*"] == "=" is the wildcard fallback for non-root users
// (set when the rule lists autogroup:nonroot).
// - SSHUsers[<literal>] == <literal> for every named user.
func sshUserMapAllows(m map[string]string, user string) bool {
if user == "" {
return false
}
if v, ok := m[user]; ok {
return v != ""
}
if user == "root" {
return false
}
// Wildcard fallback for non-root users.
if v, ok := m["*"]; ok {
return v != ""
}
return false
}

File diff suppressed because it is too large Load Diff

View File

@@ -46,61 +46,6 @@ func setupSSHDataCompatUsers() types.Users {
}
}
// setupSSHDataCompatNodes returns the test nodes for SSH data-driven
// compatibility tests. Node GivenNames match the anonymized pokémon names:
// - bulbasaur (owned by odin)
// - ivysaur (owned by thor)
// - venusaur (owned by freya)
// - beedrill (tag:server)
// - kakuna (tag:prod)
func setupSSHDataCompatNodes(users types.Users) types.Nodes {
return types.Nodes{
&types.Node{
ID: 1,
GivenName: "bulbasaur",
User: &users[0],
UserID: &users[0].ID,
IPv4: ptrAddr("100.90.199.68"),
IPv6: ptrAddr("fd7a:115c:a1e0::2d01:c747"),
Hostinfo: &tailcfg.Hostinfo{},
},
&types.Node{
ID: 2,
GivenName: "ivysaur",
User: &users[1],
UserID: &users[1].ID,
IPv4: ptrAddr("100.110.121.96"),
IPv6: ptrAddr("fd7a:115c:a1e0::1737:7960"),
Hostinfo: &tailcfg.Hostinfo{},
},
&types.Node{
ID: 3,
GivenName: "venusaur",
User: &users[2],
UserID: &users[2].ID,
IPv4: ptrAddr("100.103.90.82"),
IPv6: ptrAddr("fd7a:115c:a1e0::9e37:5a52"),
Hostinfo: &tailcfg.Hostinfo{},
},
&types.Node{
ID: 4,
GivenName: "beedrill",
IPv4: ptrAddr("100.108.74.26"),
IPv6: ptrAddr("fd7a:115c:a1e0::b901:4a87"),
Tags: []string{"tag:server"},
Hostinfo: &tailcfg.Hostinfo{},
},
&types.Node{
ID: 5,
GivenName: "kakuna",
IPv4: ptrAddr("100.103.8.15"),
IPv6: ptrAddr("fd7a:115c:a1e0::5b37:80f"),
Tags: []string{"tag:prod"},
Hostinfo: &tailcfg.Hostinfo{},
},
}
}
// loadSSHTestFile loads and parses a single SSH capture HuJSON file.
func loadSSHTestFile(t *testing.T, path string) *testcapture.Capture {
t.Helper()

View File

@@ -7,6 +7,7 @@ import (
"slices"
"strings"
"github.com/go-json-experiment/json"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
@@ -57,23 +58,80 @@ type PolicyTest struct {
// listed user is asserted against every entry in Dst.
type SSHPolicyTest struct {
// Src is a single source alias (user, group, tag, host, or IP).
Src string `json:"src"`
Src Alias `json:"src"`
// Dst lists destinations the test exercises (tag, host, or SSH-
// compatible autogroup). Ports, CIDRs, and autogroup:internet are
// rejected at parse time.
Dst []string `json:"dst"`
Dst SSHTestDestinations `json:"dst"`
// Accept lists users that must reach every Dst via an accept- or
// check-action rule.
Accept []string `json:"accept,omitempty"`
Accept []SSHUser `json:"accept,omitempty"`
// Deny lists users that must NOT reach any Dst.
Deny []string `json:"deny,omitempty"`
Deny []SSHUser `json:"deny,omitempty"`
// Check lists users that must reach every Dst via a check-action
// rule specifically; an accept-action rule does not satisfy this.
Check []string `json:"check,omitempty"`
Check []SSHUser `json:"check,omitempty"`
}
// SSHTestDestinations is the typed list of destination aliases an
// sshTests entry targets. validateSSHTestDestination enforces the
// SSH-specific shape rules (no :port, no CIDR, no autogroup:internet,
// known tag).
type SSHTestDestinations []Alias
func (d *SSHTestDestinations) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
if err != nil {
return err
}
*d = make([]Alias, len(aliases))
for i, a := range aliases {
(*d)[i] = a.Alias
}
return nil
}
// UnmarshalJSON parses each typed field. An empty src lands as a nil
// Alias so validation surfaces ErrSSHTestEmptySrc rather than a parser
// failure.
func (t *SSHPolicyTest) UnmarshalJSON(b []byte) error {
var raw struct {
Src string `json:"src"`
Dst SSHTestDestinations `json:"dst"`
Accept []SSHUser `json:"accept,omitempty"`
Deny []SSHUser `json:"deny,omitempty"`
Check []SSHUser `json:"check,omitempty"`
}
err := json.Unmarshal(b, &raw, policyJSONOpts...)
if err != nil {
return err
}
trimmedSrc := strings.TrimSpace(raw.Src)
if trimmedSrc != "" {
alias, parseErr := parseAlias(trimmedSrc)
if parseErr != nil {
return parseErr
}
t.Src = alias
}
t.Dst = raw.Dst
t.Accept = raw.Accept
t.Deny = raw.Deny
t.Check = raw.Check
return nil
}
// PolicyTestResult is the outcome of a single PolicyTest.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -863,6 +863,12 @@ type Alias interface {
Validate() error
UnmarshalJSON(b []byte) error
// String renders the alias back to its policy-file form. Implementations
// are expected to return a value that round-trips through parseAlias for
// any alias the parser accepted, so callers can use it as a stable
// identity in rendered errors and logs.
String() string
// Resolve resolves the Alias to an IPSet. The IPSet will contain all the IP
// addresses that the Alias represents within Headscale. It is the product
// of the Alias and the Policy, Users and Nodes.
@@ -2927,10 +2933,6 @@ func (p *SSHCheckPeriod) UnmarshalJSON(b []byte) error {
return nil
}
// time.ParseDuration produces error strings like
// `time: invalid duration "abc"` which match SaaS body wording
// exactly; model.ParseDuration wraps the same parse with custom
// phrasing and would diverge.
d, err := time.ParseDuration(str)
if err != nil {
return err
@@ -3407,7 +3409,7 @@ func validateSSHTests(pol *Policy, tests []SSHPolicyTest) error {
var errs []error
for i, t := range tests {
if t.Src == "" {
if t.Src == nil {
errs = append(errs, fmt.Errorf("sshTest %d: %w", i, ErrSSHTestEmptySrc))
}
@@ -3439,11 +3441,8 @@ func validateSSHTests(pol *Policy, tests []SSHPolicyTest) error {
//
// A bare IP literal (single-host /BitLen prefix) is accepted. Tag
// entries must exist in tagOwners.
func validateSSHTestDestination(pol *Policy, dst string) error {
alias, err := parseAlias(dst)
if err != nil {
return fmt.Errorf("%w %q", ErrSSHTestDstDisallowedElement, dst)
}
func validateSSHTestDestination(pol *Policy, alias Alias) error {
dst := alias.String()
switch a := alias.(type) {
case *AutoGroup:
@@ -3454,9 +3453,10 @@ func validateSSHTestDestination(pol *Policy, dst string) error {
}
case *Prefix:
// Bare IP parses to a *Prefix without slash; reject any
// explicit CIDR.
if strings.Contains(dst, "/") {
// A bare IP parses as `/BitLen` and is a valid single-host dst;
// any narrower CIDR is a multi-host range and is rejected.
p := netip.Prefix(*a)
if p.Bits() < p.Addr().BitLen() {
return fmt.Errorf("%w %q", ErrSSHTestDstDisallowedElement, dst)
}

View File

@@ -6003,9 +6003,12 @@ func TestUnmarshalPolicySSHTests(t *testing.T) {
t.Helper()
require.Len(t, pol.SSHTests, 1)
got := pol.SSHTests[0]
require.Equal(t, "thor@example.org", got.Src)
require.Equal(t, []string{"tag:server"}, got.Dst)
require.Equal(t, []string{"root"}, got.Accept)
require.IsType(t, (*Username)(nil), got.Src)
require.Equal(t, "thor@example.org", got.Src.String())
require.Len(t, got.Dst, 1)
require.IsType(t, (*Tag)(nil), got.Dst[0])
require.Equal(t, "tag:server", got.Dst[0].String())
require.Equal(t, []SSHUser{"root"}, got.Accept)
require.Empty(t, got.Deny)
require.Empty(t, got.Check)
},
@@ -6030,9 +6033,9 @@ func TestUnmarshalPolicySSHTests(t *testing.T) {
t.Helper()
require.Len(t, pol.SSHTests, 1)
got := pol.SSHTests[0]
require.Equal(t, []string{"root"}, got.Accept)
require.Equal(t, []string{"nobody"}, got.Deny)
require.Equal(t, []string{"alice"}, got.Check)
require.Equal(t, []SSHUser{"root"}, got.Accept)
require.Equal(t, []SSHUser{"nobody"}, got.Deny)
require.Equal(t, []SSHUser{"alice"}, got.Check) //nolint:goconst
},
},
{
@@ -6086,6 +6089,52 @@ func TestUnmarshalPolicySSHTests(t *testing.T) {
`,
wantErr: ErrSSHTestDstDisallowedElement,
},
{
// SaaS accepts a bare IPv4 literal as a host address. The
// Prefix parser turns it into a /32 so validateSSHTestDestination
// must match Bits() against Addr().BitLen() rather than reject
// the whole *Prefix branch.
name: "dst-bare-ipv4-accepted",
input: `
{
"tagOwners": {"tag:server": ["admin@example.org"]},
"sshTests": [
{"src": "thor@example.org", "dst": ["100.64.0.16"], "accept": ["root"]}
]
}
`,
check: func(t *testing.T, pol *Policy) {
t.Helper()
require.Len(t, pol.SSHTests, 1)
got := pol.SSHTests[0]
require.Len(t, got.Dst, 1)
pref, ok := got.Dst[0].(*Prefix)
require.True(t, ok, "want *Prefix, got %T", got.Dst[0])
require.Equal(t, "100.64.0.16/32", pref.String())
},
},
{
// IPv6 mirror of the IPv4 case: bare `fd7a::10` parses to
// /128 and must pass the parse-time shape check.
name: "dst-bare-ipv6-accepted",
input: `
{
"tagOwners": {"tag:server": ["admin@example.org"]},
"sshTests": [
{"src": "thor@example.org", "dst": ["fd7a:115c:a1e0::10"], "accept": ["root"]}
]
}
`,
check: func(t *testing.T, pol *Policy) {
t.Helper()
require.Len(t, pol.SSHTests, 1)
got := pol.SSHTests[0]
require.Len(t, got.Dst, 1)
pref, ok := got.Dst[0].(*Prefix)
require.True(t, ok, "want *Prefix, got %T", got.Dst[0])
require.Equal(t, "fd7a:115c:a1e0::10/128", pref.String())
},
},
{
name: "dst-autogroup-internet",
input: `