diff --git a/hscontrol/noise.go b/hscontrol/noise.go index ffcab68e..b5d41b5b 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -7,12 +7,14 @@ import ( "fmt" "io" "net/http" + "net/url" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/metrics" "github.com/juanfont/headscale/hscontrol/capver" "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "golang.org/x/net/http2" "tailscale.com/control/controlbase" @@ -30,6 +32,9 @@ var ErrMissingURLParameter = errors.New("missing URL parameter") // ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type. var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type") +// ErrNoAuthSession is returned when an auth_id does not match any active auth session. +var ErrNoAuthSession = errors.New("no auth session found") + const ( // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. ts2021UpgradePath = "/ts2021" @@ -113,7 +118,7 @@ func (h *Headscale) NoiseUpgradeHandler( })) r.Use(middleware.RequestID) r.Use(middleware.RealIP) - r.Use(middleware.Logger) + r.Use(middleware.RequestLogger(&zerologRequestLogger{})) r.Use(middleware.Recoverer) r.Handle("/metrics", metrics.Handler()) @@ -122,6 +127,9 @@ func (h *Headscale) NoiseUpgradeHandler( r.Post("/register", ns.RegistrationHandler) r.Post("/map", ns.PollNetMapHandler) + // SSH Check mode endpoint, consulted to validate if a given SSH connection should be accepted or rejected. + r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}", ns.SSHActionHandler) + // Not implemented yet // // /whoami is a debug endpoint to validate that the client can communicate over the connection, @@ -153,12 +161,10 @@ func (h *Headscale) NoiseUpgradeHandler( r.Post("/update-health", ns.NotImplementedHandler) r.Route("/webclient", func(r chi.Router) {}) + + r.Post("/c2n", ns.NotImplementedHandler) }) - r.Post("/c2n", ns.NotImplementedHandler) - - r.Get("/ssh-action", ns.SSHAction) - ns.httpBaseConfig = &http.Server{ Handler: r, ReadHeaderTimeout: types.HTTPTimeout, @@ -249,10 +255,233 @@ func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *ht http.Error(writer, "Not implemented yet", http.StatusNotImplemented) } -// SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction] -// to the client with the verdict of an SSH access request. -func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) { - log.Trace().Caller().Str("path", req.URL.String()).Msg("got SSH action request") +func urlParam[T any](req *http.Request, key string) (T, error) { + var zero T + + param := chi.URLParam(req, key) + if param == "" { + return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key) + } + + var value T + switch any(value).(type) { + case string: + v, ok := any(param).(T) + if !ok { + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + value = v + case types.NodeID: + id, err := types.ParseNodeID(param) + if err != nil { + return zero, fmt.Errorf("parsing %s: %w", key, err) + } + + v, ok := any(id).(T) + if !ok { + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + value = v + default: + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + return value, nil +} + +// SSHActionHandler handles the /ssh-action endpoint, returning a +// [tailcfg.SSHAction] to the client with the verdict of an SSH access +// request. +func (ns *noiseServer) SSHActionHandler( + writer http.ResponseWriter, + req *http.Request, +) { + srcNodeID, err := urlParam[types.NodeID](req, "src_node_id") + if err != nil { + httpError(writer, NewHTTPError( + http.StatusBadRequest, + "Invalid src_node_id", + err, + )) + + return + } + + dstNodeID, err := urlParam[types.NodeID](req, "dst_node_id") + if err != nil { + httpError(writer, NewHTTPError( + http.StatusBadRequest, + "Invalid dst_node_id", + err, + )) + + return + } + + reqLog := log.With(). + Uint64("src_node_id", srcNodeID.Uint64()). + Uint64("dst_node_id", dstNodeID.Uint64()). + Str("ssh_user", req.URL.Query().Get("ssh_user")). + Str("local_user", req.URL.Query().Get("local_user")). + Logger() + + reqLog.Trace().Caller().Msg("SSH action request") + + action, err := ns.sshAction( + reqLog, + req.URL.Query().Get("auth_id"), + ) + if err != nil { + httpError(writer, err) + + return + } + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + err = json.NewEncoder(writer).Encode(action) + if err != nil { + reqLog.Error().Caller().Err(err). + Msg("failed to encode SSH action response") + + return + } + + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } +} + +// sshAction resolves the SSH action for the given request parameters. +// It returns the action to send to the client, or an HTTPError on +// failure. +// +// Two cases: +// 1. Initial request — build a HoldAndDelegate URL and wait for the +// user to authenticate. +// 2. Follow-up request — an auth_id is present, wait for the auth +// verdict and accept or reject. +func (ns *noiseServer) sshAction( + reqLog zerolog.Logger, + authIDStr string, +) (*tailcfg.SSHAction, error) { + action := tailcfg.SSHAction{ + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + } + + // Follow-up request with auth_id — wait for the auth verdict. + if authIDStr != "" { + return ns.sshActionFollowUp( + reqLog, &action, authIDStr, + ) + } + + // Initial request — create an auth session and hold. + return ns.sshActionHoldAndDelegate(reqLog, &action) +} + +// sshActionHoldAndDelegate creates a new auth session and returns a +// HoldAndDelegate action that directs the client to authenticate. +func (ns *noiseServer) sshActionHoldAndDelegate( + reqLog zerolog.Logger, + action *tailcfg.SSHAction, +) (*tailcfg.SSHAction, error) { + holdURL, err := url.Parse( + ns.headscale.cfg.ServerURL + + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID" + + "?ssh_user=$SSH_USER&local_user=$LOCAL_USER", + ) + if err != nil { + return nil, NewHTTPError( + http.StatusInternalServerError, + "Internal error", + fmt.Errorf("parsing SSH action URL: %w", err), + ) + } + + authID, err := types.NewAuthID() + if err != nil { + return nil, NewHTTPError( + http.StatusInternalServerError, + "Internal error", + fmt.Errorf("generating auth ID: %w", err), + ) + } + + ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest()) + + authURL := ns.headscale.authProvider.AuthURL(authID) + + q := holdURL.Query() + q.Set("auth_id", authID.String()) + holdURL.RawQuery = q.Encode() + + action.HoldAndDelegate = holdURL.String() + + // TODO(kradalby): here we can also send a very tiny mapresponse + // "popping" the url and opening it for the user. + action.Message = fmt.Sprintf( + "# Headscale SSH requires an additional check.\n"+ + "# To authenticate, visit: %s\n"+ + "# Authentication checked with Headscale SSH.\n", + authURL, + ) + + reqLog.Info().Caller(). + Str("auth_id", authID.String()). + Msg("SSH check pending, waiting for auth") + + return action, nil +} + +// sshActionFollowUp handles follow-up requests where the client +// provides an auth_id. It blocks until the auth session resolves. +func (ns *noiseServer) sshActionFollowUp( + reqLog zerolog.Logger, + action *tailcfg.SSHAction, + authIDStr string, +) (*tailcfg.SSHAction, error) { + authID, err := types.AuthIDFromString(authIDStr) + if err != nil { + return nil, NewHTTPError( + http.StatusBadRequest, + "Invalid auth_id", + fmt.Errorf("parsing auth_id: %w", err), + ) + } + + reqLog = reqLog.With().Str("auth_id", authID.String()).Logger() + + auth, ok := ns.headscale.state.GetAuthCacheEntry(authID) + if !ok { + return nil, NewHTTPError( + http.StatusBadRequest, + "Invalid auth_id", + fmt.Errorf("%w: %s", ErrNoAuthSession, authID), + ) + } + + reqLog.Trace().Caller().Msg("SSH action follow-up") + + verdict := <-auth.WaitForAuth() + + if !verdict.Accept() { + action.Reject = true + + reqLog.Trace().Caller().Err(verdict.Err). + Msg("authentication rejected") + + return action, nil + } + + action.Accept = true + + return action, nil } // PollNetMapHandler takes care of /machine/:id/map using the Noise protocol @@ -380,28 +609,3 @@ func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types. return nv, nil } - -// urlParam extracts a typed URL parameter from a chi router request. -func urlParam[T any](req *http.Request, key string) (T, error) { - var zero T - - param := chi.URLParam(req, key) - if param == "" { - return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key) - } - - var value T - switch any(value).(type) { - case string: - v, ok := any(param).(T) - if !ok { - return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) - } - - value = v - default: - return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) - } - - return value, nil -} diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 6dfacd91..2de2e8dd 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -19,7 +19,7 @@ type PolicyManager interface { MatchersForNode(node types.NodeView) ([]matcher.Match, error) // BuildPeerMap constructs peer relationship maps for the given nodes BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView - SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) + SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) SetPolicy(pol []byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 9c97e39c..536c86f3 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1188,8 +1188,9 @@ func TestSSHPolicyRules(t *testing.T) { "root": "", }, Action: &tailcfg.SSHAction{ - Accept: true, + Accept: false, SessionDuration: 24 * time.Hour, + HoldAndDelegate: "unused-url/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER", AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, @@ -1476,7 +1477,7 @@ func TestSSHPolicyRules(t *testing.T) { require.NoError(t, err) - got, err := pm.SSHPolicy(tt.targetNode.View()) + got, err := pm.SSHPolicy("unused-url", tt.targetNode.View()) require.NoError(t, err) if diff := cmp.Diff(tt.wantSSH, got); diff != "" { diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 9c2c5f17..c8c515cd 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -319,11 +319,27 @@ func (pol *Policy) compileACLWithAutogroupSelf( return rules, nil } -func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { +var sshAccept = tailcfg.SSHAction{ + Reject: false, + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, +} + +func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { return tailcfg.SSHAction{ - Reject: !accept, - Accept: accept, - SessionDuration: duration, + Reject: false, + Accept: false, + SessionDuration: duration, + // Replaced in the client: + // * $SRC_NODE_IP (URL escaped) + // * $SRC_NODE_ID (Node.ID as int64 string) + // * $DST_NODE_IP (URL escaped) + // * $DST_NODE_ID (Node.ID as int64 string) + // * $SSH_USER (URL escaped, ssh user requested) + // * $LOCAL_USER (URL escaped, local user mapped) + HoldAndDelegate: baseURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER", AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, @@ -332,6 +348,7 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { //nolint:gocyclo // complex SSH policy compilation logic func (pol *Policy) compileSSHPolicy( + baseURL string, users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], @@ -377,9 +394,9 @@ func (pol *Policy) compileSSHPolicy( switch rule.Action { case SSHActionAccept: - action = sshAction(true, 0) + action = sshAccept case SSHActionCheck: - action = sshAction(true, time.Duration(rule.CheckPeriod)) + action = sshCheck(baseURL, time.Duration(rule.CheckPeriod)) default: return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) } @@ -503,6 +520,23 @@ func (pol *Policy) compileSSHPolicy( } } + // Sort rules: check (HoldAndDelegate) before accept, per Tailscale + // evaluation order (most-restrictive first). + slices.SortStableFunc(rules, func(a, b *tailcfg.SSHRule) int { + aIsCheck := a.Action != nil && a.Action.HoldAndDelegate != "" + + bIsCheck := b.Action != nil && b.Action.HoldAndDelegate != "" + if aIsCheck == bIsCheck { + return 0 + } + + if aIsCheck { + return -1 + } + + return 1 + }) + return &tailcfg.SSHPolicy{ Rules: rules, }, nil diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index cdf7c131..01d3d71d 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -615,7 +615,7 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { require.NoError(t, err) // Compile SSH policy - sshPolicy, err := tt.policy.compileSSHPolicy(users, tt.targetNode.View(), nodes.ViewSlice()) + sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice()) require.NoError(t, err) if tt.wantEmpty { @@ -691,7 +691,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { err := policy.validate() require.NoError(t, err) - sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -704,11 +704,90 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { } assert.Equal(t, expectedUsers, rule.SSHUsers) - // Verify check action with session duration - assert.True(t, rule.Action.Accept) + // Verify check action: Accept is false, HoldAndDelegate is set + assert.False(t, rule.Action.Accept) + assert.False(t, rule.Action.Reject) + assert.NotEmpty(t, rule.Action.HoldAndDelegate) + assert.Contains(t, rule.Action.HoldAndDelegate, "/machine/ssh/action/") assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration) } +// TestCompileSSHPolicy_CheckBeforeAcceptOrdering verifies that check +// (HoldAndDelegate) rules are sorted before accept rules, even when +// the accept rule appears first in the policy definition. +func TestCompileSSHPolicy_CheckBeforeAcceptOrdering(t *testing.T) { + users := types.Users{ + {Name: "user1", Model: gorm.Model{ID: 1}}, + {Name: "user2", Model: gorm.Model{ID: 2}}, + } + + nodeTaggedServer := types.Node{ + Hostname: "tagged-server", + IPv4: createAddr("100.64.0.1"), + UserID: new(users[0].ID), + User: new(users[0]), + Tags: []string{"tag:server"}, + } + nodeUser2 := types.Node{ + Hostname: "user2-device", + IPv4: createAddr("100.64.0.2"), + UserID: new(users[1].ID), + User: new(users[1]), + } + + nodes := types.Nodes{&nodeTaggedServer, &nodeUser2} + + // Accept rule appears BEFORE check rule in policy definition. + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{"root"}, + }, + { + Action: "check", + CheckPeriod: model.Duration(24 * time.Hour), + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{"ssh-it-user"}, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + sshPolicy, err := policy.compileSSHPolicy( + "unused-server-url", + users, + nodeTaggedServer.View(), + nodes.ViewSlice(), + ) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 2) + + // First rule must be the check rule (HoldAndDelegate set). + assert.NotEmpty(t, sshPolicy.Rules[0].Action.HoldAndDelegate, + "first rule should be check (HoldAndDelegate)") + assert.False(t, sshPolicy.Rules[0].Action.Accept, + "first rule should not be accept") + + // Second rule must be the accept rule. + assert.True(t, sshPolicy.Rules[1].Action.Accept, + "second rule should be accept") + assert.Empty(t, sshPolicy.Rules[1].Action.HoldAndDelegate, + "second rule should not have HoldAndDelegate") +} + // TestSSHIntegrationReproduction reproduces the exact scenario from the integration test // TestSSHOneUserToAll that was failing with empty sshUsers. func TestSSHIntegrationReproduction(t *testing.T) { @@ -756,7 +835,7 @@ func TestSSHIntegrationReproduction(t *testing.T) { require.NoError(t, err) // Test SSH policy compilation for node2 (owned by user2, who is in the group) - sshPolicy, err := policy.compileSSHPolicy(users, node2.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -806,7 +885,7 @@ func TestSSHJSONSerialization(t *testing.T) { err := policy.validate() require.NoError(t, err) - sshPolicy, err := policy.compileSSHPolicy(users, node.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) @@ -1413,7 +1492,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for user1's first node node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1432,7 +1511,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for user2's first node node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy2) require.Len(t, sshPolicy2.Rules, 1) @@ -1451,7 +1530,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for tagged node (should have no SSH rules) node5 := nodes[4].View() - sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + sshPolicy3, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy3 != nil { @@ -1491,7 +1570,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { // For user1's node: should allow SSH from user1's devices node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1508,7 +1587,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { // For user2's node: should have no rules (user1's devices can't match user2's self) node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1551,7 +1630,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { // For user1's node: should allow SSH from user1's devices only (not user2's) node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1568,7 +1647,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { // For user3's node: should have no rules (not in group:admins) node5 := nodes[4].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1610,7 +1689,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { // For untagged node: should only get principals from other untagged nodes node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1628,7 +1707,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { // For tagged node: should get no SSH rules node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1671,7 +1750,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Test 1: Compile for user1's device (should only match autogroup:self destination) node1 := nodes[0].View() - sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy1, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy1) require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)") @@ -1690,7 +1769,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Test 2: Compile for router (should only match tag:router destination) routerNode := nodes[3].View() // user2-router - sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice()) + sshPolicyRouter, err := policy.compileSSHPolicy("unused-server-url", users, routerNode, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicyRouter) require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 74b7ba6a..744f52c7 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -222,7 +222,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { return true, nil } -func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { +func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -230,7 +230,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err return sshPol, nil } - sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes) + sshPol, err := pm.pol.compileSSHPolicy(baseURL, pm.users, node, pm.nodes) if err != nil { return nil, fmt.Errorf("compiling SSH policy: %w", err) } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 83585732..f2ae99a9 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -871,7 +871,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha // SSHPolicy returns the SSH access policy for a node. func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { - return s.polMan.SSHPolicy(node) + return s.polMan.SSHPolicy(s.cfg.ServerURL, node) } // Filter returns the current network filter rules and matches. diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 891969d3..a78278a4 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "runtime" + "strings" "sync/atomic" "time" @@ -23,7 +24,8 @@ const ( // Common errors. var ( ErrCannotParsePrefix = errors.New("cannot parse prefix") - ErrInvalidAuthIDLength = errors.New("registration ID has invalid length") + ErrInvalidAuthIDLength = errors.New("auth ID has invalid length") + ErrInvalidAuthIDPrefix = errors.New("auth ID has invalid prefix") ) type StateUpdateType int @@ -159,17 +161,22 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { } } -const AuthIDLength = 24 +const ( + authIDPrefix = "hskey-authreq-" + authIDRandomLength = 24 + // AuthIDLength is the total length of an AuthID: 14 (prefix) + 24 (random). + AuthIDLength = 38 +) type AuthID string func NewAuthID() (AuthID, error) { - rid, err := util.GenerateRandomStringURLSafe(AuthIDLength) + rid, err := util.GenerateRandomStringURLSafe(authIDRandomLength) if err != nil { return "", err } - return AuthID(rid), nil + return AuthID(authIDPrefix + rid), nil } func MustAuthID() AuthID { @@ -197,8 +204,18 @@ func (r AuthID) String() string { } func (r AuthID) Validate() error { + if !strings.HasPrefix(string(r), authIDPrefix) { + return fmt.Errorf( + "%w: expected prefix %q", + ErrInvalidAuthIDPrefix, authIDPrefix, + ) + } + if len(r) != AuthIDLength { - return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r)) + return fmt.Errorf( + "%w: expected %d, got %d", + ErrInvalidAuthIDLength, AuthIDLength, len(r), + ) } return nil @@ -214,6 +231,13 @@ type AuthRequest struct { closed *atomic.Bool } +func NewAuthRequest() AuthRequest { + return AuthRequest{ + finished: make(chan AuthVerdict), + closed: &atomic.Bool{}, + } +} + func NewRegisterAuthRequest(node Node) AuthRequest { return AuthRequest{ node: &node, diff --git a/integration/scenario.go b/integration/scenario.go index cd43b78f..e769bd73 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -141,6 +141,12 @@ type ScenarioSpec struct { // Versions is specific list of versions to use for the test. Versions []string + // OIDCSkipUserCreation, if true, skips creating users via headscale CLI + // during environment setup. Useful for OIDC tests where the SSH policy + // references users by name, since OIDC login creates users automatically + // and pre-creating them via CLI causes duplicate user records. + OIDCSkipUserCreation bool + // OIDCUsers, if populated, will start a Mock OIDC server and populate // the user login stack with the given users. // If the NodesPerUser is set, it should align with this list to ensure @@ -866,9 +872,18 @@ func (s *Scenario) createHeadscaleEnvWithTags( } for _, user := range s.spec.Users { - u, err := s.CreateUser(user) - if err != nil { - return err + var u *v1.User + + if s.spec.OIDCSkipUserCreation { + // Only register locally — OIDC login will create the headscale user. + s.mu.Lock() + s.users[user] = &User{Clients: make(map[string]TailscaleClient)} + s.mu.Unlock() + } else { + u, err = s.CreateUser(user) + if err != nil { + return err + } } var userOpts []tsic.Option diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 45bc2dc7..15867579 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -579,3 +579,75 @@ func TestSSHAutogroupSelf(t *testing.T) { } } } + +func TestSSHOneUserToOneCheckMode(t *testing.T) { + IntegrationSkip(t) + + scenario := sshScenario(t, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + SSHs: []policyv2.SSH{ + { + Action: "check", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + // Use autogroup:member and autogroup:tagged instead of wildcard + // since wildcard (*) is no longer supported for SSH destinations + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + }, + }, + }, + 1, + ) + // defer scenario.ShutdownAssertNoPanics(t) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHHostname(t, client, peer) + } + } + + for _, client := range user2Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHPermissionDenied(t, client, peer) + } + } +}