mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 21:17:43 +09:00 
			
		
		
		
	Rewrite authentication flow (#2374)
This commit is contained in:
		
							
								
								
									
										2
									
								
								.github/workflows/check-tests.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/check-tests.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -32,7 +32,7 @@ jobs: | |||||||
|       - name: Generate and check integration tests |       - name: Generate and check integration tests | ||||||
|         if: steps.changed-files.outputs.files == 'true' |         if: steps.changed-files.outputs.files == 'true' | ||||||
|         run: | |         run: | | ||||||
|           nix develop --command bash -c "cd cmd/gh-action-integration-generator/ && go generate" |           nix develop --command bash -c "cd .github/workflows && go generate" | ||||||
|           git diff --exit-code .github/workflows/test-integration.yaml |           git diff --exit-code .github/workflows/test-integration.yaml | ||||||
|  |  | ||||||
|       - name: Show missing tests |       - name: Show missing tests | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| //go:generate go run ./main.go | //go:generate go run ./gh-action-integration-generator.go | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| @@ -42,15 +42,19 @@ func updateYAML(tests []string) { | |||||||
| 	testsForYq := fmt.Sprintf("[%s]", strings.Join(tests, ", ")) | 	testsForYq := fmt.Sprintf("[%s]", strings.Join(tests, ", ")) | ||||||
| 
 | 
 | ||||||
| 	yqCommand := fmt.Sprintf( | 	yqCommand := fmt.Sprintf( | ||||||
| 		"yq eval '.jobs.integration-test.strategy.matrix.test = %s' ../../.github/workflows/test-integration.yaml -i", | 		"yq eval '.jobs.integration-test.strategy.matrix.test = %s' ./test-integration.yaml -i", | ||||||
| 		testsForYq, | 		testsForYq, | ||||||
| 	) | 	) | ||||||
| 	cmd := exec.Command("bash", "-c", yqCommand) | 	cmd := exec.Command("bash", "-c", yqCommand) | ||||||
| 
 | 
 | ||||||
| 	var out bytes.Buffer | 	var stdout bytes.Buffer | ||||||
| 	cmd.Stdout = &out | 	var stderr bytes.Buffer | ||||||
|  | 	cmd.Stdout = &stdout | ||||||
|  | 	cmd.Stderr = &stderr | ||||||
| 	err := cmd.Run() | 	err := cmd.Run() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		log.Printf("stdout: %s", stdout.String()) | ||||||
|  | 		log.Printf("stderr: %s", stderr.String()) | ||||||
| 		log.Fatalf("failed to run yq command: %s", err) | 		log.Fatalf("failed to run yq command: %s", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
							
								
								
									
										4
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -22,10 +22,13 @@ jobs: | |||||||
|           - TestACLNamedHostsCanReach |           - TestACLNamedHostsCanReach | ||||||
|           - TestACLDevice1CanAccessDevice2 |           - TestACLDevice1CanAccessDevice2 | ||||||
|           - TestPolicyUpdateWhileRunningWithCLIInDatabase |           - TestPolicyUpdateWhileRunningWithCLIInDatabase | ||||||
|  |           - TestAuthKeyLogoutAndReloginSameUser | ||||||
|  |           - TestAuthKeyLogoutAndReloginNewUser | ||||||
|           - TestOIDCAuthenticationPingAll |           - TestOIDCAuthenticationPingAll | ||||||
|           - TestOIDCExpireNodesBasedOnTokenExpiry |           - TestOIDCExpireNodesBasedOnTokenExpiry | ||||||
|           - TestOIDC024UserCreation |           - TestOIDC024UserCreation | ||||||
|           - TestOIDCAuthenticationWithPKCE |           - TestOIDCAuthenticationWithPKCE | ||||||
|  |           - TestOIDCReloginSameNodeNewUser | ||||||
|           - TestAuthWebFlowAuthenticationPingAll |           - TestAuthWebFlowAuthenticationPingAll | ||||||
|           - TestAuthWebFlowLogoutAndRelogin |           - TestAuthWebFlowLogoutAndRelogin | ||||||
|           - TestUserCommand |           - TestUserCommand | ||||||
| @@ -50,7 +53,6 @@ jobs: | |||||||
|           - TestDERPServerWebsocketScenario |           - TestDERPServerWebsocketScenario | ||||||
|           - TestPingAllByIP |           - TestPingAllByIP | ||||||
|           - TestPingAllByIPPublicDERP |           - TestPingAllByIPPublicDERP | ||||||
|           - TestAuthKeyLogoutAndRelogin |  | ||||||
|           - TestEphemeral |           - TestEphemeral | ||||||
|           - TestEphemeralInAlternateTimezone |           - TestEphemeralInAlternateTimezone | ||||||
|           - TestEphemeral2006DeletedTooQuickly |           - TestEphemeral2006DeletedTooQuickly | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								CHANGELOG.md
									
									
									
									
									
								
							| @@ -2,6 +2,18 @@ | |||||||
|  |  | ||||||
| ## Next | ## Next | ||||||
|  |  | ||||||
|  | ### BREAKING | ||||||
|  |  | ||||||
|  | - Authentication flow has been rewritten | ||||||
|  |   [#2374](https://github.com/juanfont/headscale/pull/2374) This change should be | ||||||
|  |   transparent to users with the exception of some buxfixes that has been | ||||||
|  |   discovered and was fixed as part of the rewrite. | ||||||
|  |   - When a node is registered with _a new user_, it will be registered as a new | ||||||
|  |     node ([#2327](https://github.com/juanfont/headscale/issues/2327) and | ||||||
|  |     [#1310](https://github.com/juanfont/headscale/issues/1310)). | ||||||
|  |   - A logged out node logging in with the same user will replace the existing | ||||||
|  |     node. | ||||||
|  |  | ||||||
| ### Changes | ### Changes | ||||||
|  |  | ||||||
| - `oidc.map_legacy_users` is now `false` by default | - `oidc.map_legacy_users` is now `false` by default | ||||||
|   | |||||||
| @@ -521,25 +521,28 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not | |||||||
|  |  | ||||||
| // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. | // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. | ||||||
| // Maybe we should attempt a new in memory state and not go via the DB? | // Maybe we should attempt a new in memory state and not go via the DB? | ||||||
| func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { | // A bool is returned indicating if a full update was sent to all nodes | ||||||
|  | func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) (bool, error) { | ||||||
| 	nodes, err := db.ListNodes() | 	nodes, err := db.ListNodes() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return false, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	changed, err := polMan.SetNodes(nodes) | 	filterChanged, err := polMan.SetNodes(nodes) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return false, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if changed { | 	if filterChanged { | ||||||
| 		ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") | 		ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") | ||||||
| 		notif.NotifyAll(ctx, types.StateUpdate{ | 		notif.NotifyAll(ctx, types.StateUpdate{ | ||||||
| 			Type: types.StateFullUpdate, | 			Type: types.StateFullUpdate, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | 		return true, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return false, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Serve launches the HTTP and gRPC server service Headscale and the API. | // Serve launches the HTTP and gRPC server service Headscale and the API. | ||||||
|   | |||||||
| @@ -2,7 +2,6 @@ package hscontrol | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| @@ -13,7 +12,6 @@ import ( | |||||||
| 	"github.com/juanfont/headscale/hscontrol/db" | 	"github.com/juanfont/headscale/hscontrol/db" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/types" | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/rs/zerolog/log" |  | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"tailscale.com/tailcfg" | 	"tailscale.com/tailcfg" | ||||||
| 	"tailscale.com/types/key" | 	"tailscale.com/types/key" | ||||||
| @@ -25,730 +23,244 @@ type AuthProvider interface { | |||||||
| 	AuthURL(types.RegistrationID) string | 	AuthURL(types.RegistrationID) string | ||||||
| } | } | ||||||
|  |  | ||||||
| func logAuthFunc( | func (h *Headscale) handleRegister( | ||||||
| 	registerRequest tailcfg.RegisterRequest, | 	ctx context.Context, | ||||||
|  | 	regReq tailcfg.RegisterRequest, | ||||||
| 	machineKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| 	registrationId types.RegistrationID, | ) (*tailcfg.RegisterResponse, error) { | ||||||
| ) (func(string), func(string), func(error, string)) { | 	node, err := h.db.GetNodeByNodeKey(regReq.NodeKey) | ||||||
| 	return func(msg string) { | 	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { | ||||||
| 			log.Info(). | 		return nil, fmt.Errorf("looking up node in database: %w", err) | ||||||
| 				Caller(). | 	} | ||||||
| 				Str("registration_id", registrationId.String()). |  | ||||||
| 				Str("machine_key", machineKey.ShortString()). | 	if node != nil { | ||||||
| 				Str("node_key", registerRequest.NodeKey.ShortString()). | 		resp, err := h.handleExistingNode(node, regReq, machineKey) | ||||||
| 				Str("node_key_old", registerRequest.OldNodeKey.ShortString()). | 		if err != nil { | ||||||
| 				Str("node", registerRequest.Hostinfo.Hostname). | 			return nil, fmt.Errorf("handling existing node: %w", err) | ||||||
| 				Str("followup", registerRequest.Followup). |  | ||||||
| 				Time("expiry", registerRequest.Expiry). |  | ||||||
| 				Msg(msg) |  | ||||||
| 		}, |  | ||||||
| 		func(msg string) { |  | ||||||
| 			log.Trace(). |  | ||||||
| 				Caller(). |  | ||||||
| 				Str("registration_id", registrationId.String()). |  | ||||||
| 				Str("machine_key", machineKey.ShortString()). |  | ||||||
| 				Str("node_key", registerRequest.NodeKey.ShortString()). |  | ||||||
| 				Str("node_key_old", registerRequest.OldNodeKey.ShortString()). |  | ||||||
| 				Str("node", registerRequest.Hostinfo.Hostname). |  | ||||||
| 				Str("followup", registerRequest.Followup). |  | ||||||
| 				Time("expiry", registerRequest.Expiry). |  | ||||||
| 				Msg(msg) |  | ||||||
| 		}, |  | ||||||
| 		func(err error, msg string) { |  | ||||||
| 			log.Error(). |  | ||||||
| 				Caller(). |  | ||||||
| 				Str("registration_id", registrationId.String()). |  | ||||||
| 				Str("machine_key", machineKey.ShortString()). |  | ||||||
| 				Str("node_key", registerRequest.NodeKey.ShortString()). |  | ||||||
| 				Str("node_key_old", registerRequest.OldNodeKey.ShortString()). |  | ||||||
| 				Str("node", registerRequest.Hostinfo.Hostname). |  | ||||||
| 				Str("followup", registerRequest.Followup). |  | ||||||
| 				Time("expiry", registerRequest.Expiry). |  | ||||||
| 				Err(err). |  | ||||||
| 				Msg(msg) |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		return resp, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if regReq.Followup != "" { | ||||||
|  | 		// TODO(kradalby): Does this need to return an error of some sort? | ||||||
|  | 		// Maybe if the registration fails down the line it can be sent | ||||||
|  | 		// on the channel and returned here? | ||||||
|  | 		h.waitForFollowup(ctx, regReq) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if regReq.Auth != nil && regReq.Auth.AuthKey != "" { | ||||||
|  | 		resp, err := h.handleRegisterWithAuthKey(regReq, machineKey) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("handling register with auth key: %w", err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return resp, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err := h.handleRegisterInteractive(regReq, machineKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("handling register interactive: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return resp, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Headscale) handleExistingNode( | ||||||
|  | 	node *types.Node, | ||||||
|  | 	regReq tailcfg.RegisterRequest, | ||||||
|  | 	machineKey key.MachinePublic, | ||||||
|  | ) (*tailcfg.RegisterResponse, error) { | ||||||
|  | 	if node.MachineKey != machineKey { | ||||||
|  | 		return nil, errors.New("node already exists with different machine key") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	expired := node.IsExpired() | ||||||
|  | 	if !expired && !regReq.Expiry.IsZero() { | ||||||
|  | 		requestExpiry := regReq.Expiry | ||||||
|  |  | ||||||
|  | 		// The client is trying to extend their key, this is not allowed. | ||||||
|  | 		if requestExpiry.After(time.Now()) { | ||||||
|  | 			return nil, errors.New("extending key is not allowed") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// If the request expiry is in the past, we consider it a logout. | ||||||
|  | 		if requestExpiry.Before(time.Now()) { | ||||||
|  | 			if node.IsEphemeral() { | ||||||
|  | 				changedNodes, err := h.db.DeleteNode(node, h.nodeNotifier.LikelyConnectedMap()) | ||||||
|  | 				if err != nil { | ||||||
|  | 					return nil, fmt.Errorf("deleting ephemeral node: %w", err) | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") | ||||||
|  | 				h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ | ||||||
|  | 					Type:    types.StatePeerRemoved, | ||||||
|  | 					Removed: []types.NodeID{node.ID}, | ||||||
|  | 				}) | ||||||
|  | 				if changedNodes != nil { | ||||||
|  | 					h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ | ||||||
|  | 						Type:        types.StatePeerChanged, | ||||||
|  | 						ChangeNodes: changedNodes, | ||||||
|  | 					}) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			expired = true | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		err := h.db.NodeSetExpiry(node.ID, requestExpiry) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("setting node expiry: %w", err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") | ||||||
|  | 		h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, requestExpiry), node.ID) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &tailcfg.RegisterResponse{ | ||||||
|  | 		// TODO(kradalby): Only send for user-owned nodes | ||||||
|  | 		// and not tagged nodes when tags is working. | ||||||
|  | 		User:           *node.User.TailscaleUser(), | ||||||
|  | 		Login:          *node.User.TailscaleLogin(), | ||||||
|  | 		NodeKeyExpired: expired, | ||||||
|  |  | ||||||
|  | 		// Headscale does not implement the concept of machine authorization | ||||||
|  | 		// so we always return true here. | ||||||
|  | 		// Revisit this if #2176 gets implemented. | ||||||
|  | 		MachineAuthorized: true, | ||||||
|  | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *Headscale) waitForFollowup( | func (h *Headscale) waitForFollowup( | ||||||
| 	req *http.Request, | 	ctx context.Context, | ||||||
| 	regReq tailcfg.RegisterRequest, | 	regReq tailcfg.RegisterRequest, | ||||||
| 	logTrace func(string), |  | ||||||
| ) { | ) { | ||||||
| 	logTrace("register request is a followup") |  | ||||||
| 	fu, err := url.Parse(regReq.Followup) | 	fu, err := url.Parse(regReq.Followup) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logTrace("failed to parse followup URL") |  | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) | 	followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logTrace("followup URL does not contains a valid registration ID") |  | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg)) |  | ||||||
|  |  | ||||||
| 	if reg, ok := h.registrationCache.Get(followupReg); ok { | 	if reg, ok := h.registrationCache.Get(followupReg); ok { | ||||||
| 		logTrace("Node is waiting for interactive login") |  | ||||||
|  |  | ||||||
| 		select { | 		select { | ||||||
| 		case <-req.Context().Done(): | 		case <-ctx.Done(): | ||||||
| 			logTrace("node went away before it was registered") |  | ||||||
| 			return | 			return | ||||||
| 		case <-reg.Registered: | 		case <-reg.Registered: | ||||||
| 			logTrace("node has successfully registered") |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // handleRegister is the logic for registering a client. | func (h *Headscale) handleRegisterWithAuthKey( | ||||||
| func (h *Headscale) handleRegister( |  | ||||||
| 	writer http.ResponseWriter, |  | ||||||
| 	req *http.Request, |  | ||||||
| 	regReq tailcfg.RegisterRequest, | 	regReq tailcfg.RegisterRequest, | ||||||
| 	machineKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| ) { | ) (*tailcfg.RegisterResponse, error) { | ||||||
| 	registrationId, err := types.NewRegistrationID() | 	pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		return nil, fmt.Errorf("invalid pre auth key: %w", err) | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Failed to generate registration ID") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId) | 	nodeToRegister := types.Node{ | ||||||
| 	now := time.Now().UTC() | 		Hostname:       regReq.Hostinfo.Hostname, | ||||||
| 	logTrace("handleRegister called, looking up machine in DB") | 		UserID:         pak.User.ID, | ||||||
|  | 		User:           pak.User, | ||||||
|  | 		MachineKey:     machineKey, | ||||||
|  | 		NodeKey:        regReq.NodeKey, | ||||||
|  | 		Hostinfo:       regReq.Hostinfo, | ||||||
|  | 		LastSeen:       ptr.To(time.Now()), | ||||||
|  | 		RegisterMethod: util.RegisterMethodAuthKey, | ||||||
|  |  | ||||||
| 	// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs | 		// TODO(kradalby): This should not be set on the node, | ||||||
| 	// key refreshes. This will allow us to remove the machineKey from the registration request. | 		// they should be looked up through the key, which is | ||||||
| 	node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) | 		// attached to the node. | ||||||
| 	logTrace("handleRegister database lookup has returned") | 		ForcedTags: pak.Proto().GetAclTags(), | ||||||
| 	if errors.Is(err, gorm.ErrRecordNotFound) { | 		AuthKey:    pak, | ||||||
| 		// If the node has AuthKey set, handle registration via PreAuthKeys | 		AuthKeyID:  &pak.ID, | ||||||
| 		if regReq.Auth != nil && regReq.Auth.AuthKey != "" { |  | ||||||
| 			h.handleAuthKey(writer, regReq, machineKey) |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Check if the node is waiting for interactive login. |  | ||||||
| 		if regReq.Followup != "" { |  | ||||||
| 			h.waitForFollowup(req, regReq, logTrace) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		logInfo("Node not found in database, creating new") |  | ||||||
|  |  | ||||||
| 		// The node did not have a key to authenticate, which means |  | ||||||
| 		// that we rely on a method that calls back some how (OpenID or CLI) |  | ||||||
| 		// We create the node and then keep it around until a callback |  | ||||||
| 		// happens |  | ||||||
| 		newNode := types.RegisterNode{ |  | ||||||
| 			Node: types.Node{ |  | ||||||
| 				MachineKey: machineKey, |  | ||||||
| 				Hostname:   regReq.Hostinfo.Hostname, |  | ||||||
| 				NodeKey:    regReq.NodeKey, |  | ||||||
| 				LastSeen:   &now, |  | ||||||
| 				Expiry:     &time.Time{}, |  | ||||||
| 			}, |  | ||||||
| 			Registered: make(chan struct{}), |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if !regReq.Expiry.IsZero() { |  | ||||||
| 			logTrace("Non-zero expiry time requested") |  | ||||||
| 			newNode.Node.Expiry = ®Req.Expiry |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		h.registrationCache.Set( |  | ||||||
| 			registrationId, |  | ||||||
| 			newNode, |  | ||||||
| 		) |  | ||||||
|  |  | ||||||
| 		h.handleNewNode(writer, regReq, registrationId) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// The node is already in the DB. This could mean one of the following: | 	if !regReq.Expiry.IsZero() { | ||||||
| 	// - The node is authenticated and ready to /map | 		nodeToRegister.Expiry = ®Req.Expiry | ||||||
| 	// - We are doing a key refresh |  | ||||||
| 	// - The node is logged out (or expired) and pending to be authorized. TODO(juan): We need to keep alive the connection here |  | ||||||
| 	if node != nil { |  | ||||||
| 		// (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021, |  | ||||||
| 		// due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054 |  | ||||||
| 		// So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it. |  | ||||||
| 		if err != nil || node.MachineKey.IsZero() { |  | ||||||
| 			if err := h.db.NodeSetMachineKey(node, machineKey); err != nil { |  | ||||||
| 				log.Error(). |  | ||||||
| 					Caller(). |  | ||||||
| 					Str("func", "RegistrationHandler"). |  | ||||||
| 					Str("node", node.Hostname). |  | ||||||
| 					Err(err). |  | ||||||
| 					Msg("Error saving machine key to database") |  | ||||||
|  |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// If the NodeKey stored in headscale is the same as the key presented in a registration |  | ||||||
| 		// request, then we have a node that is either: |  | ||||||
| 		// - Trying to log out (sending a expiry in the past) |  | ||||||
| 		// - A valid, registered node, looking for /map |  | ||||||
| 		// - Expired node wanting to reauthenticate |  | ||||||
| 		if node.NodeKey.String() == regReq.NodeKey.String() { |  | ||||||
| 			// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) |  | ||||||
| 			//   https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 |  | ||||||
| 			if !regReq.Expiry.IsZero() && |  | ||||||
| 				regReq.Expiry.UTC().Before(now) { |  | ||||||
| 				h.handleNodeLogOut(writer, *node) |  | ||||||
|  |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// If node is not expired, and it is register, we have a already accepted this node, |  | ||||||
| 			// let it proceed with a valid registration |  | ||||||
| 			if !node.IsExpired() { |  | ||||||
| 				h.handleNodeWithValidRegistration(writer, *node) |  | ||||||
|  |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration |  | ||||||
| 		if node.NodeKey.String() == regReq.OldNodeKey.String() && |  | ||||||
| 			!node.IsExpired() { |  | ||||||
| 			h.handleNodeKeyRefresh( |  | ||||||
| 				writer, |  | ||||||
| 				regReq, |  | ||||||
| 				*node, |  | ||||||
| 			) |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// When logged out and reauthenticating with OIDC, the OldNodeKey is not passed, but the NodeKey has changed |  | ||||||
| 		if node.NodeKey.String() != regReq.NodeKey.String() && |  | ||||||
| 			regReq.OldNodeKey.IsZero() && !node.IsExpired() { |  | ||||||
| 			h.handleNodeKeyRefresh( |  | ||||||
| 				writer, |  | ||||||
| 				regReq, |  | ||||||
| 				*node, |  | ||||||
| 			) |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if regReq.Followup != "" { |  | ||||||
| 			h.waitForFollowup(req, regReq, logTrace) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// The node has expired or it is logged out |  | ||||||
| 		h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId) |  | ||||||
|  |  | ||||||
| 		// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use |  | ||||||
| 		node.Expiry = &time.Time{} |  | ||||||
|  |  | ||||||
| 		// TODO(kradalby): do we need to rethink this as part of authflow? |  | ||||||
| 		// If we are here it means the client needs to be reauthorized, |  | ||||||
| 		// we need to make sure the NodeKey matches the one in the request |  | ||||||
| 		// TODO(juan): What happens when using fast user switching between two |  | ||||||
| 		// headscale-managed tailnets? |  | ||||||
| 		node.NodeKey = regReq.NodeKey |  | ||||||
| 		h.registrationCache.Set( |  | ||||||
| 			registrationId, |  | ||||||
| 			types.RegisterNode{ |  | ||||||
| 				Node:       *node, |  | ||||||
| 				Registered: make(chan struct{}), |  | ||||||
| 			}, |  | ||||||
| 		) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| } |  | ||||||
|  |  | ||||||
| // handleAuthKey contains the logic to manage auth key client registration | 	ipv4, ipv6, err := h.ipAlloc.Next() | ||||||
| // When using Noise, the machineKey is Zero. |  | ||||||
| func (h *Headscale) handleAuthKey( |  | ||||||
| 	writer http.ResponseWriter, |  | ||||||
| 	registerRequest tailcfg.RegisterRequest, |  | ||||||
| 	machineKey key.MachinePublic, |  | ||||||
| ) { |  | ||||||
| 	log.Debug(). |  | ||||||
| 		Caller(). |  | ||||||
| 		Str("node", registerRequest.Hostinfo.Hostname). |  | ||||||
| 		Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) |  | ||||||
| 	resp := tailcfg.RegisterResponse{} |  | ||||||
|  |  | ||||||
| 	pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		return nil, fmt.Errorf("allocating IPs: %w", err) | ||||||
| 			Caller(). |  | ||||||
| 			Str("node", registerRequest.Hostinfo.Hostname). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Failed authentication via AuthKey") |  | ||||||
| 		resp.MachineAuthorized = false |  | ||||||
|  |  | ||||||
| 		respBody, err := json.Marshal(resp) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error(). |  | ||||||
| 				Caller(). |  | ||||||
| 				Str("node", registerRequest.Hostinfo.Hostname). |  | ||||||
| 				Err(err). |  | ||||||
| 				Msg("Cannot encode message") |  | ||||||
| 			http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		writer.Header().Set("Content-Type", "application/json; charset=utf-8") |  | ||||||
| 		writer.WriteHeader(http.StatusUnauthorized) |  | ||||||
| 		_, err = writer.Write(respBody) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error(). |  | ||||||
| 				Caller(). |  | ||||||
| 				Err(err). |  | ||||||
| 				Msg("Failed to write response") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Str("node", registerRequest.Hostinfo.Hostname). |  | ||||||
| 			Msg("Failed authentication via AuthKey") |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug(). | 	node, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.Node, error) { | ||||||
| 		Caller(). | 		node, err := db.RegisterNode(tx, | ||||||
| 		Str("node", registerRequest.Hostinfo.Hostname). |  | ||||||
| 		Msg("Authentication key was valid, proceeding to acquire IP addresses") |  | ||||||
|  |  | ||||||
| 	nodeKey := registerRequest.NodeKey |  | ||||||
|  |  | ||||||
| 	// retrieve node information if it exist |  | ||||||
| 	// The error is not important, because if it does not |  | ||||||
| 	// exist, then this is a new node and we will move |  | ||||||
| 	// on to registration. |  | ||||||
| 	// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs |  | ||||||
| 	// key refreshes. This will allow us to remove the machineKey from the registration request. |  | ||||||
| 	node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) |  | ||||||
| 	if node != nil { |  | ||||||
| 		log.Trace(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Str("node", node.Hostname). |  | ||||||
| 			Msg("node was already registered before, refreshing with new auth key") |  | ||||||
|  |  | ||||||
| 		node.NodeKey = nodeKey |  | ||||||
| 		if pak.ID != 0 { |  | ||||||
| 			node.AuthKeyID = ptr.To(pak.ID) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		node.Expiry = ®isterRequest.Expiry |  | ||||||
| 		node.User = pak.User |  | ||||||
| 		node.UserID = pak.UserID |  | ||||||
| 		err := h.db.DB.Save(node).Error |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error(). |  | ||||||
| 				Caller(). |  | ||||||
| 				Str("node", node.Hostname). |  | ||||||
| 				Err(err). |  | ||||||
| 				Msg("failed to save node after logging in with auth key") |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		aclTags := pak.Proto().GetAclTags() |  | ||||||
| 		if len(aclTags) > 0 { |  | ||||||
| 			// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login |  | ||||||
| 			err = h.db.SetTags(node.ID, aclTags) |  | ||||||
| 			if err != nil { |  | ||||||
| 				log.Error(). |  | ||||||
| 					Caller(). |  | ||||||
| 					Str("node", node.Hostname). |  | ||||||
| 					Strs("aclTags", aclTags). |  | ||||||
| 					Err(err). |  | ||||||
| 					Msg("Failed to set tags after refreshing node") |  | ||||||
|  |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na") |  | ||||||
| 		h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{Type: types.StatePeerChanged, ChangeNodes: []types.NodeID{node.ID}}) |  | ||||||
| 	} else { |  | ||||||
| 		now := time.Now().UTC() |  | ||||||
|  |  | ||||||
| 		nodeToRegister := types.Node{ |  | ||||||
| 			Hostname:       registerRequest.Hostinfo.Hostname, |  | ||||||
| 			UserID:         pak.User.ID, |  | ||||||
| 			User:           pak.User, |  | ||||||
| 			MachineKey:     machineKey, |  | ||||||
| 			RegisterMethod: util.RegisterMethodAuthKey, |  | ||||||
| 			Expiry:         ®isterRequest.Expiry, |  | ||||||
| 			NodeKey:        nodeKey, |  | ||||||
| 			LastSeen:       &now, |  | ||||||
| 			ForcedTags:     pak.Proto().GetAclTags(), |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		ipv4, ipv6, err := h.ipAlloc.Next() |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error(). |  | ||||||
| 				Caller(). |  | ||||||
| 				Str("func", "RegistrationHandler"). |  | ||||||
| 				Str("hostinfo.name", registerRequest.Hostinfo.Hostname). |  | ||||||
| 				Err(err). |  | ||||||
| 				Msg("failed to allocate IP	") |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		pakID := uint(pak.ID) |  | ||||||
| 		if pakID != 0 { |  | ||||||
| 			nodeToRegister.AuthKeyID = ptr.To(pak.ID) |  | ||||||
| 		} |  | ||||||
| 		node, err = h.db.RegisterNode( |  | ||||||
| 			nodeToRegister, | 			nodeToRegister, | ||||||
| 			ipv4, ipv6, | 			ipv4, ipv6, | ||||||
| 		) | 		) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			return nil, fmt.Errorf("registering node: %w", err) | ||||||
| 				Caller(). |  | ||||||
| 				Err(err). |  | ||||||
| 				Msg("could not register node") |  | ||||||
| 			http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier) | 		if !pak.Reusable { | ||||||
| 		if err != nil { | 			err = db.UsePreAuthKey(tx, pak) | ||||||
| 			http.Error(writer, "Internal server error", http.StatusInternalServerError) | 			if err != nil { | ||||||
| 			return | 				return nil, fmt.Errorf("using pre auth key: %w", err) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	err = h.db.Write(func(tx *gorm.DB) error { | 		return node, nil | ||||||
| 		return db.UsePreAuthKey(tx, pak) |  | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		return nil, err | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Failed to use pre-auth key") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	resp.MachineAuthorized = true | 	updateSent, err := nodesChangedHook(h.db, h.polMan, h.nodeNotifier) | ||||||
| 	resp.User = *pak.User.TailscaleUser() |  | ||||||
| 	// Provide LoginName when registering with pre-auth key |  | ||||||
| 	// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* |  | ||||||
| 	resp.Login = *pak.User.TailscaleLogin() |  | ||||||
|  |  | ||||||
| 	respBody, err := json.Marshal(resp) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		return nil, fmt.Errorf("nodes changed hook: %w", err) | ||||||
| 			Caller(). |  | ||||||
| 			Str("node", registerRequest.Hostinfo.Hostname). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Cannot encode message") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") |  | ||||||
| 	writer.WriteHeader(http.StatusOK) |  | ||||||
| 	_, err = writer.Write(respBody) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Failed to write response") |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Info(). | 	if !updateSent { | ||||||
| 		Str("node", registerRequest.Hostinfo.Hostname). | 		ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname) | ||||||
| 		Msg("Successfully authenticated via AuthKey") | 		h.nodeNotifier.NotifyAll(ctx, types.StateUpdatePeerAdded(node.ID)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &tailcfg.RegisterResponse{ | ||||||
|  | 		MachineAuthorized: true, | ||||||
|  | 		NodeKeyExpired:    node.IsExpired(), | ||||||
|  | 		User:              *pak.User.TailscaleUser(), | ||||||
|  | 		Login:             *pak.User.TailscaleLogin(), | ||||||
|  | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // handleNewNode returns the authorisation URL to the client based on what type | func (h *Headscale) handleRegisterInteractive( | ||||||
| // of registration headscale is configured with. |  | ||||||
| // This url is then showed to the user by the local Tailscale client. |  | ||||||
| func (h *Headscale) handleNewNode( |  | ||||||
| 	writer http.ResponseWriter, |  | ||||||
| 	registerRequest tailcfg.RegisterRequest, |  | ||||||
| 	registrationId types.RegistrationID, |  | ||||||
| ) { |  | ||||||
| 	logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId) |  | ||||||
|  |  | ||||||
| 	resp := tailcfg.RegisterResponse{} |  | ||||||
|  |  | ||||||
| 	// The node registration is new, redirect the client to the registration URL |  | ||||||
| 	logTrace("The node is new, sending auth url") |  | ||||||
|  |  | ||||||
| 	resp.AuthURL = h.authProvider.AuthURL(registrationId) |  | ||||||
|  |  | ||||||
| 	respBody, err := json.Marshal(resp) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logErr(err, "Cannot encode message") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") |  | ||||||
| 	writer.WriteHeader(http.StatusOK) |  | ||||||
| 	_, err = writer.Write(respBody) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logErr(err, "Failed to write response") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	logInfo(fmt.Sprintf("Successfully sent auth url: %s", resp.AuthURL)) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (h *Headscale) handleNodeLogOut( |  | ||||||
| 	writer http.ResponseWriter, |  | ||||||
| 	node types.Node, |  | ||||||
| ) { |  | ||||||
| 	resp := tailcfg.RegisterResponse{} |  | ||||||
|  |  | ||||||
| 	log.Info(). |  | ||||||
| 		Str("node", node.Hostname). |  | ||||||
| 		Msg("Client requested logout") |  | ||||||
|  |  | ||||||
| 	now := time.Now() |  | ||||||
| 	err := h.db.NodeSetExpiry(node.ID, now) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Failed to expire node") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") |  | ||||||
| 	h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID) |  | ||||||
|  |  | ||||||
| 	resp.AuthURL = "" |  | ||||||
| 	resp.MachineAuthorized = false |  | ||||||
| 	resp.NodeKeyExpired = true |  | ||||||
| 	resp.User = *node.User.TailscaleUser() |  | ||||||
| 	respBody, err := json.Marshal(resp) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Cannot encode message") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") |  | ||||||
| 	writer.WriteHeader(http.StatusOK) |  | ||||||
| 	_, err = writer.Write(respBody) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Failed to write response") |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if node.IsEphemeral() { |  | ||||||
| 		changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.LikelyConnectedMap()) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error(). |  | ||||||
| 				Err(err). |  | ||||||
| 				Str("node", node.Hostname). |  | ||||||
| 				Msg("Cannot delete ephemeral node from the database") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") |  | ||||||
| 		h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ |  | ||||||
| 			Type:    types.StatePeerRemoved, |  | ||||||
| 			Removed: []types.NodeID{node.ID}, |  | ||||||
| 		}) |  | ||||||
| 		if changedNodes != nil { |  | ||||||
| 			h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ |  | ||||||
| 				Type:        types.StatePeerChanged, |  | ||||||
| 				ChangeNodes: changedNodes, |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Info(). |  | ||||||
| 		Caller(). |  | ||||||
| 		Str("node", node.Hostname). |  | ||||||
| 		Msg("Successfully logged out") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (h *Headscale) handleNodeWithValidRegistration( |  | ||||||
| 	writer http.ResponseWriter, |  | ||||||
| 	node types.Node, |  | ||||||
| ) { |  | ||||||
| 	resp := tailcfg.RegisterResponse{} |  | ||||||
|  |  | ||||||
| 	// The node registration is valid, respond with redirect to /map |  | ||||||
| 	log.Debug(). |  | ||||||
| 		Caller(). |  | ||||||
| 		Str("node", node.Hostname). |  | ||||||
| 		Msg("Client is registered and we have the current NodeKey. All clear to /map") |  | ||||||
|  |  | ||||||
| 	resp.AuthURL = "" |  | ||||||
| 	resp.MachineAuthorized = true |  | ||||||
| 	resp.User = *node.User.TailscaleUser() |  | ||||||
| 	resp.Login = *node.User.TailscaleLogin() |  | ||||||
|  |  | ||||||
| 	respBody, err := json.Marshal(resp) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Cannot encode message") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") |  | ||||||
| 	writer.WriteHeader(http.StatusOK) |  | ||||||
| 	_, err = writer.Write(respBody) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Failed to write response") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Info(). |  | ||||||
| 		Caller(). |  | ||||||
| 		Str("node", node.Hostname). |  | ||||||
| 		Msg("Node successfully authorized") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (h *Headscale) handleNodeKeyRefresh( |  | ||||||
| 	writer http.ResponseWriter, |  | ||||||
| 	registerRequest tailcfg.RegisterRequest, |  | ||||||
| 	node types.Node, |  | ||||||
| ) { |  | ||||||
| 	resp := tailcfg.RegisterResponse{} |  | ||||||
|  |  | ||||||
| 	log.Info(). |  | ||||||
| 		Caller(). |  | ||||||
| 		Str("node", node.Hostname). |  | ||||||
| 		Msg("We have the OldNodeKey in the database. This is a key refresh") |  | ||||||
|  |  | ||||||
| 	err := h.db.Write(func(tx *gorm.DB) error { |  | ||||||
| 		return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey) |  | ||||||
| 	}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Failed to update machine key in the database") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	resp.AuthURL = "" |  | ||||||
| 	resp.User = *node.User.TailscaleUser() |  | ||||||
| 	respBody, err := json.Marshal(resp) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Cannot encode message") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") |  | ||||||
| 	writer.WriteHeader(http.StatusOK) |  | ||||||
| 	_, err = writer.Write(respBody) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Failed to write response") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Info(). |  | ||||||
| 		Caller(). |  | ||||||
| 		Str("node_key", registerRequest.NodeKey.ShortString()). |  | ||||||
| 		Str("old_node_key", registerRequest.OldNodeKey.ShortString()). |  | ||||||
| 		Str("node", node.Hostname). |  | ||||||
| 		Msg("Node key successfully refreshed") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (h *Headscale) handleNodeExpiredOrLoggedOut( |  | ||||||
| 	writer http.ResponseWriter, |  | ||||||
| 	regReq tailcfg.RegisterRequest, | 	regReq tailcfg.RegisterRequest, | ||||||
| 	node types.Node, |  | ||||||
| 	machineKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| 	registrationId types.RegistrationID, | ) (*tailcfg.RegisterResponse, error) { | ||||||
| ) { | 	registrationId, err := types.NewRegistrationID() | ||||||
| 	resp := tailcfg.RegisterResponse{} |  | ||||||
|  |  | ||||||
| 	if regReq.Auth != nil && regReq.Auth.AuthKey != "" { |  | ||||||
| 		h.handleAuthKey(writer, regReq, machineKey) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// The client has registered before, but has expired or logged out |  | ||||||
| 	log.Trace(). |  | ||||||
| 		Caller(). |  | ||||||
| 		Str("node", node.Hostname). |  | ||||||
| 		Str("registration_id", registrationId.String()). |  | ||||||
| 		Str("node_key", regReq.NodeKey.ShortString()). |  | ||||||
| 		Str("node_key_old", regReq.OldNodeKey.ShortString()). |  | ||||||
| 		Msg("Node registration has expired or logged out. Sending a auth url to register") |  | ||||||
|  |  | ||||||
| 	resp.AuthURL = h.authProvider.AuthURL(registrationId) |  | ||||||
|  |  | ||||||
| 	respBody, err := json.Marshal(resp) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		return nil, fmt.Errorf("generating registration ID: %w", err) | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Cannot encode message") |  | ||||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) |  | ||||||
|  |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") | 	newNode := types.RegisterNode{ | ||||||
| 	writer.WriteHeader(http.StatusOK) | 		Node: types.Node{ | ||||||
| 	_, err = writer.Write(respBody) | 			Hostname:   regReq.Hostinfo.Hostname, | ||||||
| 	if err != nil { | 			MachineKey: machineKey, | ||||||
| 		log.Error(). | 			NodeKey:    regReq.NodeKey, | ||||||
| 			Caller(). | 			Hostinfo:   regReq.Hostinfo, | ||||||
| 			Err(err). | 			LastSeen:   ptr.To(time.Now()), | ||||||
| 			Msg("Failed to write response") | 		}, | ||||||
|  | 		Registered: make(chan struct{}), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Trace(). | 	if !regReq.Expiry.IsZero() { | ||||||
| 		Caller(). | 		newNode.Node.Expiry = ®Req.Expiry | ||||||
| 		Str("registration_id", registrationId.String()). | 	} | ||||||
| 		Str("node_key", regReq.NodeKey.ShortString()). |  | ||||||
| 		Str("node_key_old", regReq.OldNodeKey.ShortString()). | 	h.registrationCache.Set( | ||||||
| 		Str("node", node.Hostname). | 		registrationId, | ||||||
| 		Msg("Node logged out. Sent AuthURL for reauthentication") | 		newNode, | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	return &tailcfg.RegisterResponse{ | ||||||
|  | 		AuthURL: h.authProvider.AuthURL(registrationId), | ||||||
|  | 	}, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -182,38 +182,6 @@ func GetNodeByNodeKey( | |||||||
| 	return &mach, nil | 	return &mach, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (hsdb *HSDatabase) GetNodeByAnyKey( |  | ||||||
| 	machineKey key.MachinePublic, |  | ||||||
| 	nodeKey key.NodePublic, |  | ||||||
| 	oldNodeKey key.NodePublic, |  | ||||||
| ) (*types.Node, error) { |  | ||||||
| 	return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { |  | ||||||
| 		return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey) |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct. |  | ||||||
| // TODO(kradalby): see if we can remove this. |  | ||||||
| func GetNodeByAnyKey( |  | ||||||
| 	tx *gorm.DB, |  | ||||||
| 	machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, |  | ||||||
| ) (*types.Node, error) { |  | ||||||
| 	node := types.Node{} |  | ||||||
| 	if result := tx. |  | ||||||
| 		Preload("AuthKey"). |  | ||||||
| 		Preload("AuthKey.User"). |  | ||||||
| 		Preload("User"). |  | ||||||
| 		Preload("Routes"). |  | ||||||
| 		First(&node, "machine_key = ? OR node_key = ? OR node_key = ?", |  | ||||||
| 			machineKey.String(), |  | ||||||
| 			nodeKey.String(), |  | ||||||
| 			oldNodeKey.String()); result.Error != nil { |  | ||||||
| 		return nil, result.Error |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &node, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (hsdb *HSDatabase) SetTags( | func (hsdb *HSDatabase) SetTags( | ||||||
| 	nodeID types.NodeID, | 	nodeID types.NodeID, | ||||||
| 	tags []string, | 	tags []string, | ||||||
| @@ -437,6 +405,18 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad | |||||||
| 		Str("user", node.User.Username()). | 		Str("user", node.User.Username()). | ||||||
| 		Msg("Registering node") | 		Msg("Registering node") | ||||||
|  |  | ||||||
|  | 	// If the a new node is registered with the same machine key, to the same user, | ||||||
|  | 	// update the existing node. | ||||||
|  | 	// If the same node is registered again, but to a new user, then that is considered | ||||||
|  | 	// a new node. | ||||||
|  | 	oldNode, _ := GetNodeByMachineKey(tx, node.MachineKey) | ||||||
|  | 	if oldNode != nil && oldNode.UserID == node.UserID { | ||||||
|  | 		node.ID = oldNode.ID | ||||||
|  | 		node.GivenName = oldNode.GivenName | ||||||
|  | 		ipv4 = oldNode.IPv4 | ||||||
|  | 		ipv6 = oldNode.IPv6 | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// If the node exists and it already has IP(s), we just save it | 	// If the node exists and it already has IP(s), we just save it | ||||||
| 	// so we store the node.Expire and node.Nodekey that has been set when | 	// so we store the node.Expire and node.Nodekey that has been set when | ||||||
| 	// adding it to the registrationCache | 	// adding it to the registrationCache | ||||||
|   | |||||||
| @@ -84,37 +84,6 @@ func (s *Suite) TestGetNodeByID(c *check.C) { | |||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { |  | ||||||
| 	user, err := db.CreateUser(types.User{Name: "test"}) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
|  |  | ||||||
| 	pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
|  |  | ||||||
| 	_, err = db.GetNodeByID(0) |  | ||||||
| 	c.Assert(err, check.NotNil) |  | ||||||
|  |  | ||||||
| 	nodeKey := key.NewNode() |  | ||||||
| 	oldNodeKey := key.NewNode() |  | ||||||
|  |  | ||||||
| 	machineKey := key.NewMachine() |  | ||||||
|  |  | ||||||
| 	node := types.Node{ |  | ||||||
| 		ID:             0, |  | ||||||
| 		MachineKey:     machineKey.Public(), |  | ||||||
| 		NodeKey:        nodeKey.Public(), |  | ||||||
| 		Hostname:       "testnode", |  | ||||||
| 		UserID:         user.ID, |  | ||||||
| 		RegisterMethod: util.RegisterMethodAuthKey, |  | ||||||
| 		AuthKeyID:      ptr.To(pak.ID), |  | ||||||
| 	} |  | ||||||
| 	trx := db.DB.Save(&node) |  | ||||||
| 	c.Assert(trx.Error, check.IsNil) |  | ||||||
|  |  | ||||||
| 	_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) |  | ||||||
| 	c.Assert(err, check.IsNil) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *Suite) TestHardDeleteNode(c *check.C) { | func (s *Suite) TestHardDeleteNode(c *check.C) { | ||||||
| 	user, err := db.CreateUser(types.User{Name: "test"}) | 	user, err := db.CreateUser(types.User{Name: "test"}) | ||||||
| 	c.Assert(err, check.IsNil) | 	c.Assert(err, check.IsNil) | ||||||
|   | |||||||
| @@ -256,10 +256,17 @@ func (api headscaleV1APIServer) RegisterNode( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) | 	updateSent, err := nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("updating resources using node: %w", err) | 		return nil, fmt.Errorf("updating resources using node: %w", err) | ||||||
| 	} | 	} | ||||||
|  | 	if !updateSent { | ||||||
|  | 		ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname) | ||||||
|  | 		api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ | ||||||
|  | 			Type:        types.StatePeerChanged, | ||||||
|  | 			ChangeNodes: []types.NodeID{node.ID}, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return &v1.RegisterNodeResponse{Node: node.Proto()}, nil | 	return &v1.RegisterNodeResponse{Node: node.Proto()}, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -156,7 +156,12 @@ func isSupportedVersion(version tailcfg.CapabilityVersion) bool { | |||||||
| 	return version >= MinimumCapVersion | 	return version >= MinimumCapVersion | ||||||
| } | } | ||||||
|  |  | ||||||
| func rejectUnsupported(writer http.ResponseWriter, version tailcfg.CapabilityVersion, mkey key.MachinePublic, nkey key.NodePublic) bool { | func rejectUnsupported( | ||||||
|  | 	writer http.ResponseWriter, | ||||||
|  | 	version tailcfg.CapabilityVersion, | ||||||
|  | 	mkey key.MachinePublic, | ||||||
|  | 	nkey key.NodePublic, | ||||||
|  | ) bool { | ||||||
| 	// Reject unsupported versions | 	// Reject unsupported versions | ||||||
| 	if !isSupportedVersion(version) { | 	if !isSupportedVersion(version) { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| @@ -204,11 +209,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( | |||||||
|  |  | ||||||
| 	ns.nodeKey = mapRequest.NodeKey | 	ns.nodeKey = mapRequest.NodeKey | ||||||
|  |  | ||||||
| 	node, err := ns.headscale.db.GetNodeByAnyKey( | 	node, err := ns.headscale.db.GetNodeByNodeKey(mapRequest.NodeKey) | ||||||
| 		ns.conn.Peer(), |  | ||||||
| 		mapRequest.NodeKey, |  | ||||||
| 		key.NodePublic{}, |  | ||||||
| 	) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		httpError(writer, err, "Internal error", http.StatusInternalServerError) | 		httpError(writer, err, "Internal error", http.StatusInternalServerError) | ||||||
| 		return | 		return | ||||||
| @@ -234,12 +235,38 @@ func (ns *noiseServer) NoiseRegistrationHandler( | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	body, _ := io.ReadAll(req.Body) | 	registerRequest, registerResponse, err := func() (*tailcfg.RegisterRequest, []byte, error) { | ||||||
| 	var registerRequest tailcfg.RegisterRequest | 		body, err := io.ReadAll(req.Body) | ||||||
| 	if err := json.Unmarshal(body, ®isterRequest); err != nil { | 		if err != nil { | ||||||
| 		httpError(writer, err, "Internal error", http.StatusInternalServerError) | 			return nil, nil, err | ||||||
|  | 		} | ||||||
|  | 		var registerRequest tailcfg.RegisterRequest | ||||||
|  | 		if err := json.Unmarshal(body, ®isterRequest); err != nil { | ||||||
|  | 			return nil, nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		return | 		ns.nodeKey = registerRequest.NodeKey | ||||||
|  |  | ||||||
|  | 		resp, err := ns.headscale.handleRegister(req.Context(), registerRequest, ns.conn.Peer()) | ||||||
|  | 		// TODO(kradalby): Here we could have two error types, one that is surfaced to the client | ||||||
|  | 		// and one that returns 500. | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		respBody, err := json.Marshal(resp) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return ®isterRequest, respBody, nil | ||||||
|  | 	}() | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error(). | ||||||
|  | 			Caller(). | ||||||
|  | 			Err(err). | ||||||
|  | 			Msg("Error handling registration") | ||||||
|  | 		http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Reject unsupported versions | 	// Reject unsupported versions | ||||||
| @@ -247,7 +274,13 @@ func (ns *noiseServer) NoiseRegistrationHandler( | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ns.nodeKey = registerRequest.NodeKey | 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") | ||||||
|  | 	writer.WriteHeader(http.StatusOK) | ||||||
| 	ns.headscale.handleRegister(writer, req, registerRequest, ns.conn.Peer()) | 	_, err = writer.Write(registerResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error(). | ||||||
|  | 			Caller(). | ||||||
|  | 			Err(err). | ||||||
|  | 			Msg("Failed to write response") | ||||||
|  | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -512,24 +512,21 @@ func (a *AuthProviderOIDC) handleRegistrationID( | |||||||
| 	// Send an update to all nodes if this is a new node that they need to know | 	// Send an update to all nodes if this is a new node that they need to know | ||||||
| 	// about. | 	// about. | ||||||
| 	// If this is a refresh, just send new expiry updates. | 	// If this is a refresh, just send new expiry updates. | ||||||
| 	if newNode { | 	updateSent, err := nodesChangedHook(a.db, a.polMan, a.notifier) | ||||||
| 		err = nodesChangedHook(a.db, a.polMan, a.notifier) | 	if err != nil { | ||||||
| 		if err != nil { | 		return false, fmt.Errorf("updating resources using node: %w", err) | ||||||
| 			return false, fmt.Errorf("updating resources using node: %w", err) | 	} | ||||||
| 		} |  | ||||||
| 	} else { | 	if !updateSent { | ||||||
| 		ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) | 		ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) | ||||||
| 		a.notifier.NotifyByNodeID( | 		a.notifier.NotifyByNodeID( | ||||||
| 			ctx, | 			ctx, | ||||||
| 			types.StateUpdate{ | 			types.StateSelf(node.ID), | ||||||
| 				Type:        types.StateSelfUpdate, |  | ||||||
| 				ChangeNodes: []types.NodeID{node.ID}, |  | ||||||
| 			}, |  | ||||||
| 			node.ID, | 			node.ID, | ||||||
| 		) | 		) | ||||||
|  |  | ||||||
| 		ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) | 		ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) | ||||||
| 		a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID) | 		a.notifier.NotifyWithIgnore(ctx, types.StateUpdatePeerAdded(node.ID), node.ID) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return newNode, nil | 	return newNode, nil | ||||||
|   | |||||||
| @@ -102,6 +102,20 @@ func (su *StateUpdate) Empty() bool { | |||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func StateSelf(nodeID NodeID) StateUpdate { | ||||||
|  | 	return StateUpdate{ | ||||||
|  | 		Type:        StateSelfUpdate, | ||||||
|  | 		ChangeNodes: []NodeID{nodeID}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func StateUpdatePeerAdded(nodeIDs ...NodeID) StateUpdate { | ||||||
|  | 	return StateUpdate{ | ||||||
|  | 		Type:        StatePeerChanged, | ||||||
|  | 		ChangeNodes: nodeIDs, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { | func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { | ||||||
| 	return StateUpdate{ | 	return StateUpdate{ | ||||||
| 		Type: StatePeerChangedPatch, | 		Type: StatePeerChangedPatch, | ||||||
|   | |||||||
							
								
								
									
										230
									
								
								integration/auth_key_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										230
									
								
								integration/auth_key_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,230 @@ | |||||||
|  | package integration | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/juanfont/headscale/integration/hsic" | ||||||
|  | 	"github.com/juanfont/headscale/integration/tsic" | ||||||
|  | 	"github.com/samber/lo" | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { | ||||||
|  | 	IntegrationSkip(t) | ||||||
|  | 	t.Parallel() | ||||||
|  |  | ||||||
|  | 	for _, https := range []bool{true, false} { | ||||||
|  | 		t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { | ||||||
|  | 			scenario, err := NewScenario(dockertestMaxWait()) | ||||||
|  | 			assertNoErr(t, err) | ||||||
|  | 			defer scenario.ShutdownAssertNoPanics(t) | ||||||
|  |  | ||||||
|  | 			spec := map[string]int{ | ||||||
|  | 				"user1": len(MustTestVersions), | ||||||
|  | 				"user2": len(MustTestVersions), | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			opts := []hsic.Option{hsic.WithTestName("pingallbyip")} | ||||||
|  | 			if https { | ||||||
|  | 				opts = append(opts, []hsic.Option{ | ||||||
|  | 					hsic.WithTLS(), | ||||||
|  | 				}...) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) | ||||||
|  | 			assertNoErrHeadscaleEnv(t, err) | ||||||
|  |  | ||||||
|  | 			allClients, err := scenario.ListTailscaleClients() | ||||||
|  | 			assertNoErrListClients(t, err) | ||||||
|  |  | ||||||
|  | 			err = scenario.WaitForTailscaleSync() | ||||||
|  | 			assertNoErrSync(t, err) | ||||||
|  |  | ||||||
|  | 			// assertClientsState(t, allClients) | ||||||
|  |  | ||||||
|  | 			clientIPs := make(map[TailscaleClient][]netip.Addr) | ||||||
|  | 			for _, client := range allClients { | ||||||
|  | 				ips, err := client.IPs() | ||||||
|  | 				if err != nil { | ||||||
|  | 					t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) | ||||||
|  | 				} | ||||||
|  | 				clientIPs[client] = ips | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			headscale, err := scenario.Headscale() | ||||||
|  | 			assertNoErrGetHeadscale(t, err) | ||||||
|  |  | ||||||
|  | 			listNodes, err := headscale.ListNodes() | ||||||
|  | 			assert.Equal(t, len(listNodes), len(allClients)) | ||||||
|  | 			nodeCountBeforeLogout := len(listNodes) | ||||||
|  | 			t.Logf("node count before logout: %d", nodeCountBeforeLogout) | ||||||
|  |  | ||||||
|  | 			for _, client := range allClients { | ||||||
|  | 				err := client.Logout() | ||||||
|  | 				if err != nil { | ||||||
|  | 					t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			err = scenario.WaitForTailscaleLogout() | ||||||
|  | 			assertNoErrLogout(t, err) | ||||||
|  |  | ||||||
|  | 			t.Logf("all clients logged out") | ||||||
|  |  | ||||||
|  | 			// if the server is not running with HTTPS, we have to wait a bit before | ||||||
|  | 			// reconnection as the newest Tailscale client has a measure that will only | ||||||
|  | 			// reconnect over HTTPS if they saw a noise connection previously. | ||||||
|  | 			// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 | ||||||
|  | 			// https://github.com/juanfont/headscale/issues/2164 | ||||||
|  | 			if !https { | ||||||
|  | 				time.Sleep(5 * time.Minute) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			for userName := range spec { | ||||||
|  | 				key, err := scenario.CreatePreAuthKey(userName, true, false) | ||||||
|  | 				if err != nil { | ||||||
|  | 					t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) | ||||||
|  | 				if err != nil { | ||||||
|  | 					t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			listNodes, err = headscale.ListNodes() | ||||||
|  | 			require.Equal(t, nodeCountBeforeLogout, len(listNodes)) | ||||||
|  |  | ||||||
|  | 			allIps, err := scenario.ListTailscaleClientsIPs() | ||||||
|  | 			assertNoErrListClientIPs(t, err) | ||||||
|  |  | ||||||
|  | 			allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { | ||||||
|  | 				return x.String() | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			success := pingAllHelper(t, allClients, allAddrs) | ||||||
|  | 			t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | ||||||
|  |  | ||||||
|  | 			for _, client := range allClients { | ||||||
|  | 				ips, err := client.IPs() | ||||||
|  | 				if err != nil { | ||||||
|  | 					t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				// lets check if the IPs are the same | ||||||
|  | 				if len(ips) != len(clientIPs[client]) { | ||||||
|  | 					t.Fatalf("IPs changed for client %s", client.Hostname()) | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				for _, ip := range ips { | ||||||
|  | 					found := false | ||||||
|  | 					for _, oldIP := range clientIPs[client] { | ||||||
|  | 						if ip == oldIP { | ||||||
|  | 							found = true | ||||||
|  |  | ||||||
|  | 							break | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  |  | ||||||
|  | 					if !found { | ||||||
|  | 						t.Fatalf( | ||||||
|  | 							"IPs changed for client %s. Used to be %v now %v", | ||||||
|  | 							client.Hostname(), | ||||||
|  | 							clientIPs[client], | ||||||
|  | 							ips, | ||||||
|  | 						) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // This test will first log in two sets of nodes to two sets of users, then | ||||||
|  | // it will log out all users from user2 and log them in as user1. | ||||||
|  | // This should leave us with all nodes connected to user1, while user2 | ||||||
|  | // still has nodes, but they are not connected. | ||||||
|  | func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { | ||||||
|  | 	IntegrationSkip(t) | ||||||
|  | 	t.Parallel() | ||||||
|  |  | ||||||
|  | 	scenario, err := NewScenario(dockertestMaxWait()) | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	defer scenario.ShutdownAssertNoPanics(t) | ||||||
|  |  | ||||||
|  | 	spec := map[string]int{ | ||||||
|  | 		"user1": len(MustTestVersions), | ||||||
|  | 		"user2": len(MustTestVersions), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, | ||||||
|  | 		hsic.WithTestName("keyrelognewuser"), | ||||||
|  | 		hsic.WithTLS(), | ||||||
|  | 	) | ||||||
|  | 	assertNoErrHeadscaleEnv(t, err) | ||||||
|  |  | ||||||
|  | 	allClients, err := scenario.ListTailscaleClients() | ||||||
|  | 	assertNoErrListClients(t, err) | ||||||
|  |  | ||||||
|  | 	err = scenario.WaitForTailscaleSync() | ||||||
|  | 	assertNoErrSync(t, err) | ||||||
|  |  | ||||||
|  | 	// assertClientsState(t, allClients) | ||||||
|  |  | ||||||
|  | 	headscale, err := scenario.Headscale() | ||||||
|  | 	assertNoErrGetHeadscale(t, err) | ||||||
|  |  | ||||||
|  | 	listNodes, err := headscale.ListNodes() | ||||||
|  | 	assert.Equal(t, len(listNodes), len(allClients)) | ||||||
|  | 	nodeCountBeforeLogout := len(listNodes) | ||||||
|  | 	t.Logf("node count before logout: %d", nodeCountBeforeLogout) | ||||||
|  |  | ||||||
|  | 	for _, client := range allClients { | ||||||
|  | 		err := client.Logout() | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = scenario.WaitForTailscaleLogout() | ||||||
|  | 	assertNoErrLogout(t, err) | ||||||
|  |  | ||||||
|  | 	t.Logf("all clients logged out") | ||||||
|  |  | ||||||
|  | 	// Create a new authkey for user1, to be used for all clients | ||||||
|  | 	key, err := scenario.CreatePreAuthKey("user1", true, false) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to create pre-auth key for user1: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Log in all clients as user1, iterating over the spec only returns the | ||||||
|  | 	// clients, not the usernames. | ||||||
|  | 	for userName := range spec { | ||||||
|  | 		err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	user1Nodes, err := headscale.ListNodes("user1") | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	assert.Len(t, user1Nodes, len(allClients)) | ||||||
|  |  | ||||||
|  | 	// Validate that all the old nodes are still present with user2 | ||||||
|  | 	user2Nodes, err := headscale.ListNodes("user2") | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	assert.Len(t, user2Nodes, len(allClients)/2) | ||||||
|  |  | ||||||
|  | 	for _, client := range allClients { | ||||||
|  | 		status, err := client.Status() | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("failed to get status for client %s: %s", client.Hostname(), err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		assert.Equal(t, "user1@test.no", status.User[status.Self.UserID].LoginName) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -116,20 +116,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { | |||||||
| 	headscale, err := scenario.Headscale() | 	headscale, err := scenario.Headscale() | ||||||
| 	assertNoErr(t, err) | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
| 	var listUsers []v1.User | 	listUsers, err := headscale.ListUsers() | ||||||
| 	err = executeAndUnmarshal(headscale, |  | ||||||
| 		[]string{ |  | ||||||
| 			"headscale", |  | ||||||
| 			"users", |  | ||||||
| 			"list", |  | ||||||
| 			"--output", |  | ||||||
| 			"json", |  | ||||||
| 		}, |  | ||||||
| 		&listUsers, |  | ||||||
| 	) |  | ||||||
| 	assertNoErr(t, err) | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
| 	want := []v1.User{ | 	want := []*v1.User{ | ||||||
| 		{ | 		{ | ||||||
| 			Id:    1, | 			Id:    1, | ||||||
| 			Name:  "user1", | 			Name:  "user1", | ||||||
| @@ -249,7 +239,7 @@ func TestOIDC024UserCreation(t *testing.T) { | |||||||
| 		emailVerified bool | 		emailVerified bool | ||||||
| 		cliUsers      []string | 		cliUsers      []string | ||||||
| 		oidcUsers     []string | 		oidcUsers     []string | ||||||
| 		want          func(iss string) []v1.User | 		want          func(iss string) []*v1.User | ||||||
| 	}{ | 	}{ | ||||||
| 		{ | 		{ | ||||||
| 			name: "no-migration-verified-email", | 			name: "no-migration-verified-email", | ||||||
| @@ -259,8 +249,8 @@ func TestOIDC024UserCreation(t *testing.T) { | |||||||
| 			emailVerified: true, | 			emailVerified: true, | ||||||
| 			cliUsers:      []string{"user1", "user2"}, | 			cliUsers:      []string{"user1", "user2"}, | ||||||
| 			oidcUsers:     []string{"user1", "user2"}, | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
| 			want: func(iss string) []v1.User { | 			want: func(iss string) []*v1.User { | ||||||
| 				return []v1.User{ | 				return []*v1.User{ | ||||||
| 					{ | 					{ | ||||||
| 						Id:    1, | 						Id:    1, | ||||||
| 						Name:  "user1", | 						Name:  "user1", | ||||||
| @@ -296,8 +286,8 @@ func TestOIDC024UserCreation(t *testing.T) { | |||||||
| 			emailVerified: false, | 			emailVerified: false, | ||||||
| 			cliUsers:      []string{"user1", "user2"}, | 			cliUsers:      []string{"user1", "user2"}, | ||||||
| 			oidcUsers:     []string{"user1", "user2"}, | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
| 			want: func(iss string) []v1.User { | 			want: func(iss string) []*v1.User { | ||||||
| 				return []v1.User{ | 				return []*v1.User{ | ||||||
| 					{ | 					{ | ||||||
| 						Id:    1, | 						Id:    1, | ||||||
| 						Name:  "user1", | 						Name:  "user1", | ||||||
| @@ -332,8 +322,8 @@ func TestOIDC024UserCreation(t *testing.T) { | |||||||
| 			emailVerified: true, | 			emailVerified: true, | ||||||
| 			cliUsers:      []string{"user1", "user2"}, | 			cliUsers:      []string{"user1", "user2"}, | ||||||
| 			oidcUsers:     []string{"user1", "user2"}, | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
| 			want: func(iss string) []v1.User { | 			want: func(iss string) []*v1.User { | ||||||
| 				return []v1.User{ | 				return []*v1.User{ | ||||||
| 					{ | 					{ | ||||||
| 						Id:         1, | 						Id:         1, | ||||||
| 						Name:       "user1", | 						Name:       "user1", | ||||||
| @@ -360,8 +350,8 @@ func TestOIDC024UserCreation(t *testing.T) { | |||||||
| 			emailVerified: false, | 			emailVerified: false, | ||||||
| 			cliUsers:      []string{"user1", "user2"}, | 			cliUsers:      []string{"user1", "user2"}, | ||||||
| 			oidcUsers:     []string{"user1", "user2"}, | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
| 			want: func(iss string) []v1.User { | 			want: func(iss string) []*v1.User { | ||||||
| 				return []v1.User{ | 				return []*v1.User{ | ||||||
| 					{ | 					{ | ||||||
| 						Id:    1, | 						Id:    1, | ||||||
| 						Name:  "user1", | 						Name:  "user1", | ||||||
| @@ -396,8 +386,8 @@ func TestOIDC024UserCreation(t *testing.T) { | |||||||
| 			emailVerified: true, | 			emailVerified: true, | ||||||
| 			cliUsers:      []string{"user1.headscale.net", "user2.headscale.net"}, | 			cliUsers:      []string{"user1.headscale.net", "user2.headscale.net"}, | ||||||
| 			oidcUsers:     []string{"user1", "user2"}, | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
| 			want: func(iss string) []v1.User { | 			want: func(iss string) []*v1.User { | ||||||
| 				return []v1.User{ | 				return []*v1.User{ | ||||||
| 					// Hmm I think we will have to overwrite the initial name here | 					// Hmm I think we will have to overwrite the initial name here | ||||||
| 					// createuser with "user1.headscale.net", but oidc with "user1" | 					// createuser with "user1.headscale.net", but oidc with "user1" | ||||||
| 					{ | 					{ | ||||||
| @@ -426,8 +416,8 @@ func TestOIDC024UserCreation(t *testing.T) { | |||||||
| 			emailVerified: false, | 			emailVerified: false, | ||||||
| 			cliUsers:      []string{"user1.headscale.net", "user2.headscale.net"}, | 			cliUsers:      []string{"user1.headscale.net", "user2.headscale.net"}, | ||||||
| 			oidcUsers:     []string{"user1", "user2"}, | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
| 			want: func(iss string) []v1.User { | 			want: func(iss string) []*v1.User { | ||||||
| 				return []v1.User{ | 				return []*v1.User{ | ||||||
| 					{ | 					{ | ||||||
| 						Id:    1, | 						Id:    1, | ||||||
| 						Name:  "user1.headscale.net", | 						Name:  "user1.headscale.net", | ||||||
| @@ -509,17 +499,7 @@ func TestOIDC024UserCreation(t *testing.T) { | |||||||
|  |  | ||||||
| 			want := tt.want(oidcConfig.Issuer) | 			want := tt.want(oidcConfig.Issuer) | ||||||
|  |  | ||||||
| 			var listUsers []v1.User | 			listUsers, err := headscale.ListUsers() | ||||||
| 			err = executeAndUnmarshal(headscale, |  | ||||||
| 				[]string{ |  | ||||||
| 					"headscale", |  | ||||||
| 					"users", |  | ||||||
| 					"list", |  | ||||||
| 					"--output", |  | ||||||
| 					"json", |  | ||||||
| 				}, |  | ||||||
| 				&listUsers, |  | ||||||
| 			) |  | ||||||
| 			assertNoErr(t, err) | 			assertNoErr(t, err) | ||||||
|  |  | ||||||
| 			sort.Slice(listUsers, func(i, j int) bool { | 			sort.Slice(listUsers, func(i, j int) bool { | ||||||
| @@ -587,23 +567,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { | |||||||
| 	err = scenario.WaitForTailscaleSync() | 	err = scenario.WaitForTailscaleSync() | ||||||
| 	assertNoErrSync(t, err) | 	assertNoErrSync(t, err) | ||||||
|  |  | ||||||
| 	// Verify PKCE was used in authentication |  | ||||||
| 	headscale, err := scenario.Headscale() |  | ||||||
| 	assertNoErr(t, err) |  | ||||||
|  |  | ||||||
| 	var listUsers []v1.User |  | ||||||
| 	err = executeAndUnmarshal(headscale, |  | ||||||
| 		[]string{ |  | ||||||
| 			"headscale", |  | ||||||
| 			"users", |  | ||||||
| 			"list", |  | ||||||
| 			"--output", |  | ||||||
| 			"json", |  | ||||||
| 		}, |  | ||||||
| 		&listUsers, |  | ||||||
| 	) |  | ||||||
| 	assertNoErr(t, err) |  | ||||||
|  |  | ||||||
| 	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { | 	allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { | ||||||
| 		return x.String() | 		return x.String() | ||||||
| 	}) | 	}) | ||||||
| @@ -612,6 +575,228 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { | |||||||
| 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestOIDCReloginSameNodeNewUser(t *testing.T) { | ||||||
|  | 	IntegrationSkip(t) | ||||||
|  | 	t.Parallel() | ||||||
|  |  | ||||||
|  | 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	scenario := AuthOIDCScenario{ | ||||||
|  | 		Scenario: baseScenario, | ||||||
|  | 	} | ||||||
|  | 	defer scenario.ShutdownAssertNoPanics(t) | ||||||
|  |  | ||||||
|  | 	// Create no nodes and no users | ||||||
|  | 	spec := map[string]int{} | ||||||
|  |  | ||||||
|  | 	// First login creates the first OIDC user | ||||||
|  | 	// Second login logs in the same node, which creates a new node | ||||||
|  | 	// Third login logs in the same node back into the original user | ||||||
|  | 	mockusers := []mockoidc.MockUser{ | ||||||
|  | 		oidcMockUser("user1", true), | ||||||
|  | 		oidcMockUser("user2", true), | ||||||
|  | 		oidcMockUser("user1", true), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||||
|  | 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||||
|  | 	// defer scenario.mockOIDC.Close() | ||||||
|  |  | ||||||
|  | 	oidcMap := map[string]string{ | ||||||
|  | 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||||
|  | 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||||
|  | 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||||
|  | 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||||
|  | 		// TODO(kradalby): Remove when strip_email_domain is removed | ||||||
|  | 		// after #2170 is cleaned up | ||||||
|  | 		"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "0", | ||||||
|  | 		"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = scenario.CreateHeadscaleEnv( | ||||||
|  | 		spec, | ||||||
|  | 		hsic.WithTestName("oidcauthrelog"), | ||||||
|  | 		hsic.WithConfigEnv(oidcMap), | ||||||
|  | 		hsic.WithTLS(), | ||||||
|  | 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||||
|  | 		hsic.WithEmbeddedDERPServerOnly(), | ||||||
|  | 	) | ||||||
|  | 	assertNoErrHeadscaleEnv(t, err) | ||||||
|  |  | ||||||
|  | 	headscale, err := scenario.Headscale() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	listUsers, err := headscale.ListUsers() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	assert.Len(t, listUsers, 0) | ||||||
|  |  | ||||||
|  | 	ts, err := scenario.CreateTailscaleNode("unstable") | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	u, err := ts.LoginWithURL(headscale.GetEndpoint()) | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	_, err = doLoginURL(ts.Hostname(), u) | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	listUsers, err = headscale.ListUsers() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	assert.Len(t, listUsers, 1) | ||||||
|  | 	wantUsers := []*v1.User{ | ||||||
|  | 		{ | ||||||
|  | 			Id:         1, | ||||||
|  | 			Name:       "user1", | ||||||
|  | 			Email:      "user1@headscale.net", | ||||||
|  | 			Provider:   "oidc", | ||||||
|  | 			ProviderId: oidcConfig.Issuer + "/user1", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sort.Slice(listUsers, func(i, j int) bool { | ||||||
|  | 		return listUsers[i].GetId() < listUsers[j].GetId() | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { | ||||||
|  | 		t.Fatalf("unexpected users: %s", diff) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	listNodes, err := headscale.ListNodes() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	assert.Len(t, listNodes, 1) | ||||||
|  |  | ||||||
|  | 	// Log out user1 and log in user2, this should create a new node | ||||||
|  | 	// for user2, the node should have the same machine key and | ||||||
|  | 	// a new node key. | ||||||
|  | 	err = ts.Logout() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	time.Sleep(5 * time.Second) | ||||||
|  |  | ||||||
|  | 	// TODO(kradalby): Not sure why we need to logout twice, but it fails and | ||||||
|  | 	// logs in immediately after the first logout and I cannot reproduce it | ||||||
|  | 	// manually. | ||||||
|  | 	err = ts.Logout() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	u, err = ts.LoginWithURL(headscale.GetEndpoint()) | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	_, err = doLoginURL(ts.Hostname(), u) | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	listUsers, err = headscale.ListUsers() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	assert.Len(t, listUsers, 2) | ||||||
|  | 	wantUsers = []*v1.User{ | ||||||
|  | 		{ | ||||||
|  | 			Id:         1, | ||||||
|  | 			Name:       "user1", | ||||||
|  | 			Email:      "user1@headscale.net", | ||||||
|  | 			Provider:   "oidc", | ||||||
|  | 			ProviderId: oidcConfig.Issuer + "/user1", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         2, | ||||||
|  | 			Name:       "user2", | ||||||
|  | 			Email:      "user2@headscale.net", | ||||||
|  | 			Provider:   "oidc", | ||||||
|  | 			ProviderId: oidcConfig.Issuer + "/user2", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sort.Slice(listUsers, func(i, j int) bool { | ||||||
|  | 		return listUsers[i].GetId() < listUsers[j].GetId() | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { | ||||||
|  | 		t.Fatalf("unexpected users: %s", diff) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	listNodesAfterNewUserLogin, err := headscale.ListNodes() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	assert.Len(t, listNodesAfterNewUserLogin, 2) | ||||||
|  |  | ||||||
|  | 	// Machine key is the same as the "machine" has not changed, | ||||||
|  | 	// but Node key is not as it is a new node | ||||||
|  | 	assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) | ||||||
|  | 	assert.Equal(t, listNodesAfterNewUserLogin[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) | ||||||
|  | 	assert.NotEqual(t, listNodesAfterNewUserLogin[0].NodeKey, listNodesAfterNewUserLogin[1].NodeKey) | ||||||
|  |  | ||||||
|  | 	// Log out user2, and log into user1, no new node should be created, | ||||||
|  | 	// the node should now "become" node1 again | ||||||
|  | 	err = ts.Logout() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	time.Sleep(5 * time.Second) | ||||||
|  |  | ||||||
|  | 	// TODO(kradalby): Not sure why we need to logout twice, but it fails and | ||||||
|  | 	// logs in immediately after the first logout and I cannot reproduce it | ||||||
|  | 	// manually. | ||||||
|  | 	err = ts.Logout() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	u, err = ts.LoginWithURL(headscale.GetEndpoint()) | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	_, err = doLoginURL(ts.Hostname(), u) | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  |  | ||||||
|  | 	listUsers, err = headscale.ListUsers() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	assert.Len(t, listUsers, 2) | ||||||
|  | 	wantUsers = []*v1.User{ | ||||||
|  | 		{ | ||||||
|  | 			Id:         1, | ||||||
|  | 			Name:       "user1", | ||||||
|  | 			Email:      "user1@headscale.net", | ||||||
|  | 			Provider:   "oidc", | ||||||
|  | 			ProviderId: oidcConfig.Issuer + "/user1", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         2, | ||||||
|  | 			Name:       "user2", | ||||||
|  | 			Email:      "user2@headscale.net", | ||||||
|  | 			Provider:   "oidc", | ||||||
|  | 			ProviderId: oidcConfig.Issuer + "/user2", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sort.Slice(listUsers, func(i, j int) bool { | ||||||
|  | 		return listUsers[i].GetId() < listUsers[j].GetId() | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { | ||||||
|  | 		t.Fatalf("unexpected users: %s", diff) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	listNodesAfterLoggingBackIn, err := headscale.ListNodes() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 	assert.Len(t, listNodesAfterLoggingBackIn, 2) | ||||||
|  |  | ||||||
|  | 	// Validate that the machine we had when we logged in the first time, has the same | ||||||
|  | 	// machine key, but a different ID than the newly logged in version of the same | ||||||
|  | 	// machine. | ||||||
|  | 	assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) | ||||||
|  | 	assert.Equal(t, listNodes[0].NodeKey, listNodesAfterNewUserLogin[0].NodeKey) | ||||||
|  | 	assert.Equal(t, listNodes[0].Id, listNodesAfterNewUserLogin[0].Id) | ||||||
|  | 	assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) | ||||||
|  | 	assert.NotEqual(t, listNodes[0].Id, listNodesAfterNewUserLogin[1].Id) | ||||||
|  | 	assert.NotEqual(t, listNodes[0].User.Id, listNodesAfterNewUserLogin[1].User.Id) | ||||||
|  |  | ||||||
|  | 	// Even tho we are logging in again with the same user, the previous key has been expired | ||||||
|  | 	// and a new one has been generated. The node entry in the database should be the same | ||||||
|  | 	// as the user + machinekey still matches. | ||||||
|  | 	assert.Equal(t, listNodes[0].MachineKey, listNodesAfterLoggingBackIn[0].MachineKey) | ||||||
|  | 	assert.NotEqual(t, listNodes[0].NodeKey, listNodesAfterLoggingBackIn[0].NodeKey) | ||||||
|  | 	assert.Equal(t, listNodes[0].Id, listNodesAfterLoggingBackIn[0].Id) | ||||||
|  |  | ||||||
|  | 	// The "logged back in" machine should have the same machinekey but a different nodekey | ||||||
|  | 	// than the version logged in with a different user. | ||||||
|  | 	assert.Equal(t, listNodesAfterLoggingBackIn[0].MachineKey, listNodesAfterLoggingBackIn[1].MachineKey) | ||||||
|  | 	assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey) | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *AuthOIDCScenario) CreateHeadscaleEnv( | func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||||
| 	users map[string]int, | 	users map[string]int, | ||||||
| 	opts ...hsic.Option, | 	opts ...hsic.Option, | ||||||
|   | |||||||
| @@ -11,6 +11,8 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/juanfont/headscale/integration/hsic" | 	"github.com/juanfont/headscale/integration/hsic" | ||||||
| 	"github.com/samber/lo" | 	"github.com/samber/lo" | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var errParseAuthPage = errors.New("failed to parse auth page") | var errParseAuthPage = errors.New("failed to parse auth page") | ||||||
| @@ -106,6 +108,14 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { | |||||||
| 	success := pingAllHelper(t, allClients, allAddrs) | 	success := pingAllHelper(t, allClients, allAddrs) | ||||||
| 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | ||||||
|  |  | ||||||
|  | 	headscale, err := scenario.Headscale() | ||||||
|  | 	assertNoErrGetHeadscale(t, err) | ||||||
|  |  | ||||||
|  | 	listNodes, err := headscale.ListNodes() | ||||||
|  | 	assert.Equal(t, len(listNodes), len(allClients)) | ||||||
|  | 	nodeCountBeforeLogout := len(listNodes) | ||||||
|  | 	t.Logf("node count before logout: %d", nodeCountBeforeLogout) | ||||||
|  |  | ||||||
| 	clientIPs := make(map[TailscaleClient][]netip.Addr) | 	clientIPs := make(map[TailscaleClient][]netip.Addr) | ||||||
| 	for _, client := range allClients { | 	for _, client := range allClients { | ||||||
| 		ips, err := client.IPs() | 		ips, err := client.IPs() | ||||||
| @@ -127,9 +137,6 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { | |||||||
|  |  | ||||||
| 	t.Logf("all clients logged out") | 	t.Logf("all clients logged out") | ||||||
|  |  | ||||||
| 	headscale, err := scenario.Headscale() |  | ||||||
| 	assertNoErrGetHeadscale(t, err) |  | ||||||
|  |  | ||||||
| 	for userName := range spec { | 	for userName := range spec { | ||||||
| 		err = scenario.runTailscaleUp(userName, headscale.GetEndpoint()) | 		err = scenario.runTailscaleUp(userName, headscale.GetEndpoint()) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -139,9 +146,6 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { | |||||||
|  |  | ||||||
| 	t.Logf("all clients logged in again") | 	t.Logf("all clients logged in again") | ||||||
|  |  | ||||||
| 	allClients, err = scenario.ListTailscaleClients() |  | ||||||
| 	assertNoErrListClients(t, err) |  | ||||||
|  |  | ||||||
| 	allIps, err = scenario.ListTailscaleClientsIPs() | 	allIps, err = scenario.ListTailscaleClientsIPs() | ||||||
| 	assertNoErrListClientIPs(t, err) | 	assertNoErrListClientIPs(t, err) | ||||||
|  |  | ||||||
| @@ -152,6 +156,10 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { | |||||||
| 	success = pingAllHelper(t, allClients, allAddrs) | 	success = pingAllHelper(t, allClients, allAddrs) | ||||||
| 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | ||||||
|  |  | ||||||
|  | 	listNodes, err = headscale.ListNodes() | ||||||
|  | 	require.Equal(t, nodeCountBeforeLogout, len(listNodes)) | ||||||
|  | 	t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes)) | ||||||
|  |  | ||||||
| 	for _, client := range allClients { | 	for _, client := range allClients { | ||||||
| 		ips, err := client.IPs() | 		ips, err := client.IPs() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|   | |||||||
| @@ -606,22 +606,12 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { | |||||||
| 		t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String()) | 		t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String()) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var listNodes []v1.Node | 	listNodes, err := headscale.ListNodes() | ||||||
| 	err = executeAndUnmarshal( |  | ||||||
| 		headscale, |  | ||||||
| 		[]string{ |  | ||||||
| 			"headscale", |  | ||||||
| 			"nodes", |  | ||||||
| 			"list", |  | ||||||
| 			"--output", |  | ||||||
| 			"json", |  | ||||||
| 		}, |  | ||||||
| 		&listNodes, |  | ||||||
| 	) |  | ||||||
| 	assert.Nil(t, err) | 	assert.Nil(t, err) | ||||||
| 	assert.Len(t, listNodes, 1) | 	assert.Len(t, listNodes, 2) | ||||||
|  |  | ||||||
| 	assert.Equal(t, "user2", listNodes[0].GetUser().GetName()) | 	assert.Equal(t, "user1", listNodes[0].GetUser().GetName()) | ||||||
|  | 	assert.Equal(t, "user2", listNodes[1].GetUser().GetName()) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestApiKeyCommand(t *testing.T) { | func TestApiKeyCommand(t *testing.T) { | ||||||
|   | |||||||
| @@ -17,7 +17,8 @@ type ControlServer interface { | |||||||
| 	WaitForRunning() error | 	WaitForRunning() error | ||||||
| 	CreateUser(user string) error | 	CreateUser(user string) error | ||||||
| 	CreateAuthKey(user string, reusable bool, ephemeral bool) (*v1.PreAuthKey, error) | 	CreateAuthKey(user string, reusable bool, ephemeral bool) (*v1.PreAuthKey, error) | ||||||
| 	ListNodesInUser(user string) ([]*v1.Node, error) | 	ListNodes(users ...string) ([]*v1.Node, error) | ||||||
|  | 	ListUsers() ([]*v1.User, error) | ||||||
| 	GetCert() []byte | 	GetCert() []byte | ||||||
| 	GetHostname() string | 	GetHostname() string | ||||||
| 	GetIP() string | 	GetIP() string | ||||||
|   | |||||||
| @@ -105,137 +105,6 @@ func TestPingAllByIPPublicDERP(t *testing.T) { | |||||||
| 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestAuthKeyLogoutAndRelogin(t *testing.T) { |  | ||||||
| 	IntegrationSkip(t) |  | ||||||
| 	t.Parallel() |  | ||||||
|  |  | ||||||
| 	for _, https := range []bool{true, false} { |  | ||||||
| 		t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { |  | ||||||
| 			scenario, err := NewScenario(dockertestMaxWait()) |  | ||||||
| 			assertNoErr(t, err) |  | ||||||
| 			defer scenario.ShutdownAssertNoPanics(t) |  | ||||||
|  |  | ||||||
| 			spec := map[string]int{ |  | ||||||
| 				"user1": len(MustTestVersions), |  | ||||||
| 				"user2": len(MustTestVersions), |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			opts := []hsic.Option{hsic.WithTestName("pingallbyip")} |  | ||||||
| 			if https { |  | ||||||
| 				opts = append(opts, []hsic.Option{ |  | ||||||
| 					hsic.WithTLS(), |  | ||||||
| 				}...) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) |  | ||||||
| 			assertNoErrHeadscaleEnv(t, err) |  | ||||||
|  |  | ||||||
| 			allClients, err := scenario.ListTailscaleClients() |  | ||||||
| 			assertNoErrListClients(t, err) |  | ||||||
|  |  | ||||||
| 			err = scenario.WaitForTailscaleSync() |  | ||||||
| 			assertNoErrSync(t, err) |  | ||||||
|  |  | ||||||
| 			// assertClientsState(t, allClients) |  | ||||||
|  |  | ||||||
| 			clientIPs := make(map[TailscaleClient][]netip.Addr) |  | ||||||
| 			for _, client := range allClients { |  | ||||||
| 				ips, err := client.IPs() |  | ||||||
| 				if err != nil { |  | ||||||
| 					t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) |  | ||||||
| 				} |  | ||||||
| 				clientIPs[client] = ips |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			for _, client := range allClients { |  | ||||||
| 				err := client.Logout() |  | ||||||
| 				if err != nil { |  | ||||||
| 					t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err = scenario.WaitForTailscaleLogout() |  | ||||||
| 			assertNoErrLogout(t, err) |  | ||||||
|  |  | ||||||
| 			t.Logf("all clients logged out") |  | ||||||
|  |  | ||||||
| 			headscale, err := scenario.Headscale() |  | ||||||
| 			assertNoErrGetHeadscale(t, err) |  | ||||||
|  |  | ||||||
| 			// if the server is not running with HTTPS, we have to wait a bit before |  | ||||||
| 			// reconnection as the newest Tailscale client has a measure that will only |  | ||||||
| 			// reconnect over HTTPS if they saw a noise connection previously. |  | ||||||
| 			// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 |  | ||||||
| 			// https://github.com/juanfont/headscale/issues/2164 |  | ||||||
| 			if !https { |  | ||||||
| 				time.Sleep(5 * time.Minute) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			for userName := range spec { |  | ||||||
| 				key, err := scenario.CreatePreAuthKey(userName, true, false) |  | ||||||
| 				if err != nil { |  | ||||||
| 					t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) |  | ||||||
| 				if err != nil { |  | ||||||
| 					t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err = scenario.WaitForTailscaleSync() |  | ||||||
| 			assertNoErrSync(t, err) |  | ||||||
|  |  | ||||||
| 			// assertClientsState(t, allClients) |  | ||||||
|  |  | ||||||
| 			allClients, err = scenario.ListTailscaleClients() |  | ||||||
| 			assertNoErrListClients(t, err) |  | ||||||
|  |  | ||||||
| 			allIps, err := scenario.ListTailscaleClientsIPs() |  | ||||||
| 			assertNoErrListClientIPs(t, err) |  | ||||||
|  |  | ||||||
| 			allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { |  | ||||||
| 				return x.String() |  | ||||||
| 			}) |  | ||||||
|  |  | ||||||
| 			success := pingAllHelper(t, allClients, allAddrs) |  | ||||||
| 			t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) |  | ||||||
|  |  | ||||||
| 			for _, client := range allClients { |  | ||||||
| 				ips, err := client.IPs() |  | ||||||
| 				if err != nil { |  | ||||||
| 					t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				// lets check if the IPs are the same |  | ||||||
| 				if len(ips) != len(clientIPs[client]) { |  | ||||||
| 					t.Fatalf("IPs changed for client %s", client.Hostname()) |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				for _, ip := range ips { |  | ||||||
| 					found := false |  | ||||||
| 					for _, oldIP := range clientIPs[client] { |  | ||||||
| 						if ip == oldIP { |  | ||||||
| 							found = true |  | ||||||
|  |  | ||||||
| 							break |  | ||||||
| 						} |  | ||||||
| 					} |  | ||||||
|  |  | ||||||
| 					if !found { |  | ||||||
| 						t.Fatalf( |  | ||||||
| 							"IPs changed for client %s. Used to be %v now %v", |  | ||||||
| 							client.Hostname(), |  | ||||||
| 							clientIPs[client], |  | ||||||
| 							ips, |  | ||||||
| 						) |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestEphemeral(t *testing.T) { | func TestEphemeral(t *testing.T) { | ||||||
| 	testEphemeralWithOptions(t, hsic.WithTestName("ephemeral")) | 	testEphemeralWithOptions(t, hsic.WithTestName("ephemeral")) | ||||||
| } | } | ||||||
| @@ -314,21 +183,9 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { | |||||||
|  |  | ||||||
| 	t.Logf("all clients logged out") | 	t.Logf("all clients logged out") | ||||||
|  |  | ||||||
| 	for userName := range spec { | 	nodes, err := headscale.ListNodes() | ||||||
| 		nodes, err := headscale.ListNodesInUser(userName) | 	assertNoErr(t, err) | ||||||
| 		if err != nil { | 	require.Len(t, nodes, 0) | ||||||
| 			log.Error(). |  | ||||||
| 				Err(err). |  | ||||||
| 				Str("user", userName). |  | ||||||
| 				Msg("Error listing nodes in user") |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if len(nodes) != 0 { |  | ||||||
| 			t.Fatalf("expected no nodes, got %d in user %s", len(nodes), userName) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not | // TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not | ||||||
| @@ -431,7 +288,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { | |||||||
| 	time.Sleep(3 * time.Minute) | 	time.Sleep(3 * time.Minute) | ||||||
|  |  | ||||||
| 	for userName := range spec { | 	for userName := range spec { | ||||||
| 		nodes, err := headscale.ListNodesInUser(userName) | 		nodes, err := headscale.ListNodes(userName) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| 				Err(err). | 				Err(err). | ||||||
|   | |||||||
| @@ -16,7 +16,7 @@ func DefaultConfigEnv() map[string]string { | |||||||
| 		"HEADSCALE_POLICY_PATH":                       "", | 		"HEADSCALE_POLICY_PATH":                       "", | ||||||
| 		"HEADSCALE_DATABASE_TYPE":                     "sqlite", | 		"HEADSCALE_DATABASE_TYPE":                     "sqlite", | ||||||
| 		"HEADSCALE_DATABASE_SQLITE_PATH":              "/tmp/integration_test_db.sqlite3", | 		"HEADSCALE_DATABASE_SQLITE_PATH":              "/tmp/integration_test_db.sqlite3", | ||||||
| 		"HEADSCALE_DATABASE_DEBUG":                    "1", | 		"HEADSCALE_DATABASE_DEBUG":                    "0", | ||||||
| 		"HEADSCALE_DATABASE_GORM_SLOW_THRESHOLD":      "1", | 		"HEADSCALE_DATABASE_GORM_SLOW_THRESHOLD":      "1", | ||||||
| 		"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", | 		"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", | ||||||
| 		"HEADSCALE_PREFIXES_V4":                       "100.64.0.0/10", | 		"HEADSCALE_PREFIXES_V4":                       "100.64.0.0/10", | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package hsic | package hsic | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"cmp" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| @@ -10,6 +11,7 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path" | 	"path" | ||||||
|  | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -744,12 +746,58 @@ func (t *HeadscaleInContainer) CreateAuthKey( | |||||||
| 	return &preAuthKey, nil | 	return &preAuthKey, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // ListNodesInUser list the TailscaleClients (Node, Headscale internal representation) | // ListNodes lists the currently registered Nodes in headscale. | ||||||
| // associated with a user. | // Optionally a list of usernames can be passed to get users for | ||||||
| func (t *HeadscaleInContainer) ListNodesInUser( | // specific users. | ||||||
| 	user string, | func (t *HeadscaleInContainer) ListNodes( | ||||||
|  | 	users ...string, | ||||||
| ) ([]*v1.Node, error) { | ) ([]*v1.Node, error) { | ||||||
| 	command := []string{"headscale", "--user", user, "nodes", "list", "--output", "json"} | 	var ret []*v1.Node | ||||||
|  | 	execUnmarshal := func(command []string) error { | ||||||
|  | 		result, _, err := dockertestutil.ExecuteCommand( | ||||||
|  | 			t.container, | ||||||
|  | 			command, | ||||||
|  | 			[]string{}, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("failed to execute list node command: %w", err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		var nodes []*v1.Node | ||||||
|  | 		err = json.Unmarshal([]byte(result), &nodes) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("failed to unmarshal nodes: %w", err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		ret = append(ret, nodes...) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(users) == 0 { | ||||||
|  | 		err := execUnmarshal([]string{"headscale", "nodes", "list", "--output", "json"}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		for _, user := range users { | ||||||
|  | 			command := []string{"headscale", "--user", user, "nodes", "list", "--output", "json"} | ||||||
|  |  | ||||||
|  | 			err := execUnmarshal(command) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return nil, err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sort.Slice(ret, func(i, j int) bool { | ||||||
|  | 		return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1 | ||||||
|  | 	}) | ||||||
|  | 	return ret, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ListUsers returns a list of users from Headscale. | ||||||
|  | func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) { | ||||||
|  | 	command := []string{"headscale", "users", "list", "--output", "json"} | ||||||
|  |  | ||||||
| 	result, _, err := dockertestutil.ExecuteCommand( | 	result, _, err := dockertestutil.ExecuteCommand( | ||||||
| 		t.container, | 		t.container, | ||||||
| @@ -760,13 +808,13 @@ func (t *HeadscaleInContainer) ListNodesInUser( | |||||||
| 		return nil, fmt.Errorf("failed to execute list node command: %w", err) | 		return nil, fmt.Errorf("failed to execute list node command: %w", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var nodes []*v1.Node | 	var users []*v1.User | ||||||
| 	err = json.Unmarshal([]byte(result), &nodes) | 	err = json.Unmarshal([]byte(result), &users) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to unmarshal nodes: %w", err) | 		return nil, fmt.Errorf("failed to unmarshal nodes: %w", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nodes, nil | 	return users, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // WriteFile save file inside the Headscale container. | // WriteFile save file inside the Headscale container. | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user