mirror of
https://github.com/juanfont/headscale.git
synced 2025-10-26 02:33:51 +09:00
stability and race conditions in auth and node store (#2781)
This PR addresses some consistency issues that was introduced or discovered with the nodestore. nodestore: Now returns the node that is being put or updated when it is finished. This closes a race condition where when we read it back, we do not necessarily get the node with the given change and it ensures we get all the other updates from that batch write. auth: Authentication paths have been unified and simplified. It removes a lot of bad branches and ensures we only do the minimal work. A comprehensive auth test set has been created so we do not have to run integration tests to validate auth and it has allowed us to generate test cases for all the branches we currently know of. integration: added a lot more tooling and checks to validate that nodes reach the expected state when they come up and down. Standardised between the different auth models. A lot of this is to support or detect issues in the changes to nodestore (races) and auth (inconsistencies after login and reaching correct state) This PR was assisted, particularly tests, by claude code.
This commit is contained in:
4
.github/workflows/test-integration.yaml
vendored
4
.github/workflows/test-integration.yaml
vendored
@@ -31,9 +31,11 @@ jobs:
|
||||
- TestOIDC024UserCreation
|
||||
- TestOIDCAuthenticationWithPKCE
|
||||
- TestOIDCReloginSameNodeNewUser
|
||||
- TestOIDCReloginSameNodeSameUser
|
||||
- TestOIDCFollowUpUrl
|
||||
- TestAuthWebFlowAuthenticationPingAll
|
||||
- TestAuthWebFlowLogoutAndRelogin
|
||||
- TestAuthWebFlowLogoutAndReloginSameUser
|
||||
- TestAuthWebFlowLogoutAndReloginNewUser
|
||||
- TestUserCommand
|
||||
- TestPreAuthKeyCommand
|
||||
- TestPreAuthKeyCommandWithoutExpiry
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -25,26 +26,84 @@ type AuthProvider interface {
|
||||
|
||||
func (h *Headscale) handleRegister(
|
||||
ctx context.Context,
|
||||
regReq tailcfg.RegisterRequest,
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey)
|
||||
// Check for logout/expiry FIRST, before checking auth key.
|
||||
// Tailscale clients may send logout requests with BOTH a past expiry AND an auth key.
|
||||
// A past expiry takes precedence - it's a logout regardless of other fields.
|
||||
if !req.Expiry.IsZero() && req.Expiry.Before(time.Now()) {
|
||||
log.Debug().
|
||||
Str("node.key", req.NodeKey.ShortString()).
|
||||
Time("expiry", req.Expiry).
|
||||
Bool("has_auth", req.Auth != nil).
|
||||
Msg("Detected logout attempt with past expiry")
|
||||
|
||||
if ok {
|
||||
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||
// This is a logout attempt (expiry in the past)
|
||||
if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok {
|
||||
log.Debug().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Str("node.name", node.Hostname()).
|
||||
Bool("is_ephemeral", node.IsEphemeral()).
|
||||
Bool("has_authkey", node.AuthKey().Valid()).
|
||||
Msg("Found existing node for logout, calling handleLogout")
|
||||
|
||||
resp, err := h.handleLogout(node, req, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling logout: %w", err)
|
||||
}
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
} else {
|
||||
log.Warn().
|
||||
Str("node.key", req.NodeKey.ShortString()).
|
||||
Msg("Logout attempt but node not found in NodeStore")
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if regReq.Followup != "" {
|
||||
return h.waitForFollowup(ctx, regReq, machineKey)
|
||||
// If the register request does not contain a Auth struct, it means we are logging
|
||||
// out an existing node (legacy logout path for clients that send Auth=nil).
|
||||
if req.Auth == nil {
|
||||
// If the register request present a NodeKey that is currently in use, we will
|
||||
// check if the node needs to be sent to re-auth, or if the node is logging out.
|
||||
// We do not look up nodes by [key.MachinePublic] as it might belong to multiple
|
||||
// nodes, separated by users and this path is handling expiring/logout paths.
|
||||
if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok {
|
||||
resp, err := h.handleLogout(node, req, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||
}
|
||||
|
||||
// If resp is not nil, we have a response to return to the node.
|
||||
// If resp is nil, we should proceed and see if the node is trying to re-auth.
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
} else {
|
||||
// If the register request is not attempting to register a node, and
|
||||
// we cannot match it with an existing node, we consider that unexpected
|
||||
// as only register nodes should attempt to log out.
|
||||
log.Debug().
|
||||
Str("node.key", req.NodeKey.ShortString()).
|
||||
Str("machine.key", machineKey.ShortString()).
|
||||
Bool("unexpected", true).
|
||||
Msg("received register request with no auth, and no existing node")
|
||||
}
|
||||
}
|
||||
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||
// If the [tailcfg.RegisterRequest] has a Followup URL, it means that the
|
||||
// node has already started the registration process and we should wait for
|
||||
// it to finish the original registration.
|
||||
if req.Followup != "" {
|
||||
return h.waitForFollowup(ctx, req, machineKey)
|
||||
}
|
||||
|
||||
// Pre authenticated keys are handled slightly different than interactive
|
||||
// logins as they can be done fully sync and we can respond to the node with
|
||||
// the result as it is waiting.
|
||||
if isAuthKey(req) {
|
||||
resp, err := h.handleRegisterWithAuthKey(req, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
@@ -58,7 +117,7 @@ func (h *Headscale) handleRegister(
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp, err := h.handleRegisterInteractive(regReq, machineKey)
|
||||
resp, err := h.handleRegisterInteractive(req, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling register interactive: %w", err)
|
||||
}
|
||||
@@ -66,20 +125,34 @@ func (h *Headscale) handleRegister(
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) handleExistingNode(
|
||||
node *types.Node,
|
||||
regReq tailcfg.RegisterRequest,
|
||||
// handleLogout checks if the [tailcfg.RegisterRequest] is a
|
||||
// logout attempt from a node. If the node is not attempting to
|
||||
func (h *Headscale) handleLogout(
|
||||
node types.NodeView,
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
if node.MachineKey != machineKey {
|
||||
// Fail closed if it looks like this is an attempt to modify a node where
|
||||
// the node key and the machine key the noise session was started with does
|
||||
// not align.
|
||||
if node.MachineKey() != machineKey {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||
}
|
||||
|
||||
expired := node.IsExpired()
|
||||
// Note: We do NOT return early if req.Auth is set, because Tailscale clients
|
||||
// may send logout requests with BOTH a past expiry AND an auth key.
|
||||
// A past expiry indicates logout, regardless of whether Auth is present.
|
||||
// The expiry check below will handle the logout logic.
|
||||
|
||||
// If the node is expired and this is not a re-authentication attempt,
|
||||
// force the client to re-authenticate
|
||||
if expired && regReq.Auth == nil {
|
||||
// force the client to re-authenticate.
|
||||
// TODO(kradalby): I wonder if this is a path we ever hit?
|
||||
if node.IsExpired() {
|
||||
log.Trace().Str("node.name", node.Hostname()).
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Interface("reg.req", req).
|
||||
Bool("unexpected", true).
|
||||
Msg("Node key expired, forcing re-authentication")
|
||||
return &tailcfg.RegisterResponse{
|
||||
NodeKeyExpired: true,
|
||||
MachineAuthorized: false,
|
||||
@@ -87,49 +160,76 @@ func (h *Headscale) handleExistingNode(
|
||||
}, nil
|
||||
}
|
||||
|
||||
if !expired && !regReq.Expiry.IsZero() {
|
||||
requestExpiry := regReq.Expiry
|
||||
// If we get here, the node is not currently expired, and not trying to
|
||||
// do an auth.
|
||||
// The node is likely logging out, but before we run that logic, we will validate
|
||||
// that the node is not attempting to tamper/extend their expiry.
|
||||
// If it is not, we will expire the node or in the case of an ephemeral node, delete it.
|
||||
|
||||
// The client is trying to extend their key, this is not allowed.
|
||||
if requestExpiry.After(time.Now()) {
|
||||
return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil)
|
||||
}
|
||||
|
||||
// If the request expiry is in the past, we consider it a logout.
|
||||
if requestExpiry.Before(time.Now()) {
|
||||
if node.IsEphemeral() {
|
||||
c, err := h.state.DeleteNode(node.View())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||
}
|
||||
|
||||
h.Change(c)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
h.Change(c)
|
||||
|
||||
// CRITICAL: Use the updated node view for the response
|
||||
// The original node object has stale expiry information
|
||||
node = updatedNode.AsStruct()
|
||||
// The client is trying to extend their key, this is not allowed.
|
||||
if req.Expiry.After(time.Now()) {
|
||||
return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node), nil
|
||||
// If the request expiry is in the past, we consider it a logout.
|
||||
if req.Expiry.Before(time.Now()) {
|
||||
log.Debug().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Str("node.name", node.Hostname()).
|
||||
Bool("is_ephemeral", node.IsEphemeral()).
|
||||
Bool("has_authkey", node.AuthKey().Valid()).
|
||||
Time("req.expiry", req.Expiry).
|
||||
Msg("Processing logout request with past expiry")
|
||||
|
||||
if node.IsEphemeral() {
|
||||
log.Info().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Str("node.name", node.Hostname()).
|
||||
Msg("Deleting ephemeral node during logout")
|
||||
|
||||
c, err := h.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||
}
|
||||
|
||||
h.Change(c)
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
NodeKeyExpired: true,
|
||||
MachineAuthorized: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Str("node.name", node.Hostname()).
|
||||
Msg("Node is not ephemeral, setting expiry instead of deleting")
|
||||
}
|
||||
|
||||
// Update the internal state with the nodes new expiry, meaning it is
|
||||
// logged out.
|
||||
updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), req.Expiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
h.Change(c)
|
||||
|
||||
return nodeToRegisterResponse(updatedNode), nil
|
||||
}
|
||||
|
||||
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||
// isAuthKey reports if the register request is a registration request
|
||||
// using an pre auth key.
|
||||
func isAuthKey(req tailcfg.RegisterRequest) bool {
|
||||
return req.Auth != nil && req.Auth.AuthKey != ""
|
||||
}
|
||||
|
||||
func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse {
|
||||
return &tailcfg.RegisterResponse{
|
||||
// TODO(kradalby): Only send for user-owned nodes
|
||||
// and not tagged nodes when tags is working.
|
||||
User: *node.User.TailscaleUser(),
|
||||
Login: *node.User.TailscaleLogin(),
|
||||
User: node.UserView().TailscaleUser(),
|
||||
Login: node.UserView().TailscaleLogin(),
|
||||
NodeKeyExpired: node.IsExpired(),
|
||||
|
||||
// Headscale does not implement the concept of machine authorization
|
||||
@@ -141,10 +241,10 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||
|
||||
func (h *Headscale) waitForFollowup(
|
||||
ctx context.Context,
|
||||
regReq tailcfg.RegisterRequest,
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
fu, err := url.Parse(regReq.Followup)
|
||||
fu, err := url.Parse(req.Followup)
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
|
||||
}
|
||||
@@ -161,21 +261,21 @@ func (h *Headscale) waitForFollowup(
|
||||
case node := <-reg.Registered:
|
||||
if node == nil {
|
||||
// registration is expired in the cache, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||
return h.reqToNewRegisterResponse(req, machineKey)
|
||||
}
|
||||
return nodeToRegisterResponse(node), nil
|
||||
return nodeToRegisterResponse(node.View()), nil
|
||||
}
|
||||
}
|
||||
|
||||
// if the follow-up registration isn't found anymore, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||
return h.reqToNewRegisterResponse(req, machineKey)
|
||||
}
|
||||
|
||||
// reqToNewRegisterResponse refreshes the registration flow by creating a new
|
||||
// registration ID and returning the corresponding AuthURL so the client can
|
||||
// restart the authentication process.
|
||||
func (h *Headscale) reqToNewRegisterResponse(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
newRegID, err := types.NewRegistrationID()
|
||||
@@ -183,18 +283,25 @@ func (h *Headscale) reqToNewRegisterResponse(
|
||||
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
||||
}
|
||||
|
||||
// Ensure we have valid hostinfo and hostname
|
||||
validHostinfo, hostname := util.EnsureValidHostinfo(
|
||||
req.Hostinfo,
|
||||
machineKey.String(),
|
||||
req.NodeKey.String(),
|
||||
)
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
types.Node{
|
||||
Hostname: regReq.Hostinfo.Hostname,
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: regReq.NodeKey,
|
||||
Hostinfo: regReq.Hostinfo,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: validHostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = ®Req.Expiry
|
||||
if !req.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = &req.Expiry
|
||||
}
|
||||
|
||||
log.Info().Msgf("New followup node registration using key: %s", newRegID)
|
||||
@@ -206,11 +313,11 @@ func (h *Headscale) reqToNewRegisterResponse(
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
||||
regReq,
|
||||
req,
|
||||
machineKey,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -262,18 +369,26 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// h.Change(policyChange)
|
||||
// }
|
||||
|
||||
user := node.User()
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
resp := &tailcfg.RegisterResponse{
|
||||
MachineAuthorized: true,
|
||||
NodeKeyExpired: node.IsExpired(),
|
||||
User: *user.TailscaleUser(),
|
||||
Login: *user.TailscaleLogin(),
|
||||
}, nil
|
||||
User: node.UserView().TailscaleUser(),
|
||||
Login: node.UserView().TailscaleLogin(),
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Interface("reg.resp", resp).
|
||||
Interface("reg.req", req).
|
||||
Str("node.name", node.Hostname()).
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Msg("RegisterResponse")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegisterInteractive(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
registrationId, err := types.NewRegistrationID()
|
||||
@@ -281,18 +396,39 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
return nil, fmt.Errorf("generating registration ID: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we have valid hostinfo and hostname
|
||||
validHostinfo, hostname := util.EnsureValidHostinfo(
|
||||
req.Hostinfo,
|
||||
machineKey.String(),
|
||||
req.NodeKey.String(),
|
||||
)
|
||||
|
||||
if req.Hostinfo == nil {
|
||||
log.Warn().
|
||||
Str("machine.key", machineKey.ShortString()).
|
||||
Str("node.key", req.NodeKey.ShortString()).
|
||||
Str("generated.hostname", hostname).
|
||||
Msg("Received registration request with nil hostinfo, generated default hostname")
|
||||
} else if req.Hostinfo.Hostname == "" {
|
||||
log.Warn().
|
||||
Str("machine.key", machineKey.ShortString()).
|
||||
Str("node.key", req.NodeKey.ShortString()).
|
||||
Str("generated.hostname", hostname).
|
||||
Msg("Received registration request with empty hostname, generated default")
|
||||
}
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
types.Node{
|
||||
Hostname: regReq.Hostinfo.Hostname,
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: regReq.NodeKey,
|
||||
Hostinfo: regReq.Hostinfo,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: validHostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = ®Req.Expiry
|
||||
if !req.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = &req.Expiry
|
||||
}
|
||||
|
||||
h.state.SetRegistrationCacheEntry(
|
||||
|
||||
3006
hscontrol/auth_test.go
Normal file
3006
hscontrol/auth_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -741,7 +741,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
hostinfo := tailcfg.Hostinfo{
|
||||
RoutableIPs: routes,
|
||||
OS: "TestOS",
|
||||
Hostname: "DebugTestNode",
|
||||
Hostname: request.GetName(),
|
||||
}
|
||||
|
||||
registrationId, err := types.RegistrationIDFromString(request.GetKey())
|
||||
|
||||
@@ -197,11 +197,12 @@ func (m *mapSession) serveLongPoll() {
|
||||
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
||||
|
||||
// Process the initial MapRequest to update node state (endpoints, hostinfo, etc.)
|
||||
// CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly
|
||||
// synchronized. When nodes reconnect, they send their hostinfo with announced routes
|
||||
// in the MapRequest. We need this data in NodeStore before Connect() sets up the
|
||||
// primary routes, otherwise SubnetRoutes() returns empty and the node is removed
|
||||
// from AvailableRoutes.
|
||||
// This must be done BEFORE calling Connect() to ensure routes are properly synchronized.
|
||||
// When nodes reconnect, they send their hostinfo with announced routes in the MapRequest.
|
||||
// We need this data in NodeStore before Connect() sets up the primary routes, because
|
||||
// SubnetRoutes() calculates the intersection of announced and approved routes. If we
|
||||
// call Connect() first, SubnetRoutes() returns empty (no announced routes yet), causing
|
||||
// the node to be incorrectly removed from AvailableRoutes.
|
||||
mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||
if err != nil {
|
||||
m.errf(err, "failed to update node from initial MapRequest")
|
||||
|
||||
@@ -60,9 +60,6 @@ type DebugStringInfo struct {
|
||||
|
||||
// DebugOverview returns a comprehensive overview of the current state for debugging.
|
||||
func (s *State) DebugOverview() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
allNodes := s.nodeStore.ListNodes()
|
||||
users, _ := s.ListAllUsers()
|
||||
|
||||
@@ -270,9 +267,6 @@ func (s *State) PolicyDebugString() string {
|
||||
|
||||
// DebugOverviewJSON returns a structured overview of the current state for debugging.
|
||||
func (s *State) DebugOverviewJSON() DebugOverviewInfo {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
allNodes := s.nodeStore.ListNodes()
|
||||
users, _ := s.ListAllUsers()
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ func TestNodeStoreDebugString(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
|
||||
store.PutNode(node1)
|
||||
store.PutNode(node2)
|
||||
_ = store.PutNode(node1)
|
||||
_ = store.PutNode(node2)
|
||||
|
||||
return store
|
||||
},
|
||||
|
||||
460
hscontrol/state/ephemeral_test.go
Normal file
460
hscontrol/state/ephemeral_test.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// TestEphemeralNodeDeleteWithConcurrentUpdate tests the race condition where UpdateNode and DeleteNode
|
||||
// are called concurrently and may be batched together. This reproduces the issue where ephemeral nodes
|
||||
// are not properly deleted during logout because UpdateNodeFromMapRequest returns a stale node view
|
||||
// after the node has been deleted from the NodeStore.
|
||||
func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
|
||||
// Create a simple test node
|
||||
node := createTestNode(1, 1, "test-user", "test-node")
|
||||
|
||||
// Create NodeStore
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
// Put the node in the store
|
||||
resultNode := store.PutNode(node)
|
||||
require.True(t, resultNode.Valid(), "initial PutNode should return valid node")
|
||||
|
||||
// Verify node exists
|
||||
retrievedNode, found := store.GetNode(node.ID)
|
||||
require.True(t, found)
|
||||
require.Equal(t, node.ID, retrievedNode.ID())
|
||||
|
||||
// Test scenario: UpdateNode is called, returns a node view from the batch,
|
||||
// but in the same batch a DeleteNode removes the node.
|
||||
// This simulates what happens when:
|
||||
// 1. UpdateNodeFromMapRequest calls UpdateNode and gets back updatedNode
|
||||
// 2. At the same time, handleLogout calls DeleteNode
|
||||
// 3. They get batched together: [UPDATE, DELETE]
|
||||
// 4. UPDATE modifies the node, DELETE removes it
|
||||
// 5. UpdateNode returns a node view based on the state AFTER both operations
|
||||
// 6. If DELETE came after UPDATE, the returned node should be invalid
|
||||
|
||||
done := make(chan bool, 2)
|
||||
var updatedNode types.NodeView
|
||||
var updateOk bool
|
||||
|
||||
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
|
||||
go func() {
|
||||
updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
|
||||
go func() {
|
||||
// Small delay to increase chance of batching together
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
store.DeleteNode(node.ID)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Wait for both operations
|
||||
<-done
|
||||
<-done
|
||||
|
||||
// Give batching time to complete
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// The key assertion: if UpdateNode and DeleteNode were batched together
|
||||
// with DELETE after UPDATE, then UpdateNode should return an invalid node
|
||||
// OR it should return a valid node but the node should no longer exist in the store
|
||||
|
||||
_, found = store.GetNode(node.ID)
|
||||
assert.False(t, found, "node should be deleted from NodeStore")
|
||||
|
||||
// If the update happened before delete in the batch, the returned node might be invalid
|
||||
if updateOk {
|
||||
t.Logf("UpdateNode returned ok=true, valid=%v", updatedNode.Valid())
|
||||
// This is the bug scenario - UpdateNode thinks it succeeded but node is gone
|
||||
if updatedNode.Valid() {
|
||||
t.Logf("WARNING: UpdateNode returned valid node but node was deleted - this indicates the race condition bug")
|
||||
}
|
||||
} else {
|
||||
t.Logf("UpdateNode correctly returned ok=false (node deleted in same batch)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch specifically tests that when
|
||||
// UpdateNode and DeleteNode are in the same batch with DELETE after UPDATE,
|
||||
// the UpdateNode should return an invalid node view.
|
||||
func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) {
|
||||
node := createTestNode(2, 1, "test-user", "test-node-2")
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
// Put node in store
|
||||
_ = store.PutNode(node)
|
||||
|
||||
// Simulate the exact sequence: UpdateNode gets queued, then DeleteNode gets queued,
|
||||
// they batch together, and we check what UpdateNode returns
|
||||
|
||||
resultChan := make(chan struct {
|
||||
node types.NodeView
|
||||
ok bool
|
||||
})
|
||||
|
||||
// Start UpdateNode - it will block until batch is applied
|
||||
go func() {
|
||||
node, ok := store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
resultChan <- struct {
|
||||
node types.NodeView
|
||||
ok bool
|
||||
}{node, ok}
|
||||
}()
|
||||
|
||||
// Give UpdateNode a moment to queue its work
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
// Now queue DeleteNode - should batch with the UPDATE
|
||||
store.DeleteNode(node.ID)
|
||||
|
||||
// Get the result from UpdateNode
|
||||
result := <-resultChan
|
||||
|
||||
// Wait for batch to complete
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Node should be deleted
|
||||
_, found := store.GetNode(node.ID)
|
||||
assert.False(t, found, "node should be deleted")
|
||||
|
||||
// The critical check: what did UpdateNode return?
|
||||
// After the commit c6b09289988f34398eb3157e31ba092eb8721a9f,
|
||||
// UpdateNode returns the node state from the batch.
|
||||
// If DELETE came after UPDATE in the batch, the node doesn't exist anymore,
|
||||
// so UpdateNode should return (invalid, false)
|
||||
t.Logf("UpdateNode returned: ok=%v, valid=%v", result.ok, result.node.Valid())
|
||||
|
||||
// This is the expected behavior - if node was deleted in same batch,
|
||||
// UpdateNode should return invalid node
|
||||
if result.ok && result.node.Valid() {
|
||||
t.Error("BUG: UpdateNode returned valid node even though it was deleted in same batch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPersistNodeToDBPreventsRaceCondition tests that persistNodeToDB correctly handles
|
||||
// the race condition where a node is deleted after UpdateNode returns but before
|
||||
// persistNodeToDB is called. This reproduces the ephemeral node deletion bug.
|
||||
func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) {
|
||||
node := createTestNode(3, 1, "test-user", "test-node-3")
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
// Put node in store
|
||||
_ = store.PutNode(node)
|
||||
|
||||
// Simulate UpdateNode being called
|
||||
updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
require.True(t, ok, "UpdateNode should succeed")
|
||||
require.True(t, updatedNode.Valid(), "UpdateNode should return valid node")
|
||||
|
||||
// Now delete the node (simulating ephemeral logout happening concurrently)
|
||||
store.DeleteNode(node.ID)
|
||||
|
||||
// Wait for deletion to complete
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Verify node is deleted
|
||||
_, found := store.GetNode(node.ID)
|
||||
require.False(t, found, "node should be deleted")
|
||||
|
||||
// Now try to use the updatedNode from before the deletion
|
||||
// In the old code, this would re-insert the node into the database
|
||||
// With our fix, GetNode check in persistNodeToDB should prevent this
|
||||
|
||||
// Simulate what persistNodeToDB does - check if node still exists
|
||||
_, exists := store.GetNode(updatedNode.ID())
|
||||
if !exists {
|
||||
t.Log("SUCCESS: persistNodeToDB check would prevent re-insertion of deleted node")
|
||||
} else {
|
||||
t.Error("BUG: Node still exists in NodeStore after deletion")
|
||||
}
|
||||
|
||||
// The key assertion: after deletion, attempting to persist the old updatedNode
|
||||
// should fail because the node no longer exists in NodeStore
|
||||
assert.False(t, exists, "persistNodeToDB should detect node was deleted and refuse to persist")
|
||||
}
|
||||
|
||||
// TestEphemeralNodeLogoutRaceCondition tests the specific race condition that occurs
|
||||
// when an ephemeral node logs out. This reproduces the bug where:
|
||||
// 1. UpdateNodeFromMapRequest calls UpdateNode and receives a node view
|
||||
// 2. Concurrently, handleLogout is called for the ephemeral node and calls DeleteNode
|
||||
// 3. UpdateNode and DeleteNode get batched together
|
||||
// 4. If UpdateNode's result is used to call persistNodeToDB after the deletion,
|
||||
// the node could be re-inserted into the database even though it was deleted
|
||||
func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
||||
ephemeralNode := createTestNode(4, 1, "test-user", "ephemeral-node")
|
||||
ephemeralNode.AuthKey = &types.PreAuthKey{
|
||||
ID: 1,
|
||||
Key: "test-key",
|
||||
Ephemeral: true,
|
||||
}
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
// Put ephemeral node in store
|
||||
_ = store.PutNode(ephemeralNode)
|
||||
|
||||
// Simulate concurrent operations:
|
||||
// 1. UpdateNode (from UpdateNodeFromMapRequest during polling)
|
||||
// 2. DeleteNode (from handleLogout when client sends logout request)
|
||||
|
||||
var updatedNode types.NodeView
|
||||
var updateOk bool
|
||||
done := make(chan bool, 2)
|
||||
|
||||
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
|
||||
go func() {
|
||||
updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
|
||||
go func() {
|
||||
time.Sleep(1 * time.Millisecond) // Slight delay to batch operations
|
||||
store.DeleteNode(ephemeralNode.ID)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Wait for both operations
|
||||
<-done
|
||||
<-done
|
||||
|
||||
// Give batching time to complete
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Node should be deleted from store
|
||||
_, found := store.GetNode(ephemeralNode.ID)
|
||||
assert.False(t, found, "ephemeral node should be deleted from NodeStore")
|
||||
|
||||
// Critical assertion: if UpdateNode returned before DeleteNode completed,
|
||||
// the updatedNode might be valid but the node is actually deleted.
|
||||
// This is the bug - UpdateNodeFromMapRequest would get a valid node,
|
||||
// then try to persist it, re-inserting the deleted ephemeral node.
|
||||
if updateOk && updatedNode.Valid() {
|
||||
t.Log("UpdateNode returned valid node, but node is deleted - this is the race condition")
|
||||
|
||||
// In the real code, this would cause persistNodeToDB to be called with updatedNode
|
||||
// The fix in persistNodeToDB checks if the node still exists:
|
||||
_, stillExists := store.GetNode(updatedNode.ID())
|
||||
assert.False(t, stillExists, "persistNodeToDB should check NodeStore and find node deleted")
|
||||
} else if !updateOk || !updatedNode.Valid() {
|
||||
t.Log("UpdateNode correctly returned invalid/not-ok result (delete happened in same batch)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateNodeFromMapRequestEphemeralLogoutSequence tests the exact sequence
|
||||
// that causes ephemeral node logout failures:
|
||||
// 1. Client sends MapRequest with updated endpoint info
|
||||
// 2. UpdateNodeFromMapRequest starts processing, calls UpdateNode
|
||||
// 3. Client sends logout request (past expiry)
|
||||
// 4. handleLogout calls DeleteNode for ephemeral node
|
||||
// 5. UpdateNode and DeleteNode batch together
|
||||
// 6. UpdateNode returns a valid node (from before delete in batch)
|
||||
// 7. persistNodeToDB is called with the stale valid node
|
||||
// 8. Node gets re-inserted into database instead of staying deleted
|
||||
func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
|
||||
ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5")
|
||||
ephemeralNode.AuthKey = &types.PreAuthKey{
|
||||
ID: 2,
|
||||
Key: "test-key-2",
|
||||
Ephemeral: true,
|
||||
}
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
// Initial state: ephemeral node exists
|
||||
_ = store.PutNode(ephemeralNode)
|
||||
|
||||
// Step 1: UpdateNodeFromMapRequest calls UpdateNode
|
||||
// (simulating client sending MapRequest with endpoint updates)
|
||||
updateStarted := make(chan bool)
|
||||
var updatedNode types.NodeView
|
||||
var updateOk bool
|
||||
|
||||
go func() {
|
||||
updateStarted <- true
|
||||
updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
endpoint := netip.MustParseAddrPort("10.0.0.1:41641")
|
||||
n.Endpoints = []netip.AddrPort{endpoint}
|
||||
})
|
||||
}()
|
||||
|
||||
<-updateStarted
|
||||
// Small delay to ensure UpdateNode is queued
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
// Step 2: Logout happens - handleLogout calls DeleteNode
|
||||
// (simulating client sending logout with past expiry)
|
||||
store.DeleteNode(ephemeralNode.ID)
|
||||
|
||||
// Wait for batching to complete
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Step 3: Check results
|
||||
_, nodeExists := store.GetNode(ephemeralNode.ID)
|
||||
assert.False(t, nodeExists, "ephemeral node must be deleted after logout")
|
||||
|
||||
// Step 4: Simulate what happens if we try to persist the updatedNode
|
||||
if updateOk && updatedNode.Valid() {
|
||||
// This is the problematic path - UpdateNode returned a valid node
|
||||
// but the node was deleted in the same batch
|
||||
t.Log("UpdateNode returned valid node even though node was deleted")
|
||||
|
||||
// The fix: persistNodeToDB must check NodeStore before persisting
|
||||
_, checkExists := store.GetNode(updatedNode.ID())
|
||||
if checkExists {
|
||||
t.Error("BUG: Node still exists in NodeStore after deletion - should be impossible")
|
||||
} else {
|
||||
t.Log("SUCCESS: persistNodeToDB would detect node is deleted and refuse to persist")
|
||||
}
|
||||
} else {
|
||||
t.Log("UpdateNode correctly indicated node was deleted (returned invalid or not-ok)")
|
||||
}
|
||||
|
||||
// Final assertion: node must not exist
|
||||
_, finalExists := store.GetNode(ephemeralNode.ID)
|
||||
assert.False(t, finalExists, "ephemeral node must remain deleted")
|
||||
}
|
||||
|
||||
// TestUpdateNodeDeletedInSameBatchReturnsInvalid specifically tests that when
|
||||
// UpdateNode and DeleteNode are batched together with DELETE after UPDATE,
|
||||
// UpdateNode returns ok=false to indicate the node was deleted.
|
||||
func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
|
||||
node := createTestNode(6, 1, "test-user", "test-node-6")
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
// Put node in store
|
||||
_ = store.PutNode(node)
|
||||
|
||||
// Queue UpdateNode
|
||||
updateDone := make(chan struct {
|
||||
node types.NodeView
|
||||
ok bool
|
||||
})
|
||||
|
||||
go func() {
|
||||
updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
updateDone <- struct {
|
||||
node types.NodeView
|
||||
ok bool
|
||||
}{updatedNode, ok}
|
||||
}()
|
||||
|
||||
// Small delay to ensure UpdateNode is queued
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
// Queue DeleteNode - should batch with UpdateNode
|
||||
store.DeleteNode(node.ID)
|
||||
|
||||
// Get UpdateNode result
|
||||
result := <-updateDone
|
||||
|
||||
// Wait for batch to complete
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Node should be deleted
|
||||
_, exists := store.GetNode(node.ID)
|
||||
assert.False(t, exists, "node should be deleted from store")
|
||||
|
||||
// UpdateNode should indicate the node was deleted
|
||||
// After c6b09289988f34398eb3157e31ba092eb8721a9f, when UPDATE and DELETE
|
||||
// are in the same batch with DELETE after UPDATE, UpdateNode returns
|
||||
// the state after the batch is applied - which means the node doesn't exist
|
||||
assert.False(t, result.ok, "UpdateNode should return ok=false when node deleted in same batch")
|
||||
assert.False(t, result.node.Valid(), "UpdateNode should return invalid node when node deleted in same batch")
|
||||
}
|
||||
|
||||
// TestPersistNodeToDBChecksNodeStoreBeforePersist verifies that persistNodeToDB
|
||||
// checks if the node still exists in NodeStore before persisting to database.
|
||||
// This prevents the race condition where:
|
||||
// 1. UpdateNodeFromMapRequest calls UpdateNode and gets a valid node
|
||||
// 2. Ephemeral node logout calls DeleteNode
|
||||
// 3. UpdateNode and DeleteNode batch together
|
||||
// 4. UpdateNode returns a valid node (from before delete in batch)
|
||||
// 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node
|
||||
// 6. persistNodeToDB must detect the node is deleted and refuse to persist
|
||||
func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
|
||||
ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7")
|
||||
ephemeralNode.AuthKey = &types.PreAuthKey{
|
||||
ID: 3,
|
||||
Key: "test-key-3",
|
||||
Ephemeral: true,
|
||||
}
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
// Put node in store
|
||||
_ = store.PutNode(ephemeralNode)
|
||||
|
||||
// Simulate the race:
|
||||
// 1. UpdateNode is called (from UpdateNodeFromMapRequest)
|
||||
updatedNode, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
require.True(t, ok, "UpdateNode should succeed")
|
||||
require.True(t, updatedNode.Valid(), "UpdateNode should return valid node")
|
||||
|
||||
// 2. Node is deleted (from handleLogout for ephemeral node)
|
||||
store.DeleteNode(ephemeralNode.ID)
|
||||
|
||||
// Wait for deletion
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 3. Verify node is deleted from store
|
||||
_, exists := store.GetNode(ephemeralNode.ID)
|
||||
require.False(t, exists, "node should be deleted from NodeStore")
|
||||
|
||||
// 4. Simulate what persistNodeToDB does - check if node still exists
|
||||
// The fix in persistNodeToDB checks NodeStore before persisting:
|
||||
// if !exists { return error }
|
||||
// This prevents re-inserting the deleted node into the database
|
||||
|
||||
// Verify the node from UpdateNode is valid but node is gone from store
|
||||
assert.True(t, updatedNode.Valid(), "UpdateNode returned a valid node view")
|
||||
_, stillExists := store.GetNode(updatedNode.ID())
|
||||
assert.False(t, stillExists, "but node should be deleted from NodeStore")
|
||||
|
||||
// This is the critical test: persistNodeToDB must check NodeStore
|
||||
// and refuse to persist if the node doesn't exist anymore
|
||||
// The actual persistNodeToDB implementation does:
|
||||
// _, exists := s.nodeStore.GetNode(node.ID())
|
||||
// if !exists { return error }
|
||||
}
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// NetInfoFromMapRequest determines the correct NetInfo to use.
|
||||
// netInfoFromMapRequest determines the correct NetInfo to use.
|
||||
// Returns the NetInfo that should be used for this request.
|
||||
func NetInfoFromMapRequest(
|
||||
func netInfoFromMapRequest(
|
||||
nodeID types.NodeID,
|
||||
currentHostinfo *tailcfg.Hostinfo,
|
||||
reqHostinfo *tailcfg.Hostinfo,
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestNetInfoFromMapRequest(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := NetInfoFromMapRequest(nodeID, tt.currentHostinfo, tt.reqHostinfo)
|
||||
result := netInfoFromMapRequest(nodeID, tt.currentHostinfo, tt.reqHostinfo)
|
||||
|
||||
if tt.expectNetInfo == nil {
|
||||
assert.Nil(t, result, "expected nil NetInfo")
|
||||
@@ -100,14 +100,40 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
// BUG: Using the node being modified (no NetInfo) instead of existing node (has NetInfo)
|
||||
buggyResult := NetInfoFromMapRequest(nodeID, nodeBeingModifiedHostinfo, newRegistrationHostinfo)
|
||||
buggyResult := netInfoFromMapRequest(nodeID, nodeBeingModifiedHostinfo, newRegistrationHostinfo)
|
||||
assert.Nil(t, buggyResult, "Bug: Should return nil when using wrong hostinfo reference")
|
||||
|
||||
// CORRECT: Using the existing node's hostinfo (has NetInfo)
|
||||
correctResult := NetInfoFromMapRequest(nodeID, existingNodeHostinfo, newRegistrationHostinfo)
|
||||
correctResult := netInfoFromMapRequest(nodeID, existingNodeHostinfo, newRegistrationHostinfo)
|
||||
assert.NotNil(t, correctResult, "Fix: Should preserve NetInfo when using correct hostinfo reference")
|
||||
assert.Equal(t, 5, correctResult.PreferredDERP, "Should preserve the DERP region from existing node")
|
||||
})
|
||||
|
||||
t.Run("new_node_creation_for_different_user_should_preserve_netinfo", func(t *testing.T) {
|
||||
// This test covers the scenario where:
|
||||
// 1. A node exists for user1 with NetInfo
|
||||
// 2. The same machine logs in as user2 (different user)
|
||||
// 3. A NEW node is created for user2 (pre-auth key flow)
|
||||
// 4. The new node should preserve NetInfo from the old node
|
||||
|
||||
// Existing node for user1 with NetInfo
|
||||
existingNodeUser1Hostinfo := &tailcfg.Hostinfo{
|
||||
Hostname: "test-node",
|
||||
NetInfo: &tailcfg.NetInfo{PreferredDERP: 7},
|
||||
}
|
||||
|
||||
// New registration request for user2 (no NetInfo yet)
|
||||
newNodeUser2Hostinfo := &tailcfg.Hostinfo{
|
||||
Hostname: "test-node",
|
||||
OS: "linux",
|
||||
// NetInfo is nil - registration request doesn't include it
|
||||
}
|
||||
|
||||
// When creating a new node for user2, we should preserve NetInfo from user1's node
|
||||
result := netInfoFromMapRequest(types.NodeID(2), existingNodeUser1Hostinfo, newNodeUser2Hostinfo)
|
||||
assert.NotNil(t, result, "New node for user2 should preserve NetInfo from user1's node")
|
||||
assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node")
|
||||
})
|
||||
}
|
||||
|
||||
// Simple helper function for tests
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
batchSize = 10
|
||||
batchSize = 100
|
||||
batchTimeout = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
@@ -121,10 +121,11 @@ type Snapshot struct {
|
||||
nodesByID map[types.NodeID]types.Node
|
||||
|
||||
// calculated from nodesByID
|
||||
nodesByNodeKey map[key.NodePublic]types.NodeView
|
||||
peersByNode map[types.NodeID][]types.NodeView
|
||||
nodesByUser map[types.UserID][]types.NodeView
|
||||
allNodes []types.NodeView
|
||||
nodesByNodeKey map[key.NodePublic]types.NodeView
|
||||
nodesByMachineKey map[key.MachinePublic]map[types.UserID]types.NodeView
|
||||
peersByNode map[types.NodeID][]types.NodeView
|
||||
nodesByUser map[types.UserID][]types.NodeView
|
||||
allNodes []types.NodeView
|
||||
}
|
||||
|
||||
// PeersFunc is a function that takes a list of nodes and returns a map
|
||||
@@ -135,26 +136,29 @@ type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
|
||||
// work represents a single operation to be performed on the NodeStore.
|
||||
type work struct {
|
||||
op int
|
||||
nodeID types.NodeID
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
op int
|
||||
nodeID types.NodeID
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
nodeResult chan types.NodeView // Channel to return the resulting node after batch application
|
||||
}
|
||||
|
||||
// PutNode adds or updates a node in the store.
|
||||
// If the node already exists, it will be replaced.
|
||||
// If the node does not exist, it will be added.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) PutNode(n types.Node) {
|
||||
// Returns the resulting node after all modifications in the batch have been applied.
|
||||
func (s *NodeStore) PutNode(n types.Node) types.NodeView {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: put,
|
||||
nodeID: n.ID,
|
||||
node: n,
|
||||
result: make(chan struct{}),
|
||||
op: put,
|
||||
nodeID: n.ID,
|
||||
node: n,
|
||||
result: make(chan struct{}),
|
||||
nodeResult: make(chan types.NodeView, 1),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
@@ -162,7 +166,10 @@ func (s *NodeStore) PutNode(n types.Node) {
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||
|
||||
return resultNode
|
||||
}
|
||||
|
||||
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
||||
@@ -173,6 +180,7 @@ type UpdateNodeFunc func(n *types.Node)
|
||||
// This is analogous to a database "transaction", or, the caller should
|
||||
// rather collect all data they want to change, and then call this function.
|
||||
// Fewer calls are better.
|
||||
// Returns the resulting node after all modifications in the batch have been applied.
|
||||
//
|
||||
// TODO(kradalby): Technically we could have a version of this that modifies the node
|
||||
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
|
||||
@@ -181,15 +189,16 @@ type UpdateNodeFunc func(n *types.Node)
|
||||
// a lock around the nodesByID map to ensure that no other writes are happening
|
||||
// while we are modifying the node. Which mean we would need to implement read-write locks
|
||||
// on all read operations.
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
work := work{
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
nodeResult: make(chan types.NodeView, 1),
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
@@ -197,7 +206,11 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||
|
||||
// Return the node and whether it exists (is valid)
|
||||
return resultNode, resultNode.Valid()
|
||||
}
|
||||
|
||||
// DeleteNode removes a node from the store by its ID.
|
||||
@@ -282,18 +295,32 @@ func (s *NodeStore) applyBatch(batch []work) {
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
maps.Copy(nodes, s.data.Load().nodesByID)
|
||||
|
||||
for _, w := range batch {
|
||||
// Track which work items need node results
|
||||
nodeResultRequests := make(map[types.NodeID][]*work)
|
||||
|
||||
for i := range batch {
|
||||
w := &batch[i]
|
||||
switch w.op {
|
||||
case put:
|
||||
nodes[w.nodeID] = w.node
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
case update:
|
||||
// Update the specific node identified by nodeID
|
||||
if n, exists := nodes[w.nodeID]; exists {
|
||||
w.updateFn(&n)
|
||||
nodes[w.nodeID] = n
|
||||
}
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
case del:
|
||||
delete(nodes, w.nodeID)
|
||||
// For delete operations, send an invalid NodeView if requested
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,6 +330,24 @@ func (s *NodeStore) applyBatch(batch []work) {
|
||||
// Update node count gauge
|
||||
nodeStoreNodesCount.Set(float64(len(nodes)))
|
||||
|
||||
// Send the resulting nodes to all work items that requested them
|
||||
for nodeID, workItems := range nodeResultRequests {
|
||||
if node, exists := nodes[nodeID]; exists {
|
||||
nodeView := node.View()
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- nodeView
|
||||
close(w.nodeResult)
|
||||
}
|
||||
} else {
|
||||
// Node was deleted or doesn't exist
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- types.NodeView{} // Send invalid view
|
||||
close(w.nodeResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Signal completion for all work items
|
||||
for _, w := range batch {
|
||||
close(w.result)
|
||||
}
|
||||
@@ -323,9 +368,10 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
||||
}
|
||||
|
||||
newSnap := Snapshot{
|
||||
nodesByID: nodes,
|
||||
allNodes: allNodes,
|
||||
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||
nodesByID: nodes,
|
||||
allNodes: allNodes,
|
||||
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||
nodesByMachineKey: make(map[key.MachinePublic]map[types.UserID]types.NodeView),
|
||||
|
||||
// peersByNode is most likely the most expensive operation,
|
||||
// it will use the list of all nodes, combined with the
|
||||
@@ -339,11 +385,19 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
}
|
||||
|
||||
// Build nodesByUser and nodesByNodeKey maps
|
||||
// Build nodesByUser, nodesByNodeKey, and nodesByMachineKey maps
|
||||
for _, n := range nodes {
|
||||
nodeView := n.View()
|
||||
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
|
||||
userID := types.UserID(n.UserID)
|
||||
|
||||
newSnap.nodesByUser[userID] = append(newSnap.nodesByUser[userID], nodeView)
|
||||
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
|
||||
|
||||
// Build machine key index
|
||||
if newSnap.nodesByMachineKey[n.MachineKey] == nil {
|
||||
newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView)
|
||||
}
|
||||
newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView
|
||||
}
|
||||
|
||||
return newSnap
|
||||
@@ -382,19 +436,40 @@ func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bo
|
||||
return nodeView, exists
|
||||
}
|
||||
|
||||
// GetNodeByMachineKey returns a node by its machine key. The bool indicates if the node exists.
|
||||
func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic) (types.NodeView, bool) {
|
||||
// GetNodeByMachineKey returns a node by its machine key and user ID. The bool indicates if the node exists.
|
||||
func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic, userID types.UserID) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("get_by_machine_key").Inc()
|
||||
|
||||
snapshot := s.data.Load()
|
||||
// We don't have a byMachineKey map, so we need to iterate
|
||||
// This could be optimized by adding a byMachineKey map if this becomes a hot path
|
||||
for _, node := range snapshot.nodesByID {
|
||||
if node.MachineKey == machineKey {
|
||||
return node.View(), true
|
||||
if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists {
|
||||
if node, exists := userMap[userID]; exists {
|
||||
return node, true
|
||||
}
|
||||
}
|
||||
|
||||
return types.NodeView{}, false
|
||||
}
|
||||
|
||||
// GetNodeByMachineKeyAnyUser returns the first node with the given machine key,
|
||||
// regardless of which user it belongs to. This is useful for scenarios like
|
||||
// transferring a node to a different user when re-authenticating with a
|
||||
// different user's auth key.
|
||||
// If multiple nodes exist with the same machine key (different users), the
|
||||
// first one found is returned (order is not guaranteed).
|
||||
func (s *NodeStore) GetNodeByMachineKeyAnyUser(machineKey key.MachinePublic) (types.NodeView, bool) {
|
||||
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key_any_user"))
|
||||
defer timer.ObserveDuration()
|
||||
|
||||
nodeStoreOperations.WithLabelValues("get_by_machine_key_any_user").Inc()
|
||||
|
||||
snapshot := s.data.Load()
|
||||
if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists {
|
||||
// Return the first node found (order not guaranteed due to map iteration)
|
||||
for _, node := range userMap {
|
||||
return node, true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -249,7 +253,9 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
name: "add first node",
|
||||
action: func(store *NodeStore) {
|
||||
node := createTestNode(1, 1, "user1", "node1")
|
||||
store.PutNode(node)
|
||||
resultNode := store.PutNode(node)
|
||||
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
|
||||
assert.Equal(t, node.ID, resultNode.ID())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
@@ -288,7 +294,9 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
name: "add second node same user",
|
||||
action: func(store *NodeStore) {
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
store.PutNode(node2)
|
||||
resultNode := store.PutNode(node2)
|
||||
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
|
||||
assert.Equal(t, types.NodeID(2), resultNode.ID())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
@@ -308,7 +316,9 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
name: "add third node different user",
|
||||
action: func(store *NodeStore) {
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
store.PutNode(node3)
|
||||
resultNode := store.PutNode(node3)
|
||||
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
|
||||
assert.Equal(t, types.NodeID(3), resultNode.ID())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
@@ -409,10 +419,14 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
{
|
||||
name: "update node hostname",
|
||||
action: func(store *NodeStore) {
|
||||
store.UpdateNode(1, func(n *types.Node) {
|
||||
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "updated-node1"
|
||||
n.GivenName = "updated-node1"
|
||||
})
|
||||
assert.True(t, ok, "UpdateNode should return true for existing node")
|
||||
assert.True(t, resultNode.Valid(), "Result node should be valid")
|
||||
assert.Equal(t, "updated-node1", resultNode.Hostname())
|
||||
assert.Equal(t, "updated-node1", resultNode.GivenName())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
|
||||
@@ -436,10 +450,14 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
name: "add nodes with odd-even filtering",
|
||||
action: func(store *NodeStore) {
|
||||
// Add nodes in sequence
|
||||
store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
n1 := store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
assert.True(t, n1.Valid())
|
||||
n2 := store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
assert.True(t, n2.Valid())
|
||||
n3 := store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
assert.True(t, n3.Valid())
|
||||
n4 := store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
assert.True(t, n4.Valid())
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
@@ -478,6 +496,328 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test batch modifications return correct node state",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial state",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "concurrent updates should reflect all batch changes",
|
||||
action: func(store *NodeStore) {
|
||||
// Start multiple updates that will be batched together
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2 types.NodeView
|
||||
var newNode3 types.NodeView
|
||||
var ok1, ok2 bool
|
||||
|
||||
// These should all be processed in the same batch
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "batch-updated-node1"
|
||||
n.GivenName = "batch-given-1"
|
||||
})
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
resultNode2, ok2 = store.UpdateNode(2, func(n *types.Node) {
|
||||
n.Hostname = "batch-updated-node2"
|
||||
n.GivenName = "batch-given-2"
|
||||
})
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
node3 := createTestNode(3, 1, "user1", "node3")
|
||||
newNode3 = store.PutNode(node3)
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
// Wait for all operations to complete
|
||||
<-done1
|
||||
<-done2
|
||||
<-done3
|
||||
|
||||
// Verify the returned nodes reflect the batch state
|
||||
assert.True(t, ok1, "UpdateNode should succeed for node 1")
|
||||
assert.True(t, ok2, "UpdateNode should succeed for node 2")
|
||||
assert.True(t, resultNode1.Valid())
|
||||
assert.True(t, resultNode2.Valid())
|
||||
assert.True(t, newNode3.Valid())
|
||||
|
||||
// Check that returned nodes have the updated values
|
||||
assert.Equal(t, "batch-updated-node1", resultNode1.Hostname())
|
||||
assert.Equal(t, "batch-given-1", resultNode1.GivenName())
|
||||
assert.Equal(t, "batch-updated-node2", resultNode2.Hostname())
|
||||
assert.Equal(t, "batch-given-2", resultNode2.GivenName())
|
||||
assert.Equal(t, "node3", newNode3.Hostname())
|
||||
|
||||
// Verify the snapshot also reflects all changes
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Equal(t, "batch-updated-node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "batch-updated-node2", snapshot.nodesByID[2].Hostname)
|
||||
assert.Equal(t, "node3", snapshot.nodesByID[3].Hostname)
|
||||
|
||||
// Verify peer relationships are updated correctly with new node
|
||||
assert.Len(t, snapshot.peersByNode[1], 2) // sees nodes 2 and 3
|
||||
assert.Len(t, snapshot.peersByNode[2], 2) // sees nodes 1 and 3
|
||||
assert.Len(t, snapshot.peersByNode[3], 2) // sees nodes 1 and 2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update non-existent node returns invalid view",
|
||||
action: func(store *NodeStore) {
|
||||
resultNode, ok := store.UpdateNode(999, func(n *types.Node) {
|
||||
n.Hostname = "should-not-exist"
|
||||
})
|
||||
|
||||
assert.False(t, ok, "UpdateNode should return false for non-existent node")
|
||||
assert.False(t, resultNode.Valid(), "Result should be invalid NodeView")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple updates to same node in batch all see final state",
|
||||
action: func(store *NodeStore) {
|
||||
// This test verifies that when multiple updates to the same node
|
||||
// are batched together, each returned node reflects ALL changes
|
||||
// in the batch, not just the individual update's changes.
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2, resultNode3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
|
||||
// These updates all modify node 1 and should be batched together
|
||||
// The final state should have all three modifications applied
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "multi-update-hostname"
|
||||
})
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "multi-update-givenname"
|
||||
})
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.ForcedTags = []string{"tag1", "tag2"}
|
||||
})
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
// Wait for all operations to complete
|
||||
<-done1
|
||||
<-done2
|
||||
<-done3
|
||||
|
||||
// All updates should succeed
|
||||
assert.True(t, ok1, "First update should succeed")
|
||||
assert.True(t, ok2, "Second update should succeed")
|
||||
assert.True(t, ok3, "Third update should succeed")
|
||||
|
||||
// CRITICAL: Each returned node should reflect ALL changes from the batch
|
||||
// not just the change from its specific update call
|
||||
|
||||
// resultNode1 (from hostname update) should also have the givenname and tags changes
|
||||
assert.Equal(t, "multi-update-hostname", resultNode1.Hostname())
|
||||
assert.Equal(t, "multi-update-givenname", resultNode1.GivenName())
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.ForcedTags().AsSlice())
|
||||
|
||||
// resultNode2 (from givenname update) should also have the hostname and tags changes
|
||||
assert.Equal(t, "multi-update-hostname", resultNode2.Hostname())
|
||||
assert.Equal(t, "multi-update-givenname", resultNode2.GivenName())
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.ForcedTags().AsSlice())
|
||||
|
||||
// resultNode3 (from tags update) should also have the hostname and givenname changes
|
||||
assert.Equal(t, "multi-update-hostname", resultNode3.Hostname())
|
||||
assert.Equal(t, "multi-update-givenname", resultNode3.GivenName())
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.ForcedTags().AsSlice())
|
||||
|
||||
// Verify the snapshot also has all changes
|
||||
snapshot := store.data.Load()
|
||||
finalNode := snapshot.nodesByID[1]
|
||||
assert.Equal(t, "multi-update-hostname", finalNode.Hostname)
|
||||
assert.Equal(t, "multi-update-givenname", finalNode.GivenName)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, finalNode.ForcedTags)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test UpdateNode result is immutable for database save",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify returned node is complete and consistent",
|
||||
action: func(store *NodeStore) {
|
||||
// Update a node and verify the returned view is complete
|
||||
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "db-save-hostname"
|
||||
n.GivenName = "db-save-given"
|
||||
n.ForcedTags = []string{"db-tag1", "db-tag2"}
|
||||
})
|
||||
|
||||
assert.True(t, ok, "UpdateNode should succeed")
|
||||
assert.True(t, resultNode.Valid(), "Result should be valid")
|
||||
|
||||
// Verify the returned node has all expected values
|
||||
assert.Equal(t, "db-save-hostname", resultNode.Hostname())
|
||||
assert.Equal(t, "db-save-given", resultNode.GivenName())
|
||||
assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.ForcedTags().AsSlice())
|
||||
|
||||
// Convert to struct as would be done for database save
|
||||
nodePtr := resultNode.AsStruct()
|
||||
assert.NotNil(t, nodePtr)
|
||||
assert.Equal(t, "db-save-hostname", nodePtr.Hostname)
|
||||
assert.Equal(t, "db-save-given", nodePtr.GivenName)
|
||||
assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.ForcedTags)
|
||||
|
||||
// Verify the snapshot also reflects the same state
|
||||
snapshot := store.data.Load()
|
||||
storedNode := snapshot.nodesByID[1]
|
||||
assert.Equal(t, "db-save-hostname", storedNode.Hostname)
|
||||
assert.Equal(t, "db-save-given", storedNode.GivenName)
|
||||
assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.ForcedTags)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "concurrent updates all return consistent final state for DB save",
|
||||
action: func(store *NodeStore) {
|
||||
// Multiple goroutines updating the same node
|
||||
// All should receive the final batch state suitable for DB save
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var result1, result2, result3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
|
||||
// Start concurrent updates
|
||||
go func() {
|
||||
result1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "concurrent-db-hostname"
|
||||
})
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
result2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "concurrent-db-given"
|
||||
})
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.ForcedTags = []string{"concurrent-tag"}
|
||||
})
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
// Wait for all to complete
|
||||
<-done1
|
||||
<-done2
|
||||
<-done3
|
||||
|
||||
assert.True(t, ok1 && ok2 && ok3, "All updates should succeed")
|
||||
|
||||
// All results should be valid and suitable for database save
|
||||
assert.True(t, result1.Valid())
|
||||
assert.True(t, result2.Valid())
|
||||
assert.True(t, result3.Valid())
|
||||
|
||||
// Convert each to struct as would be done for DB save
|
||||
nodePtr1 := result1.AsStruct()
|
||||
nodePtr2 := result2.AsStruct()
|
||||
nodePtr3 := result3.AsStruct()
|
||||
|
||||
// All should have the complete final state
|
||||
assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname)
|
||||
assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName)
|
||||
assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.ForcedTags)
|
||||
|
||||
assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname)
|
||||
assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName)
|
||||
assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.ForcedTags)
|
||||
|
||||
assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname)
|
||||
assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName)
|
||||
assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.ForcedTags)
|
||||
|
||||
// Verify consistency with stored state
|
||||
snapshot := store.data.Load()
|
||||
storedNode := snapshot.nodesByID[1]
|
||||
assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname)
|
||||
assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName)
|
||||
assert.Equal(t, nodePtr1.ForcedTags, storedNode.ForcedTags)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "verify returned node preserves all fields for DB save",
|
||||
action: func(store *NodeStore) {
|
||||
// Get initial state
|
||||
snapshot := store.data.Load()
|
||||
originalNode := snapshot.nodesByID[2]
|
||||
originalIPv4 := originalNode.IPv4
|
||||
originalIPv6 := originalNode.IPv6
|
||||
originalCreatedAt := originalNode.CreatedAt
|
||||
originalUser := originalNode.User
|
||||
|
||||
// Update only hostname
|
||||
resultNode, ok := store.UpdateNode(2, func(n *types.Node) {
|
||||
n.Hostname = "preserve-test-hostname"
|
||||
})
|
||||
|
||||
assert.True(t, ok, "Update should succeed")
|
||||
|
||||
// Convert to struct for DB save
|
||||
nodeForDB := resultNode.AsStruct()
|
||||
|
||||
// Verify all fields are preserved
|
||||
assert.Equal(t, "preserve-test-hostname", nodeForDB.Hostname)
|
||||
assert.Equal(t, originalIPv4, nodeForDB.IPv4)
|
||||
assert.Equal(t, originalIPv6, nodeForDB.IPv6)
|
||||
assert.Equal(t, originalCreatedAt, nodeForDB.CreatedAt)
|
||||
assert.Equal(t, originalUser.Name, nodeForDB.User.Name)
|
||||
assert.Equal(t, types.NodeID(2), nodeForDB.ID)
|
||||
|
||||
// These fields should be suitable for direct database save
|
||||
assert.NotNil(t, nodeForDB.IPv4)
|
||||
assert.NotNil(t, nodeForDB.IPv6)
|
||||
assert.False(t, nodeForDB.CreatedAt.IsZero())
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -499,3 +839,302 @@ type testStep struct {
|
||||
name string
|
||||
action func(store *NodeStore)
|
||||
}
|
||||
|
||||
// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests ---
|
||||
|
||||
// Helper for concurrent test nodes
|
||||
func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey := key.NewNode()
|
||||
return types.Node{
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
UserID: 1,
|
||||
User: types.User{
|
||||
Name: "concurrent-test-user",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// --- Concurrency: concurrent PutNode operations ---
|
||||
func TestNodeStoreConcurrentPutNode(t *testing.T) {
|
||||
const concurrentOps = 20
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan bool, concurrentOps)
|
||||
for i := 0; i < concurrentOps; i++ {
|
||||
wg.Add(1)
|
||||
go func(nodeID int) {
|
||||
defer wg.Done()
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node")
|
||||
resultNode := store.PutNode(node)
|
||||
results <- resultNode.Valid()
|
||||
}(i + 1)
|
||||
}
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successCount := 0
|
||||
for success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed")
|
||||
}
|
||||
|
||||
// --- Batching: concurrent ops fit in one batch ---
|
||||
func TestNodeStoreBatchingEfficiency(t *testing.T) {
|
||||
const batchSize = 10
|
||||
const ops = 15 // more than batchSize
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan bool, ops)
|
||||
for i := 0; i < ops; i++ {
|
||||
wg.Add(1)
|
||||
go func(nodeID int) {
|
||||
defer wg.Done()
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node")
|
||||
resultNode := store.PutNode(node)
|
||||
results <- resultNode.Valid()
|
||||
}(i + 1)
|
||||
}
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successCount := 0
|
||||
for success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
require.Equal(t, ops, successCount, "All batch PutNode operations should succeed")
|
||||
}
|
||||
|
||||
// --- Race conditions: many goroutines on same node ---
|
||||
func TestNodeStoreRaceConditions(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
node := createConcurrentTestNode(nodeID, "race-node")
|
||||
resultNode := store.PutNode(node)
|
||||
require.True(t, resultNode.Valid())
|
||||
|
||||
const numGoroutines = 30
|
||||
const opsPerGoroutine = 10
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numGoroutines*opsPerGoroutine)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(gid int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < opsPerGoroutine; j++ {
|
||||
switch j % 3 {
|
||||
case 0:
|
||||
resultNode, _ := store.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.Hostname = "race-updated"
|
||||
})
|
||||
if !resultNode.Valid() {
|
||||
errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j)
|
||||
}
|
||||
case 1:
|
||||
retrieved, found := store.GetNode(nodeID)
|
||||
if !found || !retrieved.Valid() {
|
||||
errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j)
|
||||
}
|
||||
case 2:
|
||||
newNode := createConcurrentTestNode(nodeID, "race-put")
|
||||
resultNode := store.PutNode(newNode)
|
||||
if !resultNode.Valid() {
|
||||
errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
errorCount := 0
|
||||
for err := range errors {
|
||||
t.Error(err)
|
||||
errorCount++
|
||||
}
|
||||
if errorCount > 0 {
|
||||
t.Fatalf("Race condition test failed with %d errors", errorCount)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Resource cleanup: goroutine leak detection ---
|
||||
func TestNodeStoreResourceCleanup(t *testing.T) {
|
||||
// initialGoroutines := runtime.NumGoroutine()
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
afterStartGoroutines := runtime.NumGoroutine()
|
||||
|
||||
const ops = 100
|
||||
for i := 0; i < ops; i++ {
|
||||
nodeID := types.NodeID(i + 1)
|
||||
node := createConcurrentTestNode(nodeID, "cleanup-node")
|
||||
resultNode := store.PutNode(node)
|
||||
assert.True(t, resultNode.Valid())
|
||||
store.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.Hostname = "cleanup-updated"
|
||||
})
|
||||
retrieved, found := store.GetNode(nodeID)
|
||||
assert.True(t, found && retrieved.Valid())
|
||||
if i%10 == 9 {
|
||||
store.DeleteNode(nodeID)
|
||||
}
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > afterStartGoroutines+2 {
|
||||
t.Errorf("Potential goroutine leak: started with %d, ended with %d", afterStartGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Timeout/deadlock: operations complete within reasonable time ---
|
||||
func TestNodeStoreOperationTimeout(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
const ops = 30
|
||||
var wg sync.WaitGroup
|
||||
putResults := make([]error, ops)
|
||||
updateResults := make([]error, ops)
|
||||
|
||||
// Launch all PutNode operations concurrently
|
||||
for i := 1; i <= ops; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
wg.Add(1)
|
||||
go func(idx int, id types.NodeID) {
|
||||
defer wg.Done()
|
||||
startPut := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id)
|
||||
node := createConcurrentTestNode(id, "timeout-node")
|
||||
resultNode := store.PutNode(node)
|
||||
endPut := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut))
|
||||
if !resultNode.Valid() {
|
||||
putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id)
|
||||
}
|
||||
}(i, nodeID)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Launch all UpdateNode operations concurrently
|
||||
wg = sync.WaitGroup{}
|
||||
for i := 1; i <= ops; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
wg.Add(1)
|
||||
go func(idx int, id types.NodeID) {
|
||||
defer wg.Done()
|
||||
startUpdate := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id)
|
||||
resultNode, ok := store.UpdateNode(id, func(n *types.Node) {
|
||||
n.Hostname = "timeout-updated"
|
||||
})
|
||||
endUpdate := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate))
|
||||
if !ok || !resultNode.Valid() {
|
||||
updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id)
|
||||
}
|
||||
}(i, nodeID)
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
errorCount := 0
|
||||
for _, err := range putResults {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
for _, err := range updateResults {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
if errorCount == 0 {
|
||||
t.Log("All concurrent operations completed successfully within timeout")
|
||||
} else {
|
||||
t.Fatalf("Some concurrent operations failed: %d errors", errorCount)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
fmt.Println("[TestNodeStoreOperationTimeout] Timeout reached, test failed")
|
||||
t.Fatal("Operations timed out - potential deadlock or resource issue")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Edge case: update non-existent node ---
|
||||
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
nonExistentID := types.NodeID(999 + i)
|
||||
updateCallCount := 0
|
||||
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID)
|
||||
resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) {
|
||||
updateCallCount++
|
||||
n.Hostname = "should-never-be-called"
|
||||
})
|
||||
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) finished, valid=%v, ok=%v, updateCallCount=%d\n", nonExistentID, resultNode.Valid(), ok, updateCallCount)
|
||||
assert.False(t, ok, "UpdateNode should return false for non-existent node")
|
||||
assert.False(t, resultNode.Valid(), "UpdateNode should return invalid node for non-existent node")
|
||||
assert.Equal(t, 0, updateCallCount, "UpdateFn should not be called for non-existent node")
|
||||
store.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// --- Allocation benchmark ---
|
||||
func BenchmarkNodeStoreAllocations(b *testing.B) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
nodeID := types.NodeID(i + 1)
|
||||
node := createConcurrentTestNode(nodeID, "bench-node")
|
||||
store.PutNode(node)
|
||||
store.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.Hostname = "bench-updated"
|
||||
})
|
||||
store.GetNode(nodeID)
|
||||
if i%10 == 9 {
|
||||
store.DeleteNode(nodeID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeStoreAllocationStats(t *testing.T) {
|
||||
res := testing.Benchmark(BenchmarkNodeStoreAllocations)
|
||||
allocs := res.AllocsPerOp()
|
||||
t.Logf("NodeStore allocations per op: %.2f", float64(allocs))
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -638,6 +638,11 @@ func (node Node) DebugString() string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (v NodeView) UserView() UserView {
|
||||
u := v.User()
|
||||
return u.View()
|
||||
}
|
||||
|
||||
func (v NodeView) IPs() []netip.Addr {
|
||||
if !v.Valid() {
|
||||
return nil
|
||||
|
||||
@@ -104,27 +104,31 @@ func (u *User) profilePicURL() string {
|
||||
return u.ProfilePicURL
|
||||
}
|
||||
|
||||
func (u *User) TailscaleUser() *tailcfg.User {
|
||||
user := tailcfg.User{
|
||||
func (u *User) TailscaleUser() tailcfg.User {
|
||||
return tailcfg.User{
|
||||
ID: tailcfg.UserID(u.ID),
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
Created: u.CreatedAt,
|
||||
}
|
||||
|
||||
return &user
|
||||
}
|
||||
|
||||
func (u *User) TailscaleLogin() *tailcfg.Login {
|
||||
login := tailcfg.Login{
|
||||
func (u UserView) TailscaleUser() tailcfg.User {
|
||||
return u.ж.TailscaleUser()
|
||||
}
|
||||
|
||||
func (u *User) TailscaleLogin() tailcfg.Login {
|
||||
return tailcfg.Login{
|
||||
ID: tailcfg.LoginID(u.ID),
|
||||
Provider: u.Provider,
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
}
|
||||
}
|
||||
|
||||
return &login
|
||||
func (u UserView) TailscaleLogin() tailcfg.Login {
|
||||
return u.ж.TailscaleLogin()
|
||||
}
|
||||
|
||||
func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
|
||||
@@ -136,6 +140,10 @@ func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
|
||||
}
|
||||
}
|
||||
|
||||
func (u UserView) TailscaleUserProfile() tailcfg.UserProfile {
|
||||
return u.ж.TailscaleUserProfile()
|
||||
}
|
||||
|
||||
func (u *User) Proto() *v1.User {
|
||||
return &v1.User{
|
||||
Id: uint64(u.ID),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/cmpver"
|
||||
)
|
||||
|
||||
@@ -258,3 +259,59 @@ func IsCI() bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SafeHostname extracts a hostname from Hostinfo, providing sensible defaults
|
||||
// if Hostinfo is nil or Hostname is empty. This prevents nil pointer dereferences
|
||||
// and ensures nodes always have a valid hostname.
|
||||
// The hostname is truncated to 63 characters to comply with DNS label length limits (RFC 1123).
|
||||
func SafeHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string {
|
||||
if hostinfo == nil || hostinfo.Hostname == "" {
|
||||
// Generate a default hostname using machine key prefix
|
||||
if machineKey != "" {
|
||||
keyPrefix := machineKey
|
||||
if len(machineKey) > 8 {
|
||||
keyPrefix = machineKey[:8]
|
||||
}
|
||||
return fmt.Sprintf("node-%s", keyPrefix)
|
||||
}
|
||||
if nodeKey != "" {
|
||||
keyPrefix := nodeKey
|
||||
if len(nodeKey) > 8 {
|
||||
keyPrefix = nodeKey[:8]
|
||||
}
|
||||
return fmt.Sprintf("node-%s", keyPrefix)
|
||||
}
|
||||
return "unknown-node"
|
||||
}
|
||||
|
||||
hostname := hostinfo.Hostname
|
||||
|
||||
// Validate hostname length - DNS label limit is 63 characters (RFC 1123)
|
||||
// Truncate if necessary to ensure compatibility with given name generation
|
||||
if len(hostname) > 63 {
|
||||
hostname = hostname[:63]
|
||||
}
|
||||
|
||||
return hostname
|
||||
}
|
||||
|
||||
// EnsureValidHostinfo ensures that Hostinfo is non-nil and has a valid hostname.
|
||||
// If Hostinfo is nil, it creates a minimal valid Hostinfo with a generated hostname.
|
||||
// Returns the validated/created Hostinfo and the extracted hostname.
|
||||
func EnsureValidHostinfo(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) (*tailcfg.Hostinfo, string) {
|
||||
if hostinfo == nil {
|
||||
hostname := SafeHostname(nil, machineKey, nodeKey)
|
||||
return &tailcfg.Hostinfo{
|
||||
Hostname: hostname,
|
||||
}, hostname
|
||||
}
|
||||
|
||||
hostname := SafeHostname(hostinfo, machineKey, nodeKey)
|
||||
|
||||
// Update the hostname in the hostinfo if it was empty or if it was truncated
|
||||
if hostinfo.Hostname == "" || hostinfo.Hostname != hostname {
|
||||
hostinfo.Hostname = hostname
|
||||
}
|
||||
|
||||
return hostinfo, hostname
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestTailscaleVersionNewerOrEqual(t *testing.T) {
|
||||
@@ -793,3 +794,395 @@ over a maximum of 30 hops:
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeHostname(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hostinfo *tailcfg.Hostinfo
|
||||
machineKey string
|
||||
nodeKey string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "valid_hostname",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "test-node",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "test-node",
|
||||
},
|
||||
{
|
||||
name: "nil_hostinfo_with_machine_key",
|
||||
hostinfo: nil,
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "node-mkey1234",
|
||||
},
|
||||
{
|
||||
name: "nil_hostinfo_with_node_key_only",
|
||||
hostinfo: nil,
|
||||
machineKey: "",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "node-nkey1234",
|
||||
},
|
||||
{
|
||||
name: "nil_hostinfo_no_keys",
|
||||
hostinfo: nil,
|
||||
machineKey: "",
|
||||
nodeKey: "",
|
||||
want: "unknown-node",
|
||||
},
|
||||
{
|
||||
name: "empty_hostname_with_machine_key",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "node-mkey1234",
|
||||
},
|
||||
{
|
||||
name: "empty_hostname_with_node_key_only",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "",
|
||||
},
|
||||
machineKey: "",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "node-nkey1234",
|
||||
},
|
||||
{
|
||||
name: "empty_hostname_no_keys",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "",
|
||||
},
|
||||
machineKey: "",
|
||||
nodeKey: "",
|
||||
want: "unknown-node",
|
||||
},
|
||||
{
|
||||
name: "hostname_exactly_63_chars",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "123456789012345678901234567890123456789012345678901234567890123",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "123456789012345678901234567890123456789012345678901234567890123",
|
||||
},
|
||||
{
|
||||
name: "hostname_64_chars_truncated",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "1234567890123456789012345678901234567890123456789012345678901234",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "123456789012345678901234567890123456789012345678901234567890123",
|
||||
},
|
||||
{
|
||||
name: "hostname_very_long_truncated",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits-of-63-characters-and-should-be-truncated",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits",
|
||||
},
|
||||
{
|
||||
name: "hostname_with_special_chars",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "node-with-special!@#$%",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "node-with-special!@#$%",
|
||||
},
|
||||
{
|
||||
name: "hostname_with_unicode",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "node-ñoño-测试",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "node-ñoño-测试",
|
||||
},
|
||||
{
|
||||
name: "short_machine_key",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "",
|
||||
},
|
||||
machineKey: "short",
|
||||
nodeKey: "nkey12345678",
|
||||
want: "node-short",
|
||||
},
|
||||
{
|
||||
name: "short_node_key",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "",
|
||||
},
|
||||
machineKey: "",
|
||||
nodeKey: "short",
|
||||
want: "node-short",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := SafeHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
|
||||
if got != tt.want {
|
||||
t.Errorf("SafeHostname() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureValidHostinfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hostinfo *tailcfg.Hostinfo
|
||||
machineKey string
|
||||
nodeKey string
|
||||
wantHostname string
|
||||
checkHostinfo func(*testing.T, *tailcfg.Hostinfo)
|
||||
}{
|
||||
{
|
||||
name: "valid_hostinfo_unchanged",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "test-node",
|
||||
OS: "linux",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
wantHostname: "test-node",
|
||||
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
|
||||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
if hi.Hostname != "test-node" {
|
||||
t.Errorf("hostname = %v, want test-node", hi.Hostname)
|
||||
}
|
||||
if hi.OS != "linux" {
|
||||
t.Errorf("OS = %v, want linux", hi.OS)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil_hostinfo_creates_default",
|
||||
hostinfo: nil,
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
wantHostname: "node-mkey1234",
|
||||
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
|
||||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
if hi.Hostname != "node-mkey1234" {
|
||||
t.Errorf("hostname = %v, want node-mkey1234", hi.Hostname)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_hostname_updated",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "",
|
||||
OS: "darwin",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
wantHostname: "node-mkey1234",
|
||||
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
|
||||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
if hi.Hostname != "node-mkey1234" {
|
||||
t.Errorf("hostname = %v, want node-mkey1234", hi.Hostname)
|
||||
}
|
||||
if hi.OS != "darwin" {
|
||||
t.Errorf("OS = %v, want darwin", hi.OS)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "long_hostname_truncated",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits-of-63-characters",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
wantHostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits",
|
||||
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
|
||||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
if hi.Hostname != "test-node-with-very-long-hostname-that-exceeds-dns-label-limits" {
|
||||
t.Errorf("hostname = %v, want truncated", hi.Hostname)
|
||||
}
|
||||
if len(hi.Hostname) != 63 {
|
||||
t.Errorf("hostname length = %v, want 63", len(hi.Hostname))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil_hostinfo_node_key_only",
|
||||
hostinfo: nil,
|
||||
machineKey: "",
|
||||
nodeKey: "nkey12345678",
|
||||
wantHostname: "node-nkey1234",
|
||||
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
|
||||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
if hi.Hostname != "node-nkey1234" {
|
||||
t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil_hostinfo_no_keys",
|
||||
hostinfo: nil,
|
||||
machineKey: "",
|
||||
nodeKey: "",
|
||||
wantHostname: "unknown-node",
|
||||
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
|
||||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
if hi.Hostname != "unknown-node" {
|
||||
t.Errorf("hostname = %v, want unknown-node", hi.Hostname)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_hostname_no_keys",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "",
|
||||
},
|
||||
machineKey: "",
|
||||
nodeKey: "",
|
||||
wantHostname: "unknown-node",
|
||||
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
|
||||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
if hi.Hostname != "unknown-node" {
|
||||
t.Errorf("hostname = %v, want unknown-node", hi.Hostname)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "preserves_other_fields",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "test",
|
||||
OS: "windows",
|
||||
OSVersion: "10.0.19044",
|
||||
DeviceModel: "test-device",
|
||||
BackendLogID: "log123",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
wantHostname: "test",
|
||||
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
|
||||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
if hi.Hostname != "test" {
|
||||
t.Errorf("hostname = %v, want test", hi.Hostname)
|
||||
}
|
||||
if hi.OS != "windows" {
|
||||
t.Errorf("OS = %v, want windows", hi.OS)
|
||||
}
|
||||
if hi.OSVersion != "10.0.19044" {
|
||||
t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion)
|
||||
}
|
||||
if hi.DeviceModel != "test-device" {
|
||||
t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel)
|
||||
}
|
||||
if hi.BackendLogID != "log123" {
|
||||
t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exactly_63_chars_unchanged",
|
||||
hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: "123456789012345678901234567890123456789012345678901234567890123",
|
||||
},
|
||||
machineKey: "mkey12345678",
|
||||
nodeKey: "nkey12345678",
|
||||
wantHostname: "123456789012345678901234567890123456789012345678901234567890123",
|
||||
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
|
||||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
if len(hi.Hostname) != 63 {
|
||||
t.Errorf("hostname length = %v, want 63", len(hi.Hostname))
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
gotHostinfo, gotHostname := EnsureValidHostinfo(tt.hostinfo, tt.machineKey, tt.nodeKey)
|
||||
|
||||
if gotHostname != tt.wantHostname {
|
||||
t.Errorf("EnsureValidHostinfo() hostname = %v, want %v", gotHostname, tt.wantHostname)
|
||||
}
|
||||
if gotHostinfo == nil {
|
||||
t.Error("returned hostinfo should never be nil")
|
||||
}
|
||||
|
||||
if tt.checkHostinfo != nil {
|
||||
tt.checkHostinfo(t, gotHostinfo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeHostname_DNSLabelLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []string{
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
|
||||
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
|
||||
"dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd",
|
||||
}
|
||||
|
||||
for i, hostname := range testCases {
|
||||
t.Run(cmp.Diff("", ""), func(t *testing.T) {
|
||||
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
|
||||
result := SafeHostname(hostinfo, "mkey", "nkey")
|
||||
if len(result) > 63 {
|
||||
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureValidHostinfo_Idempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
originalHostinfo := &tailcfg.Hostinfo{
|
||||
Hostname: "test-node",
|
||||
OS: "linux",
|
||||
}
|
||||
|
||||
hostinfo1, hostname1 := EnsureValidHostinfo(originalHostinfo, "mkey", "nkey")
|
||||
hostinfo2, hostname2 := EnsureValidHostinfo(hostinfo1, "mkey", "nkey")
|
||||
|
||||
if hostname1 != hostname2 {
|
||||
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)
|
||||
}
|
||||
if hostinfo1.Hostname != hostinfo2.Hostname {
|
||||
t.Errorf("hostinfo hostnames not equal: %v != %v", hostinfo1.Hostname, hostinfo2.Hostname)
|
||||
}
|
||||
if hostinfo1.OS != hostinfo2.OS {
|
||||
t.Errorf("hostinfo OS not equal: %v != %v", hostinfo1.OS, hostinfo2.OS)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
opts := []hsic.Option{
|
||||
@@ -43,31 +43,25 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
expectedNodes := make([]types.NodeID, 0, len(allClients))
|
||||
for _, client := range allClients {
|
||||
status := client.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
assertNoErr(t, err)
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
}
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 30*time.Second)
|
||||
expectedNodes := collectExpectedNodeIDs(t, allClients)
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 120*time.Second)
|
||||
|
||||
// Validate that all nodes have NetInfo and DERP servers before logout
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 1*time.Minute)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 3*time.Minute)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@@ -97,19 +91,20 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleLogout()
|
||||
assertNoErrLogout(t, err)
|
||||
requireNoErrLogout(t, err)
|
||||
|
||||
// After taking down all nodes, verify all systems show nodes offline
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should have logged out", 120*time.Second)
|
||||
|
||||
t.Logf("all clients logged out")
|
||||
|
||||
t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count")
|
||||
}, 20*time.Second, 1*time.Second)
|
||||
assert.NoError(ct, err, "Failed to list nodes after logout")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
|
||||
}, 30*time.Second, 2*time.Second, "validating node persistence after logout (nodes should remain in database)")
|
||||
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSet(t, node)
|
||||
@@ -125,7 +120,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, userName := range spec.Users {
|
||||
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
|
||||
@@ -139,12 +134,13 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection")
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
assert.NoError(ct, err, "Failed to list nodes after relogin")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
|
||||
}, 60*time.Second, 2*time.Second, "validating node count stability after same-user auth key relogin")
|
||||
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSet(t, node)
|
||||
@@ -152,11 +148,15 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 120*time.Second)
|
||||
|
||||
// Wait for Tailscale sync before validating NetInfo to ensure proper state propagation
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// Validate that all nodes have NetInfo and DERP servers after reconnection
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 1*time.Minute)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 3*time.Minute)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
@@ -197,69 +197,10 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database
|
||||
// and a valid DERP server based on the NetInfo. This function follows the pattern of
|
||||
// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state.
|
||||
func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
startTime := time.Now()
|
||||
t.Logf("requireAllClientsNetInfoAndDERP: Starting validation at %s - %s", startTime.Format(TimestampFormat), message)
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Get nodestore state
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate node counts first
|
||||
expectedCount := len(expectedNodes)
|
||||
assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch")
|
||||
|
||||
// Check each expected node
|
||||
for _, nodeID := range expectedNodes {
|
||||
node, exists := nodeStore[nodeID]
|
||||
assert.True(c, exists, "Node %d not found in nodestore", nodeID)
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate that the node has Hostinfo
|
||||
assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo", nodeID, node.Hostname)
|
||||
if node.Hostinfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate that the node has NetInfo
|
||||
assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo", nodeID, node.Hostname)
|
||||
if node.Hostinfo.NetInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate that the node has a valid DERP server (PreferredDERP should be > 0)
|
||||
preferredDERP := node.Hostinfo.NetInfo.PreferredDERP
|
||||
assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0), got %d", nodeID, node.Hostname, preferredDERP)
|
||||
|
||||
t.Logf("Node %d (%s) has valid NetInfo with DERP server %d", nodeID, node.Hostname, preferredDERP)
|
||||
}
|
||||
}, timeout, 2*time.Second, message)
|
||||
|
||||
endTime := time.Now()
|
||||
duration := endTime.Sub(startTime)
|
||||
t.Logf("requireAllClientsNetInfoAndDERP: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message)
|
||||
}
|
||||
|
||||
func assertLastSeenSet(t *testing.T, node *v1.Node) {
|
||||
assert.NotNil(t, node)
|
||||
assert.NotNil(t, node.GetLastSeen())
|
||||
}
|
||||
|
||||
// This test will first log in two sets of nodes to two sets of users, then
|
||||
// it will log out all users from user2 and log them in as user1.
|
||||
// This should leave us with all nodes connected to user1, while user2
|
||||
// still has nodes, but they are not connected.
|
||||
// it will log out all nodes and log them in as user1 using a pre-auth key.
|
||||
// This should create new nodes for user1 while preserving the original nodes for user2.
|
||||
// Pre-auth key re-authentication with a different user creates new nodes, not transfers.
|
||||
func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
@@ -269,7 +210,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{},
|
||||
@@ -277,18 +218,25 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||
hsic.WithTLS(),
|
||||
hsic.WithDERPAsIP(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
// Collect expected node IDs for validation
|
||||
expectedNodes := collectExpectedNodeIDs(t, allClients)
|
||||
|
||||
// Validate initial connection state
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
@@ -303,12 +251,15 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleLogout()
|
||||
assertNoErrLogout(t, err)
|
||||
requireNoErrLogout(t, err)
|
||||
|
||||
// Validate that all nodes are offline after logout
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second)
|
||||
|
||||
t.Logf("all clients logged out")
|
||||
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new authkey for user1, to be used for all clients
|
||||
key, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), true, false)
|
||||
@@ -326,28 +277,43 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||
}
|
||||
|
||||
var user1Nodes []*v1.Node
|
||||
t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
user1Nodes, err = headscale.ListNodes("user1")
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all clients after re-login")
|
||||
}, 20*time.Second, 1*time.Second)
|
||||
assert.NoError(ct, err, "Failed to list nodes for user1 after relogin")
|
||||
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes))
|
||||
}, 60*time.Second, 2*time.Second, "validating user1 has all client nodes after auth key relogin")
|
||||
|
||||
// Validate that all the old nodes are still present with user2
|
||||
// Collect expected node IDs for user1 after relogin
|
||||
expectedUser1Nodes := make([]types.NodeID, 0, len(user1Nodes))
|
||||
for _, node := range user1Nodes {
|
||||
expectedUser1Nodes = append(expectedUser1Nodes, types.NodeID(node.GetId()))
|
||||
}
|
||||
|
||||
// Validate connection state after relogin as user1
|
||||
requireAllClientsOnline(t, headscale, expectedUser1Nodes, true, "all user1 nodes should be connected after relogin", 120*time.Second)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedUser1Nodes, "all user1 nodes should have NetInfo and DERP after relogin", 3*time.Minute)
|
||||
|
||||
// Validate that user2 still has their original nodes after user1's re-authentication
|
||||
// When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created
|
||||
// for the new user. The original nodes remain with the original user.
|
||||
var user2Nodes []*v1.Node
|
||||
t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
user2Nodes, err = headscale.ListNodes("user2")
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should have half the clients")
|
||||
}, 20*time.Second, 1*time.Second)
|
||||
assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin")
|
||||
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes))
|
||||
}, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)")
|
||||
|
||||
t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat))
|
||||
for _, client := range allClients {
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
|
||||
assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1", client.Hostname())
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName)
|
||||
}, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after auth key user switch", client.Hostname()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,7 +328,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
opts := []hsic.Option{
|
||||
@@ -376,13 +342,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@@ -396,7 +362,14 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
}
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
// Collect expected node IDs for validation
|
||||
expectedNodes := collectExpectedNodeIDs(t, allClients)
|
||||
|
||||
// Validate initial connection state
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
@@ -411,7 +384,10 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleLogout()
|
||||
assertNoErrLogout(t, err)
|
||||
requireNoErrLogout(t, err)
|
||||
|
||||
// Validate that all nodes are offline after logout
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second)
|
||||
|
||||
t.Logf("all clients logged out")
|
||||
|
||||
@@ -425,7 +401,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
}
|
||||
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, userName := range spec.Users {
|
||||
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
|
||||
@@ -443,7 +419,8 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
"expire",
|
||||
key.GetKey(),
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
|
||||
assert.ErrorContains(t, err, "authkey expired")
|
||||
|
||||
@@ -5,17 +5,20 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/oauth2-proxy/mockoidc"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||
@@ -34,7 +37,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
@@ -52,16 +55,16 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||
hsic.WithTLS(),
|
||||
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@@ -73,10 +76,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
want := []*v1.User{
|
||||
{
|
||||
@@ -142,7 +145,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
oidcMap := map[string]string{
|
||||
@@ -157,18 +160,18 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||
hsic.WithTestName("oidcexpirenodes"),
|
||||
hsic.WithConfigEnv(oidcMap),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
// Record when sync completes to better estimate token expiry timing
|
||||
syncCompleteTime := time.Now()
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
loginDuration := time.Since(syncCompleteTime)
|
||||
t.Logf("Login and sync completed in %v", loginDuration)
|
||||
|
||||
@@ -349,7 +352,7 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
oidcMap := map[string]string{
|
||||
@@ -367,20 +370,20 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||
hsic.WithTLS(),
|
||||
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
// Ensure that the nodes have logged in, this is what
|
||||
// triggers user creation via OIDC.
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
want := tt.want(scenario.mockOIDC.Issuer())
|
||||
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
sort.Slice(listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
@@ -406,7 +409,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
oidcMap := map[string]string{
|
||||
@@ -424,17 +427,17 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
||||
hsic.WithTLS(),
|
||||
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
// Get all clients and verify they can connect
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
@@ -444,6 +447,11 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
}
|
||||
|
||||
// TestOIDCReloginSameNodeNewUser tests the scenario where:
|
||||
// 1. A Tailscale client logs in with user1 (creates node1 for user1)
|
||||
// 2. The same client logs out and logs in with user2 (creates node2 for user2)
|
||||
// 3. The same client logs out and logs in with user1 again (reuses node1, node2 remains)
|
||||
// This validates that OIDC relogin properly handles node reuse and cleanup.
|
||||
func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
@@ -458,7 +466,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
oidcMockUser("user1", true),
|
||||
},
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
oidcMap := map[string]string{
|
||||
@@ -477,24 +485,25 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithDERPAsIP(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := ts.LoginWithURL(headscale.GetEndpoint())
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = doLoginURL(ts.Hostname(), u)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Validating initial user creation at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listUsers, 1)
|
||||
assert.NoError(ct, err, "Failed to list users during initial validation")
|
||||
assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers))
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
@@ -510,44 +519,61 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
})
|
||||
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
t.Fatalf("unexpected users: %s", diff)
|
||||
ct.Errorf("User validation failed after first login - unexpected users: %s", diff)
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "validating users after first login")
|
||||
}, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login")
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listNodes, 1)
|
||||
t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat))
|
||||
var listNodes []*v1.Node
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes during initial validation")
|
||||
assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes))
|
||||
}, 30*time.Second, 1*time.Second, "validating initial node creation for user1 after OIDC login")
|
||||
|
||||
// Collect expected node IDs for validation after user1 initial login
|
||||
expectedNodes := make([]types.NodeID, 0, 1)
|
||||
status := ts.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
require.NoError(t, err)
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
|
||||
// Validate initial connection state for user1
|
||||
validateInitialConnection(t, headscale, expectedNodes)
|
||||
|
||||
// Log out user1 and log in user2, this should create a new node
|
||||
// for user2, the node should have the same machine key and
|
||||
// a new node key.
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
|
||||
// logs in immediately after the first logout and I cannot reproduce it
|
||||
// manually.
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for logout to complete and then do second logout
|
||||
t.Logf("Waiting for user1 logout completion at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Check that the first logout completed
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second)
|
||||
assert.NoError(ct, err, "Failed to get client status during logout validation")
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout, got %s", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second, "waiting for user1 logout to complete before user2 login")
|
||||
|
||||
u, err = ts.LoginWithURL(headscale.GetEndpoint())
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = doLoginURL(ts.Hostname(), u)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Validating user2 creation at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listUsers, 2)
|
||||
assert.NoError(ct, err, "Failed to list users after user2 login")
|
||||
assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers))
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
@@ -570,27 +596,83 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
})
|
||||
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
ct.Errorf("unexpected users: %s", diff)
|
||||
ct.Errorf("User validation failed after user2 login - expected both user1 and user2: %s", diff)
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "validating users after new user login")
|
||||
}, 30*time.Second, 1*time.Second, "validating both user1 and user2 exist after second OIDC login")
|
||||
|
||||
var listNodesAfterNewUserLogin []*v1.Node
|
||||
// First, wait for the new node to be created
|
||||
t.Logf("Waiting for user2 node creation at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listNodesAfterNewUserLogin, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listNodesAfterNewUserLogin, 2)
|
||||
assert.NoError(ct, err, "Failed to list nodes after user2 login")
|
||||
// We might temporarily have more than 2 nodes during cleanup, so check for at least 2
|
||||
assert.GreaterOrEqual(ct, len(listNodesAfterNewUserLogin), 2, "Should have at least 2 nodes after user2 login, got %d (may include temporary nodes during cleanup)", len(listNodesAfterNewUserLogin))
|
||||
}, 30*time.Second, 1*time.Second, "waiting for user2 node creation (allowing temporary extra nodes during cleanup)")
|
||||
|
||||
// Machine key is the same as the "machine" has not changed,
|
||||
// but Node key is not as it is a new node
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
|
||||
assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
|
||||
assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
|
||||
}, 30*time.Second, 1*time.Second, "listing nodes after new user login")
|
||||
// Then wait for cleanup to stabilize at exactly 2 nodes
|
||||
t.Logf("Waiting for node cleanup stabilization at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listNodesAfterNewUserLogin, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes during cleanup validation")
|
||||
assert.Len(ct, listNodesAfterNewUserLogin, 2, "Should have exactly 2 nodes after cleanup (1 for user1, 1 for user2), got %d nodes", len(listNodesAfterNewUserLogin))
|
||||
|
||||
// Validate that both nodes have the same machine key but different node keys
|
||||
if len(listNodesAfterNewUserLogin) >= 2 {
|
||||
// Machine key is the same as the "machine" has not changed,
|
||||
// but Node key is not as it is a new node
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Machine key should be preserved from original node")
|
||||
assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "Both nodes should share the same machine key")
|
||||
assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey(), "Node keys should be different between user1 and user2 nodes")
|
||||
}
|
||||
}, 90*time.Second, 2*time.Second, "waiting for node count stabilization at exactly 2 nodes after user2 login")
|
||||
|
||||
// Security validation: Only user2's node should be active after user switch
|
||||
var activeUser2NodeID types.NodeID
|
||||
for _, node := range listNodesAfterNewUserLogin {
|
||||
if node.GetUser().GetId() == 2 { // user2
|
||||
activeUser2NodeID = types.NodeID(node.GetId())
|
||||
t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName())
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Validate only user2's node is online (security requirement)
|
||||
t.Logf("Validating only user2 node is online at %s", time.Now().Format(TimestampFormat))
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
|
||||
// Check user2 node is online
|
||||
if node, exists := nodeStore[activeUser2NodeID]; exists {
|
||||
assert.NotNil(c, node.IsOnline, "User2 node should have online status")
|
||||
if node.IsOnline != nil {
|
||||
assert.True(c, *node.IsOnline, "User2 node should be online after login")
|
||||
}
|
||||
} else {
|
||||
assert.Fail(c, "User2 node not found in nodestore")
|
||||
}
|
||||
}, 60*time.Second, 2*time.Second, "validating only user2 node is online after user switch")
|
||||
|
||||
// Before logging out user2, validate we have exactly 2 nodes and both are stable
|
||||
t.Logf("Pre-logout validation: checking node stability at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
currentNodes, err := headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes before user2 logout")
|
||||
assert.Len(ct, currentNodes, 2, "Should have exactly 2 stable nodes before user2 logout, got %d", len(currentNodes))
|
||||
|
||||
// Validate node stability - ensure no phantom nodes
|
||||
for i, node := range currentNodes {
|
||||
assert.NotNil(ct, node.GetUser(), "Node %d should have a valid user before logout", i)
|
||||
assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should have a valid machine key before logout", i)
|
||||
t.Logf("Pre-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...")
|
||||
}
|
||||
}, 60*time.Second, 2*time.Second, "validating stable node count and integrity before user2 logout")
|
||||
|
||||
// Log out user2, and log into user1, no new node should be created,
|
||||
// the node should now "become" node1 again
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Logged out take one")
|
||||
t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
|
||||
@@ -599,41 +681,63 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
// logs in immediately after the first logout and I cannot reproduce it
|
||||
// manually.
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Logged out take two")
|
||||
t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
|
||||
|
||||
// Wait for logout to complete and then do second logout
|
||||
t.Logf("Waiting for user2 logout completion at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Check that the first logout completed
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second)
|
||||
assert.NoError(ct, err, "Failed to get client status during user2 logout validation")
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after user2 logout, got %s", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second, "waiting for user2 logout to complete before user1 relogin")
|
||||
|
||||
// Before logging back in, ensure we still have exactly 2 nodes
|
||||
// Note: We skip validateLogoutComplete here since it expects all nodes to be offline,
|
||||
// but in OIDC scenario we maintain both nodes in DB with only active user online
|
||||
|
||||
// Additional validation that nodes are properly maintained during logout
|
||||
t.Logf("Post-logout validation: checking node persistence at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
currentNodes, err := headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes after user2 logout")
|
||||
assert.Len(ct, currentNodes, 2, "Should still have exactly 2 nodes after user2 logout (nodes should persist), got %d", len(currentNodes))
|
||||
|
||||
// Ensure both nodes are still valid (not cleaned up incorrectly)
|
||||
for i, node := range currentNodes {
|
||||
assert.NotNil(ct, node.GetUser(), "Node %d should still have a valid user after user2 logout", i)
|
||||
assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should still have a valid machine key after user2 logout", i)
|
||||
t.Logf("Post-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...")
|
||||
}
|
||||
}, 60*time.Second, 2*time.Second, "validating node persistence and integrity after user2 logout")
|
||||
|
||||
// We do not actually "change" the user here, it is done by logging in again
|
||||
// as the OIDC mock server is kind of like a stack, and the next user is
|
||||
// prepared and ready to go.
|
||||
u, err = ts.LoginWithURL(headscale.GetEndpoint())
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = doLoginURL(ts.Hostname(), u)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Waiting for user1 relogin completion at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "Running", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second)
|
||||
assert.NoError(ct, err, "Failed to get client status during user1 relogin validation")
|
||||
assert.Equal(ct, "Running", status.BackendState, "Expected Running state after user1 relogin, got %s", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second, "waiting for user1 relogin to complete (final login)")
|
||||
|
||||
t.Logf("Logged back in")
|
||||
t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
|
||||
|
||||
t.Logf("Final validation: checking user persistence at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listUsers, 2)
|
||||
assert.NoError(ct, err, "Failed to list users during final validation")
|
||||
assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers))
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
@@ -656,37 +760,77 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
})
|
||||
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
ct.Errorf("unexpected users: %s", diff)
|
||||
ct.Errorf("Final user validation failed - both users should persist after relogin cycle: %s", diff)
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
|
||||
}, 30*time.Second, 1*time.Second, "validating user persistence after complete relogin cycle (user1->user2->user1)")
|
||||
|
||||
var listNodesAfterLoggingBackIn []*v1.Node
|
||||
// Wait for login to complete and nodes to stabilize
|
||||
t.Logf("Final node validation: checking node stability after user1 relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listNodesAfterLoggingBackIn, err := headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listNodesAfterLoggingBackIn, 2)
|
||||
listNodesAfterLoggingBackIn, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes during final validation")
|
||||
|
||||
// Allow for temporary instability during login process
|
||||
if len(listNodesAfterLoggingBackIn) < 2 {
|
||||
ct.Errorf("Not enough nodes yet during final validation, got %d, want at least 2", len(listNodesAfterLoggingBackIn))
|
||||
return
|
||||
}
|
||||
|
||||
// Final check should have exactly 2 nodes
|
||||
assert.Len(ct, listNodesAfterLoggingBackIn, 2, "Should have exactly 2 nodes after complete relogin cycle, got %d", len(listNodesAfterLoggingBackIn))
|
||||
|
||||
// Validate that the machine we had when we logged in the first time, has the same
|
||||
// machine key, but a different ID than the newly logged in version of the same
|
||||
// machine.
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
|
||||
assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
|
||||
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
|
||||
assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
|
||||
assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Original user1 machine key should match user1 node after user switch")
|
||||
assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey(), "Original user1 node key should match user1 node after user switch")
|
||||
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId(), "Original user1 node ID should match user1 node after user switch")
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "User1 and user2 nodes should share the same machine key")
|
||||
assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId(), "User1 and user2 nodes should have different node IDs")
|
||||
assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId(), "User1 and user2 nodes should belong to different users")
|
||||
|
||||
// Even tho we are logging in again with the same user, the previous key has been expired
|
||||
// and a new one has been generated. The node entry in the database should be the same
|
||||
// as the user + machinekey still matches.
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
|
||||
assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
|
||||
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
|
||||
assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey(), "Machine key should remain consistent after user1 relogin")
|
||||
assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey(), "Node key should be regenerated after user1 relogin")
|
||||
assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId(), "Node ID should be preserved for user1 after relogin")
|
||||
|
||||
// The "logged back in" machine should have the same machinekey but a different nodekey
|
||||
// than the version logged in with a different user.
|
||||
assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
|
||||
assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
|
||||
}, 30*time.Second, 1*time.Second, "log out user2, and log into user1, no new node should be created")
|
||||
assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey(), "Both final nodes should share the same machine key")
|
||||
assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey(), "Final nodes should have different node keys for different users")
|
||||
|
||||
t.Logf("Final validation complete - node counts and key relationships verified at %s", time.Now().Format(TimestampFormat))
|
||||
}, 60*time.Second, 2*time.Second, "validating final node state after complete user1->user2->user1 relogin cycle with detailed key validation")
|
||||
|
||||
// Security validation: Only user1's node should be active after relogin
|
||||
var activeUser1NodeID types.NodeID
|
||||
for _, node := range listNodesAfterLoggingBackIn {
|
||||
if node.GetUser().GetId() == 1 { // user1
|
||||
activeUser1NodeID = types.NodeID(node.GetId())
|
||||
t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName())
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Validate only user1's node is online (security requirement)
|
||||
t.Logf("Validating only user1 node is online after relogin at %s", time.Now().Format(TimestampFormat))
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
|
||||
// Check user1 node is online
|
||||
if node, exists := nodeStore[activeUser1NodeID]; exists {
|
||||
assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin")
|
||||
if node.IsOnline != nil {
|
||||
assert.True(c, *node.IsOnline, "User1 node should be online after relogin")
|
||||
}
|
||||
} else {
|
||||
assert.Fail(c, "User1 node not found in nodestore after relogin")
|
||||
}
|
||||
}, 60*time.Second, 2*time.Second, "validating only user1 node is online after final relogin")
|
||||
}
|
||||
|
||||
// TestOIDCFollowUpUrl validates the follow-up login flow
|
||||
@@ -709,7 +853,7 @@ func TestOIDCFollowUpUrl(t *testing.T) {
|
||||
},
|
||||
)
|
||||
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
oidcMap := map[string]string{
|
||||
@@ -730,43 +874,43 @@ func TestOIDCFollowUpUrl(t *testing.T) {
|
||||
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, listUsers)
|
||||
|
||||
ts, err := scenario.CreateTailscaleNode(
|
||||
"unstable",
|
||||
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := ts.LoginWithURL(headscale.GetEndpoint())
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// wait for the registration cache to expire
|
||||
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
|
||||
time.Sleep(2 * time.Minute)
|
||||
|
||||
st, err := ts.Status()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "NeedsLogin", st.BackendState)
|
||||
|
||||
// get new AuthURL from daemon
|
||||
newUrl, err := url.Parse(st.AuthURL)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
|
||||
|
||||
_, err = doLoginURL(ts.Hostname(), newUrl)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
listUsers, err = headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, listUsers, 1)
|
||||
|
||||
wantUsers := []*v1.User{
|
||||
@@ -795,30 +939,230 @@ func TestOIDCFollowUpUrl(t *testing.T) {
|
||||
}
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, listNodes, 1)
|
||||
}
|
||||
|
||||
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
|
||||
// are in the logged-out state (NeedsLogin).
|
||||
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
|
||||
if h, ok := t.(interface{ Helper() }); ok {
|
||||
h.Helper()
|
||||
// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client
|
||||
// authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user.
|
||||
//
|
||||
// OIDC is an authentication layer built on top of OAuth 2.0 that allows users to authenticate
|
||||
// using external identity providers (like Google, Microsoft, etc.) rather than managing
|
||||
// credentials directly in headscale.
|
||||
//
|
||||
// This test validates the "same user relogin" behavior in headscale's OIDC authentication flow:
|
||||
// - A single client authenticates via OIDC as user1
|
||||
// - The client logs out, ending the session
|
||||
// - The same client logs back in via OIDC as the same user (user1)
|
||||
// - The test verifies that the user account persists correctly
|
||||
// - The test verifies that the machine key is preserved (since it's the same physical device)
|
||||
// - The test verifies that the node ID is preserved (since it's the same user on the same device)
|
||||
// - The test verifies that the node key is regenerated (since it's a new session)
|
||||
// - The test verifies that the client comes back online properly
|
||||
//
|
||||
// This scenario is important for normal user workflows where someone might need to restart
|
||||
// their Tailscale client, reboot their computer, or temporarily disconnect and reconnect.
|
||||
// It ensures that headscale properly handles session management while preserving device
|
||||
// identity and user associations.
|
||||
//
|
||||
// The test uses a single node scenario (unlike multi-node tests) to focus specifically on
|
||||
// the authentication and session management aspects rather than network topology changes.
|
||||
// The "same node" in the name refers to the same physical device/client, while "same user"
|
||||
// refers to authenticating with the same OIDC identity.
|
||||
func TestOIDCReloginSameNodeSameUser(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
// Create scenario with same user for both login attempts
|
||||
scenario, err := NewScenario(ScenarioSpec{
|
||||
OIDCUsers: []mockoidc.MockUser{
|
||||
oidcMockUser("user1", true), // Initial login
|
||||
oidcMockUser("user1", true), // Relogin with same user
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
oidcMap := map[string]string{
|
||||
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
|
||||
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
|
||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||
}
|
||||
|
||||
for _, client := range clients {
|
||||
status, err := client.Status()
|
||||
assert.NoError(t, err, "failed to get status for client %s", client.Hostname())
|
||||
assert.Equal(t, "NeedsLogin", status.BackendState,
|
||||
"client %s should be logged out", client.Hostname())
|
||||
}
|
||||
}
|
||||
err = scenario.CreateHeadscaleEnvWithLoginURL(
|
||||
nil,
|
||||
hsic.WithTestName("oidcsameuser"),
|
||||
hsic.WithConfigEnv(oidcMap),
|
||||
hsic.WithTLS(),
|
||||
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithDERPAsIP(),
|
||||
)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
|
||||
return mockoidc.MockUser{
|
||||
Subject: username,
|
||||
PreferredUsername: username,
|
||||
Email: username + "@headscale.net",
|
||||
EmailVerified: emailVerified,
|
||||
}
|
||||
headscale, err := scenario.Headscale()
|
||||
require.NoError(t, err)
|
||||
|
||||
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initial login as user1
|
||||
u, err := ts.LoginWithURL(headscale.GetEndpoint())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = doLoginURL(ts.Hostname(), u)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Validating initial user1 creation at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assert.NoError(ct, err, "Failed to list users during initial validation")
|
||||
assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers))
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
Name: "user1",
|
||||
Email: "user1@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||
},
|
||||
}
|
||||
|
||||
sort.Slice(listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
})
|
||||
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
ct.Errorf("User validation failed after first login - unexpected users: %s", diff)
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login")
|
||||
|
||||
t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat))
|
||||
var initialNodes []*v1.Node
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
initialNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes during initial validation")
|
||||
assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes))
|
||||
}, 30*time.Second, 1*time.Second, "validating initial node creation for user1 after OIDC login")
|
||||
|
||||
// Collect expected node IDs for validation after user1 initial login
|
||||
expectedNodes := make([]types.NodeID, 0, 1)
|
||||
status := ts.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
require.NoError(t, err)
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
|
||||
// Validate initial connection state for user1
|
||||
validateInitialConnection(t, headscale, expectedNodes)
|
||||
|
||||
// Store initial node keys for comparison
|
||||
initialMachineKey := initialNodes[0].GetMachineKey()
|
||||
initialNodeKey := initialNodes[0].GetNodeKey()
|
||||
initialNodeID := initialNodes[0].GetId()
|
||||
|
||||
// Logout user1
|
||||
err = ts.Logout()
|
||||
require.NoError(t, err)
|
||||
|
||||
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
|
||||
// logs in immediately after the first logout and I cannot reproduce it
|
||||
// manually.
|
||||
err = ts.Logout()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for logout to complete
|
||||
t.Logf("Waiting for user1 logout completion at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Check that the logout completed
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err, "Failed to get client status during logout validation")
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout, got %s", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second, "waiting for user1 logout to complete before same-user relogin")
|
||||
|
||||
// Validate node persistence during logout (node should remain in DB)
|
||||
t.Logf("Validating node persistence during logout at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes during logout validation")
|
||||
assert.Len(ct, listNodes, 1, "Should still have exactly 1 node during logout (node should persist in DB), got %d", len(listNodes))
|
||||
}, 30*time.Second, 1*time.Second, "validating node persistence in database during same-user logout")
|
||||
|
||||
// Login again as the same user (user1)
|
||||
u, err = ts.LoginWithURL(headscale.GetEndpoint())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = doLoginURL(ts.Hostname(), u)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Waiting for user1 relogin completion at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err, "Failed to get client status during relogin validation")
|
||||
assert.Equal(ct, "Running", status.BackendState, "Expected Running state after user1 relogin, got %s", status.BackendState)
|
||||
}, 30*time.Second, 1*time.Second, "waiting for user1 relogin to complete (same user)")
|
||||
|
||||
t.Logf("Final validation: checking user persistence after same-user relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assert.NoError(ct, err, "Failed to list users during final validation")
|
||||
assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers))
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
Name: "user1",
|
||||
Email: "user1@headscale.net",
|
||||
Provider: "oidc",
|
||||
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||
},
|
||||
}
|
||||
|
||||
sort.Slice(listUsers, func(i, j int) bool {
|
||||
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||
})
|
||||
|
||||
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||
ct.Errorf("Final user validation failed - user1 should persist after same-user relogin: %s", diff)
|
||||
}
|
||||
}, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle")
|
||||
|
||||
var finalNodes []*v1.Node
|
||||
t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
finalNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes during final validation")
|
||||
assert.Len(ct, finalNodes, 1, "Should have exactly 1 node after same-user relogin, got %d", len(finalNodes))
|
||||
|
||||
// Validate node key behavior for same user relogin
|
||||
finalNode := finalNodes[0]
|
||||
|
||||
// Machine key should be preserved (same physical machine)
|
||||
assert.Equal(ct, initialMachineKey, finalNode.GetMachineKey(), "Machine key should be preserved for same user same node relogin")
|
||||
|
||||
// Node ID should be preserved (same user, same machine)
|
||||
assert.Equal(ct, initialNodeID, finalNode.GetId(), "Node ID should be preserved for same user same node relogin")
|
||||
|
||||
// Node key should be regenerated (new session after logout)
|
||||
assert.NotEqual(ct, initialNodeKey, finalNode.GetNodeKey(), "Node key should be regenerated after logout/relogin even for same user")
|
||||
|
||||
t.Logf("Final validation complete - same user relogin key relationships verified at %s", time.Now().Format(TimestampFormat))
|
||||
}, 60*time.Second, 2*time.Second, "validating final node state after same-user OIDC relogin cycle with key preservation validation")
|
||||
|
||||
// Security validation: user1's node should be active after relogin
|
||||
activeUser1NodeID := types.NodeID(finalNodes[0].GetId())
|
||||
t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat))
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
|
||||
// Check user1 node is online
|
||||
if node, exists := nodeStore[activeUser1NodeID]; exists {
|
||||
assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin")
|
||||
if node.IsOnline != nil {
|
||||
assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin")
|
||||
}
|
||||
} else {
|
||||
assert.Fail(c, "User1 node not found in nodestore after same-user relogin")
|
||||
}
|
||||
}, 60*time.Second, 2*time.Second, "validating user1 node is online after same-user OIDC relogin")
|
||||
}
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/integrationutil"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
||||
@@ -33,16 +37,16 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
||||
hsic.WithDERPAsIP(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@@ -54,7 +58,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
}
|
||||
|
||||
func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
spec := ScenarioSpec{
|
||||
@@ -63,7 +67,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnvWithLoginURL(
|
||||
@@ -72,16 +76,16 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
hsic.WithDERPAsIP(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@@ -93,15 +97,22 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
// Collect expected node IDs for validation
|
||||
expectedNodes := collectExpectedNodeIDs(t, allClients)
|
||||
|
||||
// Validate initial connection state
|
||||
validateInitialConnection(t, headscale, expectedNodes)
|
||||
|
||||
var listNodes []*v1.Node
|
||||
t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listNodes, len(allClients), "Node count should match client count after login")
|
||||
}, 20*time.Second, 1*time.Second)
|
||||
assert.NoError(ct, err, "Failed to list nodes after web authentication")
|
||||
assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes))
|
||||
}, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication")
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
@@ -122,7 +133,10 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleLogout()
|
||||
assertNoErrLogout(t, err)
|
||||
requireNoErrLogout(t, err)
|
||||
|
||||
// Validate that all nodes are offline after logout
|
||||
validateLogoutComplete(t, headscale, expectedNodes)
|
||||
|
||||
t.Logf("all clients logged out")
|
||||
|
||||
@@ -135,8 +149,20 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
|
||||
t.Logf("all clients logged in again")
|
||||
|
||||
t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes after web flow logout")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
|
||||
}, 60*time.Second, 2*time.Second, "validating node persistence in database after web flow logout")
|
||||
t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes))
|
||||
|
||||
// Validate connection state after relogin
|
||||
validateReloginComplete(t, headscale, expectedNodes)
|
||||
|
||||
allIps, err = scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
allAddrs = lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
@@ -145,14 +171,6 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
success = pingAllHelper(t, allClients, allAddrs)
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count after re-login")
|
||||
}, 20*time.Second, 1*time.Second)
|
||||
t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes))
|
||||
|
||||
for _, client := range allClients {
|
||||
ips, err := client.IPs()
|
||||
if err != nil {
|
||||
@@ -180,3 +198,166 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
|
||||
t.Logf("all clients IPs are the same")
|
||||
}
|
||||
|
||||
// TestAuthWebFlowLogoutAndReloginNewUser tests the scenario where multiple Tailscale clients
|
||||
// initially authenticate using the web-based authentication flow (where users visit a URL
|
||||
// in their browser to authenticate), then all clients log out and log back in as a different user.
|
||||
//
|
||||
// This test validates the "user switching" behavior in headscale's web authentication flow:
|
||||
// - Multiple clients authenticate via web flow, each to their respective users (user1, user2)
|
||||
// - All clients log out simultaneously
|
||||
// - All clients log back in via web flow, but this time they all authenticate as user1
|
||||
// - The test verifies that user1 ends up with all the client nodes
|
||||
// - The test verifies that user2's original nodes still exist in the database but are offline
|
||||
// - The test verifies network connectivity works after the user switch
|
||||
//
|
||||
// This scenario is important for organizations that need to reassign devices between users
|
||||
// or when consolidating multiple user accounts. It ensures that headscale properly handles
|
||||
// the security implications of user switching while maintaining node persistence in the database.
|
||||
//
|
||||
// The test uses headscale's web authentication flow, which is the most user-friendly method
|
||||
// where authentication happens through a web browser rather than pre-shared keys or OIDC.
|
||||
func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
Users: []string{"user1", "user2"},
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnvWithLoginURL(
|
||||
nil,
|
||||
hsic.WithTestName("webflowrelnewuser"),
|
||||
hsic.WithDERPAsIP(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
// Collect expected node IDs for validation
|
||||
expectedNodes := collectExpectedNodeIDs(t, allClients)
|
||||
|
||||
// Validate initial connection state
|
||||
validateInitialConnection(t, headscale, expectedNodes)
|
||||
|
||||
var listNodes []*v1.Node
|
||||
t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes after initial web authentication")
|
||||
assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes))
|
||||
}, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication")
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
// Log out all clients
|
||||
for _, client := range allClients {
|
||||
err := client.Logout()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to logout client %s: %s", client.Hostname(), err)
|
||||
}
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleLogout()
|
||||
requireNoErrLogout(t, err)
|
||||
|
||||
// Validate that all nodes are offline after logout
|
||||
validateLogoutComplete(t, headscale, expectedNodes)
|
||||
|
||||
t.Logf("all clients logged out")
|
||||
|
||||
// Log all clients back in as user1 using web flow
|
||||
// We manually iterate over all clients and authenticate each one as user1
|
||||
// This tests the cross-user re-authentication behavior where ALL clients
|
||||
// (including those originally from user2) are registered to user1
|
||||
for _, client := range allClients {
|
||||
loginURL, err := client.LoginWithURL(headscale.GetEndpoint())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get login URL for client %s: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
body, err := doLoginURL(client.Hostname(), loginURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to complete login for client %s: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
// Register all clients as user1 (this is where cross-user registration happens)
|
||||
// This simulates: headscale nodes register --user user1 --key <key>
|
||||
scenario.runHeadscaleRegister("user1", body)
|
||||
}
|
||||
|
||||
// Wait for all clients to reach running state
|
||||
for _, client := range allClients {
|
||||
err := client.WaitForRunning(integrationutil.PeerSyncTimeout())
|
||||
if err != nil {
|
||||
t.Fatalf("%s tailscale node has not reached running: %s", client.Hostname(), err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("all clients logged back in as user1")
|
||||
|
||||
var user1Nodes []*v1.Node
|
||||
t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
user1Nodes, err = headscale.ListNodes("user1")
|
||||
assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin")
|
||||
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes))
|
||||
}, 60*time.Second, 2*time.Second, "validating user1 has all client nodes after web flow user switch relogin")
|
||||
|
||||
// Collect expected node IDs for user1 after relogin
|
||||
expectedUser1Nodes := make([]types.NodeID, 0, len(user1Nodes))
|
||||
for _, node := range user1Nodes {
|
||||
expectedUser1Nodes = append(expectedUser1Nodes, types.NodeID(node.GetId()))
|
||||
}
|
||||
|
||||
// Validate connection state after relogin as user1
|
||||
validateReloginComplete(t, headscale, expectedUser1Nodes)
|
||||
|
||||
// Validate that user2's old nodes still exist in database (but are expired/offline)
|
||||
// When CLI registration creates new nodes for user1, user2's old nodes remain
|
||||
var user2Nodes []*v1.Node
|
||||
t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
user2Nodes, err = headscale.ListNodes("user2")
|
||||
assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1")
|
||||
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes))
|
||||
}, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1")
|
||||
|
||||
t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat))
|
||||
for _, client := range allClients {
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
|
||||
assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after web flow user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName)
|
||||
}, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after web flow user switch", client.Hostname()))
|
||||
}
|
||||
|
||||
// Test connectivity after user switch
|
||||
allIps, err = scenario.ListTailscaleClientsIPs()
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
t.Logf("%d successful pings out of %d after web flow user switch", success, len(allClients)*len(allIps))
|
||||
}
|
||||
|
||||
@@ -54,14 +54,14 @@ func TestUserCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var listUsers []*v1.User
|
||||
var result []string
|
||||
@@ -99,7 +99,7 @@ func TestUserCommand(t *testing.T) {
|
||||
"--new-name=newname",
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var listAfterRenameUsers []*v1.User
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
@@ -138,7 +138,7 @@ func TestUserCommand(t *testing.T) {
|
||||
},
|
||||
&listByUsername,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
slices.SortFunc(listByUsername, sortWithID)
|
||||
want := []*v1.User{
|
||||
@@ -165,7 +165,7 @@ func TestUserCommand(t *testing.T) {
|
||||
},
|
||||
&listByID,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
slices.SortFunc(listByID, sortWithID)
|
||||
want = []*v1.User{
|
||||
@@ -244,7 +244,7 @@ func TestUserCommand(t *testing.T) {
|
||||
},
|
||||
&listAfterNameDelete,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Empty(t, listAfterNameDelete)
|
||||
}
|
||||
@@ -260,17 +260,17 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipak"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys := make([]*v1.PreAuthKey, count)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for index := range count {
|
||||
var preAuthKey v1.PreAuthKey
|
||||
@@ -292,7 +292,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
},
|
||||
&preAuthKey,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys[index] = &preAuthKey
|
||||
}
|
||||
@@ -313,7 +313,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
},
|
||||
&listedPreAuthKeys,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// There is one key created by "scenario.CreateHeadscaleEnv"
|
||||
assert.Len(t, listedPreAuthKeys, 4)
|
||||
@@ -372,7 +372,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
listedPreAuthKeys[1].GetKey(),
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var listedPreAuthKeysAfterExpire []v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
@@ -388,7 +388,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
},
|
||||
&listedPreAuthKeysAfterExpire,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now()))
|
||||
assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now()))
|
||||
@@ -404,14 +404,14 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipaknaexp"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var preAuthKey v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
@@ -428,7 +428,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||
},
|
||||
&preAuthKey,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var listedPreAuthKeys []v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
@@ -444,7 +444,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||
},
|
||||
&listedPreAuthKeys,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// There is one key created by "scenario.CreateHeadscaleEnv"
|
||||
assert.Len(t, listedPreAuthKeys, 2)
|
||||
@@ -465,14 +465,14 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipakresueeph"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var preAuthReusableKey v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
@@ -489,7 +489,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||
},
|
||||
&preAuthReusableKey,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var preAuthEphemeralKey v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
@@ -506,7 +506,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||
},
|
||||
&preAuthEphemeralKey,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, preAuthEphemeralKey.GetEphemeral())
|
||||
assert.False(t, preAuthEphemeralKey.GetReusable())
|
||||
@@ -525,7 +525,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||
},
|
||||
&listedPreAuthKeys,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// There is one key created by "scenario.CreateHeadscaleEnv"
|
||||
assert.Len(t, listedPreAuthKeys, 3)
|
||||
@@ -543,7 +543,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
@@ -552,13 +552,13 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
u2, err := headscale.CreateUser(user2)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var user2Key v1.PreAuthKey
|
||||
|
||||
@@ -580,7 +580,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
},
|
||||
&user2Key,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var listNodes []*v1.Node
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
@@ -592,7 +592,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
}, 15*time.Second, 1*time.Second)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
require.Len(t, allClients, 1)
|
||||
|
||||
@@ -600,10 +600,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
|
||||
// Log out from user1
|
||||
err = client.Logout()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleLogout()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
@@ -613,7 +613,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
|
||||
err = client.Login(headscale.GetEndpoint(), user2Key.GetKey())
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
@@ -642,14 +642,14 @@ func TestApiKeyCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys := make([]string, count)
|
||||
|
||||
@@ -808,14 +808,14 @@ func TestNodeTagCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
regIDs := []string{
|
||||
types.MustRegistrationID().String(),
|
||||
@@ -1007,7 +1007,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
@@ -1015,10 +1015,10 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||
hsic.WithTestName("cliadvtags"),
|
||||
hsic.WithACLPolicy(tt.policy),
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test list all nodes after added seconds
|
||||
resultMachines := make([]*v1.Node, spec.NodesPerUser)
|
||||
@@ -1058,14 +1058,14 @@ func TestNodeCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
regIDs := []string{
|
||||
types.MustRegistrationID().String(),
|
||||
@@ -1302,14 +1302,14 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
regIDs := []string{
|
||||
types.MustRegistrationID().String(),
|
||||
@@ -1427,14 +1427,14 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
regIDs := []string{
|
||||
types.MustRegistrationID().String(),
|
||||
@@ -1462,7 +1462,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
"json",
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var node v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
@@ -1480,7 +1480,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodes[index] = &node
|
||||
}
|
||||
@@ -1591,20 +1591,20 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Randomly generated node key
|
||||
regID := types.MustRegistrationID()
|
||||
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
@@ -1753,7 +1753,7 @@ func TestPolicyCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
@@ -1763,10 +1763,10 @@ func TestPolicyCommand(t *testing.T) {
|
||||
"HEADSCALE_POLICY_MODE": "database",
|
||||
}),
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
p := policyv2.Policy{
|
||||
ACLs: []policyv2.ACL{
|
||||
@@ -1789,7 +1789,7 @@ func TestPolicyCommand(t *testing.T) {
|
||||
policyFilePath := "/etc/headscale/policy.json"
|
||||
|
||||
err = headscale.WriteFile(policyFilePath, pBytes)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// No policy is present at this time.
|
||||
// Add a new policy from a file.
|
||||
@@ -1803,7 +1803,7 @@ func TestPolicyCommand(t *testing.T) {
|
||||
},
|
||||
)
|
||||
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the current policy and check
|
||||
// if it is the same as the one we set.
|
||||
@@ -1819,7 +1819,7 @@ func TestPolicyCommand(t *testing.T) {
|
||||
},
|
||||
&output,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, output.TagOwners, 1)
|
||||
assert.Len(t, output.ACLs, 1)
|
||||
@@ -1834,7 +1834,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
@@ -1844,10 +1844,10 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||
"HEADSCALE_POLICY_MODE": "database",
|
||||
}),
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
p := policyv2.Policy{
|
||||
ACLs: []policyv2.ACL{
|
||||
@@ -1872,7 +1872,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||
policyFilePath := "/etc/headscale/policy.json"
|
||||
|
||||
err = headscale.WriteFile(policyFilePath, pBytes)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// No policy is present at this time.
|
||||
// Add a new policy from a file.
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/integrationutil"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/net/netmon"
|
||||
@@ -23,7 +24,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
|
||||
|
||||
// Generate random hostname for the headscale instance
|
||||
hash, err := util.GenerateRandomStringDNSSafe(6)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
testName := "derpverify"
|
||||
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
|
||||
|
||||
@@ -31,7 +32,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
|
||||
|
||||
// Create cert for headscale
|
||||
certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@@ -39,14 +40,14 @@ func TestDERPVerifyEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
derper, err := scenario.CreateDERPServer("head",
|
||||
dsic.WithCACert(certHeadscale),
|
||||
dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))),
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
derpRegion := tailcfg.DERPRegion{
|
||||
RegionCode: "test-derpverify",
|
||||
@@ -74,17 +75,17 @@ func TestDERPVerifyEndpoint(t *testing.T) {
|
||||
hsic.WithPort(headscalePort),
|
||||
hsic.WithCustomTLS(certHeadscale, keyHeadscale),
|
||||
hsic.WithDERPConfig(derpMap))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
fakeKey := key.NewNode()
|
||||
DERPVerify(t, fakeKey, derpRegion, false)
|
||||
|
||||
for _, client := range allClients {
|
||||
nodeKey, err := client.GetNodePrivateKey()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
DERPVerify(t, *nodeKey, derpRegion, true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
@@ -22,26 +23,26 @@ func TestResolveMagicDNS(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("magicdns"))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
// Poor mans cache
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
for _, peer := range allClients {
|
||||
@@ -78,7 +79,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
const erPath = "/tmp/extra_records.json"
|
||||
@@ -109,29 +110,29 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
// Poor mans cache
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6")
|
||||
}
|
||||
|
||||
hs, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write the file directly into place from the docker API.
|
||||
b0, _ := json.Marshal([]tailcfg.DNSRecord{
|
||||
@@ -143,7 +144,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
})
|
||||
|
||||
err = hs.WriteFile(erPath, b0)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "2.2.2.2")
|
||||
@@ -159,9 +160,9 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
b2, _ := json.Marshal(extraRecords)
|
||||
|
||||
err = hs.WriteFile(erPath+"2", b2)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
_, err = hs.Execute([]string{"mv", erPath + "2", erPath})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6")
|
||||
@@ -179,9 +180,9 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
})
|
||||
|
||||
err = hs.WriteFile(erPath+"3", b3)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
_, err = hs.Execute([]string{"cp", erPath + "3", erPath})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8")
|
||||
@@ -197,7 +198,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
})
|
||||
command := []string{"echo", fmt.Sprintf("'%s'", string(b4)), ">", erPath}
|
||||
_, err = hs.Execute([]string{"bash", "-c", strings.Join(command, " ")})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9")
|
||||
@@ -205,7 +206,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
|
||||
// Delete the file and create a new one to ensure it is picked up again.
|
||||
_, err = hs.Execute([]string{"rm", erPath})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The same paths should still be available as it is not cleared on delete.
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
@@ -219,7 +220,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
// Write a new file, the backoff mechanism should make the filewatcher pick it up
|
||||
// again.
|
||||
err = hs.WriteFile(erPath, b3)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8")
|
||||
|
||||
17
integration/dockertestutil/build.go
Normal file
17
integration/dockertestutil/build.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package dockertestutil
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// RunDockerBuildForDiagnostics runs docker build manually to get detailed error output.
|
||||
// This is used when a docker build fails to provide more detailed diagnostic information
|
||||
// than what dockertest typically provides.
|
||||
func RunDockerBuildForDiagnostics(contextDir, dockerfile string) string {
|
||||
cmd := exec.Command("docker", "build", "-f", dockerfile, contextDir)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return string(output)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
@@ -29,7 +30,7 @@ func TestDERPServerScenario(t *testing.T) {
|
||||
|
||||
derpServerScenario(t, spec, false, func(scenario *Scenario) {
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
t.Logf("checking %d clients for websocket connections", len(allClients))
|
||||
|
||||
for _, client := range allClients {
|
||||
@@ -43,7 +44,7 @@ func TestDERPServerScenario(t *testing.T) {
|
||||
}
|
||||
|
||||
hsServer, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
derpRegion := tailcfg.DERPRegion{
|
||||
RegionCode: "test-derpverify",
|
||||
@@ -79,7 +80,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
|
||||
|
||||
derpServerScenario(t, spec, true, func(scenario *Scenario) {
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
t.Logf("checking %d clients for websocket connections", len(allClients))
|
||||
|
||||
for _, client := range allClients {
|
||||
@@ -108,7 +109,7 @@ func derpServerScenario(
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
@@ -128,16 +129,16 @@ func derpServerScenario(
|
||||
"HEADSCALE_DERP_SERVER_VERIFY_CLIENTS": "true",
|
||||
}),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
|
||||
@@ -10,19 +10,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/integrationutil"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
"tailscale.com/types/key"
|
||||
@@ -38,7 +34,7 @@ func TestPingAllByIP(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
@@ -48,16 +44,16 @@ func TestPingAllByIP(t *testing.T) {
|
||||
hsic.WithTLS(),
|
||||
hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
hs, err := scenario.Headscale()
|
||||
require.NoError(t, err)
|
||||
@@ -80,7 +76,7 @@ func TestPingAllByIP(t *testing.T) {
|
||||
|
||||
// Get headscale instance for batcher debug check
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test our DebugBatcher functionality
|
||||
t.Logf("Testing DebugBatcher functionality...")
|
||||
@@ -99,23 +95,23 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
[]tsic.Option{},
|
||||
hsic.WithTestName("pingallbyippubderp"),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@@ -148,11 +144,11 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
headscale, err := scenario.Headscale(opts...)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
for _, userName := range spec.Users {
|
||||
user, err := scenario.CreateUser(userName)
|
||||
@@ -177,13 +173,13 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
@@ -200,7 +196,7 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleLogout()
|
||||
assertNoErrLogout(t, err)
|
||||
requireNoErrLogout(t, err)
|
||||
|
||||
t.Logf("all clients logged out")
|
||||
|
||||
@@ -222,7 +218,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
headscale, err := scenario.Headscale(
|
||||
@@ -231,7 +227,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
|
||||
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s",
|
||||
}),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
for _, userName := range spec.Users {
|
||||
user, err := scenario.CreateUser(userName)
|
||||
@@ -256,13 +252,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
@@ -344,22 +340,22 @@ func TestPingAllByHostname(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname"))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
success := pingAllHelper(t, allClients, allHostnames)
|
||||
|
||||
@@ -379,7 +375,7 @@ func TestTaildrop(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{},
|
||||
@@ -387,17 +383,17 @@ func TestTaildrop(t *testing.T) {
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// This will essentially fetch and cache all the FQDNs
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
if !strings.Contains(client.Hostname(), "head") {
|
||||
@@ -498,7 +494,7 @@ func TestTaildrop(t *testing.T) {
|
||||
)
|
||||
|
||||
result, _, err := client.Execute(command)
|
||||
assertNoErrf(t, "failed to execute command to ls taildrop: %s", err)
|
||||
require.NoErrorf(t, err, "failed to execute command to ls taildrop")
|
||||
|
||||
log.Printf("Result for %s: %s\n", peer.Hostname(), result)
|
||||
if fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()) != result {
|
||||
@@ -528,25 +524,25 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
require.NoErrorf(t, err, "failed to create scenario")
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname"))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
// update hostnames using the up command
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
command := []string{
|
||||
"tailscale",
|
||||
@@ -554,11 +550,11 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
"--hostname=" + hostnames[string(status.Self.ID)],
|
||||
}
|
||||
_, _, err = client.Execute(command)
|
||||
assertNoErrf(t, "failed to set hostname: %s", err)
|
||||
require.NoErrorf(t, err, "failed to set hostname")
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// Wait for nodestore batch processing to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 1 second
|
||||
@@ -597,7 +593,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
"--identifier",
|
||||
strconv.FormatUint(node.GetId(), 10),
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify that the server-side rename is reflected in DNSName while HostName remains unchanged
|
||||
@@ -643,7 +639,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
command := []string{
|
||||
"tailscale",
|
||||
@@ -651,11 +647,11 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
"--hostname=" + hostnames[string(status.Self.ID)] + "NEW",
|
||||
}
|
||||
_, _, err = client.Execute(command)
|
||||
assertNoErrf(t, "failed to set hostname: %s", err)
|
||||
require.NoErrorf(t, err, "failed to set hostname")
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// Wait for nodestore batch processing to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 1 second
|
||||
@@ -696,20 +692,20 @@ func TestExpireNode(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode"))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@@ -731,22 +727,22 @@ func TestExpireNode(t *testing.T) {
|
||||
}
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TODO(kradalby): This is Headscale specific and would not play nicely
|
||||
// with other implementations of the ControlServer interface
|
||||
result, err := headscale.Execute([]string{
|
||||
"headscale", "nodes", "expire", "--identifier", "1", "--output", "json",
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var node v1.Node
|
||||
err = json.Unmarshal([]byte(result), &node)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var expiredNodeKey key.NodePublic
|
||||
err = expiredNodeKey.UnmarshalText([]byte(node.GetNodeKey()))
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String())
|
||||
|
||||
@@ -773,14 +769,14 @@ func TestExpireNode(t *testing.T) {
|
||||
// Verify that the expired node has been marked in all peers list.
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
if client.Hostname() != node.GetName() {
|
||||
t.Logf("available peers of %s: %v", client.Hostname(), status.Peers())
|
||||
|
||||
// Ensures that the node is present, and that it is expired.
|
||||
if peerStatus, ok := status.Peer[expiredNodeKey]; ok {
|
||||
assertNotNil(t, peerStatus.Expired)
|
||||
requireNotNil(t, peerStatus.Expired)
|
||||
assert.NotNil(t, peerStatus.KeyExpiry)
|
||||
|
||||
t.Logf(
|
||||
@@ -840,20 +836,20 @@ func TestNodeOnlineStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online"))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@@ -866,14 +862,14 @@ func TestNodeOnlineStatus(t *testing.T) {
|
||||
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert that we have the original count - self
|
||||
assert.Len(t, status.Peers(), len(MustTestVersions)-1)
|
||||
}
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Duration is chosen arbitrarily, 10m is reported in #1561
|
||||
testDuration := 12 * time.Minute
|
||||
@@ -963,7 +959,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
@@ -973,16 +969,16 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
hsic.WithDERPAsIP(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// assertClientsState(t, allClients)
|
||||
|
||||
@@ -992,7 +988,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
|
||||
// Get headscale instance for batcher debug checks
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initial check: all nodes should be connected to batcher
|
||||
// Extract node IDs for validation
|
||||
@@ -1000,7 +996,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
for _, client := range allClients {
|
||||
status := client.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
}
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second)
|
||||
@@ -1072,7 +1068,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
@@ -1081,16 +1077,16 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
requireNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
@@ -1100,7 +1096,7 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test list all nodes after added otherUser
|
||||
var nodeList []v1.Node
|
||||
@@ -1170,159 +1166,3 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||
assert.True(t, nodeListAfter[0].GetOnline())
|
||||
assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId())
|
||||
}
|
||||
|
||||
// NodeSystemStatus represents the online status of a node across different systems
|
||||
type NodeSystemStatus struct {
|
||||
Batcher bool
|
||||
BatcherConnCount int
|
||||
MapResponses bool
|
||||
NodeStore bool
|
||||
}
|
||||
|
||||
// requireAllSystemsOnline checks that nodes are online/offline across batcher, mapresponses, and nodestore
|
||||
func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
startTime := time.Now()
|
||||
t.Logf("requireAllSystemsOnline: Starting validation at %s - %s", startTime.Format(TimestampFormat), message)
|
||||
|
||||
var prevReport string
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Get batcher state
|
||||
debugInfo, err := headscale.DebugBatcher()
|
||||
assert.NoError(c, err, "Failed to get batcher debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get map responses
|
||||
mapResponses, err := headscale.GetAllMapReponses()
|
||||
assert.NoError(c, err, "Failed to get map responses")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get nodestore state
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate node counts first
|
||||
expectedCount := len(expectedNodes)
|
||||
assert.Equal(c, expectedCount, debugInfo.TotalNodes, "Batcher total nodes mismatch")
|
||||
assert.Equal(c, expectedCount, len(nodeStore), "NodeStore total nodes mismatch")
|
||||
|
||||
// Check that we have map responses for expected nodes
|
||||
mapResponseCount := len(mapResponses)
|
||||
assert.Equal(c, expectedCount, mapResponseCount, "MapResponses total nodes mismatch")
|
||||
|
||||
// Build status map for each node
|
||||
nodeStatus := make(map[types.NodeID]NodeSystemStatus)
|
||||
|
||||
// Initialize all expected nodes
|
||||
for _, nodeID := range expectedNodes {
|
||||
nodeStatus[nodeID] = NodeSystemStatus{}
|
||||
}
|
||||
|
||||
// Check batcher state
|
||||
for nodeIDStr, nodeInfo := range debugInfo.ConnectedNodes {
|
||||
nodeID := types.MustParseNodeID(nodeIDStr)
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
status.Batcher = nodeInfo.Connected
|
||||
status.BatcherConnCount = nodeInfo.ActiveConnections
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
}
|
||||
|
||||
// Check map responses using buildExpectedOnlineMap
|
||||
onlineFromMaps := make(map[types.NodeID]bool)
|
||||
onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
|
||||
for nodeID := range nodeStatus {
|
||||
NODE_STATUS:
|
||||
for id, peerMap := range onlineMap {
|
||||
if id == nodeID {
|
||||
continue
|
||||
}
|
||||
|
||||
online := peerMap[nodeID]
|
||||
// If the node is offline in any map response, we consider it offline
|
||||
if !online {
|
||||
onlineFromMaps[nodeID] = false
|
||||
continue NODE_STATUS
|
||||
}
|
||||
|
||||
onlineFromMaps[nodeID] = true
|
||||
}
|
||||
}
|
||||
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
|
||||
|
||||
// Update status with map response data
|
||||
for nodeID, online := range onlineFromMaps {
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
status.MapResponses = online
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
}
|
||||
|
||||
// Check nodestore state
|
||||
for nodeID, node := range nodeStore {
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
// Check if node is online in nodestore
|
||||
status.NodeStore = node.IsOnline != nil && *node.IsOnline
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all systems show nodes in expected state and report failures
|
||||
allMatch := true
|
||||
var failureReport strings.Builder
|
||||
|
||||
ids := types.NodeIDs(maps.Keys(nodeStatus))
|
||||
slices.Sort(ids)
|
||||
for _, nodeID := range ids {
|
||||
status := nodeStatus[nodeID]
|
||||
systemsMatch := (status.Batcher == expectedOnline) &&
|
||||
(status.MapResponses == expectedOnline) &&
|
||||
(status.NodeStore == expectedOnline)
|
||||
|
||||
if !systemsMatch {
|
||||
allMatch = false
|
||||
stateStr := "offline"
|
||||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s:\n", nodeID, stateStr))
|
||||
failureReport.WriteString(fmt.Sprintf(" - batcher: %t\n", status.Batcher))
|
||||
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
|
||||
failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (down with at least one peer)\n", status.MapResponses))
|
||||
failureReport.WriteString(fmt.Sprintf(" - nodestore: %t\n", status.NodeStore))
|
||||
}
|
||||
}
|
||||
|
||||
if !allMatch {
|
||||
if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" {
|
||||
t.Log("Diff between reports:")
|
||||
t.Logf("Prev report: \n%s\n", prevReport)
|
||||
t.Logf("New report: \n%s\n", failureReport.String())
|
||||
t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
|
||||
prevReport = failureReport.String()
|
||||
}
|
||||
|
||||
failureReport.WriteString("timestamp: " + time.Now().Format(TimestampFormat) + "\n")
|
||||
|
||||
assert.Fail(c, failureReport.String())
|
||||
}
|
||||
|
||||
stateStr := "offline"
|
||||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
assert.True(c, allMatch, fmt.Sprintf("Not all nodes are %s across all systems", stateStr))
|
||||
}, timeout, 2*time.Second, message)
|
||||
|
||||
endTime := time.Now()
|
||||
duration := endTime.Sub(startTime)
|
||||
t.Logf("requireAllSystemsOnline: Completed validation at %s - Duration: %v - %s", endTime.Format(TimestampFormat), duration, message)
|
||||
}
|
||||
|
||||
922
integration/helpers.go
Normal file
922
integration/helpers.go
Normal file
@@ -0,0 +1,922 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/integration/integrationutil"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/oauth2-proxy/mockoidc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
// derpPingTimeout defines the timeout for individual DERP ping operations
|
||||
// Used in DERP connectivity tests to verify relay server communication.
|
||||
derpPingTimeout = 2 * time.Second
|
||||
|
||||
// derpPingCount defines the number of ping attempts for DERP connectivity tests
|
||||
// Higher count provides better reliability assessment of DERP connectivity.
|
||||
derpPingCount = 10
|
||||
|
||||
// TimestampFormat is the standard timestamp format used across all integration tests
|
||||
// Format: "2006-01-02T15-04-05.999999999" provides high precision timestamps
|
||||
// suitable for debugging and log correlation in integration tests.
|
||||
TimestampFormat = "2006-01-02T15-04-05.999999999"
|
||||
|
||||
// TimestampFormatRunID is used for generating unique run identifiers
|
||||
// Format: "20060102-150405" provides compact date-time for file/directory names.
|
||||
TimestampFormatRunID = "20060102-150405"
|
||||
)
|
||||
|
||||
// NodeSystemStatus represents the status of a node across different systems
|
||||
type NodeSystemStatus struct {
|
||||
Batcher bool
|
||||
BatcherConnCount int
|
||||
MapResponses bool
|
||||
NodeStore bool
|
||||
}
|
||||
|
||||
// requireNotNil validates that an object is not nil and fails the test if it is.
|
||||
// This helper provides consistent error messaging for nil checks in integration tests.
|
||||
func requireNotNil(t *testing.T, object interface{}) {
|
||||
t.Helper()
|
||||
require.NotNil(t, object)
|
||||
}
|
||||
|
||||
// requireNoErrHeadscaleEnv validates that headscale environment creation succeeded.
|
||||
// Provides specific error context for headscale environment setup failures.
|
||||
func requireNoErrHeadscaleEnv(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.NoError(t, err, "failed to create headscale environment")
|
||||
}
|
||||
|
||||
// requireNoErrGetHeadscale validates that headscale server retrieval succeeded.
|
||||
// Provides specific error context for headscale server access failures.
|
||||
func requireNoErrGetHeadscale(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.NoError(t, err, "failed to get headscale")
|
||||
}
|
||||
|
||||
// requireNoErrListClients validates that client listing operations succeeded.
|
||||
// Provides specific error context for client enumeration failures.
|
||||
func requireNoErrListClients(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.NoError(t, err, "failed to list clients")
|
||||
}
|
||||
|
||||
// requireNoErrListClientIPs validates that client IP retrieval succeeded.
|
||||
// Provides specific error context for client IP address enumeration failures.
|
||||
func requireNoErrListClientIPs(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.NoError(t, err, "failed to get client IPs")
|
||||
}
|
||||
|
||||
// requireNoErrSync validates that client synchronization operations succeeded.
|
||||
// Provides specific error context for client sync failures across the network.
|
||||
func requireNoErrSync(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.NoError(t, err, "failed to have all clients sync up")
|
||||
}
|
||||
|
||||
// requireNoErrListFQDN validates that FQDN listing operations succeeded.
|
||||
// Provides specific error context for DNS name enumeration failures.
|
||||
func requireNoErrListFQDN(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.NoError(t, err, "failed to list FQDNs")
|
||||
}
|
||||
|
||||
// requireNoErrLogout validates that tailscale node logout operations succeeded.
|
||||
// Provides specific error context for client logout failures.
|
||||
func requireNoErrLogout(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.NoError(t, err, "failed to log out tailscale nodes")
|
||||
}
|
||||
|
||||
// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes
|
||||
func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID {
|
||||
t.Helper()
|
||||
|
||||
expectedNodes := make([]types.NodeID, 0, len(clients))
|
||||
for _, client := range clients {
|
||||
status := client.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
require.NoError(t, err)
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
}
|
||||
return expectedNodes
|
||||
}
|
||||
|
||||
// validateInitialConnection performs comprehensive validation after initial client login.
|
||||
// Validates that all nodes are online and have proper NetInfo/DERP configuration,
|
||||
// essential for ensuring successful initial connection state in relogin tests.
|
||||
func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) {
|
||||
t.Helper()
|
||||
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
|
||||
}
|
||||
|
||||
// validateLogoutComplete performs comprehensive validation after client logout.
|
||||
// Ensures all nodes are properly offline across all headscale systems,
|
||||
// critical for validating clean logout state in relogin tests.
|
||||
func validateLogoutComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) {
|
||||
t.Helper()
|
||||
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second)
|
||||
}
|
||||
|
||||
// validateReloginComplete performs comprehensive validation after client relogin.
|
||||
// Validates that all nodes are back online with proper NetInfo/DERP configuration,
|
||||
// ensuring successful relogin state restoration in integration tests.
|
||||
func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) {
|
||||
t.Helper()
|
||||
|
||||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", 120*time.Second)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin", 3*time.Minute)
|
||||
}
|
||||
|
||||
// requireAllClientsOnline validates that all nodes are online/offline across all headscale systems
|
||||
// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems
|
||||
func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
startTime := time.Now()
|
||||
stateStr := "offline"
|
||||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message)
|
||||
|
||||
if expectedOnline {
|
||||
// For online validation, use the existing logic with full timeout
|
||||
requireAllClientsOnlineWithSingleTimeout(t, headscale, expectedNodes, expectedOnline, message, timeout)
|
||||
} else {
|
||||
// For offline validation, use staged approach with component-specific timeouts
|
||||
requireAllClientsOfflineStaged(t, headscale, expectedNodes, message, timeout)
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message)
|
||||
}
|
||||
|
||||
// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state
|
||||
func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
var prevReport string
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Get batcher state
|
||||
debugInfo, err := headscale.DebugBatcher()
|
||||
assert.NoError(c, err, "Failed to get batcher debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get map responses
|
||||
mapResponses, err := headscale.GetAllMapReponses()
|
||||
assert.NoError(c, err, "Failed to get map responses")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get nodestore state
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that all expected nodes are present in nodeStore
|
||||
for _, nodeID := range expectedNodes {
|
||||
_, exists := nodeStore[nodeID]
|
||||
assert.True(c, exists, "Expected node %d not found in nodeStore", nodeID)
|
||||
}
|
||||
|
||||
// Check that we have map responses for expected nodes
|
||||
mapResponseCount := len(mapResponses)
|
||||
expectedCount := len(expectedNodes)
|
||||
assert.GreaterOrEqual(c, mapResponseCount, expectedCount, "MapResponses insufficient - expected at least %d responses, got %d", expectedCount, mapResponseCount)
|
||||
|
||||
// Build status map for each node
|
||||
nodeStatus := make(map[types.NodeID]NodeSystemStatus)
|
||||
|
||||
// Initialize all expected nodes
|
||||
for _, nodeID := range expectedNodes {
|
||||
nodeStatus[nodeID] = NodeSystemStatus{}
|
||||
}
|
||||
|
||||
// Check batcher state for expected nodes
|
||||
for _, nodeID := range expectedNodes {
|
||||
nodeIDStr := fmt.Sprintf("%d", nodeID)
|
||||
if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists {
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
status.Batcher = nodeInfo.Connected
|
||||
status.BatcherConnCount = nodeInfo.ActiveConnections
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
} else {
|
||||
// Node not found in batcher, mark as disconnected
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
status.Batcher = false
|
||||
status.BatcherConnCount = 0
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check map responses using buildExpectedOnlineMap
|
||||
onlineFromMaps := make(map[types.NodeID]bool)
|
||||
onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
|
||||
|
||||
// For single node scenarios, we can't validate peer visibility since there are no peers
|
||||
if len(expectedNodes) == 1 {
|
||||
// For single node, just check that we have map responses for the node
|
||||
for nodeID := range nodeStatus {
|
||||
if _, exists := onlineMap[nodeID]; exists {
|
||||
onlineFromMaps[nodeID] = true
|
||||
} else {
|
||||
onlineFromMaps[nodeID] = false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Multi-node scenario: check peer visibility
|
||||
for nodeID := range nodeStatus {
|
||||
// Initialize as offline - will be set to true only if visible in all relevant peer maps
|
||||
onlineFromMaps[nodeID] = false
|
||||
|
||||
// Count how many peer maps should show this node
|
||||
expectedPeerMaps := 0
|
||||
foundOnlinePeerMaps := 0
|
||||
|
||||
for id, peerMap := range onlineMap {
|
||||
if id == nodeID {
|
||||
continue // Skip self-references
|
||||
}
|
||||
expectedPeerMaps++
|
||||
|
||||
if online, exists := peerMap[nodeID]; exists && online {
|
||||
foundOnlinePeerMaps++
|
||||
}
|
||||
}
|
||||
|
||||
// Node is considered online if it appears online in all peer maps
|
||||
// (or if there are no peer maps to check)
|
||||
if expectedPeerMaps == 0 || foundOnlinePeerMaps == expectedPeerMaps {
|
||||
onlineFromMaps[nodeID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
|
||||
|
||||
// Update status with map response data
|
||||
for nodeID, online := range onlineFromMaps {
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
status.MapResponses = online
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
}
|
||||
|
||||
// Check nodestore state for expected nodes
|
||||
for _, nodeID := range expectedNodes {
|
||||
if node, exists := nodeStore[nodeID]; exists {
|
||||
if status, exists := nodeStatus[nodeID]; exists {
|
||||
// Check if node is online in nodestore
|
||||
status.NodeStore = node.IsOnline != nil && *node.IsOnline
|
||||
nodeStatus[nodeID] = status
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all systems show nodes in expected state and report failures
|
||||
allMatch := true
|
||||
var failureReport strings.Builder
|
||||
|
||||
ids := types.NodeIDs(maps.Keys(nodeStatus))
|
||||
slices.Sort(ids)
|
||||
for _, nodeID := range ids {
|
||||
status := nodeStatus[nodeID]
|
||||
systemsMatch := (status.Batcher == expectedOnline) &&
|
||||
(status.MapResponses == expectedOnline) &&
|
||||
(status.NodeStore == expectedOnline)
|
||||
|
||||
if !systemsMatch {
|
||||
allMatch = false
|
||||
stateStr := "offline"
|
||||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat)))
|
||||
failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline))
|
||||
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
|
||||
failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (expected: %t, down with at least one peer)\n", status.MapResponses, expectedOnline))
|
||||
failureReport.WriteString(fmt.Sprintf(" - nodestore: %t (expected: %t)\n", status.NodeStore, expectedOnline))
|
||||
}
|
||||
}
|
||||
|
||||
if !allMatch {
|
||||
if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" {
|
||||
t.Logf("Node state validation report changed at %s:", time.Now().Format(TimestampFormat))
|
||||
t.Logf("Previous report:\n%s", prevReport)
|
||||
t.Logf("Current report:\n%s", failureReport.String())
|
||||
t.Logf("Report diff:\n%s", diff)
|
||||
prevReport = failureReport.String()
|
||||
}
|
||||
|
||||
failureReport.WriteString(fmt.Sprintf("validation_timestamp: %s\n", time.Now().Format(TimestampFormat)))
|
||||
// Note: timeout_remaining not available in this context
|
||||
|
||||
assert.Fail(c, failureReport.String())
|
||||
}
|
||||
|
||||
stateStr := "offline"
|
||||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr))
|
||||
}, timeout, 2*time.Second, message)
|
||||
}
|
||||
|
||||
// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components
|
||||
func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
// Stage 1: Verify batcher disconnection (should be immediate)
|
||||
t.Logf("Stage 1: Verifying batcher disconnection for %d nodes", len(expectedNodes))
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
debugInfo, err := headscale.DebugBatcher()
|
||||
assert.NoError(c, err, "Failed to get batcher debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
allBatcherOffline := true
|
||||
for _, nodeID := range expectedNodes {
|
||||
nodeIDStr := fmt.Sprintf("%d", nodeID)
|
||||
if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected {
|
||||
allBatcherOffline = false
|
||||
assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID)
|
||||
}
|
||||
}
|
||||
assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher")
|
||||
}, 15*time.Second, 1*time.Second, "batcher disconnection validation")
|
||||
|
||||
// Stage 2: Verify nodestore offline status (up to 15 seconds due to disconnect detection delay)
|
||||
t.Logf("Stage 2: Verifying nodestore offline status for %d nodes (allowing for 10s disconnect detection delay)", len(expectedNodes))
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
allNodeStoreOffline := true
|
||||
for _, nodeID := range expectedNodes {
|
||||
if node, exists := nodeStore[nodeID]; exists {
|
||||
isOnline := node.IsOnline != nil && *node.IsOnline
|
||||
if isOnline {
|
||||
allNodeStoreOffline = false
|
||||
assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID)
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore")
|
||||
}, 20*time.Second, 1*time.Second, "nodestore offline validation")
|
||||
|
||||
// Stage 3: Verify map response propagation (longest delay due to peer update timing)
|
||||
t.Logf("Stage 3: Verifying map response propagation for %d nodes (allowing for peer map update delays)", len(expectedNodes))
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
mapResponses, err := headscale.GetAllMapReponses()
|
||||
assert.NoError(c, err, "Failed to get map responses")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses)
|
||||
allMapResponsesOffline := true
|
||||
|
||||
if len(expectedNodes) == 1 {
|
||||
// Single node: check if it appears in map responses
|
||||
for nodeID := range onlineMap {
|
||||
if slices.Contains(expectedNodes, nodeID) {
|
||||
allMapResponsesOffline = false
|
||||
assert.False(c, true, "Node %d should not appear in map responses", nodeID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Multi-node: check peer visibility
|
||||
for _, nodeID := range expectedNodes {
|
||||
for id, peerMap := range onlineMap {
|
||||
if id == nodeID {
|
||||
continue // Skip self-references
|
||||
}
|
||||
if online, exists := peerMap[nodeID]; exists && online {
|
||||
allMapResponsesOffline = false
|
||||
assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses")
|
||||
}, 60*time.Second, 2*time.Second, "map response propagation validation")
|
||||
|
||||
t.Logf("All stages completed: nodes are fully offline across all systems")
|
||||
}
|
||||
|
||||
// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database
|
||||
// and a valid DERP server based on the NetInfo. This function follows the pattern of
|
||||
// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state.
|
||||
func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
startTime := time.Now()
|
||||
t.Logf("requireAllClientsNetInfoAndDERP: Starting NetInfo/DERP validation for %d nodes at %s - %s", len(expectedNodes), startTime.Format(TimestampFormat), message)
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Get nodestore state
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that all expected nodes are present in nodeStore
|
||||
for _, nodeID := range expectedNodes {
|
||||
_, exists := nodeStore[nodeID]
|
||||
assert.True(c, exists, "Expected node %d not found in nodeStore during NetInfo validation", nodeID)
|
||||
}
|
||||
|
||||
// Check each expected node
|
||||
for _, nodeID := range expectedNodes {
|
||||
node, exists := nodeStore[nodeID]
|
||||
assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID)
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate that the node has Hostinfo
|
||||
assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname)
|
||||
if node.Hostinfo == nil {
|
||||
t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat))
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate that the node has NetInfo
|
||||
assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname)
|
||||
if node.Hostinfo.NetInfo == nil {
|
||||
t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat))
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate that the node has a valid DERP server (PreferredDERP should be > 0)
|
||||
preferredDERP := node.Hostinfo.NetInfo.PreferredDERP
|
||||
assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP)
|
||||
|
||||
t.Logf("Node %d (%s) has valid NetInfo with DERP server %d at %s", nodeID, node.Hostname, preferredDERP, time.Now().Format(TimestampFormat))
|
||||
}
|
||||
}, timeout, 5*time.Second, message)
|
||||
|
||||
endTime := time.Now()
|
||||
duration := endTime.Sub(startTime)
|
||||
t.Logf("requireAllClientsNetInfoAndDERP: Completed NetInfo/DERP validation for %d nodes at %s - Duration: %v - %s", len(expectedNodes), endTime.Format(TimestampFormat), duration, message)
|
||||
}
|
||||
|
||||
// assertLastSeenSet validates that a node has a non-nil LastSeen timestamp.
|
||||
// Critical for ensuring node activity tracking is functioning properly.
|
||||
func assertLastSeenSet(t *testing.T, node *v1.Node) {
|
||||
assert.NotNil(t, node)
|
||||
assert.NotNil(t, node.GetLastSeen())
|
||||
}
|
||||
|
||||
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
|
||||
// are in the logged-out state (NeedsLogin).
|
||||
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
|
||||
if h, ok := t.(interface{ Helper() }); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
for _, client := range clients {
|
||||
status, err := client.Status()
|
||||
assert.NoError(t, err, "failed to get status for client %s", client.Hostname())
|
||||
assert.Equal(t, "NeedsLogin", status.BackendState,
|
||||
"client %s should be logged out", client.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
// pingAllHelper performs ping tests between all clients and addresses, returning success count.
|
||||
// This is used to validate network connectivity in integration tests.
|
||||
// Returns the total number of successful ping operations.
|
||||
func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int {
|
||||
t.Helper()
|
||||
success := 0
|
||||
|
||||
for _, client := range clients {
|
||||
for _, addr := range addrs {
|
||||
err := client.Ping(addr, opts...)
|
||||
if err != nil {
|
||||
t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
// pingDerpAllHelper performs DERP-based ping tests between all clients and addresses.
|
||||
// This specifically tests connectivity through DERP relay servers, which is important
|
||||
// for validating NAT traversal and relay functionality. Returns success count.
|
||||
func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {
|
||||
t.Helper()
|
||||
success := 0
|
||||
|
||||
for _, client := range clients {
|
||||
for _, addr := range addrs {
|
||||
if isSelfClient(client, addr) {
|
||||
continue
|
||||
}
|
||||
|
||||
err := client.Ping(
|
||||
addr,
|
||||
tsic.WithPingTimeout(derpPingTimeout),
|
||||
tsic.WithPingCount(derpPingCount),
|
||||
tsic.WithPingUntilDirect(false),
|
||||
)
|
||||
if err != nil {
|
||||
t.Logf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
// isSelfClient determines if the given address belongs to the client itself.
|
||||
// Used to avoid self-ping operations in connectivity tests by checking
|
||||
// hostname and IP address matches.
|
||||
func isSelfClient(client TailscaleClient, addr string) bool {
|
||||
if addr == client.Hostname() {
|
||||
return true
|
||||
}
|
||||
|
||||
ips, err := client.IPs()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.String() == addr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// assertClientsState validates the status and netmap of a list of clients for general connectivity.
|
||||
// Runs parallel validation of status, netcheck, and netmap for all clients to ensure
|
||||
// they have proper network configuration for all-to-all connectivity tests.
|
||||
func assertClientsState(t *testing.T, clients []TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, client := range clients {
|
||||
wg.Add(1)
|
||||
c := client // Avoid loop pointer
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
assertValidStatus(t, c)
|
||||
assertValidNetcheck(t, c)
|
||||
assertValidNetmap(t, c)
|
||||
}()
|
||||
}
|
||||
|
||||
t.Logf("waiting for client state checks to finish")
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// assertValidNetmap validates that a client's netmap has all required fields for proper operation.
|
||||
// Checks self node and all peers for essential networking data including hostinfo, addresses,
|
||||
// endpoints, and DERP configuration. Skips validation for Tailscale versions below 1.56.
|
||||
// This test is not suitable for ACL/partial connection tests.
|
||||
func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) {
|
||||
t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("Checking netmap of %q", client.Hostname())
|
||||
|
||||
netmap, err := client.Netmap()
|
||||
if err != nil {
|
||||
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
|
||||
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
|
||||
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
|
||||
|
||||
assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
|
||||
|
||||
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
|
||||
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
|
||||
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
|
||||
|
||||
for _, peer := range netmap.Peers {
|
||||
assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
|
||||
assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
|
||||
|
||||
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
|
||||
if hi := peer.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
|
||||
|
||||
// Netinfo is not always set
|
||||
// assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
|
||||
if ni := hi.NetInfo(); ni.Valid() {
|
||||
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
|
||||
}
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
|
||||
|
||||
assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
|
||||
|
||||
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
// assertValidStatus validates that a client's status has all required fields for proper operation.
|
||||
// Checks self and peer status for essential data including hostinfo, tailscale IPs, endpoints,
|
||||
// and network map presence. This test is not suitable for ACL/partial connection tests.
|
||||
func assertValidStatus(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
status, err := client.Status(true)
|
||||
if err != nil {
|
||||
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname())
|
||||
assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname())
|
||||
assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname())
|
||||
|
||||
assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname())
|
||||
|
||||
// This seem to not appear until version 1.56
|
||||
if status.Self.AllowedIPs != nil {
|
||||
assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname())
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname())
|
||||
|
||||
assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname())
|
||||
|
||||
assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname())
|
||||
|
||||
// This isn't really relevant for Self as it won't be in its own socket/wireguard.
|
||||
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
|
||||
// assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname())
|
||||
|
||||
for _, peer := range status.Peer {
|
||||
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
|
||||
assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname())
|
||||
assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname())
|
||||
|
||||
assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname())
|
||||
|
||||
// This seem to not appear until version 1.56
|
||||
if peer.AllowedIPs != nil {
|
||||
assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname())
|
||||
}
|
||||
|
||||
// Addrs does not seem to appear in the status from peers.
|
||||
// assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname())
|
||||
|
||||
assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname())
|
||||
|
||||
assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname())
|
||||
assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname())
|
||||
|
||||
// TODO(kradalby): InEngine is only true when a proper tunnel is set up,
|
||||
// there might be some interesting stuff to test here in the future.
|
||||
// assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
// assertValidNetcheck validates that a client has a proper DERP relay configured.
|
||||
// Ensures the client has discovered and selected a DERP server for relay functionality,
|
||||
// which is essential for NAT traversal and connectivity in restricted networks.
|
||||
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
report, err := client.Netcheck()
|
||||
if err != nil {
|
||||
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname())
|
||||
}
|
||||
|
||||
// assertCommandOutputContains executes a command with exponential backoff retry until the output
|
||||
// contains the expected string or timeout is reached (10 seconds).
|
||||
// This implements eventual consistency patterns and should be used instead of time.Sleep
|
||||
// before executing commands that depend on network state propagation.
|
||||
//
|
||||
// Timeout: 10 seconds with exponential backoff
|
||||
// Use cases: DNS resolution, route propagation, policy updates.
|
||||
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := backoff.Retry(t.Context(), func() (struct{}, error) {
|
||||
stdout, stderr, err := c.Execute(command)
|
||||
if err != nil {
|
||||
return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
|
||||
}
|
||||
|
||||
if !strings.Contains(stdout, contains) {
|
||||
return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout)
|
||||
}
|
||||
|
||||
return struct{}{}, nil
|
||||
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// dockertestMaxWait returns the maximum wait time for Docker-based test operations.
|
||||
// Uses longer timeouts in CI environments to account for slower resource allocation
|
||||
// and higher system load during automated testing.
|
||||
func dockertestMaxWait() time.Duration {
|
||||
wait := 300 * time.Second //nolint
|
||||
|
||||
if util.IsCI() {
|
||||
wait = 600 * time.Second //nolint
|
||||
}
|
||||
|
||||
return wait
|
||||
}
|
||||
|
||||
// didClientUseWebsocketForDERP analyzes client logs to determine if WebSocket was used for DERP.
|
||||
// Searches for WebSocket connection indicators in client logs to validate
|
||||
// DERP relay communication method for debugging connectivity issues.
|
||||
func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool {
|
||||
t.Helper()
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
err := client.WriteLogs(buf, buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
count, err := countMatchingLines(buf, func(line string) bool {
|
||||
return strings.Contains(line, "websocket: connected to ")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to process client logs: %s: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// countMatchingLines counts lines in a reader that match the given predicate function.
|
||||
// Uses optimized buffering for log analysis and provides flexible line-by-line
|
||||
// filtering for log parsing and pattern matching in integration tests.
|
||||
func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) {
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(in)
|
||||
{
|
||||
const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB
|
||||
buff := make([]byte, logBufferInitialSize)
|
||||
scanner.Buffer(buff, len(buff))
|
||||
scanner.Split(bufio.ScanLines)
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
if predicate(scanner.Text()) {
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
|
||||
return count, scanner.Err()
|
||||
}
|
||||
|
||||
// wildcard returns a wildcard alias (*) for use in policy v2 configurations.
|
||||
// Provides a convenient helper for creating permissive policy rules.
|
||||
func wildcard() policyv2.Alias {
|
||||
return policyv2.Wildcard
|
||||
}
|
||||
|
||||
// usernamep returns a pointer to a Username as an Alias for policy v2 configurations.
|
||||
// Used in ACL rules to reference specific users in network access policies.
|
||||
func usernamep(name string) policyv2.Alias {
|
||||
return ptr.To(policyv2.Username(name))
|
||||
}
|
||||
|
||||
// hostp returns a pointer to a Host as an Alias for policy v2 configurations.
|
||||
// Used in ACL rules to reference specific hosts in network access policies.
|
||||
func hostp(name string) policyv2.Alias {
|
||||
return ptr.To(policyv2.Host(name))
|
||||
}
|
||||
|
||||
// groupp returns a pointer to a Group as an Alias for policy v2 configurations.
|
||||
// Used in ACL rules to reference user groups in network access policies.
|
||||
func groupp(name string) policyv2.Alias {
|
||||
return ptr.To(policyv2.Group(name))
|
||||
}
|
||||
|
||||
// tagp returns a pointer to a Tag as an Alias for policy v2 configurations.
|
||||
// Used in ACL rules to reference node tags in network access policies.
|
||||
func tagp(name string) policyv2.Alias {
|
||||
return ptr.To(policyv2.Tag(name))
|
||||
}
|
||||
|
||||
// prefixp returns a pointer to a Prefix from a CIDR string for policy v2 configurations.
|
||||
// Converts CIDR notation to policy prefix format for network range specifications.
|
||||
func prefixp(cidr string) policyv2.Alias {
|
||||
prefix := netip.MustParsePrefix(cidr)
|
||||
return ptr.To(policyv2.Prefix(prefix))
|
||||
}
|
||||
|
||||
// aliasWithPorts creates an AliasWithPorts structure from an alias and port ranges.
|
||||
// Combines network targets with specific port restrictions for fine-grained
|
||||
// access control in policy v2 configurations.
|
||||
func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts {
|
||||
return policyv2.AliasWithPorts{
|
||||
Alias: alias,
|
||||
Ports: ports,
|
||||
}
|
||||
}
|
||||
|
||||
// usernameOwner returns a Username as an Owner for use in TagOwners policies.
|
||||
// Specifies which users can assign and manage specific tags in ACL configurations.
|
||||
func usernameOwner(name string) policyv2.Owner {
|
||||
return ptr.To(policyv2.Username(name))
|
||||
}
|
||||
|
||||
// groupOwner returns a Group as an Owner for use in TagOwners policies.
|
||||
// Specifies which groups can assign and manage specific tags in ACL configurations.
|
||||
func groupOwner(name string) policyv2.Owner {
|
||||
return ptr.To(policyv2.Group(name))
|
||||
}
|
||||
|
||||
// usernameApprover returns a Username as an AutoApprover for subnet route policies.
|
||||
// Specifies which users can automatically approve subnet route advertisements.
|
||||
func usernameApprover(name string) policyv2.AutoApprover {
|
||||
return ptr.To(policyv2.Username(name))
|
||||
}
|
||||
|
||||
// groupApprover returns a Group as an AutoApprover for subnet route policies.
|
||||
// Specifies which groups can automatically approve subnet route advertisements.
|
||||
func groupApprover(name string) policyv2.AutoApprover {
|
||||
return ptr.To(policyv2.Group(name))
|
||||
}
|
||||
|
||||
// tagApprover returns a Tag as an AutoApprover for subnet route policies.
|
||||
// Specifies which tagged nodes can automatically approve subnet route advertisements.
|
||||
func tagApprover(name string) policyv2.AutoApprover {
|
||||
return ptr.To(policyv2.Tag(name))
|
||||
}
|
||||
|
||||
// oidcMockUser creates a MockUser for OIDC authentication testing.
|
||||
// Generates consistent test user data with configurable email verification status
|
||||
// for validating OIDC integration flows in headscale authentication tests.
|
||||
func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
|
||||
return mockoidc.MockUser{
|
||||
Subject: username,
|
||||
PreferredUsername: username,
|
||||
Email: username + "@headscale.net",
|
||||
EmailVerified: emailVerified,
|
||||
}
|
||||
}
|
||||
@@ -460,6 +460,12 @@ func New(
|
||||
dockertestutil.DockerAllowNetworkAdministration,
|
||||
)
|
||||
if err != nil {
|
||||
// Try to get more detailed build output
|
||||
log.Printf("Docker build failed, attempting to get detailed output...")
|
||||
buildOutput := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, IntegrationTestDockerFileName)
|
||||
if buildOutput != "" {
|
||||
return nil, fmt.Errorf("could not start headscale container: %w\n\nDetailed build output:\n%s", err, buildOutput)
|
||||
}
|
||||
return nil, fmt.Errorf("could not start headscale container: %w", err)
|
||||
}
|
||||
log.Printf("Created %s container\n", hsic.hostname)
|
||||
|
||||
@@ -53,16 +53,16 @@ func TestEnablingRoutes(t *testing.T) {
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
[]tsic.Option{tsic.WithAcceptRoutes()},
|
||||
hsic.WithTestName("clienableroute"))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
expectedRoutes := map[string]string{
|
||||
"1": "10.0.0.0/24",
|
||||
@@ -83,7 +83,7 @@ func TestEnablingRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
var nodes []*v1.Node
|
||||
// Wait for route advertisements to propagate to NodeStore
|
||||
@@ -256,16 +256,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
prefp, err := scenario.SubnetOfNetwork("usernet1")
|
||||
require.NoError(t, err)
|
||||
@@ -319,7 +319,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// Wait for route configuration changes after advertising routes
|
||||
var nodes []*v1.Node
|
||||
@@ -1341,16 +1341,16 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
},
|
||||
},
|
||||
))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
expectedRoutes := map[string]string{
|
||||
"1": "10.33.0.0/16",
|
||||
@@ -1393,7 +1393,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
// Wait for route advertisements to propagate to the server
|
||||
var nodes []*v1.Node
|
||||
@@ -1572,25 +1572,25 @@ func TestEnablingExitRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
require.NoErrorf(t, err, "failed to create scenario")
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{
|
||||
tsic.WithExtraLoginArgs([]string{"--advertise-exit-node"}),
|
||||
}, hsic.WithTestName("clienableroute"))
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
nodes, err := headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
@@ -1686,16 +1686,16 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
assert.NotNil(t, headscale)
|
||||
|
||||
pref, err := scenario.SubnetOfNetwork("usernet1")
|
||||
@@ -1833,16 +1833,16 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
assert.NotNil(t, headscale)
|
||||
|
||||
var user1c, user2c TailscaleClient
|
||||
@@ -2247,13 +2247,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
err = scenario.createHeadscaleEnv(tt.withURL, tsOpts,
|
||||
opts...,
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
services, err := scenario.Services("usernet1")
|
||||
require.NoError(t, err)
|
||||
@@ -2263,7 +2263,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
assert.NotNil(t, headscale)
|
||||
|
||||
// Add the Docker network route to the auto-approvers
|
||||
@@ -2304,21 +2304,21 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
|
||||
if tt.withURL {
|
||||
u, err := routerUsernet1.LoginWithURL(headscale.GetEndpoint())
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := doLoginURL(routerUsernet1.Hostname(), u)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
scenario.runHeadscaleRegister("user1", body)
|
||||
} else {
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey())
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
// extra creation end.
|
||||
|
||||
@@ -2893,13 +2893,13 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
hsic.WithACLPolicy(aclPolicy),
|
||||
hsic.WithPolicyMode(types.PolicyModeDB),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
requireNoErrGetHeadscale(t, err)
|
||||
|
||||
// Get the router and node clients by user
|
||||
routerClients, err := scenario.ListTailscaleClients(routerUser)
|
||||
@@ -2944,7 +2944,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
require.NoErrorf(t, err, "failed to advertise routes: %s", err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
var routerNode, nodeNode *v1.Node
|
||||
// Wait for route advertisements to propagate to NodeStore
|
||||
|
||||
@@ -838,14 +838,14 @@ func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
|
||||
|
||||
var err error
|
||||
hc := &http.Client{
|
||||
Transport: LoggingRoundTripper{},
|
||||
Transport: LoggingRoundTripper{Hostname: hostname},
|
||||
}
|
||||
hc.Jar, err = cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
|
||||
}
|
||||
|
||||
log.Printf("%s logging in with url", hostname)
|
||||
log.Printf("%s logging in with url: %s", hostname, loginURL.String())
|
||||
ctx := context.Background()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
||||
resp, err := hc.Do(req)
|
||||
@@ -907,7 +907,9 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
|
||||
return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable)
|
||||
}
|
||||
|
||||
type LoggingRoundTripper struct{}
|
||||
type LoggingRoundTripper struct {
|
||||
Hostname string
|
||||
}
|
||||
|
||||
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
noTls := &http.Transport{
|
||||
@@ -918,9 +920,12 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("---")
|
||||
log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
|
||||
log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
|
||||
log.Printf(`
|
||||
---
|
||||
%s - method: %s | url: %s
|
||||
%s - status: %d | cookies: %+v
|
||||
---
|
||||
`, t.Hostname, req.Method, req.URL.String(), t.Hostname, resp.StatusCode, resp.Cookies())
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// This file is intended to "test the test framework", by proxy it will also test
|
||||
@@ -34,7 +35,7 @@ func TestHeadscale(t *testing.T) {
|
||||
user := "test-space"
|
||||
|
||||
scenario, err := NewScenario(ScenarioSpec{})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
t.Run("start-headscale", func(t *testing.T) {
|
||||
@@ -82,7 +83,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
|
||||
count := 1
|
||||
|
||||
scenario, err := NewScenario(ScenarioSpec{})
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
t.Run("start-headscale", func(t *testing.T) {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
@@ -30,7 +31,7 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce
|
||||
Users: []string{"user1", "user2"},
|
||||
}
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
[]tsic.Option{
|
||||
@@ -50,13 +51,13 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce
|
||||
hsic.WithACLPolicy(policy),
|
||||
hsic.WithTestName("ssh"),
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
return scenario
|
||||
}
|
||||
@@ -93,19 +94,19 @@ func TestSSHOneUserToAll(t *testing.T) {
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range allClients {
|
||||
@@ -160,16 +161,16 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
nsOneClients, err := scenario.ListTailscaleClients("user1")
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
nsTwoClients, err := scenario.ListTailscaleClients("user2")
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
testInterUserSSH := func(sourceClients []TailscaleClient, targetClients []TailscaleClient) {
|
||||
for _, client := range sourceClients {
|
||||
@@ -208,13 +209,13 @@ func TestSSHNoSSHConfigured(t *testing.T) {
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
for _, peer := range allClients {
|
||||
@@ -259,13 +260,13 @@ func TestSSHIsBlockedInACL(t *testing.T) {
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
for _, client := range allClients {
|
||||
for _, peer := range allClients {
|
||||
@@ -317,16 +318,16 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
ssh1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
ssh2Clients, err := scenario.ListTailscaleClients("user2")
|
||||
assertNoErrListClients(t, err)
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
assertNoErrListFQDN(t, err)
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
for _, client := range ssh1Clients {
|
||||
for _, peer := range ssh2Clients {
|
||||
@@ -422,9 +423,9 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien
|
||||
t.Helper()
|
||||
|
||||
result, _, err := doSSH(t, client, peer)
|
||||
assertNoErr(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertContains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", ""))
|
||||
require.Contains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", ""))
|
||||
}
|
||||
|
||||
func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) {
|
||||
|
||||
@@ -322,6 +322,20 @@ func New(
|
||||
dockertestutil.DockerAllowNetworkAdministration,
|
||||
dockertestutil.DockerMemoryLimit,
|
||||
)
|
||||
if err != nil {
|
||||
// Try to get more detailed build output
|
||||
log.Printf("Docker build failed for %s, attempting to get detailed output...", hostname)
|
||||
buildOutput := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, "Dockerfile.tailscale-HEAD")
|
||||
if buildOutput != "" {
|
||||
return nil, fmt.Errorf(
|
||||
"%s could not start tailscale container (version: %s): %w\n\nDetailed build output:\n%s",
|
||||
hostname,
|
||||
version,
|
||||
err,
|
||||
buildOutput,
|
||||
)
|
||||
}
|
||||
}
|
||||
case "unstable":
|
||||
tailscaleOptions.Repository = "tailscale/tailscale"
|
||||
tailscaleOptions.Tag = version
|
||||
@@ -333,6 +347,9 @@ func New(
|
||||
dockertestutil.DockerAllowNetworkAdministration,
|
||||
dockertestutil.DockerMemoryLimit,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Docker run failed for %s (unstable), error: %v", hostname, err)
|
||||
}
|
||||
default:
|
||||
tailscaleOptions.Repository = "tailscale/tailscale"
|
||||
tailscaleOptions.Tag = "v" + version
|
||||
@@ -344,6 +361,9 @@ func New(
|
||||
dockertestutil.DockerAllowNetworkAdministration,
|
||||
dockertestutil.DockerMemoryLimit,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Docker run failed for %s (version: v%s), error: %v", hostname, version, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -1,533 +0,0 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
// derpPingTimeout defines the timeout for individual DERP ping operations
|
||||
// Used in DERP connectivity tests to verify relay server communication.
|
||||
derpPingTimeout = 2 * time.Second
|
||||
|
||||
// derpPingCount defines the number of ping attempts for DERP connectivity tests
|
||||
// Higher count provides better reliability assessment of DERP connectivity.
|
||||
derpPingCount = 10
|
||||
|
||||
// TimestampFormat is the standard timestamp format used across all integration tests
|
||||
// Format: "2006-01-02T15-04-05.999999999" provides high precision timestamps
|
||||
// suitable for debugging and log correlation in integration tests.
|
||||
TimestampFormat = "2006-01-02T15-04-05.999999999"
|
||||
|
||||
// TimestampFormatRunID is used for generating unique run identifiers
|
||||
// Format: "20060102-150405" provides compact date-time for file/directory names.
|
||||
TimestampFormatRunID = "20060102-150405"
|
||||
)
|
||||
|
||||
func assertNoErr(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
assertNoErrf(t, "unexpected error: %s", err)
|
||||
}
|
||||
|
||||
func assertNoErrf(t *testing.T, msg string, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatalf(msg, err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertNotNil(t *testing.T, thing interface{}) {
|
||||
t.Helper()
|
||||
if thing == nil {
|
||||
t.Fatal("got unexpected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func assertNoErrHeadscaleEnv(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
assertNoErrf(t, "failed to create headscale environment: %s", err)
|
||||
}
|
||||
|
||||
func assertNoErrGetHeadscale(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
assertNoErrf(t, "failed to get headscale: %s", err)
|
||||
}
|
||||
|
||||
func assertNoErrListClients(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
assertNoErrf(t, "failed to list clients: %s", err)
|
||||
}
|
||||
|
||||
func assertNoErrListClientIPs(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
assertNoErrf(t, "failed to get client IPs: %s", err)
|
||||
}
|
||||
|
||||
func assertNoErrSync(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
assertNoErrf(t, "failed to have all clients sync up: %s", err)
|
||||
}
|
||||
|
||||
func assertNoErrListFQDN(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
assertNoErrf(t, "failed to list FQDNs: %s", err)
|
||||
}
|
||||
|
||||
func assertNoErrLogout(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
assertNoErrf(t, "failed to log out tailscale nodes: %s", err)
|
||||
}
|
||||
|
||||
func assertContains(t *testing.T, str, subStr string) {
|
||||
t.Helper()
|
||||
if !strings.Contains(str, subStr) {
|
||||
t.Fatalf("%#v does not contain %#v", str, subStr)
|
||||
}
|
||||
}
|
||||
|
||||
func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool {
|
||||
t.Helper()
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
err := client.WriteLogs(buf, buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
count, err := countMatchingLines(buf, func(line string) bool {
|
||||
return strings.Contains(line, "websocket: connected to ")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to process client logs: %s: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// pingAllHelper performs ping tests between all clients and addresses, returning success count.
|
||||
// This is used to validate network connectivity in integration tests.
|
||||
// Returns the total number of successful ping operations.
|
||||
func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int {
|
||||
t.Helper()
|
||||
success := 0
|
||||
|
||||
for _, client := range clients {
|
||||
for _, addr := range addrs {
|
||||
err := client.Ping(addr, opts...)
|
||||
if err != nil {
|
||||
t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
// pingDerpAllHelper performs DERP-based ping tests between all clients and addresses.
|
||||
// This specifically tests connectivity through DERP relay servers, which is important
|
||||
// for validating NAT traversal and relay functionality. Returns success count.
|
||||
func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {
|
||||
t.Helper()
|
||||
success := 0
|
||||
|
||||
for _, client := range clients {
|
||||
for _, addr := range addrs {
|
||||
if isSelfClient(client, addr) {
|
||||
continue
|
||||
}
|
||||
|
||||
err := client.Ping(
|
||||
addr,
|
||||
tsic.WithPingTimeout(derpPingTimeout),
|
||||
tsic.WithPingCount(derpPingCount),
|
||||
tsic.WithPingUntilDirect(false),
|
||||
)
|
||||
if err != nil {
|
||||
t.Logf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
// assertClientsState validates the status and netmap of a list of
|
||||
// clients for the general case of all to all connectivity.
|
||||
func assertClientsState(t *testing.T, clients []TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, client := range clients {
|
||||
wg.Add(1)
|
||||
c := client // Avoid loop pointer
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
assertValidStatus(t, c)
|
||||
assertValidNetcheck(t, c)
|
||||
assertValidNetmap(t, c)
|
||||
}()
|
||||
}
|
||||
|
||||
t.Logf("waiting for client state checks to finish")
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// assertValidNetmap asserts that the netmap of a client has all
|
||||
// the minimum required fields set to a known working config for
|
||||
// the general case. Fields are checked on self, then all peers.
|
||||
// This test is not suitable for ACL/partial connection tests.
|
||||
// This test can only be run on clients from 1.56.1. It will
|
||||
// automatically pass all clients below that and is safe to call
|
||||
// for all versions.
|
||||
func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) {
|
||||
t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("Checking netmap of %q", client.Hostname())
|
||||
|
||||
netmap, err := client.Netmap()
|
||||
if err != nil {
|
||||
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
|
||||
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
|
||||
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
|
||||
|
||||
assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
|
||||
|
||||
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
|
||||
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
|
||||
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
|
||||
|
||||
for _, peer := range netmap.Peers {
|
||||
assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
|
||||
assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
|
||||
|
||||
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
|
||||
if hi := peer.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
|
||||
|
||||
// Netinfo is not always set
|
||||
// assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
|
||||
if ni := hi.NetInfo(); ni.Valid() {
|
||||
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
|
||||
}
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
|
||||
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
|
||||
|
||||
assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
|
||||
|
||||
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
|
||||
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
// assertValidStatus asserts that the status of a client has all
|
||||
// the minimum required fields set to a known working config for
|
||||
// the general case. Fields are checked on self, then all peers.
|
||||
// This test is not suitable for ACL/partial connection tests.
|
||||
func assertValidStatus(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
status, err := client.Status(true)
|
||||
if err != nil {
|
||||
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname())
|
||||
assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname())
|
||||
assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname())
|
||||
|
||||
assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname())
|
||||
|
||||
// This seem to not appear until version 1.56
|
||||
if status.Self.AllowedIPs != nil {
|
||||
assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname())
|
||||
}
|
||||
|
||||
assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname())
|
||||
|
||||
assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname())
|
||||
|
||||
assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname())
|
||||
|
||||
// This isn't really relevant for Self as it won't be in its own socket/wireguard.
|
||||
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
|
||||
// assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname())
|
||||
|
||||
for _, peer := range status.Peer {
|
||||
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
|
||||
assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname())
|
||||
assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname())
|
||||
|
||||
assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname())
|
||||
|
||||
// This seem to not appear until version 1.56
|
||||
if peer.AllowedIPs != nil {
|
||||
assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname())
|
||||
}
|
||||
|
||||
// Addrs does not seem to appear in the status from peers.
|
||||
// assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname())
|
||||
|
||||
assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname())
|
||||
|
||||
assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname())
|
||||
assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname())
|
||||
|
||||
// TODO(kradalby): InEngine is only true when a proper tunnel is set up,
|
||||
// there might be some interesting stuff to test here in the future.
|
||||
// assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
report, err := client.Netcheck()
|
||||
if err != nil {
|
||||
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname())
|
||||
}
|
||||
|
||||
// assertCommandOutputContains executes a command with exponential backoff retry until the output
|
||||
// contains the expected string or timeout is reached (10 seconds).
|
||||
// This implements eventual consistency patterns and should be used instead of time.Sleep
|
||||
// before executing commands that depend on network state propagation.
|
||||
//
|
||||
// Timeout: 10 seconds with exponential backoff
|
||||
// Use cases: DNS resolution, route propagation, policy updates.
|
||||
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := backoff.Retry(t.Context(), func() (struct{}, error) {
|
||||
stdout, stderr, err := c.Execute(command)
|
||||
if err != nil {
|
||||
return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
|
||||
}
|
||||
|
||||
if !strings.Contains(stdout, contains) {
|
||||
return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout)
|
||||
}
|
||||
|
||||
return struct{}{}, nil
|
||||
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func isSelfClient(client TailscaleClient, addr string) bool {
|
||||
if addr == client.Hostname() {
|
||||
return true
|
||||
}
|
||||
|
||||
ips, err := client.IPs()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.String() == addr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func dockertestMaxWait() time.Duration {
|
||||
wait := 300 * time.Second //nolint
|
||||
|
||||
if util.IsCI() {
|
||||
wait = 600 * time.Second //nolint
|
||||
}
|
||||
|
||||
return wait
|
||||
}
|
||||
|
||||
func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) {
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(in)
|
||||
{
|
||||
const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB
|
||||
buff := make([]byte, logBufferInitialSize)
|
||||
scanner.Buffer(buff, len(buff))
|
||||
scanner.Split(bufio.ScanLines)
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
if predicate(scanner.Text()) {
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
|
||||
return count, scanner.Err()
|
||||
}
|
||||
|
||||
// func dockertestCommandTimeout() time.Duration {
|
||||
// timeout := 10 * time.Second //nolint
|
||||
//
|
||||
// if isCI() {
|
||||
// timeout = 60 * time.Second //nolint
|
||||
// }
|
||||
//
|
||||
// return timeout
|
||||
// }
|
||||
|
||||
// pingAllNegativeHelper is intended to have 1 or more nodes timing out from the ping,
|
||||
// it counts failures instead of successes.
|
||||
// func pingAllNegativeHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {
|
||||
// t.Helper()
|
||||
// failures := 0
|
||||
//
|
||||
// timeout := 100
|
||||
// count := 3
|
||||
//
|
||||
// for _, client := range clients {
|
||||
// for _, addr := range addrs {
|
||||
// err := client.Ping(
|
||||
// addr,
|
||||
// tsic.WithPingTimeout(time.Duration(timeout)*time.Millisecond),
|
||||
// tsic.WithPingCount(count),
|
||||
// )
|
||||
// if err != nil {
|
||||
// failures++
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return failures
|
||||
// }
|
||||
|
||||
// // findPeerByIP takes an IP and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus
|
||||
// // if there is a peer with the given IP. If no peer is found, nil is returned.
|
||||
// func findPeerByIP(
|
||||
// ip netip.Addr,
|
||||
// peers map[key.NodePublic]*ipnstate.PeerStatus,
|
||||
// ) *ipnstate.PeerStatus {
|
||||
// for _, peer := range peers {
|
||||
// for _, peerIP := range peer.TailscaleIPs {
|
||||
// if ip == peerIP {
|
||||
// return peer
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// Helper functions for creating typed policy entities
|
||||
|
||||
// wildcard returns a wildcard alias (*).
|
||||
func wildcard() policyv2.Alias {
|
||||
return policyv2.Wildcard
|
||||
}
|
||||
|
||||
// usernamep returns a pointer to a Username as an Alias.
|
||||
func usernamep(name string) policyv2.Alias {
|
||||
return ptr.To(policyv2.Username(name))
|
||||
}
|
||||
|
||||
// hostp returns a pointer to a Host.
|
||||
func hostp(name string) policyv2.Alias {
|
||||
return ptr.To(policyv2.Host(name))
|
||||
}
|
||||
|
||||
// groupp returns a pointer to a Group as an Alias.
|
||||
func groupp(name string) policyv2.Alias {
|
||||
return ptr.To(policyv2.Group(name))
|
||||
}
|
||||
|
||||
// tagp returns a pointer to a Tag as an Alias.
|
||||
func tagp(name string) policyv2.Alias {
|
||||
return ptr.To(policyv2.Tag(name))
|
||||
}
|
||||
|
||||
// prefixp returns a pointer to a Prefix from a CIDR string.
|
||||
func prefixp(cidr string) policyv2.Alias {
|
||||
prefix := netip.MustParsePrefix(cidr)
|
||||
return ptr.To(policyv2.Prefix(prefix))
|
||||
}
|
||||
|
||||
// aliasWithPorts creates an AliasWithPorts structure from an alias and ports.
|
||||
func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts {
|
||||
return policyv2.AliasWithPorts{
|
||||
Alias: alias,
|
||||
Ports: ports,
|
||||
}
|
||||
}
|
||||
|
||||
// usernameOwner returns a Username as an Owner for use in TagOwners.
|
||||
func usernameOwner(name string) policyv2.Owner {
|
||||
return ptr.To(policyv2.Username(name))
|
||||
}
|
||||
|
||||
// groupOwner returns a Group as an Owner for use in TagOwners.
|
||||
func groupOwner(name string) policyv2.Owner {
|
||||
return ptr.To(policyv2.Group(name))
|
||||
}
|
||||
|
||||
// usernameApprover returns a Username as an AutoApprover.
|
||||
func usernameApprover(name string) policyv2.AutoApprover {
|
||||
return ptr.To(policyv2.Username(name))
|
||||
}
|
||||
|
||||
// groupApprover returns a Group as an AutoApprover.
|
||||
func groupApprover(name string) policyv2.AutoApprover {
|
||||
return ptr.To(policyv2.Group(name))
|
||||
}
|
||||
|
||||
// tagApprover returns a Tag as an AutoApprover.
|
||||
func tagApprover(name string) policyv2.AutoApprover {
|
||||
return ptr.To(policyv2.Tag(name))
|
||||
}
|
||||
|
||||
//
|
||||
// // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus
|
||||
// // if there is a peer with the given hostname. If no peer is found, nil is returned.
|
||||
// func findPeerByHostname(
|
||||
// hostname string,
|
||||
// peers map[key.NodePublic]*ipnstate.PeerStatus,
|
||||
// ) *ipnstate.PeerStatus {
|
||||
// for _, peer := range peers {
|
||||
// if hostname == peer.HostName {
|
||||
// return peer
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return nil
|
||||
// }
|
||||
Reference in New Issue
Block a user