stability and race conditions in auth and node store (#2781)

This PR addresses some consistency issues that was introduced or discovered with the nodestore.

nodestore:
Now returns the node that is being put or updated when it is finished. This closes a race condition where when we read it back, we do not necessarily get the node with the given change and it ensures we get all the other updates from that batch write.

auth:
Authentication paths have been unified and simplified. It removes a lot of bad branches and ensures we only do the minimal work.
A comprehensive auth test set has been created so we do not have to run integration tests to validate auth and it has allowed us to generate test cases for all the branches we currently know of.

integration:
added a lot more tooling and checks to validate that nodes reach the expected state when they come up and down. Standardised between the different auth models. A lot of this is to support or detect issues in the changes to nodestore (races) and auth (inconsistencies after login and reaching correct state)

This PR was assisted, particularly tests, by claude code.
This commit is contained in:
Kristoffer Dalby
2025-10-16 12:17:43 +02:00
committed by GitHub
parent 881a6b9227
commit fddc7117e4
34 changed files with 7408 additions and 1876 deletions

View File

@@ -31,9 +31,11 @@ jobs:
- TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE
- TestOIDCReloginSameNodeNewUser
- TestOIDCReloginSameNodeSameUser
- TestOIDCFollowUpUrl
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin
- TestAuthWebFlowLogoutAndReloginSameUser
- TestAuthWebFlowLogoutAndReloginNewUser
- TestUserCommand
- TestPreAuthKeyCommand
- TestPreAuthKeyCommandWithoutExpiry

View File

@@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
@@ -25,26 +26,84 @@ type AuthProvider interface {
func (h *Headscale) handleRegister(
ctx context.Context,
regReq tailcfg.RegisterRequest,
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey)
// Check for logout/expiry FIRST, before checking auth key.
// Tailscale clients may send logout requests with BOTH a past expiry AND an auth key.
// A past expiry takes precedence - it's a logout regardless of other fields.
if !req.Expiry.IsZero() && req.Expiry.Before(time.Now()) {
log.Debug().
Str("node.key", req.NodeKey.ShortString()).
Time("expiry", req.Expiry).
Bool("has_auth", req.Auth != nil).
Msg("Detected logout attempt with past expiry")
if ok {
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
if err != nil {
return nil, fmt.Errorf("handling existing node: %w", err)
// This is a logout attempt (expiry in the past)
if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok {
log.Debug().
Uint64("node.id", node.ID().Uint64()).
Str("node.name", node.Hostname()).
Bool("is_ephemeral", node.IsEphemeral()).
Bool("has_authkey", node.AuthKey().Valid()).
Msg("Found existing node for logout, calling handleLogout")
resp, err := h.handleLogout(node, req, machineKey)
if err != nil {
return nil, fmt.Errorf("handling logout: %w", err)
}
if resp != nil {
return resp, nil
}
} else {
log.Warn().
Str("node.key", req.NodeKey.ShortString()).
Msg("Logout attempt but node not found in NodeStore")
}
return resp, nil
}
if regReq.Followup != "" {
return h.waitForFollowup(ctx, regReq, machineKey)
// If the register request does not contain a Auth struct, it means we are logging
// out an existing node (legacy logout path for clients that send Auth=nil).
if req.Auth == nil {
// If the register request present a NodeKey that is currently in use, we will
// check if the node needs to be sent to re-auth, or if the node is logging out.
// We do not look up nodes by [key.MachinePublic] as it might belong to multiple
// nodes, separated by users and this path is handling expiring/logout paths.
if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok {
resp, err := h.handleLogout(node, req, machineKey)
if err != nil {
return nil, fmt.Errorf("handling existing node: %w", err)
}
// If resp is not nil, we have a response to return to the node.
// If resp is nil, we should proceed and see if the node is trying to re-auth.
if resp != nil {
return resp, nil
}
} else {
// If the register request is not attempting to register a node, and
// we cannot match it with an existing node, we consider that unexpected
// as only register nodes should attempt to log out.
log.Debug().
Str("node.key", req.NodeKey.ShortString()).
Str("machine.key", machineKey.ShortString()).
Bool("unexpected", true).
Msg("received register request with no auth, and no existing node")
}
}
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
// If the [tailcfg.RegisterRequest] has a Followup URL, it means that the
// node has already started the registration process and we should wait for
// it to finish the original registration.
if req.Followup != "" {
return h.waitForFollowup(ctx, req, machineKey)
}
// Pre authenticated keys are handled slightly different than interactive
// logins as they can be done fully sync and we can respond to the node with
// the result as it is waiting.
if isAuthKey(req) {
resp, err := h.handleRegisterWithAuthKey(req, machineKey)
if err != nil {
// Preserve HTTPError types so they can be handled properly by the HTTP layer
var httpErr HTTPError
@@ -58,7 +117,7 @@ func (h *Headscale) handleRegister(
return resp, nil
}
resp, err := h.handleRegisterInteractive(regReq, machineKey)
resp, err := h.handleRegisterInteractive(req, machineKey)
if err != nil {
return nil, fmt.Errorf("handling register interactive: %w", err)
}
@@ -66,20 +125,34 @@ func (h *Headscale) handleRegister(
return resp, nil
}
func (h *Headscale) handleExistingNode(
node *types.Node,
regReq tailcfg.RegisterRequest,
// handleLogout checks if the [tailcfg.RegisterRequest] is a
// logout attempt from a node. If the node is not attempting to
func (h *Headscale) handleLogout(
node types.NodeView,
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
if node.MachineKey != machineKey {
// Fail closed if it looks like this is an attempt to modify a node where
// the node key and the machine key the noise session was started with does
// not align.
if node.MachineKey() != machineKey {
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
}
expired := node.IsExpired()
// Note: We do NOT return early if req.Auth is set, because Tailscale clients
// may send logout requests with BOTH a past expiry AND an auth key.
// A past expiry indicates logout, regardless of whether Auth is present.
// The expiry check below will handle the logout logic.
// If the node is expired and this is not a re-authentication attempt,
// force the client to re-authenticate
if expired && regReq.Auth == nil {
// force the client to re-authenticate.
// TODO(kradalby): I wonder if this is a path we ever hit?
if node.IsExpired() {
log.Trace().Str("node.name", node.Hostname()).
Uint64("node.id", node.ID().Uint64()).
Interface("reg.req", req).
Bool("unexpected", true).
Msg("Node key expired, forcing re-authentication")
return &tailcfg.RegisterResponse{
NodeKeyExpired: true,
MachineAuthorized: false,
@@ -87,49 +160,76 @@ func (h *Headscale) handleExistingNode(
}, nil
}
if !expired && !regReq.Expiry.IsZero() {
requestExpiry := regReq.Expiry
// If we get here, the node is not currently expired, and not trying to
// do an auth.
// The node is likely logging out, but before we run that logic, we will validate
// that the node is not attempting to tamper/extend their expiry.
// If it is not, we will expire the node or in the case of an ephemeral node, delete it.
// The client is trying to extend their key, this is not allowed.
if requestExpiry.After(time.Now()) {
return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil)
}
// If the request expiry is in the past, we consider it a logout.
if requestExpiry.Before(time.Now()) {
if node.IsEphemeral() {
c, err := h.state.DeleteNode(node.View())
if err != nil {
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
}
h.Change(c)
return nil, nil
}
}
updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
if err != nil {
return nil, fmt.Errorf("setting node expiry: %w", err)
}
h.Change(c)
// CRITICAL: Use the updated node view for the response
// The original node object has stale expiry information
node = updatedNode.AsStruct()
// The client is trying to extend their key, this is not allowed.
if req.Expiry.After(time.Now()) {
return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil)
}
return nodeToRegisterResponse(node), nil
// If the request expiry is in the past, we consider it a logout.
if req.Expiry.Before(time.Now()) {
log.Debug().
Uint64("node.id", node.ID().Uint64()).
Str("node.name", node.Hostname()).
Bool("is_ephemeral", node.IsEphemeral()).
Bool("has_authkey", node.AuthKey().Valid()).
Time("req.expiry", req.Expiry).
Msg("Processing logout request with past expiry")
if node.IsEphemeral() {
log.Info().
Uint64("node.id", node.ID().Uint64()).
Str("node.name", node.Hostname()).
Msg("Deleting ephemeral node during logout")
c, err := h.state.DeleteNode(node)
if err != nil {
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
}
h.Change(c)
return &tailcfg.RegisterResponse{
NodeKeyExpired: true,
MachineAuthorized: false,
}, nil
}
log.Debug().
Uint64("node.id", node.ID().Uint64()).
Str("node.name", node.Hostname()).
Msg("Node is not ephemeral, setting expiry instead of deleting")
}
// Update the internal state with the nodes new expiry, meaning it is
// logged out.
updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), req.Expiry)
if err != nil {
return nil, fmt.Errorf("setting node expiry: %w", err)
}
h.Change(c)
return nodeToRegisterResponse(updatedNode), nil
}
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
// isAuthKey reports if the register request is a registration request
// using an pre auth key.
func isAuthKey(req tailcfg.RegisterRequest) bool {
return req.Auth != nil && req.Auth.AuthKey != ""
}
func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse {
return &tailcfg.RegisterResponse{
// TODO(kradalby): Only send for user-owned nodes
// and not tagged nodes when tags is working.
User: *node.User.TailscaleUser(),
Login: *node.User.TailscaleLogin(),
User: node.UserView().TailscaleUser(),
Login: node.UserView().TailscaleLogin(),
NodeKeyExpired: node.IsExpired(),
// Headscale does not implement the concept of machine authorization
@@ -141,10 +241,10 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
func (h *Headscale) waitForFollowup(
ctx context.Context,
regReq tailcfg.RegisterRequest,
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
fu, err := url.Parse(regReq.Followup)
fu, err := url.Parse(req.Followup)
if err != nil {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
}
@@ -161,21 +261,21 @@ func (h *Headscale) waitForFollowup(
case node := <-reg.Registered:
if node == nil {
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
return h.reqToNewRegisterResponse(req, machineKey)
}
return nodeToRegisterResponse(node), nil
return nodeToRegisterResponse(node.View()), nil
}
}
// if the follow-up registration isn't found anymore, instruct the client to try a new registration
return h.reqToNewRegisterResponse(regReq, machineKey)
return h.reqToNewRegisterResponse(req, machineKey)
}
// reqToNewRegisterResponse refreshes the registration flow by creating a new
// registration ID and returning the corresponding AuthURL so the client can
// restart the authentication process.
func (h *Headscale) reqToNewRegisterResponse(
regReq tailcfg.RegisterRequest,
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
newRegID, err := types.NewRegistrationID()
@@ -183,18 +283,25 @@ func (h *Headscale) reqToNewRegisterResponse(
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
// Ensure we have valid hostinfo and hostname
validHostinfo, hostname := util.EnsureValidHostinfo(
req.Hostinfo,
machineKey.String(),
req.NodeKey.String(),
)
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: regReq.Hostinfo.Hostname,
Hostname: hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
NodeKey: req.NodeKey,
Hostinfo: validHostinfo,
LastSeen: ptr.To(time.Now()),
},
)
if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry
if !req.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &req.Expiry
}
log.Info().Msgf("New followup node registration using key: %s", newRegID)
@@ -206,11 +313,11 @@ func (h *Headscale) reqToNewRegisterResponse(
}
func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, changed, err := h.state.HandleNodeFromPreAuthKey(
regReq,
req,
machineKey,
)
if err != nil {
@@ -262,18 +369,26 @@ func (h *Headscale) handleRegisterWithAuthKey(
// h.Change(policyChange)
// }
user := node.User()
return &tailcfg.RegisterResponse{
resp := &tailcfg.RegisterResponse{
MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(),
User: *user.TailscaleUser(),
Login: *user.TailscaleLogin(),
}, nil
User: node.UserView().TailscaleUser(),
Login: node.UserView().TailscaleLogin(),
}
log.Trace().
Caller().
Interface("reg.resp", resp).
Interface("reg.req", req).
Str("node.name", node.Hostname()).
Uint64("node.id", node.ID().Uint64()).
Msg("RegisterResponse")
return resp, nil
}
func (h *Headscale) handleRegisterInteractive(
regReq tailcfg.RegisterRequest,
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
registrationId, err := types.NewRegistrationID()
@@ -281,18 +396,39 @@ func (h *Headscale) handleRegisterInteractive(
return nil, fmt.Errorf("generating registration ID: %w", err)
}
// Ensure we have valid hostinfo and hostname
validHostinfo, hostname := util.EnsureValidHostinfo(
req.Hostinfo,
machineKey.String(),
req.NodeKey.String(),
)
if req.Hostinfo == nil {
log.Warn().
Str("machine.key", machineKey.ShortString()).
Str("node.key", req.NodeKey.ShortString()).
Str("generated.hostname", hostname).
Msg("Received registration request with nil hostinfo, generated default hostname")
} else if req.Hostinfo.Hostname == "" {
log.Warn().
Str("machine.key", machineKey.ShortString()).
Str("node.key", req.NodeKey.ShortString()).
Str("generated.hostname", hostname).
Msg("Received registration request with empty hostname, generated default")
}
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: regReq.Hostinfo.Hostname,
Hostname: hostname,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
Hostinfo: regReq.Hostinfo,
NodeKey: req.NodeKey,
Hostinfo: validHostinfo,
LastSeen: ptr.To(time.Now()),
},
)
if !regReq.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &regReq.Expiry
if !req.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &req.Expiry
}
h.state.SetRegistrationCacheEntry(

3006
hscontrol/auth_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -741,7 +741,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes,
OS: "TestOS",
Hostname: "DebugTestNode",
Hostname: request.GetName(),
}
registrationId, err := types.RegistrationIDFromString(request.GetKey())

View File

@@ -197,11 +197,12 @@ func (m *mapSession) serveLongPoll() {
m.keepAliveTicker = time.NewTicker(m.keepAlive)
// Process the initial MapRequest to update node state (endpoints, hostinfo, etc.)
// CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly
// synchronized. When nodes reconnect, they send their hostinfo with announced routes
// in the MapRequest. We need this data in NodeStore before Connect() sets up the
// primary routes, otherwise SubnetRoutes() returns empty and the node is removed
// from AvailableRoutes.
// This must be done BEFORE calling Connect() to ensure routes are properly synchronized.
// When nodes reconnect, they send their hostinfo with announced routes in the MapRequest.
// We need this data in NodeStore before Connect() sets up the primary routes, because
// SubnetRoutes() calculates the intersection of announced and approved routes. If we
// call Connect() first, SubnetRoutes() returns empty (no announced routes yet), causing
// the node to be incorrectly removed from AvailableRoutes.
mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
if err != nil {
m.errf(err, "failed to update node from initial MapRequest")

View File

@@ -60,9 +60,6 @@ type DebugStringInfo struct {
// DebugOverview returns a comprehensive overview of the current state for debugging.
func (s *State) DebugOverview() string {
s.mu.RLock()
defer s.mu.RUnlock()
allNodes := s.nodeStore.ListNodes()
users, _ := s.ListAllUsers()
@@ -270,9 +267,6 @@ func (s *State) PolicyDebugString() string {
// DebugOverviewJSON returns a structured overview of the current state for debugging.
func (s *State) DebugOverviewJSON() DebugOverviewInfo {
s.mu.RLock()
defer s.mu.RUnlock()
allNodes := s.nodeStore.ListNodes()
users, _ := s.ListAllUsers()

View File

@@ -33,8 +33,8 @@ func TestNodeStoreDebugString(t *testing.T) {
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
store.PutNode(node1)
store.PutNode(node2)
_ = store.PutNode(node1)
_ = store.PutNode(node2)
return store
},

View File

@@ -0,0 +1,460 @@
package state
import (
"net/netip"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/ptr"
)
// TestEphemeralNodeDeleteWithConcurrentUpdate tests the race condition where UpdateNode and DeleteNode
// are called concurrently and may be batched together. This reproduces the issue where ephemeral nodes
// are not properly deleted during logout because UpdateNodeFromMapRequest returns a stale node view
// after the node has been deleted from the NodeStore.
func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
// Create a simple test node
node := createTestNode(1, 1, "test-user", "test-node")
// Create NodeStore
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
// Put the node in the store
resultNode := store.PutNode(node)
require.True(t, resultNode.Valid(), "initial PutNode should return valid node")
// Verify node exists
retrievedNode, found := store.GetNode(node.ID)
require.True(t, found)
require.Equal(t, node.ID, retrievedNode.ID())
// Test scenario: UpdateNode is called, returns a node view from the batch,
// but in the same batch a DeleteNode removes the node.
// This simulates what happens when:
// 1. UpdateNodeFromMapRequest calls UpdateNode and gets back updatedNode
// 2. At the same time, handleLogout calls DeleteNode
// 3. They get batched together: [UPDATE, DELETE]
// 4. UPDATE modifies the node, DELETE removes it
// 5. UpdateNode returns a node view based on the state AFTER both operations
// 6. If DELETE came after UPDATE, the returned node should be invalid
done := make(chan bool, 2)
var updatedNode types.NodeView
var updateOk bool
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
go func() {
updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
})
done <- true
}()
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
go func() {
// Small delay to increase chance of batching together
time.Sleep(1 * time.Millisecond)
store.DeleteNode(node.ID)
done <- true
}()
// Wait for both operations
<-done
<-done
// Give batching time to complete
time.Sleep(50 * time.Millisecond)
// The key assertion: if UpdateNode and DeleteNode were batched together
// with DELETE after UPDATE, then UpdateNode should return an invalid node
// OR it should return a valid node but the node should no longer exist in the store
_, found = store.GetNode(node.ID)
assert.False(t, found, "node should be deleted from NodeStore")
// If the update happened before delete in the batch, the returned node might be invalid
if updateOk {
t.Logf("UpdateNode returned ok=true, valid=%v", updatedNode.Valid())
// This is the bug scenario - UpdateNode thinks it succeeded but node is gone
if updatedNode.Valid() {
t.Logf("WARNING: UpdateNode returned valid node but node was deleted - this indicates the race condition bug")
}
} else {
t.Logf("UpdateNode correctly returned ok=false (node deleted in same batch)")
}
}
// TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch specifically tests that when
// UpdateNode and DeleteNode are in the same batch with DELETE after UPDATE,
// the UpdateNode should return an invalid node view.
func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) {
node := createTestNode(2, 1, "test-user", "test-node-2")
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
// Put node in store
_ = store.PutNode(node)
// Simulate the exact sequence: UpdateNode gets queued, then DeleteNode gets queued,
// they batch together, and we check what UpdateNode returns
resultChan := make(chan struct {
node types.NodeView
ok bool
})
// Start UpdateNode - it will block until batch is applied
go func() {
node, ok := store.UpdateNode(node.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
})
resultChan <- struct {
node types.NodeView
ok bool
}{node, ok}
}()
// Give UpdateNode a moment to queue its work
time.Sleep(5 * time.Millisecond)
// Now queue DeleteNode - should batch with the UPDATE
store.DeleteNode(node.ID)
// Get the result from UpdateNode
result := <-resultChan
// Wait for batch to complete
time.Sleep(50 * time.Millisecond)
// Node should be deleted
_, found := store.GetNode(node.ID)
assert.False(t, found, "node should be deleted")
// The critical check: what did UpdateNode return?
// After the commit c6b09289988f34398eb3157e31ba092eb8721a9f,
// UpdateNode returns the node state from the batch.
// If DELETE came after UPDATE in the batch, the node doesn't exist anymore,
// so UpdateNode should return (invalid, false)
t.Logf("UpdateNode returned: ok=%v, valid=%v", result.ok, result.node.Valid())
// This is the expected behavior - if node was deleted in same batch,
// UpdateNode should return invalid node
if result.ok && result.node.Valid() {
t.Error("BUG: UpdateNode returned valid node even though it was deleted in same batch")
}
}
// TestPersistNodeToDBPreventsRaceCondition tests that persistNodeToDB correctly handles
// the race condition where a node is deleted after UpdateNode returns but before
// persistNodeToDB is called. This reproduces the ephemeral node deletion bug.
func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) {
node := createTestNode(3, 1, "test-user", "test-node-3")
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
// Put node in store
_ = store.PutNode(node)
// Simulate UpdateNode being called
updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
})
require.True(t, ok, "UpdateNode should succeed")
require.True(t, updatedNode.Valid(), "UpdateNode should return valid node")
// Now delete the node (simulating ephemeral logout happening concurrently)
store.DeleteNode(node.ID)
// Wait for deletion to complete
time.Sleep(50 * time.Millisecond)
// Verify node is deleted
_, found := store.GetNode(node.ID)
require.False(t, found, "node should be deleted")
// Now try to use the updatedNode from before the deletion
// In the old code, this would re-insert the node into the database
// With our fix, GetNode check in persistNodeToDB should prevent this
// Simulate what persistNodeToDB does - check if node still exists
_, exists := store.GetNode(updatedNode.ID())
if !exists {
t.Log("SUCCESS: persistNodeToDB check would prevent re-insertion of deleted node")
} else {
t.Error("BUG: Node still exists in NodeStore after deletion")
}
// The key assertion: after deletion, attempting to persist the old updatedNode
// should fail because the node no longer exists in NodeStore
assert.False(t, exists, "persistNodeToDB should detect node was deleted and refuse to persist")
}
// TestEphemeralNodeLogoutRaceCondition tests the specific race condition that occurs
// when an ephemeral node logs out. This reproduces the bug where:
// 1. UpdateNodeFromMapRequest calls UpdateNode and receives a node view
// 2. Concurrently, handleLogout is called for the ephemeral node and calls DeleteNode
// 3. UpdateNode and DeleteNode get batched together
// 4. If UpdateNode's result is used to call persistNodeToDB after the deletion,
// the node could be re-inserted into the database even though it was deleted
func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
ephemeralNode := createTestNode(4, 1, "test-user", "ephemeral-node")
ephemeralNode.AuthKey = &types.PreAuthKey{
ID: 1,
Key: "test-key",
Ephemeral: true,
}
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
// Put ephemeral node in store
_ = store.PutNode(ephemeralNode)
// Simulate concurrent operations:
// 1. UpdateNode (from UpdateNodeFromMapRequest during polling)
// 2. DeleteNode (from handleLogout when client sends logout request)
var updatedNode types.NodeView
var updateOk bool
done := make(chan bool, 2)
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
go func() {
updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
})
done <- true
}()
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
go func() {
time.Sleep(1 * time.Millisecond) // Slight delay to batch operations
store.DeleteNode(ephemeralNode.ID)
done <- true
}()
// Wait for both operations
<-done
<-done
// Give batching time to complete
time.Sleep(50 * time.Millisecond)
// Node should be deleted from store
_, found := store.GetNode(ephemeralNode.ID)
assert.False(t, found, "ephemeral node should be deleted from NodeStore")
// Critical assertion: if UpdateNode returned before DeleteNode completed,
// the updatedNode might be valid but the node is actually deleted.
// This is the bug - UpdateNodeFromMapRequest would get a valid node,
// then try to persist it, re-inserting the deleted ephemeral node.
if updateOk && updatedNode.Valid() {
t.Log("UpdateNode returned valid node, but node is deleted - this is the race condition")
// In the real code, this would cause persistNodeToDB to be called with updatedNode
// The fix in persistNodeToDB checks if the node still exists:
_, stillExists := store.GetNode(updatedNode.ID())
assert.False(t, stillExists, "persistNodeToDB should check NodeStore and find node deleted")
} else if !updateOk || !updatedNode.Valid() {
t.Log("UpdateNode correctly returned invalid/not-ok result (delete happened in same batch)")
}
}
// TestUpdateNodeFromMapRequestEphemeralLogoutSequence tests the exact sequence
// that causes ephemeral node logout failures:
// 1. Client sends MapRequest with updated endpoint info
// 2. UpdateNodeFromMapRequest starts processing, calls UpdateNode
// 3. Client sends logout request (past expiry)
// 4. handleLogout calls DeleteNode for ephemeral node
// 5. UpdateNode and DeleteNode batch together
// 6. UpdateNode returns a valid node (from before delete in batch)
// 7. persistNodeToDB is called with the stale valid node
// 8. Node gets re-inserted into database instead of staying deleted
func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5")
ephemeralNode.AuthKey = &types.PreAuthKey{
ID: 2,
Key: "test-key-2",
Ephemeral: true,
}
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
// Initial state: ephemeral node exists
_ = store.PutNode(ephemeralNode)
// Step 1: UpdateNodeFromMapRequest calls UpdateNode
// (simulating client sending MapRequest with endpoint updates)
updateStarted := make(chan bool)
var updatedNode types.NodeView
var updateOk bool
go func() {
updateStarted <- true
updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
endpoint := netip.MustParseAddrPort("10.0.0.1:41641")
n.Endpoints = []netip.AddrPort{endpoint}
})
}()
<-updateStarted
// Small delay to ensure UpdateNode is queued
time.Sleep(5 * time.Millisecond)
// Step 2: Logout happens - handleLogout calls DeleteNode
// (simulating client sending logout with past expiry)
store.DeleteNode(ephemeralNode.ID)
// Wait for batching to complete
time.Sleep(50 * time.Millisecond)
// Step 3: Check results
_, nodeExists := store.GetNode(ephemeralNode.ID)
assert.False(t, nodeExists, "ephemeral node must be deleted after logout")
// Step 4: Simulate what happens if we try to persist the updatedNode
if updateOk && updatedNode.Valid() {
// This is the problematic path - UpdateNode returned a valid node
// but the node was deleted in the same batch
t.Log("UpdateNode returned valid node even though node was deleted")
// The fix: persistNodeToDB must check NodeStore before persisting
_, checkExists := store.GetNode(updatedNode.ID())
if checkExists {
t.Error("BUG: Node still exists in NodeStore after deletion - should be impossible")
} else {
t.Log("SUCCESS: persistNodeToDB would detect node is deleted and refuse to persist")
}
} else {
t.Log("UpdateNode correctly indicated node was deleted (returned invalid or not-ok)")
}
// Final assertion: node must not exist
_, finalExists := store.GetNode(ephemeralNode.ID)
assert.False(t, finalExists, "ephemeral node must remain deleted")
}
// TestUpdateNodeDeletedInSameBatchReturnsInvalid specifically tests that when
// UpdateNode and DeleteNode are batched together with DELETE after UPDATE,
// UpdateNode returns ok=false to indicate the node was deleted.
func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
node := createTestNode(6, 1, "test-user", "test-node-6")
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
// Put node in store
_ = store.PutNode(node)
// Queue UpdateNode
updateDone := make(chan struct {
node types.NodeView
ok bool
})
go func() {
updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
})
updateDone <- struct {
node types.NodeView
ok bool
}{updatedNode, ok}
}()
// Small delay to ensure UpdateNode is queued
time.Sleep(5 * time.Millisecond)
// Queue DeleteNode - should batch with UpdateNode
store.DeleteNode(node.ID)
// Get UpdateNode result
result := <-updateDone
// Wait for batch to complete
time.Sleep(50 * time.Millisecond)
// Node should be deleted
_, exists := store.GetNode(node.ID)
assert.False(t, exists, "node should be deleted from store")
// UpdateNode should indicate the node was deleted
// After c6b09289988f34398eb3157e31ba092eb8721a9f, when UPDATE and DELETE
// are in the same batch with DELETE after UPDATE, UpdateNode returns
// the state after the batch is applied - which means the node doesn't exist
assert.False(t, result.ok, "UpdateNode should return ok=false when node deleted in same batch")
assert.False(t, result.node.Valid(), "UpdateNode should return invalid node when node deleted in same batch")
}
// TestPersistNodeToDBChecksNodeStoreBeforePersist verifies that persistNodeToDB
// checks if the node still exists in NodeStore before persisting to database.
// This prevents the race condition where:
// 1. UpdateNodeFromMapRequest calls UpdateNode and gets a valid node
// 2. Ephemeral node logout calls DeleteNode
// 3. UpdateNode and DeleteNode batch together
// 4. UpdateNode returns a valid node (from before delete in batch)
// 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node
// 6. persistNodeToDB must detect the node is deleted and refuse to persist
func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7")
ephemeralNode.AuthKey = &types.PreAuthKey{
ID: 3,
Key: "test-key-3",
Ephemeral: true,
}
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
// Put node in store
_ = store.PutNode(ephemeralNode)
// Simulate the race:
// 1. UpdateNode is called (from UpdateNodeFromMapRequest)
updatedNode, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
})
require.True(t, ok, "UpdateNode should succeed")
require.True(t, updatedNode.Valid(), "UpdateNode should return valid node")
// 2. Node is deleted (from handleLogout for ephemeral node)
store.DeleteNode(ephemeralNode.ID)
// Wait for deletion
time.Sleep(50 * time.Millisecond)
// 3. Verify node is deleted from store
_, exists := store.GetNode(ephemeralNode.ID)
require.False(t, exists, "node should be deleted from NodeStore")
// 4. Simulate what persistNodeToDB does - check if node still exists
// The fix in persistNodeToDB checks NodeStore before persisting:
// if !exists { return error }
// This prevents re-inserting the deleted node into the database
// Verify the node from UpdateNode is valid but node is gone from store
assert.True(t, updatedNode.Valid(), "UpdateNode returned a valid node view")
_, stillExists := store.GetNode(updatedNode.ID())
assert.False(t, stillExists, "but node should be deleted from NodeStore")
// This is the critical test: persistNodeToDB must check NodeStore
// and refuse to persist if the node doesn't exist anymore
// The actual persistNodeToDB implementation does:
// _, exists := s.nodeStore.GetNode(node.ID())
// if !exists { return error }
}

