From f20bd0cf086423d29446ef8f02ffa4ce87f2b377 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 20 Feb 2026 10:58:49 +0000 Subject: [PATCH] node: implement disable key expiry via CLI and API Add --disable flag to "headscale nodes expire" CLI command and disable_expiry field handling in the gRPC API to allow disabling key expiry for nodes. When disabled, the node's expiry is set to NULL and IsExpired() returns false. The CLI follows the new grpcRunE/RunE/printOutput patterns introduced in the recent CLI refactor. Also fix NodeSetExpiry to persist directly to the database instead of going through persistNodeToDB which omits the expiry field. Fixes #2681 Co-authored-by: Marco Santos --- .github/workflows/test-integration.yaml | 1 + cmd/headscale/cli/nodes.go | 26 ++++++- hscontrol/auth.go | 4 +- hscontrol/auth_test.go | 2 +- hscontrol/db/node.go | 9 +-- hscontrol/db/node_test.go | 44 ++++++++++- hscontrol/grpcv1.go | 32 +++++++- hscontrol/state/state.go | 24 +++++- integration/general_test.go | 97 +++++++++++++++++++++++++ 9 files changed, 222 insertions(+), 17 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 1dfd10ee..4a9cbe98 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -229,6 +229,7 @@ jobs: - TestUpdateHostnameFromClient - TestExpireNode - TestSetNodeExpiryInFuture + - TestDisableNodeExpiry - TestNodeOnlineStatus - TestPingAllByIPManyUpDown - Test2118DeletingOnlineNodePanics diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index dbc7e8bf..930efc29 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -32,6 +32,7 @@ func init() { expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.") + expireNodeCmd.Flags().BoolP("disable", "d", false, "Disable key expiry (node will never expire)") mustMarkRequired(expireNodeCmd, "identifier") nodeCmd.AddCommand(expireNodeCmd) @@ -143,12 +144,31 @@ var listNodeRoutesCmd = &cobra.Command{ } var expireNodeCmd = &cobra.Command{ - Use: "expire", - Short: "Expire (log out) a node in your network", - Long: "Expiring a node will keep the node in the database and force it to reauthenticate.", + Use: "expire", + Short: "Expire (log out) a node in your network", + Long: `Expiring a node will keep the node in the database and force it to reauthenticate. + +Use --disable to disable key expiry (node will never expire).`, Aliases: []string{"logout", "exp", "e"}, RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { identifier, _ := cmd.Flags().GetUint64("identifier") + disableExpiry, _ := cmd.Flags().GetBool("disable") + + // Handle disable expiry - node will never expire. + if disableExpiry { + request := &v1.ExpireNodeRequest{ + NodeId: identifier, + DisableExpiry: true, + } + + response, err := client.ExpireNode(ctx, request) + if err != nil { + return fmt.Errorf("disabling node expiry: %w", err) + } + + return printOutput(cmd, response.GetNode(), "Node expiry disabled") + } + expiry, _ := cmd.Flags().GetString("expiry") now := time.Now() diff --git a/hscontrol/auth.go b/hscontrol/auth.go index fdc63461..d5a77bd7 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -212,7 +212,9 @@ func (h *Headscale) handleLogout( // Update the internal state with the nodes new expiry, meaning it is // logged out. - updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), req.Expiry) + expiry := req.Expiry + + updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), &expiry) if err != nil { return nil, fmt.Errorf("setting node expiry: %w", err) } diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index d28ed565..83dfb913 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -587,7 +587,7 @@ func TestAuthenticationFlows(t *testing.T) { // Expire the node expiredTime := time.Now().Add(-1 * time.Hour) - _, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime) + _, _, err = app.state.SetNodeExpiry(node.ID(), &expiredTime) return "", err }, diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 51bba035..d2db012c 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -315,16 +315,15 @@ func RenameNode(tx *gorm.DB, return nil } -func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error { +func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry *time.Time) error { return hsdb.Write(func(tx *gorm.DB) error { return NodeSetExpiry(tx, nodeID, expiry) }) } -// NodeSetExpiry takes a Node struct and a new expiry time. -func NodeSetExpiry(tx *gorm.DB, - nodeID types.NodeID, expiry time.Time, -) error { +// NodeSetExpiry sets a new expiry time for a node. +// If expiry is nil, the node's expiry is disabled (node will never expire). +func NodeSetExpiry(tx *gorm.DB, nodeID types.NodeID, expiry *time.Time) error { return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 55289ca4..128baf5b 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -128,7 +128,7 @@ func TestExpireNode(t *testing.T) { assert.False(t, nodeFromDB.IsExpired()) now := time.Now() - err = db.NodeSetExpiry(nodeFromDB.ID, now) + err = db.NodeSetExpiry(nodeFromDB.ID, &now) require.NoError(t, err) nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") @@ -137,6 +137,48 @@ func TestExpireNode(t *testing.T) { assert.True(t, nodeFromDB.IsExpired()) } +func TestDisableNodeExpiry(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) + + pakID := pak.ID + node := &types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + Expiry: &time.Time{}, + } + db.DB.Save(node) + + // Set an expiry first. + past := time.Now().Add(-time.Hour) + err = db.NodeSetExpiry(node.ID, &past) + require.NoError(t, err) + + nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode") + require.NoError(t, err) + assert.True(t, nodeFromDB.IsExpired(), "node should be expired") + + // Disable expiry by setting nil. + err = db.NodeSetExpiry(node.ID, nil) + require.NoError(t, err) + + nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") + require.NoError(t, err) + assert.False(t, nodeFromDB.IsExpired(), "node should not be expired after disabling expiry") + assert.Nil(t, nodeFromDB.Expiry, "expiry should be nil after disabling") +} + func TestSetTags(t *testing.T) { db, err := newSQLiteTestDB() require.NoError(t, err) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 073c6677..3af8e807 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -451,12 +451,40 @@ func (api headscaleV1APIServer) ExpireNode( ctx context.Context, request *v1.ExpireNodeRequest, ) (*v1.ExpireNodeResponse, error) { + if request.GetDisableExpiry() && request.GetExpiry() != nil { + return nil, status.Error( + codes.InvalidArgument, + "cannot set both disable_expiry and expiry", + ) + } + + // Handle disable expiry request - node will never expire. + if request.GetDisableExpiry() { + node, nodeChange, err := api.h.state.SetNodeExpiry( + types.NodeID(request.GetNodeId()), nil, + ) + if err != nil { + return nil, err + } + + api.h.Change(nodeChange) + + log.Trace(). + Caller(). + EmbedObject(node). + Msg("node expiry disabled") + + return &v1.ExpireNodeResponse{Node: node.Proto()}, nil + } + expiry := time.Now() if request.GetExpiry() != nil { expiry = request.GetExpiry().AsTime() } - node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), expiry) + node, nodeChange, err := api.h.state.SetNodeExpiry( + types.NodeID(request.GetNodeId()), &expiry, + ) if err != nil { return nil, err } @@ -467,7 +495,7 @@ func (api headscaleV1APIServer) ExpireNode( log.Trace(). Caller(). EmbedObject(node). - Time(zf.ExpiresAt, *node.AsStruct().Expiry). + Time(zf.ExpiresAt, expiry). Msg("node expired") return &v1.ExpireNodeResponse{Node: node.Proto()}, nil diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index e421d5bd..35544aa3 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -638,22 +638,38 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] { } // SetNodeExpiry updates the expiration time for a node. -func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.Change, error) { +// If expiry is nil, the node's expiry is disabled (node will never expire). +func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry *time.Time) (types.NodeView, change.Change, error) { // Update NodeStore before database to ensure consistency. The NodeStore update is // blocking and will be the source of truth for the batcher. The database update must // make the exact same change. If the database update fails, the NodeStore change will // remain, but since we return an error, no change notification will be sent to the // batcher, preventing inconsistent state propagation. - expiryPtr := expiry n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { - node.Expiry = &expiryPtr + node.Expiry = expiry }) if !ok { return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) } - return s.persistNodeToDB(n) + // Persist expiry change to database directly since persistNodeToDB omits expiry. + err := s.db.NodeSetExpiry(nodeID, expiry) + if err != nil { + return types.NodeView{}, change.Change{}, fmt.Errorf("setting node expiry in database: %w", err) + } + + // Update policy manager and generate change notification. + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.Change{}, fmt.Errorf("updating policy manager after setting expiry: %w", err) + } + + if c.IsEmpty() { + c = change.NodeAdded(n.ID()) + } + + return n, c, nil } // SetNodeTags assigns tags to a node, making it a "tagged node". diff --git a/integration/general_test.go b/integration/general_test.go index f44a0f03..42ba58bf 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -1166,6 +1166,103 @@ func TestSetNodeExpiryInFuture(t *testing.T) { } } +// TestDisableNodeExpiry tests disabling key expiry for a node. +// First sets an expiry, then disables it and verifies the node never expires. +func TestDisableNodeExpiry(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("disableexpiry")) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // First set an expiry on the node. + result, err := headscale.Execute( + []string{ + "headscale", "nodes", "expire", + "--identifier", "1", + "--output", "json", + "--expiry", time.Now().Add(time.Hour).Format(time.RFC3339), + }, + ) + require.NoError(t, err) + + var node v1.Node + err = json.Unmarshal([]byte(result), &node) + require.NoError(t, err) + require.NotNil(t, node.GetExpiry(), "node should have an expiry set") + + // Now disable the expiry. + result, err = headscale.Execute( + []string{ + "headscale", "nodes", "expire", + "--identifier", "1", + "--output", "json", + "--disable", + }, + ) + require.NoError(t, err) + + var nodeDisabled v1.Node + err = json.Unmarshal([]byte(result), &nodeDisabled) + require.NoError(t, err) + + // Expiry should be nil (or zero time) when disabled. + if nodeDisabled.GetExpiry() != nil { + require.True(t, nodeDisabled.GetExpiry().AsTime().IsZero(), + "node expiry should be zero/nil after disabling") + } + + var nodeKey key.NodePublic + err = nodeKey.UnmarshalText([]byte(nodeDisabled.GetNodeKey())) + require.NoError(t, err) + + // Verify peers see the node as not expired. + for _, client := range allClients { + if client.Hostname() == nodeDisabled.GetName() { + continue + } + + assert.EventuallyWithT( + t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + + peerStatus, ok := status.Peer[nodeKey] + assert.True(ct, ok, "node key should be present in peer list") + + if !ok { + return + } + + // Node should not be expired. + assert.Falsef( + ct, + peerStatus.Expired, + "node %q should not be marked as expired after disabling expiry", + peerStatus.HostName, + ) + }, 3*time.Minute, 5*time.Second, "waiting for disabled expiry to propagate", + ) + } +} + func TestNodeOnlineStatus(t *testing.T) { IntegrationSkip(t)