diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 459fc664..dc2787c0 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -31,9 +31,11 @@ jobs: - TestOIDC024UserCreation - TestOIDCAuthenticationWithPKCE - TestOIDCReloginSameNodeNewUser + - TestOIDCReloginSameNodeSameUser - TestOIDCFollowUpUrl - TestAuthWebFlowAuthenticationPingAll - - TestAuthWebFlowLogoutAndRelogin + - TestAuthWebFlowLogoutAndReloginSameUser + - TestAuthWebFlowLogoutAndReloginNewUser - TestUserCommand - TestPreAuthKeyCommand - TestPreAuthKeyCommandWithoutExpiry diff --git a/hscontrol/auth.go b/hscontrol/auth.go index c28ecf20..22f8cd7c 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -25,26 +26,84 @@ type AuthProvider interface { func (h *Headscale) handleRegister( ctx context.Context, - regReq tailcfg.RegisterRequest, + req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey) + // Check for logout/expiry FIRST, before checking auth key. + // Tailscale clients may send logout requests with BOTH a past expiry AND an auth key. + // A past expiry takes precedence - it's a logout regardless of other fields. + if !req.Expiry.IsZero() && req.Expiry.Before(time.Now()) { + log.Debug(). + Str("node.key", req.NodeKey.ShortString()). + Time("expiry", req.Expiry). + Bool("has_auth", req.Auth != nil). + Msg("Detected logout attempt with past expiry") - if ok { - resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey) - if err != nil { - return nil, fmt.Errorf("handling existing node: %w", err) + // This is a logout attempt (expiry in the past) + if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok { + log.Debug(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Bool("is_ephemeral", node.IsEphemeral()). + Bool("has_authkey", node.AuthKey().Valid()). + Msg("Found existing node for logout, calling handleLogout") + + resp, err := h.handleLogout(node, req, machineKey) + if err != nil { + return nil, fmt.Errorf("handling logout: %w", err) + } + if resp != nil { + return resp, nil + } + } else { + log.Warn(). + Str("node.key", req.NodeKey.ShortString()). + Msg("Logout attempt but node not found in NodeStore") } - - return resp, nil } - if regReq.Followup != "" { - return h.waitForFollowup(ctx, regReq, machineKey) + // If the register request does not contain a Auth struct, it means we are logging + // out an existing node (legacy logout path for clients that send Auth=nil). + if req.Auth == nil { + // If the register request present a NodeKey that is currently in use, we will + // check if the node needs to be sent to re-auth, or if the node is logging out. + // We do not look up nodes by [key.MachinePublic] as it might belong to multiple + // nodes, separated by users and this path is handling expiring/logout paths. + if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok { + resp, err := h.handleLogout(node, req, machineKey) + if err != nil { + return nil, fmt.Errorf("handling existing node: %w", err) + } + + // If resp is not nil, we have a response to return to the node. + // If resp is nil, we should proceed and see if the node is trying to re-auth. + if resp != nil { + return resp, nil + } + } else { + // If the register request is not attempting to register a node, and + // we cannot match it with an existing node, we consider that unexpected + // as only register nodes should attempt to log out. + log.Debug(). + Str("node.key", req.NodeKey.ShortString()). + Str("machine.key", machineKey.ShortString()). + Bool("unexpected", true). + Msg("received register request with no auth, and no existing node") + } } - if regReq.Auth != nil && regReq.Auth.AuthKey != "" { - resp, err := h.handleRegisterWithAuthKey(regReq, machineKey) + // If the [tailcfg.RegisterRequest] has a Followup URL, it means that the + // node has already started the registration process and we should wait for + // it to finish the original registration. + if req.Followup != "" { + return h.waitForFollowup(ctx, req, machineKey) + } + + // Pre authenticated keys are handled slightly different than interactive + // logins as they can be done fully sync and we can respond to the node with + // the result as it is waiting. + if isAuthKey(req) { + resp, err := h.handleRegisterWithAuthKey(req, machineKey) if err != nil { // Preserve HTTPError types so they can be handled properly by the HTTP layer var httpErr HTTPError @@ -58,7 +117,7 @@ func (h *Headscale) handleRegister( return resp, nil } - resp, err := h.handleRegisterInteractive(regReq, machineKey) + resp, err := h.handleRegisterInteractive(req, machineKey) if err != nil { return nil, fmt.Errorf("handling register interactive: %w", err) } @@ -66,20 +125,34 @@ func (h *Headscale) handleRegister( return resp, nil } -func (h *Headscale) handleExistingNode( - node *types.Node, - regReq tailcfg.RegisterRequest, +// handleLogout checks if the [tailcfg.RegisterRequest] is a +// logout attempt from a node. If the node is not attempting to +func (h *Headscale) handleLogout( + node types.NodeView, + req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - if node.MachineKey != machineKey { + // Fail closed if it looks like this is an attempt to modify a node where + // the node key and the machine key the noise session was started with does + // not align. + if node.MachineKey() != machineKey { return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil) } - expired := node.IsExpired() + // Note: We do NOT return early if req.Auth is set, because Tailscale clients + // may send logout requests with BOTH a past expiry AND an auth key. + // A past expiry indicates logout, regardless of whether Auth is present. + // The expiry check below will handle the logout logic. // If the node is expired and this is not a re-authentication attempt, - // force the client to re-authenticate - if expired && regReq.Auth == nil { + // force the client to re-authenticate. + // TODO(kradalby): I wonder if this is a path we ever hit? + if node.IsExpired() { + log.Trace().Str("node.name", node.Hostname()). + Uint64("node.id", node.ID().Uint64()). + Interface("reg.req", req). + Bool("unexpected", true). + Msg("Node key expired, forcing re-authentication") return &tailcfg.RegisterResponse{ NodeKeyExpired: true, MachineAuthorized: false, @@ -87,49 +160,76 @@ func (h *Headscale) handleExistingNode( }, nil } - if !expired && !regReq.Expiry.IsZero() { - requestExpiry := regReq.Expiry + // If we get here, the node is not currently expired, and not trying to + // do an auth. + // The node is likely logging out, but before we run that logic, we will validate + // that the node is not attempting to tamper/extend their expiry. + // If it is not, we will expire the node or in the case of an ephemeral node, delete it. - // The client is trying to extend their key, this is not allowed. - if requestExpiry.After(time.Now()) { - return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil) - } - - // If the request expiry is in the past, we consider it a logout. - if requestExpiry.Before(time.Now()) { - if node.IsEphemeral() { - c, err := h.state.DeleteNode(node.View()) - if err != nil { - return nil, fmt.Errorf("deleting ephemeral node: %w", err) - } - - h.Change(c) - - return nil, nil - } - } - - updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry) - if err != nil { - return nil, fmt.Errorf("setting node expiry: %w", err) - } - - h.Change(c) - - // CRITICAL: Use the updated node view for the response - // The original node object has stale expiry information - node = updatedNode.AsStruct() + // The client is trying to extend their key, this is not allowed. + if req.Expiry.After(time.Now()) { + return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil) } - return nodeToRegisterResponse(node), nil + // If the request expiry is in the past, we consider it a logout. + if req.Expiry.Before(time.Now()) { + log.Debug(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Bool("is_ephemeral", node.IsEphemeral()). + Bool("has_authkey", node.AuthKey().Valid()). + Time("req.expiry", req.Expiry). + Msg("Processing logout request with past expiry") + + if node.IsEphemeral() { + log.Info(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Msg("Deleting ephemeral node during logout") + + c, err := h.state.DeleteNode(node) + if err != nil { + return nil, fmt.Errorf("deleting ephemeral node: %w", err) + } + + h.Change(c) + + return &tailcfg.RegisterResponse{ + NodeKeyExpired: true, + MachineAuthorized: false, + }, nil + } + + log.Debug(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Msg("Node is not ephemeral, setting expiry instead of deleting") + } + + // Update the internal state with the nodes new expiry, meaning it is + // logged out. + updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), req.Expiry) + if err != nil { + return nil, fmt.Errorf("setting node expiry: %w", err) + } + + h.Change(c) + + return nodeToRegisterResponse(updatedNode), nil } -func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { +// isAuthKey reports if the register request is a registration request +// using an pre auth key. +func isAuthKey(req tailcfg.RegisterRequest) bool { + return req.Auth != nil && req.Auth.AuthKey != "" +} + +func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse { return &tailcfg.RegisterResponse{ // TODO(kradalby): Only send for user-owned nodes // and not tagged nodes when tags is working. - User: *node.User.TailscaleUser(), - Login: *node.User.TailscaleLogin(), + User: node.UserView().TailscaleUser(), + Login: node.UserView().TailscaleLogin(), NodeKeyExpired: node.IsExpired(), // Headscale does not implement the concept of machine authorization @@ -141,10 +241,10 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse { func (h *Headscale) waitForFollowup( ctx context.Context, - regReq tailcfg.RegisterRequest, + req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - fu, err := url.Parse(regReq.Followup) + fu, err := url.Parse(req.Followup) if err != nil { return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err) } @@ -161,21 +261,21 @@ func (h *Headscale) waitForFollowup( case node := <-reg.Registered: if node == nil { // registration is expired in the cache, instruct the client to try a new registration - return h.reqToNewRegisterResponse(regReq, machineKey) + return h.reqToNewRegisterResponse(req, machineKey) } - return nodeToRegisterResponse(node), nil + return nodeToRegisterResponse(node.View()), nil } } // if the follow-up registration isn't found anymore, instruct the client to try a new registration - return h.reqToNewRegisterResponse(regReq, machineKey) + return h.reqToNewRegisterResponse(req, machineKey) } // reqToNewRegisterResponse refreshes the registration flow by creating a new // registration ID and returning the corresponding AuthURL so the client can // restart the authentication process. func (h *Headscale) reqToNewRegisterResponse( - regReq tailcfg.RegisterRequest, + req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { newRegID, err := types.NewRegistrationID() @@ -183,18 +283,25 @@ func (h *Headscale) reqToNewRegisterResponse( return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) } + // Ensure we have valid hostinfo and hostname + validHostinfo, hostname := util.EnsureValidHostinfo( + req.Hostinfo, + machineKey.String(), + req.NodeKey.String(), + ) + nodeToRegister := types.NewRegisterNode( types.Node{ - Hostname: regReq.Hostinfo.Hostname, + Hostname: hostname, MachineKey: machineKey, - NodeKey: regReq.NodeKey, - Hostinfo: regReq.Hostinfo, + NodeKey: req.NodeKey, + Hostinfo: validHostinfo, LastSeen: ptr.To(time.Now()), }, ) - if !regReq.Expiry.IsZero() { - nodeToRegister.Node.Expiry = ®Req.Expiry + if !req.Expiry.IsZero() { + nodeToRegister.Node.Expiry = &req.Expiry } log.Info().Msgf("New followup node registration using key: %s", newRegID) @@ -206,11 +313,11 @@ func (h *Headscale) reqToNewRegisterResponse( } func (h *Headscale) handleRegisterWithAuthKey( - regReq tailcfg.RegisterRequest, + req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { node, changed, err := h.state.HandleNodeFromPreAuthKey( - regReq, + req, machineKey, ) if err != nil { @@ -262,18 +369,26 @@ func (h *Headscale) handleRegisterWithAuthKey( // h.Change(policyChange) // } - user := node.User() - - return &tailcfg.RegisterResponse{ + resp := &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), - User: *user.TailscaleUser(), - Login: *user.TailscaleLogin(), - }, nil + User: node.UserView().TailscaleUser(), + Login: node.UserView().TailscaleLogin(), + } + + log.Trace(). + Caller(). + Interface("reg.resp", resp). + Interface("reg.req", req). + Str("node.name", node.Hostname()). + Uint64("node.id", node.ID().Uint64()). + Msg("RegisterResponse") + + return resp, nil } func (h *Headscale) handleRegisterInteractive( - regReq tailcfg.RegisterRequest, + req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { registrationId, err := types.NewRegistrationID() @@ -281,18 +396,39 @@ func (h *Headscale) handleRegisterInteractive( return nil, fmt.Errorf("generating registration ID: %w", err) } + // Ensure we have valid hostinfo and hostname + validHostinfo, hostname := util.EnsureValidHostinfo( + req.Hostinfo, + machineKey.String(), + req.NodeKey.String(), + ) + + if req.Hostinfo == nil { + log.Warn(). + Str("machine.key", machineKey.ShortString()). + Str("node.key", req.NodeKey.ShortString()). + Str("generated.hostname", hostname). + Msg("Received registration request with nil hostinfo, generated default hostname") + } else if req.Hostinfo.Hostname == "" { + log.Warn(). + Str("machine.key", machineKey.ShortString()). + Str("node.key", req.NodeKey.ShortString()). + Str("generated.hostname", hostname). + Msg("Received registration request with empty hostname, generated default") + } + nodeToRegister := types.NewRegisterNode( types.Node{ - Hostname: regReq.Hostinfo.Hostname, + Hostname: hostname, MachineKey: machineKey, - NodeKey: regReq.NodeKey, - Hostinfo: regReq.Hostinfo, + NodeKey: req.NodeKey, + Hostinfo: validHostinfo, LastSeen: ptr.To(time.Now()), }, ) - if !regReq.Expiry.IsZero() { - nodeToRegister.Node.Expiry = ®Req.Expiry + if !req.Expiry.IsZero() { + nodeToRegister.Node.Expiry = &req.Expiry } h.state.SetRegistrationCacheEntry( diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go new file mode 100644 index 00000000..1727be1a --- /dev/null +++ b/hscontrol/auth_test.go @@ -0,0 +1,3006 @@ +package hscontrol + +import ( + "context" + "fmt" + "net/url" + "strings" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/mapper" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// Interactive step type constants +const ( + stepTypeInitialRequest = "initial_request" + stepTypeAuthCompletion = "auth_completion" + stepTypeFollowupRequest = "followup_request" +) + +// interactiveStep defines a step in the interactive authentication workflow +type interactiveStep struct { + stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest + expectAuthURL bool + expectCacheEntry bool + callAuthPath bool // Real call to HandleNodeFromAuthPath, not mocked +} + +func TestAuthenticationFlows(t *testing.T) { + // Shared test keys for consistent behavior across test cases + machineKey1 := key.NewMachine() + machineKey2 := key.NewMachine() + nodeKey1 := key.NewNode() + nodeKey2 := key.NewNode() + + tests := []struct { + name string + setupFunc func(*testing.T, *Headscale) (string, error) // Returns dynamic values like auth keys + request func(dynamicValue string) tailcfg.RegisterRequest + machineKey func() key.MachinePublic + wantAuth bool + wantError bool + wantAuthURL bool + wantExpired bool + validate func(*testing.T, *tailcfg.RegisterResponse, *Headscale) + + // Interactive workflow support + requiresInteractiveFlow bool + interactiveSteps []interactiveStep + validateRegistrationCache bool + expectedAuthURLPattern string + simulateAuthCompletion bool + validateCompleteResponse bool + }{ + // === PRE-AUTH KEY SCENARIOS === + // Tests authentication using pre-authorization keys for automated node registration. + // Pre-auth keys allow nodes to join without interactive authentication. + + // TEST: Valid pre-auth key registers a new node + // WHAT: Tests successful node registration using a valid pre-auth key + // INPUT: Register request with valid pre-auth key, node key, and hostinfo + // EXPECTED: Node is authorized immediately, registered in database + // WHY: Pre-auth keys enable automated/headless node registration without user interaction + { + name: "preauth_key_valid_new_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("preauth-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "preauth-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + assert.NotEmpty(t, resp.User.DisplayName) + + // Verify node was created in database + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "preauth-node-1", node.Hostname()) + }, + }, + + // TEST: Reusable pre-auth key can register multiple nodes + // WHAT: Tests that a reusable pre-auth key can be used for multiple node registrations + // INPUT: Same reusable pre-auth key used to register two different nodes + // EXPECTED: Both nodes successfully register with the same key + // WHY: Reusable keys allow multiple machines to join using one key (useful for fleet deployments) + { + name: "preauth_key_reusable_multiple_nodes", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("reusable-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Use the key for first node + firstReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reusable-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reusable-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey2.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify both nodes exist + node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) + node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) + assert.True(t, found2) + assert.Equal(t, "reusable-node-1", node1.Hostname()) + assert.Equal(t, "reusable-node-2", node2.Hostname()) + }, + }, + + // TEST: Single-use pre-auth key cannot be reused + // WHAT: Tests that a single-use pre-auth key fails on second use + // INPUT: Single-use key used for first node (succeeds), then attempted for second node + // EXPECTED: First node registers successfully, second node fails with error + // WHY: Single-use keys provide security by preventing key reuse after initial registration + { + name: "preauth_key_single_use_exhausted", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("single-use-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + if err != nil { + return "", err + } + + // Use the key for first node (should work) + firstReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "single-use-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "single-use-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey2.Public() }, + wantError: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // First node should exist, second should not + _, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) + _, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) + assert.False(t, found2) + }, + }, + + // TEST: Invalid pre-auth key is rejected + // WHAT: Tests that an invalid/non-existent pre-auth key is rejected + // INPUT: Register request with invalid auth key string + // EXPECTED: Registration fails with error + // WHY: Invalid keys must be rejected to prevent unauthorized node registration + { + name: "preauth_key_invalid", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "invalid-key-12345", nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "invalid-key-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + + // TEST: Ephemeral pre-auth key creates ephemeral node + // WHAT: Tests that a node registered with ephemeral key is marked as ephemeral + // INPUT: Pre-auth key with ephemeral=true, standard register request + // EXPECTED: Node registers and is marked as ephemeral (will be deleted on logout) + // WHY: Ephemeral nodes auto-cleanup when disconnected, useful for temporary/CI environments + { + name: "preauth_key_ephemeral_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("ephemeral-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "ephemeral-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify ephemeral node was created + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.NotNil(t, node.AuthKey) + assert.True(t, node.AuthKey().Ephemeral()) + }, + }, + + // === INTERACTIVE REGISTRATION SCENARIOS === + // Tests interactive authentication flow where user completes registration via web UI. + // Interactive flow: node requests registration → receives AuthURL → user authenticates → node gets registered + + // TEST: Complete interactive workflow for new node + // WHAT: Tests full interactive registration flow from initial request to completion + // INPUT: Register request with no auth → user completes auth → followup request + // EXPECTED: Initial request returns AuthURL, after auth completion node is registered + // WHY: Interactive flow is the standard user-facing authentication method for new nodes + { + name: "full_interactive_workflow_new_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-flow-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, // cleaned up after completion + }, + validateCompleteResponse: true, + expectedAuthURLPattern: "/register/", + }, + // TEST: Interactive workflow with no Auth struct in request + // WHAT: Tests interactive flow when request has no Auth field (nil) + // INPUT: Register request with Auth field set to nil + // EXPECTED: Node receives AuthURL and can complete registration via interactive flow + // WHY: Validates handling of requests without Auth field, same as empty auth + { + name: "interactive_workflow_no_auth_struct", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + // No Auth field at all + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-no-auth-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, // cleaned up after completion + }, + validateCompleteResponse: true, + expectedAuthURLPattern: "/register/", + }, + + // === EXISTING NODE SCENARIOS === + // Tests behavior when existing registered nodes send requests (logout, re-auth, expiry, etc.) + + // TEST: Existing node logout with past expiry + // WHAT: Tests node logout by sending request with expiry in the past + // INPUT: Previously registered node sends request with Auth=nil and past expiry time + // EXPECTED: Node expiry is updated, NodeKeyExpired=true, MachineAuthorized=true (for compatibility) + // WHY: Nodes signal logout by setting expiry to past time; system updates node state accordingly + { + name: "existing_node_logout", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("logout-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node first + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "logout-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + t.Logf("Setup registered node: %+v", resp) + + // Wait for node to be available in NodeStore with debug info + var attemptCount int + require.EventuallyWithT(t, func(c *assert.CollectT) { + attemptCount++ + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + if assert.True(c, found, "node should be available in NodeStore") { + t.Logf("Node found in NodeStore after %d attempts", attemptCount) + } + }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(-1 * time.Hour), // Past expiry = logout + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + wantExpired: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.True(t, resp.NodeKeyExpired) + }, + }, + // TEST: Existing node with different machine key is rejected + // WHAT: Tests that requests for existing node with wrong machine key are rejected + // INPUT: Node key matches existing node, but machine key is different + // EXPECTED: Request fails with unauthorized error (machine key mismatch) + // WHY: Machine key must match to prevent node hijacking/impersonation + { + name: "existing_node_machine_key_mismatch", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("mismatch-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register with machineKey1 + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "mismatch-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(-1 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey2.Public() }, // Different machine key + wantError: true, + }, + // TEST: Existing node cannot extend expiry without re-auth + // WHAT: Tests that nodes cannot extend their expiry time without authentication + // INPUT: Existing node sends request with Auth=nil and future expiry (extension attempt) + // EXPECTED: Request fails with error (extending key not allowed) + // WHY: Prevents nodes from extending their own lifetime; must re-authenticate + { + name: "existing_node_key_extension_not_allowed", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("extend-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node first + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "extend-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(48 * time.Hour), // Future time = extend attempt + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + // TEST: Expired node must re-authenticate + // WHAT: Tests that expired nodes receive NodeKeyExpired=true and must re-auth + // INPUT: Previously expired node sends request with no auth + // EXPECTED: Response has NodeKeyExpired=true, node must re-authenticate + // WHY: Expired nodes must go through authentication again for security + { + name: "existing_node_expired_forces_reauth", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("reauth-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node first + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + var node types.NodeView + var found bool + require.EventuallyWithT(t, func(c *assert.CollectT) { + node, found = app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + if !found { + return "", fmt.Errorf("node not found after setup") + } + + // Expire the node + expiredTime := time.Now().Add(-1 * time.Hour) + _, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime) + return "", err + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(24 * time.Hour), // Future expiry + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantExpired: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.NodeKeyExpired) + assert.False(t, resp.MachineAuthorized) + }, + }, + // TEST: Ephemeral node is deleted on logout + // WHAT: Tests that ephemeral nodes are deleted (not just expired) on logout + // INPUT: Ephemeral node sends logout request (past expiry) + // EXPECTED: Node is completely deleted from database, not just marked expired + // WHY: Ephemeral nodes should not persist after logout; auto-cleanup + { + name: "ephemeral_node_logout_deletion", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("ephemeral-logout-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) + if err != nil { + return "", err + } + + // Register ephemeral node + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "ephemeral-logout-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(-1 * time.Hour), // Logout + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantExpired: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.NodeKeyExpired) + assert.False(t, resp.MachineAuthorized) + + // Ephemeral node should be deleted, not just marked expired + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.False(t, found, "ephemeral node should be deleted on logout") + }, + }, + + // === FOLLOWUP REGISTRATION SCENARIOS === + // Tests followup request handling after interactive registration is initiated. + // Followup requests are sent by nodes waiting for auth completion. + + // TEST: Successful followup registration after auth completion + // WHAT: Tests node successfully completes registration via followup URL + // INPUT: Register request with followup URL after auth completion + // EXPECTED: Node receives successful registration response with user info + // WHY: Followup mechanism allows nodes to poll/wait for auth completion + { + name: "followup_registration_success", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + regID, err := types.NewRegistrationID() + if err != nil { + return "", err + } + + registered := make(chan *types.Node, 1) + nodeToRegister := types.RegisterNode{ + Node: types.Node{ + Hostname: "followup-success-node", + }, + Registered: registered, + } + app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + + // Simulate successful registration + go func() { + time.Sleep(20 * time.Millisecond) + user := app.state.CreateUserForTest("followup-user") + node := app.state.CreateNodeForTest(user, "followup-success-node") + registered <- node + }() + + return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + }, + }, + // TEST: Followup registration times out when auth not completed + // WHAT: Tests that followup request times out if auth is not completed in time + // INPUT: Followup request with short timeout, no auth completion + // EXPECTED: Request times out with unauthorized error + // WHY: Prevents indefinite waiting; nodes must retry if auth takes too long + { + name: "followup_registration_timeout", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + regID, err := types.NewRegistrationID() + if err != nil { + return "", err + } + + registered := make(chan *types.Node, 1) + nodeToRegister := types.RegisterNode{ + Node: types.Node{ + Hostname: "followup-timeout-node", + }, + Registered: registered, + } + app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + // Don't send anything on channel - will timeout + + return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + // TEST: Invalid followup URL is rejected + // WHAT: Tests that malformed/invalid followup URLs are rejected + // INPUT: Register request with invalid URL in Followup field + // EXPECTED: Request fails with error (invalid followup URL) + // WHY: Validates URL format to prevent errors and potential exploits + { + name: "followup_invalid_url", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "invalid://url[malformed", nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + // TEST: Non-existent registration ID is rejected + // WHAT: Tests that followup with non-existent registration ID fails + // INPUT: Valid followup URL but registration ID not in cache + // EXPECTED: Request fails with unauthorized error + // WHY: Registration must exist in cache; prevents invalid/expired registrations + { + name: "followup_registration_not_found", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "http://localhost:8080/register/nonexistent-id", nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + + // === EDGE CASES === + // Tests handling of malformed, invalid, or unusual input data + + // TEST: Empty hostname is handled with defensive code + // WHAT: Tests that empty hostname in hostinfo generates a default hostname + // INPUT: Register request with hostinfo containing empty hostname string + // EXPECTED: Node registers successfully with generated hostname (node-MACHINEKEY) + // WHY: Defensive code prevents errors from missing hostnames; generates sensible default + { + name: "empty_hostname", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("empty-hostname-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "", // Empty hostname should be handled gracefully + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + + // Node should be created with generated hostname + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.NotEmpty(t, node.Hostname()) + }, + }, + // TEST: Nil hostinfo is handled with defensive code + // WHAT: Tests that nil hostinfo in register request is handled gracefully + // INPUT: Register request with Hostinfo field set to nil + // EXPECTED: Node registers successfully with generated hostname starting with "node-" + // WHY: Defensive code prevents nil pointer panics; creates valid default hostinfo + { + name: "nil_hostinfo", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("nil-hostinfo-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: nil, // Nil hostinfo should be handled with defensive code + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + + // Node should be created with generated hostname from defensive code + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.NotEmpty(t, node.Hostname()) + // Hostname should start with "node-" (generated from machine key) + assert.True(t, strings.HasPrefix(node.Hostname(), "node-")) + }, + }, + + // === PRE-AUTH KEY WITH EXPIRY SCENARIOS === + // Tests pre-auth key expiration handling + + // TEST: Expired pre-auth key is rejected + // WHAT: Tests that a pre-auth key with past expiration date cannot be used + // INPUT: Pre-auth key with expiry 1 hour in the past + // EXPECTED: Registration fails with error + // WHY: Expired keys must be rejected to maintain security and key lifecycle management + { + name: "preauth_key_expired", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("expired-pak-user") + expiry := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "expired-pak-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + + // TEST: Pre-auth key with ACL tags applies tags to node + // WHAT: Tests that ACL tags from pre-auth key are applied to registered node + // INPUT: Pre-auth key with ACL tags ["tag:test", "tag:integration"], register request + // EXPECTED: Node registers with specified ACL tags applied as ForcedTags + // WHY: Pre-auth keys can enforce ACL policies on nodes during registration + { + name: "preauth_key_with_acl_tags", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("tagged-pak-user") + tags := []string{"tag:server", "tag:database"} + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, tags) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-pak-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify node was created with tags + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "tagged-pak-node", node.Hostname()) + if node.AuthKey().Valid() { + assert.NotEmpty(t, node.AuthKey().Tags()) + } + }, + }, + + // === RE-AUTHENTICATION SCENARIOS === + // TEST: Existing node re-authenticates with new pre-auth key + // WHAT: Tests that existing node can re-authenticate using new pre-auth key + // INPUT: Existing node sends request with new valid pre-auth key + // EXPECTED: Node successfully re-authenticates, stays authorized + // WHY: Allows nodes to refresh authentication using pre-auth keys + { + name: "existing_node_reauth_with_new_authkey", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("reauth-user") + + // First, register with initial auth key + pak1, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + // Create new auth key for re-authentication + pak2, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + return pak2.Key, nil + }, + request: func(newAuthKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: newAuthKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-node-updated", + }, + Expiry: time.Now().Add(48 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify node was updated, not duplicated + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "reauth-node-updated", node.Hostname()) + }, + }, + // TEST: Existing node re-authenticates via interactive flow + // WHAT: Tests that existing expired node can re-authenticate interactively + // INPUT: Expired node initiates interactive re-authentication + // EXPECTED: Node receives AuthURL and can complete re-authentication + // WHY: Allows expired nodes to re-authenticate without pre-auth keys + { + name: "existing_node_reauth_interactive_flow", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("interactive-reauth-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register initially with auth key + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-reauth-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: "", // Empty auth key triggers interactive flow + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-reauth-node-updated", + }, + Expiry: time.Now().Add(48 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuthURL: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.Contains(t, resp.AuthURL, "register/") + assert.False(t, resp.MachineAuthorized) + }, + }, + + // === NODE KEY ROTATION SCENARIOS === + // Tests node key rotation where node changes its node key while keeping same machine key + + // TEST: Node key rotation with same machine key updates in place + // WHAT: Tests that registering with new node key and same machine key updates existing node + // INPUT: Register node with nodeKey1, then register again with nodeKey2 but same machineKey + // EXPECTED: Node is updated in place; nodeKey2 exists, nodeKey1 no longer exists + // WHY: Same machine key means same physical device; node key rotation updates, doesn't duplicate + { + name: "node_key_rotation_same_machine", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("rotation-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register with initial node key + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "rotation-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + // Create new auth key for rotation + pakRotation, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + return pakRotation.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey2.Public(), // Different node key, same machine + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "rotation-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // When same machine key is used, node is updated in place (not duplicated) + // The old nodeKey1 should no longer exist + _, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.False(t, found1, "old node key should not exist after rotation") + + // The new nodeKey2 should exist with the same machine key + node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found2, "new node key should exist after rotation") + assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "machine key should remain the same") + }, + }, + + // === MALFORMED REQUEST SCENARIOS === + // Tests handling of requests with malformed or unusual field values + + // TEST: Zero-time expiry is handled correctly + // WHAT: Tests registration with expiry set to zero time value + // INPUT: Register request with Expiry set to time.Time{} (zero value) + // EXPECTED: Node registers successfully; zero time treated as no expiry + // WHY: Zero time is valid Go default; should be handled gracefully + { + name: "malformed_expiry_zero_time", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("zero-expiry-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "zero-expiry-node", + }, + Expiry: time.Time{}, // Zero time + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + + // Node should be created with default expiry handling + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "zero-expiry-node", node.Hostname()) + }, + }, + // TEST: Malformed hostinfo with very long hostname is truncated + // WHAT: Tests that excessively long hostname is truncated to DNS label limit + // INPUT: Hostinfo with 110-character hostname (exceeds 63-char DNS limit) + // EXPECTED: Node registers successfully; hostname truncated to 63 characters + // WHY: Defensive code enforces DNS label limit (RFC 1123); prevents errors + { + name: "malformed_hostinfo_invalid_data", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("invalid-hostinfo-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node-with-very-long-hostname-that-might-exceed-normal-limits-and-contain-special-chars-!@#$%", + BackendLogID: "invalid-log-id", + OS: "unknown-os", + OSVersion: "999.999.999", + DeviceModel: "test-device-model", + RequestTags: []string{"invalid:tag", "another!tag"}, + Services: []tailcfg.Service{{Proto: "tcp", Port: 65535}}, + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + + // Node should be created even with malformed hostinfo + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + // Hostname should be sanitized or handled gracefully + assert.NotEmpty(t, node.Hostname()) + }, + }, + + // === REGISTRATION CACHE EDGE CASES === + // Tests edge cases in registration cache handling during interactive flow + + // TEST: Followup registration with nil response (cache expired during auth) + // WHAT: Tests that followup request handles nil node response (cache expired/cleared) + // INPUT: Followup request where auth completion sends nil (cache was cleared) + // EXPECTED: Returns new AuthURL so client can retry authentication + // WHY: Nil response means cache expired - give client new AuthURL instead of error + { + name: "followup_registration_node_nil_response", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + regID, err := types.NewRegistrationID() + if err != nil { + return "", err + } + + registered := make(chan *types.Node, 1) + nodeToRegister := types.RegisterNode{ + Node: types.Node{ + Hostname: "nil-response-node", + }, + Registered: registered, + } + app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + + // Simulate registration that returns nil (cache expired during auth) + go func() { + time.Sleep(20 * time.Millisecond) + registered <- nil // Nil indicates cache expiry + }() + + return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "nil-response-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: false, // Should not be authorized yet - needs to use new AuthURL + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Should get a new AuthURL, not an error + assert.NotEmpty(t, resp.AuthURL, "should receive new AuthURL when cache returns nil") + assert.Contains(t, resp.AuthURL, "/register/", "AuthURL should contain registration path") + assert.False(t, resp.MachineAuthorized, "machine should not be authorized yet") + }, + }, + // TEST: Malformed followup path is rejected + // WHAT: Tests that followup URL with malformed path is rejected + // INPUT: Followup URL with path that doesn't match expected format + // EXPECTED: Request fails with error (invalid followup URL) + // WHY: Path validation prevents processing of corrupted/invalid URLs + { + name: "followup_registration_malformed_path", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "http://localhost:8080/register/", nil // Missing registration ID + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + // TEST: Wrong followup path format is rejected + // WHAT: Tests that followup URL with incorrect path structure fails + // INPUT: Valid URL but path doesn't start with "/register/" + // EXPECTED: Request fails with error (invalid path format) + // WHY: Strict path validation ensures only valid registration URLs accepted + { + name: "followup_registration_wrong_path_format", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "http://localhost:8080/wrong/path/format", nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + + // === AUTH PROVIDER EDGE CASES === + // TEST: Interactive workflow preserves custom hostinfo + // WHAT: Tests that custom hostinfo fields are preserved through interactive flow + // INPUT: Interactive registration with detailed hostinfo (OS, version, model, etc.) + // EXPECTED: Node registers with all hostinfo fields preserved + // WHY: Ensures interactive flow doesn't lose custom hostinfo data + { + name: "interactive_workflow_with_custom_hostinfo", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "custom-interactive-node", + OS: "linux", + OSVersion: "20.04", + DeviceModel: "server", + RequestTags: []string{"tag:server"}, + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, // cleaned up after completion + }, + validateCompleteResponse: true, + expectedAuthURLPattern: "/register/", + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Verify custom hostinfo was preserved through interactive workflow + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found, "node should be found after interactive registration") + if found { + assert.Equal(t, "custom-interactive-node", node.Hostname()) + assert.Equal(t, "linux", node.Hostinfo().OS()) + assert.Equal(t, "20.04", node.Hostinfo().OSVersion()) + assert.Equal(t, "server", node.Hostinfo().DeviceModel()) + assert.Contains(t, node.Hostinfo().RequestTags().AsSlice(), "tag:server") + } + }, + }, + + // === PRE-AUTH KEY USAGE TRACKING === + // Tests accurate tracking of pre-auth key usage counts + + // TEST: Pre-auth key usage count is tracked correctly + // WHAT: Tests that each use of a pre-auth key increments its usage counter + // INPUT: Reusable pre-auth key used to register three different nodes + // EXPECTED: All three nodes register successfully, key usage count increments each time + // WHY: Usage tracking enables monitoring and auditing of pre-auth key usage + { + name: "preauth_key_usage_count_tracking", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("usage-count-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // Single use + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "usage-count-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify auth key usage was tracked + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "usage-count-node", node.Hostname()) + + // Key should now be used up (single use) + if node.AuthKey().Valid() { + assert.False(t, node.AuthKey().Reusable()) + } + }, + }, + + // === REGISTRATION ID GENERATION AND ADVANCED EDGE CASES === + // TEST: Interactive workflow generates valid registration IDs + // WHAT: Tests that interactive flow generates unique, valid registration IDs + // INPUT: Interactive registration request + // EXPECTED: AuthURL contains valid registration ID that can be extracted + // WHY: Registration IDs must be unique and valid for cache lookup + { + name: "interactive_workflow_registration_id_generation", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "registration-id-test-node", + OS: "test-os", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + expectedAuthURLPattern: "/register/", + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Verify registration ID was properly generated and used + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found, "node should be registered after interactive workflow") + if found { + assert.Equal(t, "registration-id-test-node", node.Hostname()) + assert.Equal(t, "test-os", node.Hostinfo().OS()) + } + }, + }, + { + name: "concurrent_registration_same_node_key", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("concurrent-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "concurrent-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify node was registered + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "concurrent-node", node.Hostname()) + }, + }, + // TEST: Auth key expiry vs request expiry handling + // WHAT: Tests that pre-auth key expiry is independent of request expiry + // INPUT: Valid pre-auth key (future expiry), request with past expiry + // EXPECTED: Node registers with request expiry used (logout scenario) + // WHY: Request expiry overrides key expiry; allows logout with valid key + { + name: "auth_key_with_future_expiry_past_request_expiry", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("future-expiry-user") + // Auth key expires in the future + expiry := time.Now().Add(48 * time.Hour) + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "future-expiry-node", + }, + // Request expires before auth key + Expiry: time.Now().Add(12 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Node should be created with request expiry (shorter than auth key expiry) + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "future-expiry-node", node.Hostname()) + }, + }, + // TEST: Re-authentication with different user's auth key + // WHAT: Tests node transfer when re-authenticating with a different user's auth key + // INPUT: Node registered with user1's auth key, re-authenticates with user2's auth key + // EXPECTED: Node is transferred to user2 (updates UserID and related fields) + // WHY: Validates device reassignment scenarios where a machine moves between users + { + name: "reauth_existing_node_different_user_auth_key", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // Create two users + user1 := app.state.CreateUserForTest("user1-context") + user2 := app.state.CreateUserForTest("user2-context") + + // Register node with user1's auth key + pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "context-node-user1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + // Return user2's auth key for re-authentication + pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil) + if err != nil { + return "", err + } + return pak2.Key, nil + }, + request: func(user2AuthKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: user2AuthKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "context-node-user2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify NEW node was created for user2 + node2, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2)) + require.True(t, found, "new node should exist for user2") + assert.Equal(t, uint(2), node2.UserID(), "new node should belong to user2") + + user := node2.User() + assert.Equal(t, "user2-context", user.Username(), "new node should show user2 username") + + // Verify original node still exists for user1 + node1, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1)) + require.True(t, found, "original node should still exist for user1") + assert.Equal(t, uint(1), node1.UserID(), "original node should still belong to user1") + + // Verify they are different nodes (different IDs) + assert.NotEqual(t, node1.ID(), node2.ID(), "should be different node IDs") + }, + }, + // TEST: Re-authentication with different user via interactive flow creates new node + // WHAT: Tests new node creation when re-authenticating interactively with a different user + // INPUT: Node registered with user1, re-authenticates interactively as user2 (same machine key, same node key) + // EXPECTED: New node is created for user2, user1's original node remains (no transfer) + // WHY: Same physical machine can have separate node identities per user + { + name: "interactive_reauth_existing_node_different_user_creates_new_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // Create user1 and register a node with auth key + user1 := app.state.CreateUserForTest("interactive-user-1") + pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register node with user1's auth key first + initialReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "transfer-node-user1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{}, // Empty auth triggers interactive flow + NodeKey: nodeKey1.Public(), // Same node key as original registration + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "transfer-node-user2", // Different hostname + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, // Same machine key + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // User1's original node should STILL exist (not transferred) + node1, found1 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1)) + require.True(t, found1, "user1's original node should still exist") + assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1") + assert.Equal(t, nodeKey1.Public(), node1.NodeKey(), "user1's node should have original node key") + + // User2 should have a NEW node created + node2, found2 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2)) + require.True(t, found2, "user2 should have new node created") + assert.Equal(t, uint(2), node2.UserID(), "user2's node should belong to user2") + + user := node2.User() + assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should show correct username") + + // Both nodes should have the same machine key but different IDs + assert.NotEqual(t, node1.ID(), node2.ID(), "should be different nodes (different IDs)") + assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "user2's node should have same machine key") + }, + }, + // TEST: Followup request after registration cache expiry + // WHAT: Tests that expired followup requests get a new AuthURL instead of error + // INPUT: Followup request for registration ID that has expired/been evicted from cache + // EXPECTED: Returns new AuthURL (not error) so client can retry authentication + // WHY: Validates new reqToNewRegisterResponse functionality - prevents client getting stuck + { + name: "followup_request_after_cache_expiry", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // Generate a registration ID that doesn't exist in cache + // This simulates an expired/missing cache entry + regID, err := types.NewRegistrationID() + if err != nil { + return "", err + } + // Don't add it to cache - it's already expired/missing + return regID.String(), nil + }, + request: func(regID string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: "http://localhost:8080/register/" + regID, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "expired-cache-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: false, // Should not be authorized yet - needs to use new AuthURL + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Should get a new AuthURL, not an error + assert.NotEmpty(t, resp.AuthURL, "should receive new AuthURL when registration expired") + assert.Contains(t, resp.AuthURL, "/register/", "AuthURL should contain registration path") + assert.False(t, resp.MachineAuthorized, "machine should not be authorized yet") + + // Verify the response contains a valid registration URL + authURL, err := url.Parse(resp.AuthURL) + assert.NoError(t, err, "AuthURL should be a valid URL") + assert.True(t, strings.HasPrefix(authURL.Path, "/register/"), "AuthURL path should start with /register/") + + // Extract and validate the new registration ID exists in cache + newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") + newRegID, err := types.RegistrationIDFromString(newRegIDStr) + assert.NoError(t, err, "should be able to parse new registration ID") + + // Verify new registration entry exists in cache + _, found := app.state.GetRegistrationCacheEntry(newRegID) + assert.True(t, found, "new registration should exist in cache") + }, + }, + // TEST: Logout with expiry exactly at current time + // WHAT: Tests logout when expiry is set to exact current time (boundary case) + // INPUT: Existing node sends request with expiry=time.Now() (not past, not future) + // EXPECTED: Node is logged out (treated as expired) + // WHY: Edge case: current time should be treated as expired + { + name: "logout_with_exactly_now_expiry", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("exact-now-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node first + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "exact-now-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now(), // Exactly now (edge case between past and future) + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + wantExpired: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.True(t, resp.NodeKeyExpired) + + // Node should be marked as expired but still exist + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.True(t, node.IsExpired()) + }, + }, + // TEST: Interactive workflow timeout cleans up cache + // WHAT: Tests that timed-out interactive registrations clean up cache entries + // INPUT: Interactive registration that times out without completion + // EXPECTED: Cache entry should be cleaned up (behavior depends on implementation) + // WHY: Prevents cache bloat from abandoned registrations + { + name: "interactive_workflow_timeout_cleanup", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-timeout-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey2.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + // NOTE: No auth_completion step - simulates timeout scenario + }, + validateRegistrationCache: true, // should be cleaned up eventually + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Verify AuthURL was generated but registration not completed + assert.Contains(t, resp.AuthURL, "/register/") + assert.False(t, resp.MachineAuthorized) + }, + }, + + // === COMPREHENSIVE INTERACTIVE WORKFLOW EDGE CASES === + // TEST: Interactive workflow with existing node from different user creates new node + // WHAT: Tests new node creation when re-authenticating interactively with different user + // INPUT: Node already registered with user1, interactive auth with user2 (same machine key, different node key) + // EXPECTED: New node is created for user2, user1's original node remains (no transfer) + // WHY: Same physical machine can have separate node identities per user + { + name: "interactive_workflow_with_existing_node_different_user_creates_new_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // First create a node under user1 + user1 := app.state.CreateUserForTest("existing-user-1") + pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node with user1 first + initialReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "existing-node-user1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{}, // Empty auth triggers interactive flow + NodeKey: nodeKey2.Public(), // Different node key for different user + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "existing-node-user2", // Different hostname + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // User1's original node with nodeKey1 should STILL exist + node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found1, "user1's original node with nodeKey1 should still exist") + assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1") + assert.Equal(t, uint64(1), node1.ID().Uint64(), "user1's node should be ID=1") + + // User2 should have a NEW node with nodeKey2 + node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found2, "user2 should have new node with nodeKey2") + + assert.Equal(t, "existing-node-user2", node2.Hostname(), "hostname should be from new registration") + user := node2.User() + assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2") + assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "machine key should be the same") + + // Verify it's a NEW node, not transferred + assert.NotEqual(t, uint64(1), node2.ID().Uint64(), "should be a NEW node (different ID)") + }, + }, + // TEST: Interactive workflow with malformed followup URL + // WHAT: Tests that malformed followup URLs in interactive flow are rejected + // INPUT: Interactive registration with invalid followup URL format + // EXPECTED: Request fails with error (invalid URL) + // WHY: Validates followup URLs to prevent errors + { + name: "interactive_workflow_malformed_followup_url", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "malformed-followup-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Test malformed followup URLs after getting initial AuthURL + authURL := resp.AuthURL + assert.Contains(t, authURL, "/register/") + + // Test various malformed followup URLs - use completely invalid IDs to avoid blocking + malformedURLs := []string{ + "invalid-url", + "/register/", + "/register/invalid-id-that-does-not-exist", + "/register/00000000-0000-0000-0000-000000000000", + "http://malicious-site.com/register/invalid-id", + } + + for _, malformedURL := range malformedURLs { + followupReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Followup: malformedURL, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "malformed-followup-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + // These should all fail gracefully + _, err := app.handleRegister(context.Background(), followupReq, machineKey1.Public()) + assert.Error(t, err, "malformed followup URL should be rejected: %s", malformedURL) + } + }, + }, + // TEST: Concurrent interactive workflow registrations + // WHAT: Tests multiple simultaneous interactive registrations + // INPUT: Two nodes initiate interactive registration concurrently + // EXPECTED: Both registrations succeed independently + // WHY: System should handle concurrent interactive flows without conflicts + { + name: "interactive_workflow_concurrent_registrations", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "concurrent-registration-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // This test validates concurrent interactive registration attempts + assert.Contains(t, resp.AuthURL, "/register/") + + // Start multiple concurrent followup requests + authURL := resp.AuthURL + numConcurrent := 3 + results := make(chan error, numConcurrent) + + for i := range numConcurrent { + go func(index int) { + followupReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Followup: authURL, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: fmt.Sprintf("concurrent-node-%d", index), + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err := app.handleRegister(context.Background(), followupReq, machineKey1.Public()) + results <- err + }(i) + } + + // All should wait since no auth completion happened + // After a short delay, they should timeout or be waiting + time.Sleep(100 * time.Millisecond) + + // Now complete the authentication to signal one of them + registrationID, err := extractRegistrationIDFromAuthURL(authURL) + require.NoError(t, err) + + user := app.state.CreateUserForTest("concurrent-test-user") + _, _, err = app.state.HandleNodeFromAuthPath( + registrationID, + types.UserID(user.ID), + nil, + "concurrent-test-method", + ) + require.NoError(t, err) + + // Collect results - at least one should succeed + successCount := 0 + for range numConcurrent { + select { + case err := <-results: + if err == nil { + successCount++ + } + case <-time.After(2 * time.Second): + // Some may timeout, which is expected + } + } + + // At least one concurrent request should have succeeded + assert.GreaterOrEqual(t, successCount, 1, "at least one concurrent registration should succeed") + }, + }, + // TEST: Interactive workflow with node key rotation attempt + // WHAT: Tests interactive registration with different node key (appears as rotation) + // INPUT: Node registered with nodeKey1, then interactive registration with nodeKey2 + // EXPECTED: Creates new node for different user (not true rotation) + // WHY: Interactive flow creates new nodes with new users; doesn't rotate existing nodes + { + name: "interactive_workflow_node_key_rotation", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // Register initial node + user := app.state.CreateUserForTest("rotation-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + if err != nil { + return "", err + } + + initialReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "rotation-node-initial", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey2.Public(), // Different node key (rotation scenario) + OldNodeKey: nodeKey1.Public(), // Previous node key + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "rotation-node-updated", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // User1's original node with nodeKey1 should STILL exist + oldNode, foundOld := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, foundOld, "user1's original node with nodeKey1 should still exist") + assert.Equal(t, uint(1), oldNode.UserID(), "user1's node should still belong to user1") + assert.Equal(t, uint64(1), oldNode.ID().Uint64(), "user1's node should be ID=1") + + // User2 should have a NEW node with nodeKey2 + newNode, found := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found, "user2 should have new node with nodeKey2") + assert.Equal(t, "rotation-node-updated", newNode.Hostname()) + assert.Equal(t, machineKey1.Public(), newNode.MachineKey()) + + user := newNode.User() + assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2") + + // Verify it's a NEW node, not transferred + assert.NotEqual(t, uint64(1), newNode.ID().Uint64(), "should be a NEW node (different ID)") + }, + }, + // TEST: Interactive workflow with nil hostinfo + // WHAT: Tests interactive registration when request has nil hostinfo + // INPUT: Interactive registration request with Hostinfo=nil + // EXPECTED: Node registers successfully with generated default hostname + // WHY: Defensive code handles nil hostinfo in interactive flow + { + name: "interactive_workflow_with_nil_hostinfo", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: nil, // Nil hostinfo should be handled gracefully + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Should handle nil hostinfo gracefully + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found, "node should be registered despite nil hostinfo") + if found { + // Should have some default hostname or handle nil gracefully + hostname := node.Hostname() + assert.NotEmpty(t, hostname, "should have some hostname even with nil hostinfo") + } + }, + }, + // TEST: Registration cache cleanup on authentication error + // WHAT: Tests that cache is cleaned up when authentication fails + // INPUT: Interactive registration that fails during auth completion + // EXPECTED: Cache entry removed after error + // WHY: Failed registrations should clean up to prevent stale cache entries + { + name: "interactive_workflow_registration_cache_cleanup_on_error", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "cache-cleanup-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Get initial AuthURL and extract registration ID + authURL := resp.AuthURL + assert.Contains(t, authURL, "/register/") + + registrationID, err := extractRegistrationIDFromAuthURL(authURL) + require.NoError(t, err) + + // Verify cache entry exists + cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + assert.True(t, found, "registration cache entry should exist initially") + assert.NotNil(t, cacheEntry) + + // Try to complete authentication with invalid user ID (should cause error) + invalidUserID := types.UserID(99999) // Non-existent user + _, _, err = app.state.HandleNodeFromAuthPath( + registrationID, + invalidUserID, + nil, + "error-test-method", + ) + assert.Error(t, err, "should fail with invalid user ID") + + // Cache entry should still exist after auth error (for retry scenarios) + _, stillFound := app.state.GetRegistrationCacheEntry(registrationID) + assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry") + }, + }, + // TEST: Multiple interactive workflow steps for same node + // WHAT: Tests that interactive workflow can handle multi-step process for same node + // INPUT: Node goes through complete interactive flow with multiple steps + // EXPECTED: Node successfully completes registration after all steps + // WHY: Validates complete interactive flow works end-to-end + // TEST: Interactive workflow with multiple registration attempts for same node + // WHAT: Tests that multiple interactive registrations can be created for same node + // INPUT: Start two interactive registrations, verify both cache entries exist + // EXPECTED: Both registrations get different IDs and can coexist + // WHY: Validates that multiple pending registrations don't interfere with each other + { + name: "interactive_workflow_multiple_steps_same_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "multi-step-node", + OS: "linux", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Test multiple interactive registration attempts for the same node can coexist + authURL1 := resp.AuthURL + assert.Contains(t, authURL1, "/register/") + + // Start a second interactive registration for the same node + secondReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "multi-step-node-updated", + OS: "linux-updated", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) + require.NoError(t, err) + authURL2 := resp2.AuthURL + assert.Contains(t, authURL2, "/register/") + + // Both should have different registration IDs + regID1, err1 := extractRegistrationIDFromAuthURL(authURL1) + regID2, err2 := extractRegistrationIDFromAuthURL(authURL2) + require.NoError(t, err1) + require.NoError(t, err2) + assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") + + // Both cache entries should exist simultaneously + _, found1 := app.state.GetRegistrationCacheEntry(regID1) + _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first registration cache entry should exist") + assert.True(t, found2, "second registration cache entry should exist") + + // This validates that multiple pending registrations can coexist + // without interfering with each other + }, + }, + // TEST: Complete one of multiple pending registrations + // WHAT: Tests completing the second of two pending registrations for same node + // INPUT: Create two pending registrations, complete the second one + // EXPECTED: Second registration completes successfully, node is created + // WHY: Validates that you can complete any pending registration, not just the first + { + name: "interactive_workflow_complete_second_of_multiple_pending", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "pending-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + authURL1 := resp.AuthURL + regID1, err := extractRegistrationIDFromAuthURL(authURL1) + require.NoError(t, err) + + // Start a second interactive registration for the same node + secondReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "pending-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) + require.NoError(t, err) + authURL2 := resp2.AuthURL + regID2, err := extractRegistrationIDFromAuthURL(authURL2) + require.NoError(t, err) + + // Verify both exist + _, found1 := app.state.GetRegistrationCacheEntry(regID1) + _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first cache entry should exist") + assert.True(t, found2, "second cache entry should exist") + + // Complete the SECOND registration (not the first) + user := app.state.CreateUserForTest("second-registration-user") + + // Start followup request in goroutine (it will wait for auth completion) + responseChan := make(chan *tailcfg.RegisterResponse, 1) + errorChan := make(chan error, 1) + + followupReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Followup: authURL2, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "pending-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + go func() { + resp, err := app.handleRegister(context.Background(), followupReq, machineKey1.Public()) + if err != nil { + errorChan <- err + return + } + responseChan <- resp + }() + + // Give followup time to start waiting + time.Sleep(50 * time.Millisecond) + + // Complete authentication for second registration + _, _, err = app.state.HandleNodeFromAuthPath( + regID2, + types.UserID(user.ID), + nil, + "second-registration-method", + ) + require.NoError(t, err) + + // Wait for followup to complete + select { + case err := <-errorChan: + t.Fatalf("followup request failed: %v", err) + case finalResp := <-responseChan: + require.NotNil(t, finalResp) + assert.True(t, finalResp.MachineAuthorized, "machine should be authorized") + case <-time.After(2 * time.Second): + t.Fatal("followup request timed out") + } + + // Verify the node was created with the second registration's data + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found, "node should be registered") + if found { + assert.Equal(t, "pending-node-2", node.Hostname()) + assert.Equal(t, "second-registration-user", node.User().Name) + } + + // First registration should still be in cache (not completed) + _, stillFound := app.state.GetRegistrationCacheEntry(regID1) + assert.True(t, stillFound, "first registration should still be pending") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test app + app := createTestApp(t) + + // Run setup function + dynamicValue, err := tt.setupFunc(t, app) + require.NoError(t, err, "setup should not fail") + + // Check if this test requires interactive workflow + if tt.requiresInteractiveFlow { + runInteractiveWorkflowTest(t, tt, app, dynamicValue) + return + } + + // Build request + req := tt.request(dynamicValue) + machineKey := tt.machineKey() + + // Set up context with timeout for followup tests + ctx := context.Background() + if req.Followup != "" { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + } + + // Debug: check node availability before test execution + if req.Auth == nil { + if node, found := app.state.GetNodeByNodeKey(req.NodeKey); found { + t.Logf("Node found before handleRegister: hostname=%s, expired=%t", node.Hostname(), node.IsExpired()) + } else { + t.Logf("Node NOT found before handleRegister for key %s", req.NodeKey.ShortString()) + } + } + + // Execute the test + resp, err := app.handleRegister(ctx, req, machineKey) + + // Validate error expectations + if tt.wantError { + assert.Error(t, err, "expected error but got none") + return + } + + require.NoError(t, err, "unexpected error: %v", err) + require.NotNil(t, resp, "response should not be nil") + + // Validate basic response properties + if tt.wantAuth { + assert.True(t, resp.MachineAuthorized, "machine should be authorized") + } else { + assert.False(t, resp.MachineAuthorized, "machine should not be authorized") + } + + if tt.wantAuthURL { + assert.NotEmpty(t, resp.AuthURL, "should have AuthURL") + assert.Contains(t, resp.AuthURL, "register/", "AuthURL should contain registration path") + } + + if tt.wantExpired { + assert.True(t, resp.NodeKeyExpired, "node key should be expired") + } else { + assert.False(t, resp.NodeKeyExpired, "node key should not be expired") + } + + // Run custom validation if provided + if tt.validate != nil { + tt.validate(t, resp, app) + } + }) + } +} + +// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow +func runInteractiveWorkflowTest(t *testing.T, tt struct { + name string + setupFunc func(*testing.T, *Headscale) (string, error) + request func(dynamicValue string) tailcfg.RegisterRequest + machineKey func() key.MachinePublic + wantAuth bool + wantError bool + wantAuthURL bool + wantExpired bool + validate func(*testing.T, *tailcfg.RegisterResponse, *Headscale) + requiresInteractiveFlow bool + interactiveSteps []interactiveStep + validateRegistrationCache bool + expectedAuthURLPattern string + simulateAuthCompletion bool + validateCompleteResponse bool +}, app *Headscale, dynamicValue string, +) { + // Build initial request + req := tt.request(dynamicValue) + machineKey := tt.machineKey() + ctx := context.Background() + + // Execute interactive workflow steps + var ( + initialResp *tailcfg.RegisterResponse + authURL string + registrationID types.RegistrationID + finalResp *tailcfg.RegisterResponse + err error + ) + + // Execute the steps in the correct sequence for interactive workflow + for i, step := range tt.interactiveSteps { + t.Logf("Executing interactive step %d: %s", i+1, step.stepType) + + switch step.stepType { + case stepTypeInitialRequest: + // Step 1: Initial request should get AuthURL back + initialResp, err = app.handleRegister(ctx, req, machineKey) + require.NoError(t, err, "initial request should not fail") + require.NotNil(t, initialResp, "initial response should not be nil") + + if step.expectAuthURL { + require.NotEmpty(t, initialResp.AuthURL, "should have AuthURL") + require.Contains(t, initialResp.AuthURL, "/register/", "AuthURL should contain registration path") + authURL = initialResp.AuthURL + + // Extract registration ID from AuthURL + registrationID, err = extractRegistrationIDFromAuthURL(authURL) + require.NoError(t, err, "should be able to extract registration ID from AuthURL") + } + + if step.expectCacheEntry { + // Verify registration cache entry was created + cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + require.True(t, found, "registration cache entry should exist") + require.NotNil(t, cacheEntry, "cache entry should not be nil") + require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key") + } + + case stepTypeAuthCompletion: + // Step 2: Start followup request that will wait, then complete authentication + if step.callAuthPath { + require.NotEmpty(t, registrationID, "registration ID should be available from previous step") + + // Prepare followup request + followupReq := tt.request(dynamicValue) + followupReq.Followup = authURL + + // Start the followup request in a goroutine - it will wait for channel signal + responseChan := make(chan *tailcfg.RegisterResponse, 1) + errorChan := make(chan error, 1) + + go func() { + resp, err := app.handleRegister(context.Background(), followupReq, machineKey) + if err != nil { + errorChan <- err + return + } + responseChan <- resp + }() + + // Give the followup request time to start waiting + time.Sleep(50 * time.Millisecond) + + // Now complete the authentication - this will signal the waiting followup request + user := app.state.CreateUserForTest("interactive-test-user") + _, _, err = app.state.HandleNodeFromAuthPath( + registrationID, + types.UserID(user.ID), + nil, // no custom expiry + "test-method", + ) + require.NoError(t, err, "HandleNodeFromAuthPath should succeed") + + // Wait for the followup request to complete + select { + case err := <-errorChan: + require.NoError(t, err, "followup request should not fail") + case finalResp = <-responseChan: + require.NotNil(t, finalResp, "final response should not be nil") + // Verify machine is now authorized + require.True(t, finalResp.MachineAuthorized, "machine should be authorized after followup") + case <-time.After(5 * time.Second): + t.Fatal("followup request timed out waiting for authentication completion") + } + } + + case stepTypeFollowupRequest: + // This step is deprecated - followup is now handled within auth_completion step + t.Logf("followup_request step is deprecated - use expectCacheEntry in auth_completion instead") + + default: + t.Fatalf("unknown interactive step type: %s", step.stepType) + } + + // Check cache cleanup expectation for this step + if step.expectCacheEntry == false && registrationID != "" { + // Verify cache entry was cleaned up + _, found := app.state.GetRegistrationCacheEntry(registrationID) + require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType) + } + } + + // Validate final response if requested + if tt.validateCompleteResponse && finalResp != nil { + validateCompleteRegistrationResponse(t, finalResp, req) + } + + // Run custom validation if provided + if tt.validate != nil { + responseToValidate := finalResp + if responseToValidate == nil { + responseToValidate = initialResp + } + tt.validate(t, responseToValidate, app) + } +} + +// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL +func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { + // AuthURL format: "http://localhost/register/abc123" + const registerPrefix = "/register/" + idx := strings.LastIndex(authURL, registerPrefix) + if idx == -1 { + return "", fmt.Errorf("invalid AuthURL format: %s", authURL) + } + + idStr := authURL[idx+len(registerPrefix):] + return types.RegistrationIDFromString(idStr) +} + +// validateCompleteRegistrationResponse performs comprehensive validation of a registration response +func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) { + // Basic response validation + require.NotNil(t, resp, "response should not be nil") + require.True(t, resp.MachineAuthorized, "machine should be authorized") + require.False(t, resp.NodeKeyExpired, "node key should not be expired") + require.NotEmpty(t, resp.User.DisplayName, "user should have display name") + + // Additional validation can be added here as needed + // Note: NodeKey field may not be present in all response types + + // Additional validation can be added here as needed +} + +// Simple test to validate basic node creation and lookup +func TestNodeStoreLookup(t *testing.T) { + app := createTestApp(t) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + user := app.state.CreateUserForTest("test-user") + pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + require.NoError(t, err) + + // Register a node + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.NotNil(t, resp) + require.True(t, resp.MachineAuthorized) + + t.Logf("Registered node successfully: %+v", resp) + + // Wait for node to be available in NodeStore + var node types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { + var found bool + node, found = app.state.GetNodeByNodeKey(nodeKey.Public()) + assert.True(c, found, "Node should be found in NodeStore") + }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore") + + require.Equal(t, "test-node", node.Hostname()) + + t.Logf("Found node: hostname=%s, id=%d", node.Hostname(), node.ID().Uint64()) +} + +// TestPreAuthKeyLogoutAndReloginDifferentUser tests the scenario where: +// 1. Multiple nodes register with different users using pre-auth keys +// 2. All nodes logout +// 3. All nodes re-login using a different user's pre-auth key +// EXPECTED BEHAVIOR: Should create NEW nodes for the new user, leaving old nodes with the old user. +// This matches the integration test expectation and web flow behavior. +func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { + app := createTestApp(t) + + // Create two users + user1 := app.state.CreateUserForTest("user1") + user2 := app.state.CreateUserForTest("user2") + + // Create pre-auth keys for both users + pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + require.NoError(t, err) + pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil) + require.NoError(t, err) + + // Create machine and node keys for 4 nodes (2 per user) + type nodeInfo struct { + machineKey key.MachinePrivate + nodeKey key.NodePrivate + hostname string + nodeID types.NodeID + } + + nodes := []nodeInfo{ + {machineKey: key.NewMachine(), nodeKey: key.NewNode(), hostname: "user1-node1"}, + {machineKey: key.NewMachine(), nodeKey: key.NewNode(), hostname: "user1-node2"}, + {machineKey: key.NewMachine(), nodeKey: key.NewNode(), hostname: "user2-node1"}, + {machineKey: key.NewMachine(), nodeKey: key.NewNode(), hostname: "user2-node2"}, + } + + // Register nodes: first 2 to user1, last 2 to user2 + for i, node := range nodes { + authKey := pak1.Key + if i >= 2 { + authKey = pak2.Key + } + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: node.nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: node.hostname, + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, node.machineKey.Public()) + require.NoError(t, err) + require.NotNil(t, resp) + require.True(t, resp.MachineAuthorized) + + // Get the node ID + var registeredNode types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { + var found bool + registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public()) + assert.True(c, found, "Node should be found in NodeStore") + }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available") + + nodes[i].nodeID = registeredNode.ID() + t.Logf("Registered node %s with ID %d to user%d", node.hostname, registeredNode.ID().Uint64(), i/2+1) + } + + // Verify initial state: user1 has 2 nodes, user2 has 2 nodes + user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID)) + user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID)) + require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially") + require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially") + + t.Logf("Initial state verified: user1=%d nodes, user2=%d nodes", user1Nodes.Len(), user2Nodes.Len()) + + // Simulate logout for all nodes + for _, node := range nodes { + logoutReq := tailcfg.RegisterRequest{ + Auth: nil, // nil Auth indicates logout + NodeKey: node.nodeKey.Public(), + } + + resp, err := app.handleRegister(context.Background(), logoutReq, node.machineKey.Public()) + require.NoError(t, err) + t.Logf("Logout response for %s: %+v", node.hostname, resp) + } + + t.Logf("All nodes logged out") + + // Create a new pre-auth key for user1 (reusable for all nodes) + newPak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + require.NoError(t, err) + + // Re-login all nodes using user1's new pre-auth key + for i, node := range nodes { + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: newPak1.Key, + }, + NodeKey: node.nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: node.hostname, + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, node.machineKey.Public()) + require.NoError(t, err) + require.NotNil(t, resp) + require.True(t, resp.MachineAuthorized) + + t.Logf("Re-registered node %s (originally user%d) with user1's pre-auth key", node.hostname, i/2+1) + } + + // Verify final state after re-login + // EXPECTED: New nodes created for user1, old nodes remain with original users + user1NodesAfter := app.state.ListNodesByUser(types.UserID(user1.ID)) + user2NodesAfter := app.state.ListNodesByUser(types.UserID(user2.ID)) + + t.Logf("Final state: user1=%d nodes, user2=%d nodes", user1NodesAfter.Len(), user2NodesAfter.Len()) + + // CORRECT BEHAVIOR: When re-authenticating with a DIFFERENT user's pre-auth key, + // new nodes should be created (not transferred). This matches: + // 1. The integration test expectation + // 2. The web flow behavior (creates new nodes) + // 3. The principle that each user owns distinct node entries + require.Equal(t, 4, user1NodesAfter.Len(), "user1 should have 4 nodes total (2 original + 2 new from user2's machines)") + require.Equal(t, 2, user2NodesAfter.Len(), "user2 should still have 2 nodes (old nodes from original registration)") + + // Verify original nodes still exist with original users + for i := 0; i < 2; i++ { + node := nodes[i] + // User1's original nodes should still be owned by user1 + registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID)) + require.True(t, found, "User1's original node %s should still exist", node.hostname) + require.Equal(t, user1.ID, registeredNode.UserID(), "Node %s should still belong to user1", node.hostname) + t.Logf("✓ User1's original node %s (ID=%d) still owned by user1", node.hostname, registeredNode.ID().Uint64()) + } + + for i := 2; i < 4; i++ { + node := nodes[i] + // User2's original nodes should still be owned by user2 + registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user2.ID)) + require.True(t, found, "User2's original node %s should still exist", node.hostname) + require.Equal(t, user2.ID, registeredNode.UserID(), "Node %s should still belong to user2", node.hostname) + t.Logf("✓ User2's original node %s (ID=%d) still owned by user2", node.hostname, registeredNode.ID().Uint64()) + } + + // Verify new nodes were created for user1 with the same machine keys + t.Logf("Verifying new nodes created for user1 from user2's machine keys...") + for i := 2; i < 4; i++ { + node := nodes[i] + // Should be able to find a node with user1 and this machine key (the new one) + newNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID)) + require.True(t, found, "Should have created new node for user1 with machine key from %s", node.hostname) + require.Equal(t, user1.ID, newNode.UserID(), "New node should belong to user1") + t.Logf("✓ New node created for user1 with machine key from %s (ID=%d)", node.hostname, newNode.ID().Uint64()) + } +} + +// TestWebFlowReauthDifferentUser validates CLI registration behavior when switching users. +// This test replicates the TestAuthWebFlowLogoutAndReloginNewUser integration test scenario. +// +// IMPORTANT: CLI registration creates NEW nodes (different from interactive flow which transfers). +// +// Scenario: +// 1. Node registers with user1 via pre-auth key +// 2. Node logs out (expires) +// 3. Admin runs: headscale nodes register --user user2 --key +// +// Expected behavior: +// - User1's original node should STILL EXIST (expired) +// - User2 should get a NEW node created (NOT transfer) +// - Both nodes share the same machine key (same physical device) +func TestWebFlowReauthDifferentUser(t *testing.T) { + machineKey := key.NewMachine() + nodeKey1 := key.NewNode() + nodeKey2 := key.NewNode() // Node key rotates on re-auth + + app := createTestApp(t) + + // Step 1: Register node for user1 via pre-auth key (simulating initial web flow registration) + user1 := app.state.CreateUserForTest("user1") + pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + require.NoError(t, err) + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-machine", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized, "Should be authorized via pre-auth key") + + // Verify node exists for user1 + user1Node, found := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) + require.True(t, found, "Node should exist for user1") + require.Equal(t, user1.ID, user1Node.UserID(), "Node should belong to user1") + user1NodeID := user1Node.ID() + t.Logf("✓ User1 node created with ID: %d", user1NodeID) + + // Step 2: Simulate logout by expiring the node + pastTime := time.Now().Add(-1 * time.Hour) + logoutReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Expiry: pastTime, // Expired = logout + } + _, err = app.handleRegister(context.Background(), logoutReq, machineKey.Public()) + require.NoError(t, err) + + // Verify node is expired + user1Node, found = app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) + require.True(t, found, "Node should still exist after logout") + require.True(t, user1Node.IsExpired(), "Node should be expired after logout") + t.Logf("✓ User1 node expired (logged out)") + + // Step 3: Start interactive re-authentication (simulates "tailscale up") + user2 := app.state.CreateUserForTest("user2") + + reAuthReq := tailcfg.RegisterRequest{ + // No Auth field - triggers interactive flow + NodeKey: nodeKey2.Public(), // New node key (rotated on re-auth) + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-machine", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + // Initial request should return AuthURL + initialResp, err := app.handleRegister(context.Background(), reAuthReq, machineKey.Public()) + require.NoError(t, err) + require.NotEmpty(t, initialResp.AuthURL, "Should receive AuthURL for interactive flow") + t.Logf("✓ Interactive flow started, AuthURL: %s", initialResp.AuthURL) + + // Extract registration ID from AuthURL + regID, err := extractRegistrationIDFromAuthURL(initialResp.AuthURL) + require.NoError(t, err, "Should extract registration ID from AuthURL") + require.NotEmpty(t, regID, "Should have valid registration ID") + + // Step 4: Admin completes authentication via CLI + // This simulates: headscale nodes register --user user2 --key + node, _, err := app.state.HandleNodeFromAuthPath( + regID, + types.UserID(user2.ID), // Register to user2, not user1! + nil, // No custom expiry + "cli", // Registration method (CLI register command) + ) + require.NoError(t, err, "HandleNodeFromAuthPath should succeed") + t.Logf("✓ Admin registered node to user2 via CLI (node ID: %d)", node.ID()) + + t.Run("user1_original_node_still_exists", func(t *testing.T) { + // User1's original node should STILL exist (not transferred to user2) + user1NodeAfter, found1 := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) + assert.True(t, found1, "User1's original node should still exist (not transferred)") + + if !found1 { + t.Fatal("User1's node was transferred or deleted - this breaks the integration test!") + } + + assert.Equal(t, user1.ID, user1NodeAfter.UserID(), "User1's node should still belong to user1") + assert.Equal(t, user1NodeID, user1NodeAfter.ID(), "Should be the same node (same ID)") + assert.True(t, user1NodeAfter.IsExpired(), "User1's node should still be expired") + t.Logf("✓ User1's original node still exists (ID: %d, expired: %v)", user1NodeAfter.ID(), user1NodeAfter.IsExpired()) + }) + + t.Run("user2_has_new_node_created", func(t *testing.T) { + // User2 should have a NEW node created (not transfer from user1) + user2Node, found2 := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user2.ID)) + assert.True(t, found2, "User2 should have a new node created") + + if !found2 { + t.Fatal("User2 doesn't have a node - registration failed!") + } + + assert.Equal(t, user2.ID, user2Node.UserID(), "User2's node should belong to user2") + assert.NotEqual(t, user1NodeID, user2Node.ID(), "Should be a NEW node (different ID), not transfer!") + assert.Equal(t, machineKey.Public(), user2Node.MachineKey(), "Should have same machine key") + assert.Equal(t, nodeKey2.Public(), user2Node.NodeKey(), "Should have new node key") + assert.False(t, user2Node.IsExpired(), "User2's node should NOT be expired (active)") + t.Logf("✓ User2's new node created (ID: %d, active)", user2Node.ID()) + }) + + t.Run("returned_node_is_user2_new_node", func(t *testing.T) { + // The node returned from HandleNodeFromAuthPath should be user2's NEW node + assert.Equal(t, user2.ID, node.UserID(), "Returned node should belong to user2") + assert.NotEqual(t, user1NodeID, node.ID(), "Returned node should be NEW, not transferred from user1") + t.Logf("✓ HandleNodeFromAuthPath returned user2's new node (ID: %d)", node.ID()) + }) + + t.Run("both_nodes_share_machine_key", func(t *testing.T) { + // Both nodes should have the same machine key (same physical device) + user1NodeFinal, found1 := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) + user2NodeFinal, found2 := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user2.ID)) + + require.True(t, found1, "User1 node should exist") + require.True(t, found2, "User2 node should exist") + + assert.Equal(t, machineKey.Public(), user1NodeFinal.MachineKey(), "User1 node should have correct machine key") + assert.Equal(t, machineKey.Public(), user2NodeFinal.MachineKey(), "User2 node should have same machine key") + t.Logf("✓ Both nodes share machine key: %s", machineKey.Public().ShortString()) + }) + + t.Run("total_node_count", func(t *testing.T) { + // We should have exactly 2 nodes total: one for user1 (expired), one for user2 (active) + allNodesSlice := app.state.ListNodes() + assert.Equal(t, 2, allNodesSlice.Len(), "Should have exactly 2 nodes total") + + // Count nodes per user + user1Nodes := 0 + user2Nodes := 0 + for i := 0; i < allNodesSlice.Len(); i++ { + n := allNodesSlice.At(i) + if n.UserID() == user1.ID { + user1Nodes++ + } + if n.UserID() == user2.ID { + user2Nodes++ + } + } + + assert.Equal(t, 1, user1Nodes, "User1 should have 1 node") + assert.Equal(t, 1, user2Nodes, "User2 should have 1 node") + t.Logf("✓ Total: 2 nodes (user1: 1 expired, user2: 1 active)") + }) +} + +// Helper function to create test app +func createTestApp(t *testing.T) *Headscale { + t.Helper() + + tmpDir := t.TempDir() + + cfg := types.Config{ + ServerURL: "http://localhost:8080", + NoisePrivateKeyPath: tmpDir + "/noise_private.key", + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + OIDC: types.OIDCConfig{}, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, + Tuning: types.Tuning{ + BatchChangeDelay: 100 * time.Millisecond, + BatcherWorkers: 1, + }, + } + + app, err := NewHeadscale(&cfg) + require.NoError(t, err) + + // Initialize and start the mapBatcher to handle Change() calls + app.mapBatcher = mapper.NewBatcherAndMapper(&cfg, app.state) + app.mapBatcher.Start() + + // Clean up the batcher when the test finishes + t.Cleanup(func() { + if app.mapBatcher != nil { + app.mapBatcher.Close() + } + }) + + return app +} diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 01d3c6b3..6290e065 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -741,7 +741,7 @@ func (api headscaleV1APIServer) DebugCreateNode( hostinfo := tailcfg.Hostinfo{ RoutableIPs: routes, OS: "TestOS", - Hostname: "DebugTestNode", + Hostname: request.GetName(), } registrationId, err := types.RegistrationIDFromString(request.GetKey()) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index ada9fd15..4324ffba 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -197,11 +197,12 @@ func (m *mapSession) serveLongPoll() { m.keepAliveTicker = time.NewTicker(m.keepAlive) // Process the initial MapRequest to update node state (endpoints, hostinfo, etc.) - // CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly - // synchronized. When nodes reconnect, they send their hostinfo with announced routes - // in the MapRequest. We need this data in NodeStore before Connect() sets up the - // primary routes, otherwise SubnetRoutes() returns empty and the node is removed - // from AvailableRoutes. + // This must be done BEFORE calling Connect() to ensure routes are properly synchronized. + // When nodes reconnect, they send their hostinfo with announced routes in the MapRequest. + // We need this data in NodeStore before Connect() sets up the primary routes, because + // SubnetRoutes() calculates the intersection of announced and approved routes. If we + // call Connect() first, SubnetRoutes() returns empty (no announced routes yet), causing + // the node to be incorrectly removed from AvailableRoutes. mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req) if err != nil { m.errf(err, "failed to update node from initial MapRequest") diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 7c60128f..03d6854f 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -60,9 +60,6 @@ type DebugStringInfo struct { // DebugOverview returns a comprehensive overview of the current state for debugging. func (s *State) DebugOverview() string { - s.mu.RLock() - defer s.mu.RUnlock() - allNodes := s.nodeStore.ListNodes() users, _ := s.ListAllUsers() @@ -270,9 +267,6 @@ func (s *State) PolicyDebugString() string { // DebugOverviewJSON returns a structured overview of the current state for debugging. func (s *State) DebugOverviewJSON() DebugOverviewInfo { - s.mu.RLock() - defer s.mu.RUnlock() - allNodes := s.nodeStore.ListNodes() users, _ := s.ListAllUsers() diff --git a/hscontrol/state/debug_test.go b/hscontrol/state/debug_test.go index ae6c340b..60d77245 100644 --- a/hscontrol/state/debug_test.go +++ b/hscontrol/state/debug_test.go @@ -33,8 +33,8 @@ func TestNodeStoreDebugString(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc) store.Start() - store.PutNode(node1) - store.PutNode(node2) + _ = store.PutNode(node1) + _ = store.PutNode(node2) return store }, diff --git a/hscontrol/state/ephemeral_test.go b/hscontrol/state/ephemeral_test.go new file mode 100644 index 00000000..e3acc9b9 --- /dev/null +++ b/hscontrol/state/ephemeral_test.go @@ -0,0 +1,460 @@ +package state + +import ( + "net/netip" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/ptr" +) + +// TestEphemeralNodeDeleteWithConcurrentUpdate tests the race condition where UpdateNode and DeleteNode +// are called concurrently and may be batched together. This reproduces the issue where ephemeral nodes +// are not properly deleted during logout because UpdateNodeFromMapRequest returns a stale node view +// after the node has been deleted from the NodeStore. +func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { + // Create a simple test node + node := createTestNode(1, 1, "test-user", "test-node") + + // Create NodeStore + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + // Put the node in the store + resultNode := store.PutNode(node) + require.True(t, resultNode.Valid(), "initial PutNode should return valid node") + + // Verify node exists + retrievedNode, found := store.GetNode(node.ID) + require.True(t, found) + require.Equal(t, node.ID, retrievedNode.ID()) + + // Test scenario: UpdateNode is called, returns a node view from the batch, + // but in the same batch a DeleteNode removes the node. + // This simulates what happens when: + // 1. UpdateNodeFromMapRequest calls UpdateNode and gets back updatedNode + // 2. At the same time, handleLogout calls DeleteNode + // 3. They get batched together: [UPDATE, DELETE] + // 4. UPDATE modifies the node, DELETE removes it + // 5. UpdateNode returns a node view based on the state AFTER both operations + // 6. If DELETE came after UPDATE, the returned node should be invalid + + done := make(chan bool, 2) + var updatedNode types.NodeView + var updateOk bool + + // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) + go func() { + updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + done <- true + }() + + // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) + go func() { + // Small delay to increase chance of batching together + time.Sleep(1 * time.Millisecond) + store.DeleteNode(node.ID) + done <- true + }() + + // Wait for both operations + <-done + <-done + + // Give batching time to complete + time.Sleep(50 * time.Millisecond) + + // The key assertion: if UpdateNode and DeleteNode were batched together + // with DELETE after UPDATE, then UpdateNode should return an invalid node + // OR it should return a valid node but the node should no longer exist in the store + + _, found = store.GetNode(node.ID) + assert.False(t, found, "node should be deleted from NodeStore") + + // If the update happened before delete in the batch, the returned node might be invalid + if updateOk { + t.Logf("UpdateNode returned ok=true, valid=%v", updatedNode.Valid()) + // This is the bug scenario - UpdateNode thinks it succeeded but node is gone + if updatedNode.Valid() { + t.Logf("WARNING: UpdateNode returned valid node but node was deleted - this indicates the race condition bug") + } + } else { + t.Logf("UpdateNode correctly returned ok=false (node deleted in same batch)") + } +} + +// TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch specifically tests that when +// UpdateNode and DeleteNode are in the same batch with DELETE after UPDATE, +// the UpdateNode should return an invalid node view. +func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) { + node := createTestNode(2, 1, "test-user", "test-node-2") + + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + // Put node in store + _ = store.PutNode(node) + + // Simulate the exact sequence: UpdateNode gets queued, then DeleteNode gets queued, + // they batch together, and we check what UpdateNode returns + + resultChan := make(chan struct { + node types.NodeView + ok bool + }) + + // Start UpdateNode - it will block until batch is applied + go func() { + node, ok := store.UpdateNode(node.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + resultChan <- struct { + node types.NodeView + ok bool + }{node, ok} + }() + + // Give UpdateNode a moment to queue its work + time.Sleep(5 * time.Millisecond) + + // Now queue DeleteNode - should batch with the UPDATE + store.DeleteNode(node.ID) + + // Get the result from UpdateNode + result := <-resultChan + + // Wait for batch to complete + time.Sleep(50 * time.Millisecond) + + // Node should be deleted + _, found := store.GetNode(node.ID) + assert.False(t, found, "node should be deleted") + + // The critical check: what did UpdateNode return? + // After the commit c6b09289988f34398eb3157e31ba092eb8721a9f, + // UpdateNode returns the node state from the batch. + // If DELETE came after UPDATE in the batch, the node doesn't exist anymore, + // so UpdateNode should return (invalid, false) + t.Logf("UpdateNode returned: ok=%v, valid=%v", result.ok, result.node.Valid()) + + // This is the expected behavior - if node was deleted in same batch, + // UpdateNode should return invalid node + if result.ok && result.node.Valid() { + t.Error("BUG: UpdateNode returned valid node even though it was deleted in same batch") + } +} + +// TestPersistNodeToDBPreventsRaceCondition tests that persistNodeToDB correctly handles +// the race condition where a node is deleted after UpdateNode returns but before +// persistNodeToDB is called. This reproduces the ephemeral node deletion bug. +func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) { + node := createTestNode(3, 1, "test-user", "test-node-3") + + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + // Put node in store + _ = store.PutNode(node) + + // Simulate UpdateNode being called + updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + require.True(t, ok, "UpdateNode should succeed") + require.True(t, updatedNode.Valid(), "UpdateNode should return valid node") + + // Now delete the node (simulating ephemeral logout happening concurrently) + store.DeleteNode(node.ID) + + // Wait for deletion to complete + time.Sleep(50 * time.Millisecond) + + // Verify node is deleted + _, found := store.GetNode(node.ID) + require.False(t, found, "node should be deleted") + + // Now try to use the updatedNode from before the deletion + // In the old code, this would re-insert the node into the database + // With our fix, GetNode check in persistNodeToDB should prevent this + + // Simulate what persistNodeToDB does - check if node still exists + _, exists := store.GetNode(updatedNode.ID()) + if !exists { + t.Log("SUCCESS: persistNodeToDB check would prevent re-insertion of deleted node") + } else { + t.Error("BUG: Node still exists in NodeStore after deletion") + } + + // The key assertion: after deletion, attempting to persist the old updatedNode + // should fail because the node no longer exists in NodeStore + assert.False(t, exists, "persistNodeToDB should detect node was deleted and refuse to persist") +} + +// TestEphemeralNodeLogoutRaceCondition tests the specific race condition that occurs +// when an ephemeral node logs out. This reproduces the bug where: +// 1. UpdateNodeFromMapRequest calls UpdateNode and receives a node view +// 2. Concurrently, handleLogout is called for the ephemeral node and calls DeleteNode +// 3. UpdateNode and DeleteNode get batched together +// 4. If UpdateNode's result is used to call persistNodeToDB after the deletion, +// the node could be re-inserted into the database even though it was deleted +func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { + ephemeralNode := createTestNode(4, 1, "test-user", "ephemeral-node") + ephemeralNode.AuthKey = &types.PreAuthKey{ + ID: 1, + Key: "test-key", + Ephemeral: true, + } + + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + // Put ephemeral node in store + _ = store.PutNode(ephemeralNode) + + // Simulate concurrent operations: + // 1. UpdateNode (from UpdateNodeFromMapRequest during polling) + // 2. DeleteNode (from handleLogout when client sends logout request) + + var updatedNode types.NodeView + var updateOk bool + done := make(chan bool, 2) + + // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) + go func() { + updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + done <- true + }() + + // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) + go func() { + time.Sleep(1 * time.Millisecond) // Slight delay to batch operations + store.DeleteNode(ephemeralNode.ID) + done <- true + }() + + // Wait for both operations + <-done + <-done + + // Give batching time to complete + time.Sleep(50 * time.Millisecond) + + // Node should be deleted from store + _, found := store.GetNode(ephemeralNode.ID) + assert.False(t, found, "ephemeral node should be deleted from NodeStore") + + // Critical assertion: if UpdateNode returned before DeleteNode completed, + // the updatedNode might be valid but the node is actually deleted. + // This is the bug - UpdateNodeFromMapRequest would get a valid node, + // then try to persist it, re-inserting the deleted ephemeral node. + if updateOk && updatedNode.Valid() { + t.Log("UpdateNode returned valid node, but node is deleted - this is the race condition") + + // In the real code, this would cause persistNodeToDB to be called with updatedNode + // The fix in persistNodeToDB checks if the node still exists: + _, stillExists := store.GetNode(updatedNode.ID()) + assert.False(t, stillExists, "persistNodeToDB should check NodeStore and find node deleted") + } else if !updateOk || !updatedNode.Valid() { + t.Log("UpdateNode correctly returned invalid/not-ok result (delete happened in same batch)") + } +} + +// TestUpdateNodeFromMapRequestEphemeralLogoutSequence tests the exact sequence +// that causes ephemeral node logout failures: +// 1. Client sends MapRequest with updated endpoint info +// 2. UpdateNodeFromMapRequest starts processing, calls UpdateNode +// 3. Client sends logout request (past expiry) +// 4. handleLogout calls DeleteNode for ephemeral node +// 5. UpdateNode and DeleteNode batch together +// 6. UpdateNode returns a valid node (from before delete in batch) +// 7. persistNodeToDB is called with the stale valid node +// 8. Node gets re-inserted into database instead of staying deleted +func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { + ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5") + ephemeralNode.AuthKey = &types.PreAuthKey{ + ID: 2, + Key: "test-key-2", + Ephemeral: true, + } + + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + // Initial state: ephemeral node exists + _ = store.PutNode(ephemeralNode) + + // Step 1: UpdateNodeFromMapRequest calls UpdateNode + // (simulating client sending MapRequest with endpoint updates) + updateStarted := make(chan bool) + var updatedNode types.NodeView + var updateOk bool + + go func() { + updateStarted <- true + updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + endpoint := netip.MustParseAddrPort("10.0.0.1:41641") + n.Endpoints = []netip.AddrPort{endpoint} + }) + }() + + <-updateStarted + // Small delay to ensure UpdateNode is queued + time.Sleep(5 * time.Millisecond) + + // Step 2: Logout happens - handleLogout calls DeleteNode + // (simulating client sending logout with past expiry) + store.DeleteNode(ephemeralNode.ID) + + // Wait for batching to complete + time.Sleep(50 * time.Millisecond) + + // Step 3: Check results + _, nodeExists := store.GetNode(ephemeralNode.ID) + assert.False(t, nodeExists, "ephemeral node must be deleted after logout") + + // Step 4: Simulate what happens if we try to persist the updatedNode + if updateOk && updatedNode.Valid() { + // This is the problematic path - UpdateNode returned a valid node + // but the node was deleted in the same batch + t.Log("UpdateNode returned valid node even though node was deleted") + + // The fix: persistNodeToDB must check NodeStore before persisting + _, checkExists := store.GetNode(updatedNode.ID()) + if checkExists { + t.Error("BUG: Node still exists in NodeStore after deletion - should be impossible") + } else { + t.Log("SUCCESS: persistNodeToDB would detect node is deleted and refuse to persist") + } + } else { + t.Log("UpdateNode correctly indicated node was deleted (returned invalid or not-ok)") + } + + // Final assertion: node must not exist + _, finalExists := store.GetNode(ephemeralNode.ID) + assert.False(t, finalExists, "ephemeral node must remain deleted") +} + +// TestUpdateNodeDeletedInSameBatchReturnsInvalid specifically tests that when +// UpdateNode and DeleteNode are batched together with DELETE after UPDATE, +// UpdateNode returns ok=false to indicate the node was deleted. +func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { + node := createTestNode(6, 1, "test-user", "test-node-6") + + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + // Put node in store + _ = store.PutNode(node) + + // Queue UpdateNode + updateDone := make(chan struct { + node types.NodeView + ok bool + }) + + go func() { + updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + updateDone <- struct { + node types.NodeView + ok bool + }{updatedNode, ok} + }() + + // Small delay to ensure UpdateNode is queued + time.Sleep(5 * time.Millisecond) + + // Queue DeleteNode - should batch with UpdateNode + store.DeleteNode(node.ID) + + // Get UpdateNode result + result := <-updateDone + + // Wait for batch to complete + time.Sleep(50 * time.Millisecond) + + // Node should be deleted + _, exists := store.GetNode(node.ID) + assert.False(t, exists, "node should be deleted from store") + + // UpdateNode should indicate the node was deleted + // After c6b09289988f34398eb3157e31ba092eb8721a9f, when UPDATE and DELETE + // are in the same batch with DELETE after UPDATE, UpdateNode returns + // the state after the batch is applied - which means the node doesn't exist + assert.False(t, result.ok, "UpdateNode should return ok=false when node deleted in same batch") + assert.False(t, result.node.Valid(), "UpdateNode should return invalid node when node deleted in same batch") +} + +// TestPersistNodeToDBChecksNodeStoreBeforePersist verifies that persistNodeToDB +// checks if the node still exists in NodeStore before persisting to database. +// This prevents the race condition where: +// 1. UpdateNodeFromMapRequest calls UpdateNode and gets a valid node +// 2. Ephemeral node logout calls DeleteNode +// 3. UpdateNode and DeleteNode batch together +// 4. UpdateNode returns a valid node (from before delete in batch) +// 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node +// 6. persistNodeToDB must detect the node is deleted and refuse to persist +func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { + ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7") + ephemeralNode.AuthKey = &types.PreAuthKey{ + ID: 3, + Key: "test-key-3", + Ephemeral: true, + } + + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + // Put node in store + _ = store.PutNode(ephemeralNode) + + // Simulate the race: + // 1. UpdateNode is called (from UpdateNodeFromMapRequest) + updatedNode, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + require.True(t, ok, "UpdateNode should succeed") + require.True(t, updatedNode.Valid(), "UpdateNode should return valid node") + + // 2. Node is deleted (from handleLogout for ephemeral node) + store.DeleteNode(ephemeralNode.ID) + + // Wait for deletion + time.Sleep(50 * time.Millisecond) + + // 3. Verify node is deleted from store + _, exists := store.GetNode(ephemeralNode.ID) + require.False(t, exists, "node should be deleted from NodeStore") + + // 4. Simulate what persistNodeToDB does - check if node still exists + // The fix in persistNodeToDB checks NodeStore before persisting: + // if !exists { return error } + // This prevents re-inserting the deleted node into the database + + // Verify the node from UpdateNode is valid but node is gone from store + assert.True(t, updatedNode.Valid(), "UpdateNode returned a valid node view") + _, stillExists := store.GetNode(updatedNode.ID()) + assert.False(t, stillExists, "but node should be deleted from NodeStore") + + // This is the critical test: persistNodeToDB must check NodeStore + // and refuse to persist if the node doesn't exist anymore + // The actual persistNodeToDB implementation does: + // _, exists := s.nodeStore.GetNode(node.ID()) + // if !exists { return error } +} diff --git a/hscontrol/state/maprequest.go b/hscontrol/state/maprequest.go index 9d6f1a09..e7dfc11c 100644 --- a/hscontrol/state/maprequest.go +++ b/hscontrol/state/maprequest.go @@ -10,9 +10,9 @@ import ( "tailscale.com/tailcfg" ) -// NetInfoFromMapRequest determines the correct NetInfo to use. +// netInfoFromMapRequest determines the correct NetInfo to use. // Returns the NetInfo that should be used for this request. -func NetInfoFromMapRequest( +func netInfoFromMapRequest( nodeID types.NodeID, currentHostinfo *tailcfg.Hostinfo, reqHostinfo *tailcfg.Hostinfo, diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index dfb2abd0..865d3eb4 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -61,7 +61,7 @@ func TestNetInfoFromMapRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := NetInfoFromMapRequest(nodeID, tt.currentHostinfo, tt.reqHostinfo) + result := netInfoFromMapRequest(nodeID, tt.currentHostinfo, tt.reqHostinfo) if tt.expectNetInfo == nil { assert.Nil(t, result, "expected nil NetInfo") @@ -100,14 +100,40 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { } // BUG: Using the node being modified (no NetInfo) instead of existing node (has NetInfo) - buggyResult := NetInfoFromMapRequest(nodeID, nodeBeingModifiedHostinfo, newRegistrationHostinfo) + buggyResult := netInfoFromMapRequest(nodeID, nodeBeingModifiedHostinfo, newRegistrationHostinfo) assert.Nil(t, buggyResult, "Bug: Should return nil when using wrong hostinfo reference") // CORRECT: Using the existing node's hostinfo (has NetInfo) - correctResult := NetInfoFromMapRequest(nodeID, existingNodeHostinfo, newRegistrationHostinfo) + correctResult := netInfoFromMapRequest(nodeID, existingNodeHostinfo, newRegistrationHostinfo) assert.NotNil(t, correctResult, "Fix: Should preserve NetInfo when using correct hostinfo reference") assert.Equal(t, 5, correctResult.PreferredDERP, "Should preserve the DERP region from existing node") }) + + t.Run("new_node_creation_for_different_user_should_preserve_netinfo", func(t *testing.T) { + // This test covers the scenario where: + // 1. A node exists for user1 with NetInfo + // 2. The same machine logs in as user2 (different user) + // 3. A NEW node is created for user2 (pre-auth key flow) + // 4. The new node should preserve NetInfo from the old node + + // Existing node for user1 with NetInfo + existingNodeUser1Hostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + NetInfo: &tailcfg.NetInfo{PreferredDERP: 7}, + } + + // New registration request for user2 (no NetInfo yet) + newNodeUser2Hostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + OS: "linux", + // NetInfo is nil - registration request doesn't include it + } + + // When creating a new node for user2, we should preserve NetInfo from user1's node + result := netInfoFromMapRequest(types.NodeID(2), existingNodeUser1Hostinfo, newNodeUser2Hostinfo) + assert.NotNil(t, result, "New node for user2 should preserve NetInfo from user1's node") + assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node") + }) } // Simple helper function for tests diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index 555766d1..34bbb24f 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -15,7 +15,7 @@ import ( ) const ( - batchSize = 10 + batchSize = 100 batchTimeout = 500 * time.Millisecond ) @@ -121,10 +121,11 @@ type Snapshot struct { nodesByID map[types.NodeID]types.Node // calculated from nodesByID - nodesByNodeKey map[key.NodePublic]types.NodeView - peersByNode map[types.NodeID][]types.NodeView - nodesByUser map[types.UserID][]types.NodeView - allNodes []types.NodeView + nodesByNodeKey map[key.NodePublic]types.NodeView + nodesByMachineKey map[key.MachinePublic]map[types.UserID]types.NodeView + peersByNode map[types.NodeID][]types.NodeView + nodesByUser map[types.UserID][]types.NodeView + allNodes []types.NodeView } // PeersFunc is a function that takes a list of nodes and returns a map @@ -135,26 +136,29 @@ type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView // work represents a single operation to be performed on the NodeStore. type work struct { - op int - nodeID types.NodeID - node types.Node - updateFn UpdateNodeFunc - result chan struct{} + op int + nodeID types.NodeID + node types.Node + updateFn UpdateNodeFunc + result chan struct{} + nodeResult chan types.NodeView // Channel to return the resulting node after batch application } // PutNode adds or updates a node in the store. // If the node already exists, it will be replaced. // If the node does not exist, it will be added. // This is a blocking operation that waits for the write to complete. -func (s *NodeStore) PutNode(n types.Node) { +// Returns the resulting node after all modifications in the batch have been applied. +func (s *NodeStore) PutNode(n types.Node) types.NodeView { timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put")) defer timer.ObserveDuration() work := work{ - op: put, - nodeID: n.ID, - node: n, - result: make(chan struct{}), + op: put, + nodeID: n.ID, + node: n, + result: make(chan struct{}), + nodeResult: make(chan types.NodeView, 1), } nodeStoreQueueDepth.Inc() @@ -162,7 +166,10 @@ func (s *NodeStore) PutNode(n types.Node) { <-work.result nodeStoreQueueDepth.Dec() + resultNode := <-work.nodeResult nodeStoreOperations.WithLabelValues("put").Inc() + + return resultNode } // UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it. @@ -173,6 +180,7 @@ type UpdateNodeFunc func(n *types.Node) // This is analogous to a database "transaction", or, the caller should // rather collect all data they want to change, and then call this function. // Fewer calls are better. +// Returns the resulting node after all modifications in the batch have been applied. // // TODO(kradalby): Technically we could have a version of this that modifies the node // in the current snapshot if _we know_ that the change will not affect the peer relationships. @@ -181,15 +189,16 @@ type UpdateNodeFunc func(n *types.Node) // a lock around the nodesByID map to ensure that no other writes are happening // while we are modifying the node. Which mean we would need to implement read-write locks // on all read operations. -func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) { +func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) (types.NodeView, bool) { timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update")) defer timer.ObserveDuration() work := work{ - op: update, - nodeID: nodeID, - updateFn: updateFn, - result: make(chan struct{}), + op: update, + nodeID: nodeID, + updateFn: updateFn, + result: make(chan struct{}), + nodeResult: make(chan types.NodeView, 1), } nodeStoreQueueDepth.Inc() @@ -197,7 +206,11 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node) <-work.result nodeStoreQueueDepth.Dec() + resultNode := <-work.nodeResult nodeStoreOperations.WithLabelValues("update").Inc() + + // Return the node and whether it exists (is valid) + return resultNode, resultNode.Valid() } // DeleteNode removes a node from the store by its ID. @@ -282,18 +295,32 @@ func (s *NodeStore) applyBatch(batch []work) { nodes := make(map[types.NodeID]types.Node) maps.Copy(nodes, s.data.Load().nodesByID) - for _, w := range batch { + // Track which work items need node results + nodeResultRequests := make(map[types.NodeID][]*work) + + for i := range batch { + w := &batch[i] switch w.op { case put: nodes[w.nodeID] = w.node + if w.nodeResult != nil { + nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) + } case update: // Update the specific node identified by nodeID if n, exists := nodes[w.nodeID]; exists { w.updateFn(&n) nodes[w.nodeID] = n } + if w.nodeResult != nil { + nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) + } case del: delete(nodes, w.nodeID) + // For delete operations, send an invalid NodeView if requested + if w.nodeResult != nil { + nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) + } } } @@ -303,6 +330,24 @@ func (s *NodeStore) applyBatch(batch []work) { // Update node count gauge nodeStoreNodesCount.Set(float64(len(nodes))) + // Send the resulting nodes to all work items that requested them + for nodeID, workItems := range nodeResultRequests { + if node, exists := nodes[nodeID]; exists { + nodeView := node.View() + for _, w := range workItems { + w.nodeResult <- nodeView + close(w.nodeResult) + } + } else { + // Node was deleted or doesn't exist + for _, w := range workItems { + w.nodeResult <- types.NodeView{} // Send invalid view + close(w.nodeResult) + } + } + } + + // Signal completion for all work items for _, w := range batch { close(w.result) } @@ -323,9 +368,10 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S } newSnap := Snapshot{ - nodesByID: nodes, - allNodes: allNodes, - nodesByNodeKey: make(map[key.NodePublic]types.NodeView), + nodesByID: nodes, + allNodes: allNodes, + nodesByNodeKey: make(map[key.NodePublic]types.NodeView), + nodesByMachineKey: make(map[key.MachinePublic]map[types.UserID]types.NodeView), // peersByNode is most likely the most expensive operation, // it will use the list of all nodes, combined with the @@ -339,11 +385,19 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S nodesByUser: make(map[types.UserID][]types.NodeView), } - // Build nodesByUser and nodesByNodeKey maps + // Build nodesByUser, nodesByNodeKey, and nodesByMachineKey maps for _, n := range nodes { nodeView := n.View() - newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView) + userID := types.UserID(n.UserID) + + newSnap.nodesByUser[userID] = append(newSnap.nodesByUser[userID], nodeView) newSnap.nodesByNodeKey[n.NodeKey] = nodeView + + // Build machine key index + if newSnap.nodesByMachineKey[n.MachineKey] == nil { + newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView) + } + newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView } return newSnap @@ -382,19 +436,40 @@ func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bo return nodeView, exists } -// GetNodeByMachineKey returns a node by its machine key. The bool indicates if the node exists. -func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) { +// GetNodeByMachineKey returns a node by its machine key and user ID. The bool indicates if the node exists. +func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic, userID types.UserID) (types.NodeView, bool) { timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key")) defer timer.ObserveDuration() nodeStoreOperations.WithLabelValues("get_by_machine_key").Inc() snapshot := s.data.Load() - // We don't have a byMachineKey map, so we need to iterate - // This could be optimized by adding a byMachineKey map if this becomes a hot path - for _, node := range snapshot.nodesByID { - if node.MachineKey == machineKey { - return node.View(), true + if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists { + if node, exists := userMap[userID]; exists { + return node, true + } + } + + return types.NodeView{}, false +} + +// GetNodeByMachineKeyAnyUser returns the first node with the given machine key, +// regardless of which user it belongs to. This is useful for scenarios like +// transferring a node to a different user when re-authenticating with a +// different user's auth key. +// If multiple nodes exist with the same machine key (different users), the +// first one found is returned (order is not guaranteed). +func (s *NodeStore) GetNodeByMachineKeyAnyUser(machineKey key.MachinePublic) (types.NodeView, bool) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key_any_user")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("get_by_machine_key_any_user").Inc() + + snapshot := s.data.Load() + if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists { + // Return the first node found (order not guaranteed due to map iteration) + for _, node := range userMap { + return node, true } } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 9666e5db..64ee0406 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -1,7 +1,11 @@ package state import ( + "context" + "fmt" "net/netip" + "runtime" + "sync" "testing" "time" @@ -249,7 +253,9 @@ func TestNodeStoreOperations(t *testing.T) { name: "add first node", action: func(store *NodeStore) { node := createTestNode(1, 1, "user1", "node1") - store.PutNode(node) + resultNode := store.PutNode(node) + assert.True(t, resultNode.Valid(), "PutNode should return valid node") + assert.Equal(t, node.ID, resultNode.ID()) snapshot := store.data.Load() assert.Len(t, snapshot.nodesByID, 1) @@ -288,7 +294,9 @@ func TestNodeStoreOperations(t *testing.T) { name: "add second node same user", action: func(store *NodeStore) { node2 := createTestNode(2, 1, "user1", "node2") - store.PutNode(node2) + resultNode := store.PutNode(node2) + assert.True(t, resultNode.Valid(), "PutNode should return valid node") + assert.Equal(t, types.NodeID(2), resultNode.ID()) snapshot := store.data.Load() assert.Len(t, snapshot.nodesByID, 2) @@ -308,7 +316,9 @@ func TestNodeStoreOperations(t *testing.T) { name: "add third node different user", action: func(store *NodeStore) { node3 := createTestNode(3, 2, "user2", "node3") - store.PutNode(node3) + resultNode := store.PutNode(node3) + assert.True(t, resultNode.Valid(), "PutNode should return valid node") + assert.Equal(t, types.NodeID(3), resultNode.ID()) snapshot := store.data.Load() assert.Len(t, snapshot.nodesByID, 3) @@ -409,10 +419,14 @@ func TestNodeStoreOperations(t *testing.T) { { name: "update node hostname", action: func(store *NodeStore) { - store.UpdateNode(1, func(n *types.Node) { + resultNode, ok := store.UpdateNode(1, func(n *types.Node) { n.Hostname = "updated-node1" n.GivenName = "updated-node1" }) + assert.True(t, ok, "UpdateNode should return true for existing node") + assert.True(t, resultNode.Valid(), "Result node should be valid") + assert.Equal(t, "updated-node1", resultNode.Hostname()) + assert.Equal(t, "updated-node1", resultNode.GivenName()) snapshot := store.data.Load() assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname) @@ -436,10 +450,14 @@ func TestNodeStoreOperations(t *testing.T) { name: "add nodes with odd-even filtering", action: func(store *NodeStore) { // Add nodes in sequence - store.PutNode(createTestNode(1, 1, "user1", "node1")) - store.PutNode(createTestNode(2, 2, "user2", "node2")) - store.PutNode(createTestNode(3, 3, "user3", "node3")) - store.PutNode(createTestNode(4, 4, "user4", "node4")) + n1 := store.PutNode(createTestNode(1, 1, "user1", "node1")) + assert.True(t, n1.Valid()) + n2 := store.PutNode(createTestNode(2, 2, "user2", "node2")) + assert.True(t, n2.Valid()) + n3 := store.PutNode(createTestNode(3, 3, "user3", "node3")) + assert.True(t, n3.Valid()) + n4 := store.PutNode(createTestNode(4, 4, "user4", "node4")) + assert.True(t, n4.Valid()) snapshot := store.data.Load() assert.Len(t, snapshot.nodesByID, 4) @@ -478,6 +496,328 @@ func TestNodeStoreOperations(t *testing.T) { }, }, }, + { + name: "test batch modifications return correct node state", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify initial state", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) + }, + }, + { + name: "concurrent updates should reflect all batch changes", + action: func(store *NodeStore) { + // Start multiple updates that will be batched together + done1 := make(chan struct{}) + done2 := make(chan struct{}) + done3 := make(chan struct{}) + + var resultNode1, resultNode2 types.NodeView + var newNode3 types.NodeView + var ok1, ok2 bool + + // These should all be processed in the same batch + go func() { + resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "batch-updated-node1" + n.GivenName = "batch-given-1" + }) + close(done1) + }() + + go func() { + resultNode2, ok2 = store.UpdateNode(2, func(n *types.Node) { + n.Hostname = "batch-updated-node2" + n.GivenName = "batch-given-2" + }) + close(done2) + }() + + go func() { + node3 := createTestNode(3, 1, "user1", "node3") + newNode3 = store.PutNode(node3) + close(done3) + }() + + // Wait for all operations to complete + <-done1 + <-done2 + <-done3 + + // Verify the returned nodes reflect the batch state + assert.True(t, ok1, "UpdateNode should succeed for node 1") + assert.True(t, ok2, "UpdateNode should succeed for node 2") + assert.True(t, resultNode1.Valid()) + assert.True(t, resultNode2.Valid()) + assert.True(t, newNode3.Valid()) + + // Check that returned nodes have the updated values + assert.Equal(t, "batch-updated-node1", resultNode1.Hostname()) + assert.Equal(t, "batch-given-1", resultNode1.GivenName()) + assert.Equal(t, "batch-updated-node2", resultNode2.Hostname()) + assert.Equal(t, "batch-given-2", resultNode2.GivenName()) + assert.Equal(t, "node3", newNode3.Hostname()) + + // Verify the snapshot also reflects all changes + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Equal(t, "batch-updated-node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "batch-updated-node2", snapshot.nodesByID[2].Hostname) + assert.Equal(t, "node3", snapshot.nodesByID[3].Hostname) + + // Verify peer relationships are updated correctly with new node + assert.Len(t, snapshot.peersByNode[1], 2) // sees nodes 2 and 3 + assert.Len(t, snapshot.peersByNode[2], 2) // sees nodes 1 and 3 + assert.Len(t, snapshot.peersByNode[3], 2) // sees nodes 1 and 2 + }, + }, + { + name: "update non-existent node returns invalid view", + action: func(store *NodeStore) { + resultNode, ok := store.UpdateNode(999, func(n *types.Node) { + n.Hostname = "should-not-exist" + }) + + assert.False(t, ok, "UpdateNode should return false for non-existent node") + assert.False(t, resultNode.Valid(), "Result should be invalid NodeView") + }, + }, + { + name: "multiple updates to same node in batch all see final state", + action: func(store *NodeStore) { + // This test verifies that when multiple updates to the same node + // are batched together, each returned node reflects ALL changes + // in the batch, not just the individual update's changes. + + done1 := make(chan struct{}) + done2 := make(chan struct{}) + done3 := make(chan struct{}) + + var resultNode1, resultNode2, resultNode3 types.NodeView + var ok1, ok2, ok3 bool + + // These updates all modify node 1 and should be batched together + // The final state should have all three modifications applied + go func() { + resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "multi-update-hostname" + }) + close(done1) + }() + + go func() { + resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) { + n.GivenName = "multi-update-givenname" + }) + close(done2) + }() + + go func() { + resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) { + n.ForcedTags = []string{"tag1", "tag2"} + }) + close(done3) + }() + + // Wait for all operations to complete + <-done1 + <-done2 + <-done3 + + // All updates should succeed + assert.True(t, ok1, "First update should succeed") + assert.True(t, ok2, "Second update should succeed") + assert.True(t, ok3, "Third update should succeed") + + // CRITICAL: Each returned node should reflect ALL changes from the batch + // not just the change from its specific update call + + // resultNode1 (from hostname update) should also have the givenname and tags changes + assert.Equal(t, "multi-update-hostname", resultNode1.Hostname()) + assert.Equal(t, "multi-update-givenname", resultNode1.GivenName()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.ForcedTags().AsSlice()) + + // resultNode2 (from givenname update) should also have the hostname and tags changes + assert.Equal(t, "multi-update-hostname", resultNode2.Hostname()) + assert.Equal(t, "multi-update-givenname", resultNode2.GivenName()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.ForcedTags().AsSlice()) + + // resultNode3 (from tags update) should also have the hostname and givenname changes + assert.Equal(t, "multi-update-hostname", resultNode3.Hostname()) + assert.Equal(t, "multi-update-givenname", resultNode3.GivenName()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.ForcedTags().AsSlice()) + + // Verify the snapshot also has all changes + snapshot := store.data.Load() + finalNode := snapshot.nodesByID[1] + assert.Equal(t, "multi-update-hostname", finalNode.Hostname) + assert.Equal(t, "multi-update-givenname", finalNode.GivenName) + assert.Equal(t, []string{"tag1", "tag2"}, finalNode.ForcedTags) + }, + }, + }, + }, + { + name: "test UpdateNode result is immutable for database save", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + return NewNodeStore(initialNodes, allowAllPeersFunc) + }, + steps: []testStep{ + { + name: "verify returned node is complete and consistent", + action: func(store *NodeStore) { + // Update a node and verify the returned view is complete + resultNode, ok := store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "db-save-hostname" + n.GivenName = "db-save-given" + n.ForcedTags = []string{"db-tag1", "db-tag2"} + }) + + assert.True(t, ok, "UpdateNode should succeed") + assert.True(t, resultNode.Valid(), "Result should be valid") + + // Verify the returned node has all expected values + assert.Equal(t, "db-save-hostname", resultNode.Hostname()) + assert.Equal(t, "db-save-given", resultNode.GivenName()) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.ForcedTags().AsSlice()) + + // Convert to struct as would be done for database save + nodePtr := resultNode.AsStruct() + assert.NotNil(t, nodePtr) + assert.Equal(t, "db-save-hostname", nodePtr.Hostname) + assert.Equal(t, "db-save-given", nodePtr.GivenName) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.ForcedTags) + + // Verify the snapshot also reflects the same state + snapshot := store.data.Load() + storedNode := snapshot.nodesByID[1] + assert.Equal(t, "db-save-hostname", storedNode.Hostname) + assert.Equal(t, "db-save-given", storedNode.GivenName) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.ForcedTags) + }, + }, + { + name: "concurrent updates all return consistent final state for DB save", + action: func(store *NodeStore) { + // Multiple goroutines updating the same node + // All should receive the final batch state suitable for DB save + done1 := make(chan struct{}) + done2 := make(chan struct{}) + done3 := make(chan struct{}) + + var result1, result2, result3 types.NodeView + var ok1, ok2, ok3 bool + + // Start concurrent updates + go func() { + result1, ok1 = store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "concurrent-db-hostname" + }) + close(done1) + }() + + go func() { + result2, ok2 = store.UpdateNode(1, func(n *types.Node) { + n.GivenName = "concurrent-db-given" + }) + close(done2) + }() + + go func() { + result3, ok3 = store.UpdateNode(1, func(n *types.Node) { + n.ForcedTags = []string{"concurrent-tag"} + }) + close(done3) + }() + + // Wait for all to complete + <-done1 + <-done2 + <-done3 + + assert.True(t, ok1 && ok2 && ok3, "All updates should succeed") + + // All results should be valid and suitable for database save + assert.True(t, result1.Valid()) + assert.True(t, result2.Valid()) + assert.True(t, result3.Valid()) + + // Convert each to struct as would be done for DB save + nodePtr1 := result1.AsStruct() + nodePtr2 := result2.AsStruct() + nodePtr3 := result3.AsStruct() + + // All should have the complete final state + assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname) + assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.ForcedTags) + + assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname) + assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.ForcedTags) + + assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname) + assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.ForcedTags) + + // Verify consistency with stored state + snapshot := store.data.Load() + storedNode := snapshot.nodesByID[1] + assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname) + assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName) + assert.Equal(t, nodePtr1.ForcedTags, storedNode.ForcedTags) + }, + }, + { + name: "verify returned node preserves all fields for DB save", + action: func(store *NodeStore) { + // Get initial state + snapshot := store.data.Load() + originalNode := snapshot.nodesByID[2] + originalIPv4 := originalNode.IPv4 + originalIPv6 := originalNode.IPv6 + originalCreatedAt := originalNode.CreatedAt + originalUser := originalNode.User + + // Update only hostname + resultNode, ok := store.UpdateNode(2, func(n *types.Node) { + n.Hostname = "preserve-test-hostname" + }) + + assert.True(t, ok, "Update should succeed") + + // Convert to struct for DB save + nodeForDB := resultNode.AsStruct() + + // Verify all fields are preserved + assert.Equal(t, "preserve-test-hostname", nodeForDB.Hostname) + assert.Equal(t, originalIPv4, nodeForDB.IPv4) + assert.Equal(t, originalIPv6, nodeForDB.IPv6) + assert.Equal(t, originalCreatedAt, nodeForDB.CreatedAt) + assert.Equal(t, originalUser.Name, nodeForDB.User.Name) + assert.Equal(t, types.NodeID(2), nodeForDB.ID) + + // These fields should be suitable for direct database save + assert.NotNil(t, nodeForDB.IPv4) + assert.NotNil(t, nodeForDB.IPv6) + assert.False(t, nodeForDB.CreatedAt.IsZero()) + }, + }, + }, + }, } for _, tt := range tests { @@ -499,3 +839,302 @@ type testStep struct { name string action func(store *NodeStore) } + +// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests --- + +// Helper for concurrent test nodes +func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { + machineKey := key.NewMachine() + nodeKey := key.NewNode() + return types.Node{ + ID: id, + Hostname: hostname, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + UserID: 1, + User: types.User{ + Name: "concurrent-test-user", + }, + } +} + +// --- Concurrency: concurrent PutNode operations --- +func TestNodeStoreConcurrentPutNode(t *testing.T) { + const concurrentOps = 20 + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + var wg sync.WaitGroup + results := make(chan bool, concurrentOps) + for i := 0; i < concurrentOps; i++ { + wg.Add(1) + go func(nodeID int) { + defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") + resultNode := store.PutNode(node) + results <- resultNode.Valid() + }(i + 1) + } + wg.Wait() + close(results) + + successCount := 0 + for success := range results { + if success { + successCount++ + } + } + require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed") +} + +// --- Batching: concurrent ops fit in one batch --- +func TestNodeStoreBatchingEfficiency(t *testing.T) { + const batchSize = 10 + const ops = 15 // more than batchSize + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + var wg sync.WaitGroup + results := make(chan bool, ops) + for i := 0; i < ops; i++ { + wg.Add(1) + go func(nodeID int) { + defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") + resultNode := store.PutNode(node) + results <- resultNode.Valid() + }(i + 1) + } + wg.Wait() + close(results) + + successCount := 0 + for success := range results { + if success { + successCount++ + } + } + require.Equal(t, ops, successCount, "All batch PutNode operations should succeed") +} + +// --- Race conditions: many goroutines on same node --- +func TestNodeStoreRaceConditions(t *testing.T) { + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + nodeID := types.NodeID(1) + node := createConcurrentTestNode(nodeID, "race-node") + resultNode := store.PutNode(node) + require.True(t, resultNode.Valid()) + + const numGoroutines = 30 + const opsPerGoroutine = 10 + var wg sync.WaitGroup + errors := make(chan error, numGoroutines*opsPerGoroutine) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(gid int) { + defer wg.Done() + for j := 0; j < opsPerGoroutine; j++ { + switch j % 3 { + case 0: + resultNode, _ := store.UpdateNode(nodeID, func(n *types.Node) { + n.Hostname = "race-updated" + }) + if !resultNode.Valid() { + errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j) + } + case 1: + retrieved, found := store.GetNode(nodeID) + if !found || !retrieved.Valid() { + errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j) + } + case 2: + newNode := createConcurrentTestNode(nodeID, "race-put") + resultNode := store.PutNode(newNode) + if !resultNode.Valid() { + errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) + } + } + } + }(i) + } + wg.Wait() + close(errors) + + errorCount := 0 + for err := range errors { + t.Error(err) + errorCount++ + } + if errorCount > 0 { + t.Fatalf("Race condition test failed with %d errors", errorCount) + } +} + +// --- Resource cleanup: goroutine leak detection --- +func TestNodeStoreResourceCleanup(t *testing.T) { + // initialGoroutines := runtime.NumGoroutine() + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + time.Sleep(50 * time.Millisecond) + afterStartGoroutines := runtime.NumGoroutine() + + const ops = 100 + for i := 0; i < ops; i++ { + nodeID := types.NodeID(i + 1) + node := createConcurrentTestNode(nodeID, "cleanup-node") + resultNode := store.PutNode(node) + assert.True(t, resultNode.Valid()) + store.UpdateNode(nodeID, func(n *types.Node) { + n.Hostname = "cleanup-updated" + }) + retrieved, found := store.GetNode(nodeID) + assert.True(t, found && retrieved.Valid()) + if i%10 == 9 { + store.DeleteNode(nodeID) + } + } + runtime.GC() + time.Sleep(100 * time.Millisecond) + finalGoroutines := runtime.NumGoroutine() + if finalGoroutines > afterStartGoroutines+2 { + t.Errorf("Potential goroutine leak: started with %d, ended with %d", afterStartGoroutines, finalGoroutines) + } +} + +// --- Timeout/deadlock: operations complete within reasonable time --- +func TestNodeStoreOperationTimeout(t *testing.T) { + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + const ops = 30 + var wg sync.WaitGroup + putResults := make([]error, ops) + updateResults := make([]error, ops) + + // Launch all PutNode operations concurrently + for i := 1; i <= ops; i++ { + nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { + defer wg.Done() + startPut := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id) + node := createConcurrentTestNode(id, "timeout-node") + resultNode := store.PutNode(node) + endPut := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut)) + if !resultNode.Valid() { + putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) + } + }(i, nodeID) + } + wg.Wait() + + // Launch all UpdateNode operations concurrently + wg = sync.WaitGroup{} + for i := 1; i <= ops; i++ { + nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { + defer wg.Done() + startUpdate := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id) + resultNode, ok := store.UpdateNode(id, func(n *types.Node) { + n.Hostname = "timeout-updated" + }) + endUpdate := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate)) + if !ok || !resultNode.Valid() { + updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) + } + }(i, nodeID) + } + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + errorCount := 0 + for _, err := range putResults { + if err != nil { + t.Error(err) + errorCount++ + } + } + for _, err := range updateResults { + if err != nil { + t.Error(err) + errorCount++ + } + } + if errorCount == 0 { + t.Log("All concurrent operations completed successfully within timeout") + } else { + t.Fatalf("Some concurrent operations failed: %d errors", errorCount) + } + case <-ctx.Done(): + fmt.Println("[TestNodeStoreOperationTimeout] Timeout reached, test failed") + t.Fatal("Operations timed out - potential deadlock or resource issue") + } +} + +// --- Edge case: update non-existent node --- +func TestNodeStoreUpdateNonExistentNode(t *testing.T) { + for i := 0; i < 10; i++ { + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + nonExistentID := types.NodeID(999 + i) + updateCallCount := 0 + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID) + resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) { + updateCallCount++ + n.Hostname = "should-never-be-called" + }) + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) finished, valid=%v, ok=%v, updateCallCount=%d\n", nonExistentID, resultNode.Valid(), ok, updateCallCount) + assert.False(t, ok, "UpdateNode should return false for non-existent node") + assert.False(t, resultNode.Valid(), "UpdateNode should return invalid node for non-existent node") + assert.Equal(t, 0, updateCallCount, "UpdateFn should not be called for non-existent node") + store.Stop() + } +} + +// --- Allocation benchmark --- +func BenchmarkNodeStoreAllocations(b *testing.B) { + store := NewNodeStore(nil, allowAllPeersFunc) + store.Start() + defer store.Stop() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + nodeID := types.NodeID(i + 1) + node := createConcurrentTestNode(nodeID, "bench-node") + store.PutNode(node) + store.UpdateNode(nodeID, func(n *types.Node) { + n.Hostname = "bench-updated" + }) + store.GetNode(nodeID) + if i%10 == 9 { + store.DeleteNode(nodeID) + } + } +} + +func TestNodeStoreAllocationStats(t *testing.T) { + res := testing.Benchmark(BenchmarkNodeStoreAllocations) + allocs := res.AllocsPerOp() + t.Logf("NodeStore allocations per op: %.2f", float64(allocs)) +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index ad7770ff..c8e33544 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -4,6 +4,7 @@ package state import ( + "cmp" "context" "errors" "fmt" @@ -23,7 +24,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" - "github.com/sasha-s/go-deadlock" "golang.org/x/sync/errgroup" "gorm.io/gorm" "tailscale.com/net/tsaddr" @@ -48,8 +48,6 @@ var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode") // State manages Headscale's core state, coordinating between database, policy management, // IP allocation, and DERP routing. All methods are thread-safe. type State struct { - // mu protects all in-memory data structures from concurrent access - mu deadlock.RWMutex // cfg holds the current Headscale configuration cfg *types.Config @@ -257,9 +255,6 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) { // CreateUser creates a new user and updates the policy manager. // Returns the created user, change set, and any error. func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, error) { - s.mu.Lock() - defer s.mu.Unlock() - if err := s.db.DB.Save(&user).Error; err != nil { return nil, change.EmptySet, fmt.Errorf("creating user: %w", err) } @@ -288,9 +283,6 @@ func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, erro // UpdateUser modifies an existing user using the provided update function within a transaction. // Returns the updated user, change set, and any error. func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.ChangeSet, error) { - s.mu.Lock() - defer s.mu.Unlock() - user, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.User, error) { user, err := hsdb.GetUserByID(tx, userID) if err != nil { @@ -361,44 +353,28 @@ func (s *State) ListAllUsers() ([]types.User, error) { return s.db.ListUsers() } -// updateNodeTx performs a database transaction to update a node and refresh the policy manager. -// IMPORTANT: This function does NOT update the NodeStore. The caller MUST update the NodeStore -// BEFORE calling this function with the EXACT same changes that the database update will make. -// This ensures the NodeStore is the source of truth for the batcher and maintains consistency. -// Returns error only; callers should get the updated NodeView from NodeStore to maintain consistency. -func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) error { - s.mu.Lock() - defer s.mu.Unlock() +// persistNodeToDB saves the given node state to the database. +// This function must receive the exact node state to save to ensure consistency between +// NodeStore and the database. It verifies the node still exists in NodeStore to prevent +// race conditions where a node might be deleted between UpdateNode returning and +// persistNodeToDB being called. +func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.ChangeSet, error) { + if !node.Valid() { + return types.NodeView{}, change.EmptySet, fmt.Errorf("invalid node view provided") + } - _, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := updateFn(tx); err != nil { - return nil, err - } - - node, err := hsdb.GetNodeByID(tx, nodeID) - if err != nil { - return nil, err - } - - if err := tx.Save(node).Error; err != nil { - return nil, fmt.Errorf("updating node: %w", err) - } - - return node, nil - }) - return err -} - -// persistNodeToDB saves the current state of a node from NodeStore to the database. -// CRITICAL: This function MUST get the latest node from NodeStore to ensure consistency. -func (s *State) persistNodeToDB(nodeID types.NodeID) (types.NodeView, change.ChangeSet, error) { - s.mu.Lock() - defer s.mu.Unlock() - - // CRITICAL: Always get the latest node from NodeStore to ensure we save the current state - node, found := s.nodeStore.GetNode(nodeID) - if !found { - return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + // Verify the node still exists in NodeStore before persisting to database. + // Without this check, we could hit a race condition where UpdateNode returns a valid + // node from a batch update, then the node gets deleted (e.g., ephemeral node logout), + // and persistNodeToDB would incorrectly re-insert the deleted node into the database. + _, exists := s.nodeStore.GetNode(node.ID()) + if !exists { + log.Warn(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Bool("is_ephemeral", node.IsEphemeral()). + Msg("Node no longer exists in NodeStore, skipping database persist to prevent race condition") + return types.NodeView{}, change.EmptySet, fmt.Errorf("node %d no longer exists in NodeStore, skipping database persist", node.ID()) } nodePtr := node.AsStruct() @@ -424,10 +400,10 @@ func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, // Update NodeStore first nodePtr := node.AsStruct() - s.nodeStore.PutNode(*nodePtr) + resultNode := s.nodeStore.PutNode(*nodePtr) - // Then save to database - return s.persistNodeToDB(node.ID()) + // Then save to database using the result from PutNode + return s.persistNodeToDB(resultNode) } // DeleteNode permanently removes a node and cleans up associated resources. @@ -461,17 +437,14 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet { // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, // the NodeStore already reflects the correct online status for full map generation. // now := time.Now() - s.nodeStore.UpdateNode(id, func(n *types.Node) { + node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { n.IsOnline = ptr.To(true) // n.LastSeen = ptr.To(now) }) - c := []change.ChangeSet{change.NodeOnline(id)} - - // Get fresh node data from NodeStore after the online status update - node, found := s.GetNodeByID(id) - if !found { + if !ok { return nil } + c := []change.ChangeSet{change.NodeOnline(id)} log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected") @@ -491,39 +464,25 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet { func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) { now := time.Now() - // Get node info before updating for logging - node, found := s.GetNodeByID(id) - var nodeName string - if found { - nodeName = node.Hostname() - } - - s.nodeStore.UpdateNode(id, func(n *types.Node) { + node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { n.LastSeen = ptr.To(now) // NodeStore is the source of truth for all node state including online status. n.IsOnline = ptr.To(false) }) - if found { - log.Info().Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Node disconnected") + if !ok { + return nil, fmt.Errorf("node not found: %d", id) } - err := s.updateNodeTx(id, func(tx *gorm.DB) error { - // Update last_seen in the database - // Note: IsOnline is managed only in NodeStore (marked with gorm:"-"), not persisted to database - return hsdb.SetLastSeen(tx, id, now) - }) + log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node disconnected") + + // Special error handling for disconnect - we log errors but continue + // because NodeStore is already updated and we need to notify peers + _, c, err := s.persistNodeToDB(node) if err != nil { // Log error but don't fail the disconnection - NodeStore is already updated // and we need to send change notifications to peers - log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update last seen in database") - } - - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - // Log error but continue - disconnection must proceed - log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", nodeName).Msg("Failed to update policy manager after node disconnect") + log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Failed to update last seen in database") c = change.EmptySet } @@ -559,12 +518,12 @@ func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bool) return s.nodeStore.GetNodeByNodeKey(nodeKey) } -// GetNodeByMachineKey retrieves a node by its machine key. +// GetNodeByMachineKey retrieves a node by its machine key and user ID. // The bool indicates if the node exists or is available (like "err not found"). // The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure // it isn't an invalid node (this is more of a node error or node is broken). -func (s *State) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) { - return s.nodeStore.GetNodeByMachineKey(machineKey) +func (s *State) GetNodeByMachineKey(machineKey key.MachinePublic, userID types.UserID) (types.NodeView, bool) { + return s.nodeStore.GetNodeByMachineKey(machineKey, userID) } // ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided. @@ -635,77 +594,37 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] { // SetNodeExpiry updates the expiration time for a node. func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) { - // CRITICAL: Update NodeStore BEFORE database to ensure consistency. - // The NodeStore update is blocking and will be the source of truth for the batcher. - // The database update MUST make the EXACT same change. - // If the database update fails, the NodeStore change will remain, but since we return - // an error, no change notification will be sent to the batcher. + // Update NodeStore before database to ensure consistency. The NodeStore update is + // blocking and will be the source of truth for the batcher. The database update must + // make the exact same change. If the database update fails, the NodeStore change will + // remain, but since we return an error, no change notification will be sent to the + // batcher, preventing inconsistent state propagation. expiryPtr := expiry - s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { node.Expiry = &expiryPtr }) - err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.NodeSetExpiry(tx, nodeID, expiry) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node expiry: %w", err) - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - if !c.IsFull() { - c = change.KeyExpiry(nodeID, expiry) - } - - return n, c, nil + return s.persistNodeToDB(n) } // SetNodeTags assigns tags to a node for use in access control policies. func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) { - // CRITICAL: Update NodeStore BEFORE database to ensure consistency. - // The NodeStore update is blocking and will be the source of truth for the batcher. - // The database update MUST make the EXACT same change. - s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + // Update NodeStore before database to ensure consistency. The NodeStore update is + // blocking and will be the source of truth for the batcher. The database update must + // make the exact same change. + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { node.ForcedTags = tags }) - err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.SetTags(tx, nodeID, tags) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node tags: %w", err) - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - if !c.IsFull() { - c = change.NodeAdded(nodeID) - } - - return n, c, nil + return s.persistNodeToDB(n) } // SetApprovedRoutes sets the network routes that a node is approved to advertise. @@ -713,44 +632,32 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t // TODO(kradalby): In principle we should call the AutoApprove logic here // because even if the CLI removes an auto-approved route, it will be added // back automatically. - s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { node.ApprovedRoutes = routes }) - err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.SetApprovedRoutes(tx, nodeID, routes) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("setting approved routes: %w", err) - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() + // Persist the node changes to the database + nodeView, c, err := s.persistNodeToDB(n) if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) + return types.NodeView{}, change.EmptySet, err } - // Get the node from NodeStore to ensure we have the latest state - nodeView, ok := s.GetNodeByID(nodeID) - if !ok { - return n, change.EmptySet, fmt.Errorf("node %d not found in NodeStore", nodeID) - } - // Use SubnetRoutes() instead of ApprovedRoutes() to ensure we only set - // primary routes for routes that are both announced AND approved + // Update primary routes table based on SubnetRoutes (intersection of announced and approved). + // The primary routes table is what the mapper uses to generate network maps, so updating it + // here ensures that route changes are distributed to peers. routeChange := s.primaryRoutes.SetRoutes(nodeID, nodeView.SubnetRoutes()...) + // If routes changed or the changeset isn't already a full update, trigger a policy change + // to ensure all nodes get updated network maps if routeChange || !c.IsFull() { c = change.PolicyChange() } - return n, c, nil + return nodeView, c, nil } // RenameNode changes the display name of a node. @@ -760,49 +667,27 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) } - // Check name uniqueness - nodes, err := s.db.ListNodes() - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("checking name uniqueness: %w", err) - } - for _, node := range nodes { - if node.ID != nodeID && node.GivenName == newName { + // Check name uniqueness against NodeStore + allNodes := s.nodeStore.ListNodes() + for i := 0; i < allNodes.Len(); i++ { + node := allNodes.At(i) + if node.ID() != nodeID && node.AsStruct().GivenName == newName { return types.NodeView{}, change.EmptySet, fmt.Errorf("name is not unique: %s", newName) } } - // CRITICAL: Update NodeStore BEFORE database to ensure consistency. - // The NodeStore update is blocking and will be the source of truth for the batcher. - // The database update MUST make the EXACT same change. - s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + // Update NodeStore before database to ensure consistency. The NodeStore update is + // blocking and will be the source of truth for the batcher. The database update must + // make the exact same change. + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { node.GivenName = newName }) - err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.RenameNode(tx, nodeID, newName) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - if !c.IsFull() { - c = change.NodeAdded(nodeID) - } - - return n, c, nil + return s.persistNodeToDB(n) } // AssignNodeToUser transfers a node to a different user. @@ -818,39 +703,19 @@ func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (type return types.NodeView{}, change.EmptySet, fmt.Errorf("user not found: %w", err) } - // CRITICAL: Update NodeStore BEFORE database to ensure consistency. - // The NodeStore update is blocking and will be the source of truth for the batcher. - // The database update MUST make the EXACT same change. - s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { + // Update NodeStore before database to ensure consistency. The NodeStore update is + // blocking and will be the source of truth for the batcher. The database update must + // make the exact same change. + n, ok := s.nodeStore.UpdateNode(nodeID, func(n *types.Node) { n.User = *user n.UserID = uint(userID) }) - err = s.updateNodeTx(nodeID, func(tx *gorm.DB) error { - return hsdb.AssignNodeToUser(tx, nodeID, userID) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, err - } - - // Get the updated node from NodeStore to ensure consistency - // TODO(kradalby): Validate if this NodeStore read makes sense after database update - n, found := s.GetNodeByID(nodeID) - if !found { + if !ok { return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) } - // Check if policy manager needs updating - c, err := s.updatePolicyManagerNodes() - if err != nil { - return n, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err) - } - - if !c.IsFull() { - c = change.NodeAdded(nodeID) - } - - return n, c, nil + return s.persistNodeToDB(n) } // BackfillNodeIPs assigns IP addresses to nodes that don't have them. @@ -877,20 +742,13 @@ func (s *State) BackfillNodeIPs() ([]string, error) { // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). // Preserve NetInfo from existing node to prevent loss during backfill - netInfo := NetInfoFromMapRequest(node.ID, existingNode.AsStruct().Hostinfo, node.Hostinfo) - if netInfo != nil { - if node.Hostinfo != nil { - hostinfoCopy := *node.Hostinfo - hostinfoCopy.NetInfo = netInfo - node.Hostinfo = &hostinfoCopy - } else { - node.Hostinfo = &tailcfg.Hostinfo{NetInfo: netInfo} - } - } + netInfo := netInfoFromMapRequest(node.ID, existingNode.Hostinfo().AsStruct(), node.Hostinfo) + node.Hostinfo = existingNode.Hostinfo().AsStruct() + node.Hostinfo.NetInfo = netInfo } // TODO(kradalby): This should just update the IP addresses, nothing else in the node store. // We should avoid PutNode here. - s.nodeStore.PutNode(*node) + _ = s.nodeStore.PutNode(*node) } } @@ -1043,6 +901,38 @@ func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral b return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags) } +// Test helpers for the state layer + +// CreateUserForTest creates a test user. This is a convenience wrapper around the database layer. +func (s *State) CreateUserForTest(name ...string) *types.User { + return s.db.CreateUserForTest(name...) +} + +// CreateNodeForTest creates a test node. This is a convenience wrapper around the database layer. +func (s *State) CreateNodeForTest(user *types.User, hostname ...string) *types.Node { + return s.db.CreateNodeForTest(user, hostname...) +} + +// CreateRegisteredNodeForTest creates a test node with allocated IPs. This is a convenience wrapper around the database layer. +func (s *State) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node { + return s.db.CreateRegisteredNodeForTest(user, hostname...) +} + +// CreateNodesForTest creates multiple test nodes. This is a convenience wrapper around the database layer. +func (s *State) CreateNodesForTest(user *types.User, count int, namePrefix ...string) []*types.Node { + return s.db.CreateNodesForTest(user, count, namePrefix...) +} + +// CreateUsersForTest creates multiple test users. This is a convenience wrapper around the database layer. +func (s *State) CreateUsersForTest(count int, namePrefix ...string) []*types.User { + return s.db.CreateUsersForTest(count, namePrefix...) +} + +// DB returns the underlying database for testing purposes. +func (s *State) DB() *hsdb.HSDatabase { + return s.db +} + // GetPreAuthKey retrieves a pre-authentication key by ID. func (s *State) GetPreAuthKey(id string) (*types.PreAuthKey, error) { return s.db.GetPreAuthKey(id) @@ -1073,6 +963,131 @@ func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.R s.registrationCache.Set(id, entry) } +// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname. +func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) { + if hostinfo == nil { + log.Warn(). + Caller(). + Str("machine.key", machineKey). + Str("node.key", nodeKey). + Str("user.name", username). + Str("generated.hostname", hostname). + Msg("Registration had nil hostinfo, generated default hostname") + } else if hostinfo.Hostname == "" { + log.Warn(). + Caller(). + Str("machine.key", machineKey). + Str("node.key", nodeKey). + Str("user.name", username). + Str("generated.hostname", hostname). + Msg("Registration had empty hostname, generated default") + } +} + +// preserveNetInfo preserves NetInfo from an existing node for faster DERP connectivity. +// If no existing node is provided, it creates new netinfo from the provided hostinfo. +func preserveNetInfo(existingNode types.NodeView, nodeID types.NodeID, validHostinfo *tailcfg.Hostinfo) *tailcfg.NetInfo { + var existingHostinfo *tailcfg.Hostinfo + if existingNode.Valid() { + existingHostinfo = existingNode.Hostinfo().AsStruct() + } + return netInfoFromMapRequest(nodeID, existingHostinfo, validHostinfo) +} + +// newNodeParams contains parameters for creating a new node. +type newNodeParams struct { + User types.User + MachineKey key.MachinePublic + NodeKey key.NodePublic + DiscoKey key.DiscoPublic + Hostname string + Hostinfo *tailcfg.Hostinfo + Endpoints []netip.AddrPort + Expiry *time.Time + RegisterMethod string + + // Optional: Pre-auth key specific fields + PreAuthKey *types.PreAuthKey + + // Optional: Existing node for netinfo preservation + ExistingNodeForNetinfo types.NodeView +} + +// createAndSaveNewNode creates a new node, allocates IPs, saves to DB, and adds to NodeStore. +// It preserves netinfo from an existing node if one is provided (for faster DERP connectivity). +func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, error) { + // Preserve NetInfo from existing node if available + if params.Hostinfo != nil { + params.Hostinfo.NetInfo = preserveNetInfo( + params.ExistingNodeForNetinfo, + types.NodeID(0), + params.Hostinfo, + ) + } + + // Prepare the node for registration + nodeToRegister := types.Node{ + Hostname: params.Hostname, + UserID: params.User.ID, + User: params.User, + MachineKey: params.MachineKey, + NodeKey: params.NodeKey, + DiscoKey: params.DiscoKey, + Hostinfo: params.Hostinfo, + Endpoints: params.Endpoints, + LastSeen: ptr.To(time.Now()), + RegisterMethod: params.RegisterMethod, + Expiry: params.Expiry, + } + + // Pre-auth key specific fields + if params.PreAuthKey != nil { + nodeToRegister.ForcedTags = params.PreAuthKey.Proto().GetAclTags() + nodeToRegister.AuthKey = params.PreAuthKey + nodeToRegister.AuthKeyID = ¶ms.PreAuthKey.ID + } + + // Allocate new IPs + ipv4, ipv6, err := s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, fmt.Errorf("allocating IPs: %w", err) + } + + nodeToRegister.IPv4 = ipv4 + nodeToRegister.IPv6 = ipv6 + + // Ensure unique given name if not set + if nodeToRegister.GivenName == "" { + givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) + if err != nil { + return types.NodeView{}, fmt.Errorf("failed to ensure unique given name: %w", err) + } + nodeToRegister.GivenName = givenName + } + + // New node - database first to get ID, then NodeStore + savedNode, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if params.PreAuthKey != nil && !params.PreAuthKey.Reusable { + err := hsdb.UsePreAuthKey(tx, params.PreAuthKey) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, err + } + + // Add to NodeStore after database creates the ID + return s.nodeStore.PutNode(*savedNode), nil +} + // HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC). func (s *State) HandleNodeFromAuthPath( registrationID types.RegistrationID, @@ -1080,9 +1095,6 @@ func (s *State) HandleNodeFromAuthPath( expiry *time.Time, registrationMethod string, ) (types.NodeView, change.ChangeSet, error) { - s.mu.Lock() - defer s.mu.Unlock() - // Get the registration entry from cache regEntry, ok := s.GetRegistrationCacheEntry(registrationID) if !ok { @@ -1095,182 +1107,161 @@ func (s *State) HandleNodeFromAuthPath( return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err) } - // Check if node already exists by node key - existingNodeView, exists := s.nodeStore.GetNodeByNodeKey(regEntry.Node.NodeKey) - if exists && existingNodeView.Valid() { - // Node exists - this is a refresh/re-registration + // Ensure we have valid hostinfo and hostname from the registration cache entry + validHostinfo, hostname := util.EnsureValidHostinfo( + regEntry.Node.Hostinfo, + regEntry.Node.MachineKey.String(), + regEntry.Node.NodeKey.String(), + ) + + logHostinfoValidation( + regEntry.Node.MachineKey.ShortString(), + regEntry.Node.NodeKey.String(), + user.Username(), + hostname, + regEntry.Node.Hostinfo, + ) + + var finalNode types.NodeView + + // Check if node already exists with same machine key for this user + existingNodeSameUser, existsSameUser := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey, types.UserID(user.ID)) + + // If this node exists for this user, update the node in place. + if existsSameUser && existingNodeSameUser.Valid() { log.Debug(). Caller(). Str("registration_id", registrationID.String()). Str("user.name", user.Username()). Str("registrationMethod", registrationMethod). - Str("node.name", existingNodeView.Hostname()). - Uint64("node.id", existingNodeView.ID().Uint64()). - Msg("Refreshing existing node registration") + Str("node.name", existingNodeSameUser.Hostname()). + Uint64("node.id", existingNodeSameUser.ID().Uint64()). + Msg("Updating existing node registration") - // Update NodeStore first with the new expiry - s.nodeStore.UpdateNode(existingNodeView.ID(), func(node *types.Node) { - if expiry != nil { - node.Expiry = expiry - } - // Mark as offline since node is reconnecting - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) - }) - - // Save to database - _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - err := hsdb.NodeSetExpiry(tx, existingNodeView.ID(), *expiry) - if err != nil { - return nil, err - } - // Return the node to satisfy the Write signature - return hsdb.GetNodeByID(tx, existingNodeView.ID()) - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update node expiry: %w", err) - } - - // Get updated node from NodeStore - updatedNode, _ := s.nodeStore.GetNode(existingNodeView.ID()) - - if expiry != nil { - return updatedNode, change.KeyExpiry(existingNodeView.ID(), *expiry), nil - } - - return updatedNode, change.FullSet, nil - } - - // New node registration - log.Debug(). - Caller(). - Str("registration_id", registrationID.String()). - Str("user.name", user.Username()). - Str("registrationMethod", registrationMethod). - Str("expiresAt", fmt.Sprintf("%v", expiry)). - Msg("Registering new node from auth callback") - - // Check if node exists with same machine key - var existingMachineNode *types.Node - if nv, exists := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey); exists && nv.Valid() { - existingMachineNode = nv.AsStruct() - } - - // Prepare the node for registration - nodeToRegister := regEntry.Node - nodeToRegister.UserID = uint(userID) - nodeToRegister.User = *user - nodeToRegister.RegisterMethod = registrationMethod - if expiry != nil { - nodeToRegister.Expiry = expiry - } - - // Handle IP allocation - var ipv4, ipv6 *netip.Addr - if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { - // Reuse existing IPs and properties - nodeToRegister.ID = existingMachineNode.ID - nodeToRegister.GivenName = existingMachineNode.GivenName - nodeToRegister.ApprovedRoutes = existingMachineNode.ApprovedRoutes - ipv4 = existingMachineNode.IPv4 - ipv6 = existingMachineNode.IPv6 - } else { - // Allocate new IPs - ipv4, ipv6, err = s.ipAlloc.Next() - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) - } - } - - nodeToRegister.IPv4 = ipv4 - nodeToRegister.IPv6 = ipv6 - - // Ensure unique given name if not set - if nodeToRegister.GivenName == "" { - givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) - } - nodeToRegister.GivenName = givenName - } - - var savedNode *types.Node - if existingMachineNode != nil && existingMachineNode.UserID == uint(userID) { // Update existing node - NodeStore first, then database - s.nodeStore.UpdateNode(existingMachineNode.ID, func(node *types.Node) { - node.NodeKey = nodeToRegister.NodeKey - node.DiscoKey = nodeToRegister.DiscoKey - node.Hostname = nodeToRegister.Hostname + updatedNodeView, ok := s.nodeStore.UpdateNode(existingNodeSameUser.ID(), func(node *types.Node) { + node.NodeKey = regEntry.Node.NodeKey + node.DiscoKey = regEntry.Node.DiscoKey + node.Hostname = hostname // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). // Preserve NetInfo from existing node when re-registering - netInfo := NetInfoFromMapRequest(existingMachineNode.ID, existingMachineNode.Hostinfo, nodeToRegister.Hostinfo) - if netInfo != nil { - if nodeToRegister.Hostinfo != nil { - hostinfoCopy := *nodeToRegister.Hostinfo - hostinfoCopy.NetInfo = netInfo - node.Hostinfo = &hostinfoCopy - } else { - node.Hostinfo = &tailcfg.Hostinfo{NetInfo: netInfo} - } - } else { - node.Hostinfo = nodeToRegister.Hostinfo - } + node.Hostinfo = validHostinfo + node.Hostinfo.NetInfo = preserveNetInfo(existingNodeSameUser, existingNodeSameUser.ID(), validHostinfo) - node.Endpoints = nodeToRegister.Endpoints - node.RegisterMethod = nodeToRegister.RegisterMethod - if expiry != nil { - node.Expiry = expiry - } + node.Endpoints = regEntry.Node.Endpoints + node.RegisterMethod = regEntry.Node.RegisterMethod node.IsOnline = ptr.To(false) node.LastSeen = ptr.To(time.Now()) + + if expiry != nil { + node.Expiry = expiry + } else { + node.Expiry = regEntry.Node.Expiry + } }) - // Save to database - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { + if !ok { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID()) + } + + // Use the node from UpdateNode to save to database + _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil { return nil, fmt.Errorf("failed to save node: %w", err) } - return &nodeToRegister, nil + return nil, nil }) if err != nil { return types.NodeView{}, change.EmptySet, err } + + log.Trace(). + Caller(). + Str("node.name", updatedNodeView.Hostname()). + Uint64("node.id", updatedNodeView.ID().Uint64()). + Str("machine.key", regEntry.Node.MachineKey.ShortString()). + Str("node.key", updatedNodeView.NodeKey().ShortString()). + Str("user.name", user.Name). + Msg("Node re-authorized") + + finalNode = updatedNodeView } else { - // New node - database first to get ID, then NodeStore - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { - return nil, fmt.Errorf("failed to save node: %w", err) - } - return &nodeToRegister, nil + // Node does not exist for this user with this machine key + // Check if node exists with this machine key for a different user (for netinfo preservation) + existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(regEntry.Node.MachineKey) + + if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != user.ID { + // Node exists but belongs to a different user + // Create a NEW node for the new user (do not transfer) + // This allows the same machine to have separate node identities per user + oldUser := existingNodeAnyUser.User() + log.Info(). + Caller(). + Str("existing.node.name", existingNodeAnyUser.Hostname()). + Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). + Str("machine.key", regEntry.Node.MachineKey.ShortString()). + Str("old.user", oldUser.Username()). + Str("new.user", user.Username()). + Str("method", registrationMethod). + Msg("Creating new node for different user (same machine key exists for another user)") + } + + // Create a completely new node + log.Debug(). + Caller(). + Str("registration_id", registrationID.String()). + Str("user.name", user.Username()). + Str("registrationMethod", registrationMethod). + Str("expiresAt", fmt.Sprintf("%v", expiry)). + Msg("Registering new node from auth callback") + + // Create and save new node + var err error + finalNode, err = s.createAndSaveNewNode(newNodeParams{ + User: *user, + MachineKey: regEntry.Node.MachineKey, + NodeKey: regEntry.Node.NodeKey, + DiscoKey: regEntry.Node.DiscoKey, + Hostname: hostname, + Hostinfo: validHostinfo, + Endpoints: regEntry.Node.Endpoints, + Expiry: cmp.Or(expiry, regEntry.Node.Expiry), + RegisterMethod: registrationMethod, + ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}), }) if err != nil { return types.NodeView{}, change.EmptySet, err } - - // Add to NodeStore after database creates the ID - s.nodeStore.PutNode(*savedNode) } // Signal to waiting clients - regEntry.SendAndClose(savedNode) + regEntry.SendAndClose(finalNode.AsStruct()) // Delete from registration cache s.registrationCache.Delete(registrationID) - // Update policy manager + // Update policy managers + usersChange, err := s.updatePolicyManagerUsers() + if err != nil { + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager users: %w", err) + } + nodesChange, err := s.updatePolicyManagerNodes() if err != nil { - return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager: %w", err) + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err) } - if !nodesChange.Empty() { - return savedNode.View(), nodesChange, nil + var c change.ChangeSet + if !usersChange.Empty() || !nodesChange.Empty() { + c = change.PolicyChange() + } else { + c = change.NodeAdded(finalNode.ID()) } - return savedNode.View(), change.NodeAdded(savedNode.ID), nil + return finalNode, c, nil } // HandleNodeFromPreAuthKey handles node registration using a pre-authentication key. @@ -1278,9 +1269,6 @@ func (s *State) HandleNodeFromPreAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (types.NodeView, change.ChangeSet, error) { - s.mu.Lock() - defer s.mu.Unlock() - pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey) if err != nil { return types.NodeView{}, change.EmptySet, err @@ -1291,196 +1279,166 @@ func (s *State) HandleNodeFromPreAuthKey( return types.NodeView{}, change.EmptySet, err } - // Check if this is a logout request for an ephemeral node - if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral { - // Find the node to delete - var nodeToDelete types.NodeView - for _, nv := range s.nodeStore.ListNodes().All() { - if nv.Valid() && nv.MachineKey() == machineKey { - nodeToDelete = nv - break - } - } - if nodeToDelete.Valid() { - c, err := s.DeleteNode(nodeToDelete) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err) - } + // Ensure we have valid hostinfo and hostname - handle nil/empty cases + validHostinfo, hostname := util.EnsureValidHostinfo( + regReq.Hostinfo, + machineKey.String(), + regReq.NodeKey.String(), + ) - return types.NodeView{}, c, nil - } - - return types.NodeView{}, change.EmptySet, nil - } + logHostinfoValidation( + machineKey.ShortString(), + regReq.NodeKey.ShortString(), + pak.User.Username(), + hostname, + regReq.Hostinfo, + ) log.Debug(). Caller(). - Str("node.name", regReq.Hostinfo.Hostname). + Str("node.name", hostname). Str("machine.key", machineKey.ShortString()). Str("node.key", regReq.NodeKey.ShortString()). Str("user.name", pak.User.Username()). Msg("Registering node with pre-auth key") - // Check if node already exists with same machine key - var existingNode *types.Node - if nv, exists := s.nodeStore.GetNodeByMachineKey(machineKey); exists && nv.Valid() { - existingNode = nv.AsStruct() - } + var finalNode types.NodeView - // Prepare the node for registration - nodeToRegister := types.Node{ - Hostname: regReq.Hostinfo.Hostname, - UserID: pak.User.ID, - User: pak.User, - MachineKey: machineKey, - NodeKey: regReq.NodeKey, - Hostinfo: regReq.Hostinfo, - LastSeen: ptr.To(time.Now()), - RegisterMethod: util.RegisterMethodAuthKey, - ForcedTags: pak.Proto().GetAclTags(), - AuthKey: pak, - AuthKeyID: &pak.ID, - } + // Check if node already exists with same machine key for this user + existingNodeSameUser, existsSameUser := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(pak.User.ID)) - if !regReq.Expiry.IsZero() { - nodeToRegister.Expiry = ®Req.Expiry - } + // If this node exists for this user, update the node in place. + if existsSameUser && existingNodeSameUser.Valid() { + log.Trace(). + Caller(). + Str("node.name", existingNodeSameUser.Hostname()). + Uint64("node.id", existingNodeSameUser.ID().Uint64()). + Str("machine.key", machineKey.ShortString()). + Str("node.key", existingNodeSameUser.NodeKey().ShortString()). + Str("user.name", pak.User.Username()). + Msg("Node re-registering with existing machine key and user, updating in place") - // Handle IP allocation and existing node properties - var ipv4, ipv6 *netip.Addr - if existingNode != nil && existingNode.UserID == pak.User.ID { - // Reuse existing node properties - nodeToRegister.ID = existingNode.ID - nodeToRegister.GivenName = existingNode.GivenName - nodeToRegister.ApprovedRoutes = existingNode.ApprovedRoutes - ipv4 = existingNode.IPv4 - ipv6 = existingNode.IPv6 - } else { - // Allocate new IPs - ipv4, ipv6, err = s.ipAlloc.Next() - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err) - } - } - - nodeToRegister.IPv4 = ipv4 - nodeToRegister.IPv6 = ipv6 - - // Ensure unique given name if not set - if nodeToRegister.GivenName == "" { - givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to ensure unique given name: %w", err) - } - nodeToRegister.GivenName = givenName - } - - var savedNode *types.Node - if existingNode != nil && existingNode.UserID == pak.User.ID { // Update existing node - NodeStore first, then database - s.nodeStore.UpdateNode(existingNode.ID, func(node *types.Node) { - node.NodeKey = nodeToRegister.NodeKey - node.Hostname = nodeToRegister.Hostname + updatedNodeView, ok := s.nodeStore.UpdateNode(existingNodeSameUser.ID(), func(node *types.Node) { + node.NodeKey = regReq.NodeKey + node.Hostname = hostname // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). // Preserve NetInfo from existing node when re-registering - netInfo := NetInfoFromMapRequest(existingNode.ID, existingNode.Hostinfo, nodeToRegister.Hostinfo) - if netInfo != nil { - if nodeToRegister.Hostinfo != nil { - hostinfoCopy := *nodeToRegister.Hostinfo - hostinfoCopy.NetInfo = netInfo - node.Hostinfo = &hostinfoCopy - } else { - node.Hostinfo = &tailcfg.Hostinfo{NetInfo: netInfo} - } - } else { - node.Hostinfo = nodeToRegister.Hostinfo - } + node.Hostinfo = validHostinfo + node.Hostinfo.NetInfo = preserveNetInfo(existingNodeSameUser, existingNodeSameUser.ID(), validHostinfo) - node.Endpoints = nodeToRegister.Endpoints - node.RegisterMethod = nodeToRegister.RegisterMethod - node.ForcedTags = nodeToRegister.ForcedTags - node.AuthKey = nodeToRegister.AuthKey - node.AuthKeyID = nodeToRegister.AuthKeyID - if nodeToRegister.Expiry != nil { - node.Expiry = nodeToRegister.Expiry - } + node.RegisterMethod = util.RegisterMethodAuthKey + + // TODO(kradalby): This might need a rework as part of #2417 + node.ForcedTags = pak.Proto().GetAclTags() + node.AuthKey = pak + node.AuthKeyID = &pak.ID node.IsOnline = ptr.To(false) node.LastSeen = ptr.To(time.Now()) + + // Update expiry, if it is zero, it means that the node will + // not have an expiry anymore. If it is non-zero, we set that. + node.Expiry = ®Req.Expiry }) + if !ok { + return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID()) + } + + // Use the node from UpdateNode to save to database + _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if !pak.Reusable { + err = hsdb.UsePreAuthKey(tx, pak) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return nil, nil + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) + } + log.Trace(). Caller(). - Str("node.name", nodeToRegister.Hostname). - Uint64("node.id", existingNode.ID.Uint64()). + Str("node.name", updatedNodeView.Hostname()). + Uint64("node.id", updatedNodeView.ID().Uint64()). Str("machine.key", machineKey.ShortString()). - Str("node.key", regReq.NodeKey.ShortString()). + Str("node.key", updatedNodeView.NodeKey().ShortString()). Str("user.name", pak.User.Username()). Msg("Node re-authorized") - // Save to database - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { - return nil, fmt.Errorf("failed to save node: %w", err) - } - - if !pak.Reusable { - err = hsdb.UsePreAuthKey(tx, pak) - if err != nil { - return nil, fmt.Errorf("using pre auth key: %w", err) - } - } - - return &nodeToRegister, nil - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) - } + finalNode = updatedNodeView } else { - // New node - database first to get ID, then NodeStore - savedNode, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { - return nil, fmt.Errorf("failed to save node: %w", err) - } + // Node does not exist for this user with this machine key + // Check if node exists with this machine key for a different user + existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) - if !pak.Reusable { - err = hsdb.UsePreAuthKey(tx, pak) - if err != nil { - return nil, fmt.Errorf("using pre auth key: %w", err) - } - } - - return &nodeToRegister, nil - }) - if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) + if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != pak.User.ID { + // Node exists but belongs to a different user + // Create a NEW node for the new user (do not transfer) + // This allows the same machine to have separate node identities per user + oldUser := existingNodeAnyUser.User() + log.Info(). + Caller(). + Str("existing.node.name", existingNodeAnyUser.Hostname()). + Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). + Str("machine.key", machineKey.ShortString()). + Str("old.user", oldUser.Username()). + Str("new.user", pak.User.Username()). + Msg("Creating new node for different user (same machine key exists for another user)") } - // Add to NodeStore after database creates the ID - s.nodeStore.PutNode(*savedNode) + // This is a new node for this user - create it + // (Either completely new, or new for this user while existing for another user) + + // Create and save new node + var err error + finalNode, err = s.createAndSaveNewNode(newNodeParams{ + User: pak.User, + MachineKey: machineKey, + NodeKey: regReq.NodeKey, + DiscoKey: key.DiscoPublic{}, // DiscoKey not available in RegisterRequest + Hostname: hostname, + Hostinfo: validHostinfo, + Endpoints: nil, // Endpoints not available in RegisterRequest + Expiry: ®Req.Expiry, + RegisterMethod: util.RegisterMethodAuthKey, + PreAuthKey: pak, + ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}), + }) + if err != nil { + return types.NodeView{}, change.EmptySet, fmt.Errorf("creating new node: %w", err) + } } // Update policy managers usersChange, err := s.updatePolicyManagerUsers() if err != nil { - return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager users: %w", err) + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager users: %w", err) } nodesChange, err := s.updatePolicyManagerNodes() if err != nil { - return savedNode.View(), change.NodeAdded(savedNode.ID), fmt.Errorf("failed to update policy manager nodes: %w", err) + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err) } var c change.ChangeSet if !usersChange.Empty() || !nodesChange.Empty() { c = change.PolicyChange() } else { - c = change.NodeAdded(savedNode.ID) + c = change.NodeAdded(finalNode.ID()) } - return savedNode.View(), c, nil + return finalNode, c, nil } // updatePolicyManagerUsers updates the policy manager with current users. @@ -1603,26 +1561,16 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest var needsRouteApproval bool // We need to ensure we update the node as it is in the NodeStore at // the time of the request. - s.nodeStore.UpdateNode(id, func(currentNode *types.Node) { + updatedNode, ok := s.nodeStore.UpdateNode(id, func(currentNode *types.Node) { peerChange := currentNode.PeerChangeFromMapRequest(req) hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) // Get the correct NetInfo to use - netInfo := NetInfoFromMapRequest(id, currentNode.Hostinfo, req.Hostinfo) - - // Apply NetInfo to request Hostinfo + netInfo := netInfoFromMapRequest(id, currentNode.Hostinfo, req.Hostinfo) if req.Hostinfo != nil { - if netInfo != nil { - // Create a copy to avoid modifying the original - hostinfoCopy := *req.Hostinfo - hostinfoCopy.NetInfo = netInfo - req.Hostinfo = &hostinfoCopy - } - } else if netInfo != nil { - // Create minimal Hostinfo with NetInfo - req.Hostinfo = &tailcfg.Hostinfo{ - NetInfo: netInfo, - } + req.Hostinfo.NetInfo = netInfo + } else { + req.Hostinfo = &tailcfg.Hostinfo{NetInfo: netInfo} } // Re-check hostinfoChanged after potential NetInfo preservation @@ -1706,6 +1654,10 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest } }) + if !ok { + return change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", id) + } + nodeRouteChange := change.EmptySet // Handle route changes after NodeStore update @@ -1735,12 +1687,6 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest } if needsRouteUpdate { - // Get the updated node to access its subnet routes - updatedNode, exists := s.GetNodeByID(id) - if !exists { - return change.EmptySet, fmt.Errorf("node disappeared during update: %d", id) - } - // SetNodeRoutes sets the active/distributed routes, so we must use SubnetRoutes() // which returns only the intersection of announced AND approved routes. // Using AnnouncedRoutes() would bypass the security model and auto-approve everything. @@ -1754,7 +1700,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest nodeRouteChange = s.SetNodeRoutes(id, updatedNode.SubnetRoutes()...) } - _, policyChange, err := s.persistNodeToDB(id) + _, policyChange, err := s.persistNodeToDB(updatedNode) if err != nil { return change.EmptySet, fmt.Errorf("saving to database: %w", err) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index a7d25e11..6b20091b 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -638,6 +638,11 @@ func (node Node) DebugString() string { return sb.String() } +func (v NodeView) UserView() UserView { + u := v.User() + return u.View() +} + func (v NodeView) IPs() []netip.Addr { if !v.Valid() { return nil diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 131e8019..b7cb1038 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -104,27 +104,31 @@ func (u *User) profilePicURL() string { return u.ProfilePicURL } -func (u *User) TailscaleUser() *tailcfg.User { - user := tailcfg.User{ +func (u *User) TailscaleUser() tailcfg.User { + return tailcfg.User{ ID: tailcfg.UserID(u.ID), DisplayName: u.Display(), ProfilePicURL: u.profilePicURL(), Created: u.CreatedAt, } - - return &user } -func (u *User) TailscaleLogin() *tailcfg.Login { - login := tailcfg.Login{ +func (u UserView) TailscaleUser() tailcfg.User { + return u.ж.TailscaleUser() +} + +func (u *User) TailscaleLogin() tailcfg.Login { + return tailcfg.Login{ ID: tailcfg.LoginID(u.ID), Provider: u.Provider, LoginName: u.Username(), DisplayName: u.Display(), ProfilePicURL: u.profilePicURL(), } +} - return &login +func (u UserView) TailscaleLogin() tailcfg.Login { + return u.ж.TailscaleLogin() } func (u *User) TailscaleUserProfile() tailcfg.UserProfile { @@ -136,6 +140,10 @@ func (u *User) TailscaleUserProfile() tailcfg.UserProfile { } } +func (u UserView) TailscaleUserProfile() tailcfg.UserProfile { + return u.ж.TailscaleUserProfile() +} + func (u *User) Proto() *v1.User { return &v1.User{ Id: uint64(u.ID), diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index f3843f81..143998cc 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "tailscale.com/tailcfg" "tailscale.com/util/cmpver" ) @@ -258,3 +259,59 @@ func IsCI() bool { return false } + +// SafeHostname extracts a hostname from Hostinfo, providing sensible defaults +// if Hostinfo is nil or Hostname is empty. This prevents nil pointer dereferences +// and ensures nodes always have a valid hostname. +// The hostname is truncated to 63 characters to comply with DNS label length limits (RFC 1123). +func SafeHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string { + if hostinfo == nil || hostinfo.Hostname == "" { + // Generate a default hostname using machine key prefix + if machineKey != "" { + keyPrefix := machineKey + if len(machineKey) > 8 { + keyPrefix = machineKey[:8] + } + return fmt.Sprintf("node-%s", keyPrefix) + } + if nodeKey != "" { + keyPrefix := nodeKey + if len(nodeKey) > 8 { + keyPrefix = nodeKey[:8] + } + return fmt.Sprintf("node-%s", keyPrefix) + } + return "unknown-node" + } + + hostname := hostinfo.Hostname + + // Validate hostname length - DNS label limit is 63 characters (RFC 1123) + // Truncate if necessary to ensure compatibility with given name generation + if len(hostname) > 63 { + hostname = hostname[:63] + } + + return hostname +} + +// EnsureValidHostinfo ensures that Hostinfo is non-nil and has a valid hostname. +// If Hostinfo is nil, it creates a minimal valid Hostinfo with a generated hostname. +// Returns the validated/created Hostinfo and the extracted hostname. +func EnsureValidHostinfo(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) (*tailcfg.Hostinfo, string) { + if hostinfo == nil { + hostname := SafeHostname(nil, machineKey, nodeKey) + return &tailcfg.Hostinfo{ + Hostname: hostname, + }, hostname + } + + hostname := SafeHostname(hostinfo, machineKey, nodeKey) + + // Update the hostname in the hostinfo if it was empty or if it was truncated + if hostinfo.Hostname == "" || hostinfo.Hostname != hostname { + hostinfo.Hostname = hostname + } + + return hostinfo, hostname +} diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 47a2709b..e0414071 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "tailscale.com/tailcfg" ) func TestTailscaleVersionNewerOrEqual(t *testing.T) { @@ -793,3 +794,395 @@ over a maximum of 30 hops: }) } } + +func TestSafeHostname(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + hostinfo *tailcfg.Hostinfo + machineKey string + nodeKey string + want string + }{ + { + name: "valid_hostname", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "test-node", + }, + { + name: "nil_hostinfo_with_machine_key", + hostinfo: nil, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "node-mkey1234", + }, + { + name: "nil_hostinfo_with_node_key_only", + hostinfo: nil, + machineKey: "", + nodeKey: "nkey12345678", + want: "node-nkey1234", + }, + { + name: "nil_hostinfo_no_keys", + hostinfo: nil, + machineKey: "", + nodeKey: "", + want: "unknown-node", + }, + { + name: "empty_hostname_with_machine_key", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "node-mkey1234", + }, + { + name: "empty_hostname_with_node_key_only", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "", + nodeKey: "nkey12345678", + want: "node-nkey1234", + }, + { + name: "empty_hostname_no_keys", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "", + nodeKey: "", + want: "unknown-node", + }, + { + name: "hostname_exactly_63_chars", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "123456789012345678901234567890123456789012345678901234567890123", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "123456789012345678901234567890123456789012345678901234567890123", + }, + { + name: "hostname_64_chars_truncated", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "1234567890123456789012345678901234567890123456789012345678901234", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "123456789012345678901234567890123456789012345678901234567890123", + }, + { + name: "hostname_very_long_truncated", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits-of-63-characters-and-should-be-truncated", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits", + }, + { + name: "hostname_with_special_chars", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "node-with-special!@#$%", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "node-with-special!@#$%", + }, + { + name: "hostname_with_unicode", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "node-ñoño-测试", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "node-ñoño-测试", + }, + { + name: "short_machine_key", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "short", + nodeKey: "nkey12345678", + want: "node-short", + }, + { + name: "short_node_key", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "", + nodeKey: "short", + want: "node-short", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SafeHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + if got != tt.want { + t.Errorf("SafeHostname() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEnsureValidHostinfo(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + hostinfo *tailcfg.Hostinfo + machineKey string + nodeKey string + wantHostname string + checkHostinfo func(*testing.T, *tailcfg.Hostinfo) + }{ + { + name: "valid_hostinfo_unchanged", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + OS: "linux", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "test-node", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "test-node" { + t.Errorf("hostname = %v, want test-node", hi.Hostname) + } + if hi.OS != "linux" { + t.Errorf("OS = %v, want linux", hi.OS) + } + }, + }, + { + name: "nil_hostinfo_creates_default", + hostinfo: nil, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "node-mkey1234", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "node-mkey1234" { + t.Errorf("hostname = %v, want node-mkey1234", hi.Hostname) + } + }, + }, + { + name: "empty_hostname_updated", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + OS: "darwin", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "node-mkey1234", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "node-mkey1234" { + t.Errorf("hostname = %v, want node-mkey1234", hi.Hostname) + } + if hi.OS != "darwin" { + t.Errorf("OS = %v, want darwin", hi.OS) + } + }, + }, + { + name: "long_hostname_truncated", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits-of-63-characters", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "test-node-with-very-long-hostname-that-exceeds-dns-label-limits" { + t.Errorf("hostname = %v, want truncated", hi.Hostname) + } + if len(hi.Hostname) != 63 { + t.Errorf("hostname length = %v, want 63", len(hi.Hostname)) + } + }, + }, + { + name: "nil_hostinfo_node_key_only", + hostinfo: nil, + machineKey: "", + nodeKey: "nkey12345678", + wantHostname: "node-nkey1234", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "node-nkey1234" { + t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname) + } + }, + }, + { + name: "nil_hostinfo_no_keys", + hostinfo: nil, + machineKey: "", + nodeKey: "", + wantHostname: "unknown-node", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "unknown-node" { + t.Errorf("hostname = %v, want unknown-node", hi.Hostname) + } + }, + }, + { + name: "empty_hostname_no_keys", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "", + nodeKey: "", + wantHostname: "unknown-node", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "unknown-node" { + t.Errorf("hostname = %v, want unknown-node", hi.Hostname) + } + }, + }, + { + name: "preserves_other_fields", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + OS: "windows", + OSVersion: "10.0.19044", + DeviceModel: "test-device", + BackendLogID: "log123", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "test", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "test" { + t.Errorf("hostname = %v, want test", hi.Hostname) + } + if hi.OS != "windows" { + t.Errorf("OS = %v, want windows", hi.OS) + } + if hi.OSVersion != "10.0.19044" { + t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion) + } + if hi.DeviceModel != "test-device" { + t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel) + } + if hi.BackendLogID != "log123" { + t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID) + } + }, + }, + { + name: "exactly_63_chars_unchanged", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "123456789012345678901234567890123456789012345678901234567890123", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "123456789012345678901234567890123456789012345678901234567890123", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if len(hi.Hostname) != 63 { + t.Errorf("hostname length = %v, want 63", len(hi.Hostname)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + gotHostinfo, gotHostname := EnsureValidHostinfo(tt.hostinfo, tt.machineKey, tt.nodeKey) + + if gotHostname != tt.wantHostname { + t.Errorf("EnsureValidHostinfo() hostname = %v, want %v", gotHostname, tt.wantHostname) + } + if gotHostinfo == nil { + t.Error("returned hostinfo should never be nil") + } + + if tt.checkHostinfo != nil { + tt.checkHostinfo(t, gotHostinfo) + } + }) + } +} + +func TestSafeHostname_DNSLabelLimit(t *testing.T) { + t.Parallel() + + testCases := []string{ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc", + "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd", + } + + for i, hostname := range testCases { + t.Run(cmp.Diff("", ""), func(t *testing.T) { + hostinfo := &tailcfg.Hostinfo{Hostname: hostname} + result := SafeHostname(hostinfo, "mkey", "nkey") + if len(result) > 63 { + t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) + } + }) + } +} + +func TestEnsureValidHostinfo_Idempotent(t *testing.T) { + t.Parallel() + + originalHostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + OS: "linux", + } + + hostinfo1, hostname1 := EnsureValidHostinfo(originalHostinfo, "mkey", "nkey") + hostinfo2, hostname2 := EnsureValidHostinfo(hostinfo1, "mkey", "nkey") + + if hostname1 != hostname2 { + t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2) + } + if hostinfo1.Hostname != hostinfo2.Hostname { + t.Errorf("hostinfo hostnames not equal: %v != %v", hostinfo1.Hostname, hostinfo2.Hostname) + } + if hostinfo1.OS != hostinfo2.OS { + t.Errorf("hostinfo OS not equal: %v != %v", hostinfo1.OS, hostinfo2.OS) + } +} diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 90034434..7f8a9e8f 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -28,7 +28,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) opts := []hsic.Option{ @@ -43,31 +43,25 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) - expectedNodes := make([]types.NodeID, 0, len(allClients)) - for _, client := range allClients { - status := client.MustStatus() - nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) - assertNoErr(t, err) - expectedNodes = append(expectedNodes, types.NodeID(nodeID)) - } - requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 30*time.Second) + expectedNodes := collectExpectedNodeIDs(t, allClients) + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 120*time.Second) // Validate that all nodes have NetInfo and DERP servers before logout - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 1*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 3*time.Minute) // assertClientsState(t, allClients) @@ -97,19 +91,20 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) // After taking down all nodes, verify all systems show nodes offline requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should have logged out", 120*time.Second) t.Logf("all clients logged out") + t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error listNodes, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count") - }, 20*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to list nodes after logout") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) + }, 30*time.Second, 2*time.Second, "validating node persistence after logout (nodes should remain in database)") for _, node := range listNodes { assertLastSeenSet(t, node) @@ -125,7 +120,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) for _, userName := range spec.Users { key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) @@ -139,12 +134,13 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } } + t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error listNodes, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection") - }, 30*time.Second, 2*time.Second) + assert.NoError(ct, err, "Failed to list nodes after relogin") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) + }, 60*time.Second, 2*time.Second, "validating node count stability after same-user auth key relogin") for _, node := range listNodes { assertLastSeenSet(t, node) @@ -152,11 +148,15 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 120*time.Second) + // Wait for Tailscale sync before validating NetInfo to ensure proper state propagation + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + // Validate that all nodes have NetInfo and DERP servers after reconnection - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 1*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 3*time.Minute) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -197,69 +197,10 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } } -// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database -// and a valid DERP server based on the NetInfo. This function follows the pattern of -// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state. -func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) { - t.Helper() - - startTime := time.Now() - t.Logf("requireAllClientsNetInfoAndDERP: Starting validation at %s - %s", startTime.Format(TimestampFormat), message) - - require.EventuallyWithT(t, func(c *assert.CollectT) { - // Get nodestore state - nodeStore, err := headscale.DebugNodeStore() - assert.NoError(c, err, "Failed to get nodestore debug info") - if err != nil { - return - } - - // Validate node counts first - expectedCount := len(expectedNodes) - assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch") - - // Check each expected node - for _, nodeID := range expectedNodes { - node, exists := nodeStore[nodeID] - assert.True(c, exists, "Node %d not found in nodestore", nodeID) - if !exists { - continue - } - - // Validate that the node has Hostinfo - assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo", nodeID, node.Hostname) - if node.Hostinfo == nil { - continue - } - - // Validate that the node has NetInfo - assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo", nodeID, node.Hostname) - if node.Hostinfo.NetInfo == nil { - continue - } - - // Validate that the node has a valid DERP server (PreferredDERP should be > 0) - preferredDERP := node.Hostinfo.NetInfo.PreferredDERP - assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0), got %d", nodeID, node.Hostname, preferredDERP) - - t.Logf("Node %d (%s) has valid NetInfo with DERP server %d", nodeID, node.Hostname, preferredDERP) - } - }, timeout, 2*time.Second, message) - - endTime := time.Now() - duration := endTime.Sub(startTime) - t.Logf("requireAllClientsNetInfoAndDERP: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message) -} - -func assertLastSeenSet(t *testing.T, node *v1.Node) { - assert.NotNil(t, node) - assert.NotNil(t, node.GetLastSeen()) -} - // This test will first log in two sets of nodes to two sets of users, then -// it will log out all users from user2 and log them in as user1. -// This should leave us with all nodes connected to user1, while user2 -// still has nodes, but they are not connected. +// it will log out all nodes and log them in as user1 using a pre-auth key. +// This should create new nodes for user1 while preserving the original nodes for user2. +// Pre-auth key re-authentication with a different user creates new nodes, not transfers. func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { IntegrationSkip(t) @@ -269,7 +210,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, @@ -277,18 +218,25 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { hsic.WithTLS(), hsic.WithDERPAsIP(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) listNodes, err := headscale.ListNodes() assert.Len(t, allClients, len(listNodes)) @@ -303,12 +251,15 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) t.Logf("all clients logged out") userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) // Create a new authkey for user1, to be used for all clients key, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), true, false) @@ -326,28 +277,43 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error user1Nodes, err = headscale.ListNodes("user1") - assert.NoError(ct, err) - assert.Len(ct, user1Nodes, len(allClients), "User1 should have all clients after re-login") - }, 20*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to list nodes for user1 after relogin") + assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes)) + }, 60*time.Second, 2*time.Second, "validating user1 has all client nodes after auth key relogin") - // Validate that all the old nodes are still present with user2 + // Collect expected node IDs for user1 after relogin + expectedUser1Nodes := make([]types.NodeID, 0, len(user1Nodes)) + for _, node := range user1Nodes { + expectedUser1Nodes = append(expectedUser1Nodes, types.NodeID(node.GetId())) + } + + // Validate connection state after relogin as user1 + requireAllClientsOnline(t, headscale, expectedUser1Nodes, true, "all user1 nodes should be connected after relogin", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedUser1Nodes, "all user1 nodes should have NetInfo and DERP after relogin", 3*time.Minute) + + // Validate that user2 still has their original nodes after user1's re-authentication + // When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created + // for the new user. The original nodes remain with the original user. var user2Nodes []*v1.Node + t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error user2Nodes, err = headscale.ListNodes("user2") - assert.NoError(ct, err) - assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should have half the clients") - }, 20*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin") + assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes)) + }, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)") + t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat)) for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) - assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1", client.Hostname()) - }, 30*time.Second, 2*time.Second) + assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) + }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after auth key user switch", client.Hostname())) } } @@ -362,7 +328,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) opts := []hsic.Option{ @@ -376,13 +342,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -396,7 +362,14 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) listNodes, err := headscale.ListNodes() assert.Len(t, allClients, len(listNodes)) @@ -411,7 +384,10 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) t.Logf("all clients logged out") @@ -425,7 +401,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) for _, userName := range spec.Users { key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) @@ -443,7 +419,8 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { "expire", key.GetKey(), }) - assertNoErr(t, err) + require.NoError(t, err) + require.NoError(t, err) err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) assert.ErrorContains(t, err, "authkey expired") diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index fcb1b4cb..fb05b1ba 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -5,17 +5,20 @@ import ( "net/netip" "net/url" "sort" + "strconv" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/oauth2-proxy/mockoidc" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestOIDCAuthenticationPingAll(t *testing.T) { @@ -34,7 +37,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -52,16 +55,16 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { hsic.WithTLS(), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -73,10 +76,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) listUsers, err := headscale.ListUsers() - assertNoErr(t, err) + require.NoError(t, err) want := []*v1.User{ { @@ -142,7 +145,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ @@ -157,18 +160,18 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { hsic.WithTestName("oidcexpirenodes"), hsic.WithConfigEnv(oidcMap), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) // Record when sync completes to better estimate token expiry timing syncCompleteTime := time.Now() err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) loginDuration := time.Since(syncCompleteTime) t.Logf("Login and sync completed in %v", loginDuration) @@ -349,7 +352,7 @@ func TestOIDC024UserCreation(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ @@ -367,20 +370,20 @@ func TestOIDC024UserCreation(t *testing.T) { hsic.WithTLS(), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) // Ensure that the nodes have logged in, this is what // triggers user creation via OIDC. err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) want := tt.want(scenario.mockOIDC.Issuer()) listUsers, err := headscale.ListUsers() - assertNoErr(t, err) + require.NoError(t, err) sort.Slice(listUsers, func(i, j int) bool { return listUsers[i].GetId() < listUsers[j].GetId() @@ -406,7 +409,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ @@ -424,17 +427,17 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { hsic.WithTLS(), hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) // Get all clients and verify they can connect allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -444,6 +447,11 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) } +// TestOIDCReloginSameNodeNewUser tests the scenario where: +// 1. A Tailscale client logs in with user1 (creates node1 for user1) +// 2. The same client logs out and logs in with user2 (creates node2 for user2) +// 3. The same client logs out and logs in with user1 again (reuses node1, node2 remains) +// This validates that OIDC relogin properly handles node reuse and cleanup. func TestOIDCReloginSameNodeNewUser(t *testing.T) { IntegrationSkip(t) @@ -458,7 +466,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { oidcMockUser("user1", true), }, }) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ @@ -477,24 +485,25 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithDERPAsIP(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) - assertNoErr(t, err) + require.NoError(t, err) u, err := ts.LoginWithURL(headscale.GetEndpoint()) - assertNoErr(t, err) + require.NoError(t, err) _, err = doLoginURL(ts.Hostname(), u) - assertNoErr(t, err) + require.NoError(t, err) + t.Logf("Validating initial user creation at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { listUsers, err := headscale.ListUsers() - assertNoErr(t, err) - assert.Len(t, listUsers, 1) + assert.NoError(ct, err, "Failed to list users during initial validation") + assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) wantUsers := []*v1.User{ { Id: 1, @@ -510,44 +519,61 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }) if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { - t.Fatalf("unexpected users: %s", diff) + ct.Errorf("User validation failed after first login - unexpected users: %s", diff) } - }, 30*time.Second, 1*time.Second, "validating users after first login") + }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") - listNodes, err := headscale.ListNodes() - assertNoErr(t, err) - assert.Len(t, listNodes, 1) + t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var listNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during initial validation") + assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes)) + }, 30*time.Second, 1*time.Second, "validating initial node creation for user1 after OIDC login") + + // Collect expected node IDs for validation after user1 initial login + expectedNodes := make([]types.NodeID, 0, 1) + status := ts.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + + // Validate initial connection state for user1 + validateInitialConnection(t, headscale, expectedNodes) // Log out user1 and log in user2, this should create a new node // for user2, the node should have the same machine key and // a new node key. err = ts.Logout() - assertNoErr(t, err) + require.NoError(t, err) // TODO(kradalby): Not sure why we need to logout twice, but it fails and // logs in immediately after the first logout and I cannot reproduce it // manually. err = ts.Logout() - assertNoErr(t, err) + require.NoError(t, err) // Wait for logout to complete and then do second logout + t.Logf("Waiting for user1 logout completion at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Check that the first logout completed status, err := ts.Status() - assert.NoError(ct, err) - assert.Equal(ct, "NeedsLogin", status.BackendState) - }, 30*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to get client status during logout validation") + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 logout to complete before user2 login") u, err = ts.LoginWithURL(headscale.GetEndpoint()) - assertNoErr(t, err) + require.NoError(t, err) _, err = doLoginURL(ts.Hostname(), u) - assertNoErr(t, err) + require.NoError(t, err) + t.Logf("Validating user2 creation at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { listUsers, err := headscale.ListUsers() - assertNoErr(t, err) - assert.Len(t, listUsers, 2) + assert.NoError(ct, err, "Failed to list users after user2 login") + assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers)) wantUsers := []*v1.User{ { Id: 1, @@ -570,27 +596,83 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }) if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { - ct.Errorf("unexpected users: %s", diff) + ct.Errorf("User validation failed after user2 login - expected both user1 and user2: %s", diff) } - }, 30*time.Second, 1*time.Second, "validating users after new user login") + }, 30*time.Second, 1*time.Second, "validating both user1 and user2 exist after second OIDC login") var listNodesAfterNewUserLogin []*v1.Node + // First, wait for the new node to be created + t.Logf("Waiting for user2 node creation at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { listNodesAfterNewUserLogin, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodesAfterNewUserLogin, 2) + assert.NoError(ct, err, "Failed to list nodes after user2 login") + // We might temporarily have more than 2 nodes during cleanup, so check for at least 2 + assert.GreaterOrEqual(ct, len(listNodesAfterNewUserLogin), 2, "Should have at least 2 nodes after user2 login, got %d (may include temporary nodes during cleanup)", len(listNodesAfterNewUserLogin)) + }, 30*time.Second, 1*time.Second, "waiting for user2 node creation (allowing temporary extra nodes during cleanup)") - // Machine key is the same as the "machine" has not changed, - // but Node key is not as it is a new node - assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) - assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) - assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey()) - }, 30*time.Second, 1*time.Second, "listing nodes after new user login") + // Then wait for cleanup to stabilize at exactly 2 nodes + t.Logf("Waiting for node cleanup stabilization at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listNodesAfterNewUserLogin, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during cleanup validation") + assert.Len(ct, listNodesAfterNewUserLogin, 2, "Should have exactly 2 nodes after cleanup (1 for user1, 1 for user2), got %d nodes", len(listNodesAfterNewUserLogin)) + + // Validate that both nodes have the same machine key but different node keys + if len(listNodesAfterNewUserLogin) >= 2 { + // Machine key is the same as the "machine" has not changed, + // but Node key is not as it is a new node + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Machine key should be preserved from original node") + assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "Both nodes should share the same machine key") + assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey(), "Node keys should be different between user1 and user2 nodes") + } + }, 90*time.Second, 2*time.Second, "waiting for node count stabilization at exactly 2 nodes after user2 login") + + // Security validation: Only user2's node should be active after user switch + var activeUser2NodeID types.NodeID + for _, node := range listNodesAfterNewUserLogin { + if node.GetUser().GetId() == 2 { // user2 + activeUser2NodeID = types.NodeID(node.GetId()) + t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break + } + } + + // Validate only user2's node is online (security requirement) + t.Logf("Validating only user2 node is online at %s", time.Now().Format(TimestampFormat)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + + // Check user2 node is online + if node, exists := nodeStore[activeUser2NodeID]; exists { + assert.NotNil(c, node.IsOnline, "User2 node should have online status") + if node.IsOnline != nil { + assert.True(c, *node.IsOnline, "User2 node should be online after login") + } + } else { + assert.Fail(c, "User2 node not found in nodestore") + } + }, 60*time.Second, 2*time.Second, "validating only user2 node is online after user switch") + + // Before logging out user2, validate we have exactly 2 nodes and both are stable + t.Logf("Pre-logout validation: checking node stability at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + currentNodes, err := headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes before user2 logout") + assert.Len(ct, currentNodes, 2, "Should have exactly 2 stable nodes before user2 logout, got %d", len(currentNodes)) + + // Validate node stability - ensure no phantom nodes + for i, node := range currentNodes { + assert.NotNil(ct, node.GetUser(), "Node %d should have a valid user before logout", i) + assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should have a valid machine key before logout", i) + t.Logf("Pre-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...") + } + }, 60*time.Second, 2*time.Second, "validating stable node count and integrity before user2 logout") // Log out user2, and log into user1, no new node should be created, // the node should now "become" node1 again err = ts.Logout() - assertNoErr(t, err) + require.NoError(t, err) t.Logf("Logged out take one") t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") @@ -599,41 +681,63 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // logs in immediately after the first logout and I cannot reproduce it // manually. err = ts.Logout() - assertNoErr(t, err) + require.NoError(t, err) t.Logf("Logged out take two") t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") // Wait for logout to complete and then do second logout + t.Logf("Waiting for user2 logout completion at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Check that the first logout completed status, err := ts.Status() - assert.NoError(ct, err) - assert.Equal(ct, "NeedsLogin", status.BackendState) - }, 30*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to get client status during user2 logout validation") + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after user2 logout, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user2 logout to complete before user1 relogin") + + // Before logging back in, ensure we still have exactly 2 nodes + // Note: We skip validateLogoutComplete here since it expects all nodes to be offline, + // but in OIDC scenario we maintain both nodes in DB with only active user online + + // Additional validation that nodes are properly maintained during logout + t.Logf("Post-logout validation: checking node persistence at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + currentNodes, err := headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after user2 logout") + assert.Len(ct, currentNodes, 2, "Should still have exactly 2 nodes after user2 logout (nodes should persist), got %d", len(currentNodes)) + + // Ensure both nodes are still valid (not cleaned up incorrectly) + for i, node := range currentNodes { + assert.NotNil(ct, node.GetUser(), "Node %d should still have a valid user after user2 logout", i) + assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should still have a valid machine key after user2 logout", i) + t.Logf("Post-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...") + } + }, 60*time.Second, 2*time.Second, "validating node persistence and integrity after user2 logout") // We do not actually "change" the user here, it is done by logging in again // as the OIDC mock server is kind of like a stack, and the next user is // prepared and ready to go. u, err = ts.LoginWithURL(headscale.GetEndpoint()) - assertNoErr(t, err) + require.NoError(t, err) _, err = doLoginURL(ts.Hostname(), u) - assertNoErr(t, err) + require.NoError(t, err) + t.Logf("Waiting for user1 relogin completion at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := ts.Status() - assert.NoError(ct, err) - assert.Equal(ct, "Running", status.BackendState) - }, 30*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to get client status during user1 relogin validation") + assert.Equal(ct, "Running", status.BackendState, "Expected Running state after user1 relogin, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 relogin to complete (final login)") t.Logf("Logged back in") t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") + t.Logf("Final validation: checking user persistence at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { listUsers, err := headscale.ListUsers() - assert.NoError(ct, err) - assert.Len(ct, listUsers, 2) + assert.NoError(ct, err, "Failed to list users during final validation") + assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers)) wantUsers := []*v1.User{ { Id: 1, @@ -656,37 +760,77 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }) if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { - ct.Errorf("unexpected users: %s", diff) + ct.Errorf("Final user validation failed - both users should persist after relogin cycle: %s", diff) } - }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") + }, 30*time.Second, 1*time.Second, "validating user persistence after complete relogin cycle (user1->user2->user1)") + var listNodesAfterLoggingBackIn []*v1.Node + // Wait for login to complete and nodes to stabilize + t.Logf("Final node validation: checking node stability after user1 relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { - listNodesAfterLoggingBackIn, err := headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodesAfterLoggingBackIn, 2) + listNodesAfterLoggingBackIn, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during final validation") + + // Allow for temporary instability during login process + if len(listNodesAfterLoggingBackIn) < 2 { + ct.Errorf("Not enough nodes yet during final validation, got %d, want at least 2", len(listNodesAfterLoggingBackIn)) + return + } + + // Final check should have exactly 2 nodes + assert.Len(ct, listNodesAfterLoggingBackIn, 2, "Should have exactly 2 nodes after complete relogin cycle, got %d", len(listNodesAfterLoggingBackIn)) // Validate that the machine we had when we logged in the first time, has the same // machine key, but a different ID than the newly logged in version of the same // machine. - assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey()) - assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey()) - assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId()) - assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey()) - assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId()) - assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId()) + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Original user1 machine key should match user1 node after user switch") + assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey(), "Original user1 node key should match user1 node after user switch") + assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId(), "Original user1 node ID should match user1 node after user switch") + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "User1 and user2 nodes should share the same machine key") + assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId(), "User1 and user2 nodes should have different node IDs") + assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId(), "User1 and user2 nodes should belong to different users") // Even tho we are logging in again with the same user, the previous key has been expired // and a new one has been generated. The node entry in the database should be the same // as the user + machinekey still matches. - assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey()) - assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey()) - assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId()) + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey(), "Machine key should remain consistent after user1 relogin") + assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey(), "Node key should be regenerated after user1 relogin") + assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId(), "Node ID should be preserved for user1 after relogin") // The "logged back in" machine should have the same machinekey but a different nodekey // than the version logged in with a different user. - assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey()) - assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey()) - }, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created") + assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey(), "Both final nodes should share the same machine key") + assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey(), "Final nodes should have different node keys for different users") + + t.Logf("Final validation complete - node counts and key relationships verified at %s", time.Now().Format(TimestampFormat)) + }, 60*time.Second, 2*time.Second, "validating final node state after complete user1->user2->user1 relogin cycle with detailed key validation") + + // Security validation: Only user1's node should be active after relogin + var activeUser1NodeID types.NodeID + for _, node := range listNodesAfterLoggingBackIn { + if node.GetUser().GetId() == 1 { // user1 + activeUser1NodeID = types.NodeID(node.GetId()) + t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break + } + } + + // Validate only user1's node is online (security requirement) + t.Logf("Validating only user1 node is online after relogin at %s", time.Now().Format(TimestampFormat)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + + // Check user1 node is online + if node, exists := nodeStore[activeUser1NodeID]; exists { + assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin") + if node.IsOnline != nil { + assert.True(c, *node.IsOnline, "User1 node should be online after relogin") + } + } else { + assert.Fail(c, "User1 node not found in nodestore after relogin") + } + }, 60*time.Second, 2*time.Second, "validating only user1 node is online after final relogin") } // TestOIDCFollowUpUrl validates the follow-up login flow @@ -709,7 +853,7 @@ func TestOIDCFollowUpUrl(t *testing.T) { }, ) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ @@ -730,43 +874,43 @@ func TestOIDCFollowUpUrl(t *testing.T) { hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), hsic.WithEmbeddedDERPServerOnly(), ) - assertNoErrHeadscaleEnv(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) listUsers, err := headscale.ListUsers() - assertNoErr(t, err) + require.NoError(t, err) assert.Empty(t, listUsers) ts, err := scenario.CreateTailscaleNode( "unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), ) - assertNoErr(t, err) + require.NoError(t, err) u, err := ts.LoginWithURL(headscale.GetEndpoint()) - assertNoErr(t, err) + require.NoError(t, err) // wait for the registration cache to expire // a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION time.Sleep(2 * time.Minute) st, err := ts.Status() - assertNoErr(t, err) + require.NoError(t, err) assert.Equal(t, "NeedsLogin", st.BackendState) // get new AuthURL from daemon newUrl, err := url.Parse(st.AuthURL) - assertNoErr(t, err) + require.NoError(t, err) assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change") _, err = doLoginURL(ts.Hostname(), newUrl) - assertNoErr(t, err) + require.NoError(t, err) listUsers, err = headscale.ListUsers() - assertNoErr(t, err) + require.NoError(t, err) assert.Len(t, listUsers, 1) wantUsers := []*v1.User{ @@ -795,30 +939,230 @@ func TestOIDCFollowUpUrl(t *testing.T) { } listNodes, err := headscale.ListNodes() - assertNoErr(t, err) + require.NoError(t, err) assert.Len(t, listNodes, 1) } -// assertTailscaleNodesLogout verifies that all provided Tailscale clients -// are in the logged-out state (NeedsLogin). -func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { - if h, ok := t.(interface{ Helper() }); ok { - h.Helper() +// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client +// authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user. +// +// OIDC is an authentication layer built on top of OAuth 2.0 that allows users to authenticate +// using external identity providers (like Google, Microsoft, etc.) rather than managing +// credentials directly in headscale. +// +// This test validates the "same user relogin" behavior in headscale's OIDC authentication flow: +// - A single client authenticates via OIDC as user1 +// - The client logs out, ending the session +// - The same client logs back in via OIDC as the same user (user1) +// - The test verifies that the user account persists correctly +// - The test verifies that the machine key is preserved (since it's the same physical device) +// - The test verifies that the node ID is preserved (since it's the same user on the same device) +// - The test verifies that the node key is regenerated (since it's a new session) +// - The test verifies that the client comes back online properly +// +// This scenario is important for normal user workflows where someone might need to restart +// their Tailscale client, reboot their computer, or temporarily disconnect and reconnect. +// It ensures that headscale properly handles session management while preserving device +// identity and user associations. +// +// The test uses a single node scenario (unlike multi-node tests) to focus specifically on +// the authentication and session management aspects rather than network topology changes. +// The "same node" in the name refers to the same physical device/client, while "same user" +// refers to authenticating with the same OIDC identity. +func TestOIDCReloginSameNodeSameUser(t *testing.T) { + IntegrationSkip(t) + + // Create scenario with same user for both login attempts + scenario, err := NewScenario(ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), // Initial login + oidcMockUser("user1", true), // Relogin with same user + }, + }) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", } - for _, client := range clients { - status, err := client.Status() - assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) - assert.Equal(t, "NeedsLogin", status.BackendState, - "client %s should be logged out", client.Hostname()) - } -} + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcsameuser"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithDERPAsIP(), + ) + requireNoErrHeadscaleEnv(t, err) -func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { - return mockoidc.MockUser{ - Subject: username, - PreferredUsername: username, - Email: username + "@headscale.net", - EmailVerified: emailVerified, - } + headscale, err := scenario.Headscale() + require.NoError(t, err) + + ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) + require.NoError(t, err) + + // Initial login as user1 + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + t.Logf("Validating initial user1 creation at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assert.NoError(ct, err, "Failed to list users during initial validation") + assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + ct.Errorf("User validation failed after first login - unexpected users: %s", diff) + } + }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") + + t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var initialNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + initialNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during initial validation") + assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes)) + }, 30*time.Second, 1*time.Second, "validating initial node creation for user1 after OIDC login") + + // Collect expected node IDs for validation after user1 initial login + expectedNodes := make([]types.NodeID, 0, 1) + status := ts.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + + // Validate initial connection state for user1 + validateInitialConnection(t, headscale, expectedNodes) + + // Store initial node keys for comparison + initialMachineKey := initialNodes[0].GetMachineKey() + initialNodeKey := initialNodes[0].GetNodeKey() + initialNodeID := initialNodes[0].GetId() + + // Logout user1 + err = ts.Logout() + require.NoError(t, err) + + // TODO(kradalby): Not sure why we need to logout twice, but it fails and + // logs in immediately after the first logout and I cannot reproduce it + // manually. + err = ts.Logout() + require.NoError(t, err) + + // Wait for logout to complete + t.Logf("Waiting for user1 logout completion at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check that the logout completed + status, err := ts.Status() + assert.NoError(ct, err, "Failed to get client status during logout validation") + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 logout to complete before same-user relogin") + + // Validate node persistence during logout (node should remain in DB) + t.Logf("Validating node persistence during logout at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listNodes, err := headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during logout validation") + assert.Len(ct, listNodes, 1, "Should still have exactly 1 node during logout (node should persist in DB), got %d", len(listNodes)) + }, 30*time.Second, 1*time.Second, "validating node persistence in database during same-user logout") + + // Login again as the same user (user1) + u, err = ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + t.Logf("Waiting for user1 relogin completion at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := ts.Status() + assert.NoError(ct, err, "Failed to get client status during relogin validation") + assert.Equal(ct, "Running", status.BackendState, "Expected Running state after user1 relogin, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 relogin to complete (same user)") + + t.Logf("Final validation: checking user persistence after same-user relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assert.NoError(ct, err, "Failed to list users during final validation") + assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + ct.Errorf("Final user validation failed - user1 should persist after same-user relogin: %s", diff) + } + }, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle") + + var finalNodes []*v1.Node + t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + finalNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during final validation") + assert.Len(ct, finalNodes, 1, "Should have exactly 1 node after same-user relogin, got %d", len(finalNodes)) + + // Validate node key behavior for same user relogin + finalNode := finalNodes[0] + + // Machine key should be preserved (same physical machine) + assert.Equal(ct, initialMachineKey, finalNode.GetMachineKey(), "Machine key should be preserved for same user same node relogin") + + // Node ID should be preserved (same user, same machine) + assert.Equal(ct, initialNodeID, finalNode.GetId(), "Node ID should be preserved for same user same node relogin") + + // Node key should be regenerated (new session after logout) + assert.NotEqual(ct, initialNodeKey, finalNode.GetNodeKey(), "Node key should be regenerated after logout/relogin even for same user") + + t.Logf("Final validation complete - same user relogin key relationships verified at %s", time.Now().Format(TimestampFormat)) + }, 60*time.Second, 2*time.Second, "validating final node state after same-user OIDC relogin cycle with key preservation validation") + + // Security validation: user1's node should be active after relogin + activeUser1NodeID := types.NodeID(finalNodes[0].GetId()) + t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + + // Check user1 node is online + if node, exists := nodeStore[activeUser1NodeID]; exists { + assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin") + if node.IsOnline != nil { + assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin") + } + } else { + assert.Fail(c, "User1 node not found in nodestore after same-user relogin") + } + }, 60*time.Second, 2*time.Second, "validating user1 node is online after same-user OIDC relogin") } diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index ff190142..5dd546f3 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -1,15 +1,19 @@ package integration import ( + "fmt" "net/netip" "slices" "testing" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { @@ -33,16 +37,16 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { hsic.WithDERPAsIP(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -54,7 +58,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) } -func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { +func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { IntegrationSkip(t) spec := ScenarioSpec{ @@ -63,7 +67,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnvWithLoginURL( @@ -72,16 +76,16 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { hsic.WithDERPAsIP(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -93,15 +97,22 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error listNodes, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodes, len(allClients), "Node count should match client count after login") - }, 20*time.Second, 1*time.Second) + assert.NoError(ct, err, "Failed to list nodes after web authentication") + assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) + }, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication") nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -122,7 +133,10 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + validateLogoutComplete(t, headscale, expectedNodes) t.Logf("all clients logged out") @@ -135,8 +149,20 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients logged in again") + t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after web flow logout") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) + }, 60*time.Second, 2*time.Second, "validating node persistence in database after web flow logout") + t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes)) + + // Validate connection state after relogin + validateReloginComplete(t, headscale, expectedNodes) + allIps, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) allAddrs = lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -145,14 +171,6 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { success = pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) - assert.EventuallyWithT(t, func(ct *assert.CollectT) { - var err error - listNodes, err = headscale.ListNodes() - assert.NoError(ct, err) - assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count after re-login") - }, 20*time.Second, 1*time.Second) - t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes)) - for _, client := range allClients { ips, err := client.IPs() if err != nil { @@ -180,3 +198,166 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients IPs are the same") } + +// TestAuthWebFlowLogoutAndReloginNewUser tests the scenario where multiple Tailscale clients +// initially authenticate using the web-based authentication flow (where users visit a URL +// in their browser to authenticate), then all clients log out and log back in as a different user. +// +// This test validates the "user switching" behavior in headscale's web authentication flow: +// - Multiple clients authenticate via web flow, each to their respective users (user1, user2) +// - All clients log out simultaneously +// - All clients log back in via web flow, but this time they all authenticate as user1 +// - The test verifies that user1 ends up with all the client nodes +// - The test verifies that user2's original nodes still exist in the database but are offline +// - The test verifies network connectivity works after the user switch +// +// This scenario is important for organizations that need to reassign devices between users +// or when consolidating multiple user accounts. It ensures that headscale properly handles +// the security implications of user switching while maintaining node persistence in the database. +// +// The test uses headscale's web authentication flow, which is the most user-friendly method +// where authentication happens through a web browser rather than pre-shared keys or OIDC. +func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("webflowrelnewuser"), + hsic.WithDERPAsIP(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + validateInitialConnection(t, headscale, expectedNodes) + + var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after initial web authentication") + assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) + }, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication") + nodeCountBeforeLogout := len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) + + // Log out all clients + for _, client := range allClients { + err := client.Logout() + if err != nil { + t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) + } + } + + err = scenario.WaitForTailscaleLogout() + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + validateLogoutComplete(t, headscale, expectedNodes) + + t.Logf("all clients logged out") + + // Log all clients back in as user1 using web flow + // We manually iterate over all clients and authenticate each one as user1 + // This tests the cross-user re-authentication behavior where ALL clients + // (including those originally from user2) are registered to user1 + for _, client := range allClients { + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + if err != nil { + t.Fatalf("failed to get login URL for client %s: %s", client.Hostname(), err) + } + + body, err := doLoginURL(client.Hostname(), loginURL) + if err != nil { + t.Fatalf("failed to complete login for client %s: %s", client.Hostname(), err) + } + + // Register all clients as user1 (this is where cross-user registration happens) + // This simulates: headscale nodes register --user user1 --key + scenario.runHeadscaleRegister("user1", body) + } + + // Wait for all clients to reach running state + for _, client := range allClients { + err := client.WaitForRunning(integrationutil.PeerSyncTimeout()) + if err != nil { + t.Fatalf("%s tailscale node has not reached running: %s", client.Hostname(), err) + } + } + + t.Logf("all clients logged back in as user1") + + var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + user1Nodes, err = headscale.ListNodes("user1") + assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin") + assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes)) + }, 60*time.Second, 2*time.Second, "validating user1 has all client nodes after web flow user switch relogin") + + // Collect expected node IDs for user1 after relogin + expectedUser1Nodes := make([]types.NodeID, 0, len(user1Nodes)) + for _, node := range user1Nodes { + expectedUser1Nodes = append(expectedUser1Nodes, types.NodeID(node.GetId())) + } + + // Validate connection state after relogin as user1 + validateReloginComplete(t, headscale, expectedUser1Nodes) + + // Validate that user2's old nodes still exist in database (but are expired/offline) + // When CLI registration creates new nodes for user1, user2's old nodes remain + var user2Nodes []*v1.Node + t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + user2Nodes, err = headscale.ListNodes("user2") + assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1") + assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes)) + }, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1") + + t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) + assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after web flow user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) + }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after web flow user switch", client.Hostname())) + } + + // Test connectivity after user switch + allIps, err = scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d after web flow user switch", success, len(allClients)*len(allIps)) +} diff --git a/integration/cli_test.go b/integration/cli_test.go index 98e2ddf3..40afd2c3 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -54,14 +54,14 @@ func TestUserCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) var listUsers []*v1.User var result []string @@ -99,7 +99,7 @@ func TestUserCommand(t *testing.T) { "--new-name=newname", }, ) - assertNoErr(t, err) + require.NoError(t, err) var listAfterRenameUsers []*v1.User assert.EventuallyWithT(t, func(ct *assert.CollectT) { @@ -138,7 +138,7 @@ func TestUserCommand(t *testing.T) { }, &listByUsername, ) - assertNoErr(t, err) + require.NoError(t, err) slices.SortFunc(listByUsername, sortWithID) want := []*v1.User{ @@ -165,7 +165,7 @@ func TestUserCommand(t *testing.T) { }, &listByID, ) - assertNoErr(t, err) + require.NoError(t, err) slices.SortFunc(listByID, sortWithID) want = []*v1.User{ @@ -244,7 +244,7 @@ func TestUserCommand(t *testing.T) { }, &listAfterNameDelete, ) - assertNoErr(t, err) + require.NoError(t, err) require.Empty(t, listAfterNameDelete) } @@ -260,17 +260,17 @@ func TestPreAuthKeyCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipak")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) keys := make([]*v1.PreAuthKey, count) - assertNoErr(t, err) + require.NoError(t, err) for index := range count { var preAuthKey v1.PreAuthKey @@ -292,7 +292,7 @@ func TestPreAuthKeyCommand(t *testing.T) { }, &preAuthKey, ) - assertNoErr(t, err) + require.NoError(t, err) keys[index] = &preAuthKey } @@ -313,7 +313,7 @@ func TestPreAuthKeyCommand(t *testing.T) { }, &listedPreAuthKeys, ) - assertNoErr(t, err) + require.NoError(t, err) // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 4) @@ -372,7 +372,7 @@ func TestPreAuthKeyCommand(t *testing.T) { listedPreAuthKeys[1].GetKey(), }, ) - assertNoErr(t, err) + require.NoError(t, err) var listedPreAuthKeysAfterExpire []v1.PreAuthKey err = executeAndUnmarshal( @@ -388,7 +388,7 @@ func TestPreAuthKeyCommand(t *testing.T) { }, &listedPreAuthKeysAfterExpire, ) - assertNoErr(t, err) + require.NoError(t, err) assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now())) @@ -404,14 +404,14 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipaknaexp")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) var preAuthKey v1.PreAuthKey err = executeAndUnmarshal( @@ -428,7 +428,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { }, &preAuthKey, ) - assertNoErr(t, err) + require.NoError(t, err) var listedPreAuthKeys []v1.PreAuthKey err = executeAndUnmarshal( @@ -444,7 +444,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { }, &listedPreAuthKeys, ) - assertNoErr(t, err) + require.NoError(t, err) // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 2) @@ -465,14 +465,14 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipakresueeph")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) var preAuthReusableKey v1.PreAuthKey err = executeAndUnmarshal( @@ -489,7 +489,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { }, &preAuthReusableKey, ) - assertNoErr(t, err) + require.NoError(t, err) var preAuthEphemeralKey v1.PreAuthKey err = executeAndUnmarshal( @@ -506,7 +506,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { }, &preAuthEphemeralKey, ) - assertNoErr(t, err) + require.NoError(t, err) assert.True(t, preAuthEphemeralKey.GetEphemeral()) assert.False(t, preAuthEphemeralKey.GetReusable()) @@ -525,7 +525,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { }, &listedPreAuthKeys, ) - assertNoErr(t, err) + require.NoError(t, err) // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 3) @@ -543,7 +543,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -552,13 +552,13 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) u2, err := headscale.CreateUser(user2) - assertNoErr(t, err) + require.NoError(t, err) var user2Key v1.PreAuthKey @@ -580,7 +580,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { }, &user2Key, ) - assertNoErr(t, err) + require.NoError(t, err) var listNodes []*v1.Node assert.EventuallyWithT(t, func(ct *assert.CollectT) { @@ -592,7 +592,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { }, 15*time.Second, 1*time.Second) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) require.Len(t, allClients, 1) @@ -600,10 +600,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { // Log out from user1 err = client.Logout() - assertNoErr(t, err) + require.NoError(t, err) err = scenario.WaitForTailscaleLogout() - assertNoErr(t, err) + require.NoError(t, err) assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() @@ -613,7 +613,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { }, 30*time.Second, 2*time.Second) err = client.Login(headscale.GetEndpoint(), user2Key.GetKey()) - assertNoErr(t, err) + require.NoError(t, err) assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() @@ -642,14 +642,14 @@ func TestApiKeyCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) keys := make([]string, count) @@ -808,14 +808,14 @@ func TestNodeTagCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) regIDs := []string{ types.MustRegistrationID().String(), @@ -1007,7 +1007,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -1015,10 +1015,10 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy(tt.policy), ) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Test list all nodes after added seconds resultMachines := make([]*v1.Node, spec.NodesPerUser) @@ -1058,14 +1058,14 @@ func TestNodeCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) regIDs := []string{ types.MustRegistrationID().String(), @@ -1302,14 +1302,14 @@ func TestNodeExpireCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) regIDs := []string{ types.MustRegistrationID().String(), @@ -1427,14 +1427,14 @@ func TestNodeRenameCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) regIDs := []string{ types.MustRegistrationID().String(), @@ -1462,7 +1462,7 @@ func TestNodeRenameCommand(t *testing.T) { "json", }, ) - assertNoErr(t, err) + require.NoError(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1480,7 +1480,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &node, ) - assertNoErr(t, err) + require.NoError(t, err) nodes[index] = &node } @@ -1591,20 +1591,20 @@ func TestNodeMoveCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Randomly generated node key regID := types.MustRegistrationID() userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) _, err = headscale.Execute( []string{ @@ -1753,7 +1753,7 @@ func TestPolicyCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -1763,10 +1763,10 @@ func TestPolicyCommand(t *testing.T) { "HEADSCALE_POLICY_MODE": "database", }), ) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) p := policyv2.Policy{ ACLs: []policyv2.ACL{ @@ -1789,7 +1789,7 @@ func TestPolicyCommand(t *testing.T) { policyFilePath := "/etc/headscale/policy.json" err = headscale.WriteFile(policyFilePath, pBytes) - assertNoErr(t, err) + require.NoError(t, err) // No policy is present at this time. // Add a new policy from a file. @@ -1803,7 +1803,7 @@ func TestPolicyCommand(t *testing.T) { }, ) - assertNoErr(t, err) + require.NoError(t, err) // Get the current policy and check // if it is the same as the one we set. @@ -1819,7 +1819,7 @@ func TestPolicyCommand(t *testing.T) { }, &output, ) - assertNoErr(t, err) + require.NoError(t, err) assert.Len(t, output.TagOwners, 1) assert.Len(t, output.ACLs, 1) @@ -1834,7 +1834,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -1844,10 +1844,10 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { "HEADSCALE_POLICY_MODE": "database", }), ) - assertNoErr(t, err) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) p := policyv2.Policy{ ACLs: []policyv2.ACL{ @@ -1872,7 +1872,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath := "/etc/headscale/policy.json" err = headscale.WriteFile(policyFilePath, pBytes) - assertNoErr(t, err) + require.NoError(t, err) // No policy is present at this time. // Add a new policy from a file. diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index 4a5e52ae..60260bb1 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/require" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/netmon" @@ -23,7 +24,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { // Generate random hostname for the headscale instance hash, err := util.GenerateRandomStringDNSSafe(6) - assertNoErr(t, err) + require.NoError(t, err) testName := "derpverify" hostname := fmt.Sprintf("hs-%s-%s", testName, hash) @@ -31,7 +32,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { // Create cert for headscale certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname) - assertNoErr(t, err) + require.NoError(t, err) spec := ScenarioSpec{ NodesPerUser: len(MustTestVersions), @@ -39,14 +40,14 @@ func TestDERPVerifyEndpoint(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) derper, err := scenario.CreateDERPServer("head", dsic.WithCACert(certHeadscale), dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))), ) - assertNoErr(t, err) + require.NoError(t, err) derpRegion := tailcfg.DERPRegion{ RegionCode: "test-derpverify", @@ -74,17 +75,17 @@ func TestDERPVerifyEndpoint(t *testing.T) { hsic.WithPort(headscalePort), hsic.WithCustomTLS(certHeadscale, keyHeadscale), hsic.WithDERPConfig(derpMap)) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) fakeKey := key.NewNode() DERPVerify(t, fakeKey, derpRegion, false) for _, client := range allClients { nodeKey, err := client.GetNodePrivateKey() - assertNoErr(t, err) + require.NoError(t, err) DERPVerify(t, *nodeKey, derpRegion, true) } } diff --git a/integration/dns_test.go b/integration/dns_test.go index 7cac4d47..7267bc09 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "tailscale.com/tailcfg" ) @@ -22,26 +23,26 @@ func TestResolveMagicDNS(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("magicdns")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) // Poor mans cache _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) _, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) for _, client := range allClients { for _, peer := range allClients { @@ -78,7 +79,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) const erPath = "/tmp/extra_records.json" @@ -109,29 +110,29 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) // Poor mans cache _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) _, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6") } hs, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Write the file directly into place from the docker API. b0, _ := json.Marshal([]tailcfg.DNSRecord{ @@ -143,7 +144,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { }) err = hs.WriteFile(erPath, b0) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "2.2.2.2") @@ -159,9 +160,9 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { b2, _ := json.Marshal(extraRecords) err = hs.WriteFile(erPath+"2", b2) - assertNoErr(t, err) + require.NoError(t, err) _, err = hs.Execute([]string{"mv", erPath + "2", erPath}) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6") @@ -179,9 +180,9 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { }) err = hs.WriteFile(erPath+"3", b3) - assertNoErr(t, err) + require.NoError(t, err) _, err = hs.Execute([]string{"cp", erPath + "3", erPath}) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8") @@ -197,7 +198,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { }) command := []string{"echo", fmt.Sprintf("'%s'", string(b4)), ">", erPath} _, err = hs.Execute([]string{"bash", "-c", strings.Join(command, " ")}) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9") @@ -205,7 +206,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { // Delete the file and create a new one to ensure it is picked up again. _, err = hs.Execute([]string{"rm", erPath}) - assertNoErr(t, err) + require.NoError(t, err) // The same paths should still be available as it is not cleared on delete. assert.EventuallyWithT(t, func(ct *assert.CollectT) { @@ -219,7 +220,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { // Write a new file, the backoff mechanism should make the filewatcher pick it up // again. err = hs.WriteFile(erPath, b3) - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8") diff --git a/integration/dockertestutil/build.go b/integration/dockertestutil/build.go new file mode 100644 index 00000000..635f91ef --- /dev/null +++ b/integration/dockertestutil/build.go @@ -0,0 +1,17 @@ +package dockertestutil + +import ( + "os/exec" +) + +// RunDockerBuildForDiagnostics runs docker build manually to get detailed error output. +// This is used when a docker build fails to provide more detailed diagnostic information +// than what dockertest typically provides. +func RunDockerBuildForDiagnostics(contextDir, dockerfile string) string { + cmd := exec.Command("docker", "build", "-f", dockerfile, contextDir) + output, err := cmd.CombinedOutput() + if err != nil { + return string(output) + } + return "" +} diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index e9ba69dd..17cb01af 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -29,7 +30,7 @@ func TestDERPServerScenario(t *testing.T) { derpServerScenario(t, spec, false, func(scenario *Scenario) { allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) t.Logf("checking %d clients for websocket connections", len(allClients)) for _, client := range allClients { @@ -43,7 +44,7 @@ func TestDERPServerScenario(t *testing.T) { } hsServer, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) derpRegion := tailcfg.DERPRegion{ RegionCode: "test-derpverify", @@ -79,7 +80,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) { derpServerScenario(t, spec, true, func(scenario *Scenario) { allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) t.Logf("checking %d clients for websocket connections", len(allClients)) for _, client := range allClients { @@ -108,7 +109,7 @@ func derpServerScenario( IntegrationSkip(t) scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -128,16 +129,16 @@ func derpServerScenario( "HEADSCALE_DERP_SERVER_VERIFY_CLIENTS": "true", }), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allHostnames, err := scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { diff --git a/integration/general_test.go b/integration/general_test.go index 65131af0..ab6d4f71 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -10,19 +10,15 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" - "github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/tsic" "github.com/rs/zerolog/log" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "tailscale.com/client/tailscale/apitype" "tailscale.com/types/key" @@ -38,7 +34,7 @@ func TestPingAllByIP(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -48,16 +44,16 @@ func TestPingAllByIP(t *testing.T) { hsic.WithTLS(), hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) hs, err := scenario.Headscale() require.NoError(t, err) @@ -80,7 +76,7 @@ func TestPingAllByIP(t *testing.T) { // Get headscale instance for batcher debug check headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Test our DebugBatcher functionality t.Logf("Testing DebugBatcher functionality...") @@ -99,23 +95,23 @@ func TestPingAllByIPPublicDERP(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( []tsic.Option{}, hsic.WithTestName("pingallbyippubderp"), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -148,11 +144,11 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) headscale, err := scenario.Headscale(opts...) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) for _, userName := range spec.Users { user, err := scenario.CreateUser(userName) @@ -177,13 +173,13 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -200,7 +196,7 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) t.Logf("all clients logged out") @@ -222,7 +218,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) headscale, err := scenario.Headscale( @@ -231,7 +227,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s", }), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) for _, userName := range spec.Users { user, err := scenario.CreateUser(userName) @@ -256,13 +252,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -344,22 +340,22 @@ func TestPingAllByHostname(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) allHostnames, err := scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) success := pingAllHelper(t, allClients, allHostnames) @@ -379,7 +375,7 @@ func TestTaildrop(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, @@ -387,17 +383,17 @@ func TestTaildrop(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // This will essentially fetch and cache all the FQDNs _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { if !strings.Contains(client.Hostname(), "head") { @@ -498,7 +494,7 @@ func TestTaildrop(t *testing.T) { ) result, _, err := client.Execute(command) - assertNoErrf(t, "failed to execute command to ls taildrop: %s", err) + require.NoErrorf(t, err, "failed to execute command to ls taildrop") log.Printf("Result for %s: %s\n", peer.Hostname(), result) if fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()) != result { @@ -528,25 +524,25 @@ func TestUpdateHostnameFromClient(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErrf(t, "failed to create scenario: %s", err) + require.NoErrorf(t, err, "failed to create scenario") defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) // update hostnames using the up command for _, client := range allClients { status, err := client.Status() - assertNoErr(t, err) + require.NoError(t, err) command := []string{ "tailscale", @@ -554,11 +550,11 @@ func TestUpdateHostnameFromClient(t *testing.T) { "--hostname=" + hostnames[string(status.Self.ID)], } _, _, err = client.Execute(command) - assertNoErrf(t, "failed to set hostname: %s", err) + require.NoErrorf(t, err, "failed to set hostname") } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // Wait for nodestore batch processing to complete // NodeStore batching timeout is 500ms, so we wait up to 1 second @@ -597,7 +593,7 @@ func TestUpdateHostnameFromClient(t *testing.T) { "--identifier", strconv.FormatUint(node.GetId(), 10), }) - assertNoErr(t, err) + require.NoError(t, err) } // Verify that the server-side rename is reflected in DNSName while HostName remains unchanged @@ -643,7 +639,7 @@ func TestUpdateHostnameFromClient(t *testing.T) { for _, client := range allClients { status, err := client.Status() - assertNoErr(t, err) + require.NoError(t, err) command := []string{ "tailscale", @@ -651,11 +647,11 @@ func TestUpdateHostnameFromClient(t *testing.T) { "--hostname=" + hostnames[string(status.Self.ID)] + "NEW", } _, _, err = client.Execute(command) - assertNoErrf(t, "failed to set hostname: %s", err) + require.NoErrorf(t, err, "failed to set hostname") } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // Wait for nodestore batch processing to complete // NodeStore batching timeout is 500ms, so we wait up to 1 second @@ -696,20 +692,20 @@ func TestExpireNode(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -731,22 +727,22 @@ func TestExpireNode(t *testing.T) { } headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // TODO(kradalby): This is Headscale specific and would not play nicely // with other implementations of the ControlServer interface result, err := headscale.Execute([]string{ "headscale", "nodes", "expire", "--identifier", "1", "--output", "json", }) - assertNoErr(t, err) + require.NoError(t, err) var node v1.Node err = json.Unmarshal([]byte(result), &node) - assertNoErr(t, err) + require.NoError(t, err) var expiredNodeKey key.NodePublic err = expiredNodeKey.UnmarshalText([]byte(node.GetNodeKey())) - assertNoErr(t, err) + require.NoError(t, err) t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String()) @@ -773,14 +769,14 @@ func TestExpireNode(t *testing.T) { // Verify that the expired node has been marked in all peers list. for _, client := range allClients { status, err := client.Status() - assertNoErr(t, err) + require.NoError(t, err) if client.Hostname() != node.GetName() { t.Logf("available peers of %s: %v", client.Hostname(), status.Peers()) // Ensures that the node is present, and that it is expired. if peerStatus, ok := status.Peer[expiredNodeKey]; ok { - assertNotNil(t, peerStatus.Expired) + requireNotNil(t, peerStatus.Expired) assert.NotNil(t, peerStatus.KeyExpiry) t.Logf( @@ -840,20 +836,20 @@ func TestNodeOnlineStatus(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -866,14 +862,14 @@ func TestNodeOnlineStatus(t *testing.T) { for _, client := range allClients { status, err := client.Status() - assertNoErr(t, err) + require.NoError(t, err) // Assert that we have the original count - self assert.Len(t, status.Peers(), len(MustTestVersions)-1) } headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Duration is chosen arbitrarily, 10m is reported in #1561 testDuration := 12 * time.Minute @@ -963,7 +959,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -973,16 +969,16 @@ func TestPingAllByIPManyUpDown(t *testing.T) { hsic.WithDERPAsIP(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // assertClientsState(t, allClients) @@ -992,7 +988,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) { // Get headscale instance for batcher debug checks headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Initial check: all nodes should be connected to batcher // Extract node IDs for validation @@ -1000,7 +996,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) { for _, client := range allClients { status := client.MustStatus() nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) - assertNoErr(t, err) + require.NoError(t, err) expectedNodes = append(expectedNodes, types.NodeID(nodeID)) } requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second) @@ -1072,7 +1068,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( @@ -1081,16 +1077,16 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -1100,7 +1096,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // Test list all nodes after added otherUser var nodeList []v1.Node @@ -1170,159 +1166,3 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { assert.True(t, nodeListAfter[0].GetOnline()) assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId()) } - -// NodeSystemStatus represents the online status of a node across different systems -type NodeSystemStatus struct { - Batcher bool - BatcherConnCount int - MapResponses bool - NodeStore bool -} - -// requireAllSystemsOnline checks that nodes are online/offline across batcher, mapresponses, and nodestore -func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { - t.Helper() - - startTime := time.Now() - t.Logf("requireAllSystemsOnline: Starting validation at %s - %s", startTime.Format(TimestampFormat), message) - - var prevReport string - require.EventuallyWithT(t, func(c *assert.CollectT) { - // Get batcher state - debugInfo, err := headscale.DebugBatcher() - assert.NoError(c, err, "Failed to get batcher debug info") - if err != nil { - return - } - - // Get map responses - mapResponses, err := headscale.GetAllMapReponses() - assert.NoError(c, err, "Failed to get map responses") - if err != nil { - return - } - - // Get nodestore state - nodeStore, err := headscale.DebugNodeStore() - assert.NoError(c, err, "Failed to get nodestore debug info") - if err != nil { - return - } - - // Validate node counts first - expectedCount := len(expectedNodes) - assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch") - assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch") - - // Check that we have map responses for expected nodes - mapResponseCount := len(mapResponses) - assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch") - - // Build status map for each node - nodeStatus := make(map[types.NodeID]NodeSystemStatus) - - // Initialize all expected nodes - for _, nodeID := range expectedNodes { - nodeStatus[nodeID] = NodeSystemStatus{} - } - - // Check batcher state - for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes { - nodeID := types.MustParseNodeID(nodeIDStr) - if status, exists := nodeStatus[nodeID]; exists { - status.Batcher = nodeInfo.Connected - status.BatcherConnCount = nodeInfo.ActiveConnections - nodeStatus[nodeID] = status - } - } - - // Check map responses using buildExpectedOnlineMap - onlineFromMaps := make(map[types.NodeID]bool) - onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses) - for nodeID := range nodeStatus { - NODE_STATUS: - for id, peerMap := range onlineMap { - if id == nodeID { - continue - } - - online := peerMap[nodeID] - // If the node is offline in any map response, we consider it offline - if !online { - onlineFromMaps[nodeID] = false - continue NODE_STATUS - } - - onlineFromMaps[nodeID] = true - } - } - assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") - - // Update status with map response data - for nodeID, online := range onlineFromMaps { - if status, exists := nodeStatus[nodeID]; exists { - status.MapResponses = online - nodeStatus[nodeID] = status - } - } - - // Check nodestore state - for nodeID, node := range nodeStore { - if status, exists := nodeStatus[nodeID]; exists { - // Check if node is online in nodestore - status.NodeStore = node.IsOnline != nil && *node.IsOnline - nodeStatus[nodeID] = status - } - } - - // Verify all systems show nodes in expected state and report failures - allMatch := true - var failureReport strings.Builder - - ids := types.NodeIDs(maps.Keys(nodeStatus)) - slices.Sort(ids) - for _, nodeID := range ids { - status := nodeStatus[nodeID] - systemsMatch := (status.Batcher == expectedOnline) && - (status.MapResponses == expectedOnline) && - (status.NodeStore == expectedOnline) - - if !systemsMatch { - allMatch = false - stateStr := "offline" - if expectedOnline { - stateStr = "online" - } - failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s:\n", nodeID, stateStr)) - failureReport.WriteString(fmt.Sprintf(" - batcher: %t\n", status.Batcher)) - failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) - failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (down with at least one peer)\n", status.MapResponses)) - failureReport.WriteString(fmt.Sprintf(" - nodestore: %t\n", status.NodeStore)) - } - } - - if !allMatch { - if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" { - t.Log("Diff between reports:") - t.Logf("Prev report: \n%s\n", prevReport) - t.Logf("New report: \n%s\n", failureReport.String()) - t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") - prevReport = failureReport.String() - } - - failureReport.WriteString("timestamp: " + time.Now().Format(TimestampFormat) + "\n") - - assert.Fail(c, failureReport.String()) - } - - stateStr := "offline" - if expectedOnline { - stateStr = "online" - } - assert.True(c, allMatch, fmt.Sprintf("Not all nodes are %s across all systems", stateStr)) - }, timeout, 2*time.Second, message) - - endTime := time.Now() - duration := endTime.Sub(startTime) - t.Logf("requireAllSystemsOnline: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message) -} diff --git a/integration/helpers.go b/integration/helpers.go new file mode 100644 index 00000000..8e81fa9b --- /dev/null +++ b/integration/helpers.go @@ -0,0 +1,922 @@ +package integration + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/netip" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/cenkalti/backoff/v5" + "github.com/google/go-cmp/cmp" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/juanfont/headscale/integration/integrationutil" + "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +const ( + // derpPingTimeout defines the timeout for individual DERP ping operations + // Used in DERP connectivity tests to verify relay server communication. + derpPingTimeout = 2 * time.Second + + // derpPingCount defines the number of ping attempts for DERP connectivity tests + // Higher count provides better reliability assessment of DERP connectivity. + derpPingCount = 10 + + // TimestampFormat is the standard timestamp format used across all integration tests + // Format: "2006-01-02T15-04-05.999999999" provides high precision timestamps + // suitable for debugging and log correlation in integration tests. + TimestampFormat = "2006-01-02T15-04-05.999999999" + + // TimestampFormatRunID is used for generating unique run identifiers + // Format: "20060102-150405" provides compact date-time for file/directory names. + TimestampFormatRunID = "20060102-150405" +) + +// NodeSystemStatus represents the status of a node across different systems +type NodeSystemStatus struct { + Batcher bool + BatcherConnCount int + MapResponses bool + NodeStore bool +} + +// requireNotNil validates that an object is not nil and fails the test if it is. +// This helper provides consistent error messaging for nil checks in integration tests. +func requireNotNil(t *testing.T, object interface{}) { + t.Helper() + require.NotNil(t, object) +} + +// requireNoErrHeadscaleEnv validates that headscale environment creation succeeded. +// Provides specific error context for headscale environment setup failures. +func requireNoErrHeadscaleEnv(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to create headscale environment") +} + +// requireNoErrGetHeadscale validates that headscale server retrieval succeeded. +// Provides specific error context for headscale server access failures. +func requireNoErrGetHeadscale(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to get headscale") +} + +// requireNoErrListClients validates that client listing operations succeeded. +// Provides specific error context for client enumeration failures. +func requireNoErrListClients(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to list clients") +} + +// requireNoErrListClientIPs validates that client IP retrieval succeeded. +// Provides specific error context for client IP address enumeration failures. +func requireNoErrListClientIPs(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to get client IPs") +} + +// requireNoErrSync validates that client synchronization operations succeeded. +// Provides specific error context for client sync failures across the network. +func requireNoErrSync(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to have all clients sync up") +} + +// requireNoErrListFQDN validates that FQDN listing operations succeeded. +// Provides specific error context for DNS name enumeration failures. +func requireNoErrListFQDN(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to list FQDNs") +} + +// requireNoErrLogout validates that tailscale node logout operations succeeded. +// Provides specific error context for client logout failures. +func requireNoErrLogout(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to log out tailscale nodes") +} + +// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes +func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID { + t.Helper() + + expectedNodes := make([]types.NodeID, 0, len(clients)) + for _, client := range clients { + status := client.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + } + return expectedNodes +} + +// validateInitialConnection performs comprehensive validation after initial client login. +// Validates that all nodes are online and have proper NetInfo/DERP configuration, +// essential for ensuring successful initial connection state in relogin tests. +func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { + t.Helper() + + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) +} + +// validateLogoutComplete performs comprehensive validation after client logout. +// Ensures all nodes are properly offline across all headscale systems, +// critical for validating clean logout state in relogin tests. +func validateLogoutComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { + t.Helper() + + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) +} + +// validateReloginComplete performs comprehensive validation after client relogin. +// Validates that all nodes are back online with proper NetInfo/DERP configuration, +// ensuring successful relogin state restoration in integration tests. +func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { + t.Helper() + + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin", 3*time.Minute) +} + +// requireAllClientsOnline validates that all nodes are online/offline across all headscale systems +// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems +func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { + t.Helper() + + startTime := time.Now() + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message) + + if expectedOnline { + // For online validation, use the existing logic with full timeout + requireAllClientsOnlineWithSingleTimeout(t, headscale, expectedNodes, expectedOnline, message, timeout) + } else { + // For offline validation, use staged approach with component-specific timeouts + requireAllClientsOfflineStaged(t, headscale, expectedNodes, message, timeout) + } + + endTime := time.Now() + t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message) +} + +// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state +func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { + t.Helper() + + var prevReport string + require.EventuallyWithT(t, func(c *assert.CollectT) { + // Get batcher state + debugInfo, err := headscale.DebugBatcher() + assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { + return + } + + // Get map responses + mapResponses, err := headscale.GetAllMapReponses() + assert.NoError(c, err, "Failed to get map responses") + if err != nil { + return + } + + // Get nodestore state + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + // Validate that all expected nodes are present in nodeStore + for _, nodeID := range expectedNodes { + _, exists := nodeStore[nodeID] + assert.True(c, exists, "Expected node %d not found in nodeStore", nodeID) + } + + // Check that we have map responses for expected nodes + mapResponseCount := len(mapResponses) + expectedCount := len(expectedNodes) + assert.GreaterOrEqual(c, mapResponseCount, expectedCount, "MapResponses insufficient - expected at least %d responses, got %d", expectedCount, mapResponseCount) + + // Build status map for each node + nodeStatus := make(map[types.NodeID]NodeSystemStatus) + + // Initialize all expected nodes + for _, nodeID := range expectedNodes { + nodeStatus[nodeID] = NodeSystemStatus{} + } + + // Check batcher state for expected nodes + for _, nodeID := range expectedNodes { + nodeIDStr := fmt.Sprintf("%d", nodeID) + if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists { + if status, exists := nodeStatus[nodeID]; exists { + status.Batcher = nodeInfo.Connected + status.BatcherConnCount = nodeInfo.ActiveConnections + nodeStatus[nodeID] = status + } + } else { + // Node not found in batcher, mark as disconnected + if status, exists := nodeStatus[nodeID]; exists { + status.Batcher = false + status.BatcherConnCount = 0 + nodeStatus[nodeID] = status + } + } + } + + // Check map responses using buildExpectedOnlineMap + onlineFromMaps := make(map[types.NodeID]bool) + onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses) + + // For single node scenarios, we can't validate peer visibility since there are no peers + if len(expectedNodes) == 1 { + // For single node, just check that we have map responses for the node + for nodeID := range nodeStatus { + if _, exists := onlineMap[nodeID]; exists { + onlineFromMaps[nodeID] = true + } else { + onlineFromMaps[nodeID] = false + } + } + } else { + // Multi-node scenario: check peer visibility + for nodeID := range nodeStatus { + // Initialize as offline - will be set to true only if visible in all relevant peer maps + onlineFromMaps[nodeID] = false + + // Count how many peer maps should show this node + expectedPeerMaps := 0 + foundOnlinePeerMaps := 0 + + for id, peerMap := range onlineMap { + if id == nodeID { + continue // Skip self-references + } + expectedPeerMaps++ + + if online, exists := peerMap[nodeID]; exists && online { + foundOnlinePeerMaps++ + } + } + + // Node is considered online if it appears online in all peer maps + // (or if there are no peer maps to check) + if expectedPeerMaps == 0 || foundOnlinePeerMaps == expectedPeerMaps { + onlineFromMaps[nodeID] = true + } + } + } + assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") + + // Update status with map response data + for nodeID, online := range onlineFromMaps { + if status, exists := nodeStatus[nodeID]; exists { + status.MapResponses = online + nodeStatus[nodeID] = status + } + } + + // Check nodestore state for expected nodes + for _, nodeID := range expectedNodes { + if node, exists := nodeStore[nodeID]; exists { + if status, exists := nodeStatus[nodeID]; exists { + // Check if node is online in nodestore + status.NodeStore = node.IsOnline != nil && *node.IsOnline + nodeStatus[nodeID] = status + } + } + } + + // Verify all systems show nodes in expected state and report failures + allMatch := true + var failureReport strings.Builder + + ids := types.NodeIDs(maps.Keys(nodeStatus)) + slices.Sort(ids) + for _, nodeID := range ids { + status := nodeStatus[nodeID] + systemsMatch := (status.Batcher == expectedOnline) && + (status.MapResponses == expectedOnline) && + (status.NodeStore == expectedOnline) + + if !systemsMatch { + allMatch = false + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat))) + failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline)) + failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) + failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (expected: %t, down with at least one peer)\n", status.MapResponses, expectedOnline)) + failureReport.WriteString(fmt.Sprintf(" - nodestore: %t (expected: %t)\n", status.NodeStore, expectedOnline)) + } + } + + if !allMatch { + if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" { + t.Logf("Node state validation report changed at %s:", time.Now().Format(TimestampFormat)) + t.Logf("Previous report:\n%s", prevReport) + t.Logf("Current report:\n%s", failureReport.String()) + t.Logf("Report diff:\n%s", diff) + prevReport = failureReport.String() + } + + failureReport.WriteString(fmt.Sprintf("validation_timestamp: %s\n", time.Now().Format(TimestampFormat))) + // Note: timeout_remaining not available in this context + + assert.Fail(c, failureReport.String()) + } + + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr)) + }, timeout, 2*time.Second, message) +} + +// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components +func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) { + t.Helper() + + // Stage 1: Verify batcher disconnection (should be immediate) + t.Logf("Stage 1: Verifying batcher disconnection for %d nodes", len(expectedNodes)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + debugInfo, err := headscale.DebugBatcher() + assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { + return + } + + allBatcherOffline := true + for _, nodeID := range expectedNodes { + nodeIDStr := fmt.Sprintf("%d", nodeID) + if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected { + allBatcherOffline = false + assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID) + } + } + assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher") + }, 15*time.Second, 1*time.Second, "batcher disconnection validation") + + // Stage 2: Verify nodestore offline status (up to 15 seconds due to disconnect detection delay) + t.Logf("Stage 2: Verifying nodestore offline status for %d nodes (allowing for 10s disconnect detection delay)", len(expectedNodes)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + allNodeStoreOffline := true + for _, nodeID := range expectedNodes { + if node, exists := nodeStore[nodeID]; exists { + isOnline := node.IsOnline != nil && *node.IsOnline + if isOnline { + allNodeStoreOffline = false + assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID) + } + } + } + assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore") + }, 20*time.Second, 1*time.Second, "nodestore offline validation") + + // Stage 3: Verify map response propagation (longest delay due to peer update timing) + t.Logf("Stage 3: Verifying map response propagation for %d nodes (allowing for peer map update delays)", len(expectedNodes)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + mapResponses, err := headscale.GetAllMapReponses() + assert.NoError(c, err, "Failed to get map responses") + if err != nil { + return + } + + onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses) + allMapResponsesOffline := true + + if len(expectedNodes) == 1 { + // Single node: check if it appears in map responses + for nodeID := range onlineMap { + if slices.Contains(expectedNodes, nodeID) { + allMapResponsesOffline = false + assert.False(c, true, "Node %d should not appear in map responses", nodeID) + } + } + } else { + // Multi-node: check peer visibility + for _, nodeID := range expectedNodes { + for id, peerMap := range onlineMap { + if id == nodeID { + continue // Skip self-references + } + if online, exists := peerMap[nodeID]; exists && online { + allMapResponsesOffline = false + assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id) + } + } + } + } + assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses") + }, 60*time.Second, 2*time.Second, "map response propagation validation") + + t.Logf("All stages completed: nodes are fully offline across all systems") +} + +// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database +// and a valid DERP server based on the NetInfo. This function follows the pattern of +// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state. +func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) { + t.Helper() + + startTime := time.Now() + t.Logf("requireAllClientsNetInfoAndDERP: Starting NetInfo/DERP validation for %d nodes at %s - %s", len(expectedNodes), startTime.Format(TimestampFormat), message) + + require.EventuallyWithT(t, func(c *assert.CollectT) { + // Get nodestore state + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + // Validate that all expected nodes are present in nodeStore + for _, nodeID := range expectedNodes { + _, exists := nodeStore[nodeID] + assert.True(c, exists, "Expected node %d not found in nodeStore during NetInfo validation", nodeID) + } + + // Check each expected node + for _, nodeID := range expectedNodes { + node, exists := nodeStore[nodeID] + assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID) + if !exists { + continue + } + + // Validate that the node has Hostinfo + assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname) + if node.Hostinfo == nil { + t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) + continue + } + + // Validate that the node has NetInfo + assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname) + if node.Hostinfo.NetInfo == nil { + t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) + continue + } + + // Validate that the node has a valid DERP server (PreferredDERP should be > 0) + preferredDERP := node.Hostinfo.NetInfo.PreferredDERP + assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP) + + t.Logf("Node %d (%s) has valid NetInfo with DERP server %d at %s", nodeID, node.Hostname, preferredDERP, time.Now().Format(TimestampFormat)) + } + }, timeout, 5*time.Second, message) + + endTime := time.Now() + duration := endTime.Sub(startTime) + t.Logf("requireAllClientsNetInfoAndDERP: Completed NetInfo/DERP validation for %d nodes at %s - Duration: %v - %s", len(expectedNodes), endTime.Format(TimestampFormat), duration, message) +} + +// assertLastSeenSet validates that a node has a non-nil LastSeen timestamp. +// Critical for ensuring node activity tracking is functioning properly. +func assertLastSeenSet(t *testing.T, node *v1.Node) { + assert.NotNil(t, node) + assert.NotNil(t, node.GetLastSeen()) +} + +// assertTailscaleNodesLogout verifies that all provided Tailscale clients +// are in the logged-out state (NeedsLogin). +func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { + if h, ok := t.(interface{ Helper() }); ok { + h.Helper() + } + + for _, client := range clients { + status, err := client.Status() + assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) + assert.Equal(t, "NeedsLogin", status.BackendState, + "client %s should be logged out", client.Hostname()) + } +} + +// pingAllHelper performs ping tests between all clients and addresses, returning success count. +// This is used to validate network connectivity in integration tests. +// Returns the total number of successful ping operations. +func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { + t.Helper() + success := 0 + + for _, client := range clients { + for _, addr := range addrs { + err := client.Ping(addr, opts...) + if err != nil { + t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err) + } else { + success++ + } + } + } + + return success +} + +// pingDerpAllHelper performs DERP-based ping tests between all clients and addresses. +// This specifically tests connectivity through DERP relay servers, which is important +// for validating NAT traversal and relay functionality. Returns success count. +func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { + t.Helper() + success := 0 + + for _, client := range clients { + for _, addr := range addrs { + if isSelfClient(client, addr) { + continue + } + + err := client.Ping( + addr, + tsic.WithPingTimeout(derpPingTimeout), + tsic.WithPingCount(derpPingCount), + tsic.WithPingUntilDirect(false), + ) + if err != nil { + t.Logf("failed to ping %s from %s: %s", addr, client.Hostname(), err) + } else { + success++ + } + } + } + + return success +} + +// isSelfClient determines if the given address belongs to the client itself. +// Used to avoid self-ping operations in connectivity tests by checking +// hostname and IP address matches. +func isSelfClient(client TailscaleClient, addr string) bool { + if addr == client.Hostname() { + return true + } + + ips, err := client.IPs() + if err != nil { + return false + } + + for _, ip := range ips { + if ip.String() == addr { + return true + } + } + + return false +} + +// assertClientsState validates the status and netmap of a list of clients for general connectivity. +// Runs parallel validation of status, netcheck, and netmap for all clients to ensure +// they have proper network configuration for all-to-all connectivity tests. +func assertClientsState(t *testing.T, clients []TailscaleClient) { + t.Helper() + + var wg sync.WaitGroup + + for _, client := range clients { + wg.Add(1) + c := client // Avoid loop pointer + go func() { + defer wg.Done() + assertValidStatus(t, c) + assertValidNetcheck(t, c) + assertValidNetmap(t, c) + }() + } + + t.Logf("waiting for client state checks to finish") + wg.Wait() +} + +// assertValidNetmap validates that a client's netmap has all required fields for proper operation. +// Checks self node and all peers for essential networking data including hostinfo, addresses, +// endpoints, and DERP configuration. Skips validation for Tailscale versions below 1.56. +// This test is not suitable for ACL/partial connection tests. +func assertValidNetmap(t *testing.T, client TailscaleClient) { + t.Helper() + + if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) { + t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version()) + + return + } + + t.Logf("Checking netmap of %q", client.Hostname()) + + netmap, err := client.Netmap() + if err != nil { + t.Fatalf("getting netmap for %q: %s", client.Hostname(), err) + } + + assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) + if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { + assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) + } + + assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) + assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) + + assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname()) + + assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) + assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) + assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) + + for _, peer := range netmap.Peers { + assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) + assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) + + assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) + if hi := peer.Hostinfo(); hi.Valid() { + assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) + + // Netinfo is not always set + // assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) + if ni := hi.NetInfo(); ni.Valid() { + assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) + } + } + + assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) + + assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) + + assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) + } +} + +// assertValidStatus validates that a client's status has all required fields for proper operation. +// Checks self and peer status for essential data including hostinfo, tailscale IPs, endpoints, +// and network map presence. This test is not suitable for ACL/partial connection tests. +func assertValidStatus(t *testing.T, client TailscaleClient) { + t.Helper() + status, err := client.Status(true) + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname()) + + assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname()) + + // This seem to not appear until version 1.56 + if status.Self.AllowedIPs != nil { + assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname()) + } + + assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname()) + + assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname()) + + assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname()) + + // This isn't really relevant for Self as it won't be in its own socket/wireguard. + // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) + // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname()) + + for _, peer := range status.Peer { + assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname()) + + assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname()) + + // This seem to not appear until version 1.56 + if peer.AllowedIPs != nil { + assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname()) + } + + // Addrs does not seem to appear in the status from peers. + // assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname()) + assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname()) + + // TODO(kradalby): InEngine is only true when a proper tunnel is set up, + // there might be some interesting stuff to test here in the future. + // assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname()) + } +} + +// assertValidNetcheck validates that a client has a proper DERP relay configured. +// Ensures the client has discovered and selected a DERP server for relay functionality, +// which is essential for NAT traversal and connectivity in restricted networks. +func assertValidNetcheck(t *testing.T, client TailscaleClient) { + t.Helper() + report, err := client.Netcheck() + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) +} + +// assertCommandOutputContains executes a command with exponential backoff retry until the output +// contains the expected string or timeout is reached (10 seconds). +// This implements eventual consistency patterns and should be used instead of time.Sleep +// before executing commands that depend on network state propagation. +// +// Timeout: 10 seconds with exponential backoff +// Use cases: DNS resolution, route propagation, policy updates. +func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { + t.Helper() + + _, err := backoff.Retry(t.Context(), func() (struct{}, error) { + stdout, stderr, err := c.Execute(command) + if err != nil { + return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err) + } + + if !strings.Contains(stdout, contains) { + return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) + } + + return struct{}{}, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + + assert.NoError(t, err) +} + +// dockertestMaxWait returns the maximum wait time for Docker-based test operations. +// Uses longer timeouts in CI environments to account for slower resource allocation +// and higher system load during automated testing. +func dockertestMaxWait() time.Duration { + wait := 300 * time.Second //nolint + + if util.IsCI() { + wait = 600 * time.Second //nolint + } + + return wait +} + +// didClientUseWebsocketForDERP analyzes client logs to determine if WebSocket was used for DERP. +// Searches for WebSocket connection indicators in client logs to validate +// DERP relay communication method for debugging connectivity issues. +func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool { + t.Helper() + + buf := &bytes.Buffer{} + err := client.WriteLogs(buf, buf) + if err != nil { + t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err) + } + + count, err := countMatchingLines(buf, func(line string) bool { + return strings.Contains(line, "websocket: connected to ") + }) + if err != nil { + t.Fatalf("failed to process client logs: %s: %s", client.Hostname(), err) + } + + return count > 0 +} + +// countMatchingLines counts lines in a reader that match the given predicate function. +// Uses optimized buffering for log analysis and provides flexible line-by-line +// filtering for log parsing and pattern matching in integration tests. +func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) { + count := 0 + scanner := bufio.NewScanner(in) + { + const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB + buff := make([]byte, logBufferInitialSize) + scanner.Buffer(buff, len(buff)) + scanner.Split(bufio.ScanLines) + } + + for scanner.Scan() { + if predicate(scanner.Text()) { + count += 1 + } + } + + return count, scanner.Err() +} + +// wildcard returns a wildcard alias (*) for use in policy v2 configurations. +// Provides a convenient helper for creating permissive policy rules. +func wildcard() policyv2.Alias { + return policyv2.Wildcard +} + +// usernamep returns a pointer to a Username as an Alias for policy v2 configurations. +// Used in ACL rules to reference specific users in network access policies. +func usernamep(name string) policyv2.Alias { + return ptr.To(policyv2.Username(name)) +} + +// hostp returns a pointer to a Host as an Alias for policy v2 configurations. +// Used in ACL rules to reference specific hosts in network access policies. +func hostp(name string) policyv2.Alias { + return ptr.To(policyv2.Host(name)) +} + +// groupp returns a pointer to a Group as an Alias for policy v2 configurations. +// Used in ACL rules to reference user groups in network access policies. +func groupp(name string) policyv2.Alias { + return ptr.To(policyv2.Group(name)) +} + +// tagp returns a pointer to a Tag as an Alias for policy v2 configurations. +// Used in ACL rules to reference node tags in network access policies. +func tagp(name string) policyv2.Alias { + return ptr.To(policyv2.Tag(name)) +} + +// prefixp returns a pointer to a Prefix from a CIDR string for policy v2 configurations. +// Converts CIDR notation to policy prefix format for network range specifications. +func prefixp(cidr string) policyv2.Alias { + prefix := netip.MustParsePrefix(cidr) + return ptr.To(policyv2.Prefix(prefix)) +} + +// aliasWithPorts creates an AliasWithPorts structure from an alias and port ranges. +// Combines network targets with specific port restrictions for fine-grained +// access control in policy v2 configurations. +func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts { + return policyv2.AliasWithPorts{ + Alias: alias, + Ports: ports, + } +} + +// usernameOwner returns a Username as an Owner for use in TagOwners policies. +// Specifies which users can assign and manage specific tags in ACL configurations. +func usernameOwner(name string) policyv2.Owner { + return ptr.To(policyv2.Username(name)) +} + +// groupOwner returns a Group as an Owner for use in TagOwners policies. +// Specifies which groups can assign and manage specific tags in ACL configurations. +func groupOwner(name string) policyv2.Owner { + return ptr.To(policyv2.Group(name)) +} + +// usernameApprover returns a Username as an AutoApprover for subnet route policies. +// Specifies which users can automatically approve subnet route advertisements. +func usernameApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Username(name)) +} + +// groupApprover returns a Group as an AutoApprover for subnet route policies. +// Specifies which groups can automatically approve subnet route advertisements. +func groupApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Group(name)) +} + +// tagApprover returns a Tag as an AutoApprover for subnet route policies. +// Specifies which tagged nodes can automatically approve subnet route advertisements. +func tagApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Tag(name)) +} + +// oidcMockUser creates a MockUser for OIDC authentication testing. +// Generates consistent test user data with configurable email verification status +// for validating OIDC integration flows in headscale authentication tests. +func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { + return mockoidc.MockUser{ + Subject: username, + PreferredUsername: username, + Email: username + "@headscale.net", + EmailVerified: emailVerified, + } +} diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 9c28dc00..553b8b1c 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -460,6 +460,12 @@ func New( dockertestutil.DockerAllowNetworkAdministration, ) if err != nil { + // Try to get more detailed build output + log.Printf("Docker build failed, attempting to get detailed output...") + buildOutput := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, IntegrationTestDockerFileName) + if buildOutput != "" { + return nil, fmt.Errorf("could not start headscale container: %w\n\nDetailed build output:\n%s", err, buildOutput) + } return nil, fmt.Errorf("could not start headscale container: %w", err) } log.Printf("Created %s container\n", hsic.hostname) diff --git a/integration/route_test.go b/integration/route_test.go index 9aced164..a613c375 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -53,16 +53,16 @@ func TestEnablingRoutes(t *testing.T) { err = scenario.CreateHeadscaleEnv( []tsic.Option{tsic.WithAcceptRoutes()}, hsic.WithTestName("clienableroute")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) expectedRoutes := map[string]string{ "1": "10.0.0.0/24", @@ -83,7 +83,7 @@ func TestEnablingRoutes(t *testing.T) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) var nodes []*v1.Node // Wait for route advertisements to propagate to NodeStore @@ -256,16 +256,16 @@ func TestHASubnetRouterFailover(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) prefp, err := scenario.SubnetOfNetwork("usernet1") require.NoError(t, err) @@ -319,7 +319,7 @@ func TestHASubnetRouterFailover(t *testing.T) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // Wait for route configuration changes after advertising routes var nodes []*v1.Node @@ -1341,16 +1341,16 @@ func TestSubnetRouteACL(t *testing.T) { }, }, )) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) expectedRoutes := map[string]string{ "1": "10.33.0.0/16", @@ -1393,7 +1393,7 @@ func TestSubnetRouteACL(t *testing.T) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) // Wait for route advertisements to propagate to the server var nodes []*v1.Node @@ -1572,25 +1572,25 @@ func TestEnablingExitRoutes(t *testing.T) { } scenario, err := NewScenario(spec) - assertNoErrf(t, "failed to create scenario: %s", err) + require.NoErrorf(t, err, "failed to create scenario") defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv([]tsic.Option{ tsic.WithExtraLoginArgs([]string{"--advertise-exit-node"}), }, hsic.WithTestName("clienableroute")) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) nodes, err := headscale.ListNodes() require.NoError(t, err) @@ -1686,16 +1686,16 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) assert.NotNil(t, headscale) pref, err := scenario.SubnetOfNetwork("usernet1") @@ -1833,16 +1833,16 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) assert.NotNil(t, headscale) var user1c, user2c TailscaleClient @@ -2247,13 +2247,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { err = scenario.createHeadscaleEnv(tt.withURL, tsOpts, opts..., ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) services, err := scenario.Services("usernet1") require.NoError(t, err) @@ -2263,7 +2263,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) assert.NotNil(t, headscale) // Add the Docker network route to the auto-approvers @@ -2304,21 +2304,21 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if tt.withURL { u, err := routerUsernet1.LoginWithURL(headscale.GetEndpoint()) - assertNoErr(t, err) + require.NoError(t, err) body, err := doLoginURL(routerUsernet1.Hostname(), u) - assertNoErr(t, err) + require.NoError(t, err) scenario.runHeadscaleRegister("user1", body) } else { userMap, err := headscale.MapUsers() - assertNoErr(t, err) + require.NoError(t, err) pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) - assertNoErr(t, err) + require.NoError(t, err) err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey()) - assertNoErr(t, err) + require.NoError(t, err) } // extra creation end. @@ -2893,13 +2893,13 @@ func TestSubnetRouteACLFiltering(t *testing.T) { hsic.WithACLPolicy(aclPolicy), hsic.WithPolicyMode(types.PolicyModeDB), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) // Get the router and node clients by user routerClients, err := scenario.ListTailscaleClients(routerUser) @@ -2944,7 +2944,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { require.NoErrorf(t, err, "failed to advertise routes: %s", err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) var routerNode, nodeNode *v1.Node // Wait for route advertisements to propagate to NodeStore diff --git a/integration/scenario.go b/integration/scenario.go index 8382d6a8..b48e3265 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -838,14 +838,14 @@ func doLoginURL(hostname string, loginURL *url.URL) (string, error) { var err error hc := &http.Client{ - Transport: LoggingRoundTripper{}, + Transport: LoggingRoundTripper{Hostname: hostname}, } hc.Jar, err = cookiejar.New(nil) if err != nil { return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err) } - log.Printf("%s logging in with url", hostname) + log.Printf("%s logging in with url: %s", hostname, loginURL.String()) ctx := context.Background() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) resp, err := hc.Do(req) @@ -907,7 +907,9 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable) } -type LoggingRoundTripper struct{} +type LoggingRoundTripper struct { + Hostname string +} func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { noTls := &http.Transport{ @@ -918,9 +920,12 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error return nil, err } - log.Printf("---") - log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String()) - log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies()) + log.Printf(` +--- +%s - method: %s | url: %s +%s - status: %d | cookies: %+v +--- +`, t.Hostname, req.Method, req.URL.String(), t.Hostname, resp.StatusCode, resp.Cookies()) return resp, nil } diff --git a/integration/scenario_test.go b/integration/scenario_test.go index ead3f1fd..1e2a151a 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -5,6 +5,7 @@ import ( "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/require" ) // This file is intended to "test the test framework", by proxy it will also test @@ -34,7 +35,7 @@ func TestHeadscale(t *testing.T) { user := "test-space" scenario, err := NewScenario(ScenarioSpec{}) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) t.Run("start-headscale", func(t *testing.T) { @@ -82,7 +83,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { count := 1 scenario, err := NewScenario(ScenarioSpec{}) - assertNoErr(t, err) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) t.Run("start-headscale", func(t *testing.T) { diff --git a/integration/ssh_test.go b/integration/ssh_test.go index a5975eb4..1299ba52 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "tailscale.com/tailcfg" ) @@ -30,7 +31,7 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce Users: []string{"user1", "user2"}, } scenario, err := NewScenario(spec) - assertNoErr(t, err) + require.NoError(t, err) err = scenario.CreateHeadscaleEnv( []tsic.Option{ @@ -50,13 +51,13 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce hsic.WithACLPolicy(policy), hsic.WithTestName("ssh"), ) - assertNoErr(t, err) + require.NoError(t, err) err = scenario.WaitForTailscaleSync() - assertNoErr(t, err) + require.NoError(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErr(t, err) + require.NoError(t, err) return scenario } @@ -93,19 +94,19 @@ func TestSSHOneUserToAll(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range user1Clients { for _, peer := range allClients { @@ -160,16 +161,16 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) nsOneClients, err := scenario.ListTailscaleClients("user1") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) nsTwoClients, err := scenario.ListTailscaleClients("user2") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) testInterUserSSH := func(sourceClients []TailscaleClient, targetClients []TailscaleClient) { for _, client := range sourceClients { @@ -208,13 +209,13 @@ func TestSSHNoSSHConfigured(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { for _, peer := range allClients { @@ -259,13 +260,13 @@ func TestSSHIsBlockedInACL(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { for _, peer := range allClients { @@ -317,16 +318,16 @@ func TestSSHUserOnlyIsolation(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) ssh1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) ssh2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range ssh1Clients { for _, peer := range ssh2Clients { @@ -422,9 +423,9 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien t.Helper() result, _, err := doSSH(t, client, peer) - assertNoErr(t, err) + require.NoError(t, err) - assertContains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", "")) + require.Contains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", "")) } func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) { diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index 665fd670..ddd5027f 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -322,6 +322,20 @@ func New( dockertestutil.DockerAllowNetworkAdministration, dockertestutil.DockerMemoryLimit, ) + if err != nil { + // Try to get more detailed build output + log.Printf("Docker build failed for %s, attempting to get detailed output...", hostname) + buildOutput := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, "Dockerfile.tailscale-HEAD") + if buildOutput != "" { + return nil, fmt.Errorf( + "%s could not start tailscale container (version: %s): %w\n\nDetailed build output:\n%s", + hostname, + version, + err, + buildOutput, + ) + } + } case "unstable": tailscaleOptions.Repository = "tailscale/tailscale" tailscaleOptions.Tag = version @@ -333,6 +347,9 @@ func New( dockertestutil.DockerAllowNetworkAdministration, dockertestutil.DockerMemoryLimit, ) + if err != nil { + log.Printf("Docker run failed for %s (unstable), error: %v", hostname, err) + } default: tailscaleOptions.Repository = "tailscale/tailscale" tailscaleOptions.Tag = "v" + version @@ -344,6 +361,9 @@ func New( dockertestutil.DockerAllowNetworkAdministration, dockertestutil.DockerMemoryLimit, ) + if err != nil { + log.Printf("Docker run failed for %s (version: v%s), error: %v", hostname, version, err) + } } if err != nil { diff --git a/integration/utils.go b/integration/utils.go deleted file mode 100644 index 117bdab7..00000000 --- a/integration/utils.go +++ /dev/null @@ -1,533 +0,0 @@ -package integration - -import ( - "bufio" - "bytes" - "fmt" - "io" - "net/netip" - "strings" - "sync" - "testing" - "time" - - "github.com/cenkalti/backoff/v5" - policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/juanfont/headscale/integration/tsic" - "github.com/stretchr/testify/assert" - "tailscale.com/tailcfg" - "tailscale.com/types/ptr" -) - -const ( - // derpPingTimeout defines the timeout for individual DERP ping operations - // Used in DERP connectivity tests to verify relay server communication. - derpPingTimeout = 2 * time.Second - - // derpPingCount defines the number of ping attempts for DERP connectivity tests - // Higher count provides better reliability assessment of DERP connectivity. - derpPingCount = 10 - - // TimestampFormat is the standard timestamp format used across all integration tests - // Format: "2006-01-02T15-04-05.999999999" provides high precision timestamps - // suitable for debugging and log correlation in integration tests. - TimestampFormat = "2006-01-02T15-04-05.999999999" - - // TimestampFormatRunID is used for generating unique run identifiers - // Format: "20060102-150405" provides compact date-time for file/directory names. - TimestampFormatRunID = "20060102-150405" -) - -func assertNoErr(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "unexpected error: %s", err) -} - -func assertNoErrf(t *testing.T, msg string, err error) { - t.Helper() - if err != nil { - t.Fatalf(msg, err) - } -} - -func assertNotNil(t *testing.T, thing interface{}) { - t.Helper() - if thing == nil { - t.Fatal("got unexpected nil") - } -} - -func assertNoErrHeadscaleEnv(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to create headscale environment: %s", err) -} - -func assertNoErrGetHeadscale(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to get headscale: %s", err) -} - -func assertNoErrListClients(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to list clients: %s", err) -} - -func assertNoErrListClientIPs(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to get client IPs: %s", err) -} - -func assertNoErrSync(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to have all clients sync up: %s", err) -} - -func assertNoErrListFQDN(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to list FQDNs: %s", err) -} - -func assertNoErrLogout(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to log out tailscale nodes: %s", err) -} - -func assertContains(t *testing.T, str, subStr string) { - t.Helper() - if !strings.Contains(str, subStr) { - t.Fatalf("%#v does not contain %#v", str, subStr) - } -} - -func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool { - t.Helper() - - buf := &bytes.Buffer{} - err := client.WriteLogs(buf, buf) - if err != nil { - t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err) - } - - count, err := countMatchingLines(buf, func(line string) bool { - return strings.Contains(line, "websocket: connected to ") - }) - if err != nil { - t.Fatalf("failed to process client logs: %s: %s", client.Hostname(), err) - } - - return count > 0 -} - -// pingAllHelper performs ping tests between all clients and addresses, returning success count. -// This is used to validate network connectivity in integration tests. -// Returns the total number of successful ping operations. -func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { - t.Helper() - success := 0 - - for _, client := range clients { - for _, addr := range addrs { - err := client.Ping(addr, opts...) - if err != nil { - t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err) - } else { - success++ - } - } - } - - return success -} - -// pingDerpAllHelper performs DERP-based ping tests between all clients and addresses. -// This specifically tests connectivity through DERP relay servers, which is important -// for validating NAT traversal and relay functionality. Returns success count. -func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { - t.Helper() - success := 0 - - for _, client := range clients { - for _, addr := range addrs { - if isSelfClient(client, addr) { - continue - } - - err := client.Ping( - addr, - tsic.WithPingTimeout(derpPingTimeout), - tsic.WithPingCount(derpPingCount), - tsic.WithPingUntilDirect(false), - ) - if err != nil { - t.Logf("failed to ping %s from %s: %s", addr, client.Hostname(), err) - } else { - success++ - } - } - } - - return success -} - -// assertClientsState validates the status and netmap of a list of -// clients for the general case of all to all connectivity. -func assertClientsState(t *testing.T, clients []TailscaleClient) { - t.Helper() - - var wg sync.WaitGroup - - for _, client := range clients { - wg.Add(1) - c := client // Avoid loop pointer - go func() { - defer wg.Done() - assertValidStatus(t, c) - assertValidNetcheck(t, c) - assertValidNetmap(t, c) - }() - } - - t.Logf("waiting for client state checks to finish") - wg.Wait() -} - -// assertValidNetmap asserts that the netmap of a client has all -// the minimum required fields set to a known working config for -// the general case. Fields are checked on self, then all peers. -// This test is not suitable for ACL/partial connection tests. -// This test can only be run on clients from 1.56.1. It will -// automatically pass all clients below that and is safe to call -// for all versions. -func assertValidNetmap(t *testing.T, client TailscaleClient) { - t.Helper() - - if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) { - t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version()) - - return - } - - t.Logf("Checking netmap of %q", client.Hostname()) - - netmap, err := client.Netmap() - if err != nil { - t.Fatalf("getting netmap for %q: %s", client.Hostname(), err) - } - - assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) - if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { - assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) - } - - assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) - assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) - - assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname()) - - assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) - assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) - assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) - - for _, peer := range netmap.Peers { - assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) - assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) - - assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) - if hi := peer.Hostinfo(); hi.Valid() { - assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) - - // Netinfo is not always set - // assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) - if ni := hi.NetInfo(); ni.Valid() { - assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) - } - } - - assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) - - assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) - - assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) - } -} - -// assertValidStatus asserts that the status of a client has all -// the minimum required fields set to a known working config for -// the general case. Fields are checked on self, then all peers. -// This test is not suitable for ACL/partial connection tests. -func assertValidStatus(t *testing.T, client TailscaleClient) { - t.Helper() - status, err := client.Status(true) - if err != nil { - t.Fatalf("getting status for %q: %s", client.Hostname(), err) - } - - assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname()) - assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname()) - assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname()) - - assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname()) - - // This seem to not appear until version 1.56 - if status.Self.AllowedIPs != nil { - assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname()) - } - - assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname()) - - assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname()) - - assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname()) - - // This isn't really relevant for Self as it won't be in its own socket/wireguard. - // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) - // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname()) - - for _, peer := range status.Peer { - assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) - assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname()) - assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname()) - - assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname()) - - // This seem to not appear until version 1.56 - if peer.AllowedIPs != nil { - assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname()) - } - - // Addrs does not seem to appear in the status from peers. - // assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname()) - - assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname()) - - assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname()) - assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname()) - - // TODO(kradalby): InEngine is only true when a proper tunnel is set up, - // there might be some interesting stuff to test here in the future. - // assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname()) - } -} - -func assertValidNetcheck(t *testing.T, client TailscaleClient) { - t.Helper() - report, err := client.Netcheck() - if err != nil { - t.Fatalf("getting status for %q: %s", client.Hostname(), err) - } - - assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) -} - -// assertCommandOutputContains executes a command with exponential backoff retry until the output -// contains the expected string or timeout is reached (10 seconds). -// This implements eventual consistency patterns and should be used instead of time.Sleep -// before executing commands that depend on network state propagation. -// -// Timeout: 10 seconds with exponential backoff -// Use cases: DNS resolution, route propagation, policy updates. -func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { - t.Helper() - - _, err := backoff.Retry(t.Context(), func() (struct{}, error) { - stdout, stderr, err := c.Execute(command) - if err != nil { - return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err) - } - - if !strings.Contains(stdout, contains) { - return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) - } - - return struct{}{}, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) - - assert.NoError(t, err) -} - -func isSelfClient(client TailscaleClient, addr string) bool { - if addr == client.Hostname() { - return true - } - - ips, err := client.IPs() - if err != nil { - return false - } - - for _, ip := range ips { - if ip.String() == addr { - return true - } - } - - return false -} - -func dockertestMaxWait() time.Duration { - wait := 300 * time.Second //nolint - - if util.IsCI() { - wait = 600 * time.Second //nolint - } - - return wait -} - -func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) { - count := 0 - scanner := bufio.NewScanner(in) - { - const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB - buff := make([]byte, logBufferInitialSize) - scanner.Buffer(buff, len(buff)) - scanner.Split(bufio.ScanLines) - } - - for scanner.Scan() { - if predicate(scanner.Text()) { - count += 1 - } - } - - return count, scanner.Err() -} - -// func dockertestCommandTimeout() time.Duration { -// timeout := 10 * time.Second //nolint -// -// if isCI() { -// timeout = 60 * time.Second //nolint -// } -// -// return timeout -// } - -// pingAllNegativeHelper is intended to have 1 or more nodes timing out from the ping, -// it counts failures instead of successes. -// func pingAllNegativeHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { -// t.Helper() -// failures := 0 -// -// timeout := 100 -// count := 3 -// -// for _, client := range clients { -// for _, addr := range addrs { -// err := client.Ping( -// addr, -// tsic.WithPingTimeout(time.Duration(timeout)*time.Millisecond), -// tsic.WithPingCount(count), -// ) -// if err != nil { -// failures++ -// } -// } -// } -// -// return failures -// } - -// // findPeerByIP takes an IP and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus -// // if there is a peer with the given IP. If no peer is found, nil is returned. -// func findPeerByIP( -// ip netip.Addr, -// peers map[key.NodePublic]*ipnstate.PeerStatus, -// ) *ipnstate.PeerStatus { -// for _, peer := range peers { -// for _, peerIP := range peer.TailscaleIPs { -// if ip == peerIP { -// return peer -// } -// } -// } -// -// return nil -// } - -// Helper functions for creating typed policy entities - -// wildcard returns a wildcard alias (*). -func wildcard() policyv2.Alias { - return policyv2.Wildcard -} - -// usernamep returns a pointer to a Username as an Alias. -func usernamep(name string) policyv2.Alias { - return ptr.To(policyv2.Username(name)) -} - -// hostp returns a pointer to a Host. -func hostp(name string) policyv2.Alias { - return ptr.To(policyv2.Host(name)) -} - -// groupp returns a pointer to a Group as an Alias. -func groupp(name string) policyv2.Alias { - return ptr.To(policyv2.Group(name)) -} - -// tagp returns a pointer to a Tag as an Alias. -func tagp(name string) policyv2.Alias { - return ptr.To(policyv2.Tag(name)) -} - -// prefixp returns a pointer to a Prefix from a CIDR string. -func prefixp(cidr string) policyv2.Alias { - prefix := netip.MustParsePrefix(cidr) - return ptr.To(policyv2.Prefix(prefix)) -} - -// aliasWithPorts creates an AliasWithPorts structure from an alias and ports. -func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts { - return policyv2.AliasWithPorts{ - Alias: alias, - Ports: ports, - } -} - -// usernameOwner returns a Username as an Owner for use in TagOwners. -func usernameOwner(name string) policyv2.Owner { - return ptr.To(policyv2.Username(name)) -} - -// groupOwner returns a Group as an Owner for use in TagOwners. -func groupOwner(name string) policyv2.Owner { - return ptr.To(policyv2.Group(name)) -} - -// usernameApprover returns a Username as an AutoApprover. -func usernameApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Username(name)) -} - -// groupApprover returns a Group as an AutoApprover. -func groupApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Group(name)) -} - -// tagApprover returns a Tag as an AutoApprover. -func tagApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Tag(name)) -} - -// -// // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus -// // if there is a peer with the given hostname. If no peer is found, nil is returned. -// func findPeerByHostname( -// hostname string, -// peers map[key.NodePublic]*ipnstate.PeerStatus, -// ) *ipnstate.PeerStatus { -// for _, peer := range peers { -// if hostname == peer.HostName { -// return peer -// } -// } -// -// return nil -// }