View File

@@ -10,9 +10,9 @@ import (
"tailscale.com/tailcfg"
)
// NetInfoFromMapRequest determines the correct NetInfo to use.
// netInfoFromMapRequest determines the correct NetInfo to use.
// Returns the NetInfo that should be used for this request.
func NetInfoFromMapRequest(
func netInfoFromMapRequest(
nodeID types.NodeID,
currentHostinfo *tailcfg.Hostinfo,
reqHostinfo *tailcfg.Hostinfo,

View File

@@ -61,7 +61,7 @@ func TestNetInfoFromMapRequest(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := NetInfoFromMapRequest(nodeID, tt.currentHostinfo, tt.reqHostinfo)
result := netInfoFromMapRequest(nodeID, tt.currentHostinfo, tt.reqHostinfo)
if tt.expectNetInfo == nil {
assert.Nil(t, result, "expected nil NetInfo")
@@ -100,14 +100,40 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) {
}
// BUG: Using the node being modified (no NetInfo) instead of existing node (has NetInfo)
buggyResult := NetInfoFromMapRequest(nodeID, nodeBeingModifiedHostinfo, newRegistrationHostinfo)
buggyResult := netInfoFromMapRequest(nodeID, nodeBeingModifiedHostinfo, newRegistrationHostinfo)
assert.Nil(t, buggyResult, "Bug: Should return nil when using wrong hostinfo reference")
// CORRECT: Using the existing node's hostinfo (has NetInfo)
correctResult := NetInfoFromMapRequest(nodeID, existingNodeHostinfo, newRegistrationHostinfo)
correctResult := netInfoFromMapRequest(nodeID, existingNodeHostinfo, newRegistrationHostinfo)
assert.NotNil(t, correctResult, "Fix: Should preserve NetInfo when using correct hostinfo reference")
assert.Equal(t, 5, correctResult.PreferredDERP, "Should preserve the DERP region from existing node")
})
t.Run("new_node_creation_for_different_user_should_preserve_netinfo", func(t *testing.T) {
// This test covers the scenario where:
// 1. A node exists for user1 with NetInfo
// 2. The same machine logs in as user2 (different user)
// 3. A NEW node is created for user2 (pre-auth key flow)
// 4. The new node should preserve NetInfo from the old node
// Existing node for user1 with NetInfo
existingNodeUser1Hostinfo := &tailcfg.Hostinfo{
Hostname: "test-node",
NetInfo: &tailcfg.NetInfo{PreferredDERP: 7},
}
// New registration request for user2 (no NetInfo yet)
newNodeUser2Hostinfo := &tailcfg.Hostinfo{
Hostname: "test-node",
OS: "linux",
// NetInfo is nil - registration request doesn't include it
}
// When creating a new node for user2, we should preserve NetInfo from user1's node
result := netInfoFromMapRequest(types.NodeID(2), existingNodeUser1Hostinfo, newNodeUser2Hostinfo)
assert.NotNil(t, result, "New node for user2 should preserve NetInfo from user1's node")
assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node")
})
}
// Simple helper function for tests

View File

