From cb3b6949ea0c72d45f25d5eb2e8f2ef9584af88a Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 24 Feb 2026 18:48:57 +0000 Subject: [PATCH] auth: generalise auth flow and introduce AuthVerdict Generalise the registration pipeline to a more general auth pipeline supporting both node registrations and SSH check auth requests. Rename RegistrationID to AuthID, unexport AuthRequest fields, and introduce AuthVerdict to unify the auth finish API. Add the urlParam generic helper for extracting typed URL parameters from chi routes, used by the new auth request handler. Updates #1850 --- cmd/headscale/cli/debug.go | 2 +- hscontrol/app.go | 3 +- hscontrol/auth.go | 101 +++++++++-------- hscontrol/auth_tags_test.go | 24 ++--- hscontrol/auth_test.go | 118 +++++++++----------- hscontrol/db/db.go | 4 +- hscontrol/db/db_test.go | 4 +- hscontrol/grpcv1.go | 29 +++-- hscontrol/handlers.go | 45 ++++++-- hscontrol/mapper/batcher_test.go | 4 +- hscontrol/noise.go | 31 ++++++ hscontrol/oidc.go | 118 ++++++++++++-------- hscontrol/state/state.go | 138 ++++++++++++------------ hscontrol/templates/register_web.go | 2 +- hscontrol/templates_consistency_test.go | 8 +- hscontrol/types/common.go | 98 ++++++++++++----- hscontrol/util/util.go | 6 +- hscontrol/util/util_test.go | 10 +- integration/cli_test.go | 34 +++--- 19 files changed, 443 insertions(+), 336 deletions(-) diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index fac317fc..9e4a67fd 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -37,7 +37,7 @@ var createNodeCmd = &cobra.Command{ name, _ := cmd.Flags().GetString("name") registrationID, _ := cmd.Flags().GetString("key") - _, err := types.RegistrationIDFromString(registrationID) + _, err := types.AuthIDFromString(registrationID) if err != nil { return fmt.Errorf("parsing machine key: %w", err) } diff --git a/hscontrol/app.go b/hscontrol/app.go index bb4733f7..87da6f87 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -479,7 +479,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux { r.Get("/health", h.HealthHandler) r.Get("/version", h.VersionHandler) r.Get("/key", h.KeyHandler) - r.Get("/register/{registration_id}", h.authProvider.RegisterHandler) + r.Get("/register/{auth_id}", h.authProvider.RegisterHandler) + r.Get("/auth/{auth_id}", h.authProvider.AuthHandler) if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { r.Get("/oidc/callback", provider.OIDCCallbackHandler) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index d5a77bd7..3a066d91 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -20,7 +20,9 @@ import ( type AuthProvider interface { RegisterHandler(w http.ResponseWriter, r *http.Request) - AuthURL(regID types.RegistrationID) string + AuthHandler(w http.ResponseWriter, r *http.Request) + RegisterURL(authID types.AuthID) string + AuthURL(authID types.AuthID) string } func (h *Headscale) handleRegister( @@ -263,22 +265,24 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err) } - followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) + followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) if err != nil { return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err) } - if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok { + if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok { select { case <-ctx.Done(): return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) - case node := <-reg.Registered: - if node == nil { - // registration is expired in the cache, instruct the client to try a new registration - return h.reqToNewRegisterResponse(req, machineKey) - } + case verdict := <-reg.WaitForAuth(): + if verdict.Accept() { + if !verdict.Node.Valid() { + // registration is expired in the cache, instruct the client to try a new registration + return h.reqToNewRegisterResponse(req, machineKey) + } - return nodeToRegisterResponse(node.View()), nil + return nodeToRegisterResponse(verdict.Node), nil + } } } @@ -293,14 +297,14 @@ func (h *Headscale) reqToNewRegisterResponse( req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - newRegID, err := types.NewRegistrationID() + newAuthID, err := types.NewAuthID() if err != nil { return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) } // Ensure we have a valid hostname hostname := util.EnsureHostname( - req.Hostinfo, + req.Hostinfo.View(), machineKey.String(), req.NodeKey.String(), ) @@ -309,25 +313,25 @@ func (h *Headscale) reqToNewRegisterResponse( hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{}) hostinfo.Hostname = hostname - nodeToRegister := types.NewRegisterNode( - types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - }, - ) - - if !req.Expiry.IsZero() { - nodeToRegister.Node.Expiry = &req.Expiry + nodeToRegister := types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: new(time.Now()), } - log.Info().Msgf("new followup node registration using key: %s", newRegID) - h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) + if !req.Expiry.IsZero() { + nodeToRegister.Expiry = &req.Expiry + } + + authRegReq := types.NewRegisterAuthRequest(nodeToRegister) + + log.Info().Msgf("new followup node registration using key: %s", newAuthID) + h.state.SetAuthCacheEntry(newAuthID, authRegReq) return &tailcfg.RegisterResponse{ - AuthURL: h.authProvider.AuthURL(newRegID), + AuthURL: h.authProvider.RegisterURL(newAuthID), }, nil } @@ -378,13 +382,6 @@ func (h *Headscale) handleRegisterWithAuthKey( // Send both changes. Empty changes are ignored by Change(). h.Change(changed, routesChange) - // TODO(kradalby): I think this is covered above, but we need to validate that. - // // If policy changed due to node registration, send a separate policy change - // if policyChanged { - // policyChange := change.PolicyChange() - // h.Change(policyChange) - // } - resp := &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), @@ -406,14 +403,14 @@ func (h *Headscale) handleRegisterInteractive( req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - registrationId, err := types.NewRegistrationID() + authID, err := types.NewAuthID() if err != nil { return nil, fmt.Errorf("generating registration ID: %w", err) } // Ensure we have a valid hostname hostname := util.EnsureHostname( - req.Hostinfo, + req.Hostinfo.View(), machineKey.String(), req.NodeKey.String(), ) @@ -436,28 +433,28 @@ func (h *Headscale) handleRegisterInteractive( hostinfo.Hostname = hostname - nodeToRegister := types.NewRegisterNode( - types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - }, - ) - - if !req.Expiry.IsZero() { - nodeToRegister.Node.Expiry = &req.Expiry + nodeToRegister := types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: new(time.Now()), } - h.state.SetRegistrationCacheEntry( - registrationId, - nodeToRegister, + if !req.Expiry.IsZero() { + nodeToRegister.Expiry = &req.Expiry + } + + authRegReq := types.NewRegisterAuthRequest(nodeToRegister) + + h.state.SetAuthCacheEntry( + authID, + authRegReq, ) - log.Info().Msgf("starting node registration using key: %s", registrationId) + log.Info().Msgf("starting node registration using key: %s", authID) return &tailcfg.RegisterResponse{ - AuthURL: h.authProvider.AuthURL(registrationId), + AuthURL: h.authProvider.RegisterURL(authID), }, nil } diff --git a/hscontrol/auth_tags_test.go b/hscontrol/auth_tags_test.go index 2564154c..cd9d4c96 100644 --- a/hscontrol/auth_tags_test.go +++ b/hscontrol/auth_tags_test.go @@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 1: Create user-owned node WITH expiry set clientExpiry := time.Now().Add(24 * time.Hour) - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "personal-to-tagged", @@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { }, Expiry: &clientExpiry, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) node, _, err := app.state.HandleNodeFromAuthPath( registrationID1, types.UserID(user.ID), nil, "webauth", @@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 2: Re-auth with tags (Personal → Tagged conversion) nodeKey2 := key.NewNode() - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "personal-to-tagged", @@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { }, Expiry: &clientExpiry, // Client still sends expiry }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) nodeAfter, _, err := app.state.HandleNodeFromAuthPath( registrationID2, types.UserID(user.ID), nil, "webauth", @@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { nodeKey1 := key.NewNode() // Step 1: Create tagged node (expiry should be nil) - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "tagged-to-personal", @@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { RequestTags: []string{"tag:server"}, // Tagged node }, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) node, _, err := app.state.HandleNodeFromAuthPath( registrationID1, types.UserID(user.ID), nil, "webauth", @@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { // Step 2: Re-auth with empty tags (Tagged → Personal conversion) nodeKey2 := key.NewNode() clientExpiry := time.Now().Add(48 * time.Hour) - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "tagged-to-personal", @@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { }, Expiry: &clientExpiry, // Client requests expiry }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) nodeAfter, _, err := app.state.HandleNodeFromAuthPath( registrationID2, types.UserID(user.ID), nil, "webauth", diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 83dfb913..2a878851 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_success", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "followup-success-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "followup-success-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) - // Simulate successful registration - send to buffered channel - // The channel is buffered (size 1), so this can complete immediately - // and handleRegister will receive the value when it starts waiting + // Simulate successful registration + // handleRegister will receive the value when it starts waiting go func() { user := app.state.CreateUserForTest("followup-user") node := app.state.CreateNodeForTest(user, "followup-success-node") - registered <- node + nodeToRegister.FinishAuth(types.AuthVerdict{Node: node.View()}) }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_timeout", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "followup-timeout-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) - // Don't send anything on channel - will timeout + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "followup-timeout-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) + // Don't call FinishRegistration - will timeout return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil }, @@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_node_nil_response", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "nil-response-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "nil-response-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) - // Simulate registration that returns nil (cache expired during auth) - // The channel is buffered (size 1), so this can complete immediately + // Simulate registration that returns empty NodeView (cache expired during auth) go func() { - registered <- nil // Nil indicates cache expiry + nodeToRegister.FinishAuth(types.AuthVerdict{Node: types.NodeView{}}) // Empty view indicates cache expiry }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper // Generate a registration ID that doesn't exist in cache // This simulates an expired/missing cache entry - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } @@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) { // Extract and validate the new registration ID exists in cache newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") - newRegID, err := types.RegistrationIDFromString(newRegIDStr) + newRegID, err := types.AuthIDFromString(newRegIDStr) assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure // Verify new registration entry exists in cache - _, found := app.state.GetRegistrationCacheEntry(newRegID) + _, found := app.state.GetAuthCacheEntry(newRegID) assert.True(t, found, "new registration should exist in cache") }, }, @@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) { require.NoError(t, err) // Verify cache entry exists - cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + cacheEntry, found := app.state.GetAuthCacheEntry(registrationID) assert.True(t, found, "registration cache entry should exist initially") assert.NotNil(t, cacheEntry) @@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern // Cache entry should still exist after auth error (for retry scenarios) - _, stillFound := app.state.GetRegistrationCacheEntry(registrationID) + _, stillFound := app.state.GetAuthCacheEntry(registrationID) assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry") }, }, @@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) { assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") // Both cache entries should exist simultaneously - _, found1 := app.state.GetRegistrationCacheEntry(regID1) - _, found2 := app.state.GetRegistrationCacheEntry(regID2) + _, found1 := app.state.GetAuthCacheEntry(regID1) + _, found2 := app.state.GetAuthCacheEntry(regID2) assert.True(t, found1, "first registration cache entry should exist") assert.True(t, found2, "second registration cache entry should exist") @@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) { require.NoError(t, err) // Verify both exist - _, found1 := app.state.GetRegistrationCacheEntry(regID1) - _, found2 := app.state.GetRegistrationCacheEntry(regID2) + _, found1 := app.state.GetAuthCacheEntry(regID1) + _, found2 := app.state.GetAuthCacheEntry(regID2) assert.True(t, found1, "first cache entry should exist") assert.True(t, found2, "second cache entry should exist") @@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) { } // First registration should still be in cache (not completed) - _, stillFound := app.state.GetRegistrationCacheEntry(regID1) + _, stillFound := app.state.GetAuthCacheEntry(regID1) assert.True(t, stillFound, "first registration should still be pending") }, }, @@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { var ( initialResp *tailcfg.RegisterResponse authURL string - registrationID types.RegistrationID + registrationID types.AuthID finalResp *tailcfg.RegisterResponse err error ) @@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { if step.expectCacheEntry { // Verify registration cache entry was created - cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + cacheEntry, found := app.state.GetAuthCacheEntry(registrationID) require.True(t, found, "registration cache entry should exist") require.NotNil(t, cacheEntry, "cache entry should not be nil") - require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key") + require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key") } case stepTypeAuthCompletion: @@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { // Check cache cleanup expectation for this step if step.expectCacheEntry == false && registrationID != "" { // Verify cache entry was cleaned up - _, found := app.state.GetRegistrationCacheEntry(registrationID) + _, found := app.state.GetAuthCacheEntry(registrationID) require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType) } } @@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { } // extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL. -func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { +func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) { // AuthURL format: "http://localhost/register/abc123" const registerPrefix = "/register/" @@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err idStr := authURL[idx+len(registerPrefix):] - return types.RegistrationIDFromString(idStr) + return types.AuthIDFromString(idStr) } // validateCompleteRegistrationResponse performs comprehensive validation of a registration response. @@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { nodeKey := key.NewNode() // Simulate a registration cache entry (as would be created during web auth) - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "webauth-tags-node", @@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // Complete the web auth - should fail because tag is unauthorized _, _, err := app.state.HandleNodeFromAuthPath( @@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { nodeKey1 := key.NewNode() // Step 1: Initial registration with tags - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "reauth-untag-node", @@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { RequestTags: []string{"tag:valid-owned", "tag:second"}, }, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) // Complete initial registration with tags node, _, err := app.state.HandleNodeFromAuthPath( @@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { // Step 2: Reauth with EMPTY tags to untag nodeKey2 := key.NewNode() // New node key for reauth - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "reauth-untag-node", @@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { RequestTags: []string{}, // EMPTY - should untag }, }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) // Complete reauth with empty tags nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( @@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { // Step 2: Reauth via web auth with EMPTY tags to transition to user-owned nodeKey2 := key.NewNode() // New node key for reauth - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "authkey-tagged-node", @@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { RequestTags: []string{}, // EMPTY - should untag }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // Complete reauth with empty tags nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( @@ -3958,8 +3944,8 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { // Step 4: Re-register the node to alice via HandleNodeFromAuthPath // This is what happens when running: headscale nodes register --user alice --key ... nodeKey2 := key.NewNode() - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key as the tagged node NodeKey: nodeKey2.Public(), Hostname: "tagged-orphan-node", @@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { RequestTags: []string{}, // Empty - transition to user-owned }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // This should NOT panic - before the fix, this would panic with: // panic: runtime error: invalid memory address or nil pointer dereference diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 478614fb..cfc3b789 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -47,7 +47,7 @@ const ( type HSDatabase struct { DB *gorm.DB cfg *types.Config - regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + regCache *zcache.Cache[types.AuthID, types.AuthRequest] } // NewHeadscaleDatabase creates a new database connection and runs migrations. @@ -56,7 +56,7 @@ type HSDatabase struct { //nolint:gocyclo // complex database initialization with many migrations func NewHeadscaleDatabase( cfg *types.Config, - regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], + regCache *zcache.Cache[types.AuthID, types.AuthRequest], ) (*HSDatabase, error) { dbConn, err := openDB(cfg.Database) if err != nil { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 3c687b39..151d9966 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { } } -func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { - return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { + return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) } func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 3af8e807..c0fd5a3e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode( Str(zf.RegistrationKey, registrationKey). Msg("registering node") - registrationId, err := types.RegistrationIDFromString(request.GetKey()) + registrationId, err := types.AuthIDFromString(request.GetKey()) if err != nil { return nil, err } @@ -808,33 +808,32 @@ func (api headscaleV1APIServer) DebugCreateNode( Hostname: request.GetName(), } - registrationId, err := types.RegistrationIDFromString(request.GetKey()) + registrationId, err := types.AuthIDFromString(request.GetKey()) if err != nil { return nil, err } - newNode := types.NewRegisterNode( - types.Node{ - NodeKey: key.NewNode().Public(), - MachineKey: key.NewMachine().Public(), - Hostname: request.GetName(), - User: user, + newNode := types.Node{ + NodeKey: key.NewNode().Public(), + MachineKey: key.NewMachine().Public(), + Hostname: request.GetName(), + User: user, - Expiry: &time.Time{}, - LastSeen: &time.Time{}, + Expiry: &time.Time{}, + LastSeen: &time.Time{}, - Hostinfo: &hostinfo, - }, - ) + Hostinfo: &hostinfo, + } log.Debug(). Caller(). Str("registration_id", registrationId.String()). Msg("adding debug machine via CLI, appending to registration cache") - api.h.state.SetRegistrationCacheEntry(registrationId, newNode) + authRegReq := types.NewRegisterAuthRequest(newNode) + api.h.state.SetAuthCacheEntry(registrationId, authRegReq) - return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil + return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil } func (api headscaleV1APIServer) Health( diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 7c45f1ec..b7aa8460 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/assets" "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" @@ -245,11 +244,41 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb { } } -func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string { +func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - registrationId.String()) + authID.String()) +} + +func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string { + return fmt.Sprintf( + "%s/auth/%s", + strings.TrimSuffix(a.serverURL, "/"), + authID.String()) +} + +func (a *AuthProviderWeb) AuthHandler( + writer http.ResponseWriter, + req *http.Request, +) { +} + +func authIDFromRequest(req *http.Request) (types.AuthID, error) { + registrationId, err := urlParam[types.AuthID](req, "auth_id") + if err != nil { + return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) + } + + // We need to make sure we dont open for XSS style injections, if the parameter that + // is passed as a key is not parsable/validated as a NodePublic key, then fail to render + // the template and log an error. + err = registrationId.Validate() + if err != nil { + return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) + } + + return registrationId, nil } // RegisterHandler shows a simple message in the browser to point to the CLI @@ -261,15 +290,9 @@ func (a *AuthProviderWeb) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { - vars := mux.Vars(req) - registrationIdStr := vars["registration_id"] - - // We need to make sure we dont open for XSS style injections, if the parameter that - // is passed as a key is not parsable/validated as a NodePublic key, then fail to render - // the template and log an error. - registrationId, err := types.RegistrationIDFromString(registrationIdStr) + registrationId, err := authIDFromRequest(req) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) + httpError(writer, err) return } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 9e544633..6f3fbccb 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{ } // emptyCache creates an empty registration cache for testing. -func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { - return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { + return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) } // Test configuration constants. diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 57a79b96..ffcab68e 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -24,6 +24,12 @@ import ( // ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version. var ErrUnsupportedClientVersion = errors.New("unsupported client version") +// ErrMissingURLParameter is returned when a required URL parameter is not provided. +var ErrMissingURLParameter = errors.New("missing URL parameter") + +// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type. +var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type") + const ( // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. ts2021UpgradePath = "/ts2021" @@ -374,3 +380,28 @@ func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types. return nv, nil } + +// urlParam extracts a typed URL parameter from a chi router request. +func urlParam[T any](req *http.Request, key string) (T, error) { + var zero T + + param := chi.URLParam(req, key) + if param == "" { + return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key) + } + + var value T + switch any(value).(type) { + case string: + v, ok := any(param).(T) + if !ok { + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + value = v + default: + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + return value, nil +} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 9d284921..2bc62fa9 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -12,7 +12,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" @@ -26,8 +25,8 @@ import ( const ( randomByteSize = 16 defaultOAuthOptionsCount = 3 - registerCacheExpiration = time.Minute * 15 - registerCacheCleanup = time.Minute * 20 + authCacheExpiration = time.Minute * 15 + authCacheCleanup = time.Minute * 20 ) var ( @@ -44,17 +43,21 @@ var ( errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email") ) -// RegistrationInfo contains both machine key and verifier information for OIDC validation. -type RegistrationInfo struct { - RegistrationID types.RegistrationID - Verifier *string +// AuthInfo contains both auth ID and verifier information for OIDC validation. +type AuthInfo struct { + AuthID types.AuthID + Verifier *string + Registration bool } type AuthProviderOIDC struct { - h *Headscale - serverURL string - cfg *types.OIDCConfig - registrationCache *zcache.Cache[string, RegistrationInfo] + h *Headscale + serverURL string + cfg *types.OIDCConfig + + // authCache holds auth information between + // the auth and the callback steps. + authCache *zcache.Cache[string, AuthInfo] oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -81,45 +84,63 @@ func NewAuthProviderOIDC( Scopes: cfg.Scope, } - registrationCache := zcache.New[string, RegistrationInfo]( - registerCacheExpiration, - registerCacheCleanup, + authCache := zcache.New[string, AuthInfo]( + authCacheExpiration, + authCacheCleanup, ) return &AuthProviderOIDC{ - h: h, - serverURL: serverURL, - cfg: cfg, - registrationCache: registrationCache, + h: h, + serverURL: serverURL, + cfg: cfg, + authCache: authCache, oidcProvider: oidcProvider, oauth2Config: oauth2Config, }, nil } -func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string { +func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string { + return fmt.Sprintf( + "%s/auth/%s", + strings.TrimSuffix(a.serverURL, "/"), + authID.String()) +} + +func (a *AuthProviderOIDC) AuthHandler( + writer http.ResponseWriter, + req *http.Request, +) { + a.authHandler(writer, req, false) +} + +func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - registrationID.String()) + authID.String()) } // RegisterHandler registers the OIDC callback handler with the given router. // It puts NodeKey in cache so the callback can retrieve it using the oidc state param. -// Listens in /register/:registration_id. +// Listens in /register/:auth_id. func (a *AuthProviderOIDC) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { - vars := mux.Vars(req) - registrationIdStr := vars["registration_id"] + a.authHandler(writer, req, true) +} - // We need to make sure we dont open for XSS style injections, if the parameter that - // is passed as a key is not parsable/validated as a NodePublic key, then fail to render - // the template and log an error. - registrationId, err := types.RegistrationIDFromString(registrationIdStr) +// authHandler takes an incoming request that needs to be authenticated and +// validates and prepares it for the OIDC flow. +func (a *AuthProviderOIDC) authHandler( + writer http.ResponseWriter, + req *http.Request, + registration bool, +) { + authID, err := authIDFromRequest(req) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) + httpError(writer, err) return } @@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler( return } - // Initialize registration info with machine key - registrationInfo := RegistrationInfo{ - RegistrationID: registrationId, + registrationInfo := AuthInfo{ + AuthID: authID, + Registration: registration, } extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) @@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler( extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info - a.registrationCache.Set(state, registrationInfo) + a.authCache.Set(state, registrationInfo) authURL := a.oauth2Config.AuthCodeURL(state, extras...) log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL) @@ -302,16 +323,22 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // If the node exists, then the node should be reauthenticated, // if the node does not exist, and the machine key exists, then // this is a new node that should be registered. - registrationId := a.getRegistrationIDFromState(state) + authInfo := a.getAuthInfoFromState(state) + if authInfo == nil { + log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) - // Register the node if it does not exist. - if registrationId != nil { + return + } + + // If this is a registration flow, then we need to register the node. + if authInfo.Registration { verb := "Reauthenticated" - newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) + newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { - log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed") + log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed") httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) return @@ -339,9 +366,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // Neither node nor machine key was found in the state cache meaning - // that we could not reauth nor register the node. - httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + // TODO(kradalby): handle login flow (without registration) if needed. + // We need to send an update here to whatever might be waiting for this auth flow. } func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { @@ -374,7 +400,7 @@ func (a *AuthProviderOIDC) getOauth2Token( var exchangeOpts []oauth2.AuthCodeOption if a.cfg.PKCE.Enabled { - regInfo, ok := a.registrationCache.Get(state) + regInfo, ok := a.authCache.Get(state) if !ok { return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } @@ -507,14 +533,14 @@ func doOIDCAuthorization( return nil } -// getRegistrationIDFromState retrieves the registration ID from the state. -func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID { - regInfo, ok := a.registrationCache.Get(state) +// getAuthInfoFromState retrieves the registration ID from the state. +func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo { + authInfo, ok := a.authCache.Get(state) if !ok { return nil } - return ®Info.RegistrationID + return &authInfo } func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( @@ -562,7 +588,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( func (a *AuthProviderOIDC) handleRegistration( user *types.User, - registrationID types.RegistrationID, + registrationID types.AuthID, expiry time.Time, ) (bool, error) { node, nodeChange, err := a.h.state.HandleNodeFromAuthPath( diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index e8f4b9ce..83585732 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -64,6 +64,9 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore") // ErrNodeNameNotUnique is returned when a node name is not unique. var ErrNodeNameNotUnique = errors.New("node name is not unique") +// ErrRegistrationExpired is returned when a registration has expired. +var ErrRegistrationExpired = errors.New("registration expired") + // State manages Headscale's core state, coordinating between database, policy management, // IP allocation, and DERP routing. All methods are thread-safe. type State struct { @@ -82,8 +85,10 @@ type State struct { derpMap atomic.Pointer[tailcfg.DERPMap] // polMan handles policy evaluation and management polMan policy.PolicyManager - // registrationCache caches node registration data to reduce database load - registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + + // authCache caches any pending authentication requests, from either auth type (Web and OIDC). + authCache *zcache.Cache[types.AuthID, types.AuthRequest] + // primaryRoutes tracks primary route assignments for nodes primaryRoutes *routes.PrimaryRoutes } @@ -101,20 +106,20 @@ func NewState(cfg *types.Config) (*State, error) { cacheCleanup = cfg.Tuning.RegisterCacheCleanup } - registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( + authCache := zcache.New[types.AuthID, types.AuthRequest]( cacheExpiration, cacheCleanup, ) - registrationCache.OnEvicted( - func(id types.RegistrationID, rn types.RegisterNode) { - rn.SendAndClose(nil) + authCache.OnEvicted( + func(id types.AuthID, rn types.AuthRequest) { + rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired}) }, ) db, err := hsdb.NewHeadscaleDatabase( cfg, - registrationCache, + authCache, ) if err != nil { return nil, fmt.Errorf("initializing database: %w", err) @@ -178,12 +183,12 @@ func NewState(cfg *types.Config) (*State, error) { return &State{ cfg: cfg, - db: db, - ipAlloc: ipAlloc, - polMan: polMan, - registrationCache: registrationCache, - primaryRoutes: routes.New(), - nodeStore: nodeStore, + db: db, + ipAlloc: ipAlloc, + polMan: polMan, + authCache: authCache, + primaryRoutes: routes.New(), + nodeStore: nodeStore, }, nil } @@ -1057,9 +1062,9 @@ func (s *State) DeletePreAuthKey(id uint64) error { return s.db.DeletePreAuthKey(id) } -// GetRegistrationCacheEntry retrieves a node registration from cache. -func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) { - entry, found := s.registrationCache.Get(id) +// GetAuthCacheEntry retrieves a node registration from cache. +func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) { + entry, found := s.authCache.Get(id) if !found { return nil, false } @@ -1067,26 +1072,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis return &entry, true } -// SetRegistrationCacheEntry stores a node registration in cache. -func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) { - s.registrationCache.Set(id, entry) +// SetAuthCacheEntry stores a node registration in cache. +func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) { + s.authCache.Set(id, entry) } // logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname. -func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) { - if hostinfo == nil { +func logHostinfoValidation(nv types.NodeView, username, hostname string) { + if !nv.Hostinfo().Valid() { log.Warn(). Caller(). - Str(zf.MachineKey, machineKey). - Str(zf.NodeKey, nodeKey). + EmbedObject(nv). Str(zf.UserName, username). Str(zf.GeneratedHostname, hostname). Msg("Registration had nil hostinfo, generated default hostname") - } else if hostinfo.Hostname == "" { + } else if nv.Hostinfo().Hostname() == "" { log.Warn(). Caller(). - Str(zf.MachineKey, machineKey). - Str(zf.NodeKey, nodeKey). + EmbedObject(nv). Str(zf.UserName, username). Str(zf.GeneratedHostname, hostname). Msg("Registration had empty hostname, generated default") @@ -1128,7 +1131,7 @@ type authNodeUpdateParams struct { // Node to update; must be valid and in NodeStore. ExistingNode types.NodeView // Client data: keys, hostinfo, endpoints. - RegEntry *types.RegisterNode + RegEntry *types.AuthRequest // Pre-validated hostinfo; NetInfo preserved from ExistingNode. ValidHostinfo *tailcfg.Hostinfo // Hostname from hostinfo, or generated from keys if client omits it. @@ -1147,6 +1150,7 @@ type authNodeUpdateParams struct { // an existing node. It updates the node in NodeStore, processes RequestTags, and // persists changes to the database. func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) { + regNv := params.RegEntry.Node() // Log the operation type if params.IsConvertFromTag { log.Info(). @@ -1155,16 +1159,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView Msg("Converting tagged node to user-owned node") } else { log.Info(). - EmbedObject(params.ExistingNode). - Interface("hostinfo", params.RegEntry.Node.Hostinfo). + Object("existing", params.ExistingNode). + Object("incoming", regNv). Msg("Updating existing node registration via reauth") } // Process RequestTags during reauth (#2979) // Due to json:",omitempty", we treat empty/nil as "clear tags" var requestTags []string - if params.RegEntry.Node.Hostinfo != nil { - requestTags = params.RegEntry.Node.Hostinfo.RequestTags + if regNv.Hostinfo().Valid() { + requestTags = regNv.Hostinfo().RequestTags().AsSlice() } oldTags := params.ExistingNode.Tags().AsSlice() @@ -1182,8 +1186,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView // Update existing node in NodeStore - validation passed, safe to mutate updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) { - node.NodeKey = params.RegEntry.Node.NodeKey - node.DiscoKey = params.RegEntry.Node.DiscoKey + node.NodeKey = regNv.NodeKey() + node.DiscoKey = regNv.DiscoKey() node.Hostname = params.Hostname // Preserve NetInfo from existing node when re-registering @@ -1194,7 +1198,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView params.ValidHostinfo, ) - node.Endpoints = params.RegEntry.Node.Endpoints + node.Endpoints = regNv.Endpoints().AsSlice() node.IsOnline = new(false) node.LastSeen = new(time.Now()) @@ -1203,7 +1207,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.IsConvertFromTag { node.RegisterMethod = params.RegisterMethod } else { - node.RegisterMethod = params.RegEntry.Node.RegisterMethod + node.RegisterMethod = regNv.RegisterMethod() } // Track tagged status BEFORE processing tags @@ -1223,7 +1227,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } case !wasTagged && isTagged: // Personal → Tagged: clear expiry (tagged nodes don't expire) @@ -1233,14 +1237,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } case !isTagged: // Personal → Personal: update expiry from client if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } } // Tagged → Tagged: keep existing expiry (nil) - no action needed @@ -1527,13 +1531,13 @@ func (s *State) processReauthTags( // HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC). func (s *State) HandleNodeFromAuthPath( - registrationID types.RegistrationID, + authID types.AuthID, userID types.UserID, expiry *time.Time, registrationMethod string, ) (types.NodeView, change.Change, error) { // Get the registration entry from cache - regEntry, ok := s.GetRegistrationCacheEntry(registrationID) + regEntry, ok := s.GetAuthCacheEntry(authID) if !ok { return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache } @@ -1546,25 +1550,27 @@ func (s *State) HandleNodeFromAuthPath( // Ensure we have a valid hostname from the registration cache entry hostname := util.EnsureHostname( - regEntry.Node.Hostinfo, - regEntry.Node.MachineKey.String(), - regEntry.Node.NodeKey.String(), + regEntry.Node().Hostinfo(), + regEntry.Node().MachineKey().String(), + regEntry.Node().NodeKey().String(), ) // Ensure we have valid hostinfo - validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{}) - validHostinfo.Hostname = hostname + hostinfo := &tailcfg.Hostinfo{} + if regEntry.Node().Hostinfo().Valid() { + hostinfo = regEntry.Node().Hostinfo().AsStruct() + } + + hostinfo.Hostname = hostname logHostinfoValidation( - regEntry.Node.MachineKey.ShortString(), - regEntry.Node.NodeKey.String(), + regEntry.Node(), user.Name, hostname, - regEntry.Node.Hostinfo, ) // Lookup existing nodes - machineKey := regEntry.Node.MachineKey + machineKey := regEntry.Node().MachineKey() existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID)) existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) @@ -1578,7 +1584,7 @@ func (s *State) HandleNodeFromAuthPath( // Create logger with common fields for all auth operations logger := log.With(). - Str(zf.RegistrationID, registrationID.String()). + Str(zf.RegistrationID, authID.String()). Str(zf.UserName, user.Name). Str(zf.MachineKey, machineKey.ShortString()). Str(zf.Method, registrationMethod). @@ -1587,7 +1593,7 @@ func (s *State) HandleNodeFromAuthPath( // Common params for update operations updateParams := authNodeUpdateParams{ RegEntry: regEntry, - ValidHostinfo: validHostinfo, + ValidHostinfo: hostinfo, Hostname: hostname, User: user, Expiry: expiry, @@ -1621,7 +1627,7 @@ func (s *State) HandleNodeFromAuthPath( Msg("Creating new node for different user (same machine key exists for another user)") finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, validHostinfo, + logger, user, regEntry, hostname, hostinfo, expiry, registrationMethod, existingNodeAnyUser, ) if err != nil { @@ -1629,7 +1635,7 @@ func (s *State) HandleNodeFromAuthPath( } } else { finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, validHostinfo, + logger, user, regEntry, hostname, hostinfo, expiry, registrationMethod, types.NodeView{}, ) if err != nil { @@ -1638,10 +1644,10 @@ func (s *State) HandleNodeFromAuthPath( } // Signal to waiting clients - regEntry.SendAndClose(finalNode.AsStruct()) + regEntry.FinishAuth(types.AuthVerdict{Node: finalNode}) // Delete from registration cache - s.registrationCache.Delete(registrationID) + s.authCache.Delete(authID) // Update policy managers usersChange, err := s.updatePolicyManagerUsers() @@ -1670,7 +1676,7 @@ func (s *State) HandleNodeFromAuthPath( func (s *State) createNewNodeFromAuth( logger zerolog.Logger, user *types.User, - regEntry *types.RegisterNode, + regEntry *types.AuthRequest, hostname string, validHostinfo *tailcfg.Hostinfo, expiry *time.Time, @@ -1683,13 +1689,13 @@ func (s *State) createNewNodeFromAuth( return s.createAndSaveNewNode(newNodeParams{ User: *user, - MachineKey: regEntry.Node.MachineKey, - NodeKey: regEntry.Node.NodeKey, - DiscoKey: regEntry.Node.DiscoKey, + MachineKey: regEntry.Node().MachineKey(), + NodeKey: regEntry.Node().NodeKey(), + DiscoKey: regEntry.Node().DiscoKey(), Hostname: hostname, Hostinfo: validHostinfo, - Endpoints: regEntry.Node.Endpoints, - Expiry: cmp.Or(expiry, regEntry.Node.Expiry), + Endpoints: regEntry.Node().Endpoints().AsSlice(), + Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()), RegisterMethod: registrationMethod, ExistingNodeForNetinfo: existingNodeForNetinfo, }) @@ -1784,7 +1790,7 @@ func (s *State) HandleNodeFromPreAuthKey( // Ensure we have a valid hostname - handle nil/empty cases hostname := util.EnsureHostname( - regReq.Hostinfo, + regReq.Hostinfo.View(), machineKey.String(), regReq.NodeKey.String(), ) @@ -1793,14 +1799,6 @@ func (s *State) HandleNodeFromPreAuthKey( validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{}) validHostinfo.Hostname = hostname - logHostinfoValidation( - machineKey.ShortString(), - regReq.NodeKey.ShortString(), - pakUsername(), - hostname, - regReq.Hostinfo, - ) - log.Debug(). Caller(). Str(zf.NodeName, hostname). diff --git a/hscontrol/templates/register_web.go b/hscontrol/templates/register_web.go index 829af7fb..cdede03b 100644 --- a/hscontrol/templates/register_web.go +++ b/hscontrol/templates/register_web.go @@ -7,7 +7,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" ) -func RegisterWeb(registrationID types.RegistrationID) *elem.Element { +func RegisterWeb(registrationID types.AuthID) *elem.Element { return HtmlStructure( elem.Title(nil, elem.Text("Registration - Headscale")), mdTypesetBody( diff --git a/hscontrol/templates_consistency_test.go b/hscontrol/templates_consistency_test.go index 369639cc..0464fb88 100644 --- a/hscontrol/templates_consistency_test.go +++ b/hscontrol/templates_consistency_test.go @@ -21,7 +21,7 @@ func TestTemplateHTMLConsistency(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), }, { name: "Windows Config", @@ -77,7 +77,7 @@ func TestTemplateModernHTMLFeatures(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), }, { name: "Windows Config", @@ -125,7 +125,7 @@ func TestTemplateExternalLinkSecurity(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), externalURLs: []string{}, // No external links }, { @@ -190,7 +190,7 @@ func TestTemplateAccessibilityAttributes(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), }, { name: "Windows Config", diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index d852753e..891969d3 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -22,8 +22,8 @@ const ( // Common errors. var ( - ErrCannotParsePrefix = errors.New("cannot parse prefix") - ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length") + ErrCannotParsePrefix = errors.New("cannot parse prefix") + ErrInvalidAuthIDLength = errors.New("registration ID has invalid length") ) type StateUpdateType int @@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { } } -const RegistrationIDLength = 24 +const AuthIDLength = 24 -type RegistrationID string +type AuthID string -func NewRegistrationID() (RegistrationID, error) { - rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength) +func NewAuthID() (AuthID, error) { + rid, err := util.GenerateRandomStringURLSafe(AuthIDLength) if err != nil { return "", err } - return RegistrationID(rid), nil + return AuthID(rid), nil } -func MustRegistrationID() RegistrationID { - rid, err := NewRegistrationID() +func MustAuthID() AuthID { + rid, err := NewAuthID() if err != nil { panic(err) } @@ -181,43 +181,89 @@ func MustRegistrationID() RegistrationID { return rid } -func RegistrationIDFromString(str string) (RegistrationID, error) { - if len(str) != RegistrationIDLength { - return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str)) +func AuthIDFromString(str string) (AuthID, error) { + r := AuthID(str) + + err := r.Validate() + if err != nil { + return "", err } - return RegistrationID(str), nil + return r, nil } -func (r RegistrationID) String() string { +func (r AuthID) String() string { return string(r) } -type RegisterNode struct { - Node Node - Registered chan *Node - closed *atomic.Bool +func (r AuthID) Validate() error { + if len(r) != AuthIDLength { + return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r)) + } + + return nil } -func NewRegisterNode(node Node) RegisterNode { - return RegisterNode{ - Node: node, - Registered: make(chan *Node), - closed: &atomic.Bool{}, +// AuthRequest represent a pending authentication request from a user or a node. +// If it is a registration request, the node field will be populate with the node that is trying to register. +// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel. +// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed. +type AuthRequest struct { + node *Node + finished chan AuthVerdict + closed *atomic.Bool +} + +func NewRegisterAuthRequest(node Node) AuthRequest { + return AuthRequest{ + node: &node, + finished: make(chan AuthVerdict), + closed: &atomic.Bool{}, } } -func (rn *RegisterNode) SendAndClose(node *Node) { +// Node returns the node that is trying to register. +// It will panic if the AuthRequest is not a registration request. +// Can _only_ be used in the registration path. +func (rn *AuthRequest) Node() NodeView { + if rn.node == nil { + panic("Node can only be used in registration requests") + } + + return rn.node.View() +} + +func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) { if rn.closed.Swap(true) { return } select { - case rn.Registered <- node: + case rn.finished <- verdict: default: } - close(rn.Registered) + close(rn.finished) +} + +func (rn *AuthRequest) WaitForAuth() <-chan AuthVerdict { + return rn.finished +} + +type AuthVerdict struct { + // Err is the error that occurred during the authentication process, if any. + // If Err is nil, the authentication process has succeeded. + // If Err is not nil, the authentication process has failed and the node should not be authenticated. + Err error + + // Node is the node that has been authenticated. + // Node is only valid if the auth request was a registration request + // and the authentication process has succeeded. + Node NodeView +} + +func (v AuthVerdict) Accept() bool { + return v.Err == nil } // DefaultBatcherWorkers returns the default number of batcher workers. diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index cbce663b..034779b5 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -295,8 +295,8 @@ func IsCI() bool { // 3. If normalisation fails → generate invalid- replacement // // Returns the guaranteed-valid hostname to use. -func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string { - if hostinfo == nil || hostinfo.Hostname == "" { +func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string { + if !hostinfo.Valid() || hostinfo.Hostname() == "" { key := cmp.Or(machineKey, nodeKey) if key == "" { return "unknown-node" @@ -310,7 +310,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri return "node-" + keyPrefix } - lowercased := strings.ToLower(hostinfo.Hostname) + lowercased := strings.ToLower(hostinfo.Hostname()) err := ValidateHostname(lowercased) if err == nil { diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 5cca4990..6e7a0630 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.want, "invalid-") { if !strings.HasPrefix(got, "invalid-") { @@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.wantHostname, "invalid-") { if !strings.HasPrefix(gotHostname, "invalid-") { @@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) { hostinfo := &tailcfg.Hostinfo{Hostname: hostname} - result := EnsureHostname(hostinfo, "mkey", "nkey") + result := EnsureHostname(hostinfo.View(), "mkey", "nkey") if len(result) > 63 { t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) } @@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) { OS: "linux", } - hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey") - hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey") + hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey") + hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey") if hostname1 != hostname2 { t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2) diff --git a/integration/cli_test.go b/integration/cli_test.go index a1174277..c46361d4 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) @@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) { assert.Equal(t, "node-5", listAll[4].GetName()) otherUserRegIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) @@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) @@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs))