diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 5c8757ed..634467a4 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -22,6 +22,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "golang.org/x/sync/errgroup" "gorm.io/gorm" @@ -1101,6 +1102,161 @@ type newNodeParams struct { ExistingNodeForNetinfo types.NodeView } +// authNodeUpdateParams contains parameters for updating an existing node during auth. +// Used by both reauthExistingNode and convertTaggedNodeToUser to share common logic. +type authNodeUpdateParams struct { + // Node to update; must be valid and in NodeStore. + ExistingNode types.NodeView + // Client data: keys, hostinfo, endpoints. + RegEntry *types.RegisterNode + // Pre-validated hostinfo; NetInfo preserved from ExistingNode. + ValidHostinfo *tailcfg.Hostinfo + // Hostname from hostinfo, or generated from keys if client omits it. + Hostname string + // Auth user; may differ from ExistingNode.User() on conversion. + User *types.User + // Overrides RegEntry.Node.Expiry; ignored for tagged nodes. + Expiry *time.Time + // Only used when IsConvertFromTag=true. + RegisterMethod string + // Set true for tagged->user conversion. Affects RegisterMethod and expiry. + IsConvertFromTag bool +} + +// applyAuthNodeUpdate applies common update logic for re-authenticating or converting +// an existing node. It updates the node in NodeStore, processes RequestTags, and +// persists changes to the database. +func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) { + // Process RequestTags during reauth (#2979) + // Due to json:",omitempty", we treat empty/nil as "clear tags" + var requestTags []string + if params.RegEntry.Node.Hostinfo != nil { + requestTags = params.RegEntry.Node.Hostinfo.RequestTags + } + + oldTags := params.ExistingNode.Tags().AsSlice() + + var rejectedTags []string + + // Update existing node in NodeStore + updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) { + node.NodeKey = params.RegEntry.Node.NodeKey + node.DiscoKey = params.RegEntry.Node.DiscoKey + node.Hostname = params.Hostname + + // Preserve NetInfo from existing node when re-registering + node.Hostinfo = params.ValidHostinfo + node.Hostinfo.NetInfo = preserveNetInfo( + params.ExistingNode, + params.ExistingNode.ID(), + params.ValidHostinfo, + ) + + node.Endpoints = params.RegEntry.Node.Endpoints + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + + // Set RegisterMethod - for conversion this is the new method, + // for reauth we preserve the existing one from regEntry + if params.IsConvertFromTag { + node.RegisterMethod = params.RegisterMethod + } else { + node.RegisterMethod = params.RegEntry.Node.RegisterMethod + } + + // Expiry handling differs based on node type: + // - Tagged nodes keep their existing expiry (disabled) + // - User-owned nodes update expiry from the provided value or registration entry + // - Converting from tagged to user-owned: always set expiry + if params.IsConvertFromTag || !node.IsTagged() { + if params.Expiry != nil { + node.Expiry = params.Expiry + } else { + node.Expiry = params.RegEntry.Node.Expiry + } + } + + rejectedTags = s.processReauthTags(node, requestTags, params.User, oldTags) + }) + + if !ok { + return types.NodeView{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, params.ExistingNode.ID()) + } + + if len(rejectedTags) > 0 { + return types.NodeView{}, fmt.Errorf( + "%w %v are invalid or not permitted", + ErrRequestedTagsInvalidOrNotPermitted, + rejectedTags, + ) + } + + // Persist to database + // Omit AuthKeyID/AuthKey to prevent stale PreAuthKey references from causing FK errors. + _, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := tx.Omit("AuthKeyID", "AuthKey").Updates(updatedNodeView.AsStruct()).Error + if err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + return nil, nil //nolint:nilnil // side-effect only write + }) + if err != nil { + return types.NodeView{}, err + } + + return updatedNodeView, nil +} + +// reauthExistingNode handles re-authentication of an existing node that belongs +// to the same user. This updates the node in place with new keys and hostinfo. +func (s *State) reauthExistingNode(logger zerolog.Logger, p authNodeUpdateParams) (types.NodeView, error) { + logger.Info(). + Str("node.name", p.ExistingNode.Hostname()). + Uint64("node.id", p.ExistingNode.ID().Uint64()). + Interface("hostinfo", p.RegEntry.Node.Hostinfo). + Msg("Updating existing node registration via reauth") + + updatedNode, err := s.applyAuthNodeUpdate(p) + if err != nil { + return types.NodeView{}, err + } + + logger.Trace(). + Str("node.name", updatedNode.Hostname()). + Uint64("node.id", updatedNode.ID().Uint64()). + Str("node.key", updatedNode.NodeKey().ShortString()). + Msg("Node re-authorized") + + return updatedNode, nil +} + +// convertTaggedNodeToUser converts a tagged node to be owned by a user. +// This handles the case where a node was registered with a tags-only PreAuthKey +// and is now being re-registered via OIDC or other user auth method. +func (s *State) convertTaggedNodeToUser(logger zerolog.Logger, p authNodeUpdateParams) (types.NodeView, error) { + logger.Info(). + Str("node.name", p.ExistingNode.Hostname()). + Uint64("node.id", p.ExistingNode.ID().Uint64()). + Strs("old.tags", p.ExistingNode.Tags().AsSlice()). + Msg("Converting tagged node to user-owned node") + + p.IsConvertFromTag = true + + updatedNode, err := s.applyAuthNodeUpdate(p) + if err != nil { + return types.NodeView{}, err + } + + logger.Trace(). + Str("node.name", updatedNode.Hostname()). + Uint64("node.id", updatedNode.ID().Uint64()). + Str("node.key", updatedNode.NodeKey().ShortString()). + Msg("Tagged node converted to user-owned") + + return updatedNode, nil +} + // createAndSaveNewNode creates a new node, allocates IPs, saves to DB, and adds to NodeStore. // It preserves netinfo from an existing node if one is provided (for faster DERP connectivity). func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, error) { @@ -1368,248 +1524,76 @@ func (s *State) HandleNodeFromAuthPath( regEntry.Node.Hostinfo, ) + // Lookup existing nodes + machineKey := regEntry.Node.MachineKey + existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID)) + existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) + + // Named conditions - describe WHAT we found, not HOW we check it + nodeExistsForSameUser := existingNodeSameUser.Valid() + nodeExistsForAnyUser := existingNodeAnyUser.Valid() + existingNodeIsTagged := nodeExistsForAnyUser && existingNodeAnyUser.IsTagged() + existingNodeOwnedByOtherUser := nodeExistsForAnyUser && + !existingNodeIsTagged && + existingNodeAnyUser.UserID().Get() != user.ID + + // Create logger with common fields for all auth operations + logger := log.With(). + Str("registration_id", registrationID.String()). + Str("user.name", user.Name). + Str("machine.key", machineKey.ShortString()). + Str("method", registrationMethod). + Logger() + + // Common params for update operations + updateParams := authNodeUpdateParams{ + RegEntry: regEntry, + ValidHostinfo: validHostinfo, + Hostname: hostname, + User: user, + Expiry: expiry, + RegisterMethod: registrationMethod, + } + var finalNode types.NodeView - // Check if node already exists with same machine key for this user - existingNodeSameUser, existsSameUser := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey, types.UserID(user.ID)) + if nodeExistsForSameUser { + updateParams.ExistingNode = existingNodeSameUser - // If this node exists for this user, update the node in place. - if existsSameUser && existingNodeSameUser.Valid() { - log.Info(). - Caller(). - Str("registration_id", registrationID.String()). - Str("user.name", user.Name). - Str("registrationMethod", registrationMethod). - Str("node.name", existingNodeSameUser.Hostname()). - Uint64("node.id", existingNodeSameUser.ID().Uint64()). - Interface("hostinfo", regEntry.Node.Hostinfo). - Msg("Updating existing node registration via reauth") - - // Process RequestTags during reauth (#2979) - // Due to json:",omitempty", we treat empty/nil as "clear tags" - var requestTags []string - if regEntry.Node.Hostinfo != nil { - requestTags = regEntry.Node.Hostinfo.RequestTags - } - - oldTags := existingNodeSameUser.Tags().AsSlice() - - var rejectedTags []string - - // Update existing node - NodeStore first, then database - updatedNodeView, ok := s.nodeStore.UpdateNode(existingNodeSameUser.ID(), func(node *types.Node) { - node.NodeKey = regEntry.Node.NodeKey - node.DiscoKey = regEntry.Node.DiscoKey - node.Hostname = hostname - - // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics - // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). - - // Preserve NetInfo from existing node when re-registering - node.Hostinfo = validHostinfo - node.Hostinfo.NetInfo = preserveNetInfo(existingNodeSameUser, existingNodeSameUser.ID(), validHostinfo) - - node.Endpoints = regEntry.Node.Endpoints - node.RegisterMethod = regEntry.Node.RegisterMethod - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) - - // Tagged nodes keep their existing expiry (disabled). - // User-owned nodes update expiry from the provided value or registration entry. - if !node.IsTagged() { - if expiry != nil { - node.Expiry = expiry - } else { - node.Expiry = regEntry.Node.Expiry - } - } - - rejectedTags = s.processReauthTags(node, requestTags, user, oldTags) - }) - - if !ok { - return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeSameUser.ID()) - } - - if len(rejectedTags) > 0 { - return types.NodeView{}, change.Change{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, rejectedTags) - } - - _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - // Use Updates() to preserve fields not modified by UpdateNode. - // Omit AuthKeyID/AuthKey to prevent stale PreAuthKey references from causing FK errors. - err := tx.Omit("AuthKeyID", "AuthKey").Updates(updatedNodeView.AsStruct()).Error - if err != nil { - return nil, fmt.Errorf("failed to save node: %w", err) - } - return nil, nil - }) + finalNode, err = s.reauthExistingNode(logger, updateParams) if err != nil { return types.NodeView{}, change.Change{}, err } + } else if existingNodeIsTagged { + updateParams.ExistingNode = existingNodeAnyUser - log.Trace(). - Caller(). - Str("node.name", updatedNodeView.Hostname()). - Uint64("node.id", updatedNodeView.ID().Uint64()). - Str("machine.key", regEntry.Node.MachineKey.ShortString()). - Str("node.key", updatedNodeView.NodeKey().ShortString()). - Str("user.name", user.Name). - Msg("Node re-authorized") + finalNode, err = s.convertTaggedNodeToUser(logger, updateParams) + if err != nil { + return types.NodeView{}, change.Change{}, err + } + } else if existingNodeOwnedByOtherUser { + oldUser := existingNodeAnyUser.User() - finalNode = updatedNodeView + logger.Info(). + Str("existing.node.name", existingNodeAnyUser.Hostname()). + Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). + Str("old.user", oldUser.Name()). + Msg("Creating new node for different user (same machine key exists for another user)") + + finalNode, err = s.createNewNodeFromAuth( + logger, user, regEntry, hostname, validHostinfo, + expiry, registrationMethod, existingNodeAnyUser, + ) + if err != nil { + return types.NodeView{}, change.Change{}, err + } } else { - // Node does not exist for this user with this machine key - // Check if node exists with this machine key for a different user/owner - existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(regEntry.Node.MachineKey) - - // If an existing TAGGED node is found (regardless of UserID), update it to be owned by - // the new user. This handles the case where a node was registered with a tags-only - // PreAuthKey and is now being re-registered to a user. - if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.IsTagged() { - log.Info(). - Caller(). - Str("existing.node.name", existingNodeAnyUser.Hostname()). - Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). - Str("machine.key", regEntry.Node.MachineKey.ShortString()). - Strs("old.tags", existingNodeAnyUser.Tags().AsSlice()). - Str("new.user", user.Name). - Str("method", registrationMethod). - Msg("Converting tagged node to user-owned node") - - // Process RequestTags during conversion - var requestTags []string - if regEntry.Node.Hostinfo != nil { - requestTags = regEntry.Node.Hostinfo.RequestTags - } - - oldTags := existingNodeAnyUser.Tags().AsSlice() - - var rejectedTags []string - - // Update existing node - convert from tagged to user-owned - updatedNodeView, ok := s.nodeStore.UpdateNode(existingNodeAnyUser.ID(), func(node *types.Node) { - node.NodeKey = regEntry.Node.NodeKey - node.DiscoKey = regEntry.Node.DiscoKey - node.Hostname = hostname - node.Hostinfo = validHostinfo - node.Hostinfo.NetInfo = preserveNetInfo(existingNodeAnyUser, existingNodeAnyUser.ID(), validHostinfo) - node.Endpoints = regEntry.Node.Endpoints - node.RegisterMethod = registrationMethod - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) - - // Set expiry for user-owned node - if expiry != nil { - node.Expiry = expiry - } else { - node.Expiry = regEntry.Node.Expiry - } - - rejectedTags = s.processReauthTags(node, requestTags, user, oldTags) - }) - - if !ok { - return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeAnyUser.ID()) - } - - if len(rejectedTags) > 0 { - return types.NodeView{}, change.Change{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, rejectedTags) - } - - _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - node := updatedNodeView.AsStruct() - - err := tx.Omit("AuthKeyID", "AuthKey").Updates(node).Error - if err != nil { - return nil, fmt.Errorf("failed to save node: %w", err) - } - - return node, nil - }) - if err != nil { - return types.NodeView{}, change.Change{}, err - } - - log.Trace(). - Caller(). - Str("node.name", updatedNodeView.Hostname()). - Uint64("node.id", updatedNodeView.ID().Uint64()). - Str("machine.key", regEntry.Node.MachineKey.ShortString()). - Str("node.key", updatedNodeView.NodeKey().ShortString()). - Str("user.name", user.Name). - Msg("Tagged node converted to user-owned") - - finalNode = updatedNodeView - } else if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != user.ID { - // Node exists but belongs to a different user (user-owned by someone else) - // Create a NEW node for the new user (do not transfer) - // This allows the same machine to have separate node identities per user - oldUser := existingNodeAnyUser.User() - - log.Info(). - Caller(). - Str("existing.node.name", existingNodeAnyUser.Hostname()). - Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). - Str("machine.key", regEntry.Node.MachineKey.ShortString()). - Str("old.user", oldUser.Name()). - Str("new.user", user.Name). - Str("method", registrationMethod). - Msg("Creating new node for different user (same machine key exists for another user)") - - // Create a completely new node - log.Debug(). - Caller(). - Str("registration_id", registrationID.String()). - Str("user.name", user.Name). - Str("registrationMethod", registrationMethod). - Str("expiresAt", fmt.Sprintf("%v", expiry)). - Msg("Registering new node from auth callback") - - var err error - - finalNode, err = s.createAndSaveNewNode(newNodeParams{ - User: *user, - MachineKey: regEntry.Node.MachineKey, - NodeKey: regEntry.Node.NodeKey, - DiscoKey: regEntry.Node.DiscoKey, - Hostname: hostname, - Hostinfo: validHostinfo, - Endpoints: regEntry.Node.Endpoints, - Expiry: cmp.Or(expiry, regEntry.Node.Expiry), - RegisterMethod: registrationMethod, - ExistingNodeForNetinfo: existingNodeAnyUser, - }) - if err != nil { - return types.NodeView{}, change.Change{}, err - } - } else { - // No existing node found - create a completely new node - log.Debug(). - Caller(). - Str("registration_id", registrationID.String()). - Str("user.name", user.Name). - Str("registrationMethod", registrationMethod). - Str("expiresAt", fmt.Sprintf("%v", expiry)). - Msg("Registering new node from auth callback") - - var err error - - finalNode, err = s.createAndSaveNewNode(newNodeParams{ - User: *user, - MachineKey: regEntry.Node.MachineKey, - NodeKey: regEntry.Node.NodeKey, - DiscoKey: regEntry.Node.DiscoKey, - Hostname: hostname, - Hostinfo: validHostinfo, - Endpoints: regEntry.Node.Endpoints, - Expiry: cmp.Or(expiry, regEntry.Node.Expiry), - RegisterMethod: registrationMethod, - ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}), - }) - if err != nil { - return types.NodeView{}, change.Change{}, err - } + finalNode, err = s.createNewNodeFromAuth( + logger, user, regEntry, hostname, validHostinfo, + expiry, registrationMethod, types.NodeView{}, + ) + if err != nil { + return types.NodeView{}, change.Change{}, err } } @@ -1640,6 +1624,37 @@ func (s *State) HandleNodeFromAuthPath( return finalNode, c, nil } +// createNewNodeFromAuth creates a new node during auth callback. +// This is used for both new registrations and when a machine already has a node +// for a different user. +func (s *State) createNewNodeFromAuth( + logger zerolog.Logger, + user *types.User, + regEntry *types.RegisterNode, + hostname string, + validHostinfo *tailcfg.Hostinfo, + expiry *time.Time, + registrationMethod string, + existingNodeForNetinfo types.NodeView, +) (types.NodeView, error) { + logger.Debug(). + Interface("expiry", expiry). + Msg("Registering new node from auth callback") + + return s.createAndSaveNewNode(newNodeParams{ + User: *user, + MachineKey: regEntry.Node.MachineKey, + NodeKey: regEntry.Node.NodeKey, + DiscoKey: regEntry.Node.DiscoKey, + Hostname: hostname, + Hostinfo: validHostinfo, + Endpoints: regEntry.Node.Endpoints, + Expiry: cmp.Or(expiry, regEntry.Node.Expiry), + RegisterMethod: registrationMethod, + ExistingNodeForNetinfo: existingNodeForNetinfo, + }) +} + // HandleNodeFromPreAuthKey handles node registration using a pre-authentication key. func (s *State) HandleNodeFromPreAuthKey( regReq tailcfg.RegisterRequest,