@@ -15,7 +15,7 @@ import (
)
const (
batchSize = 10
batchSize = 100
batchTimeout = 500 * time.Millisecond
)
@@ -121,10 +121,11 @@ type Snapshot struct {
nodesByID map[types.NodeID]types.Node
// calculated from nodesByID
nodesByNodeKey map[key.NodePublic]types.NodeView
peersByNode map[types.NodeID][]types.NodeView
nodesByUser map[types.UserID][]types.NodeView
allNodes []types.NodeView
nodesByNodeKey map[key.NodePublic]types.NodeView
nodesByMachineKey map[key.MachinePublic]map[types.UserID]types.NodeView
peersByNode map[types.NodeID][]types.NodeView
nodesByUser map[types.UserID][]types.NodeView
allNodes []types.NodeView
}
// PeersFunc is a function that takes a list of nodes and returns a map
@@ -135,26 +136,29 @@ type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
// work represents a single operation to be performed on the NodeStore.
type work struct {
op int
nodeID types.NodeID
node types.Node
updateFn UpdateNodeFunc
result chan struct{}
op int
nodeID types.NodeID
node types.Node
updateFn UpdateNodeFunc
result chan struct{}
nodeResult chan types.NodeView // Channel to return the resulting node after batch application
}
// PutNode adds or updates a node in the store.
// If the node already exists, it will be replaced.
// If the node does not exist, it will be added.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) PutNode(n types.Node) {
// Returns the resulting node after all modifications in the batch have been applied.
func (s *NodeStore) PutNode(n types.Node) types.NodeView {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
defer timer.ObserveDuration()
work := work{
op: put,
nodeID: n.ID,
node: n,
result: make(chan struct{}),
op: put,
nodeID: n.ID,
node: n,
result: make(chan struct{}),
nodeResult: make(chan types.NodeView, 1),
}
nodeStoreQueueDepth.Inc()
@@ -162,7 +166,10 @@ func (s *NodeStore) PutNode(n types.Node) {
<-work.result
nodeStoreQueueDepth.Dec()
resultNode := <-work.nodeResult
nodeStoreOperations.WithLabelValues("put").Inc()
return resultNode
}
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
@@ -173,6 +180,7 @@ type UpdateNodeFunc func(n *types.Node)
// This is analogous to a database "transaction", or, the caller should
// rather collect all data they want to change, and then call this function.
// Fewer calls are better.
// Returns the resulting node after all modifications in the batch have been applied.
//
// TODO(kradalby): Technically we could have a version of this that modifies the node
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
@@ -181,15 +189,16 @@ type UpdateNodeFunc func(n *types.Node)
// a lock around the nodesByID map to ensure that no other writes are happening
// while we are modifying the node. Which mean we would need to implement read-write locks
// on all read operations.
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) (types.NodeView, bool) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
defer timer.ObserveDuration()
work := work{
op: update,
nodeID: nodeID,
updateFn: updateFn,
result: make(chan struct{}),
op: update,
nodeID: nodeID,
updateFn: updateFn,
result: make(chan struct{}),
nodeResult: make(chan types.NodeView, 1),
}
nodeStoreQueueDepth.Inc()
@@ -197,7 +206,11 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
<-work.result
nodeStoreQueueDepth.Dec()
resultNode := <-work.nodeResult
nodeStoreOperations.WithLabelValues("update").Inc()
// Return the node and whether it exists (is valid)
return resultNode, resultNode.Valid()
}
// DeleteNode removes a node from the store by its ID.
@@ -282,18 +295,32 @@ func (s *NodeStore) applyBatch(batch []work) {
nodes := make(map[types.NodeID]types.Node)
maps.Copy(nodes, s.data.Load().nodesByID)
for _, w := range batch {
// Track which work items need node results
nodeResultRequests := make(map[types.NodeID][]*work)
for i := range batch {
w := &batch[i]
switch w.op {
case put:
nodes[w.nodeID] = w.node
if w.nodeResult != nil {
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
}
case update:
// Update the specific node identified by nodeID
if n, exists := nodes[w.nodeID]; exists {
w.updateFn(&n)
nodes[w.nodeID] = n
}
if w.nodeResult != nil {
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
}
case del:
delete(nodes, w.nodeID)
// For delete operations, send an invalid NodeView if requested
if w.nodeResult != nil {
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
}
}
}
@@ -303,6 +330,24 @@ func (s *NodeStore) applyBatch(batch []work) {
// Update node count gauge
nodeStoreNodesCount.Set(float64(len(nodes)))
// Send the resulting nodes to all work items that requested them
for nodeID, workItems := range nodeResultRequests {
if node, exists := nodes[nodeID]; exists {
nodeView := node.View()
for _, w := range workItems {
w.nodeResult <- nodeView
close(w.nodeResult)
}
} else {
// Node was deleted or doesn't exist
for _, w := range workItems {
w.nodeResult <- types.NodeView{} // Send invalid view
close(w.nodeResult)
}
}
}
// Signal completion for all work items
for _, w := range batch {
close(w.result)
}
@@ -323,9 +368,10 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
}
newSnap := Snapshot{
nodesByID: nodes,
allNodes: allNodes,
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
nodesByID: nodes,
allNodes: allNodes,
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
nodesByMachineKey: make(map[key.MachinePublic]map[types.UserID]types.NodeView),
// peersByNode is most likely the most expensive operation,
// it will use the list of all nodes, combined with the
@@ -339,11 +385,19 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
nodesByUser: make(map[types.UserID][]types.NodeView),
}
// Build nodesByUser and nodesByNodeKey maps
// Build nodesByUser, nodesByNodeKey, and nodesByMachineKey maps
for _, n := range nodes {
nodeView := n.View()
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
userID := types.UserID(n.UserID)
newSnap.nodesByUser[userID] = append(newSnap.nodesByUser[userID], nodeView)
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
// Build machine key index
if newSnap.nodesByMachineKey[n.MachineKey] == nil {
newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView)
}
newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView
}
return newSnap
@@ -382,19 +436,40 @@ func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bo
return nodeView, exists
}
// GetNodeByMachineKey returns a node by its machine key. The bool indicates if the node exists.
func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) {
// GetNodeByMachineKey returns a node by its machine key and user ID. The bool indicates if the node exists.
func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic, userID types.UserID) (types.NodeView, bool) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("get_by_machine_key").Inc()
snapshot := s.data.Load()
// We don't have a byMachineKey map, so we need to iterate
// This could be optimized by adding a byMachineKey map if this becomes a hot path
for _, node := range snapshot.nodesByID {
if node.MachineKey == machineKey {
return node.View(), true
if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists {
if node, exists := userMap[userID]; exists {
return node, true
}
}
return types.NodeView{}, false
}
// GetNodeByMachineKeyAnyUser returns the first node with the given machine key,
// regardless of which user it belongs to. This is useful for scenarios like
// transferring a node to a different user when re-authenticating with a
// different user's auth key.
// If multiple nodes exist with the same machine key (different users), the
// first one found is returned (order is not guaranteed).
func (s *NodeStore) GetNodeByMachineKeyAnyUser(machineKey key.MachinePublic) (types.NodeView, bool) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key_any_user"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("get_by_machine_key_any_user").Inc()
snapshot := s.data.Load()
if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists {
// Return the first node found (order not guaranteed due to map iteration)
for _, node := range userMap {
return node, true
}
}

View File

@@ -1,7 +1,11 @@
package state
import (
"context"
"fmt"
"net/netip"
"runtime"
"sync"
"testing"
"time"
@@ -249,7 +253,9 @@ func TestNodeStoreOperations(t *testing.T) {
name: "add first node",
action: func(store *NodeStore) {
node := createTestNode(1, 1, "user1", "node1")
store.PutNode(node)
resultNode := store.PutNode(node)
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
assert.Equal(t, node.ID, resultNode.ID())
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 1)
@@ -288,7 +294,9 @@ func TestNodeStoreOperations(t *testing.T) {
name: "add second node same user",
action: func(store *NodeStore) {
node2 := createTestNode(2, 1, "user1", "node2")
store.PutNode(node2)
resultNode := store.PutNode(node2)
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
assert.Equal(t, types.NodeID(2), resultNode.ID())
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
@@ -308,7 +316,9 @@ func TestNodeStoreOperations(t *testing.T) {
name: "add third node different user",
action: func(store *NodeStore) {
node3 := createTestNode(3, 2, "user2", "node3")
store.PutNode(node3)
resultNode := store.PutNode(node3)
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
assert.Equal(t, types.NodeID(3), resultNode.ID())
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
@@ -409,10 +419,14 @@ func TestNodeStoreOperations(t *testing.T) {
{
name: "update node hostname",
action: func(store *NodeStore) {
store.UpdateNode(1, func(n *types.Node) {
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "updated-node1"
n.GivenName = "updated-node1"
})
assert.True(t, ok, "UpdateNode should return true for existing node")
assert.True(t, resultNode.Valid(), "Result node should be valid")
assert.Equal(t, "updated-node1", resultNode.Hostname())
assert.Equal(t, "updated-node1", resultNode.GivenName())
snapshot := store.data.Load()
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
@@ -436,10 +450,14 @@ func TestNodeStoreOperations(t *testing.T) {
name: "add nodes with odd-even filtering",
action: func(store *NodeStore) {
// Add nodes in sequence
store.PutNode(createTestNode(1, 1, "user1", "node1"))
store.PutNode(createTestNode(2, 2, "user2", "node2"))
store.PutNode(createTestNode(3, 3, "user3", "node3"))
store.PutNode(createTestNode(4, 4, "user4", "node4"))
n1 := store.PutNode(createTestNode(1, 1, "user1", "node1"))
assert.True(t, n1.Valid())
n2 := store.PutNode(createTestNode(2, 2, "user2", "node2"))
assert.True(t, n2.Valid())
n3 := store.PutNode(createTestNode(3, 3, "user3", "node3"))
assert.True(t, n3.Valid())
n4 := store.PutNode(createTestNode(4, 4, "user4", "node4"))
assert.True(t, n4.Valid())
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 4)
@@ -478,6 +496,328 @@ func TestNodeStoreOperations(t *testing.T) {
},
},
},
{
name: "test batch modifications return correct node state",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
initialNodes := types.Nodes{&node1, &node2}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify initial state",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
},
},
{
name: "concurrent updates should reflect all batch changes",
action: func(store *NodeStore) {
// Start multiple updates that will be batched together
done1 := make(chan struct{})
done2 := make(chan struct{})
done3 := make(chan struct{})
var resultNode1, resultNode2 types.NodeView
var newNode3 types.NodeView
var ok1, ok2 bool
// These should all be processed in the same batch
go func() {
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "batch-updated-node1"
n.GivenName = "batch-given-1"
})
close(done1)
}()
go func() {
resultNode2, ok2 = store.UpdateNode(2, func(n *types.Node) {
n.Hostname = "batch-updated-node2"
n.GivenName = "batch-given-2"
})
close(done2)
}()
go func() {
node3 := createTestNode(3, 1, "user1", "node3")
newNode3 = store.PutNode(node3)
close(done3)
}()
// Wait for all operations to complete
<-done1
<-done2
<-done3
// Verify the returned nodes reflect the batch state
assert.True(t, ok1, "UpdateNode should succeed for node 1")
assert.True(t, ok2, "UpdateNode should succeed for node 2")
assert.True(t, resultNode1.Valid())
assert.True(t, resultNode2.Valid())
assert.True(t, newNode3.Valid())
// Check that returned nodes have the updated values
assert.Equal(t, "batch-updated-node1", resultNode1.Hostname())
assert.Equal(t, "batch-given-1", resultNode1.GivenName())
assert.Equal(t, "batch-updated-node2", resultNode2.Hostname())
assert.Equal(t, "batch-given-2", resultNode2.GivenName())
assert.Equal(t, "node3", newNode3.Hostname())
// Verify the snapshot also reflects all changes
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Equal(t, "batch-updated-node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "batch-updated-node2", snapshot.nodesByID[2].Hostname)
assert.Equal(t, "node3", snapshot.nodesByID[3].Hostname)
// Verify peer relationships are updated correctly with new node
assert.Len(t, snapshot.peersByNode[1], 2) // sees nodes 2 and 3
assert.Len(t, snapshot.peersByNode[2], 2) // sees nodes 1 and 3
assert.Len(t, snapshot.peersByNode[3], 2) // sees nodes 1 and 2
},
},
{
name: "update non-existent node returns invalid view",
action: func(store *NodeStore) {
resultNode, ok := store.UpdateNode(999, func(n *types.Node) {
n.Hostname = "should-not-exist"
})
assert.False(t, ok, "UpdateNode should return false for non-existent node")
assert.False(t, resultNode.Valid(), "Result should be invalid NodeView")
},
},
{
name: "multiple updates to same node in batch all see final state",
action: func(store *NodeStore) {
// This test verifies that when multiple updates to the same node
// are batched together, each returned node reflects ALL changes
// in the batch, not just the individual update's changes.
done1 := make(chan struct{})
done2 := make(chan struct{})
done3 := make(chan struct{})
var resultNode1, resultNode2, resultNode3 types.NodeView
var ok1, ok2, ok3 bool
// These updates all modify node 1 and should be batched together
// The final state should have all three modifications applied
go func() {
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "multi-update-hostname"
})
close(done1)
}()
go func() {
resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) {
n.GivenName = "multi-update-givenname"
})
close(done2)
}()
go func() {
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.ForcedTags = []string{"tag1", "tag2"}
})
close(done3)
}()
// Wait for all operations to complete
<-done1
<-done2
<-done3
// All updates should succeed
assert.True(t, ok1, "First update should succeed")
assert.True(t, ok2, "Second update should succeed")
assert.True(t, ok3, "Third update should succeed")
// CRITICAL: Each returned node should reflect ALL changes from the batch
// not just the change from its specific update call
// resultNode1 (from hostname update) should also have the givenname and tags changes
assert.Equal(t, "multi-update-hostname", resultNode1.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode1.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.ForcedTags().AsSlice())
// resultNode2 (from givenname update) should also have the hostname and tags changes
assert.Equal(t, "multi-update-hostname", resultNode2.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode2.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.ForcedTags().AsSlice())
// resultNode3 (from tags update) should also have the hostname and givenname changes
assert.Equal(t, "multi-update-hostname", resultNode3.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode3.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.ForcedTags().AsSlice())
// Verify the snapshot also has all changes
snapshot := store.data.Load()
finalNode := snapshot.nodesByID[1]
assert.Equal(t, "multi-update-hostname", finalNode.Hostname)
assert.Equal(t, "multi-update-givenname", finalNode.GivenName)
assert.Equal(t, []string{"tag1", "tag2"}, finalNode.ForcedTags)
},
},
},
},
{
name: "test UpdateNode result is immutable for database save",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
initialNodes := types.Nodes{&node1, &node2}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{
{
name: "verify returned node is complete and consistent",
action: func(store *NodeStore) {
// Update a node and verify the returned view is complete
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "db-save-hostname"
n.GivenName = "db-save-given"
n.ForcedTags = []string{"db-tag1", "db-tag2"}
})
assert.True(t, ok, "UpdateNode should succeed")
assert.True(t, resultNode.Valid(), "Result should be valid")
// Verify the returned node has all expected values
assert.Equal(t, "db-save-hostname", resultNode.Hostname())
assert.Equal(t, "db-save-given", resultNode.GivenName())
assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.ForcedTags().AsSlice())
// Convert to struct as would be done for database save
nodePtr := resultNode.AsStruct()
assert.NotNil(t, nodePtr)
assert.Equal(t, "db-save-hostname", nodePtr.Hostname)
assert.Equal(t, "db-save-given", nodePtr.GivenName)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.ForcedTags)
// Verify the snapshot also reflects the same state
snapshot := store.data.Load()
storedNode := snapshot.nodesByID[1]
assert.Equal(t, "db-save-hostname", storedNode.Hostname)
assert.Equal(t, "db-save-given", storedNode.GivenName)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.ForcedTags)
},
},
{
name: "concurrent updates all return consistent final state for DB save",
action: func(store *NodeStore) {
// Multiple goroutines updating the same node
// All should receive the final batch state suitable for DB save
done1 := make(chan struct{})
done2 := make(chan struct{})
done3 := make(chan struct{})
var result1, result2, result3 types.NodeView
var ok1, ok2, ok3 bool
// Start concurrent updates
go func() {
result1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "concurrent-db-hostname"
})
close(done1)
}()
go func() {
result2, ok2 = store.UpdateNode(1, func(n *types.Node) {
n.GivenName = "concurrent-db-given"
})
close(done2)
}()
go func() {
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.ForcedTags = []string{"concurrent-tag"}
})
close(done3)
}()
// Wait for all to complete
<-done1
<-done2
<-done3
assert.True(t, ok1 && ok2 && ok3, "All updates should succeed")
// All results should be valid and suitable for database save
assert.True(t, result1.Valid())
assert.True(t, result2.Valid())
assert.True(t, result3.Valid())
// Convert each to struct as would be done for DB save
nodePtr1 := result1.AsStruct()
nodePtr2 := result2.AsStruct()
nodePtr3 := result3.AsStruct()
// All should have the complete final state
assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.ForcedTags)
assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.ForcedTags)
assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.ForcedTags)
// Verify consistency with stored state
snapshot := store.data.Load()
storedNode := snapshot.nodesByID[1]
assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname)
assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName)
assert.Equal(t, nodePtr1.ForcedTags, storedNode.ForcedTags)
},
},
{
name: "verify returned node preserves all fields for DB save",
action: func(store *NodeStore) {
// Get initial state
snapshot := store.data.Load()
originalNode := snapshot.nodesByID[2]
originalIPv4 := originalNode.IPv4
originalIPv6 := originalNode.IPv6
originalCreatedAt := originalNode.CreatedAt
originalUser := originalNode.User
// Update only hostname
resultNode, ok := store.UpdateNode(2, func(n *types.Node) {
n.Hostname = "preserve-test-hostname"
})
assert.True(t, ok, "Update should succeed")
// Convert to struct for DB save
nodeForDB := resultNode.AsStruct()
// Verify all fields are preserved
assert.Equal(t, "preserve-test-hostname", nodeForDB.Hostname)
assert.Equal(t, originalIPv4, nodeForDB.IPv4)
assert.Equal(t, originalIPv6, nodeForDB.IPv6)
assert.Equal(t, originalCreatedAt, nodeForDB.CreatedAt)
assert.Equal(t, originalUser.Name, nodeForDB.User.Name)
assert.Equal(t, types.NodeID(2), nodeForDB.ID)
// These fields should be suitable for direct database save
assert.NotNil(t, nodeForDB.IPv4)
assert.NotNil(t, nodeForDB.IPv6)
assert.False(t, nodeForDB.CreatedAt.IsZero())
},
},
},
},
}
for _, tt := range tests {
@@ -499,3 +839,302 @@ type testStep struct {
name string
action func(store *NodeStore)
}
// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests ---
// Helper for concurrent test nodes
func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
machineKey := key.NewMachine()
nodeKey := key.NewNode()
return types.Node{
ID: id,
Hostname: hostname,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
UserID: 1,
User: types.User{
Name: "concurrent-test-user",
},
}
}
// --- Concurrency: concurrent PutNode operations ---
func TestNodeStoreConcurrentPutNode(t *testing.T) {
const concurrentOps = 20
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
var wg sync.WaitGroup
results := make(chan bool, concurrentOps)
for i := 0; i < concurrentOps; i++ {
wg.Add(1)
go func(nodeID int) {
defer wg.Done()
node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node")
resultNode := store.PutNode(node)
results <- resultNode.Valid()
}(i + 1)
}
wg.Wait()
close(results)
successCount := 0
for success := range results {
if success {
successCount++
}
}
require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed")
}
// --- Batching: concurrent ops fit in one batch ---
func TestNodeStoreBatchingEfficiency(t *testing.T) {
const batchSize = 10
const ops = 15 // more than batchSize
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
var wg sync.WaitGroup
results := make(chan bool, ops)
for i := 0; i < ops; i++ {
wg.Add(1)
go func(nodeID int) {
defer wg.Done()
node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node")
resultNode := store.PutNode(node)
results <- resultNode.Valid()
}(i + 1)
}
wg.Wait()
close(results)
successCount := 0
for success := range results {
if success {
successCount++
}
}
require.Equal(t, ops, successCount, "All batch PutNode operations should succeed")
}
// --- Race conditions: many goroutines on same node ---
func TestNodeStoreRaceConditions(t *testing.T) {
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
nodeID := types.NodeID(1)
node := createConcurrentTestNode(nodeID, "race-node")
resultNode := store.PutNode(node)
require.True(t, resultNode.Valid())
const numGoroutines = 30
const opsPerGoroutine = 10
var wg sync.WaitGroup
errors := make(chan error, numGoroutines*opsPerGoroutine)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(gid int) {
defer wg.Done()
for j := 0; j < opsPerGoroutine; j++ {
switch j % 3 {
case 0:
resultNode, _ := store.UpdateNode(nodeID, func(n *types.Node) {
n.Hostname = "race-updated"
})
if !resultNode.Valid() {
errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j)
}
case 1:
retrieved, found := store.GetNode(nodeID)
if !found || !retrieved.Valid() {
errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j)
}
case 2:
newNode := createConcurrentTestNode(nodeID, "race-put")
resultNode := store.PutNode(newNode)
if !resultNode.Valid() {
errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j)
}
}
}
}(i)
}
wg.Wait()
close(errors)
errorCount := 0
for err := range errors {
t.Error(err)
errorCount++
}
if errorCount > 0 {
t.Fatalf("Race condition test failed with %d errors", errorCount)
}
}
// --- Resource cleanup: goroutine leak detection ---
func TestNodeStoreResourceCleanup(t *testing.T) {
// initialGoroutines := runtime.NumGoroutine()
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
time.Sleep(50 * time.Millisecond)
afterStartGoroutines := runtime.NumGoroutine()
const ops = 100
for i := 0; i < ops; i++ {
nodeID := types.NodeID(i + 1)
node := createConcurrentTestNode(nodeID, "cleanup-node")
resultNode := store.PutNode(node)
assert.True(t, resultNode.Valid())
store.UpdateNode(nodeID, func(n *types.Node) {
n.Hostname = "cleanup-updated"
})
retrieved, found := store.GetNode(nodeID)
assert.True(t, found && retrieved.Valid())
if i%10 == 9 {
store.DeleteNode(nodeID)
}
}
runtime.GC()
time.Sleep(100 * time.Millisecond)
finalGoroutines := runtime.NumGoroutine()
if finalGoroutines > afterStartGoroutines+2 {
t.Errorf("Potential goroutine leak: started with %d, ended with %d", afterStartGoroutines, finalGoroutines)
}
}
// --- Timeout/deadlock: operations complete within reasonable time ---
func TestNodeStoreOperationTimeout(t *testing.T) {
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
const ops = 30
var wg sync.WaitGroup
putResults := make([]error, ops)
updateResults := make([]error, ops)
// Launch all PutNode operations concurrently
for i := 1; i <= ops; i++ {
nodeID := types.NodeID(i)
wg.Add(1)
go func(idx int, id types.NodeID) {
defer wg.Done()
startPut := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id)
node := createConcurrentTestNode(id, "timeout-node")
resultNode := store.PutNode(node)
endPut := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut))
if !resultNode.Valid() {
putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id)
}
}(i, nodeID)
}
wg.Wait()
// Launch all UpdateNode operations concurrently
wg = sync.WaitGroup{}
for i := 1; i <= ops; i++ {
nodeID := types.NodeID(i)
wg.Add(1)
go func(idx int, id types.NodeID) {
defer wg.Done()
startUpdate := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id)
resultNode, ok := store.UpdateNode(id, func(n *types.Node) {
n.Hostname = "timeout-updated"
})
endUpdate := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate))
if !ok || !resultNode.Valid() {
updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id)
}
}(i, nodeID)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
errorCount := 0
for _, err := range putResults {
if err != nil {
t.Error(err)
errorCount++
}
}
for _, err := range updateResults {
if err != nil {
t.Error(err)
errorCount++
}
}
if errorCount == 0 {
t.Log("All concurrent operations completed successfully within timeout")
} else {
t.Fatalf("Some concurrent operations failed: %d errors", errorCount)
}
case <-ctx.Done():
fmt.Println("[TestNodeStoreOperationTimeout] Timeout reached, test failed")
t.Fatal("Operations timed out - potential deadlock or resource issue")
}
}
// --- Edge case: update non-existent node ---
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
for i := 0; i < 10; i++ {
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
nonExistentID := types.NodeID(999 + i)
updateCallCount := 0
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID)
resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) {
updateCallCount++
n.Hostname = "should-never-be-called"
})
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) finished, valid=%v, ok=%v, updateCallCount=%d\n", nonExistentID, resultNode.Valid(), ok, updateCallCount)
assert.False(t, ok, "UpdateNode should return false for non-existent node")
assert.False(t, resultNode.Valid(), "UpdateNode should return invalid node for non-existent node")
assert.Equal(t, 0, updateCallCount, "UpdateFn should not be called for non-existent node")
store.Stop()
}
}
// --- Allocation benchmark ---
func BenchmarkNodeStoreAllocations(b *testing.B) {
store := NewNodeStore(nil, allowAllPeersFunc)
store.Start()
defer store.Stop()
b.ResetTimer()
for i := 0; i < b.N; i++ {
nodeID := types.NodeID(i + 1)
node := createConcurrentTestNode(nodeID, "bench-node")
store.PutNode(node)
store.UpdateNode(nodeID, func(n *types.Node) {
n.Hostname = "bench-updated"
})
store.GetNode(nodeID)
if i%10 == 9 {
store.DeleteNode(nodeID)
}
}
}
func TestNodeStoreAllocationStats(t *testing.T) {
res := testing.Benchmark(BenchmarkNodeStoreAllocations)
allocs := res.AllocsPerOp()
t.Logf("NodeStore allocations per op: %.2f", float64(allocs))
}

File diff suppressed because it is too large Load Diff

View File

