mirror of
https://github.com/juanfont/headscale.git
synced 2026-05-23 10:42:30 +09:00
policy/v2: evaluate sshTests at write boundary
SetPolicy and policy check now compile per-dst SSH rules and replay each sshTests entry. The accept assertion treats check-action rules as reachable; the check assertion requires HoldAndDelegate on the matching rule. Boot reload warns and continues.
This commit is contained in:
@@ -208,6 +208,10 @@ func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.Node
|
||||
log.Warn().Err(testErr).Msg("policy tests failed at boot; server starting anyway, fix the policy and reload")
|
||||
}
|
||||
|
||||
if testErr := pm.RunSSHTests(); testErr != nil { //nolint:noinlineerr // boot path: warn-and-continue, not return
|
||||
log.Warn().Err(testErr).Msg("policy sshTests failed at boot; server starting anyway, fix the policy and reload")
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
@@ -482,9 +486,16 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||
// sandbox compiled from the new policy + current users/nodes; if
|
||||
// they fail, return without mutating the live PolicyManager so the
|
||||
// failed write does not knock the running config offline.
|
||||
err = evaluateTests(pol, pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, err
|
||||
//
|
||||
// Aggregate ACL and SSH test failures via multierr so operators
|
||||
// see both classes in a single response instead of having to
|
||||
// fix-and-retry to discover the second one.
|
||||
testErr := multierr.New(
|
||||
evaluateTests(pol, pm.users, pm.nodes),
|
||||
evaluateSSHTests(pol, pm.users, pm.nodes),
|
||||
)
|
||||
if testErr != nil {
|
||||
return false, testErr
|
||||
}
|
||||
|
||||
// Log policy metadata for debugging
|
||||
|
||||
664
hscontrol/policy/v2/sshtest.go
Normal file
664
hscontrol/policy/v2/sshtest.go
Normal file
@@ -0,0 +1,664 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
// sshTests assertions evaluate on user-initiated writes; boot reload
|
||||
// skips them so a stale reference does not block startup. Each entry
|
||||
// names a src and one or more dst, and uses:
|
||||
//
|
||||
// - accept: every listed user reaches every dst via an accept- or
|
||||
// check-action rule.
|
||||
// - deny: no listed user reaches any dst.
|
||||
// - check: every listed user reaches every dst via a check-action
|
||||
// rule specifically (accept-only matches fail the assertion).
|
||||
|
||||
// SSHPolicyTestResult is the outcome of a single SSHPolicyTest.
|
||||
type SSHPolicyTestResult struct {
|
||||
Src string `json:"src"`
|
||||
Passed bool `json:"passed"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
|
||||
AcceptOK map[string][]string `json:"accept_ok,omitempty"`
|
||||
AcceptFail map[string][]string `json:"accept_fail,omitempty"`
|
||||
DenyOK map[string][]string `json:"deny_ok,omitempty"`
|
||||
DenyFail map[string][]string `json:"deny_fail,omitempty"`
|
||||
CheckOK map[string][]string `json:"check_ok,omitempty"`
|
||||
CheckFail map[string][]string `json:"check_fail,omitempty"`
|
||||
}
|
||||
|
||||
// SSHPolicyTestResults aggregates one evaluation run.
|
||||
type SSHPolicyTestResults struct {
|
||||
AllPassed bool `json:"all_passed"`
|
||||
Results []SSHPolicyTestResult `json:"results"`
|
||||
}
|
||||
|
||||
// Errors renders the per-test failure breakdown joined by newlines.
|
||||
func (r SSHPolicyTestResults) Errors() string {
|
||||
if r.AllPassed {
|
||||
return ""
|
||||
}
|
||||
|
||||
var lines []string
|
||||
|
||||
for _, res := range r.Results {
|
||||
if res.Passed {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, e := range res.Errors {
|
||||
lines = append(lines, fmt.Sprintf("%s: %s", res.Src, e))
|
||||
}
|
||||
|
||||
for _, user := range sortedUsers(res.AcceptFail) {
|
||||
for _, dst := range res.AcceptFail[user] {
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"%s/%s -> %s: expected ALLOWED, got DENIED",
|
||||
res.Src, displayUser(user), dst,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
for _, user := range sortedUsers(res.DenyFail) {
|
||||
for _, dst := range res.DenyFail[user] {
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"%s/%s -> %s: expected DENIED, got ALLOWED",
|
||||
res.Src, displayUser(user), dst,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
for _, user := range sortedUsers(res.CheckFail) {
|
||||
for _, dst := range res.CheckFail[user] {
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"%s/%s -> %s: expected ALLOWED via check, got %s",
|
||||
res.Src, displayUser(user), dst,
|
||||
checkFailReason(res, user, dst),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func sortedUsers(m map[string][]string) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
slices.Sort(keys)
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// displayUser shows an empty username as `""` rather than blank.
|
||||
func displayUser(u string) string {
|
||||
if u == "" {
|
||||
return `""`
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// checkFailReason annotates a check-fail with whether the user reached
|
||||
// the dst via an accept rule or did not reach at all.
|
||||
func checkFailReason(res SSHPolicyTestResult, user, dst string) string {
|
||||
if slices.Contains(res.AcceptOK[user], dst) {
|
||||
return "ALLOWED via accept"
|
||||
}
|
||||
|
||||
return "DENIED"
|
||||
}
|
||||
|
||||
// RunSSHTests evaluates the live policy's sshTests block and wraps any
|
||||
// failure in errSSHPolicyTestsFailed.
|
||||
func (pm *PolicyManager) RunSSHTests() error {
|
||||
if pm == nil || pm.pol == nil || len(pm.pol.SSHTests) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
cache := make(map[types.NodeID]*tailcfg.SSHPolicy)
|
||||
results := runSSHPolicyTests(pm.pol, pm.users, pm.nodes, cache)
|
||||
|
||||
if results.AllPassed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w:\n%s", errSSHPolicyTestsFailed, results.Errors())
|
||||
}
|
||||
|
||||
// evaluateSSHTests runs the block against pol without mutating live state.
|
||||
func evaluateSSHTests(
|
||||
pol *Policy,
|
||||
users []types.User,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) error {
|
||||
if pol == nil || len(pol.SSHTests) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cache := make(map[types.NodeID]*tailcfg.SSHPolicy)
|
||||
results := runSSHPolicyTests(pol, users, nodes, cache)
|
||||
|
||||
if results.AllPassed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w:\n%s", errSSHPolicyTestsFailed, results.Errors())
|
||||
}
|
||||
|
||||
// runSSHPolicyTests evaluates every sshTests entry. The cache is keyed
|
||||
// by dst NodeID so repeat destinations only compile once per pass.
|
||||
func runSSHPolicyTests(
|
||||
pol *Policy,
|
||||
users []types.User,
|
||||
nodes views.Slice[types.NodeView],
|
||||
cache map[types.NodeID]*tailcfg.SSHPolicy,
|
||||
) SSHPolicyTestResults {
|
||||
results := SSHPolicyTestResults{
|
||||
AllPassed: true,
|
||||
Results: make([]SSHPolicyTestResult, 0, len(pol.SSHTests)),
|
||||
}
|
||||
|
||||
for _, test := range pol.SSHTests {
|
||||
res := runSSHPolicyTest(test, pol, users, nodes, cache)
|
||||
if !res.Passed {
|
||||
results.AllPassed = false
|
||||
}
|
||||
|
||||
results.Results = append(results.Results, res)
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// runSSHPolicyTest evaluates one entry: resolve src → resolve dst →
|
||||
// walk accept/deny/check arrays against each dst's compiled SSH policy.
|
||||
func runSSHPolicyTest(
|
||||
test SSHPolicyTest,
|
||||
pol *Policy,
|
||||
users []types.User,
|
||||
nodes views.Slice[types.NodeView],
|
||||
cache map[types.NodeID]*tailcfg.SSHPolicy,
|
||||
) SSHPolicyTestResult {
|
||||
srcLabel := ""
|
||||
if test.Src != nil {
|
||||
srcLabel = test.Src.String()
|
||||
}
|
||||
|
||||
res := SSHPolicyTestResult{
|
||||
Src: srcLabel,
|
||||
Passed: true,
|
||||
}
|
||||
|
||||
srcAddrs, srcUserID, err := resolveSSHTestSource(test.Src, pol, users, nodes)
|
||||
if err != nil {
|
||||
res.Passed = false
|
||||
res.Errors = append(res.Errors,
|
||||
fmt.Sprintf("failed to resolve source %q: %v", srcLabel, err))
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
if len(srcAddrs) == 0 {
|
||||
res.Passed = false
|
||||
res.Errors = append(res.Errors,
|
||||
fmt.Sprintf("source %q resolved to no IP addresses", srcLabel))
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// An entry with no assertion arrays would silently pass.
|
||||
if len(test.Accept) == 0 && len(test.Deny) == 0 && len(test.Check) == 0 {
|
||||
res.Passed = false
|
||||
res.Errors = append(res.Errors,
|
||||
"no accept, deny, or check assertions specified")
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
dstNodes, emptyDsts, err := resolveSSHTestDestNodes(test.Dst, pol, users, nodes, srcUserID)
|
||||
if err != nil {
|
||||
res.Passed = false
|
||||
res.Errors = append(res.Errors,
|
||||
fmt.Sprintf("failed to resolve destinations: %v", err))
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// A dst resolving to zero nodes would silently pass.
|
||||
for _, dst := range emptyDsts {
|
||||
res.Passed = false
|
||||
res.Errors = append(res.Errors,
|
||||
fmt.Sprintf("dst alias %q resolved to no nodes", dst))
|
||||
}
|
||||
|
||||
if len(dstNodes) == 0 {
|
||||
return res
|
||||
}
|
||||
|
||||
for _, user := range test.Accept {
|
||||
evaluateAssertion(
|
||||
pol, users, nodes, cache,
|
||||
srcAddrs, dstNodes, user.String(),
|
||||
assertAccept, &res,
|
||||
)
|
||||
}
|
||||
|
||||
for _, user := range test.Deny {
|
||||
evaluateAssertion(
|
||||
pol, users, nodes, cache,
|
||||
srcAddrs, dstNodes, user.String(),
|
||||
assertDeny, &res,
|
||||
)
|
||||
}
|
||||
|
||||
for _, user := range test.Check {
|
||||
evaluateAssertion(
|
||||
pol, users, nodes, cache,
|
||||
srcAddrs, dstNodes, user.String(),
|
||||
assertCheck, &res,
|
||||
)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type sshAssertion int
|
||||
|
||||
const (
|
||||
assertAccept sshAssertion = iota
|
||||
assertDeny
|
||||
assertCheck
|
||||
)
|
||||
|
||||
// evaluateAssertion walks every (srcAddr, dstNode) pair for one user
|
||||
// and records the outcome. Empty username fails — SSH login users
|
||||
// cannot be empty even when parse accepted it.
|
||||
func evaluateAssertion(
|
||||
pol *Policy,
|
||||
users []types.User,
|
||||
nodes views.Slice[types.NodeView],
|
||||
cache map[types.NodeID]*tailcfg.SSHPolicy,
|
||||
srcAddrs []netip.Addr,
|
||||
dstNodes []types.NodeView,
|
||||
user string,
|
||||
kind sshAssertion,
|
||||
res *SSHPolicyTestResult,
|
||||
) {
|
||||
dstLoop:
|
||||
for _, dst := range dstNodes {
|
||||
dstPol, err := compiledSSHPolicy(pol, users, nodes, cache, dst)
|
||||
if err != nil {
|
||||
res.Passed = false
|
||||
res.Errors = append(res.Errors,
|
||||
fmt.Sprintf("compiling SSH policy for %s: %v",
|
||||
dst.Hostname(), err))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
dstLabel := dst.Hostname()
|
||||
|
||||
acceptHit := false
|
||||
checkHit := false
|
||||
|
||||
for _, srcAddr := range srcAddrs {
|
||||
a, c := reachability(dstPol, srcAddr, user)
|
||||
if a {
|
||||
acceptHit = true
|
||||
}
|
||||
|
||||
if c {
|
||||
checkHit = true
|
||||
}
|
||||
|
||||
// All src IPs must agree; one counter-example fails
|
||||
// the whole (user, dst) pair.
|
||||
switch kind {
|
||||
case assertAccept:
|
||||
if !a {
|
||||
res.Passed = false
|
||||
res.AcceptFail = appendUserDst(res.AcceptFail, user, dstLabel)
|
||||
|
||||
continue dstLoop
|
||||
}
|
||||
case assertDeny:
|
||||
if a {
|
||||
res.Passed = false
|
||||
res.DenyFail = appendUserDst(res.DenyFail, user, dstLabel)
|
||||
|
||||
continue dstLoop
|
||||
}
|
||||
case assertCheck:
|
||||
if !c {
|
||||
res.Passed = false
|
||||
res.CheckFail = appendUserDst(res.CheckFail, user, dstLabel)
|
||||
|
||||
// Record whether the accept side passed so
|
||||
// the rendered error can say "ALLOWED via
|
||||
// accept" instead of "DENIED".
|
||||
if a {
|
||||
res.AcceptOK = appendUserDst(res.AcceptOK, user, dstLabel)
|
||||
}
|
||||
|
||||
continue dstLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case assertAccept:
|
||||
if acceptHit {
|
||||
res.AcceptOK = appendUserDst(res.AcceptOK, user, dstLabel)
|
||||
}
|
||||
case assertDeny:
|
||||
res.DenyOK = appendUserDst(res.DenyOK, user, dstLabel)
|
||||
case assertCheck:
|
||||
if checkHit {
|
||||
res.CheckOK = appendUserDst(res.CheckOK, user, dstLabel)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// appendUserDst appends dst to m[user], allocating m on first use.
|
||||
func appendUserDst(m map[string][]string, user, dst string) map[string][]string {
|
||||
if m == nil {
|
||||
m = make(map[string][]string)
|
||||
}
|
||||
|
||||
m[user] = append(m[user], dst)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// resolveSSHTestSource returns the src's principal addresses and, for
|
||||
// user-shaped sources, the user ID (so autogroup:self can scope to it).
|
||||
// Tag, host, and IP sources return userID 0.
|
||||
func resolveSSHTestSource(
|
||||
src Alias,
|
||||
pol *Policy,
|
||||
users []types.User,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) ([]netip.Addr, uint, error) {
|
||||
if src == nil {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
addrs, err := src.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("resolving: %w", err)
|
||||
}
|
||||
|
||||
if addrs == nil || addrs.Empty() {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
out := make([]netip.Addr, 0)
|
||||
for a := range addrs.Iter() {
|
||||
out = append(out, a)
|
||||
}
|
||||
|
||||
var userID uint
|
||||
|
||||
u, ok := src.(*Username)
|
||||
if ok {
|
||||
resolved, rErr := u.resolveUser(users)
|
||||
if rErr == nil {
|
||||
userID = resolved.ID
|
||||
}
|
||||
}
|
||||
|
||||
return out, userID, nil
|
||||
}
|
||||
|
||||
// resolveSSHTestDestNodes maps each dst alias to its destination
|
||||
// NodeViews. autogroup:self needs special handling: it cannot resolve
|
||||
// without per-node context, so it walks the node set keyed on src's
|
||||
// owning user. Other aliases resolve to an IPSet and match via InIPSet.
|
||||
func resolveSSHTestDestNodes(
|
||||
dsts SSHTestDestinations,
|
||||
pol *Policy,
|
||||
users []types.User,
|
||||
nodes views.Slice[types.NodeView],
|
||||
srcUserID uint,
|
||||
) ([]types.NodeView, []string, error) {
|
||||
seen := make(map[types.NodeID]struct{})
|
||||
|
||||
var (
|
||||
out []types.NodeView
|
||||
emptyDsts []string
|
||||
)
|
||||
|
||||
for _, alias := range dsts {
|
||||
dstLabel := alias.String()
|
||||
matched := false
|
||||
|
||||
if ag, ok := alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
// autogroup:self resolves to non-tagged nodes owned by
|
||||
// the same user as src; tagged/IP sources have no user.
|
||||
if srcUserID == 0 {
|
||||
emptyDsts = append(emptyDsts, dstLabel)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if n.IsTagged() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !n.User().Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
if n.User().ID() != srcUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
matched = true
|
||||
|
||||
if _, dup := seen[n.ID()]; dup {
|
||||
continue
|
||||
}
|
||||
|
||||
seen[n.ID()] = struct{}{}
|
||||
out = append(out, n)
|
||||
}
|
||||
|
||||
if !matched {
|
||||
emptyDsts = append(emptyDsts, dstLabel)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ips, err := alias.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("resolving destination %q: %w", dstLabel, err)
|
||||
}
|
||||
|
||||
if ips == nil || ips.Empty() {
|
||||
emptyDsts = append(emptyDsts, dstLabel)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
set, err := prefixesToIPSet(ips.Prefixes())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("building IPSet for %q: %w", dstLabel, err)
|
||||
}
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !n.InIPSet(set) {
|
||||
continue
|
||||
}
|
||||
|
||||
matched = true
|
||||
|
||||
if _, dup := seen[n.ID()]; dup {
|
||||
continue
|
||||
}
|
||||
|
||||
seen[n.ID()] = struct{}{}
|
||||
out = append(out, n)
|
||||
}
|
||||
|
||||
if !matched {
|
||||
emptyDsts = append(emptyDsts, dstLabel)
|
||||
}
|
||||
}
|
||||
|
||||
return out, emptyDsts, nil
|
||||
}
|
||||
|
||||
// prefixesToIPSet builds the IPSet that InIPSet expects on the node
|
||||
// side.
|
||||
func prefixesToIPSet(prefixes []netip.Prefix) (*netipx.IPSet, error) {
|
||||
var b netipx.IPSetBuilder
|
||||
|
||||
for _, p := range prefixes {
|
||||
b.AddPrefix(p)
|
||||
}
|
||||
|
||||
return b.IPSet()
|
||||
}
|
||||
|
||||
// compiledSSHPolicy returns the per-node compiled SSH policy, caching
|
||||
// on miss. baseURL is empty because reachability only checks for the
|
||||
// presence of HoldAndDelegate, not its value.
|
||||
func compiledSSHPolicy(
|
||||
pol *Policy,
|
||||
users []types.User,
|
||||
nodes views.Slice[types.NodeView],
|
||||
cache map[types.NodeID]*tailcfg.SSHPolicy,
|
||||
node types.NodeView,
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if sshPol, ok := cache[node.ID()]; ok {
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
sshPol, err := pol.compileSSHPolicy("", users, node, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache[node.ID()] = sshPol
|
||||
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
// reachability reports whether srcAddr can log in as user via:
|
||||
//
|
||||
// - any matching rule (acceptHit, satisfies accept assertions)
|
||||
// - a check-action rule (checkHit, satisfies check assertions)
|
||||
func reachability(
|
||||
dstPolicy *tailcfg.SSHPolicy,
|
||||
srcAddr netip.Addr,
|
||||
user string,
|
||||
) (bool, bool) {
|
||||
if dstPolicy == nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
var acceptHit, checkHit bool
|
||||
|
||||
for _, rule := range dstPolicy.Rules {
|
||||
if !principalContainsAddr(rule.Principals, srcAddr) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !sshUserMapAllows(rule.SSHUsers, user) {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.Action == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
acceptHit = true
|
||||
|
||||
if rule.Action.HoldAndDelegate != "" {
|
||||
checkHit = true
|
||||
}
|
||||
|
||||
// Early-out only when both bits are set: a rule satisfying
|
||||
// accept does not always satisfy check.
|
||||
if acceptHit && checkHit {
|
||||
return acceptHit, checkHit
|
||||
}
|
||||
}
|
||||
|
||||
return acceptHit, checkHit
|
||||
}
|
||||
|
||||
// principalContainsAddr reports whether any principal's NodeIP matches
|
||||
// srcAddr exactly (the SSH compiler emits one principal per source IP).
|
||||
func principalContainsAddr(
|
||||
principals []*tailcfg.SSHPrincipal,
|
||||
srcAddr netip.Addr,
|
||||
) bool {
|
||||
for _, p := range principals {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if p.NodeIP == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(p.NodeIP)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if addr == srcAddr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// sshUserMapAllows reports whether SSHUsers permits user. The SSHUsers
|
||||
// wire shape (see filter.go compileSSHPolicy):
|
||||
//
|
||||
// - SSHUsers["root"] == "root" allows root; == "" disallows it.
|
||||
// - SSHUsers["*"] == "=" is the wildcard fallback for non-root users
|
||||
// (set when the rule lists autogroup:nonroot).
|
||||
// - SSHUsers[<literal>] == <literal> for every named user.
|
||||
func sshUserMapAllows(m map[string]string, user string) bool {
|
||||
if user == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if v, ok := m[user]; ok {
|
||||
return v != ""
|
||||
}
|
||||
|
||||
if user == "root" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Wildcard fallback for non-root users.
|
||||
if v, ok := m["*"]; ok {
|
||||
return v != ""
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
1000
hscontrol/policy/v2/sshtest_test.go
Normal file
1000
hscontrol/policy/v2/sshtest_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -46,61 +46,6 @@ func setupSSHDataCompatUsers() types.Users {
|
||||
}
|
||||
}
|
||||
|
||||
// setupSSHDataCompatNodes returns the test nodes for SSH data-driven
|
||||
// compatibility tests. Node GivenNames match the anonymized pokémon names:
|
||||
// - bulbasaur (owned by odin)
|
||||
// - ivysaur (owned by thor)
|
||||
// - venusaur (owned by freya)
|
||||
// - beedrill (tag:server)
|
||||
// - kakuna (tag:prod)
|
||||
func setupSSHDataCompatNodes(users types.Users) types.Nodes {
|
||||
return types.Nodes{
|
||||
&types.Node{
|
||||
ID: 1,
|
||||
GivenName: "bulbasaur",
|
||||
User: &users[0],
|
||||
UserID: &users[0].ID,
|
||||
IPv4: ptrAddr("100.90.199.68"),
|
||||
IPv6: ptrAddr("fd7a:115c:a1e0::2d01:c747"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
&types.Node{
|
||||
ID: 2,
|
||||
GivenName: "ivysaur",
|
||||
User: &users[1],
|
||||
UserID: &users[1].ID,
|
||||
IPv4: ptrAddr("100.110.121.96"),
|
||||
IPv6: ptrAddr("fd7a:115c:a1e0::1737:7960"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
&types.Node{
|
||||
ID: 3,
|
||||
GivenName: "venusaur",
|
||||
User: &users[2],
|
||||
UserID: &users[2].ID,
|
||||
IPv4: ptrAddr("100.103.90.82"),
|
||||
IPv6: ptrAddr("fd7a:115c:a1e0::9e37:5a52"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
&types.Node{
|
||||
ID: 4,
|
||||
GivenName: "beedrill",
|
||||
IPv4: ptrAddr("100.108.74.26"),
|
||||
IPv6: ptrAddr("fd7a:115c:a1e0::b901:4a87"),
|
||||
Tags: []string{"tag:server"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
&types.Node{
|
||||
ID: 5,
|
||||
GivenName: "kakuna",
|
||||
IPv4: ptrAddr("100.103.8.15"),
|
||||
IPv6: ptrAddr("fd7a:115c:a1e0::5b37:80f"),
|
||||
Tags: []string{"tag:prod"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// loadSSHTestFile loads and parses a single SSH capture HuJSON file.
|
||||
func loadSSHTestFile(t *testing.T, path string) *testcapture.Capture {
|
||||
t.Helper()
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/go-json-experiment/json"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -57,23 +58,80 @@ type PolicyTest struct {
|
||||
// listed user is asserted against every entry in Dst.
|
||||
type SSHPolicyTest struct {
|
||||
// Src is a single source alias (user, group, tag, host, or IP).
|
||||
Src string `json:"src"`
|
||||
Src Alias `json:"src"`
|
||||
|
||||
// Dst lists destinations the test exercises (tag, host, or SSH-
|
||||
// compatible autogroup). Ports, CIDRs, and autogroup:internet are
|
||||
// rejected at parse time.
|
||||
Dst []string `json:"dst"`
|
||||
Dst SSHTestDestinations `json:"dst"`
|
||||
|
||||
// Accept lists users that must reach every Dst via an accept- or
|
||||
// check-action rule.
|
||||
Accept []string `json:"accept,omitempty"`
|
||||
Accept []SSHUser `json:"accept,omitempty"`
|
||||
|
||||
// Deny lists users that must NOT reach any Dst.
|
||||
Deny []string `json:"deny,omitempty"`
|
||||
Deny []SSHUser `json:"deny,omitempty"`
|
||||
|
||||
// Check lists users that must reach every Dst via a check-action
|
||||
// rule specifically; an accept-action rule does not satisfy this.
|
||||
Check []string `json:"check,omitempty"`
|
||||
Check []SSHUser `json:"check,omitempty"`
|
||||
}
|
||||
|
||||
// SSHTestDestinations is the typed list of destination aliases an
|
||||
// sshTests entry targets. validateSSHTestDestination enforces the
|
||||
// SSH-specific shape rules (no :port, no CIDR, no autogroup:internet,
|
||||
// known tag).
|
||||
type SSHTestDestinations []Alias
|
||||
|
||||
func (d *SSHTestDestinations) UnmarshalJSON(b []byte) error {
|
||||
var aliases []AliasEnc
|
||||
|
||||
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*d = make([]Alias, len(aliases))
|
||||
for i, a := range aliases {
|
||||
(*d)[i] = a.Alias
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses each typed field. An empty src lands as a nil
|
||||
// Alias so validation surfaces ErrSSHTestEmptySrc rather than a parser
|
||||
// failure.
|
||||
func (t *SSHPolicyTest) UnmarshalJSON(b []byte) error {
|
||||
var raw struct {
|
||||
Src string `json:"src"`
|
||||
Dst SSHTestDestinations `json:"dst"`
|
||||
Accept []SSHUser `json:"accept,omitempty"`
|
||||
Deny []SSHUser `json:"deny,omitempty"`
|
||||
Check []SSHUser `json:"check,omitempty"`
|
||||
}
|
||||
|
||||
err := json.Unmarshal(b, &raw, policyJSONOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
trimmedSrc := strings.TrimSpace(raw.Src)
|
||||
if trimmedSrc != "" {
|
||||
alias, parseErr := parseAlias(trimmedSrc)
|
||||
if parseErr != nil {
|
||||
return parseErr
|
||||
}
|
||||
|
||||
t.Src = alias
|
||||
}
|
||||
|
||||
t.Dst = raw.Dst
|
||||
t.Accept = raw.Accept
|
||||
t.Deny = raw.Deny
|
||||
t.Check = raw.Check
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PolicyTestResult is the outcome of a single PolicyTest.
|
||||
|
||||
18758
hscontrol/policy/v2/testdata/sshtest_results/sshtest-malformed-dst-bare-ipv4.hujson
vendored
Normal file
18758
hscontrol/policy/v2/testdata/sshtest_results/sshtest-malformed-dst-bare-ipv4.hujson
vendored
Normal file
File diff suppressed because it is too large
Load Diff
18737
hscontrol/policy/v2/testdata/sshtest_results/sshtest-malformed-dst-bare-ipv6.hujson
vendored
Normal file
18737
hscontrol/policy/v2/testdata/sshtest_results/sshtest-malformed-dst-bare-ipv6.hujson
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -863,6 +863,12 @@ type Alias interface {
|
||||
Validate() error
|
||||
UnmarshalJSON(b []byte) error
|
||||
|
||||
// String renders the alias back to its policy-file form. Implementations
|
||||
// are expected to return a value that round-trips through parseAlias for
|
||||
// any alias the parser accepted, so callers can use it as a stable
|
||||
// identity in rendered errors and logs.
|
||||
String() string
|
||||
|
||||
// Resolve resolves the Alias to an IPSet. The IPSet will contain all the IP
|
||||
// addresses that the Alias represents within Headscale. It is the product
|
||||
// of the Alias and the Policy, Users and Nodes.
|
||||
@@ -2927,10 +2933,6 @@ func (p *SSHCheckPeriod) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// time.ParseDuration produces error strings like
|
||||
// `time: invalid duration "abc"` which match SaaS body wording
|
||||
// exactly; model.ParseDuration wraps the same parse with custom
|
||||
// phrasing and would diverge.
|
||||
d, err := time.ParseDuration(str)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -3407,7 +3409,7 @@ func validateSSHTests(pol *Policy, tests []SSHPolicyTest) error {
|
||||
var errs []error
|
||||
|
||||
for i, t := range tests {
|
||||
if t.Src == "" {
|
||||
if t.Src == nil {
|
||||
errs = append(errs, fmt.Errorf("sshTest %d: %w", i, ErrSSHTestEmptySrc))
|
||||
}
|
||||
|
||||
@@ -3439,11 +3441,8 @@ func validateSSHTests(pol *Policy, tests []SSHPolicyTest) error {
|
||||
//
|
||||
// A bare IP literal (single-host /BitLen prefix) is accepted. Tag
|
||||
// entries must exist in tagOwners.
|
||||
func validateSSHTestDestination(pol *Policy, dst string) error {
|
||||
alias, err := parseAlias(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w %q", ErrSSHTestDstDisallowedElement, dst)
|
||||
}
|
||||
func validateSSHTestDestination(pol *Policy, alias Alias) error {
|
||||
dst := alias.String()
|
||||
|
||||
switch a := alias.(type) {
|
||||
case *AutoGroup:
|
||||
@@ -3454,9 +3453,10 @@ func validateSSHTestDestination(pol *Policy, dst string) error {
|
||||
}
|
||||
|
||||
case *Prefix:
|
||||
// Bare IP parses to a *Prefix without slash; reject any
|
||||
// explicit CIDR.
|
||||
if strings.Contains(dst, "/") {
|
||||
// A bare IP parses as `/BitLen` and is a valid single-host dst;
|
||||
// any narrower CIDR is a multi-host range and is rejected.
|
||||
p := netip.Prefix(*a)
|
||||
if p.Bits() < p.Addr().BitLen() {
|
||||
return fmt.Errorf("%w %q", ErrSSHTestDstDisallowedElement, dst)
|
||||
}
|
||||
|
||||
|
||||
@@ -6003,9 +6003,12 @@ func TestUnmarshalPolicySSHTests(t *testing.T) {
|
||||
t.Helper()
|
||||
require.Len(t, pol.SSHTests, 1)
|
||||
got := pol.SSHTests[0]
|
||||
require.Equal(t, "thor@example.org", got.Src)
|
||||
require.Equal(t, []string{"tag:server"}, got.Dst)
|
||||
require.Equal(t, []string{"root"}, got.Accept)
|
||||
require.IsType(t, (*Username)(nil), got.Src)
|
||||
require.Equal(t, "thor@example.org", got.Src.String())
|
||||
require.Len(t, got.Dst, 1)
|
||||
require.IsType(t, (*Tag)(nil), got.Dst[0])
|
||||
require.Equal(t, "tag:server", got.Dst[0].String())
|
||||
require.Equal(t, []SSHUser{"root"}, got.Accept)
|
||||
require.Empty(t, got.Deny)
|
||||
require.Empty(t, got.Check)
|
||||
},
|
||||
@@ -6030,9 +6033,9 @@ func TestUnmarshalPolicySSHTests(t *testing.T) {
|
||||
t.Helper()
|
||||
require.Len(t, pol.SSHTests, 1)
|
||||
got := pol.SSHTests[0]
|
||||
require.Equal(t, []string{"root"}, got.Accept)
|
||||
require.Equal(t, []string{"nobody"}, got.Deny)
|
||||
require.Equal(t, []string{"alice"}, got.Check)
|
||||
require.Equal(t, []SSHUser{"root"}, got.Accept)
|
||||
require.Equal(t, []SSHUser{"nobody"}, got.Deny)
|
||||
require.Equal(t, []SSHUser{"alice"}, got.Check) //nolint:goconst
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -6086,6 +6089,52 @@ func TestUnmarshalPolicySSHTests(t *testing.T) {
|
||||
`,
|
||||
wantErr: ErrSSHTestDstDisallowedElement,
|
||||
},
|
||||
{
|
||||
// SaaS accepts a bare IPv4 literal as a host address. The
|
||||
// Prefix parser turns it into a /32 so validateSSHTestDestination
|
||||
// must match Bits() against Addr().BitLen() rather than reject
|
||||
// the whole *Prefix branch.
|
||||
name: "dst-bare-ipv4-accepted",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:server": ["admin@example.org"]},
|
||||
"sshTests": [
|
||||
{"src": "thor@example.org", "dst": ["100.64.0.16"], "accept": ["root"]}
|
||||
]
|
||||
}
|
||||
`,
|
||||
check: func(t *testing.T, pol *Policy) {
|
||||
t.Helper()
|
||||
require.Len(t, pol.SSHTests, 1)
|
||||
got := pol.SSHTests[0]
|
||||
require.Len(t, got.Dst, 1)
|
||||
pref, ok := got.Dst[0].(*Prefix)
|
||||
require.True(t, ok, "want *Prefix, got %T", got.Dst[0])
|
||||
require.Equal(t, "100.64.0.16/32", pref.String())
|
||||
},
|
||||
},
|
||||
{
|
||||
// IPv6 mirror of the IPv4 case: bare `fd7a::10` parses to
|
||||
// /128 and must pass the parse-time shape check.
|
||||
name: "dst-bare-ipv6-accepted",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:server": ["admin@example.org"]},
|
||||
"sshTests": [
|
||||
{"src": "thor@example.org", "dst": ["fd7a:115c:a1e0::10"], "accept": ["root"]}
|
||||
]
|
||||
}
|
||||
`,
|
||||
check: func(t *testing.T, pol *Policy) {
|
||||
t.Helper()
|
||||
require.Len(t, pol.SSHTests, 1)
|
||||
got := pol.SSHTests[0]
|
||||
require.Len(t, got.Dst, 1)
|
||||
pref, ok := got.Dst[0].(*Prefix)
|
||||
require.True(t, ok, "want *Prefix, got %T", got.Dst[0])
|
||||
require.Equal(t, "fd7a:115c:a1e0::10/128", pref.String())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dst-autogroup-internet",
|
||||
input: `
|
||||
|
||||
Reference in New Issue
Block a user