diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 1dfd10ee..4a9cbe98 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -229,6 +229,7 @@ jobs: - TestUpdateHostnameFromClient - TestExpireNode - TestSetNodeExpiryInFuture + - TestDisableNodeExpiry - TestNodeOnlineStatus - TestPingAllByIPManyUpDown - Test2118DeletingOnlineNodePanics diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index dbc7e8bf..930efc29 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -32,6 +32,7 @@ func init() { expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.") + expireNodeCmd.Flags().BoolP("disable", "d", false, "Disable key expiry (node will never expire)") mustMarkRequired(expireNodeCmd, "identifier") nodeCmd.AddCommand(expireNodeCmd) @@ -143,12 +144,31 @@ var listNodeRoutesCmd = &cobra.Command{ } var expireNodeCmd = &cobra.Command{ - Use: "expire", - Short: "Expire (log out) a node in your network", - Long: "Expiring a node will keep the node in the database and force it to reauthenticate.", + Use: "expire", + Short: "Expire (log out) a node in your network", + Long: `Expiring a node will keep the node in the database and force it to reauthenticate. + +Use --disable to disable key expiry (node will never expire).`, Aliases: []string{"logout", "exp", "e"}, RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { identifier, _ := cmd.Flags().GetUint64("identifier") + disableExpiry, _ := cmd.Flags().GetBool("disable") + + // Handle disable expiry - node will never expire. + if disableExpiry { + request := &v1.ExpireNodeRequest{ + NodeId: identifier, + DisableExpiry: true, + } + + response, err := client.ExpireNode(ctx, request) + if err != nil { + return fmt.Errorf("disabling node expiry: %w", err) + } + + return printOutput(cmd, response.GetNode(), "Node expiry disabled") + } + expiry, _ := cmd.Flags().GetString("expiry") now := time.Now() diff --git a/hscontrol/auth.go b/hscontrol/auth.go index fdc63461..d5a77bd7 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -212,7 +212,9 @@ func (h *Headscale) handleLogout( // Update the internal state with the nodes new expiry, meaning it is // logged out. - updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), req.Expiry) + expiry := req.Expiry + + updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), &expiry) if err != nil { return nil, fmt.Errorf("setting node expiry: %w", err) } diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index d28ed565..83dfb913 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -587,7 +587,7 @@ func TestAuthenticationFlows(t *testing.T) { // Expire the node expiredTime := time.Now().Add(-1 * time.Hour) - _, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime) + _, _, err = app.state.SetNodeExpiry(node.ID(), &expiredTime) return "", err }, diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 51bba035..d2db012c 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -315,16 +315,15 @@ func RenameNode(tx *gorm.DB, return nil } -func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error { +func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry *time.Time) error { return hsdb.Write(func(tx *gorm.DB) error { return NodeSetExpiry(tx, nodeID, expiry) }) } -// NodeSetExpiry takes a Node struct and a new expiry time. -func NodeSetExpiry(tx *gorm.DB, - nodeID types.NodeID, expiry time.Time, -) error { +// NodeSetExpiry sets a new expiry time for a node. +// If expiry is nil, the node's expiry is disabled (node will never expire). +func NodeSetExpiry(tx *gorm.DB, nodeID types.NodeID, expiry *time.Time) error { return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 55289ca4..128baf5b 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -128,7 +128,7 @@ func TestExpireNode(t *testing.T) { assert.False(t, nodeFromDB.IsExpired()) now := time.Now() - err = db.NodeSetExpiry(nodeFromDB.ID, now) + err = db.NodeSetExpiry(nodeFromDB.ID, &now) require.NoError(t, err) nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") @@ -137,6 +137,48 @@ func TestExpireNode(t *testing.T) { assert.True(t, nodeFromDB.IsExpired()) } +func TestDisableNodeExpiry(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) + + pakID := pak.ID + node := &types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + Expiry: &time.Time{}, + } + db.DB.Save(node) + + // Set an expiry first. + past := time.Now().Add(-time.Hour) + err = db.NodeSetExpiry(node.ID, &past) + require.NoError(t, err) + + nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode") + require.NoError(t, err) + assert.True(t, nodeFromDB.IsExpired(), "node should be expired") + + // Disable expiry by setting nil. + err = db.NodeSetExpiry(node.ID, nil) + require.NoError(t, err) + + nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") + require.NoError(t, err) + assert.False(t, nodeFromDB.IsExpired(), "node should not be expired after disabling expiry") + assert.Nil(t, nodeFromDB.Expiry, "expiry should be nil after disabling") +} + func TestSetTags(t *testing.T) { db, err := newSQLiteTestDB() require.NoError(t, err) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 073c6677..3af8e807 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -451,12 +451,40 @@ func (api headscaleV1APIServer) ExpireNode( ctx context.Context, request *v1.ExpireNodeRequest, ) (*v1.ExpireNodeResponse, error) { + if request.GetDisableExpiry() && request.GetExpiry() != nil { + return nil, status.Error( + codes.InvalidArgument, + "cannot set both disable_expiry and expiry", + ) + } + + // Handle disable expiry request - node will never expire. + if request.GetDisableExpiry() { + node, nodeChange, err := api.h.state.SetNodeExpiry( + types.NodeID(request.GetNodeId()), nil, + ) + if err != nil { + return nil, err + } + + api.h.Change(nodeChange) + + log.Trace(). + Caller(). + EmbedObject(node). + Msg("node expiry disabled") + + return &v1.ExpireNodeResponse{Node: node.Proto()}, nil + } + expiry := time.Now() if request.GetExpiry() != nil { expiry = request.GetExpiry().AsTime() } - node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), expiry) + node, nodeChange, err := api.h.state.SetNodeExpiry( + types.NodeID(request.GetNodeId()), &expiry, + ) if err != nil { return nil, err } @@ -467,7 +495,7 @@ func (api headscaleV1APIServer) ExpireNode( log.Trace(). Caller(). EmbedObject(node). - Time(zf.ExpiresAt, *node.AsStruct().Expiry). + Time(zf.ExpiresAt, expiry). Msg("node expired") return &v1.ExpireNodeResponse{Node: node.Proto()}, nil diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index e421d5bd..35544aa3 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -638,22 +638,38 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] { } // SetNodeExpiry updates the expiration time for a node. -func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.Change, error) { +// If expiry is nil, the node's expiry is disabled (node will never expire). +func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry *time.Time) (types.NodeView, change.Change, error) { // Update NodeStore before database to ensure consistency. The NodeStore update is // blocking and will be the source of truth for the batcher. The database update must // make the exact same change. If the database update fails, the NodeStore change will // remain, but since we return an error, no change notification will be sent to the // batcher, preventing inconsistent state propagation. - expiryPtr := expiry n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { - node.Expiry = &expiryPtr + node.Expiry = expiry }) if !ok { return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) } - return s.persistNodeToDB(n) + // Persist expiry change to database directly since persistNodeToDB omits expiry. + err := s.db.NodeSetExpiry(nodeID, expiry) + if err != nil { + return types.NodeView{}, change.Change{}, fmt.Errorf("setting node expiry in database: %w", err) + } + + // Update policy manager and generate change notification. + c, err := s.updatePolicyManagerNodes() + if err != nil { + return n, change.Change{}, fmt.Errorf("updating policy manager after setting expiry: %w", err) + } + + if c.IsEmpty() { + c = change.NodeAdded(n.ID()) + } + + return n, c, nil } // SetNodeTags assigns tags to a node, making it a "tagged node". diff --git a/integration/general_test.go b/integration/general_test.go index f44a0f03..42ba58bf 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -1166,6 +1166,103 @@ func TestSetNodeExpiryInFuture(t *testing.T) { } } +// TestDisableNodeExpiry tests disabling key expiry for a node. +// First sets an expiry, then disables it and verifies the node never expires. +func TestDisableNodeExpiry(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("disableexpiry")) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // First set an expiry on the node. + result, err := headscale.Execute( + []string{ + "headscale", "nodes", "expire", + "--identifier", "1", + "--output", "json", + "--expiry", time.Now().Add(time.Hour).Format(time.RFC3339), + }, + ) + require.NoError(t, err) + + var node v1.Node + err = json.Unmarshal([]byte(result), &node) + require.NoError(t, err) + require.NotNil(t, node.GetExpiry(), "node should have an expiry set") + + // Now disable the expiry. + result, err = headscale.Execute( + []string{ + "headscale", "nodes", "expire", + "--identifier", "1", + "--output", "json", + "--disable", + }, + ) + require.NoError(t, err) + + var nodeDisabled v1.Node + err = json.Unmarshal([]byte(result), &nodeDisabled) + require.NoError(t, err) + + // Expiry should be nil (or zero time) when disabled. + if nodeDisabled.GetExpiry() != nil { + require.True(t, nodeDisabled.GetExpiry().AsTime().IsZero(), + "node expiry should be zero/nil after disabling") + } + + var nodeKey key.NodePublic + err = nodeKey.UnmarshalText([]byte(nodeDisabled.GetNodeKey())) + require.NoError(t, err) + + // Verify peers see the node as not expired. + for _, client := range allClients { + if client.Hostname() == nodeDisabled.GetName() { + continue + } + + assert.EventuallyWithT( + t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + + peerStatus, ok := status.Peer[nodeKey] + assert.True(ct, ok, "node key should be present in peer list") + + if !ok { + return + } + + // Node should not be expired. + assert.Falsef( + ct, + peerStatus.Expired, + "node %q should not be marked as expired after disabling expiry", + peerStatus.HostName, + ) + }, 3*time.Minute, 5*time.Second, "waiting for disabled expiry to propagate", + ) + } +} + func TestNodeOnlineStatus(t *testing.T) { IntegrationSkip(t)