@@ -638,6 +638,11 @@ func (node Node) DebugString() string {
return sb.String()
}
func (v NodeView) UserView() UserView {
u := v.User()
return u.View()
}
func (v NodeView) IPs() []netip.Addr {
if !v.Valid() {
return nil

View File

@@ -104,27 +104,31 @@ func (u *User) profilePicURL() string {
return u.ProfilePicURL
}
func (u *User) TailscaleUser() *tailcfg.User {
user := tailcfg.User{
func (u *User) TailscaleUser() tailcfg.User {
return tailcfg.User{
ID: tailcfg.UserID(u.ID),
DisplayName: u.Display(),
ProfilePicURL: u.profilePicURL(),
Created: u.CreatedAt,
}
return &user
}
func (u *User) TailscaleLogin() *tailcfg.Login {
login := tailcfg.Login{
func (u UserView) TailscaleUser() tailcfg.User {
return u.ж.TailscaleUser()
}
func (u *User) TailscaleLogin() tailcfg.Login {
return tailcfg.Login{
ID: tailcfg.LoginID(u.ID),
Provider: u.Provider,
LoginName: u.Username(),
DisplayName: u.Display(),
ProfilePicURL: u.profilePicURL(),
}
}
return &login
func (u UserView) TailscaleLogin() tailcfg.Login {
return u.ж.TailscaleLogin()
}
func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
@@ -136,6 +140,10 @@ func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
}
}
func (u UserView) TailscaleUserProfile() tailcfg.UserProfile {
return u.ж.TailscaleUserProfile()
}
func (u *User) Proto() *v1.User {
return &v1.User{
Id: uint64(u.ID),

View File

@@ -11,6 +11,7 @@ import (
"strings"
"time"
"tailscale.com/tailcfg"
"tailscale.com/util/cmpver"
)
@@ -258,3 +259,59 @@ func IsCI() bool {
return false
}
// SafeHostname extracts a hostname from Hostinfo, providing sensible defaults
// if Hostinfo is nil or Hostname is empty. This prevents nil pointer dereferences
// and ensures nodes always have a valid hostname.
// The hostname is truncated to 63 characters to comply with DNS label length limits (RFC 1123).
func SafeHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string {
if hostinfo == nil || hostinfo.Hostname == "" {
// Generate a default hostname using machine key prefix
if machineKey != "" {
keyPrefix := machineKey
if len(machineKey) > 8 {
keyPrefix = machineKey[:8]
}
return fmt.Sprintf("node-%s", keyPrefix)
}
if nodeKey != "" {
keyPrefix := nodeKey
if len(nodeKey) > 8 {
keyPrefix = nodeKey[:8]
}
return fmt.Sprintf("node-%s", keyPrefix)
}
return "unknown-node"
}
hostname := hostinfo.Hostname
// Validate hostname length - DNS label limit is 63 characters (RFC 1123)
// Truncate if necessary to ensure compatibility with given name generation
if len(hostname) > 63 {
hostname = hostname[:63]
}
return hostname
}
// EnsureValidHostinfo ensures that Hostinfo is non-nil and has a valid hostname.
// If Hostinfo is nil, it creates a minimal valid Hostinfo with a generated hostname.
// Returns the validated/created Hostinfo and the extracted hostname.
func EnsureValidHostinfo(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) (*tailcfg.Hostinfo, string) {
if hostinfo == nil {
hostname := SafeHostname(nil, machineKey, nodeKey)
return &tailcfg.Hostinfo{
Hostname: hostname,
}, hostname
}
hostname := SafeHostname(hostinfo, machineKey, nodeKey)
// Update the hostname in the hostinfo if it was empty or if it was truncated
if hostinfo.Hostname == "" || hostinfo.Hostname != hostname {
hostinfo.Hostname = hostname
}
return hostinfo, hostname
}

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"tailscale.com/tailcfg"
)
func TestTailscaleVersionNewerOrEqual(t *testing.T) {
@@ -793,3 +794,395 @@ over a maximum of 30 hops:
})
}
}
func TestSafeHostname(t *testing.T) {
t.Parallel()
tests := []struct {
name string
hostinfo *tailcfg.Hostinfo
machineKey string
nodeKey string
want string
}{
{
name: "valid_hostname",
hostinfo: &tailcfg.Hostinfo{
Hostname: "test-node",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "test-node",
},
{
name: "nil_hostinfo_with_machine_key",
hostinfo: nil,
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "node-mkey1234",
},
{
name: "nil_hostinfo_with_node_key_only",
hostinfo: nil,
machineKey: "",
nodeKey: "nkey12345678",
want: "node-nkey1234",
},
{
name: "nil_hostinfo_no_keys",
hostinfo: nil,
machineKey: "",
nodeKey: "",
want: "unknown-node",
},
{
name: "empty_hostname_with_machine_key",
hostinfo: &tailcfg.Hostinfo{
Hostname: "",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "node-mkey1234",
},
{
name: "empty_hostname_with_node_key_only",
hostinfo: &tailcfg.Hostinfo{
Hostname: "",
},
machineKey: "",
nodeKey: "nkey12345678",
want: "node-nkey1234",
},
{
name: "empty_hostname_no_keys",
hostinfo: &tailcfg.Hostinfo{
Hostname: "",
},
machineKey: "",
nodeKey: "",
want: "unknown-node",
},
{
name: "hostname_exactly_63_chars",
hostinfo: &tailcfg.Hostinfo{
Hostname: "123456789012345678901234567890123456789012345678901234567890123",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "123456789012345678901234567890123456789012345678901234567890123",
},
{
name: "hostname_64_chars_truncated",
hostinfo: &tailcfg.Hostinfo{
Hostname: "1234567890123456789012345678901234567890123456789012345678901234",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "123456789012345678901234567890123456789012345678901234567890123",
},
{
name: "hostname_very_long_truncated",
hostinfo: &tailcfg.Hostinfo{
Hostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits-of-63-characters-and-should-be-truncated",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits",
},
{
name: "hostname_with_special_chars",
hostinfo: &tailcfg.Hostinfo{
Hostname: "node-with-special!@#$%",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "node-with-special!@#$%",
},
{
name: "hostname_with_unicode",
hostinfo: &tailcfg.Hostinfo{
Hostname: "node-ñoño-测试",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "node-ñoño-测试",
},
{
name: "short_machine_key",
hostinfo: &tailcfg.Hostinfo{
Hostname: "",
},
machineKey: "short",
nodeKey: "nkey12345678",
want: "node-short",
},
{
name: "short_node_key",
hostinfo: &tailcfg.Hostinfo{
Hostname: "",
},
machineKey: "",
nodeKey: "short",
want: "node-short",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := SafeHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
if got != tt.want {
t.Errorf("SafeHostname() = %v, want %v", got, tt.want)
}
})
}
}
func TestEnsureValidHostinfo(t *testing.T) {
t.Parallel()
tests := []struct {
name string
hostinfo *tailcfg.Hostinfo
machineKey string
nodeKey string
wantHostname string
checkHostinfo func(*testing.T, *tailcfg.Hostinfo)
}{
{
name: "valid_hostinfo_unchanged",
hostinfo: &tailcfg.Hostinfo{
Hostname: "test-node",
OS: "linux",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
wantHostname: "test-node",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "test-node" {
t.Errorf("hostname = %v, want test-node", hi.Hostname)
}
if hi.OS != "linux" {
t.Errorf("OS = %v, want linux", hi.OS)
}
},
},
{
name: "nil_hostinfo_creates_default",
hostinfo: nil,
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
wantHostname: "node-mkey1234",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "node-mkey1234" {
t.Errorf("hostname = %v, want node-mkey1234", hi.Hostname)
}
},
},
{
name: "empty_hostname_updated",
hostinfo: &tailcfg.Hostinfo{
Hostname: "",
OS: "darwin",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
wantHostname: "node-mkey1234",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "node-mkey1234" {
t.Errorf("hostname = %v, want node-mkey1234", hi.Hostname)
}
if hi.OS != "darwin" {
t.Errorf("OS = %v, want darwin", hi.OS)
}
},
},
{
name: "long_hostname_truncated",
hostinfo: &tailcfg.Hostinfo{
Hostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits-of-63-characters",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
wantHostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "test-node-with-very-long-hostname-that-exceeds-dns-label-limits" {
t.Errorf("hostname = %v, want truncated", hi.Hostname)
}
if len(hi.Hostname) != 63 {
t.Errorf("hostname length = %v, want 63", len(hi.Hostname))
}
},
},
{
name: "nil_hostinfo_node_key_only",
hostinfo: nil,
machineKey: "",
nodeKey: "nkey12345678",
wantHostname: "node-nkey1234",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "node-nkey1234" {
t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname)
}
},
},
{
name: "nil_hostinfo_no_keys",
hostinfo: nil,
machineKey: "",
nodeKey: "",
wantHostname: "unknown-node",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "unknown-node" {
t.Errorf("hostname = %v, want unknown-node", hi.Hostname)
}
},
},
{
name: "empty_hostname_no_keys",
hostinfo: &tailcfg.Hostinfo{
Hostname: "",
},
machineKey: "",
nodeKey: "",
wantHostname: "unknown-node",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "unknown-node" {
t.Errorf("hostname = %v, want unknown-node", hi.Hostname)
}
},
},
{
name: "preserves_other_fields",
hostinfo: &tailcfg.Hostinfo{
Hostname: "test",
OS: "windows",
OSVersion: "10.0.19044",
DeviceModel: "test-device",
BackendLogID: "log123",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
wantHostname: "test",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "test" {
t.Errorf("hostname = %v, want test", hi.Hostname)
}
if hi.OS != "windows" {
t.Errorf("OS = %v, want windows", hi.OS)
}
if hi.OSVersion != "10.0.19044" {
t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion)
}
if hi.DeviceModel != "test-device" {
t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel)
}
if hi.BackendLogID != "log123" {
t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID)
}
},
},
{
name: "exactly_63_chars_unchanged",
hostinfo: &tailcfg.Hostinfo{
Hostname: "123456789012345678901234567890123456789012345678901234567890123",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
wantHostname: "123456789012345678901234567890123456789012345678901234567890123",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if len(hi.Hostname) != 63 {
t.Errorf("hostname length = %v, want 63", len(hi.Hostname))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotHostinfo, gotHostname := EnsureValidHostinfo(tt.hostinfo, tt.machineKey, tt.nodeKey)
if gotHostname != tt.wantHostname {
t.Errorf("EnsureValidHostinfo() hostname = %v, want %v", gotHostname, tt.wantHostname)
}
if gotHostinfo == nil {
t.Error("returned hostinfo should never be nil")
}
if tt.checkHostinfo != nil {
tt.checkHostinfo(t, gotHostinfo)
}
})
}
}
func TestSafeHostname_DNSLabelLimit(t *testing.T) {
t.Parallel()
testCases := []string{
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
"dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd",
}
for i, hostname := range testCases {
t.Run(cmp.Diff("", ""), func(t *testing.T) {
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
result := SafeHostname(hostinfo, "mkey", "nkey")
if len(result) > 63 {
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
}
})
}
}
func TestEnsureValidHostinfo_Idempotent(t *testing.T) {
t.Parallel()
originalHostinfo := &tailcfg.Hostinfo{
Hostname: "test-node",
OS: "linux",
}
hostinfo1, hostname1 := EnsureValidHostinfo(originalHostinfo, "mkey", "nkey")
hostinfo2, hostname2 := EnsureValidHostinfo(hostinfo1, "mkey", "nkey")
if hostname1 != hostname2 {
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)
}
if hostinfo1.Hostname != hostinfo2.Hostname {
t.Errorf("hostinfo hostnames not equal: %v != %v", hostinfo1.Hostname, hostinfo2.Hostname)
}
if hostinfo1.OS != hostinfo2.OS {
t.Errorf("hostinfo OS not equal: %v != %v", hostinfo1.OS, hostinfo2.OS)
}
}

View File

@@ -28,7 +28,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
opts := []hsic.Option{
@@ -43,31 +43,25 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
expectedNodes := make([]types.NodeID, 0, len(allClients))
for _, client := range allClients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
assertNoErr(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 30*time.Second)
expectedNodes := collectExpectedNodeIDs(t, allClients)
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 120*time.Second)
// Validate that all nodes have NetInfo and DERP servers before logout
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 1*time.Minute)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 3*time.Minute)
// assertClientsState(t, allClients)
@@ -97,19 +91,20 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
err = scenario.WaitForTailscaleLogout()
assertNoErrLogout(t, err)
requireNoErrLogout(t, err)
// After taking down all nodes, verify all systems show nodes offline
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should have logged out", 120*time.Second)
t.Logf("all clients logged out")
t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count")
}, 20*time.Second, 1*time.Second)
assert.NoError(ct, err, "Failed to list nodes after logout")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
}, 30*time.Second, 2*time.Second, "validating node persistence after logout (nodes should remain in database)")
for _, node := range listNodes {
assertLastSeenSet(t, node)
@@ -125,7 +120,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
require.NoError(t, err)
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
@@ -139,12 +134,13 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
}
t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection")
}, 30*time.Second, 2*time.Second)
assert.NoError(ct, err, "Failed to list nodes after relogin")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
}, 60*time.Second, 2*time.Second, "validating node count stability after same-user auth key relogin")
for _, node := range listNodes {
assertLastSeenSet(t, node)
@@ -152,11 +148,15 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 120*time.Second)
// Wait for Tailscale sync before validating NetInfo to ensure proper state propagation
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
// Validate that all nodes have NetInfo and DERP servers after reconnection
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 1*time.Minute)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 3*time.Minute)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
@@ -197,69 +197,10 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
}
// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database
// and a valid DERP server based on the NetInfo. This function follows the pattern of
// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state.
func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) {
t.Helper()
startTime := time.Now()
t.Logf("requireAllClientsNetInfoAndDERP: Starting validation at %s - %s", startTime.Format(TimestampFormat), message)
require.EventuallyWithT(t, func(c *assert.CollectT) {
// Get nodestore state
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
// Validate node counts first
expectedCount := len(expectedNodes)
assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch")
// Check each expected node
for _, nodeID := range expectedNodes {
node, exists := nodeStore[nodeID]
assert.True(c, exists, "Node %d not found in nodestore", nodeID)
if !exists {
continue
}
// Validate that the node has Hostinfo
assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo", nodeID, node.Hostname)
if node.Hostinfo == nil {
continue
}
// Validate that the node has NetInfo
assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo", nodeID, node.Hostname)
if node.Hostinfo.NetInfo == nil {
continue
}
// Validate that the node has a valid DERP server (PreferredDERP should be > 0)
preferredDERP := node.Hostinfo.NetInfo.PreferredDERP
assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0), got %d", nodeID, node.Hostname, preferredDERP)
t.Logf("Node %d (%s) has valid NetInfo with DERP server %d", nodeID, node.Hostname, preferredDERP)
}
}, timeout, 2*time.Second, message)
endTime := time.Now()
duration := endTime.Sub(startTime)
t.Logf("requireAllClientsNetInfoAndDERP: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message)
}
func assertLastSeenSet(t *testing.T, node *v1.Node) {
assert.NotNil(t, node)
assert.NotNil(t, node.GetLastSeen())
}
// This test will first log in two sets of nodes to two sets of users, then
// it will log out all users from user2 and log them in as user1.
// This should leave us with all nodes connected to user1, while user2
// still has nodes, but they are not connected.
// it will log out all nodes and log them in as user1 using a pre-auth key.
// This should create new nodes for user1 while preserving the original nodes for user2.
// Pre-auth key re-authentication with a different user creates new nodes, not transfers.
func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
IntegrationSkip(t)
@@ -269,7 +210,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{},
@@ -277,18 +218,25 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
hsic.WithTLS(),
hsic.WithDERPAsIP(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
// Collect expected node IDs for validation
expectedNodes := collectExpectedNodeIDs(t, allClients)
// Validate initial connection state
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes))
@@ -303,12 +251,15 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
}
err = scenario.WaitForTailscaleLogout()
assertNoErrLogout(t, err)
requireNoErrLogout(t, err)
// Validate that all nodes are offline after logout
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second)
t.Logf("all clients logged out")
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
require.NoError(t, err)
// Create a new authkey for user1, to be used for all clients
key, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), true, false)
@@ -326,28 +277,43 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
}
var user1Nodes []*v1.Node
t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
user1Nodes, err = headscale.ListNodes("user1")
assert.NoError(ct, err)
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all clients after re-login")
}, 20*time.Second, 1*time.Second)
assert.NoError(ct, err, "Failed to list nodes for user1 after relogin")
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes))
}, 60*time.Second, 2*time.Second, "validating user1 has all client nodes after auth key relogin")
// Validate that all the old nodes are still present with user2
// Collect expected node IDs for user1 after relogin
expectedUser1Nodes := make([]types.NodeID, 0, len(user1Nodes))
for _, node := range user1Nodes {
expectedUser1Nodes = append(expectedUser1Nodes, types.NodeID(node.GetId()))
}
// Validate connection state after relogin as user1
requireAllClientsOnline(t, headscale, expectedUser1Nodes, true, "all user1 nodes should be connected after relogin", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedUser1Nodes, "all user1 nodes should have NetInfo and DERP after relogin", 3*time.Minute)
// Validate that user2 still has their original nodes after user1's re-authentication
// When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created
// for the new user. The original nodes remain with the original user.
var user2Nodes []*v1.Node
t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
user2Nodes, err = headscale.ListNodes("user2")
assert.NoError(ct, err)
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should have half the clients")
}, 20*time.Second, 1*time.Second)
assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin")
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes))
}, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)")
t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat))
for _, client := range allClients {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1", client.Hostname())
}, 30*time.Second, 2*time.Second)
assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName)
}, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after auth key user switch", client.Hostname()))
}
}
@@ -362,7 +328,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
opts := []hsic.Option{
@@ -376,13 +342,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
}
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -396,7 +362,14 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
}
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
// Collect expected node IDs for validation
expectedNodes := collectExpectedNodeIDs(t, allClients)
// Validate initial connection state
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes))
@@ -411,7 +384,10 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
}
err = scenario.WaitForTailscaleLogout()
assertNoErrLogout(t, err)
requireNoErrLogout(t, err)
// Validate that all nodes are offline after logout
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second)
t.Logf("all clients logged out")
@@ -425,7 +401,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
}
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
require.NoError(t, err)
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
@@ -443,7 +419,8 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
"expire",
key.GetKey(),
})
assertNoErr(t, err)
require.NoError(t, err)
require.NoError(t, err)
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
assert.ErrorContains(t, err, "authkey expired")

View File

