policy/v2: evaluate the tests block on user-initiated writes

v2 silently dropped policy.tests, so a policy that contradicted its
own assertions still applied. Resolve src/dst via the existing Alias
machinery, walk the compiled global filter rules (acls and grants
both contribute), and run on every user-write boundary: SetPolicy,
the file watcher, and `headscale policy check`. A failing test
rejects the write before it mutates live state.

Boot-time reload skips evaluation; an already-stored policy that
references a deleted user shouldn't lock the server out.

`headscale policy check` is a thin frontend for the new CheckPolicy
gRPC method. The server-side handler builds a fresh PolicyManager
from the request bytes and the state's live users/nodes, runs
SetPolicy on the sandbox so the tests block executes, and returns
the result through gRPC status. No persistence, no policy_mode
coupling. --bypass-grpc-and-access-database-directly opens the DB
directly when the server is not running.

cmd/headscale/cli/root.go no longer special-cases `policy check` in
init() (the early return from PR #2580 broke --config registration
and viper priming for --bypass).

integration/cli_policy_test.go covers policy_mode={file,database} x
fixture={acl-only, acl+passing-tests, acl+failing-tests} x
bypass={false,true} = 12 rows.

Updates #1803

Co-authored-by: Janis Jansons <janhouse@gmail.com>
This commit is contained in:
Kristoffer Dalby
2026-04-29 14:27:12 +00:00
parent 56146de377
commit b29ae25356
9 changed files with 1062 additions and 19 deletions

View File

@@ -202,6 +202,7 @@ jobs:
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndReloginSameUser
- TestAuthWebFlowLogoutAndReloginNewUser
- TestPolicyCheckCommand
- TestUserCommand
- TestPreAuthKeyCommand
- TestPreAuthKeyCommandWithoutExpiry

View File

