mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-11 20:55:15 +09:00
policy, noise: implement SSH check action
Implement the SSH "check" action which requires additional verification before allowing SSH access. The policy compiler generates a HoldAndDelegate URL that the Tailscale client calls back to headscale. The SSHActionHandler creates an auth session and waits for approval via the generalised auth flow. Sort check (HoldAndDelegate) rules before accept rules to match Tailscale's first-match-wins evaluation order. Updates #1850
This commit is contained in:
@@ -7,12 +7,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/metrics"
|
||||
"github.com/juanfont/headscale/hscontrol/capver"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http2"
|
||||
"tailscale.com/control/controlbase"
|
||||
@@ -30,6 +32,9 @@ var ErrMissingURLParameter = errors.New("missing URL parameter")
|
||||
// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type.
|
||||
var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type")
|
||||
|
||||
// ErrNoAuthSession is returned when an auth_id does not match any active auth session.
|
||||
var ErrNoAuthSession = errors.New("no auth session found")
|
||||
|
||||
const (
|
||||
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
|
||||
ts2021UpgradePath = "/ts2021"
|
||||
@@ -113,7 +118,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
}))
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(middleware.RequestLogger(&zerologRequestLogger{}))
|
||||
r.Use(middleware.Recoverer)
|
||||
|
||||
r.Handle("/metrics", metrics.Handler())
|
||||
@@ -122,6 +127,9 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
r.Post("/register", ns.RegistrationHandler)
|
||||
r.Post("/map", ns.PollNetMapHandler)
|
||||
|
||||
// SSH Check mode endpoint, consulted to validate if a given SSH connection should be accepted or rejected.
|
||||
r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}", ns.SSHActionHandler)
|
||||
|
||||
// Not implemented yet
|
||||
//
|
||||
// /whoami is a debug endpoint to validate that the client can communicate over the connection,
|
||||
@@ -153,12 +161,10 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
r.Post("/update-health", ns.NotImplementedHandler)
|
||||
|
||||
r.Route("/webclient", func(r chi.Router) {})
|
||||
|
||||
r.Post("/c2n", ns.NotImplementedHandler)
|
||||
})
|
||||
|
||||
r.Post("/c2n", ns.NotImplementedHandler)
|
||||
|
||||
r.Get("/ssh-action", ns.SSHAction)
|
||||
|
||||
ns.httpBaseConfig = &http.Server{
|
||||
Handler: r,
|
||||
ReadHeaderTimeout: types.HTTPTimeout,
|
||||
@@ -249,10 +255,233 @@ func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *ht
|
||||
http.Error(writer, "Not implemented yet", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction]
|
||||
// to the client with the verdict of an SSH access request.
|
||||
func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) {
|
||||
log.Trace().Caller().Str("path", req.URL.String()).Msg("got SSH action request")
|
||||
func urlParam[T any](req *http.Request, key string) (T, error) {
|
||||
var zero T
|
||||
|
||||
param := chi.URLParam(req, key)
|
||||
if param == "" {
|
||||
return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key)
|
||||
}
|
||||
|
||||
var value T
|
||||
switch any(value).(type) {
|
||||
case string:
|
||||
v, ok := any(param).(T)
|
||||
if !ok {
|
||||
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
|
||||
}
|
||||
|
||||
value = v
|
||||
case types.NodeID:
|
||||
id, err := types.ParseNodeID(param)
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("parsing %s: %w", key, err)
|
||||
}
|
||||
|
||||
v, ok := any(id).(T)
|
||||
if !ok {
|
||||
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
|
||||
}
|
||||
|
||||
value = v
|
||||
default:
|
||||
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// SSHActionHandler handles the /ssh-action endpoint, returning a
|
||||
// [tailcfg.SSHAction] to the client with the verdict of an SSH access
|
||||
// request.
|
||||
func (ns *noiseServer) SSHActionHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
srcNodeID, err := urlParam[types.NodeID](req, "src_node_id")
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(
|
||||
http.StatusBadRequest,
|
||||
"Invalid src_node_id",
|
||||
err,
|
||||
))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
dstNodeID, err := urlParam[types.NodeID](req, "dst_node_id")
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(
|
||||
http.StatusBadRequest,
|
||||
"Invalid dst_node_id",
|
||||
err,
|
||||
))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
reqLog := log.With().
|
||||
Uint64("src_node_id", srcNodeID.Uint64()).
|
||||
Uint64("dst_node_id", dstNodeID.Uint64()).
|
||||
Str("ssh_user", req.URL.Query().Get("ssh_user")).
|
||||
Str("local_user", req.URL.Query().Get("local_user")).
|
||||
Logger()
|
||||
|
||||
reqLog.Trace().Caller().Msg("SSH action request")
|
||||
|
||||
action, err := ns.sshAction(
|
||||
reqLog,
|
||||
req.URL.Query().Get("auth_id"),
|
||||
)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
err = json.NewEncoder(writer).Encode(action)
|
||||
if err != nil {
|
||||
reqLog.Error().Caller().Err(err).
|
||||
Msg("failed to encode SSH action response")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// sshAction resolves the SSH action for the given request parameters.
|
||||
// It returns the action to send to the client, or an HTTPError on
|
||||
// failure.
|
||||
//
|
||||
// Two cases:
|
||||
// 1. Initial request — build a HoldAndDelegate URL and wait for the
|
||||
// user to authenticate.
|
||||
// 2. Follow-up request — an auth_id is present, wait for the auth
|
||||
// verdict and accept or reject.
|
||||
func (ns *noiseServer) sshAction(
|
||||
reqLog zerolog.Logger,
|
||||
authIDStr string,
|
||||
) (*tailcfg.SSHAction, error) {
|
||||
action := tailcfg.SSHAction{
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
}
|
||||
|
||||
// Follow-up request with auth_id — wait for the auth verdict.
|
||||
if authIDStr != "" {
|
||||
return ns.sshActionFollowUp(
|
||||
reqLog, &action, authIDStr,
|
||||
)
|
||||
}
|
||||
|
||||
// Initial request — create an auth session and hold.
|
||||
return ns.sshActionHoldAndDelegate(reqLog, &action)
|
||||
}
|
||||
|
||||
// sshActionHoldAndDelegate creates a new auth session and returns a
|
||||
// HoldAndDelegate action that directs the client to authenticate.
|
||||
func (ns *noiseServer) sshActionHoldAndDelegate(
|
||||
reqLog zerolog.Logger,
|
||||
action *tailcfg.SSHAction,
|
||||
) (*tailcfg.SSHAction, error) {
|
||||
holdURL, err := url.Parse(
|
||||
ns.headscale.cfg.ServerURL +
|
||||
"/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID" +
|
||||
"?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(
|
||||
http.StatusInternalServerError,
|
||||
"Internal error",
|
||||
fmt.Errorf("parsing SSH action URL: %w", err),
|
||||
)
|
||||
}
|
||||
|
||||
authID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(
|
||||
http.StatusInternalServerError,
|
||||
"Internal error",
|
||||
fmt.Errorf("generating auth ID: %w", err),
|
||||
)
|
||||
}
|
||||
|
||||
ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest())
|
||||
|
||||
authURL := ns.headscale.authProvider.AuthURL(authID)
|
||||
|
||||
q := holdURL.Query()
|
||||
q.Set("auth_id", authID.String())
|
||||
holdURL.RawQuery = q.Encode()
|
||||
|
||||
action.HoldAndDelegate = holdURL.String()
|
||||
|
||||
// TODO(kradalby): here we can also send a very tiny mapresponse
|
||||
// "popping" the url and opening it for the user.
|
||||
action.Message = fmt.Sprintf(
|
||||
"# Headscale SSH requires an additional check.\n"+
|
||||
"# To authenticate, visit: %s\n"+
|
||||
"# Authentication checked with Headscale SSH.\n",
|
||||
authURL,
|
||||
)
|
||||
|
||||
reqLog.Info().Caller().
|
||||
Str("auth_id", authID.String()).
|
||||
Msg("SSH check pending, waiting for auth")
|
||||
|
||||
return action, nil
|
||||
}
|
||||
|
||||
// sshActionFollowUp handles follow-up requests where the client
|
||||
// provides an auth_id. It blocks until the auth session resolves.
|
||||
func (ns *noiseServer) sshActionFollowUp(
|
||||
reqLog zerolog.Logger,
|
||||
action *tailcfg.SSHAction,
|
||||
authIDStr string,
|
||||
) (*tailcfg.SSHAction, error) {
|
||||
authID, err := types.AuthIDFromString(authIDStr)
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(
|
||||
http.StatusBadRequest,
|
||||
"Invalid auth_id",
|
||||
fmt.Errorf("parsing auth_id: %w", err),
|
||||
)
|
||||
}
|
||||
|
||||
reqLog = reqLog.With().Str("auth_id", authID.String()).Logger()
|
||||
|
||||
auth, ok := ns.headscale.state.GetAuthCacheEntry(authID)
|
||||
if !ok {
|
||||
return nil, NewHTTPError(
|
||||
http.StatusBadRequest,
|
||||
"Invalid auth_id",
|
||||
fmt.Errorf("%w: %s", ErrNoAuthSession, authID),
|
||||
)
|
||||
}
|
||||
|
||||
reqLog.Trace().Caller().Msg("SSH action follow-up")
|
||||
|
||||
verdict := <-auth.WaitForAuth()
|
||||
|
||||
if !verdict.Accept() {
|
||||
action.Reject = true
|
||||
|
||||
reqLog.Trace().Caller().Err(verdict.Err).
|
||||
Msg("authentication rejected")
|
||||
|
||||
return action, nil
|
||||
}
|
||||
|
||||
action.Accept = true
|
||||
|
||||
return action, nil
|
||||
}
|
||||
|
||||
// PollNetMapHandler takes care of /machine/:id/map using the Noise protocol
|
||||
@@ -380,28 +609,3 @@ func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.
|
||||
|
||||
return nv, nil
|
||||
}
|
||||
|
||||
// urlParam extracts a typed URL parameter from a chi router request.
|
||||
func urlParam[T any](req *http.Request, key string) (T, error) {
|
||||
var zero T
|
||||
|
||||
param := chi.URLParam(req, key)
|
||||
if param == "" {
|
||||
return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key)
|
||||
}
|
||||
|
||||
var value T
|
||||
switch any(value).(type) {
|
||||
case string:
|
||||
v, ok := any(param).(T)
|
||||
if !ok {
|
||||
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
|
||||
}
|
||||
|
||||
value = v
|
||||
default:
|
||||
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type PolicyManager interface {
|
||||
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
|
||||
// BuildPeerMap constructs peer relationship maps for the given nodes
|
||||
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
|
||||
SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||
SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||
SetPolicy(pol []byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
|
||||
|
||||
@@ -1188,8 +1188,9 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
"root": "",
|
||||
},
|
||||
Action: &tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
Accept: false,
|
||||
SessionDuration: 24 * time.Hour,
|
||||
HoldAndDelegate: "unused-url/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
@@ -1476,7 +1477,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := pm.SSHPolicy(tt.targetNode.View())
|
||||
got, err := pm.SSHPolicy("unused-url", tt.targetNode.View())
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.wantSSH, got); diff != "" {
|
||||
|
||||
@@ -319,11 +319,27 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||
var sshAccept = tailcfg.SSHAction{
|
||||
Reject: false,
|
||||
Accept: true,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
}
|
||||
|
||||
func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction {
|
||||
return tailcfg.SSHAction{
|
||||
Reject: !accept,
|
||||
Accept: accept,
|
||||
SessionDuration: duration,
|
||||
Reject: false,
|
||||
Accept: false,
|
||||
SessionDuration: duration,
|
||||
// Replaced in the client:
|
||||
// * $SRC_NODE_IP (URL escaped)
|
||||
// * $SRC_NODE_ID (Node.ID as int64 string)
|
||||
// * $DST_NODE_IP (URL escaped)
|
||||
// * $DST_NODE_ID (Node.ID as int64 string)
|
||||
// * $SSH_USER (URL escaped, ssh user requested)
|
||||
// * $LOCAL_USER (URL escaped, local user mapped)
|
||||
HoldAndDelegate: baseURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
AllowRemotePortForwarding: true,
|
||||
@@ -332,6 +348,7 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||
|
||||
//nolint:gocyclo // complex SSH policy compilation logic
|
||||
func (pol *Policy) compileSSHPolicy(
|
||||
baseURL string,
|
||||
users types.Users,
|
||||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
@@ -377,9 +394,9 @@ func (pol *Policy) compileSSHPolicy(
|
||||
|
||||
switch rule.Action {
|
||||
case SSHActionAccept:
|
||||
action = sshAction(true, 0)
|
||||
action = sshAccept
|
||||
case SSHActionCheck:
|
||||
action = sshAction(true, time.Duration(rule.CheckPeriod))
|
||||
action = sshCheck(baseURL, time.Duration(rule.CheckPeriod))
|
||||
default:
|
||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
|
||||
}
|
||||
@@ -503,6 +520,23 @@ func (pol *Policy) compileSSHPolicy(
|
||||
}
|
||||
}
|
||||
|
||||
// Sort rules: check (HoldAndDelegate) before accept, per Tailscale
|
||||
// evaluation order (most-restrictive first).
|
||||
slices.SortStableFunc(rules, func(a, b *tailcfg.SSHRule) int {
|
||||
aIsCheck := a.Action != nil && a.Action.HoldAndDelegate != ""
|
||||
|
||||
bIsCheck := b.Action != nil && b.Action.HoldAndDelegate != ""
|
||||
if aIsCheck == bIsCheck {
|
||||
return 0
|
||||
}
|
||||
|
||||
if aIsCheck {
|
||||
return -1
|
||||
}
|
||||
|
||||
return 1
|
||||
})
|
||||
|
||||
return &tailcfg.SSHPolicy{
|
||||
Rules: rules,
|
||||
}, nil
|
||||
|
||||
@@ -615,7 +615,7 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compile SSH policy
|
||||
sshPolicy, err := tt.policy.compileSSHPolicy(users, tt.targetNode.View(), nodes.ViewSlice())
|
||||
sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.wantEmpty {
|
||||
@@ -691,7 +691,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
||||
err := policy.validate()
|
||||
require.NoError(t, err)
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 1)
|
||||
@@ -704,11 +704,90 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, expectedUsers, rule.SSHUsers)
|
||||
|
||||
// Verify check action with session duration
|
||||
assert.True(t, rule.Action.Accept)
|
||||
// Verify check action: Accept is false, HoldAndDelegate is set
|
||||
assert.False(t, rule.Action.Accept)
|
||||
assert.False(t, rule.Action.Reject)
|
||||
assert.NotEmpty(t, rule.Action.HoldAndDelegate)
|
||||
assert.Contains(t, rule.Action.HoldAndDelegate, "/machine/ssh/action/")
|
||||
assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration)
|
||||
}
|
||||
|
||||
// TestCompileSSHPolicy_CheckBeforeAcceptOrdering verifies that check
|
||||
// (HoldAndDelegate) rules are sorted before accept rules, even when
|
||||
// the accept rule appears first in the policy definition.
|
||||
func TestCompileSSHPolicy_CheckBeforeAcceptOrdering(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Name: "user1", Model: gorm.Model{ID: 1}},
|
||||
{Name: "user2", Model: gorm.Model{ID: 2}},
|
||||
}
|
||||
|
||||
nodeTaggedServer := types.Node{
|
||||
Hostname: "tagged-server",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
Tags: []string{"tag:server"},
|
||||
}
|
||||
nodeUser2 := types.Node{
|
||||
Hostname: "user2-device",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
nodes := types.Nodes{&nodeTaggedServer, &nodeUser2}
|
||||
|
||||
// Accept rule appears BEFORE check rule in policy definition.
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("user1@")},
|
||||
},
|
||||
Groups: Groups{
|
||||
Group("group:admins"): []Username{Username("user2@")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{gp("group:admins")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{"root"},
|
||||
},
|
||||
{
|
||||
Action: "check",
|
||||
CheckPeriod: model.Duration(24 * time.Hour),
|
||||
Sources: SSHSrcAliases{gp("group:admins")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{"ssh-it-user"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := policy.validate()
|
||||
require.NoError(t, err)
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy(
|
||||
"unused-server-url",
|
||||
users,
|
||||
nodeTaggedServer.View(),
|
||||
nodes.ViewSlice(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 2)
|
||||
|
||||
// First rule must be the check rule (HoldAndDelegate set).
|
||||
assert.NotEmpty(t, sshPolicy.Rules[0].Action.HoldAndDelegate,
|
||||
"first rule should be check (HoldAndDelegate)")
|
||||
assert.False(t, sshPolicy.Rules[0].Action.Accept,
|
||||
"first rule should not be accept")
|
||||
|
||||
// Second rule must be the accept rule.
|
||||
assert.True(t, sshPolicy.Rules[1].Action.Accept,
|
||||
"second rule should be accept")
|
||||
assert.Empty(t, sshPolicy.Rules[1].Action.HoldAndDelegate,
|
||||
"second rule should not have HoldAndDelegate")
|
||||
}
|
||||
|
||||
// TestSSHIntegrationReproduction reproduces the exact scenario from the integration test
|
||||
// TestSSHOneUserToAll that was failing with empty sshUsers.
|
||||
func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
@@ -756,7 +835,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test SSH policy compilation for node2 (owned by user2, who is in the group)
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, node2.View(), nodes.ViewSlice())
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 1)
|
||||
@@ -806,7 +885,7 @@ func TestSSHJSONSerialization(t *testing.T) {
|
||||
err := policy.validate()
|
||||
require.NoError(t, err)
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, node.View(), nodes.ViewSlice())
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
|
||||
@@ -1413,7 +1492,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
|
||||
// Test for user1's first node
|
||||
node1 := nodes[0].View()
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 1)
|
||||
@@ -1432,7 +1511,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
|
||||
// Test for user2's first node
|
||||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy2)
|
||||
require.Len(t, sshPolicy2.Rules, 1)
|
||||
@@ -1451,7 +1530,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
|
||||
// Test for tagged node (should have no SSH rules)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
sshPolicy3, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy3 != nil {
|
||||
@@ -1491,7 +1570,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
||||
|
||||
// For user1's node: should allow SSH from user1's devices
|
||||
node1 := nodes[0].View()
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 1)
|
||||
@@ -1508,7 +1587,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
||||
|
||||
// For user2's node: should have no rules (user1's devices can't match user2's self)
|
||||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
@@ -1551,7 +1630,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
||||
|
||||
// For user1's node: should allow SSH from user1's devices only (not user2's)
|
||||
node1 := nodes[0].View()
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 1)
|
||||
@@ -1568,7 +1647,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
||||
|
||||
// For user3's node: should have no rules (not in group:admins)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
@@ -1610,7 +1689,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
|
||||
// For untagged node: should only get principals from other untagged nodes
|
||||
node1 := nodes[0].View()
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
|
||||
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 1)
|
||||
@@ -1628,7 +1707,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
|
||||
// For tagged node: should get no SSH rules
|
||||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
@@ -1671,7 +1750,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
||||
|
||||
// Test 1: Compile for user1's device (should only match autogroup:self destination)
|
||||
node1 := nodes[0].View()
|
||||
sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
|
||||
sshPolicy1, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy1)
|
||||
require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)")
|
||||
@@ -1690,7 +1769,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
||||
|
||||
// Test 2: Compile for router (should only match tag:router destination)
|
||||
routerNode := nodes[3].View() // user2-router
|
||||
sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice())
|
||||
sshPolicyRouter, err := policy.compileSSHPolicy("unused-server-url", users, routerNode, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicyRouter)
|
||||
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")
|
||||
|
||||
@@ -222,7 +222,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
@@ -230,7 +230,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes)
|
||||
sshPol, err := pm.pol.compileSSHPolicy(baseURL, pm.users, node, pm.nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||
}
|
||||
|
||||
@@ -871,7 +871,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
|
||||
|
||||
// SSHPolicy returns the SSH access policy for a node.
|
||||
func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
return s.polMan.SSHPolicy(node)
|
||||
return s.polMan.SSHPolicy(s.cfg.ServerURL, node)
|
||||
}
|
||||
|
||||
// Filter returns the current network filter rules and matches.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -23,7 +24,8 @@ const (
|
||||
// Common errors.
|
||||
var (
|
||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
ErrInvalidAuthIDLength = errors.New("registration ID has invalid length")
|
||||
ErrInvalidAuthIDLength = errors.New("auth ID has invalid length")
|
||||
ErrInvalidAuthIDPrefix = errors.New("auth ID has invalid prefix")
|
||||
)
|
||||
|
||||
type StateUpdateType int
|
||||
@@ -159,17 +161,22 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
||||
}
|
||||
}
|
||||
|
||||
const AuthIDLength = 24
|
||||
const (
|
||||
authIDPrefix = "hskey-authreq-"
|
||||
authIDRandomLength = 24
|
||||
// AuthIDLength is the total length of an AuthID: 14 (prefix) + 24 (random).
|
||||
AuthIDLength = 38
|
||||
)
|
||||
|
||||
type AuthID string
|
||||
|
||||
func NewAuthID() (AuthID, error) {
|
||||
rid, err := util.GenerateRandomStringURLSafe(AuthIDLength)
|
||||
rid, err := util.GenerateRandomStringURLSafe(authIDRandomLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return AuthID(rid), nil
|
||||
return AuthID(authIDPrefix + rid), nil
|
||||
}
|
||||
|
||||
func MustAuthID() AuthID {
|
||||
@@ -197,8 +204,18 @@ func (r AuthID) String() string {
|
||||
}
|
||||
|
||||
func (r AuthID) Validate() error {
|
||||
if !strings.HasPrefix(string(r), authIDPrefix) {
|
||||
return fmt.Errorf(
|
||||
"%w: expected prefix %q",
|
||||
ErrInvalidAuthIDPrefix, authIDPrefix,
|
||||
)
|
||||
}
|
||||
|
||||
if len(r) != AuthIDLength {
|
||||
return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r))
|
||||
return fmt.Errorf(
|
||||
"%w: expected %d, got %d",
|
||||
ErrInvalidAuthIDLength, AuthIDLength, len(r),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -214,6 +231,13 @@ type AuthRequest struct {
|
||||
closed *atomic.Bool
|
||||
}
|
||||
|
||||
func NewAuthRequest() AuthRequest {
|
||||
return AuthRequest{
|
||||
finished: make(chan AuthVerdict),
|
||||
closed: &atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func NewRegisterAuthRequest(node Node) AuthRequest {
|
||||
return AuthRequest{
|
||||
node: &node,
|
||||
|
||||
@@ -141,6 +141,12 @@ type ScenarioSpec struct {
|
||||
// Versions is specific list of versions to use for the test.
|
||||
Versions []string
|
||||
|
||||
// OIDCSkipUserCreation, if true, skips creating users via headscale CLI
|
||||
// during environment setup. Useful for OIDC tests where the SSH policy
|
||||
// references users by name, since OIDC login creates users automatically
|
||||
// and pre-creating them via CLI causes duplicate user records.
|
||||
OIDCSkipUserCreation bool
|
||||
|
||||
// OIDCUsers, if populated, will start a Mock OIDC server and populate
|
||||
// the user login stack with the given users.
|
||||
// If the NodesPerUser is set, it should align with this list to ensure
|
||||
@@ -866,9 +872,18 @@ func (s *Scenario) createHeadscaleEnvWithTags(
|
||||
}
|
||||
|
||||
for _, user := range s.spec.Users {
|
||||
u, err := s.CreateUser(user)
|
||||
if err != nil {
|
||||
return err
|
||||
var u *v1.User
|
||||
|
||||
if s.spec.OIDCSkipUserCreation {
|
||||
// Only register locally — OIDC login will create the headscale user.
|
||||
s.mu.Lock()
|
||||
s.users[user] = &User{Clients: make(map[string]TailscaleClient)}
|
||||
s.mu.Unlock()
|
||||
} else {
|
||||
u, err = s.CreateUser(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var userOpts []tsic.Option
|
||||
|
||||
@@ -579,3 +579,75 @@ func TestSSHAutogroupSelf(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHOneUserToOneCheckMode(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policyv2.Policy{
|
||||
Groups: policyv2.Groups{
|
||||
policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")},
|
||||
},
|
||||
ACLs: []policyv2.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Sources: []policyv2.Alias{wildcard()},
|
||||
Destinations: []policyv2.AliasWithPorts{
|
||||
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
|
||||
},
|
||||
},
|
||||
},
|
||||
SSHs: []policyv2.SSH{
|
||||
{
|
||||
Action: "check",
|
||||
Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")},
|
||||
// Use autogroup:member and autogroup:tagged instead of wildcard
|
||||
// since wildcard (*) is no longer supported for SSH destinations
|
||||
Destinations: policyv2.SSHDstAliases{
|
||||
new(policyv2.AutoGroupMember),
|
||||
new(policyv2.AutoGroupTagged),
|
||||
},
|
||||
Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")},
|
||||
},
|
||||
},
|
||||
},
|
||||
1,
|
||||
)
|
||||
// defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range allClients {
|
||||
if client.Hostname() == peer.Hostname() {
|
||||
continue
|
||||
}
|
||||
|
||||
assertSSHHostname(t, client, peer)
|
||||
}
|
||||
}
|
||||
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range allClients {
|
||||
if client.Hostname() == peer.Hostname() {
|
||||
continue
|
||||
}
|
||||
|
||||
assertSSHPermissionDenied(t, client, peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user