@@ -5,17 +5,20 @@ import (
"net/netip"
"net/url"
"sort"
"strconv"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/oauth2-proxy/mockoidc"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOIDCAuthenticationPingAll(t *testing.T) {
@@ -34,7 +37,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@@ -52,16 +55,16 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -73,10 +76,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
require.NoError(t, err)
want := []*v1.User{
{
@@ -142,7 +145,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
@@ -157,18 +160,18 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
hsic.WithTestName("oidcexpirenodes"),
hsic.WithConfigEnv(oidcMap),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
// Record when sync completes to better estimate token expiry timing
syncCompleteTime := time.Now()
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
loginDuration := time.Since(syncCompleteTime)
t.Logf("Login and sync completed in %v", loginDuration)
@@ -349,7 +352,7 @@ func TestOIDC024UserCreation(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
@@ -367,20 +370,20 @@ func TestOIDC024UserCreation(t *testing.T) {
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
// Ensure that the nodes have logged in, this is what
// triggers user creation via OIDC.
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
want := tt.want(scenario.mockOIDC.Issuer())
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
require.NoError(t, err)
sort.Slice(listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
@@ -406,7 +409,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
@@ -424,17 +427,17 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
// Get all clients and verify they can connect
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
@@ -444,6 +447,11 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
// TestOIDCReloginSameNodeNewUser tests the scenario where:
// 1. A Tailscale client logs in with user1 (creates node1 for user1)
// 2. The same client logs out and logs in with user2 (creates node2 for user2)
// 3. The same client logs out and logs in with user1 again (reuses node1, node2 remains)
// This validates that OIDC relogin properly handles node reuse and cleanup.
func TestOIDCReloginSameNodeNewUser(t *testing.T) {
IntegrationSkip(t)
@@ -458,7 +466,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
oidcMockUser("user1", true),
},
})
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
@@ -477,24 +485,25 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithDERPAsIP(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
assertNoErr(t, err)
require.NoError(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
require.NoError(t, err)
_, err = doLoginURL(ts.Hostname(), u)
assertNoErr(t, err)
require.NoError(t, err)
t.Logf("Validating initial user creation at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 1)
assert.NoError(ct, err, "Failed to list users during initial validation")
assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
@@ -510,44 +519,61 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
})
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
t.Fatalf("unexpected users: %s", diff)
ct.Errorf("User validation failed after first login - unexpected users: %s", diff)
}
}, 30*time.Second, 1*time.Second, "validating users after first login")
}, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login")
listNodes, err := headscale.ListNodes()
assertNoErr(t, err)
assert.Len(t, listNodes, 1)
t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat))
var listNodes []*v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes during initial validation")
assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes))
}, 30*time.Second, 1*time.Second, "validating initial node creation for user1 after OIDC login")
// Collect expected node IDs for validation after user1 initial login
expectedNodes := make([]types.NodeID, 0, 1)
status := ts.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
require.NoError(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
// Validate initial connection state for user1
validateInitialConnection(t, headscale, expectedNodes)
// Log out user1 and log in user2, this should create a new node
// for user2, the node should have the same machine key and
// a new node key.
err = ts.Logout()
assertNoErr(t, err)
require.NoError(t, err)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
// manually.
err = ts.Logout()
assertNoErr(t, err)
require.NoError(t, err)
// Wait for logout to complete and then do second logout
t.Logf("Waiting for user1 logout completion at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 30*time.Second, 1*time.Second)
assert.NoError(ct, err, "Failed to get client status during logout validation")
assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout, got %s", status.BackendState)
}, 30*time.Second, 1*time.Second, "waiting for user1 logout to complete before user2 login")
u, err = ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
require.NoError(t, err)
_, err = doLoginURL(ts.Hostname(), u)
assertNoErr(t, err)
require.NoError(t, err)
t.Logf("Validating user2 creation at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 2)
assert.NoError(ct, err, "Failed to list users after user2 login")
assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
@@ -570,27 +596,83 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
})
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
ct.Errorf("unexpected users: %s", diff)
ct.Errorf("User validation failed after user2 login - expected both user1 and user2: %s", diff)
}
}, 30*time.Second, 1*time.Second, "validating users after new user login")
}, 30*time.Second, 1*time.Second, "validating both user1 and user2 exist after second OIDC login")
var listNodesAfterNewUserLogin []*v1.Node
// First, wait for the new node to be created
t.Logf("Waiting for user2 node creation at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listNodesAfterNewUserLogin, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, listNodesAfterNewUserLogin, 2)
assert.NoError(ct, err, "Failed to list nodes after user2 login")
// We might temporarily have more than 2 nodes during cleanup, so check for at least 2
assert.GreaterOrEqual(ct, len(listNodesAfterNewUserLogin), 2, "Should have at least 2 nodes after user2 login, got %d (may include temporary nodes during cleanup)", len(listNodesAfterNewUserLogin))
}, 30*time.Second, 1*time.Second, "waiting for user2 node creation (allowing temporary extra nodes during cleanup)")
// Machine key is the same as the "machine" has not changed,
// but Node key is not as it is a new node
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
}, 30*time.Second, 1*time.Second, "listing nodes after new user login")
// Then wait for cleanup to stabilize at exactly 2 nodes
t.Logf("Waiting for node cleanup stabilization at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listNodesAfterNewUserLogin, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes during cleanup validation")
assert.Len(ct, listNodesAfterNewUserLogin, 2, "Should have exactly 2 nodes after cleanup (1 for user1, 1 for user2), got %d nodes", len(listNodesAfterNewUserLogin))
// Validate that both nodes have the same machine key but different node keys
if len(listNodesAfterNewUserLogin) >= 2 {
// Machine key is the same as the "machine" has not changed,
// but Node key is not as it is a new node
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Machine key should be preserved from original node")
assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "Both nodes should share the same machine key")
assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey(), "Node keys should be different between user1 and user2 nodes")
}
}, 90*time.Second, 2*time.Second, "waiting for node count stabilization at exactly 2 nodes after user2 login")
// Security validation: Only user2's node should be active after user switch
var activeUser2NodeID types.NodeID
for _, node := range listNodesAfterNewUserLogin {
if node.GetUser().GetId() == 2 { // user2
activeUser2NodeID = types.NodeID(node.GetId())
t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName())
break
}
}
// Validate only user2's node is online (security requirement)
t.Logf("Validating only user2 node is online at %s", time.Now().Format(TimestampFormat))
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
// Check user2 node is online
if node, exists := nodeStore[activeUser2NodeID]; exists {
assert.NotNil(c, node.IsOnline, "User2 node should have online status")
if node.IsOnline != nil {
assert.True(c, *node.IsOnline, "User2 node should be online after login")
}
} else {
assert.Fail(c, "User2 node not found in nodestore")
}
}, 60*time.Second, 2*time.Second, "validating only user2 node is online after user switch")
// Before logging out user2, validate we have exactly 2 nodes and both are stable
t.Logf("Pre-logout validation: checking node stability at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
currentNodes, err := headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes before user2 logout")
assert.Len(ct, currentNodes, 2, "Should have exactly 2 stable nodes before user2 logout, got %d", len(currentNodes))
// Validate node stability - ensure no phantom nodes
for i, node := range currentNodes {
assert.NotNil(ct, node.GetUser(), "Node %d should have a valid user before logout", i)
assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should have a valid machine key before logout", i)
t.Logf("Pre-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...")
}
}, 60*time.Second, 2*time.Second, "validating stable node count and integrity before user2 logout")
// Log out user2, and log into user1, no new node should be created,
// the node should now "become" node1 again
err = ts.Logout()
assertNoErr(t, err)
require.NoError(t, err)
t.Logf("Logged out take one")
t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
@@ -599,41 +681,63 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// logs in immediately after the first logout and I cannot reproduce it
// manually.
err = ts.Logout()
assertNoErr(t, err)
require.NoError(t, err)
t.Logf("Logged out take two")
t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
// Wait for logout to complete and then do second logout
t.Logf("Waiting for user2 logout completion at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 30*time.Second, 1*time.Second)
assert.NoError(ct, err, "Failed to get client status during user2 logout validation")
assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after user2 logout, got %s", status.BackendState)
}, 30*time.Second, 1*time.Second, "waiting for user2 logout to complete before user1 relogin")
// Before logging back in, ensure we still have exactly 2 nodes
// Note: We skip validateLogoutComplete here since it expects all nodes to be offline,
// but in OIDC scenario we maintain both nodes in DB with only active user online
// Additional validation that nodes are properly maintained during logout
t.Logf("Post-logout validation: checking node persistence at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
currentNodes, err := headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes after user2 logout")
assert.Len(ct, currentNodes, 2, "Should still have exactly 2 nodes after user2 logout (nodes should persist), got %d", len(currentNodes))
// Ensure both nodes are still valid (not cleaned up incorrectly)
for i, node := range currentNodes {
assert.NotNil(ct, node.GetUser(), "Node %d should still have a valid user after user2 logout", i)
assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should still have a valid machine key after user2 logout", i)
t.Logf("Post-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...")
}
}, 60*time.Second, 2*time.Second, "validating node persistence and integrity after user2 logout")
// We do not actually "change" the user here, it is done by logging in again
// as the OIDC mock server is kind of like a stack, and the next user is
// prepared and ready to go.
u, err = ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
require.NoError(t, err)
_, err = doLoginURL(ts.Hostname(), u)
assertNoErr(t, err)
require.NoError(t, err)
t.Logf("Waiting for user1 relogin completion at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "Running", status.BackendState)
}, 30*time.Second, 1*time.Second)
assert.NoError(ct, err, "Failed to get client status during user1 relogin validation")
assert.Equal(ct, "Running", status.BackendState, "Expected Running state after user1 relogin, got %s", status.BackendState)
}, 30*time.Second, 1*time.Second, "waiting for user1 relogin to complete (final login)")
t.Logf("Logged back in")
t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
t.Logf("Final validation: checking user persistence at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err)
assert.Len(ct, listUsers, 2)
assert.NoError(ct, err, "Failed to list users during final validation")
assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
@@ -656,37 +760,77 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
})
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
ct.Errorf("unexpected users: %s", diff)
ct.Errorf("Final user validation failed - both users should persist after relogin cycle: %s", diff)
}
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
}, 30*time.Second, 1*time.Second, "validating user persistence after complete relogin cycle (user1->user2->user1)")
var listNodesAfterLoggingBackIn []*v1.Node
// Wait for login to complete and nodes to stabilize
t.Logf("Final node validation: checking node stability after user1 relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listNodesAfterLoggingBackIn, err := headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, listNodesAfterLoggingBackIn, 2)
listNodesAfterLoggingBackIn, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes during final validation")
// Allow for temporary instability during login process
if len(listNodesAfterLoggingBackIn) < 2 {
ct.Errorf("Not enough nodes yet during final validation, got %d, want at least 2", len(listNodesAfterLoggingBackIn))
return
}
// Final check should have exactly 2 nodes
assert.Len(ct, listNodesAfterLoggingBackIn, 2, "Should have exactly 2 nodes after complete relogin cycle, got %d", len(listNodesAfterLoggingBackIn))
// Validate that the machine we had when we logged in the first time, has the same
// machine key, but a different ID than the newly logged in version of the same
// machine.
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Original user1 machine key should match user1 node after user switch")
assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey(), "Original user1 node key should match user1 node after user switch")
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId(), "Original user1 node ID should match user1 node after user switch")
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "User1 and user2 nodes should share the same machine key")
assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId(), "User1 and user2 nodes should have different node IDs")
assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId(), "User1 and user2 nodes should belong to different users")
// Even tho we are logging in again with the same user, the previous key has been expired
// and a new one has been generated. The node entry in the database should be the same
// as the user + machinekey still matches.
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey(), "Machine key should remain consistent after user1 relogin")
assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey(), "Node key should be regenerated after user1 relogin")
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId(), "Node ID should be preserved for user1 after relogin")
// The "logged back in" machine should have the same machinekey but a different nodekey
// than the version logged in with a different user.
assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey(), "Both final nodes should share the same machine key")
assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey(), "Final nodes should have different node keys for different users")
t.Logf("Final validation complete - node counts and key relationships verified at %s", time.Now().Format(TimestampFormat))
}, 60*time.Second, 2*time.Second, "validating final node state after complete user1->user2->user1 relogin cycle with detailed key validation")
// Security validation: Only user1's node should be active after relogin
var activeUser1NodeID types.NodeID
for _, node := range listNodesAfterLoggingBackIn {
if node.GetUser().GetId() == 1 { // user1
activeUser1NodeID = types.NodeID(node.GetId())
t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName())
break
}
}
// Validate only user1's node is online (security requirement)
t.Logf("Validating only user1 node is online after relogin at %s", time.Now().Format(TimestampFormat))
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
// Check user1 node is online
if node, exists := nodeStore[activeUser1NodeID]; exists {
assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin")
if node.IsOnline != nil {
assert.True(c, *node.IsOnline, "User1 node should be online after relogin")
}
} else {
assert.Fail(c, "User1 node not found in nodestore after relogin")
}
}, 60*time.Second, 2*time.Second, "validating only user1 node is online after final relogin")
}
// TestOIDCFollowUpUrl validates the follow-up login flow
@@ -709,7 +853,7 @@ func TestOIDCFollowUpUrl(t *testing.T) {
},
)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
@@ -730,43 +874,43 @@ func TestOIDCFollowUpUrl(t *testing.T) {
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
)
assertNoErrHeadscaleEnv(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
require.NoError(t, err)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode(
"unstable",
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
)
assertNoErr(t, err)
require.NoError(t, err)
u, err := ts.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
require.NoError(t, err)
// wait for the registration cache to expire
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
time.Sleep(2 * time.Minute)
st, err := ts.Status()
assertNoErr(t, err)
require.NoError(t, err)
assert.Equal(t, "NeedsLogin", st.BackendState)
// get new AuthURL from daemon
newUrl, err := url.Parse(st.AuthURL)
assertNoErr(t, err)
require.NoError(t, err)
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
_, err = doLoginURL(ts.Hostname(), newUrl)
assertNoErr(t, err)
require.NoError(t, err)
listUsers, err = headscale.ListUsers()
assertNoErr(t, err)
require.NoError(t, err)
assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{
@@ -795,30 +939,230 @@ func TestOIDCFollowUpUrl(t *testing.T) {
}
listNodes, err := headscale.ListNodes()
assertNoErr(t, err)
require.NoError(t, err)
assert.Len(t, listNodes, 1)
}
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
// are in the logged-out state (NeedsLogin).
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
if h, ok := t.(interface{ Helper() }); ok {
h.Helper()
// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client
// authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user.
//
// OIDC is an authentication layer built on top of OAuth 2.0 that allows users to authenticate
// using external identity providers (like Google, Microsoft, etc.) rather than managing
// credentials directly in headscale.
//
// This test validates the "same user relogin" behavior in headscale's OIDC authentication flow:
// - A single client authenticates via OIDC as user1
// - The client logs out, ending the session
// - The same client logs back in via OIDC as the same user (user1)
// - The test verifies that the user account persists correctly
// - The test verifies that the machine key is preserved (since it's the same physical device)
// - The test verifies that the node ID is preserved (since it's the same user on the same device)
// - The test verifies that the node key is regenerated (since it's a new session)
// - The test verifies that the client comes back online properly
//
// This scenario is important for normal user workflows where someone might need to restart
// their Tailscale client, reboot their computer, or temporarily disconnect and reconnect.
// It ensures that headscale properly handles session management while preserving device
// identity and user associations.
//
// The test uses a single node scenario (unlike multi-node tests) to focus specifically on
// the authentication and session management aspects rather than network topology changes.
// The "same node" in the name refers to the same physical device/client, while "same user"
// refers to authenticating with the same OIDC identity.
func TestOIDCReloginSameNodeSameUser(t *testing.T) {
IntegrationSkip(t)
// Create scenario with same user for both login attempts
scenario, err := NewScenario(ScenarioSpec{
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true), // Initial login
oidcMockUser("user1", true), // Relogin with same user
},
})
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
}
for _, client := range clients {
status, err := client.Status()
assert.NoError(t, err, "failed to get status for client %s", client.Hostname())
assert.Equal(t, "NeedsLogin", status.BackendState,
"client %s should be logged out", client.Hostname())
}
}
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcsameuser"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithDERPAsIP(),
)
requireNoErrHeadscaleEnv(t, err)
func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
return mockoidc.MockUser{
Subject: username,
PreferredUsername: username,
Email: username + "@headscale.net",
EmailVerified: emailVerified,
}
headscale, err := scenario.Headscale()
require.NoError(t, err)
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
require.NoError(t, err)
// Initial login as user1
u, err := ts.LoginWithURL(headscale.GetEndpoint())
require.NoError(t, err)
_, err = doLoginURL(ts.Hostname(), u)
require.NoError(t, err)
t.Logf("Validating initial user1 creation at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err, "Failed to list users during initial validation")
assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
sort.Slice(listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
})
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
ct.Errorf("User validation failed after first login - unexpected users: %s", diff)
}
}, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login")
t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat))
var initialNodes []*v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
initialNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes during initial validation")
assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes))
}, 30*time.Second, 1*time.Second, "validating initial node creation for user1 after OIDC login")
// Collect expected node IDs for validation after user1 initial login
expectedNodes := make([]types.NodeID, 0, 1)
status := ts.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
require.NoError(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
// Validate initial connection state for user1
validateInitialConnection(t, headscale, expectedNodes)
// Store initial node keys for comparison
initialMachineKey := initialNodes[0].GetMachineKey()
initialNodeKey := initialNodes[0].GetNodeKey()
initialNodeID := initialNodes[0].GetId()
// Logout user1
err = ts.Logout()
require.NoError(t, err)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
// manually.
err = ts.Logout()
require.NoError(t, err)
// Wait for logout to complete
t.Logf("Waiting for user1 logout completion at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the logout completed
status, err := ts.Status()
assert.NoError(ct, err, "Failed to get client status during logout validation")
assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout, got %s", status.BackendState)
}, 30*time.Second, 1*time.Second, "waiting for user1 logout to complete before same-user relogin")
// Validate node persistence during logout (node should remain in DB)
t.Logf("Validating node persistence during logout at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listNodes, err := headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes during logout validation")
assert.Len(ct, listNodes, 1, "Should still have exactly 1 node during logout (node should persist in DB), got %d", len(listNodes))
}, 30*time.Second, 1*time.Second, "validating node persistence in database during same-user logout")
// Login again as the same user (user1)
u, err = ts.LoginWithURL(headscale.GetEndpoint())
require.NoError(t, err)
_, err = doLoginURL(ts.Hostname(), u)
require.NoError(t, err)
t.Logf("Waiting for user1 relogin completion at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := ts.Status()
assert.NoError(ct, err, "Failed to get client status during relogin validation")
assert.Equal(ct, "Running", status.BackendState, "Expected Running state after user1 relogin, got %s", status.BackendState)
}, 30*time.Second, 1*time.Second, "waiting for user1 relogin to complete (same user)")
t.Logf("Final validation: checking user persistence after same-user relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err, "Failed to list users during final validation")
assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
sort.Slice(listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
})
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
ct.Errorf("Final user validation failed - user1 should persist after same-user relogin: %s", diff)
}
}, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle")
var finalNodes []*v1.Node
t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
finalNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes during final validation")
assert.Len(ct, finalNodes, 1, "Should have exactly 1 node after same-user relogin, got %d", len(finalNodes))
// Validate node key behavior for same user relogin
finalNode := finalNodes[0]
// Machine key should be preserved (same physical machine)
assert.Equal(ct, initialMachineKey, finalNode.GetMachineKey(), "Machine key should be preserved for same user same node relogin")
// Node ID should be preserved (same user, same machine)
assert.Equal(ct, initialNodeID, finalNode.GetId(), "Node ID should be preserved for same user same node relogin")
// Node key should be regenerated (new session after logout)
assert.NotEqual(ct, initialNodeKey, finalNode.GetNodeKey(), "Node key should be regenerated after logout/relogin even for same user")
t.Logf("Final validation complete - same user relogin key relationships verified at %s", time.Now().Format(TimestampFormat))
}, 60*time.Second, 2*time.Second, "validating final node state after same-user OIDC relogin cycle with key preservation validation")
// Security validation: user1's node should be active after relogin
activeUser1NodeID := types.NodeID(finalNodes[0].GetId())
t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat))
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
// Check user1 node is online
if node, exists := nodeStore[activeUser1NodeID]; exists {
assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin")
if node.IsOnline != nil {
assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin")
}
} else {
assert.Fail(c, "User1 node not found in nodestore after same-user relogin")
}
}, 60*time.Second, 2*time.Second, "validating user1 node is online after same-user OIDC relogin")
}

View File

@@ -1,15 +1,19 @@
package integration
import (
"fmt"
"net/netip"
"slices"
"testing"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
@@ -33,16 +37,16 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
hsic.WithDERPAsIP(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -54,7 +58,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
@@ -63,7 +67,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnvWithLoginURL(
@@ -72,16 +76,16 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
hsic.WithDERPAsIP(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -93,15 +97,22 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
// Collect expected node IDs for validation
expectedNodes := collectExpectedNodeIDs(t, allClients)
// Validate initial connection state
validateInitialConnection(t, headscale, expectedNodes)
var listNodes []*v1.Node
t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, listNodes, len(allClients), "Node count should match client count after login")
}, 20*time.Second, 1*time.Second)
assert.NoError(ct, err, "Failed to list nodes after web authentication")
assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes))
}, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication")
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@@ -122,7 +133,10 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
}
err = scenario.WaitForTailscaleLogout()
assertNoErrLogout(t, err)
requireNoErrLogout(t, err)
// Validate that all nodes are offline after logout
validateLogoutComplete(t, headscale, expectedNodes)
t.Logf("all clients logged out")
@@ -135,8 +149,20 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("all clients logged in again")
t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes after web flow logout")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
}, 60*time.Second, 2*time.Second, "validating node persistence in database after web flow logout")
t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes))
// Validate connection state after relogin
validateReloginComplete(t, headscale, expectedNodes)
allIps, err = scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
allAddrs = lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
@@ -145,14 +171,6 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
success = pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count after re-login")
}, 20*time.Second, 1*time.Second)
t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes))
for _, client := range allClients {
ips, err := client.IPs()
if err != nil {
@@ -180,3 +198,166 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("all clients IPs are the same")
}
// TestAuthWebFlowLogoutAndReloginNewUser tests the scenario where multiple Tailscale clients
// initially authenticate using the web-based authentication flow (where users visit a URL
// in their browser to authenticate), then all clients log out and log back in as a different user.
//
// This test validates the "user switching" behavior in headscale's web authentication flow:
// - Multiple clients authenticate via web flow, each to their respective users (user1, user2)
// - All clients log out simultaneously
// - All clients log back in via web flow, but this time they all authenticate as user1
// - The test verifies that user1 ends up with all the client nodes
// - The test verifies that user2's original nodes still exist in the database but are offline
// - The test verifies network connectivity works after the user switch
//
// This scenario is important for organizations that need to reassign devices between users
// or when consolidating multiple user accounts. It ensures that headscale properly handles
// the security implications of user switching while maintaining node persistence in the database.
//
// The test uses headscale's web authentication flow, which is the most user-friendly method
// where authentication happens through a web browser rather than pre-shared keys or OIDC.
func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("webflowrelnewuser"),
hsic.WithDERPAsIP(),
hsic.WithTLS(),
)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
requireNoErrGetHeadscale(t, err)
// Collect expected node IDs for validation
expectedNodes := collectExpectedNodeIDs(t, allClients)
// Validate initial connection state
validateInitialConnection(t, headscale, expectedNodes)
var listNodes []*v1.Node
t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes after initial web authentication")
assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes))
}, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication")
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
// Log out all clients
for _, client := range allClients {
err := client.Logout()
if err != nil {
t.Fatalf("failed to logout client %s: %s", client.Hostname(), err)
}
}
err = scenario.WaitForTailscaleLogout()
requireNoErrLogout(t, err)
// Validate that all nodes are offline after logout
validateLogoutComplete(t, headscale, expectedNodes)
t.Logf("all clients logged out")
// Log all clients back in as user1 using web flow
// We manually iterate over all clients and authenticate each one as user1
// This tests the cross-user re-authentication behavior where ALL clients
// (including those originally from user2) are registered to user1
for _, client := range allClients {
loginURL, err := client.LoginWithURL(headscale.GetEndpoint())
if err != nil {
t.Fatalf("failed to get login URL for client %s: %s", client.Hostname(), err)
}
body, err := doLoginURL(client.Hostname(), loginURL)
if err != nil {
t.Fatalf("failed to complete login for client %s: %s", client.Hostname(), err)
}
// Register all clients as user1 (this is where cross-user registration happens)
// This simulates: headscale nodes register --user user1 --key <key>
scenario.runHeadscaleRegister("user1", body)
}
// Wait for all clients to reach running state
for _, client := range allClients {
err := client.WaitForRunning(integrationutil.PeerSyncTimeout())
if err != nil {
t.Fatalf("%s tailscale node has not reached running: %s", client.Hostname(), err)
}
}
t.Logf("all clients logged back in as user1")
var user1Nodes []*v1.Node
t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
user1Nodes, err = headscale.ListNodes("user1")
assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin")
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes))
}, 60*time.Second, 2*time.Second, "validating user1 has all client nodes after web flow user switch relogin")
// Collect expected node IDs for user1 after relogin
expectedUser1Nodes := make([]types.NodeID, 0, len(user1Nodes))
for _, node := range user1Nodes {
expectedUser1Nodes = append(expectedUser1Nodes, types.NodeID(node.GetId()))
}
// Validate connection state after relogin as user1
validateReloginComplete(t, headscale, expectedUser1Nodes)
// Validate that user2's old nodes still exist in database (but are expired/offline)
// When CLI registration creates new nodes for user1, user2's old nodes remain
var user2Nodes []*v1.Node
t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
user2Nodes, err = headscale.ListNodes("user2")
assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1")
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes))
}, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1")
t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat))
for _, client := range allClients {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after web flow user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName)
}, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after web flow user switch", client.Hostname()))
}
// Test connectivity after user switch
allIps, err = scenario.ListTailscaleClientsIPs()
requireNoErrListClientIPs(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d after web flow user switch", success, len(allClients)*len(allIps))
}

View File

