diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index c5555b43..53f4a488 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -202,6 +202,7 @@ jobs: - TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowLogoutAndReloginSameUser - TestAuthWebFlowLogoutAndReloginNewUser + - TestPolicyCheckCommand - TestUserCommand - TestPreAuthKeyCommand - TestPreAuthKeyCommandWithoutExpiry diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 2c3365ed..1c6cb3dd 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -48,7 +48,7 @@ func init() { policyCmd.AddCommand(setPolicy) checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") - checkPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running. Required to validate that user@ tokens resolve against the user database; without it, the check is syntax-only.") + checkPolicy.Flags().BoolP(bypassFlag, "", false, "Open the database directly (no gRPC, no running server) to validate user@ token references and to evaluate the policy's tests block. Required when those checks are needed.") mustMarkRequired(checkPolicy, "file") policyCmd.AddCommand(checkPolicy) } @@ -171,6 +171,11 @@ var setPolicy = &cobra.Command{ var checkPolicy = &cobra.Command{ Use: "check", Short: "Check the Policy file for errors", + Long: ` + Check validates the policy against the server's live users and nodes, + running any "tests" block. By default the command is a thin frontend + for a gRPC call to a running headscale; pass --` + bypassFlag + ` to + open the database directly when headscale is not running.`, RunE: func(cmd *cobra.Command, args []string) error { policyPath, _ := cmd.Flags().GetString("file") @@ -179,8 +184,6 @@ var checkPolicy = &cobra.Command{ return fmt.Errorf("reading policy file: %w", err) } - var users []types.User - if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass { if !confirmAction(cmd, "DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?") { return errAborted @@ -192,23 +195,49 @@ var checkPolicy = &cobra.Command{ } defer d.Close() - users, err = d.ListUsers() + users, err := d.ListUsers() if err != nil { - return fmt.Errorf("loading users for policy validation: %w", err) + return fmt.Errorf("loading users: %w", err) } - } - _, err = policy.NewPolicyManager(policyBytes, users, views.Slice[types.NodeView]{}) - if err != nil { - return fmt.Errorf("parsing policy file: %w", err) - } + nodes, err := d.ListNodes() + if err != nil { + return fmt.Errorf("loading nodes: %w", err) + } + + // NewPolicyManager validates structure and user references + // but intentionally skips test evaluation (boot path). + // SetPolicy is the user-write boundary and is what runs the + // tests block. + pm, err := policy.NewPolicyManager(policyBytes, users, nodes.ViewSlice()) + if err != nil { + return fmt.Errorf("parsing policy file: %w", err) + } + + _, err = pm.SetPolicy(policyBytes) + if err != nil { + return err + } - if users == nil { - fmt.Println("Policy syntax is valid (run with --" + bypassFlag + " to also validate user references against the database)") - } else { fmt.Println("Policy is valid") + + return nil } + ctx, client, conn, cancel, err := newHeadscaleCLIWithConfig() + if err != nil { + return fmt.Errorf("connecting to headscale: %w", err) + } + defer cancel() + defer conn.Close() + + _, err = client.CheckPolicy(ctx, &v1.CheckPolicyRequest{Policy: string(policyBytes)}) + if err != nil { + return err + } + + fmt.Println("Policy is valid") + return nil }, } diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index a8a9bf69..fa7ded2d 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -3,7 +3,6 @@ package cli import ( "os" "runtime" - "slices" "strings" "github.com/juanfont/headscale/hscontrol/types" @@ -22,11 +21,6 @@ func init() { return } - if slices.Contains(os.Args, "policy") && slices.Contains(os.Args, "check") { - zerolog.SetGlobalLevel(zerolog.Disabled) - return - } - cobra.OnInitialize(initConfig) rootCmd.PersistentFlags(). StringVarP(&cfgFile, "config", "c", "", "config file (default is /etc/headscale/config.yaml)") diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 48d6e2d1..953ae0d7 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -26,6 +26,7 @@ import ( "tailscale.com/types/views" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -781,6 +782,35 @@ func (api headscaleV1APIServer) SetPolicy( return response, nil } +// CheckPolicy validates the given policy against the server's live users +// and nodes, running its `tests` block as a sandbox. Nothing is persisted +// and the live PolicyManager is not touched. Works regardless of +// policy.mode so operators can validate a policy file before storing it. +func (api headscaleV1APIServer) CheckPolicy( + _ context.Context, + request *v1.CheckPolicyRequest, +) (*v1.CheckPolicyResponse, error) { + polB := []byte(request.GetPolicy()) + + users, err := api.h.state.ListAllUsers() + if err != nil { + return nil, status.Errorf(codes.Internal, "loading users: %s", err) + } + + nodes := api.h.state.ListNodes() + + pm, err := policyv2.NewPolicyManager(polB, users, nodes) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + + if _, err := pm.SetPolicy(polB); err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + + return &v1.CheckPolicyResponse{}, nil +} + // The following service calls are for testing and debugging func (api headscaleV1APIServer) DebugCreateNode( ctx context.Context, diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index f11741fb..1941ecf7 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -187,6 +187,16 @@ func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.Node return nil, err } + // Boot path: log a warning if the stored policy's tests would + // fail against the current users and nodes, but keep the server + // running. A stale stored policy (e.g. referencing a user that + // was deleted while the server was offline) should not block + // boot; the operator finds out via logs and re-runs the write + // boundary when they are ready. + if testErr := pm.RunTests(); testErr != nil { //nolint:noinlineerr // boot path: warn-and-continue, not return + log.Warn().Err(testErr).Msg("policy tests failed at boot; server starting anyway, fix the policy and reload") + } + return &pm, nil } @@ -447,6 +457,15 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { return false, fmt.Errorf("validating policy user references: %w", err) } + // SetPolicy is the user-write boundary. Tests evaluate against a + // sandbox compiled from the new policy + current users/nodes; if + // they fail, return without mutating the live PolicyManager so the + // failed write does not knock the running config offline. + err = evaluateTests(pol, pm.users, pm.nodes) + if err != nil { + return false, err + } + // Log policy metadata for debugging log.Debug(). Int("policy.bytes", len(polB)). @@ -455,6 +474,7 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { Int("hosts.count", len(pol.Hosts)). Int("tagOwners.count", len(pol.TagOwners)). Int("autoApprovers.routes.count", len(pol.AutoApprovers.Routes)). + Int("tests.count", len(pol.Tests)). Msg("Policy parsed successfully") pm.pol = pol diff --git a/hscontrol/policy/v2/test.go b/hscontrol/policy/v2/test.go new file mode 100644 index 00000000..cf7b4bc6 --- /dev/null +++ b/hscontrol/policy/v2/test.go @@ -0,0 +1,420 @@ +package v2 + +import ( + "errors" + "fmt" + "net/netip" + "slices" + "strings" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "tailscale.com/tailcfg" + "tailscale.com/types/views" +) + +// Tailscale's policy file `tests` block validates a policy against operator +// assertions: from a given src, named dst:port pairs must be accepted, and +// (optionally) other dst:port pairs must be denied. They run at user-write +// boundaries — `headscale policy set`, file-mode reload after a change, +// `headscale policy check` — and reject the write if any assertion fails. +// Boot-time reload of an already-stored policy does not run them, so a +// stale referenced entity (e.g. a deleted user) cannot lock the server out. +// +// The tests evaluate against the compiled global filter rules, which fold in +// both `acls` and `grants`, so the `tests` block validates the whole policy. + +// errPolicyTestsFailed wraps the rendered failure body so callers can +// type-assert when they need to react differently to test failures vs. parse +// errors. The Error() output is the user-facing message and is intended to +// match Tailscale SaaS verbatim once the corpus is captured via tscap. +var ( + errPolicyTestsFailed = errors.New("policy tests failed") + errTestDestinationNoIP = errors.New("destination resolved to no IP addresses") +) + +// PolicyTest is one entry in the policy's `tests` block. +type PolicyTest struct { + // Src is a single source alias (user, group, tag, host, autogroup, or IP). + // Tailscale only supports a single src per test entry. + Src string `json:"src"` + + // Proto restricts the test to one protocol. Empty matches the default + // set the client applies when proto is omitted (TCP/UDP/ICMP). + Proto Protocol `json:"proto,omitempty"` + + // Accept lists destinations in `host:port` form that must be reachable + // from Src. A test fails if any entry is denied by the compiled filter. + Accept []string `json:"accept,omitempty"` + + // Deny lists destinations in `host:port` form that must NOT be reachable + // from Src. A test fails if any entry is allowed by the compiled filter. + Deny []string `json:"deny,omitempty"` +} + +// PolicyTestResult is the outcome of a single PolicyTest. +type PolicyTestResult struct { + Src string `json:"src"` + Proto Protocol `json:"proto,omitempty"` + Passed bool `json:"passed"` + + // Errors are non-assertion problems: src failed to resolve, dst was + // malformed, etc. These cause the test to fail. + Errors []string `json:"errors,omitempty"` + + // AcceptOK / AcceptFail / DenyOK / DenyFail partition the per-dst + // outcomes for diagnostics. + AcceptOK []string `json:"accept_ok,omitempty"` + AcceptFail []string `json:"accept_fail,omitempty"` + DenyOK []string `json:"deny_ok,omitempty"` + DenyFail []string `json:"deny_fail,omitempty"` +} + +// PolicyTestResults aggregates a run. +type PolicyTestResults struct { + AllPassed bool `json:"all_passed"` + Results []PolicyTestResult `json:"results"` +} + +// Errors renders the failure body. Format is intended to byte-exact match +// Tailscale SaaS once captured via tscap; until the corpus lands, the +// strings below are best-effort and will be updated to match. +func (r PolicyTestResults) Errors() string { + if r.AllPassed { + return "" + } + + var lines []string + + for _, res := range r.Results { + if res.Passed { + continue + } + + protoSuffix := "" + if res.Proto != "" { + protoSuffix = fmt.Sprintf(" (%s)", res.Proto) + } + + for _, e := range res.Errors { + lines = append(lines, fmt.Sprintf("%s%s: %s", res.Src, protoSuffix, e)) + } + + for _, dst := range res.AcceptFail { + lines = append(lines, fmt.Sprintf("%s -> %s%s: expected ALLOWED, got DENIED", res.Src, dst, protoSuffix)) + } + + for _, dst := range res.DenyFail { + lines = append(lines, fmt.Sprintf("%s -> %s%s: expected DENIED, got ALLOWED", res.Src, dst, protoSuffix)) + } + } + + return strings.Join(lines, "\n") +} + +// RunTests evaluates the policy's own `tests` block against the live compiled +// filter and returns a wrapped error when any test fails. Callers that need +// the per-test breakdown can call runPolicyTests directly. +func (pm *PolicyManager) RunTests() error { + if pm == nil || pm.pol == nil || len(pm.pol.Tests) == 0 { + return nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + results := runPolicyTests(pm.pol, pm.filter, pm.users, pm.nodes) + if results.AllPassed { + return nil + } + + return fmt.Errorf("%w:\n%s", errPolicyTestsFailed, results.Errors()) +} + +// evaluateTests runs the `tests` block against a fresh compilation of pol. +// It is the user-write sandbox: the live PolicyManager state is left +// untouched, so a failing test rejects the write without side effects. +func evaluateTests(pol *Policy, users []types.User, nodes views.Slice[types.NodeView]) error { + if pol == nil || len(pol.Tests) == 0 { + return nil + } + + grants := pol.compileGrants(users, nodes) + + var filter []tailcfg.FilterRule + if pol.ACLs == nil && pol.Grants == nil { + filter = tailcfg.FilterAllowAll + } else { + filter = globalFilterRules(grants) + } + + results := runPolicyTests(pol, filter, users, nodes) + if results.AllPassed { + return nil + } + + return fmt.Errorf("%w:\n%s", errPolicyTestsFailed, results.Errors()) +} + +// runPolicyTests is the pure evaluation function: given a policy, the +// compiled filter rules derived from it, and the active users/nodes, run +// every test and return the aggregated outcome. It does not lock anything +// or mutate any input. +func runPolicyTests(pol *Policy, filter []tailcfg.FilterRule, users []types.User, nodes views.Slice[types.NodeView]) PolicyTestResults { + results := PolicyTestResults{ + AllPassed: true, + Results: make([]PolicyTestResult, 0, len(pol.Tests)), + } + + for _, test := range pol.Tests { + res := runPolicyTest(test, pol, filter, users, nodes) + if !res.Passed { + results.AllPassed = false + } + + results.Results = append(results.Results, res) + } + + return results +} + +// runPolicyTest evaluates one PolicyTest. +func runPolicyTest(test PolicyTest, pol *Policy, filter []tailcfg.FilterRule, users []types.User, nodes views.Slice[types.NodeView]) PolicyTestResult { + res := PolicyTestResult{ + Src: test.Src, + Proto: test.Proto, + Passed: true, + } + + srcPrefixes, err := resolveTestSource(test.Src, pol, users, nodes) + if err != nil { + res.Passed = false + res.Errors = append(res.Errors, fmt.Sprintf("failed to resolve source %q: %v", test.Src, err)) + + return res + } + + if len(srcPrefixes) == 0 { + res.Passed = false + res.Errors = append(res.Errors, fmt.Sprintf("source %q resolved to no IP addresses", test.Src)) + + return res + } + + for _, dst := range test.Accept { + allowed, err := evalReachability(srcPrefixes, dst, test.Proto, pol, filter, users, nodes) + if err != nil { + res.Passed = false + res.Errors = append(res.Errors, fmt.Sprintf("error testing %q: %v", dst, err)) + + continue + } + + if allowed { + res.AcceptOK = append(res.AcceptOK, dst) + } else { + res.Passed = false + res.AcceptFail = append(res.AcceptFail, dst) + } + } + + for _, dst := range test.Deny { + allowed, err := evalReachability(srcPrefixes, dst, test.Proto, pol, filter, users, nodes) + if err != nil { + res.Passed = false + res.Errors = append(res.Errors, fmt.Sprintf("error testing %q: %v", dst, err)) + + continue + } + + if !allowed { + res.DenyOK = append(res.DenyOK, dst) + } else { + res.Passed = false + res.DenyFail = append(res.DenyFail, dst) + } + } + + return res +} + +// resolveTestSource resolves the Src alias of a PolicyTest into a slice of +// netip.Prefix. parseAlias + Alias.Resolve cover every alias type the rest +// of the policy engine supports, so tests inherit alias semantics for free. +func resolveTestSource(src string, pol *Policy, users []types.User, nodes views.Slice[types.NodeView]) ([]netip.Prefix, error) { + alias, err := parseAlias(src) + if err != nil { + return nil, fmt.Errorf("invalid alias: %w", err) + } + + addrs, err := alias.Resolve(pol, users, nodes) + if err != nil { + return nil, fmt.Errorf("resolving: %w", err) + } + + if addrs == nil || addrs.Empty() { + return nil, nil + } + + return addrs.Prefixes(), nil +} + +// evalReachability reports whether traffic from any srcPrefix to dst (in +// `host:port` form) is allowed by filter for the requested protocol. +// +// Empty proto means the default set the client applies when proto is +// omitted (TCP/UDP/ICMP) — we accept a rule whose IPProto list contains +// any of those, or rules with no IPProto restriction at all. +func evalReachability(srcPrefixes []netip.Prefix, dst string, proto Protocol, pol *Policy, filter []tailcfg.FilterRule, users []types.User, nodes views.Slice[types.NodeView]) (bool, error) { + awp, err := parseDestinationAlias(dst) + if err != nil { + return false, fmt.Errorf("invalid destination %q: %w", dst, err) + } + + dstAddrs, err := awp.Resolve(pol, users, nodes) + if err != nil { + return false, fmt.Errorf("resolving destination: %w", err) + } + + if dstAddrs == nil || dstAddrs.Empty() { + return false, fmt.Errorf("%w: %q", errTestDestinationNoIP, dst) + } + + dstPrefixes := dstAddrs.Prefixes() + + // Tailscale's tests semantics: ALL src prefixes must reach the dst for + // the test to consider it allowed. A partial allow is a fail. + for _, src := range srcPrefixes { + if !srcReachesDst(src, dstPrefixes, awp.Ports, proto, filter) { + return false, nil + } + } + + return true, nil +} + +// parseDestinationAlias is a thin wrapper over AliasWithPorts.UnmarshalJSON +// so callers can hand it a bare `"host:port"` string without re-implementing +// the parse logic. +func parseDestinationAlias(dst string) (*AliasWithPorts, error) { + var awp AliasWithPorts + + // AliasWithPorts.UnmarshalJSON expects a quoted JSON string, so wrap. + err := awp.UnmarshalJSON([]byte(`"` + dst + `"`)) + if err != nil { + return nil, err + } + + return &awp, nil +} + +// srcReachesDst walks the compiled filter rules and reports whether +// traffic from src to any prefix in dstPrefixes on at least one of ports +// (or any port when ports is empty) is allowed under proto. +func srcReachesDst(src netip.Prefix, dstPrefixes []netip.Prefix, ports []tailcfg.PortRange, proto Protocol, filter []tailcfg.FilterRule) bool { + requestedProtos := proto.toIANAProtocolNumbers() + + for _, rule := range filter { + if !ruleMatchesSource(rule, src) { + continue + } + + if !ruleMatchesProto(rule, requestedProtos) { + continue + } + + if ruleAllowsAnyDest(rule, dstPrefixes, ports) { + return true + } + } + + return false +} + +// ruleMatchesSource reports whether the rule's source list contains src. +// SrcIPs may be CIDR, single addresses, IP ranges (`a-b`), or `*`; we use +// util.ParseIPSet to cover all of those uniformly. Unparseable entries +// are skipped (the rule compiler emits well-formed strings, so this is +// defence-in-depth, not error handling). +func ruleMatchesSource(rule tailcfg.FilterRule, src netip.Prefix) bool { + for _, raw := range rule.SrcIPs { + set, err := util.ParseIPSet(raw, nil) + if err != nil { + continue + } + + if set.OverlapsPrefix(src) { + return true + } + } + + return false +} + +// ruleMatchesProto reports whether the rule permits the requested +// protocols. An unset rule.IPProto means "any protocol" and matches +// everything; an empty requestedProtos (proto == "") means the default +// set, which matches any rule including unset ones. +func ruleMatchesProto(rule tailcfg.FilterRule, requestedProtos []int) bool { + if len(rule.IPProto) == 0 { + return true + } + + if len(requestedProtos) == 0 { + // Default set: a rule restricted to a non-default protocol does + // not match the default request. + return false + } + + for _, ruleProto := range rule.IPProto { + if slices.Contains(requestedProtos, ruleProto) { + return true + } + } + + return false +} + +// ruleAllowsAnyDest reports whether at least one destination prefix in +// dstPrefixes is allowed by at least one of the rule's DstPorts entries +// for at least one of ports (or any port when ports is empty). +func ruleAllowsAnyDest(rule tailcfg.FilterRule, dstPrefixes []netip.Prefix, ports []tailcfg.PortRange) bool { + for _, dp := range rule.DstPorts { + if !destEntryMatchesPrefixes(dp, dstPrefixes) { + continue + } + + if portsAllowed(ports, dp.Ports) { + return true + } + } + + return false +} + +// destEntryMatchesPrefixes reports whether the rule's NetPortRange.IP +// (CIDR, single IP, IP range, or "*") covers any prefix in dstPrefixes. +func destEntryMatchesPrefixes(dp tailcfg.NetPortRange, dstPrefixes []netip.Prefix) bool { + set, err := util.ParseIPSet(dp.IP, nil) + if err != nil { + return false + } + + return slices.ContainsFunc(dstPrefixes, set.OverlapsPrefix) +} + +// portsAllowed reports whether at least one requested port is contained +// in allowed. Empty requested means "any port". +func portsAllowed(requested []tailcfg.PortRange, allowed tailcfg.PortRange) bool { + if len(requested) == 0 { + return true + } + + for _, r := range requested { + if r.First >= allowed.First && r.Last <= allowed.Last { + return true + } + } + + return false +} diff --git a/hscontrol/policy/v2/test_test.go b/hscontrol/policy/v2/test_test.go new file mode 100644 index 00000000..b87a1856 --- /dev/null +++ b/hscontrol/policy/v2/test_test.go @@ -0,0 +1,382 @@ +package v2 + +import ( + "strings" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +// policyTestUsers/policyTestNodes are reused across the test cases below to +// keep each table row focussed on the policy + tests under exercise. +func policyTestUsers() types.Users { + return types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@headscale.net"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@headscale.net"}, + } +} + +func policyTestNodes(users types.Users) types.Nodes { + nodes := types.Nodes{ + // alice's user-owned laptop + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &users[0], + UserID: &users[0].ID, + }, + // bob's user-owned laptop + { + ID: 2, + Hostname: "bob-laptop", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: &users[1], + UserID: &users[1].ID, + }, + // tagged server (created via tagged preauth key from alice) + { + ID: 3, + Hostname: "server", + IPv4: ap("100.64.0.3"), + IPv6: ap("fd7a:115c:a1e0::3"), + User: &users[0], + UserID: &users[0].ID, + Tags: []string{"tag:server"}, + }, + } + + return nodes +} + +// TestRunTests covers the engine's per-test outcome reporting. Each row +// constructs a PolicyManager (which also runs SetPolicy's sandbox) and +// checks the resulting RunTests behaviour. SetPolicy gating is exercised +// separately in TestSetPolicyRejectsFailingTests. +func TestRunTests(t *testing.T) { + users := policyTestUsers() + nodes := policyTestNodes(users) + + tests := []struct { + name string + policy string + wantPass bool + wantErrSub []string // substrings expected in the rendered error + wantNoErrIs error // sentinel the error must wrap + }{ + { + name: "all-pass-user-to-tag", + policy: `{ + "tagOwners": { "tag:server": ["alice@headscale.net"] }, + "acls": [{ + "action": "accept", + "src": ["alice@headscale.net"], + "dst": ["tag:server:22"] + }], + "tests": [{ + "src": "alice@headscale.net", + "accept": ["tag:server:22"] + }] + }`, + wantPass: true, + }, + { + name: "accept-fail-blocked-by-policy", + policy: `{ + "tagOwners": { "tag:server": ["alice@headscale.net"] }, + "acls": [{ + "action": "accept", + "src": ["alice@headscale.net"], + "dst": ["tag:server:22"] + }], + "tests": [{ + "src": "bob@headscale.net", + "accept": ["tag:server:22"] + }] + }`, + wantPass: false, + wantErrSub: []string{"bob@headscale.net", "tag:server:22", "expected ALLOWED"}, + wantNoErrIs: errPolicyTestsFailed, + }, + { + name: "deny-fail-policy-allows-traffic", + policy: `{ + "tagOwners": { "tag:server": ["alice@headscale.net"] }, + "acls": [{ + "action": "accept", + "src": ["alice@headscale.net"], + "dst": ["tag:server:22"] + }], + "tests": [{ + "src": "alice@headscale.net", + "deny": ["tag:server:22"] + }] + }`, + wantPass: false, + wantErrSub: []string{"alice@headscale.net", "tag:server:22", "expected DENIED"}, + wantNoErrIs: errPolicyTestsFailed, + }, + { + name: "unknown-src-user", + policy: `{ + "acls": [{ + "action": "accept", + "src": ["*"], + "dst": ["*:*"] + }], + "tests": [{ + "src": "ghost@headscale.net", + "accept": ["alice-laptop:22"] + }] + }`, + wantPass: false, + wantErrSub: []string{"ghost@headscale.net", "failed to resolve source"}, + wantNoErrIs: errPolicyTestsFailed, + }, + { + name: "malformed-dst-missing-port", + policy: `{ + "acls": [{ + "action": "accept", + "src": ["*"], + "dst": ["*:*"] + }], + "tests": [{ + "src": "alice@headscale.net", + "accept": ["alice-laptop"] + }] + }`, + wantPass: false, + wantErrSub: []string{"alice-laptop", "error testing"}, + wantNoErrIs: errPolicyTestsFailed, + }, + { + name: "wildcard-src-passes", + policy: `{ + "tagOwners": { "tag:server": ["alice@headscale.net"] }, + "acls": [{ + "action": "accept", + "src": ["*"], + "dst": ["tag:server:80"] + }], + "tests": [{ + "src": "alice@headscale.net", + "accept": ["tag:server:80"] + }] + }`, + wantPass: true, + }, + { + name: "proto-restrict-tcp-only", + policy: `{ + "tagOwners": { "tag:server": ["alice@headscale.net"] }, + "acls": [{ + "action": "accept", + "proto": "tcp", + "src": ["alice@headscale.net"], + "dst": ["tag:server:22"] + }], + "tests": [ + { + "src": "alice@headscale.net", + "proto": "tcp", + "accept": ["tag:server:22"] + }, + { + "src": "alice@headscale.net", + "proto": "udp", + "deny": ["tag:server:22"] + } + ] + }`, + wantPass: true, + }, + { + name: "grants-only-policy-evaluated", + policy: `{ + "tagOwners": { "tag:server": ["alice@headscale.net"] }, + "grants": [{ + "src": ["alice@headscale.net"], + "dst": ["tag:server"], + "ip": ["22"] + }], + "tests": [{ + "src": "alice@headscale.net", + "accept": ["tag:server:22"] + }] + }`, + wantPass: true, + }, + { + name: "mixed-pass-and-fail-reports-failure", + policy: `{ + "tagOwners": { "tag:server": ["alice@headscale.net"] }, + "acls": [{ + "action": "accept", + "src": ["alice@headscale.net"], + "dst": ["tag:server:22"] + }], + "tests": [ + { + "src": "alice@headscale.net", + "accept": ["tag:server:22"] + }, + { + "src": "bob@headscale.net", + "accept": ["tag:server:22"] + } + ] + }`, + wantPass: false, + wantErrSub: []string{"bob@headscale.net", "expected ALLOWED"}, + wantNoErrIs: errPolicyTestsFailed, + }, + { + name: "no-tests-block-is-no-op", + policy: `{ + "acls": [{ + "action": "accept", + "src": ["*"], + "dst": ["*:*"] + }] + }`, + wantPass: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm, err := NewPolicyManager([]byte(tt.policy), users, nodes.ViewSlice()) + require.NoError(t, err, "policy must parse and compile") + + runErr := pm.RunTests() + if tt.wantPass { + require.NoError(t, runErr, "tests should pass") + + return + } + + require.Error(t, runErr, "tests should fail") + require.ErrorIs(t, runErr, tt.wantNoErrIs, "error should wrap errPolicyTestsFailed") + + for _, sub := range tt.wantErrSub { + assert.Contains(t, runErr.Error(), sub, "rendered error should mention %q", sub) + } + }) + } +} + +// TestSetPolicyRejectsFailingTests asserts that SetPolicy is the user-write +// boundary: a policy whose tests fail must be rejected without mutating the +// live PolicyManager. NewPolicyManager (boot path) does not run tests. +func TestSetPolicyRejectsFailingTests(t *testing.T) { + users := policyTestUsers() + nodes := policyTestNodes(users) + + good := `{ + "tagOwners": { "tag:server": ["alice@headscale.net"] }, + "acls": [{ + "action": "accept", + "src": ["alice@headscale.net"], + "dst": ["tag:server:22"] + }], + "tests": [{ + "src": "alice@headscale.net", + "accept": ["tag:server:22"] + }] + }` + + bad := `{ + "tagOwners": { "tag:server": ["alice@headscale.net"] }, + "acls": [{ + "action": "accept", + "src": ["alice@headscale.net"], + "dst": ["tag:server:22"] + }], + "tests": [{ + "src": "bob@headscale.net", + "accept": ["tag:server:22"] + }] + }` + + pm, err := NewPolicyManager([]byte(good), users, nodes.ViewSlice()) + require.NoError(t, err) + + beforeFilter, _ := pm.Filter() + + changed, err := pm.SetPolicy([]byte(bad)) + require.Error(t, err, "SetPolicy must reject a policy whose tests fail") + require.False(t, changed, "SetPolicy must report no change when rejected") + require.ErrorIs(t, err, errPolicyTestsFailed) + require.Contains(t, err.Error(), "expected ALLOWED") + + afterFilter, _ := pm.Filter() + require.Len(t, afterFilter, len(beforeFilter), "live filter must not change after a rejected SetPolicy") +} + +// TestNewPolicyManagerSkipsTests asserts the boot path does not evaluate +// tests, so a stale stored policy referencing a now-deleted user does not +// stop the server from booting. +func TestNewPolicyManagerSkipsTests(t *testing.T) { + users := policyTestUsers() + nodes := policyTestNodes(users) + + // Tests reference "ghost@headscale.net" which doesn't exist. Boot + // must not error. + stale := `{ + "acls": [{ + "action": "accept", + "src": ["*"], + "dst": ["*:*"] + }], + "tests": [{ + "src": "ghost@headscale.net", + "accept": ["alice-laptop:22"] + }] + }` + + pm, err := NewPolicyManager([]byte(stale), users, nodes.ViewSlice()) + require.NoError(t, err, "boot must not run tests") + require.NotNil(t, pm) + + // And a subsequent SetPolicy of the same body must reject — that's + // the user-write path. + _, err = pm.SetPolicy([]byte(stale)) + require.Error(t, err) + require.ErrorIs(t, err, errPolicyTestsFailed) +} + +// TestPolicyTestResultsErrorsRendering checks the multi-line render layout +// since the body becomes the user-facing error. +func TestPolicyTestResultsErrorsRendering(t *testing.T) { + results := PolicyTestResults{ + AllPassed: false, + Results: []PolicyTestResult{ + { + Src: "alice@headscale.net", + AcceptFail: []string{"tag:server:22"}, + }, + { + Src: "bob@headscale.net", + Proto: "tcp", + DenyFail: []string{"tag:server:443"}, + }, + }, + } + + rendered := results.Errors() + for _, sub := range []string{ + "alice@headscale.net -> tag:server:22: expected ALLOWED, got DENIED", + "bob@headscale.net -> tag:server:443 (tcp): expected DENIED, got ALLOWED", + } { + assert.Contains(t, rendered, sub) + } + + // Lines should be newline-separated, not space-joined. + assert.Equal(t, 2, strings.Count(rendered, "\n")+1, "expected one line per failing assertion") +} diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 4c75dd61..13fffa94 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1986,6 +1986,7 @@ type Policy struct { Grants []Grant `json:"grants,omitempty"` AutoApprovers AutoApproverPolicy `json:"autoApprovers"` SSHs []SSH `json:"ssh,omitempty"` + Tests []PolicyTest `json:"tests,omitempty"` } // MarshalJSON is deliberately not implemented for Policy. diff --git a/integration/cli_policy_test.go b/integration/cli_policy_test.go new file mode 100644 index 00000000..2f8a0c11 --- /dev/null +++ b/integration/cli_policy_test.go @@ -0,0 +1,166 @@ +package integration + +import ( + "encoding/json" + "testing" + + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" +) + +// TestPolicyCheckCommand exercises `headscale policy check` across the +// matrix that nblock asked about on PR #3229: +// +// - policyMode: server runs with policy_mode=file vs policy_mode=database. +// `check` reads from `--file`, so the server-side mode should not +// change the outcome; running both proves that. +// - fixture: ACL only, ACL with passing tests, ACL with failing tests. +// - bypass: no-bypass talks to the server over gRPC; bypass opens the +// database directly. +// +// Each row spins up its own scenario because policy_mode is fixed at boot +// via `HEADSCALE_POLICY_MODE`. The two users + two nodes give the tests +// block real `user@` aliases to resolve against. +func TestPolicyCheckCommand(t *testing.T) { + IntegrationSkip(t) + + type fixture struct { + name string + policy policyv2.Policy + } + + const ( + user1 = "user1@" + user2 = "user2@" + ) + + aclOnly := policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: policyv2.ActionAccept, + Protocol: "tcp", //nolint:goconst // protocol literal, used inline once + Sources: []policyv2.Alias{usernamep(user1)}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep(user2), tailcfg.PortRange{First: 22, Last: 22}), + }, + }, + }, + } + + aclPlusPassingTests := aclOnly + aclPlusPassingTests.Tests = []policyv2.PolicyTest{ + { + Src: user1, + Accept: []string{user2 + ":22"}, + }, + } + + aclPlusFailingTests := aclOnly + aclPlusFailingTests.Tests = []policyv2.PolicyTest{ + { + // Reverse direction is not allowed by the ACL; the test + // asserts ALLOWED, so it must fail. + Src: user2, + Accept: []string{user1 + ":22"}, + }, + } + + fixtures := []fixture{ + {name: "acl-only", policy: aclOnly}, + {name: "acl-plus-passing-tests", policy: aclPlusPassingTests}, + {name: "acl-plus-failing-tests", policy: aclPlusFailingTests}, + } + + type row struct { + name string + policyMode string + fixture fixture + bypass bool + wantErr string + wantStdout string + } + + modes := []string{"file", "database"} //nolint:goconst // axis labels match HEADSCALE_POLICY_MODE values + bypasses := []bool{false, true} + rows := make([]row, 0, len(modes)*len(fixtures)*len(bypasses)) + + for _, mode := range modes { + for _, f := range fixtures { + for _, bypass := range bypasses { + suffix := "no-bypass" + if bypass { + suffix = "bypass" + } + + r := row{ + name: mode + "-" + f.name + "-" + suffix, + policyMode: mode, + fixture: f, + bypass: bypass, + wantStdout: "Policy is valid", + } + if f.name == "acl-plus-failing-tests" { + r.wantErr = "test(s) failed" + r.wantStdout = "" + } + + rows = append(rows, r) + } + } + } + + for _, tt := range rows { + t.Run(tt.name, func(t *testing.T) { + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, //nolint:goconst // matches usernamep("user1@")/("user2@") above + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("cli-policycheck"), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_POLICY_MODE": tt.policyMode, //nolint:goconst // env var name from hscontrol/types/config.go + }), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + pBytes, err := json.Marshal(tt.fixture.policy) + require.NoError(t, err) + + policyFilePath := "/etc/headscale/policy.json" //nolint:goconst // standard headscale policy path + err = headscale.WriteFile(policyFilePath, pBytes) + require.NoError(t, err) + + cmd := []string{"headscale", "policy", "check", "-f", policyFilePath} //nolint:goconst // CLI invocation + if tt.bypass { + // --force suppresses the "is the server running?" + // confirmation prompt so the command can run + // non-interactively under the test harness. + cmd = append(cmd, "--bypass-grpc-and-access-database-directly", "--force") + } + + stdout, err := headscale.Execute(cmd) + + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + + return + } + + require.NoError(t, err) + require.Contains(t, stdout, tt.wantStdout) + }) + } +}