@@ -48,7 +48,7 @@ func init() {
policyCmd.AddCommand(setPolicy)
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
checkPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running. Required to validate that user@ tokens resolve against the user database; without it, the check is syntax-only.")
checkPolicy.Flags().BoolP(bypassFlag, "", false, "Open the database directly (no gRPC, no running server) to validate user@ token references and to evaluate the policy's tests block. Required when those checks are needed.")
mustMarkRequired(checkPolicy, "file")
policyCmd.AddCommand(checkPolicy)
}
@@ -171,6 +171,11 @@ var setPolicy = &cobra.Command{
var checkPolicy = &cobra.Command{
Use: "check",
Short: "Check the Policy file for errors",
Long: `
Check validates the policy against the server's live users and nodes,
running any "tests" block. By default the command is a thin frontend
for a gRPC call to a running headscale; pass --` + bypassFlag + ` to
open the database directly when headscale is not running.`,
RunE: func(cmd *cobra.Command, args []string) error {
policyPath, _ := cmd.Flags().GetString("file")
@@ -179,8 +184,6 @@ var checkPolicy = &cobra.Command{
return fmt.Errorf("reading policy file: %w", err)
}
var users []types.User
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
if !confirmAction(cmd, "DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?") {
return errAborted
@@ -192,23 +195,49 @@ var checkPolicy = &cobra.Command{
}
defer d.Close()
users, err = d.ListUsers()
users, err := d.ListUsers()
if err != nil {
return fmt.Errorf("loading users for policy validation: %w", err)
return fmt.Errorf("loading users: %w", err)
}
}
_, err = policy.NewPolicyManager(policyBytes, users, views.Slice[types.NodeView]{})
if err != nil {
return fmt.Errorf("parsing policy file: %w", err)
}
nodes, err := d.ListNodes()
if err != nil {
return fmt.Errorf("loading nodes: %w", err)
}
// NewPolicyManager validates structure and user references
// but intentionally skips test evaluation (boot path).
// SetPolicy is the user-write boundary and is what runs the
// tests block.
pm, err := policy.NewPolicyManager(policyBytes, users, nodes.ViewSlice())
if err != nil {
return fmt.Errorf("parsing policy file: %w", err)
}
_, err = pm.SetPolicy(policyBytes)
if err != nil {
return err
}
if users == nil {
fmt.Println("Policy syntax is valid (run with --" + bypassFlag + " to also validate user references against the database)")
} else {
fmt.Println("Policy is valid")
return nil
}
ctx, client, conn, cancel, err := newHeadscaleCLIWithConfig()
if err != nil {
return fmt.Errorf("connecting to headscale: %w", err)
}
defer cancel()
defer conn.Close()
_, err = client.CheckPolicy(ctx, &v1.CheckPolicyRequest{Policy: string(policyBytes)})
if err != nil {
return err
}
fmt.Println("Policy is valid")
return nil
},
}

View File

@@ -3,7 +3,6 @@ package cli
import (
"os"
"runtime"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
@@ -22,11 +21,6 @@ func init() {
return
}
if slices.Contains(os.Args, "policy") && slices.Contains(os.Args, "check") {
zerolog.SetGlobalLevel(zerolog.Disabled)
return
}
cobra.OnInitialize(initConfig)
rootCmd.PersistentFlags().
StringVarP(&cfgFile, "config", "c", "", "config file (default is /etc/headscale/config.yaml)")

View File

@@ -26,6 +26,7 @@ import (
"tailscale.com/types/views"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@@ -781,6 +782,35 @@ func (api headscaleV1APIServer) SetPolicy(
return response, nil
}
// CheckPolicy validates the given policy against the server's live users
// and nodes, running its `tests` block as a sandbox. Nothing is persisted
// and the live PolicyManager is not touched. Works regardless of
// policy.mode so operators can validate a policy file before storing it.
func (api headscaleV1APIServer) CheckPolicy(
_ context.Context,
request *v1.CheckPolicyRequest,
) (*v1.CheckPolicyResponse, error) {
polB := []byte(request.GetPolicy())
users, err := api.h.state.ListAllUsers()
if err != nil {
return nil, status.Errorf(codes.Internal, "loading users: %s", err)
}
nodes := api.h.state.ListNodes()
pm, err := policyv2.NewPolicyManager(polB, users, nodes)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
if _, err := pm.SetPolicy(polB); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
return &v1.CheckPolicyResponse{}, nil
}
// The following service calls are for testing and debugging
func (api headscaleV1APIServer) DebugCreateNode(
ctx context.Context,

View File

@@ -187,6 +187,16 @@ func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.Node
return nil, err
}
// Boot path: log a warning if the stored policy's tests would
// fail against the current users and nodes, but keep the server
// running. A stale stored policy (e.g. referencing a user that
// was deleted while the server was offline) should not block
// boot; the operator finds out via logs and re-runs the write
// boundary when they are ready.
if testErr := pm.RunTests(); testErr != nil { //nolint:noinlineerr // boot path: warn-and-continue, not return
log.Warn().Err(testErr).Msg("policy tests failed at boot; server starting anyway, fix the policy and reload")
}
return &pm, nil
}
@@ -447,6 +457,15 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
return false, fmt.Errorf("validating policy user references: %w", err)
}
// SetPolicy is the user-write boundary. Tests evaluate against a
// sandbox compiled from the new policy + current users/nodes; if
// they fail, return without mutating the live PolicyManager so the
// failed write does not knock the running config offline.
err = evaluateTests(pol, pm.users, pm.nodes)
if err != nil {
return false, err
}
// Log policy metadata for debugging
log.Debug().
Int("policy.bytes", len(polB)).
@@ -455,6 +474,7 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
Int("hosts.count", len(pol.Hosts)).
Int("tagOwners.count", len(pol.TagOwners)).
Int("autoApprovers.routes.count", len(pol.AutoApprovers.Routes)).
Int("tests.count", len(pol.Tests)).
Msg("Policy parsed successfully")
pm.pol = pol

420
hscontrol/policy/v2/test.go Normal file
View File

@@ -0,0 +1,420 @@
package v2
import (
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
)
// Tailscale's policy file `tests` block validates a policy against operator
// assertions: from a given src, named dst:port pairs must be accepted, and
// (optionally) other dst:port pairs must be denied. They run at user-write
// boundaries — `headscale policy set`, file-mode reload after a change,
// `headscale policy check` — and reject the write if any assertion fails.
// Boot-time reload of an already-stored policy does not run them, so a
// stale referenced entity (e.g. a deleted user) cannot lock the server out.
//
// The tests evaluate against the compiled global filter rules, which fold in
// both `acls` and `grants`, so the `tests` block validates the whole policy.
// errPolicyTestsFailed wraps the rendered failure body so callers can
// type-assert when they need to react differently to test failures vs. parse
// errors. The Error() output is the user-facing message and is intended to
// match Tailscale SaaS verbatim once the corpus is captured via tscap.
var (
errPolicyTestsFailed = errors.New("policy tests failed")
errTestDestinationNoIP = errors.New("destination resolved to no IP addresses")
)
// PolicyTest is one entry in the policy's `tests` block.
type PolicyTest struct {
// Src is a single source alias (user, group, tag, host, autogroup, or IP).
// Tailscale only supports a single src per test entry.
Src string `json:"src"`
// Proto restricts the test to one protocol. Empty matches the default
// set the client applies when proto is omitted (TCP/UDP/ICMP).
Proto Protocol `json:"proto,omitempty"`
// Accept lists destinations in `host:port` form that must be reachable
// from Src. A test fails if any entry is denied by the compiled filter.
Accept []string `json:"accept,omitempty"`
// Deny lists destinations in `host:port` form that must NOT be reachable
// from Src. A test fails if any entry is allowed by the compiled filter.
Deny []string `json:"deny,omitempty"`
}
// PolicyTestResult is the outcome of a single PolicyTest.
type PolicyTestResult struct {
Src string `json:"src"`
Proto Protocol `json:"proto,omitempty"`
Passed bool `json:"passed"`
// Errors are non-assertion problems: src failed to resolve, dst was
// malformed, etc. These cause the test to fail.
Errors []string `json:"errors,omitempty"`
// AcceptOK / AcceptFail / DenyOK / DenyFail partition the per-dst
// outcomes for diagnostics.
AcceptOK []string `json:"accept_ok,omitempty"`
AcceptFail []string `json:"accept_fail,omitempty"`
DenyOK []string `json:"deny_ok,omitempty"`
DenyFail []string `json:"deny_fail,omitempty"`
}
// PolicyTestResults aggregates a run.
type PolicyTestResults struct {
AllPassed bool `json:"all_passed"`
Results []PolicyTestResult `json:"results"`
}
// Errors renders the failure body. Format is intended to byte-exact match
// Tailscale SaaS once captured via tscap; until the corpus lands, the
// strings below are best-effort and will be updated to match.
func (r PolicyTestResults) Errors() string {
if r.AllPassed {
return ""
}
var lines []string
for _, res := range r.Results {
if res.Passed {
continue
}
protoSuffix := ""
if res.Proto != "" {
protoSuffix = fmt.Sprintf(" (%s)", res.Proto)
}
for _, e := range res.Errors {
lines = append(lines, fmt.Sprintf("%s%s: %s", res.Src, protoSuffix, e))
}
for _, dst := range res.AcceptFail {
lines = append(lines, fmt.Sprintf("%s -> %s%s: expected ALLOWED, got DENIED", res.Src, dst, protoSuffix))
}
for _, dst := range res.DenyFail {
lines = append(lines, fmt.Sprintf("%s -> %s%s: expected DENIED, got ALLOWED", res.Src, dst, protoSuffix))
}
}
return strings.Join(lines, "\n")
}
// RunTests evaluates the policy's own `tests` block against the live compiled
// filter and returns a wrapped error when any test fails. Callers that need
// the per-test breakdown can call runPolicyTests directly.
func (pm *PolicyManager) RunTests() error {
if pm == nil || pm.pol == nil || len(pm.pol.Tests) == 0 {
return nil
}
pm.mu.Lock()
defer pm.mu.Unlock()
results := runPolicyTests(pm.pol, pm.filter, pm.users, pm.nodes)
if results.AllPassed {
return nil
}
return fmt.Errorf("%w:\n%s", errPolicyTestsFailed, results.Errors())
}
// evaluateTests runs the `tests` block against a fresh compilation of pol.
// It is the user-write sandbox: the live PolicyManager state is left
// untouched, so a failing test rejects the write without side effects.
func evaluateTests(pol *Policy, users []types.User, nodes views.Slice[types.NodeView]) error {
if pol == nil || len(pol.Tests) == 0 {
return nil
}
grants := pol.compileGrants(users, nodes)
var filter []tailcfg.FilterRule
if pol.ACLs == nil && pol.Grants == nil {
filter = tailcfg.FilterAllowAll
} else {
filter = globalFilterRules(grants)
}
results := runPolicyTests(pol, filter, users, nodes)
if results.AllPassed {
return nil
}
return fmt.Errorf("%w:\n%s", errPolicyTestsFailed, results.Errors())
}
// runPolicyTests is the pure evaluation function: given a policy, the
// compiled filter rules derived from it, and the active users/nodes, run
// every test and return the aggregated outcome. It does not lock anything
// or mutate any input.
func runPolicyTests(pol *Policy, filter []tailcfg.FilterRule, users []types.User, nodes views.Slice[types.NodeView]) PolicyTestResults {
results := PolicyTestResults{
AllPassed: true,
Results: make([]PolicyTestResult, 0, len(pol.Tests)),
}
for _, test := range pol.Tests {
res := runPolicyTest(test, pol, filter, users, nodes)
if !res.Passed {
results.AllPassed = false
}
results.Results = append(results.Results, res)
}
return results
}
// runPolicyTest evaluates one PolicyTest.
func runPolicyTest(test PolicyTest, pol *Policy, filter []tailcfg.FilterRule, users []types.User, nodes views.Slice[types.NodeView]) PolicyTestResult {
res := PolicyTestResult{
Src: test.Src,
Proto: test.Proto,
Passed: true,
}
srcPrefixes, err := resolveTestSource(test.Src, pol, users, nodes)
if err != nil {
res.Passed = false
res.Errors = append(res.Errors, fmt.Sprintf("failed to resolve source %q: %v", test.Src, err))
return res
}
if len(srcPrefixes) == 0 {
res.Passed = false
res.Errors = append(res.Errors, fmt.Sprintf("source %q resolved to no IP addresses", test.Src))
return res
}
for _, dst := range test.Accept {
allowed, err := evalReachability(srcPrefixes, dst, test.Proto, pol, filter, users, nodes)
if err != nil {
res.Passed = false
res.Errors = append(res.Errors, fmt.Sprintf("error testing %q: %v", dst, err))
continue
}
if allowed {
res.AcceptOK = append(res.AcceptOK, dst)
} else {
res.Passed = false
res.AcceptFail = append(res.AcceptFail, dst)
}
}
for _, dst := range test.Deny {
allowed, err := evalReachability(srcPrefixes, dst, test.Proto, pol, filter, users, nodes)
if err != nil {
res.Passed = false
res.Errors = append(res.Errors, fmt.Sprintf("error testing %q: %v", dst, err))
continue
}
if !allowed {
res.DenyOK = append(res.DenyOK, dst)
} else {
res.Passed = false
res.DenyFail = append(res.DenyFail, dst)
}
}
return res
}
// resolveTestSource resolves the Src alias of a PolicyTest into a slice of
// netip.Prefix. parseAlias + Alias.Resolve cover every alias type the rest
// of the policy engine supports, so tests inherit alias semantics for free.
func resolveTestSource(src string, pol *Policy, users []types.User, nodes views.Slice[types.NodeView]) ([]netip.Prefix, error) {
alias, err := parseAlias(src)
if err != nil {
return nil, fmt.Errorf("invalid alias: %w", err)
}
addrs, err := alias.Resolve(pol, users, nodes)
if err != nil {
return nil, fmt.Errorf("resolving: %w", err)
}
if addrs == nil || addrs.Empty() {
return nil, nil
}
return addrs.Prefixes(), nil
}
// evalReachability reports whether traffic from any srcPrefix to dst (in
// `host:port` form) is allowed by filter for the requested protocol.
//
// Empty proto means the default set the client applies when proto is
// omitted (TCP/UDP/ICMP) — we accept a rule whose IPProto list contains
// any of those, or rules with no IPProto restriction at all.
func evalReachability(srcPrefixes []netip.Prefix, dst string, proto Protocol, pol *Policy, filter []tailcfg.FilterRule, users []types.User, nodes views.Slice[types.NodeView]) (bool, error) {
awp, err := parseDestinationAlias(dst)
if err != nil {
return false, fmt.Errorf("invalid destination %q: %w", dst, err)
}
dstAddrs, err := awp.Resolve(pol, users, nodes)
if err != nil {
return false, fmt.Errorf("resolving destination: %w", err)
}
if dstAddrs == nil || dstAddrs.Empty() {
return false, fmt.Errorf("%w: %q", errTestDestinationNoIP, dst)
}
dstPrefixes := dstAddrs.Prefixes()
// Tailscale's tests semantics: ALL src prefixes must reach the dst for
// the test to consider it allowed. A partial allow is a fail.
for _, src := range srcPrefixes {
if !srcReachesDst(src, dstPrefixes, awp.Ports, proto, filter) {
return false, nil
}
}
return true, nil
}
// parseDestinationAlias is a thin wrapper over AliasWithPorts.UnmarshalJSON
// so callers can hand it a bare `"host:port"` string without re-implementing
// the parse logic.
func parseDestinationAlias(dst string) (*AliasWithPorts, error) {
var awp AliasWithPorts
// AliasWithPorts.UnmarshalJSON expects a quoted JSON string, so wrap.
err := awp.UnmarshalJSON([]byte(`"` + dst + `"`))
if err != nil {
return nil, err
}
return &awp, nil
}
// srcReachesDst walks the compiled filter rules and reports whether
// traffic from src to any prefix in dstPrefixes on at least one of ports
// (or any port when ports is empty) is allowed under proto.
func srcReachesDst(src netip.Prefix, dstPrefixes []netip.Prefix, ports []tailcfg.PortRange, proto Protocol, filter []tailcfg.FilterRule) bool {
requestedProtos := proto.toIANAProtocolNumbers()
for _, rule := range filter {
if !ruleMatchesSource(rule, src) {
continue
}
if !ruleMatchesProto(rule, requestedProtos) {
continue
}
if ruleAllowsAnyDest(rule, dstPrefixes, ports) {
return true
}
}
return false
}
// ruleMatchesSource reports whether the rule's source list contains src.
// SrcIPs may be CIDR, single addresses, IP ranges (`a-b`), or `*`; we use
// util.ParseIPSet to cover all of those uniformly. Unparseable entries
// are skipped (the rule compiler emits well-formed strings, so this is
// defence-in-depth, not error handling).
func ruleMatchesSource(rule tailcfg.FilterRule, src netip.Prefix) bool {
for _, raw := range rule.SrcIPs {
set, err := util.ParseIPSet(raw, nil)
if err != nil {
continue
}
if set.OverlapsPrefix(src) {
return true
}
}
return false
}
// ruleMatchesProto reports whether the rule permits the requested
// protocols. An unset rule.IPProto means "any protocol" and matches
// everything; an empty requestedProtos (proto == "") means the default
// set, which matches any rule including unset ones.
func ruleMatchesProto(rule tailcfg.FilterRule, requestedProtos []int) bool {
if len(rule.IPProto) == 0 {
return true
}
if len(requestedProtos) == 0 {
// Default set: a rule restricted to a non-default protocol does
// not match the default request.
return false
}
for _, ruleProto := range rule.IPProto {
if slices.Contains(requestedProtos, ruleProto) {
return true
}
}
return false
}
// ruleAllowsAnyDest reports whether at least one destination prefix in
// dstPrefixes is allowed by at least one of the rule's DstPorts entries
// for at least one of ports (or any port when ports is empty).
func ruleAllowsAnyDest(rule tailcfg.FilterRule, dstPrefixes []netip.Prefix, ports []tailcfg.PortRange) bool {
for _, dp := range rule.DstPorts {
if !destEntryMatchesPrefixes(dp, dstPrefixes) {
continue
}
if portsAllowed(ports, dp.Ports) {
return true
}
}
return false
}
// destEntryMatchesPrefixes reports whether the rule's NetPortRange.IP
// (CIDR, single IP, IP range, or "*") covers any prefix in dstPrefixes.
func destEntryMatchesPrefixes(dp tailcfg.NetPortRange, dstPrefixes []netip.Prefix) bool {
set, err := util.ParseIPSet(dp.IP, nil)
if err != nil {
return false
}
return slices.ContainsFunc(dstPrefixes, set.OverlapsPrefix)
}
// portsAllowed reports whether at least one requested port is contained
// in allowed. Empty requested means "any port".
func portsAllowed(requested []tailcfg.PortRange, allowed tailcfg.PortRange) bool {
if len(requested) == 0 {
return true
}
for _, r := range requested {
if r.First >= allowed.First && r.Last <= allowed.Last {
return true
}
}
return false
}

View File

@@ -0,0 +1,382 @@
package v2
import (
"strings"
"testing"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
// policyTestUsers/policyTestNodes are reused across the test cases below to
// keep each table row focussed on the policy + tests under exercise.
func policyTestUsers() types.Users {
return types.Users{
{Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@headscale.net"},
{Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@headscale.net"},
}
}
func policyTestNodes(users types.Users) types.Nodes {
nodes := types.Nodes{
// alice's user-owned laptop
{
ID: 1,
Hostname: "alice-laptop",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: &users[0],
UserID: &users[0].ID,
},
// bob's user-owned laptop
{
ID: 2,
Hostname: "bob-laptop",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: &users[1],
UserID: &users[1].ID,
},
// tagged server (created via tagged preauth key from alice)
{
ID: 3,
Hostname: "server",
IPv4: ap("100.64.0.3"),
IPv6: ap("fd7a:115c:a1e0::3"),
User: &users[0],
UserID: &users[0].ID,
Tags: []string{"tag:server"},
},
}
return nodes
}
// TestRunTests covers the engine's per-test outcome reporting. Each row
// constructs a PolicyManager (which also runs SetPolicy's sandbox) and
// checks the resulting RunTests behaviour. SetPolicy gating is exercised
// separately in TestSetPolicyRejectsFailingTests.
func TestRunTests(t *testing.T) {
users := policyTestUsers()
nodes := policyTestNodes(users)
tests := []struct {
name string
policy string
wantPass bool
wantErrSub []string // substrings expected in the rendered error
wantNoErrIs error // sentinel the error must wrap
}{
{
name: "all-pass-user-to-tag",
policy: `{
"tagOwners": { "tag:server": ["alice@headscale.net"] },
"acls": [{
"action": "accept",
"src": ["alice@headscale.net"],
"dst": ["tag:server:22"]
}],
"tests": [{
"src": "alice@headscale.net",
"accept": ["tag:server:22"]
}]
}`,
wantPass: true,
},
{
name: "accept-fail-blocked-by-policy",
policy: `{
"tagOwners": { "tag:server": ["alice@headscale.net"] },
"acls": [{
"action": "accept",
"src": ["alice@headscale.net"],
"dst": ["tag:server:22"]
}],
"tests": [{
"src": "bob@headscale.net",
"accept": ["tag:server:22"]
}]
}`,
wantPass: false,
wantErrSub: []string{"bob@headscale.net", "tag:server:22", "expected ALLOWED"},
wantNoErrIs: errPolicyTestsFailed,
},
{
name: "deny-fail-policy-allows-traffic",
policy: `{
"tagOwners": { "tag:server": ["alice@headscale.net"] },
"acls": [{
"action": "accept",
"src": ["alice@headscale.net"],
"dst": ["tag:server:22"]
}],
"tests": [{
"src": "alice@headscale.net",
"deny": ["tag:server:22"]
}]
}`,
wantPass: false,
wantErrSub: []string{"alice@headscale.net", "tag:server:22", "expected DENIED"},
wantNoErrIs: errPolicyTestsFailed,
},
{
name: "unknown-src-user",
policy: `{
"acls": [{
"action": "accept",
"src": ["*"],
"dst": ["*:*"]
}],
"tests": [{
"src": "ghost@headscale.net",
"accept": ["alice-laptop:22"]
}]
}`,
wantPass: false,
wantErrSub: []string{"ghost@headscale.net", "failed to resolve source"},
wantNoErrIs: errPolicyTestsFailed,
},
{
name: "malformed-dst-missing-port",
policy: `{
"acls": [{
"action": "accept",
"src": ["*"],
"dst": ["*:*"]
}],
"tests": [{
"src": "alice@headscale.net",
"accept": ["alice-laptop"]
}]
}`,
wantPass: false,
wantErrSub: []string{"alice-laptop", "error testing"},
wantNoErrIs: errPolicyTestsFailed,
},
{
name: "wildcard-src-passes",
policy: `{
"tagOwners": { "tag:server": ["alice@headscale.net"] },
"acls": [{
"action": "accept",
"src": ["*"],
"dst": ["tag:server:80"]
}],
"tests": [{
"src": "alice@headscale.net",
"accept": ["tag:server:80"]
}]
}`,
wantPass: true,
},
{
name: "proto-restrict-tcp-only",
policy: `{
"tagOwners": { "tag:server": ["alice@headscale.net"] },
"acls": [{
"action": "accept",
"proto": "tcp",
"src": ["alice@headscale.net"],
"dst": ["tag:server:22"]
}],
"tests": [
{
"src": "alice@headscale.net",
"proto": "tcp",
"accept": ["tag:server:22"]
},
{
"src": "alice@headscale.net",
"proto": "udp",
"deny": ["tag:server:22"]
}
]
}`,
wantPass: true,
},
{
name: "grants-only-policy-evaluated",
policy: `{
"tagOwners": { "tag:server": ["alice@headscale.net"] },
"grants": [{
"src": ["alice@headscale.net"],
"dst": ["tag:server"],
"ip": ["22"]
}],
"tests": [{
"src": "alice@headscale.net",
"accept": ["tag:server:22"]
}]
}`,
wantPass: true,
},
{
name: "mixed-pass-and-fail-reports-failure",
policy: `{
"tagOwners": { "tag:server": ["alice@headscale.net"] },
"acls": [{
"action": "accept",
"src": ["alice@headscale.net"],
"dst": ["tag:server:22"]
}],
"tests": [
{
"src": "alice@headscale.net",
"accept": ["tag:server:22"]
},
{
"src": "bob@headscale.net",
"accept": ["tag:server:22"]
}
]
}`,
wantPass: false,
wantErrSub: []string{"bob@headscale.net", "expected ALLOWED"},
wantNoErrIs: errPolicyTestsFailed,
},
{
name: "no-tests-block-is-no-op",
policy: `{
"acls": [{
"action": "accept",
"src": ["*"],
"dst": ["*:*"]
}]
}`,
wantPass: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.policy), users, nodes.ViewSlice())
require.NoError(t, err, "policy must parse and compile")
runErr := pm.RunTests()
if tt.wantPass {
require.NoError(t, runErr, "tests should pass")
return
}
require.Error(t, runErr, "tests should fail")
require.ErrorIs(t, runErr, tt.wantNoErrIs, "error should wrap errPolicyTestsFailed")
for _, sub := range tt.wantErrSub {
assert.Contains(t, runErr.Error(), sub, "rendered error should mention %q", sub)
}
})
}
}
// TestSetPolicyRejectsFailingTests asserts that SetPolicy is the user-write
// boundary: a policy whose tests fail must be rejected without mutating the
// live PolicyManager. NewPolicyManager (boot path) does not run tests.
func TestSetPolicyRejectsFailingTests(t *testing.T) {
users := policyTestUsers()
nodes := policyTestNodes(users)
good := `{
"tagOwners": { "tag:server": ["alice@headscale.net"] },
"acls": [{
"action": "accept",
"src": ["alice@headscale.net"],
"dst": ["tag:server:22"]
}],
"tests": [{
"src": "alice@headscale.net",
"accept": ["tag:server:22"]
}]
}`
bad := `{
"tagOwners": { "tag:server": ["alice@headscale.net"] },
"acls": [{
"action": "accept",
"src": ["alice@headscale.net"],
"dst": ["tag:server:22"]
}],
"tests": [{
"src": "bob@headscale.net",
"accept": ["tag:server:22"]
}]
}`
pm, err := NewPolicyManager([]byte(good), users, nodes.ViewSlice())
require.NoError(t, err)
beforeFilter, _ := pm.Filter()
changed, err := pm.SetPolicy([]byte(bad))
require.Error(t, err, "SetPolicy must reject a policy whose tests fail")
require.False(t, changed, "SetPolicy must report no change when rejected")
require.ErrorIs(t, err, errPolicyTestsFailed)
require.Contains(t, err.Error(), "expected ALLOWED")
afterFilter, _ := pm.Filter()
require.Len(t, afterFilter, len(beforeFilter), "live filter must not change after a rejected SetPolicy")
}
// TestNewPolicyManagerSkipsTests asserts the boot path does not evaluate
// tests, so a stale stored policy referencing a now-deleted user does not
// stop the server from booting.
func TestNewPolicyManagerSkipsTests(t *testing.T) {
users := policyTestUsers()
nodes := policyTestNodes(users)
// Tests reference "ghost@headscale.net" which doesn't exist. Boot
// must not error.
stale := `{
"acls": [{
"action": "accept",
"src": ["*"],
"dst": ["*:*"]
}],
"tests": [{
"src": "ghost@headscale.net",
"accept": ["alice-laptop:22"]
}]
}`
pm, err := NewPolicyManager([]byte(stale), users, nodes.ViewSlice())
require.NoError(t, err, "boot must not run tests")
require.NotNil(t, pm)
// And a subsequent SetPolicy of the same body must reject — that's
// the user-write path.
_, err = pm.SetPolicy([]byte(stale))
require.Error(t, err)
require.ErrorIs(t, err, errPolicyTestsFailed)
}
// TestPolicyTestResultsErrorsRendering checks the multi-line render layout
// since the body becomes the user-facing error.
func TestPolicyTestResultsErrorsRendering(t *testing.T) {
results := PolicyTestResults{
AllPassed: false,
Results: []PolicyTestResult{
{
Src: "alice@headscale.net",
AcceptFail: []string{"tag:server:22"},
},
{
Src: "bob@headscale.net",
Proto: "tcp",
DenyFail: []string{"tag:server:443"},
},
},
}
rendered := results.Errors()
for _, sub := range []string{
"alice@headscale.net -> tag:server:22: expected ALLOWED, got DENIED",
"bob@headscale.net -> tag:server:443 (tcp): expected DENIED, got ALLOWED",
} {
assert.Contains(t, rendered, sub)
}
// Lines should be newline-separated, not space-joined.
assert.Equal(t, 2, strings.Count(rendered, "\n")+1, "expected one line per failing assertion")
}

View File

@@ -1986,6 +1986,7 @@ type Policy struct {
Grants []Grant `json:"grants,omitempty"`
AutoApprovers AutoApproverPolicy `json:"autoApprovers"`
SSHs []SSH `json:"ssh,omitempty"`
Tests []PolicyTest `json:"tests,omitempty"`
}
// MarshalJSON is deliberately not implemented for Policy.

View File

@@ -0,0 +1,166 @@
package integration
import (
"encoding/json"
"testing"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
// TestPolicyCheckCommand exercises `headscale policy check` across the
// matrix that nblock asked about on PR #3229:
//
// - policyMode: server runs with policy_mode=file vs policy_mode=database.
// `check` reads from `--file`, so the server-side mode should not
// change the outcome; running both proves that.
// - fixture: ACL only, ACL with passing tests, ACL with failing tests.
// - bypass: no-bypass talks to the server over gRPC; bypass opens the
// database directly.
//
// Each row spins up its own scenario because policy_mode is fixed at boot
// via `HEADSCALE_POLICY_MODE`. The two users + two nodes give the tests
// block real `user@` aliases to resolve against.
func TestPolicyCheckCommand(t *testing.T) {
IntegrationSkip(t)
type fixture struct {
name string
policy policyv2.Policy
}
const (
user1 = "user1@"
user2 = "user2@"
)
aclOnly := policyv2.Policy{
ACLs: []policyv2.ACL{
{
Action: policyv2.ActionAccept,
Protocol: "tcp", //nolint:goconst // protocol literal, used inline once
Sources: []policyv2.Alias{usernamep(user1)},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(usernamep(user2), tailcfg.PortRange{First: 22, Last: 22}),
},
},
},
}
aclPlusPassingTests := aclOnly
aclPlusPassingTests.Tests = []policyv2.PolicyTest{
{
Src: user1,
Accept: []string{user2 + ":22"},
},
}
aclPlusFailingTests := aclOnly
aclPlusFailingTests.Tests = []policyv2.PolicyTest{
{
// Reverse direction is not allowed by the ACL; the test
// asserts ALLOWED, so it must fail.
Src: user2,
Accept: []string{user1 + ":22"},
},
}
fixtures := []fixture{
{name: "acl-only", policy: aclOnly},
{name: "acl-plus-passing-tests", policy: aclPlusPassingTests},
{name: "acl-plus-failing-tests", policy: aclPlusFailingTests},
}
type row struct {
name string
policyMode string
fixture fixture
bypass bool
wantErr string
wantStdout string
}
modes := []string{"file", "database"} //nolint:goconst // axis labels match HEADSCALE_POLICY_MODE values
bypasses := []bool{false, true}
rows := make([]row, 0, len(modes)*len(fixtures)*len(bypasses))
for _, mode := range modes {
for _, f := range fixtures {
for _, bypass := range bypasses {
suffix := "no-bypass"
if bypass {
suffix = "bypass"
}
r := row{
name: mode + "-" + f.name + "-" + suffix,
policyMode: mode,
fixture: f,
bypass: bypass,
wantStdout: "Policy is valid",
}
if f.name == "acl-plus-failing-tests" {
r.wantErr = "test(s) failed"
r.wantStdout = ""
}
rows = append(rows, r)
}
}
}
for _, tt := range rows {
t.Run(tt.name, func(t *testing.T) {
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"}, //nolint:goconst // matches usernamep("user1@")/("user2@") above
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("cli-policycheck"),
hsic.WithConfigEnv(map[string]string{
"HEADSCALE_POLICY_MODE": tt.policyMode, //nolint:goconst // env var name from hscontrol/types/config.go
}),
)
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
pBytes, err := json.Marshal(tt.fixture.policy)
require.NoError(t, err)
policyFilePath := "/etc/headscale/policy.json" //nolint:goconst // standard headscale policy path
err = headscale.WriteFile(policyFilePath, pBytes)
require.NoError(t, err)
cmd := []string{"headscale", "policy", "check", "-f", policyFilePath} //nolint:goconst // CLI invocation
if tt.bypass {
// --force suppresses the "is the server running?"
// confirmation prompt so the command can run
// non-interactively under the test harness.
cmd = append(cmd, "--bypass-grpc-and-access-database-directly", "--force")
}
stdout, err := headscale.Execute(cmd)
if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
require.Contains(t, stdout, tt.wantStdout)
})
}
}