@@ -54,14 +54,14 @@ func TestUserCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
var listUsers []*v1.User
var result []string
@@ -99,7 +99,7 @@ func TestUserCommand(t *testing.T) {
"--new-name=newname",
},
)
assertNoErr(t, err)
require.NoError(t, err)
var listAfterRenameUsers []*v1.User
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
@@ -138,7 +138,7 @@ func TestUserCommand(t *testing.T) {
},
&listByUsername,
)
assertNoErr(t, err)
require.NoError(t, err)
slices.SortFunc(listByUsername, sortWithID)
want := []*v1.User{
@@ -165,7 +165,7 @@ func TestUserCommand(t *testing.T) {
},
&listByID,
)
assertNoErr(t, err)
require.NoError(t, err)
slices.SortFunc(listByID, sortWithID)
want = []*v1.User{
@@ -244,7 +244,7 @@ func TestUserCommand(t *testing.T) {
},
&listAfterNameDelete,
)
assertNoErr(t, err)
require.NoError(t, err)
require.Empty(t, listAfterNameDelete)
}
@@ -260,17 +260,17 @@ func TestPreAuthKeyCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipak"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
keys := make([]*v1.PreAuthKey, count)
assertNoErr(t, err)
require.NoError(t, err)
for index := range count {
var preAuthKey v1.PreAuthKey
@@ -292,7 +292,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
},
&preAuthKey,
)
assertNoErr(t, err)
require.NoError(t, err)
keys[index] = &preAuthKey
}
@@ -313,7 +313,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
},
&listedPreAuthKeys,
)
assertNoErr(t, err)
require.NoError(t, err)
// There is one key created by "scenario.CreateHeadscaleEnv"
assert.Len(t, listedPreAuthKeys, 4)
@@ -372,7 +372,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
listedPreAuthKeys[1].GetKey(),
},
)
assertNoErr(t, err)
require.NoError(t, err)
var listedPreAuthKeysAfterExpire []v1.PreAuthKey
err = executeAndUnmarshal(
@@ -388,7 +388,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
},
&listedPreAuthKeysAfterExpire,
)
assertNoErr(t, err)
require.NoError(t, err)
assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now()))
assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now()))
@@ -404,14 +404,14 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipaknaexp"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
var preAuthKey v1.PreAuthKey
err = executeAndUnmarshal(
@@ -428,7 +428,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
},
&preAuthKey,
)
assertNoErr(t, err)
require.NoError(t, err)
var listedPreAuthKeys []v1.PreAuthKey
err = executeAndUnmarshal(
@@ -444,7 +444,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
},
&listedPreAuthKeys,
)
assertNoErr(t, err)
require.NoError(t, err)
// There is one key created by "scenario.CreateHeadscaleEnv"
assert.Len(t, listedPreAuthKeys, 2)
@@ -465,14 +465,14 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipakresueeph"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
var preAuthReusableKey v1.PreAuthKey
err = executeAndUnmarshal(
@@ -489,7 +489,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
},
&preAuthReusableKey,
)
assertNoErr(t, err)
require.NoError(t, err)
var preAuthEphemeralKey v1.PreAuthKey
err = executeAndUnmarshal(
@@ -506,7 +506,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
},
&preAuthEphemeralKey,
)
assertNoErr(t, err)
require.NoError(t, err)
assert.True(t, preAuthEphemeralKey.GetEphemeral())
assert.False(t, preAuthEphemeralKey.GetReusable())
@@ -525,7 +525,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
},
&listedPreAuthKeys,
)
assertNoErr(t, err)
require.NoError(t, err)
// There is one key created by "scenario.CreateHeadscaleEnv"
assert.Len(t, listedPreAuthKeys, 3)
@@ -543,7 +543,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -552,13 +552,13 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
)
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
u2, err := headscale.CreateUser(user2)
assertNoErr(t, err)
require.NoError(t, err)
var user2Key v1.PreAuthKey
@@ -580,7 +580,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
},
&user2Key,
)
assertNoErr(t, err)
require.NoError(t, err)
var listNodes []*v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
@@ -592,7 +592,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
}, 15*time.Second, 1*time.Second)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
require.Len(t, allClients, 1)
@@ -600,10 +600,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
// Log out from user1
err = client.Logout()
assertNoErr(t, err)
require.NoError(t, err)
err = scenario.WaitForTailscaleLogout()
assertNoErr(t, err)
require.NoError(t, err)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()
@@ -613,7 +613,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
}, 30*time.Second, 2*time.Second)
err = client.Login(headscale.GetEndpoint(), user2Key.GetKey())
assertNoErr(t, err)
require.NoError(t, err)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()
@@ -642,14 +642,14 @@ func TestApiKeyCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
keys := make([]string, count)
@@ -808,14 +808,14 @@ func TestNodeTagCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
@@ -1007,7 +1007,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -1015,10 +1015,10 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
hsic.WithTestName("cliadvtags"),
hsic.WithACLPolicy(tt.policy),
)
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Test list all nodes after added seconds
resultMachines := make([]*v1.Node, spec.NodesPerUser)
@@ -1058,14 +1058,14 @@ func TestNodeCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
@@ -1302,14 +1302,14 @@ func TestNodeExpireCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
@@ -1427,14 +1427,14 @@ func TestNodeRenameCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
@@ -1462,7 +1462,7 @@ func TestNodeRenameCommand(t *testing.T) {
"json",
},
)
assertNoErr(t, err)
require.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@@ -1480,7 +1480,7 @@ func TestNodeRenameCommand(t *testing.T) {
},
&node,
)
assertNoErr(t, err)
require.NoError(t, err)
nodes[index] = &node
}
@@ -1591,20 +1591,20 @@ func TestNodeMoveCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Randomly generated node key
regID := types.MustRegistrationID()
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
require.NoError(t, err)
_, err = headscale.Execute(
[]string{
@@ -1753,7 +1753,7 @@ func TestPolicyCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -1763,10 +1763,10 @@ func TestPolicyCommand(t *testing.T) {
"HEADSCALE_POLICY_MODE": "database",
}),
)
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
p := policyv2.Policy{
ACLs: []policyv2.ACL{
@@ -1789,7 +1789,7 @@ func TestPolicyCommand(t *testing.T) {
policyFilePath := "/etc/headscale/policy.json"
err = headscale.WriteFile(policyFilePath, pBytes)
assertNoErr(t, err)
require.NoError(t, err)
// No policy is present at this time.
// Add a new policy from a file.
@@ -1803,7 +1803,7 @@ func TestPolicyCommand(t *testing.T) {
},
)
assertNoErr(t, err)
require.NoError(t, err)
// Get the current policy and check
// if it is the same as the one we set.
@@ -1819,7 +1819,7 @@ func TestPolicyCommand(t *testing.T) {
},
&output,
)
assertNoErr(t, err)
require.NoError(t, err)
assert.Len(t, output.TagOwners, 1)
assert.Len(t, output.ACLs, 1)
@@ -1834,7 +1834,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -1844,10 +1844,10 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
"HEADSCALE_POLICY_MODE": "database",
}),
)
assertNoErr(t, err)
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
p := policyv2.Policy{
ACLs: []policyv2.ACL{
@@ -1872,7 +1872,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
policyFilePath := "/etc/headscale/policy.json"
err = headscale.WriteFile(policyFilePath, pBytes)
assertNoErr(t, err)
require.NoError(t, err)
// No policy is present at this time.
// Add a new policy from a file.

View File

@@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/require"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/net/netmon"
@@ -23,7 +24,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
// Generate random hostname for the headscale instance
hash, err := util.GenerateRandomStringDNSSafe(6)
assertNoErr(t, err)
require.NoError(t, err)
testName := "derpverify"
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
@@ -31,7 +32,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
// Create cert for headscale
certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname)
assertNoErr(t, err)
require.NoError(t, err)
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -39,14 +40,14 @@ func TestDERPVerifyEndpoint(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
derper, err := scenario.CreateDERPServer("head",
dsic.WithCACert(certHeadscale),
dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))),
)
assertNoErr(t, err)
require.NoError(t, err)
derpRegion := tailcfg.DERPRegion{
RegionCode: "test-derpverify",
@@ -74,17 +75,17 @@ func TestDERPVerifyEndpoint(t *testing.T) {
hsic.WithPort(headscalePort),
hsic.WithCustomTLS(certHeadscale, keyHeadscale),
hsic.WithDERPConfig(derpMap))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
fakeKey := key.NewNode()
DERPVerify(t, fakeKey, derpRegion, false)
for _, client := range allClients {
nodeKey, err := client.GetNodePrivateKey()
assertNoErr(t, err)
require.NoError(t, err)
DERPVerify(t, *nodeKey, derpRegion, true)
}
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
@@ -22,26 +23,26 @@ func TestResolveMagicDNS(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("magicdns"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
// Poor mans cache
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
_, err = scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
for _, client := range allClients {
for _, peer := range allClients {
@@ -78,7 +79,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
const erPath = "/tmp/extra_records.json"
@@ -109,29 +110,29 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
// Poor mans cache
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
_, err = scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
for _, client := range allClients {
assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6")
}
hs, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Write the file directly into place from the docker API.
b0, _ := json.Marshal([]tailcfg.DNSRecord{
@@ -143,7 +144,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
})
err = hs.WriteFile(erPath, b0)
assertNoErr(t, err)
require.NoError(t, err)
for _, client := range allClients {
assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "2.2.2.2")
@@ -159,9 +160,9 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
b2, _ := json.Marshal(extraRecords)
err = hs.WriteFile(erPath+"2", b2)
assertNoErr(t, err)
require.NoError(t, err)
_, err = hs.Execute([]string{"mv", erPath + "2", erPath})
assertNoErr(t, err)
require.NoError(t, err)
for _, client := range allClients {
assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6")
@@ -179,9 +180,9 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
})
err = hs.WriteFile(erPath+"3", b3)
assertNoErr(t, err)
require.NoError(t, err)
_, err = hs.Execute([]string{"cp", erPath + "3", erPath})
assertNoErr(t, err)
require.NoError(t, err)
for _, client := range allClients {
assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8")
@@ -197,7 +198,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
})
command := []string{"echo", fmt.Sprintf("'%s'", string(b4)), ">", erPath}
_, err = hs.Execute([]string{"bash", "-c", strings.Join(command, " ")})
assertNoErr(t, err)
require.NoError(t, err)
for _, client := range allClients {
assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9")
@@ -205,7 +206,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
// Delete the file and create a new one to ensure it is picked up again.
_, err = hs.Execute([]string{"rm", erPath})
assertNoErr(t, err)
require.NoError(t, err)
// The same paths should still be available as it is not cleared on delete.
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
@@ -219,7 +220,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
// Write a new file, the backoff mechanism should make the filewatcher pick it up
// again.
err = hs.WriteFile(erPath, b3)
assertNoErr(t, err)
require.NoError(t, err)
for _, client := range allClients {
assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8")

View File

@@ -0,0 +1,17 @@
package dockertestutil
import (
"os/exec"
)
// RunDockerBuildForDiagnostics runs docker build manually to get detailed error output.
// This is used when a docker build fails to provide more detailed diagnostic information
// than what dockertest typically provides.
func RunDockerBuildForDiagnostics(contextDir, dockerfile string) string {
cmd := exec.Command("docker", "build", "-f", dockerfile, contextDir)
output, err := cmd.CombinedOutput()
if err != nil {
return string(output)
}
return ""
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@@ -29,7 +30,7 @@ func TestDERPServerScenario(t *testing.T) {
derpServerScenario(t, spec, false, func(scenario *Scenario) {
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
t.Logf("checking %d clients for websocket connections", len(allClients))
for _, client := range allClients {
@@ -43,7 +44,7 @@ func TestDERPServerScenario(t *testing.T) {
}
hsServer, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
derpRegion := tailcfg.DERPRegion{
RegionCode: "test-derpverify",
@@ -79,7 +80,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
derpServerScenario(t, spec, true, func(scenario *Scenario) {
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
t.Logf("checking %d clients for websocket connections", len(allClients))
for _, client := range allClients {
@@ -108,7 +109,7 @@ func derpServerScenario(
IntegrationSkip(t)
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@@ -128,16 +129,16 @@ func derpServerScenario(
"HEADSCALE_DERP_SERVER_VERIFY_CLIENTS": "true",
}),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
for _, client := range allClients {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {

View File

@@ -10,19 +10,15 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/types/key"
@@ -38,7 +34,7 @@ func TestPingAllByIP(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -48,16 +44,16 @@ func TestPingAllByIP(t *testing.T) {
hsic.WithTLS(),
hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
hs, err := scenario.Headscale()
require.NoError(t, err)
@@ -80,7 +76,7 @@ func TestPingAllByIP(t *testing.T) {
// Get headscale instance for batcher debug check
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Test our DebugBatcher functionality
t.Logf("Testing DebugBatcher functionality...")
@@ -99,23 +95,23 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("pingallbyippubderp"),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -148,11 +144,11 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
headscale, err := scenario.Headscale(opts...)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
for _, userName := range spec.Users {
user, err := scenario.CreateUser(userName)
@@ -177,13 +173,13 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
@@ -200,7 +196,7 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
}
err = scenario.WaitForTailscaleLogout()
assertNoErrLogout(t, err)
requireNoErrLogout(t, err)
t.Logf("all clients logged out")
@@ -222,7 +218,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
headscale, err := scenario.Headscale(
@@ -231,7 +227,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s",
}),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
for _, userName := range spec.Users {
user, err := scenario.CreateUser(userName)
@@ -256,13 +252,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
@@ -344,22 +340,22 @@ func TestPingAllByHostname(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
success := pingAllHelper(t, allClients, allHostnames)
@@ -379,7 +375,7 @@ func TestTaildrop(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{},
@@ -387,17 +383,17 @@ func TestTaildrop(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// This will essentially fetch and cache all the FQDNs
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
for _, client := range allClients {
if !strings.Contains(client.Hostname(), "head") {
@@ -498,7 +494,7 @@ func TestTaildrop(t *testing.T) {
)
result, _, err := client.Execute(command)
assertNoErrf(t, "failed to execute command to ls taildrop: %s", err)
require.NoErrorf(t, err, "failed to execute command to ls taildrop")
log.Printf("Result for %s: %s\n", peer.Hostname(), result)
if fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()) != result {
@@ -528,25 +524,25 @@ func TestUpdateHostnameFromClient(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErrf(t, "failed to create scenario: %s", err)
require.NoErrorf(t, err, "failed to create scenario")
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
// update hostnames using the up command
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
require.NoError(t, err)
command := []string{
"tailscale",
@@ -554,11 +550,11 @@ func TestUpdateHostnameFromClient(t *testing.T) {
"--hostname=" + hostnames[string(status.Self.ID)],
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to set hostname: %s", err)
require.NoErrorf(t, err, "failed to set hostname")
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// Wait for nodestore batch processing to complete
// NodeStore batching timeout is 500ms, so we wait up to 1 second
@@ -597,7 +593,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
"--identifier",
strconv.FormatUint(node.GetId(), 10),
})
assertNoErr(t, err)
require.NoError(t, err)
}
// Verify that the server-side rename is reflected in DNSName while HostName remains unchanged
@@ -643,7 +639,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
require.NoError(t, err)
command := []string{
"tailscale",
@@ -651,11 +647,11 @@ func TestUpdateHostnameFromClient(t *testing.T) {
"--hostname=" + hostnames[string(status.Self.ID)] + "NEW",
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to set hostname: %s", err)
require.NoErrorf(t, err, "failed to set hostname")
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// Wait for nodestore batch processing to complete
// NodeStore batching timeout is 500ms, so we wait up to 1 second
@@ -696,20 +692,20 @@ func TestExpireNode(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -731,22 +727,22 @@ func TestExpireNode(t *testing.T) {
}
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// TODO(kradalby): This is Headscale specific and would not play nicely
// with other implementations of the ControlServer interface
result, err := headscale.Execute([]string{
"headscale", "nodes", "expire", "--identifier", "1", "--output", "json",
})
assertNoErr(t, err)
require.NoError(t, err)
var node v1.Node
err = json.Unmarshal([]byte(result), &node)
assertNoErr(t, err)
require.NoError(t, err)
var expiredNodeKey key.NodePublic
err = expiredNodeKey.UnmarshalText([]byte(node.GetNodeKey()))
assertNoErr(t, err)
require.NoError(t, err)
t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String())
@@ -773,14 +769,14 @@ func TestExpireNode(t *testing.T) {
// Verify that the expired node has been marked in all peers list.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
require.NoError(t, err)
if client.Hostname() != node.GetName() {
t.Logf("available peers of %s: %v", client.Hostname(), status.Peers())
// Ensures that the node is present, and that it is expired.
if peerStatus, ok := status.Peer[expiredNodeKey]; ok {
assertNotNil(t, peerStatus.Expired)
requireNotNil(t, peerStatus.Expired)
assert.NotNil(t, peerStatus.KeyExpiry)
t.Logf(
@@ -840,20 +836,20 @@ func TestNodeOnlineStatus(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -866,14 +862,14 @@ func TestNodeOnlineStatus(t *testing.T) {
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
require.NoError(t, err)
// Assert that we have the original count - self
assert.Len(t, status.Peers(), len(MustTestVersions)-1)
}
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Duration is chosen arbitrarily, 10m is reported in #1561
testDuration := 12 * time.Minute
@@ -963,7 +959,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -973,16 +969,16 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
hsic.WithDERPAsIP(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// assertClientsState(t, allClients)
@@ -992,7 +988,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
// Get headscale instance for batcher debug checks
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Initial check: all nodes should be connected to batcher
// Extract node IDs for validation
@@ -1000,7 +996,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
for _, client := range allClients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
assertNoErr(t, err)
require.NoError(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second)
@@ -1072,7 +1068,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
@@ -1081,16 +1077,16 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
@@ -1100,7 +1096,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
headscale, err := scenario.Headscale()
assertNoErr(t, err)
require.NoError(t, err)
// Test list all nodes after added otherUser
var nodeList []v1.Node
@@ -1170,159 +1166,3 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
assert.True(t, nodeListAfter[0].GetOnline())
assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId())
}
// NodeSystemStatus represents the online status of a node across different systems
type NodeSystemStatus struct {
Batcher bool
BatcherConnCount int
MapResponses bool
NodeStore bool
}
// requireAllSystemsOnline checks that nodes are online/offline across batcher, mapresponses, and nodestore
func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
t.Helper()
startTime := time.Now()
t.Logf("requireAllSystemsOnline: Starting validation at %s - %s", startTime.Format(TimestampFormat), message)
var prevReport string
require.EventuallyWithT(t, func(c *assert.CollectT) {
// Get batcher state
debugInfo, err := headscale.DebugBatcher()
assert.NoError(c, err, "Failed to get batcher debug info")
if err != nil {
return
}
// Get map responses
mapResponses, err := headscale.GetAllMapReponses()
assert.NoError(c, err, "Failed to get map responses")
if err != nil {
return
}
// Get nodestore state
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
// Validate node counts first
expectedCount := len(expectedNodes)
assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch")
assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch")
// Check that we have map responses for expected nodes
mapResponseCount := len(mapResponses)
assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch")
// Build status map for each node
nodeStatus := make(map[types.NodeID]NodeSystemStatus)
// Initialize all expected nodes
for _, nodeID := range expectedNodes {
nodeStatus[nodeID] = NodeSystemStatus{}
}
// Check batcher state
for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes {
nodeID := types.MustParseNodeID(nodeIDStr)
if status, exists := nodeStatus[nodeID]; exists {
status.Batcher = nodeInfo.Connected
status.BatcherConnCount = nodeInfo.ActiveConnections
nodeStatus[nodeID] = status
}
}
// Check map responses using buildExpectedOnlineMap
onlineFromMaps := make(map[types.NodeID]bool)
onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
for nodeID := range nodeStatus {
NODE_STATUS:
for id, peerMap := range onlineMap {
if id == nodeID {
continue
}
online := peerMap[nodeID]
// If the node is offline in any map response, we consider it offline
if !online {
onlineFromMaps[nodeID] = false
continue NODE_STATUS
}
onlineFromMaps[nodeID] = true
}
}
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
// Update status with map response data
for nodeID, online := range onlineFromMaps {
if status, exists := nodeStatus[nodeID]; exists {
status.MapResponses = online
nodeStatus[nodeID] = status
}
}
// Check nodestore state
for nodeID, node := range nodeStore {
if status, exists := nodeStatus[nodeID]; exists {
// Check if node is online in nodestore
status.NodeStore = node.IsOnline != nil && *node.IsOnline
nodeStatus[nodeID] = status
}
}
// Verify all systems show nodes in expected state and report failures
allMatch := true
var failureReport strings.Builder
ids := types.NodeIDs(maps.Keys(nodeStatus))
slices.Sort(ids)
for _, nodeID := range ids {
status := nodeStatus[nodeID]
systemsMatch := (status.Batcher == expectedOnline) &&
(status.MapResponses == expectedOnline) &&
(status.NodeStore == expectedOnline)
if !systemsMatch {
allMatch = false
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s:\n", nodeID, stateStr))
failureReport.WriteString(fmt.Sprintf(" - batcher: %t\n", status.Batcher))
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (down with at least one peer)\n", status.MapResponses))
failureReport.WriteString(fmt.Sprintf(" - nodestore: %t\n", status.NodeStore))
}
}
if !allMatch {
if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" {
t.Log("Diff between reports:")
t.Logf("Prev report: \n%s\n", prevReport)
t.Logf("New report: \n%s\n", failureReport.String())
t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
prevReport = failureReport.String()
}
failureReport.WriteString("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
assert.Fail(c, failureReport.String())
}
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
assert.True(c, allMatch, fmt.Sprintf("Not all nodes are %s across all systems", stateStr))
}, timeout, 2*time.Second, message)
endTime := time.Now()
duration := endTime.Sub(startTime)
t.Logf("requireAllSystemsOnline: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message)
}

922
integration/helpers.go Normal file
View File

@@ -0,0 +1,922 @@
package integration
import (
"bufio"
"bytes"
"fmt"
"io"
"net/netip"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/cenkalti/backoff/v5"
"github.com/google/go-cmp/cmp"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic"
"github.com/oauth2-proxy/mockoidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
const (
// derpPingTimeout defines the timeout for individual DERP ping operations
// Used in DERP connectivity tests to verify relay server communication.
derpPingTimeout = 2 * time.Second
// derpPingCount defines the number of ping attempts for DERP connectivity tests
// Higher count provides better reliability assessment of DERP connectivity.
derpPingCount = 10
// TimestampFormat is the standard timestamp format used across all integration tests
// Format: "2006-01-02T15-04-05.999999999" provides high precision timestamps
// suitable for debugging and log correlation in integration tests.
TimestampFormat = "2006-01-02T15-04-05.999999999"
// TimestampFormatRunID is used for generating unique run identifiers
// Format: "20060102-150405" provides compact date-time for file/directory names.
TimestampFormatRunID = "20060102-150405"
)
// NodeSystemStatus represents the status of a node across different systems
type NodeSystemStatus struct {
Batcher bool
BatcherConnCount int
MapResponses bool
NodeStore bool
}
// requireNotNil validates that an object is not nil and fails the test if it is.
// This helper provides consistent error messaging for nil checks in integration tests.
func requireNotNil(t *testing.T, object interface{}) {
t.Helper()
require.NotNil(t, object)
}
// requireNoErrHeadscaleEnv validates that headscale environment creation succeeded.
// Provides specific error context for headscale environment setup failures.
func requireNoErrHeadscaleEnv(t *testing.T, err error) {
t.Helper()
require.NoError(t, err, "failed to create headscale environment")
}
// requireNoErrGetHeadscale validates that headscale server retrieval succeeded.
// Provides specific error context for headscale server access failures.
func requireNoErrGetHeadscale(t *testing.T, err error) {
t.Helper()
require.NoError(t, err, "failed to get headscale")
}
// requireNoErrListClients validates that client listing operations succeeded.
// Provides specific error context for client enumeration failures.
func requireNoErrListClients(t *testing.T, err error) {
t.Helper()
require.NoError(t, err, "failed to list clients")
}
// requireNoErrListClientIPs validates that client IP retrieval succeeded.
// Provides specific error context for client IP address enumeration failures.
func requireNoErrListClientIPs(t *testing.T, err error) {
t.Helper()
require.NoError(t, err, "failed to get client IPs")
}
// requireNoErrSync validates that client synchronization operations succeeded.
// Provides specific error context for client sync failures across the network.
func requireNoErrSync(t *testing.T, err error) {
t.Helper()
require.NoError(t, err, "failed to have all clients sync up")
}
// requireNoErrListFQDN validates that FQDN listing operations succeeded.
// Provides specific error context for DNS name enumeration failures.
func requireNoErrListFQDN(t *testing.T, err error) {
t.Helper()
require.NoError(t, err, "failed to list FQDNs")
}
// requireNoErrLogout validates that tailscale node logout operations succeeded.
// Provides specific error context for client logout failures.
func requireNoErrLogout(t *testing.T, err error) {
t.Helper()
require.NoError(t, err, "failed to log out tailscale nodes")
}
// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes
func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID {
t.Helper()
expectedNodes := make([]types.NodeID, 0, len(clients))
for _, client := range clients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
require.NoError(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
return expectedNodes
}
// validateInitialConnection performs comprehensive validation after initial client login.
// Validates that all nodes are online and have proper NetInfo/DERP configuration,
// essential for ensuring successful initial connection state in relogin tests.
func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) {
t.Helper()
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
}
// validateLogoutComplete performs comprehensive validation after client logout.
// Ensures all nodes are properly offline across all headscale systems,
// critical for validating clean logout state in relogin tests.
func validateLogoutComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) {
t.Helper()
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second)
}
// validateReloginComplete performs comprehensive validation after client relogin.
// Validates that all nodes are back online with proper NetInfo/DERP configuration,
// ensuring successful relogin state restoration in integration tests.
func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) {
t.Helper()
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin", 3*time.Minute)
}
// requireAllClientsOnline validates that all nodes are online/offline across all headscale systems
// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems
func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
t.Helper()
startTime := time.Now()
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message)
if expectedOnline {
// For online validation, use the existing logic with full timeout
requireAllClientsOnlineWithSingleTimeout(t, headscale, expectedNodes, expectedOnline, message, timeout)
} else {
// For offline validation, use staged approach with component-specific timeouts
requireAllClientsOfflineStaged(t, headscale, expectedNodes, message, timeout)
}
endTime := time.Now()
t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message)
}
// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state
func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
t.Helper()
var prevReport string
require.EventuallyWithT(t, func(c *assert.CollectT) {
// Get batcher state
debugInfo, err := headscale.DebugBatcher()
assert.NoError(c, err, "Failed to get batcher debug info")
if err != nil {
return
}
// Get map responses
mapResponses, err := headscale.GetAllMapReponses()
assert.NoError(c, err, "Failed to get map responses")
if err != nil {
return
}
// Get nodestore state
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
// Validate that all expected nodes are present in nodeStore
for _, nodeID := range expectedNodes {
_, exists := nodeStore[nodeID]
assert.True(c, exists, "Expected node %d not found in nodeStore", nodeID)
}
// Check that we have map responses for expected nodes
mapResponseCount := len(mapResponses)
expectedCount := len(expectedNodes)
assert.GreaterOrEqual(c, mapResponseCount, expectedCount, "MapResponses insufficient - expected at least %d responses, got %d", expectedCount, mapResponseCount)
// Build status map for each node
nodeStatus := make(map[types.NodeID]NodeSystemStatus)
// Initialize all expected nodes
for _, nodeID := range expectedNodes {
nodeStatus[nodeID] = NodeSystemStatus{}
}
// Check batcher state for expected nodes
for _, nodeID := range expectedNodes {
nodeIDStr := fmt.Sprintf("%d", nodeID)
if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists {
if status, exists := nodeStatus[nodeID]; exists {
status.Batcher = nodeInfo.Connected
status.BatcherConnCount = nodeInfo.ActiveConnections
nodeStatus[nodeID] = status
}
} else {
// Node not found in batcher, mark as disconnected
if status, exists := nodeStatus[nodeID]; exists {
status.Batcher = false
status.BatcherConnCount = 0
nodeStatus[nodeID] = status
}
}
}
// Check map responses using buildExpectedOnlineMap
onlineFromMaps := make(map[types.NodeID]bool)
onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
// For single node scenarios, we can't validate peer visibility since there are no peers
if len(expectedNodes) == 1 {
// For single node, just check that we have map responses for the node
for nodeID := range nodeStatus {
if _, exists := onlineMap[nodeID]; exists {
onlineFromMaps[nodeID] = true
} else {
onlineFromMaps[nodeID] = false
}
}
} else {
// Multi-node scenario: check peer visibility
for nodeID := range nodeStatus {
// Initialize as offline - will be set to true only if visible in all relevant peer maps
onlineFromMaps[nodeID] = false
// Count how many peer maps should show this node
expectedPeerMaps := 0
foundOnlinePeerMaps := 0
for id, peerMap := range onlineMap {
if id == nodeID {
continue // Skip self-references
}
expectedPeerMaps++
if online, exists := peerMap[nodeID]; exists && online {
foundOnlinePeerMaps++
}
}
// Node is considered online if it appears online in all peer maps
// (or if there are no peer maps to check)
if expectedPeerMaps == 0 || foundOnlinePeerMaps == expectedPeerMaps {
onlineFromMaps[nodeID] = true
}
}
}
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
// Update status with map response data
for nodeID, online := range onlineFromMaps {
if status, exists := nodeStatus[nodeID]; exists {
status.MapResponses = online
nodeStatus[nodeID] = status
}
}
// Check nodestore state for expected nodes
for _, nodeID := range expectedNodes {
if node, exists := nodeStore[nodeID]; exists {
if status, exists := nodeStatus[nodeID]; exists {
// Check if node is online in nodestore
status.NodeStore = node.IsOnline != nil && *node.IsOnline
nodeStatus[nodeID] = status
}
}
}
// Verify all systems show nodes in expected state and report failures
allMatch := true
var failureReport strings.Builder
ids := types.NodeIDs(maps.Keys(nodeStatus))
slices.Sort(ids)
for _, nodeID := range ids {
status := nodeStatus[nodeID]
systemsMatch := (status.Batcher == expectedOnline) &&
(status.MapResponses == expectedOnline) &&
(status.NodeStore == expectedOnline)
if !systemsMatch {
allMatch = false
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat)))
failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline))
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (expected: %t, down with at least one peer)\n", status.MapResponses, expectedOnline))
failureReport.WriteString(fmt.Sprintf(" - nodestore: %t (expected: %t)\n", status.NodeStore, expectedOnline))
}
}
if !allMatch {
if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" {
t.Logf("Node state validation report changed at %s:", time.Now().Format(TimestampFormat))
t.Logf("Previous report:\n%s", prevReport)
t.Logf("Current report:\n%s", failureReport.String())
t.Logf("Report diff:\n%s", diff)
prevReport = failureReport.String()
}
failureReport.WriteString(fmt.Sprintf("validation_timestamp: %s\n", time.Now().Format(TimestampFormat)))
// Note: timeout_remaining not available in this context
assert.Fail(c, failureReport.String())
}
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr))
}, timeout, 2*time.Second, message)
}
// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components
func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) {
t.Helper()
// Stage 1: Verify batcher disconnection (should be immediate)
t.Logf("Stage 1: Verifying batcher disconnection for %d nodes", len(expectedNodes))
require.EventuallyWithT(t, func(c *assert.CollectT) {
debugInfo, err := headscale.DebugBatcher()
assert.NoError(c, err, "Failed to get batcher debug info")
if err != nil {
return
}
allBatcherOffline := true
for _, nodeID := range expectedNodes {
nodeIDStr := fmt.Sprintf("%d", nodeID)
if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected {
allBatcherOffline = false
assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID)
}
}
assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher")
}, 15*time.Second, 1*time.Second, "batcher disconnection validation")
// Stage 2: Verify nodestore offline status (up to 15 seconds due to disconnect detection delay)
t.Logf("Stage 2: Verifying nodestore offline status for %d nodes (allowing for 10s disconnect detection delay)", len(expectedNodes))
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
allNodeStoreOffline := true
for _, nodeID := range expectedNodes {
if node, exists := nodeStore[nodeID]; exists {
isOnline := node.IsOnline != nil && *node.IsOnline
if isOnline {
allNodeStoreOffline = false
assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID)
}
}
}
assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore")
}, 20*time.Second, 1*time.Second, "nodestore offline validation")
// Stage 3: Verify map response propagation (longest delay due to peer update timing)
t.Logf("Stage 3: Verifying map response propagation for %d nodes (allowing for peer map update delays)", len(expectedNodes))
require.EventuallyWithT(t, func(c *assert.CollectT) {
mapResponses, err := headscale.GetAllMapReponses()
assert.NoError(c, err, "Failed to get map responses")
if err != nil {
return
}
onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
allMapResponsesOffline := true
if len(expectedNodes) == 1 {
// Single node: check if it appears in map responses
for nodeID := range onlineMap {
if slices.Contains(expectedNodes, nodeID) {
allMapResponsesOffline = false
assert.False(c, true, "Node %d should not appear in map responses", nodeID)
}
}
} else {
// Multi-node: check peer visibility
for _, nodeID := range expectedNodes {
for id, peerMap := range onlineMap {
if id == nodeID {
continue // Skip self-references
}
if online, exists := peerMap[nodeID]; exists && online {
allMapResponsesOffline = false
assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id)
}
}
}
}
assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses")
}, 60*time.Second, 2*time.Second, "map response propagation validation")
t.Logf("All stages completed: nodes are fully offline across all systems")
}
// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database
// and a valid DERP server based on the NetInfo. This function follows the pattern of
// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state.
func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) {
t.Helper()
startTime := time.Now()
t.Logf("requireAllClientsNetInfoAndDERP: Starting NetInfo/DERP validation for %d nodes at %s - %s", len(expectedNodes), startTime.Format(TimestampFormat), message)
require.EventuallyWithT(t, func(c *assert.CollectT) {
// Get nodestore state
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
// Validate that all expected nodes are present in nodeStore
for _, nodeID := range expectedNodes {
_, exists := nodeStore[nodeID]
assert.True(c, exists, "Expected node %d not found in nodeStore during NetInfo validation", nodeID)
}
// Check each expected node
for _, nodeID := range expectedNodes {
node, exists := nodeStore[nodeID]
assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID)
if !exists {
continue
}
// Validate that the node has Hostinfo
assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname)
if node.Hostinfo == nil {
t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat))
continue
}
// Validate that the node has NetInfo
assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname)
if node.Hostinfo.NetInfo == nil {
t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat))
continue
}
// Validate that the node has a valid DERP server (PreferredDERP should be > 0)
preferredDERP := node.Hostinfo.NetInfo.PreferredDERP
assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP)
t.Logf("Node %d (%s) has valid NetInfo with DERP server %d at %s", nodeID, node.Hostname, preferredDERP, time.Now().Format(TimestampFormat))
}
}, timeout, 5*time.Second, message)
endTime := time.Now()
duration := endTime.Sub(startTime)
t.Logf("requireAllClientsNetInfoAndDERP: Completed NetInfo/DERP validation for %d nodes at %s - Duration: %v - %s", len(expectedNodes), endTime.Format(TimestampFormat), duration, message)
}
// assertLastSeenSet validates that a node has a non-nil LastSeen timestamp.
// Critical for ensuring node activity tracking is functioning properly.
func assertLastSeenSet(t *testing.T, node *v1.Node) {
assert.NotNil(t, node)
assert.NotNil(t, node.GetLastSeen())
}
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
// are in the logged-out state (NeedsLogin).
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
if h, ok := t.(interface{ Helper() }); ok {
h.Helper()
}
for _, client := range clients {
status, err := client.Status()
assert.NoError(t, err, "failed to get status for client %s", client.Hostname())
assert.Equal(t, "NeedsLogin", status.BackendState,
"client %s should be logged out", client.Hostname())
}
}
// pingAllHelper performs ping tests between all clients and addresses, returning success count.
// This is used to validate network connectivity in integration tests.
// Returns the total number of successful ping operations.
func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int {
t.Helper()
success := 0
for _, client := range clients {
for _, addr := range addrs {
err := client.Ping(addr, opts...)
if err != nil {
t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
} else {
success++
}
}
}
return success
}
// pingDerpAllHelper performs DERP-based ping tests between all clients and addresses.
// This specifically tests connectivity through DERP relay servers, which is important
// for validating NAT traversal and relay functionality. Returns success count.
func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {
t.Helper()
success := 0
for _, client := range clients {
for _, addr := range addrs {
if isSelfClient(client, addr) {
continue
}
err := client.Ping(
addr,
tsic.WithPingTimeout(derpPingTimeout),
tsic.WithPingCount(derpPingCount),
tsic.WithPingUntilDirect(false),
)
if err != nil {
t.Logf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
} else {
success++
}
}
}
return success
}
// isSelfClient determines if the given address belongs to the client itself.
// Used to avoid self-ping operations in connectivity tests by checking
// hostname and IP address matches.
func isSelfClient(client TailscaleClient, addr string) bool {
if addr == client.Hostname() {
return true
}
ips, err := client.IPs()
if err != nil {
return false
}
for _, ip := range ips {
if ip.String() == addr {
return true
}
}
return false
}
// assertClientsState validates the status and netmap of a list of clients for general connectivity.
// Runs parallel validation of status, netcheck, and netmap for all clients to ensure
// they have proper network configuration for all-to-all connectivity tests.
func assertClientsState(t *testing.T, clients []TailscaleClient) {
t.Helper()
var wg sync.WaitGroup
for _, client := range clients {
wg.Add(1)
c := client // Avoid loop pointer
go func() {
defer wg.Done()
assertValidStatus(t, c)
assertValidNetcheck(t, c)
assertValidNetmap(t, c)
}()
}
t.Logf("waiting for client state checks to finish")
wg.Wait()
}
// assertValidNetmap validates that a client's netmap has all required fields for proper operation.
// Checks self node and all peers for essential networking data including hostinfo, addresses,
// endpoints, and DERP configuration. Skips validation for Tailscale versions below 1.56.
// This test is not suitable for ACL/partial connection tests.
func assertValidNetmap(t *testing.T, client TailscaleClient) {
t.Helper()
if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) {
t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version())
return
}
t.Logf("Checking netmap of %q", client.Hostname())
netmap, err := client.Netmap()
if err != nil {
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
}
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
}
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
for _, peer := range netmap.Peers {
assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
if hi := peer.Hostinfo(); hi.Valid() {
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
// Netinfo is not always set
// assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
if ni := hi.NetInfo(); ni.Valid() {
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
}
}
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
}
}
// assertValidStatus validates that a client's status has all required fields for proper operation.
// Checks self and peer status for essential data including hostinfo, tailscale IPs, endpoints,
// and network map presence. This test is not suitable for ACL/partial connection tests.
func assertValidStatus(t *testing.T, client TailscaleClient) {
t.Helper()
status, err := client.Status(true)
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
}
assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname())
// This seem to not appear until version 1.56
if status.Self.AllowedIPs != nil {
assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname())
}
assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname())
assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname())
assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname())
// This isn't really relevant for Self as it won't be in its own socket/wireguard.
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
// assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname())
for _, peer := range status.Peer {
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname())
// This seem to not appear until version 1.56
if peer.AllowedIPs != nil {
assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname())
}
// Addrs does not seem to appear in the status from peers.
// assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname())
assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname())
assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname())
assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname())
// TODO(kradalby): InEngine is only true when a proper tunnel is set up,
// there might be some interesting stuff to test here in the future.
// assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname())
}
}
// assertValidNetcheck validates that a client has a proper DERP relay configured.
// Ensures the client has discovered and selected a DERP server for relay functionality,
// which is essential for NAT traversal and connectivity in restricted networks.
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
t.Helper()
report, err := client.Netcheck()
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
}
assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname())
}
// assertCommandOutputContains executes a command with exponential backoff retry until the output
// contains the expected string or timeout is reached (10 seconds).
// This implements eventual consistency patterns and should be used instead of time.Sleep
// before executing commands that depend on network state propagation.
//
// Timeout: 10 seconds with exponential backoff
// Use cases: DNS resolution, route propagation, policy updates.
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
t.Helper()
_, err := backoff.Retry(t.Context(), func() (struct{}, error) {
stdout, stderr, err := c.Execute(command)
if err != nil {
return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
}
if !strings.Contains(stdout, contains) {
return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout)
}
return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
assert.NoError(t, err)
}
// dockertestMaxWait returns the maximum wait time for Docker-based test operations.
// Uses longer timeouts in CI environments to account for slower resource allocation
// and higher system load during automated testing.
func dockertestMaxWait() time.Duration {
wait := 300 * time.Second //nolint
if util.IsCI() {
wait = 600 * time.Second //nolint
}
return wait
}
// didClientUseWebsocketForDERP analyzes client logs to determine if WebSocket was used for DERP.
// Searches for WebSocket connection indicators in client logs to validate
// DERP relay communication method for debugging connectivity issues.
func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool {
t.Helper()
buf := &bytes.Buffer{}
err := client.WriteLogs(buf, buf)
if err != nil {
t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err)
}
count, err := countMatchingLines(buf, func(line string) bool {
return strings.Contains(line, "websocket: connected to ")
})
if err != nil {
t.Fatalf("failed to process client logs: %s: %s", client.Hostname(), err)
}
return count > 0
}
// countMatchingLines counts lines in a reader that match the given predicate function.
// Uses optimized buffering for log analysis and provides flexible line-by-line
// filtering for log parsing and pattern matching in integration tests.
func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) {
count := 0
scanner := bufio.NewScanner(in)
{
const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB
buff := make([]byte, logBufferInitialSize)
scanner.Buffer(buff, len(buff))
scanner.Split(bufio.ScanLines)
}
for scanner.Scan() {
if predicate(scanner.Text()) {
count += 1
}
}
return count, scanner.Err()
}
// wildcard returns a wildcard alias (*) for use in policy v2 configurations.
// Provides a convenient helper for creating permissive policy rules.
func wildcard() policyv2.Alias {
return policyv2.Wildcard
}
// usernamep returns a pointer to a Username as an Alias for policy v2 configurations.
// Used in ACL rules to reference specific users in network access policies.
func usernamep(name string) policyv2.Alias {
return ptr.To(policyv2.Username(name))
}
// hostp returns a pointer to a Host as an Alias for policy v2 configurations.
// Used in ACL rules to reference specific hosts in network access policies.
func hostp(name string) policyv2.Alias {
return ptr.To(policyv2.Host(name))
}
// groupp returns a pointer to a Group as an Alias for policy v2 configurations.
// Used in ACL rules to reference user groups in network access policies.
func groupp(name string) policyv2.Alias {
return ptr.To(policyv2.Group(name))
}
// tagp returns a pointer to a Tag as an Alias for policy v2 configurations.
// Used in ACL rules to reference node tags in network access policies.
func tagp(name string) policyv2.Alias {
return ptr.To(policyv2.Tag(name))
}
// prefixp returns a pointer to a Prefix from a CIDR string for policy v2 configurations.
// Converts CIDR notation to policy prefix format for network range specifications.
func prefixp(cidr string) policyv2.Alias {
prefix := netip.MustParsePrefix(cidr)
return ptr.To(policyv2.Prefix(prefix))
}
// aliasWithPorts creates an AliasWithPorts structure from an alias and port ranges.
// Combines network targets with specific port restrictions for fine-grained
// access control in policy v2 configurations.
func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts {
return policyv2.AliasWithPorts{
Alias: alias,
Ports: ports,
}
}
// usernameOwner returns a Username as an Owner for use in TagOwners policies.
// Specifies which users can assign and manage specific tags in ACL configurations.
func usernameOwner(name string) policyv2.Owner {
return ptr.To(policyv2.Username(name))
}
// groupOwner returns a Group as an Owner for use in TagOwners policies.
// Specifies which groups can assign and manage specific tags in ACL configurations.
func groupOwner(name string) policyv2.Owner {
return ptr.To(policyv2.Group(name))
}
// usernameApprover returns a Username as an AutoApprover for subnet route policies.
// Specifies which users can automatically approve subnet route advertisements.
func usernameApprover(name string) policyv2.AutoApprover {
return ptr.To(policyv2.Username(name))
}
// groupApprover returns a Group as an AutoApprover for subnet route policies.
// Specifies which groups can automatically approve subnet route advertisements.
func groupApprover(name string) policyv2.AutoApprover {
return ptr.To(policyv2.Group(name))
}
// tagApprover returns a Tag as an AutoApprover for subnet route policies.
// Specifies which tagged nodes can automatically approve subnet route advertisements.
func tagApprover(name string) policyv2.AutoApprover {
return ptr.To(policyv2.Tag(name))
}
// oidcMockUser creates a MockUser for OIDC authentication testing.
// Generates consistent test user data with configurable email verification status
// for validating OIDC integration flows in headscale authentication tests.
func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
return mockoidc.MockUser{
Subject: username,
PreferredUsername: username,
Email: username + "@headscale.net",
EmailVerified: emailVerified,
}
}

View File

@@ -460,6 +460,12 @@ func New(
dockertestutil.DockerAllowNetworkAdministration,
)
if err != nil {
// Try to get more detailed build output
log.Printf("Docker build failed, attempting to get detailed output...")
buildOutput := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, IntegrationTestDockerFileName)
if buildOutput != "" {
return nil, fmt.Errorf("could not start headscale container: %w\n\nDetailed build output:\n%s", err, buildOutput)
}
return nil, fmt.Errorf("could not start headscale container: %w", err)
}
log.Printf("Created %s container\n", hsic.hostname)

View File

@@ -53,16 +53,16 @@ func TestEnablingRoutes(t *testing.T) {
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{tsic.WithAcceptRoutes()},
hsic.WithTestName("clienableroute"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
expectedRoutes := map[string]string{
"1": "10.0.0.0/24",
@@ -83,7 +83,7 @@ func TestEnablingRoutes(t *testing.T) {
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
var nodes []*v1.Node
// Wait for route advertisements to propagate to NodeStore
@@ -256,16 +256,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
prefp, err := scenario.SubnetOfNetwork("usernet1")
require.NoError(t, err)
@@ -319,7 +319,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// Wait for route configuration changes after advertising routes
var nodes []*v1.Node
@@ -1341,16 +1341,16 @@ func TestSubnetRouteACL(t *testing.T) {
},
},
))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
expectedRoutes := map[string]string{
"1": "10.33.0.0/16",
@@ -1393,7 +1393,7 @@ func TestSubnetRouteACL(t *testing.T) {
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
// Wait for route advertisements to propagate to the server
var nodes []*v1.Node
@@ -1572,25 +1572,25 @@ func TestEnablingExitRoutes(t *testing.T) {
}
scenario, err := NewScenario(spec)
assertNoErrf(t, "failed to create scenario: %s", err)
require.NoErrorf(t, err, "failed to create scenario")
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{
tsic.WithExtraLoginArgs([]string{"--advertise-exit-node"}),
}, hsic.WithTestName("clienableroute"))
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
nodes, err := headscale.ListNodes()
require.NoError(t, err)
@@ -1686,16 +1686,16 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
assert.NotNil(t, headscale)
pref, err := scenario.SubnetOfNetwork("usernet1")
@@ -1833,16 +1833,16 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
assert.NotNil(t, headscale)
var user1c, user2c TailscaleClient
@@ -2247,13 +2247,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
err = scenario.createHeadscaleEnv(tt.withURL, tsOpts,
opts...,
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
services, err := scenario.Services("usernet1")
require.NoError(t, err)
@@ -2263,7 +2263,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
require.NoError(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
assert.NotNil(t, headscale)
// Add the Docker network route to the auto-approvers
@@ -2304,21 +2304,21 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
if tt.withURL {
u, err := routerUsernet1.LoginWithURL(headscale.GetEndpoint())
assertNoErr(t, err)
require.NoError(t, err)
body, err := doLoginURL(routerUsernet1.Hostname(), u)
assertNoErr(t, err)
require.NoError(t, err)
scenario.runHeadscaleRegister("user1", body)
} else {
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
require.NoError(t, err)
pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
assertNoErr(t, err)
require.NoError(t, err)
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey())
assertNoErr(t, err)
require.NoError(t, err)
}
// extra creation end.
@@ -2893,13 +2893,13 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
hsic.WithACLPolicy(aclPolicy),
hsic.WithPolicyMode(types.PolicyModeDB),
)
assertNoErrHeadscaleEnv(t, err)
requireNoErrHeadscaleEnv(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
requireNoErrGetHeadscale(t, err)
// Get the router and node clients by user
routerClients, err := scenario.ListTailscaleClients(routerUser)
@@ -2944,7 +2944,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
require.NoErrorf(t, err, "failed to advertise routes: %s", err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
var routerNode, nodeNode *v1.Node
// Wait for route advertisements to propagate to NodeStore

View File

@@ -838,14 +838,14 @@ func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
var err error
hc := &http.Client{
Transport: LoggingRoundTripper{},
Transport: LoggingRoundTripper{Hostname: hostname},
}
hc.Jar, err = cookiejar.New(nil)
if err != nil {
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
}
log.Printf("%s logging in with url", hostname)
log.Printf("%s logging in with url: %s", hostname, loginURL.String())
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
resp, err := hc.Do(req)
@@ -907,7 +907,9 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable)
}
type LoggingRoundTripper struct{}
type LoggingRoundTripper struct {
Hostname string
}
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
noTls := &http.Transport{
@@ -918,9 +920,12 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
return nil, err
}
log.Printf("---")
log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
log.Printf(`
---
%s - method: %s | url: %s
%s - status: %d | cookies: %+v
---
`, t.Hostname, req.Method, req.URL.String(), t.Hostname, resp.StatusCode, resp.Cookies())
return resp, nil
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/require"
)
// This file is intended to "test the test framework", by proxy it will also test
@@ -34,7 +35,7 @@ func TestHeadscale(t *testing.T) {
user := "test-space"
scenario, err := NewScenario(ScenarioSpec{})
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
t.Run("start-headscale", func(t *testing.T) {
@@ -82,7 +83,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
count := 1
scenario, err := NewScenario(ScenarioSpec{})
assertNoErr(t, err)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
t.Run("start-headscale", func(t *testing.T) {

View File

@@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
@@ -30,7 +31,7 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
assertNoErr(t, err)
require.NoError(t, err)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{
@@ -50,13 +51,13 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce
hsic.WithACLPolicy(policy),
hsic.WithTestName("ssh"),
)
assertNoErr(t, err)
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErr(t, err)
require.NoError(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErr(t, err)
require.NoError(t, err)
return scenario
}
@@ -93,19 +94,19 @@ func TestSSHOneUserToAll(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2")
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
for _, client := range user1Clients {
for _, peer := range allClients {
@@ -160,16 +161,16 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t)
nsOneClients, err := scenario.ListTailscaleClients("user1")
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
nsTwoClients, err := scenario.ListTailscaleClients("user2")
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
testInterUserSSH := func(sourceClients []TailscaleClient, targetClients []TailscaleClient) {
for _, client := range sourceClients {
@@ -208,13 +209,13 @@ func TestSSHNoSSHConfigured(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
for _, client := range allClients {
for _, peer := range allClients {
@@ -259,13 +260,13 @@ func TestSSHIsBlockedInACL(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
for _, client := range allClients {
for _, peer := range allClients {
@@ -317,16 +318,16 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t)
ssh1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
ssh2Clients, err := scenario.ListTailscaleClients("user2")
assertNoErrListClients(t, err)
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
requireNoErrListFQDN(t, err)
for _, client := range ssh1Clients {
for _, peer := range ssh2Clients {
@@ -422,9 +423,9 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien
t.Helper()
result, _, err := doSSH(t, client, peer)
assertNoErr(t, err)
require.NoError(t, err)
assertContains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", ""))
require.Contains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", ""))
}
func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) {

View File

@@ -322,6 +322,20 @@ func New(
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
if err != nil {
// Try to get more detailed build output
log.Printf("Docker build failed for %s, attempting to get detailed output...", hostname)
buildOutput := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, "Dockerfile.tailscale-HEAD")
if buildOutput != "" {
return nil, fmt.Errorf(
"%s could not start tailscale container (version: %s): %w\n\nDetailed build output:\n%s",
hostname,
version,
err,
buildOutput,
)
}
}
case "unstable":
tailscaleOptions.Repository = "tailscale/tailscale"
tailscaleOptions.Tag = version
@@ -333,6 +347,9 @@ func New(
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
if err != nil {
log.Printf("Docker run failed for %s (unstable), error: %v", hostname, err)
}
default:
tailscaleOptions.Repository = "tailscale/tailscale"
tailscaleOptions.Tag = "v" + version
@@ -344,6 +361,9 @@ func New(
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
if err != nil {
log.Printf("Docker run failed for %s (version: v%s), error: %v", hostname, version, err)
}
}
if err != nil {

View File

@@ -1,533 +0,0 @@
package integration
import (
"bufio"
"bytes"
"fmt"
"io"
"net/netip"
"strings"
"sync"
"testing"
"time"
"github.com/cenkalti/backoff/v5"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
const (
// derpPingTimeout defines the timeout for individual DERP ping operations
// Used in DERP connectivity tests to verify relay server communication.
derpPingTimeout = 2 * time.Second
// derpPingCount defines the number of ping attempts for DERP connectivity tests
// Higher count provides better reliability assessment of DERP connectivity.
derpPingCount = 10
// TimestampFormat is the standard timestamp format used across all integration tests
// Format: "2006-01-02T15-04-05.999999999" provides high precision timestamps
// suitable for debugging and log correlation in integration tests.
TimestampFormat = "2006-01-02T15-04-05.999999999"
// TimestampFormatRunID is used for generating unique run identifiers
// Format: "20060102-150405" provides compact date-time for file/directory names.
TimestampFormatRunID = "20060102-150405"
)
func assertNoErr(t *testing.T, err error) {
t.Helper()
assertNoErrf(t, "unexpected error: %s", err)
}
func assertNoErrf(t *testing.T, msg string, err error) {
t.Helper()
if err != nil {
t.Fatalf(msg, err)
}
}
func assertNotNil(t *testing.T, thing interface{}) {
t.Helper()
if thing == nil {
t.Fatal("got unexpected nil")
}
}
func assertNoErrHeadscaleEnv(t *testing.T, err error) {
t.Helper()
assertNoErrf(t, "failed to create headscale environment: %s", err)
}
func assertNoErrGetHeadscale(t *testing.T, err error) {
t.Helper()
assertNoErrf(t, "failed to get headscale: %s", err)
}
func assertNoErrListClients(t *testing.T, err error) {
t.Helper()
assertNoErrf(t, "failed to list clients: %s", err)
}
func assertNoErrListClientIPs(t *testing.T, err error) {
t.Helper()
assertNoErrf(t, "failed to get client IPs: %s", err)
}
func assertNoErrSync(t *testing.T, err error) {
t.Helper()
assertNoErrf(t, "failed to have all clients sync up: %s", err)
}
func assertNoErrListFQDN(t *testing.T, err error) {
t.Helper()
assertNoErrf(t, "failed to list FQDNs: %s", err)
}
func assertNoErrLogout(t *testing.T, err error) {
t.Helper()
assertNoErrf(t, "failed to log out tailscale nodes: %s", err)
}
func assertContains(t *testing.T, str, subStr string) {
t.Helper()
if !strings.Contains(str, subStr) {
t.Fatalf("%#v does not contain %#v", str, subStr)
}
}
func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool {
t.Helper()
buf := &bytes.Buffer{}
err := client.WriteLogs(buf, buf)
if err != nil {
t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err)
}
count, err := countMatchingLines(buf, func(line string) bool {
return strings.Contains(line, "websocket: connected to ")
})
if err != nil {
t.Fatalf("failed to process client logs: %s: %s", client.Hostname(), err)
}
return count > 0
}
// pingAllHelper performs ping tests between all clients and addresses, returning success count.
// This is used to validate network connectivity in integration tests.
// Returns the total number of successful ping operations.
func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int {
t.Helper()
success := 0
for _, client := range clients {
for _, addr := range addrs {
err := client.Ping(addr, opts...)
if err != nil {
t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
} else {
success++
}
}
}
return success
}
// pingDerpAllHelper performs DERP-based ping tests between all clients and addresses.
// This specifically tests connectivity through DERP relay servers, which is important
// for validating NAT traversal and relay functionality. Returns success count.
func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {
t.Helper()
success := 0
for _, client := range clients {
for _, addr := range addrs {
if isSelfClient(client, addr) {
continue
}
err := client.Ping(
addr,
tsic.WithPingTimeout(derpPingTimeout),
tsic.WithPingCount(derpPingCount),
tsic.WithPingUntilDirect(false),
)
if err != nil {
t.Logf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
} else {
success++
}
}
}
return success
}
// assertClientsState validates the status and netmap of a list of
// clients for the general case of all to all connectivity.
func assertClientsState(t *testing.T, clients []TailscaleClient) {
t.Helper()
var wg sync.WaitGroup
for _, client := range clients {
wg.Add(1)
c := client // Avoid loop pointer
go func() {
defer wg.Done()
assertValidStatus(t, c)
assertValidNetcheck(t, c)
assertValidNetmap(t, c)
}()
}
t.Logf("waiting for client state checks to finish")
wg.Wait()
}
// assertValidNetmap asserts that the netmap of a client has all
// the minimum required fields set to a known working config for
// the general case. Fields are checked on self, then all peers.
// This test is not suitable for ACL/partial connection tests.
// This test can only be run on clients from 1.56.1. It will
// automatically pass all clients below that and is safe to call
// for all versions.
func assertValidNetmap(t *testing.T, client TailscaleClient) {
t.Helper()
if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) {
t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version())
return
}
t.Logf("Checking netmap of %q", client.Hostname())
netmap, err := client.Netmap()
if err != nil {
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
}
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
}
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
for _, peer := range netmap.Peers {
assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
if hi := peer.Hostinfo(); hi.Valid() {
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
// Netinfo is not always set
// assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
if ni := hi.NetInfo(); ni.Valid() {
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
}
}
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
}
}
// assertValidStatus asserts that the status of a client has all
// the minimum required fields set to a known working config for
// the general case. Fields are checked on self, then all peers.
// This test is not suitable for ACL/partial connection tests.
func assertValidStatus(t *testing.T, client TailscaleClient) {
t.Helper()
status, err := client.Status(true)
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
}
assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname())
assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname())
// This seem to not appear until version 1.56
if status.Self.AllowedIPs != nil {
assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname())
}
assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname())
assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname())
assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname())
// This isn't really relevant for Self as it won't be in its own socket/wireguard.
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
// assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname())
for _, peer := range status.Peer {
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname())
assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname())
// This seem to not appear until version 1.56
if peer.AllowedIPs != nil {
assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname())
}
// Addrs does not seem to appear in the status from peers.
// assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname())
assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname())
assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname())
assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname())
// TODO(kradalby): InEngine is only true when a proper tunnel is set up,
// there might be some interesting stuff to test here in the future.
// assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname())
}
}
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
t.Helper()
report, err := client.Netcheck()
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
}
assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname())
}
// assertCommandOutputContains executes a command with exponential backoff retry until the output
// contains the expected string or timeout is reached (10 seconds).
// This implements eventual consistency patterns and should be used instead of time.Sleep
// before executing commands that depend on network state propagation.
//
// Timeout: 10 seconds with exponential backoff
// Use cases: DNS resolution, route propagation, policy updates.
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
t.Helper()
_, err := backoff.Retry(t.Context(), func() (struct{}, error) {
stdout, stderr, err := c.Execute(command)
if err != nil {
return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
}
if !strings.Contains(stdout, contains) {
return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout)
}
return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
assert.NoError(t, err)
}
func isSelfClient(client TailscaleClient, addr string) bool {
if addr == client.Hostname() {
return true
}
ips, err := client.IPs()
if err != nil {
return false
}
for _, ip := range ips {
if ip.String() == addr {
return true
}
}
return false
}
func dockertestMaxWait() time.Duration {
wait := 300 * time.Second //nolint
if util.IsCI() {
wait = 600 * time.Second //nolint
}
return wait
}
func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) {
count := 0
scanner := bufio.NewScanner(in)
{
const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB
buff := make([]byte, logBufferInitialSize)
scanner.Buffer(buff, len(buff))
scanner.Split(bufio.ScanLines)
}
for scanner.Scan() {
if predicate(scanner.Text()) {
count += 1
}
}
return count, scanner.Err()
}
// func dockertestCommandTimeout() time.Duration {
// timeout := 10 * time.Second //nolint
//
// if isCI() {
// timeout = 60 * time.Second //nolint
// }
//
// return timeout
// }
// pingAllNegativeHelper is intended to have 1 or more nodes timing out from the ping,
// it counts failures instead of successes.
// func pingAllNegativeHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {
// t.Helper()
// failures := 0
//
// timeout := 100
// count := 3
//
// for _, client := range clients {
// for _, addr := range addrs {
// err := client.Ping(
// addr,
// tsic.WithPingTimeout(time.Duration(timeout)*time.Millisecond),
// tsic.WithPingCount(count),
// )
// if err != nil {
// failures++
// }
// }
// }
//
// return failures
// }
// // findPeerByIP takes an IP and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus
// // if there is a peer with the given IP. If no peer is found, nil is returned.
// func findPeerByIP(
// ip netip.Addr,
// peers map[key.NodePublic]*ipnstate.PeerStatus,
// ) *ipnstate.PeerStatus {
// for _, peer := range peers {
// for _, peerIP := range peer.TailscaleIPs {
// if ip == peerIP {
// return peer
// }
// }
// }
//
// return nil
// }
// Helper functions for creating typed policy entities
// wildcard returns a wildcard alias (*).
func wildcard() policyv2.Alias {
return policyv2.Wildcard
}
// usernamep returns a pointer to a Username as an Alias.
func usernamep(name string) policyv2.Alias {
return ptr.To(policyv2.Username(name))
}
// hostp returns a pointer to a Host.
func hostp(name string) policyv2.Alias {
return ptr.To(policyv2.Host(name))
}
// groupp returns a pointer to a Group as an Alias.
func groupp(name string) policyv2.Alias {
return ptr.To(policyv2.Group(name))
}
// tagp returns a pointer to a Tag as an Alias.
func tagp(name string) policyv2.Alias {
return ptr.To(policyv2.Tag(name))
}
// prefixp returns a pointer to a Prefix from a CIDR string.
func prefixp(cidr string) policyv2.Alias {
prefix := netip.MustParsePrefix(cidr)
return ptr.To(policyv2.Prefix(prefix))
}
// aliasWithPorts creates an AliasWithPorts structure from an alias and ports.
func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts {
return policyv2.AliasWithPorts{
Alias: alias,
Ports: ports,
}
}
// usernameOwner returns a Username as an Owner for use in TagOwners.
func usernameOwner(name string) policyv2.Owner {
return ptr.To(policyv2.Username(name))
}
// groupOwner returns a Group as an Owner for use in TagOwners.
func groupOwner(name string) policyv2.Owner {
return ptr.To(policyv2.Group(name))
}
// usernameApprover returns a Username as an AutoApprover.
func usernameApprover(name string) policyv2.AutoApprover {
return ptr.To(policyv2.Username(name))
}
// groupApprover returns a Group as an AutoApprover.
func groupApprover(name string) policyv2.AutoApprover {
return ptr.To(policyv2.Group(name))
}
// tagApprover returns a Tag as an AutoApprover.
func tagApprover(name string) policyv2.AutoApprover {
return ptr.To(policyv2.Tag(name))
}
//
// // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus
// // if there is a peer with the given hostname. If no peer is found, nil is returned.
// func findPeerByHostname(
// hostname string,
// peers map[key.NodePublic]*ipnstate.PeerStatus,
// ) *ipnstate.PeerStatus {
// for _, peer := range peers {
// if hostname == peer.HostName {
// return peer
// }
// }
//
// return nil
// }