integration: replace time.Sleep with assert.EventuallyWithT (#2680)
Some checks failed
Build / build-nix (push) Has been cancelled
Build / build-cross (GOARCH=386 GOOS=linux) (push) Has been cancelled
Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Has been cancelled
Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Has been cancelled
Build / build-cross (GOARCH=arm GOOS=linux GOARM=5) (push) Has been cancelled
Build / build-cross (GOARCH=arm GOOS=linux GOARM=6) (push) Has been cancelled
Build / build-cross (GOARCH=arm GOOS=linux GOARM=7) (push) Has been cancelled
Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Has been cancelled
Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Has been cancelled
Tests / test (push) Has been cancelled
update-flake-lock / lockfile (push) Has been cancelled
GitHub Actions Version Updater / build (push) Has been cancelled

This commit is contained in:
Kristoffer Dalby
2025-07-10 23:38:55 +02:00
committed by GitHub
parent b904276f2b
commit c6d7b512bd
73 changed files with 584 additions and 573 deletions

View File

@@ -48,5 +48,4 @@ jobs:
- name: Deploy stable docs from tag
if: startsWith(github.ref, 'refs/tags/v')
# This assumes that only newer tags are pushed
run:
mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest
run: mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest

View File

@@ -75,7 +75,7 @@ jobs:
# Some of the jobs might still require manual restart as they are really
# slow and this will cause them to eventually be killed by Github actions.
attempt_delay: 300000 # 5 min
attempt_limit: 3
attempt_limit: 2
command: |
nix develop --command -- hi run "^${{ inputs.test }}$" \
--timeout=120m \

View File

@@ -36,8 +36,7 @@ jobs:
- name: golangci-lint
if: steps.changed-files.outputs.files == 'true'
run:
nix develop --command -- golangci-lint run
run: nix develop --command -- golangci-lint run
--new-from-rev=${{github.event.pull_request.base.sha}}
--format=colored-line-number
@@ -75,8 +74,7 @@ jobs:
- name: Prettify code
if: steps.changed-files.outputs.files == 'true'
run:
nix develop --command -- prettier --no-error-on-unmatched-pattern
run: nix develop --command -- prettier --no-error-on-unmatched-pattern
--ignore-unknown --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html}
proto-lint:

View File

@@ -117,7 +117,7 @@ var createNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()),
"Cannot create node: "+status.Convert(err).Message(),
output,
)
}

View File

@@ -2,6 +2,7 @@ package cli
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
@@ -68,7 +69,7 @@ func mockOIDC() error {
userStr := os.Getenv("MOCKOIDC_USERS")
if userStr == "" {
return fmt.Errorf("MOCKOIDC_USERS not defined")
return errors.New("MOCKOIDC_USERS not defined")
}
var users []mockoidc.MockUser

View File

@@ -184,7 +184,7 @@ var listNodesCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
"Cannot get nodes: "+status.Convert(err).Message(),
output,
)
}
@@ -398,10 +398,7 @@ var deleteNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error getting node node: %s",
status.Convert(err).Message(),
),
"Error getting node node: "+status.Convert(err).Message(),
output,
)
@@ -437,10 +434,7 @@ var deleteNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error deleting node: %s",
status.Convert(err).Message(),
),
"Error deleting node: "+status.Convert(err).Message(),
output,
)
@@ -498,10 +492,7 @@ var moveNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error getting node: %s",
status.Convert(err).Message(),
),
"Error getting node: "+status.Convert(err).Message(),
output,
)
@@ -517,10 +508,7 @@ var moveNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error moving node: %s",
status.Convert(err).Message(),
),
"Error moving node: "+status.Convert(err).Message(),
output,
)
@@ -567,10 +555,7 @@ be assigned to nodes.`,
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error backfilling IPs: %s",
status.Convert(err).Message(),
),
"Error backfilling IPs: "+status.Convert(err).Message(),
output,
)

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/url"
"strconv"
survey "github.com/AlecAivazis/survey/v2"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@@ -27,10 +28,7 @@ func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
err := errors.New("--name or --identifier flag is required")
ErrorOutput(
err,
fmt.Sprintf(
"Cannot rename user: %s",
status.Convert(err).Message(),
),
"Cannot rename user: "+status.Convert(err).Message(),
"",
)
}
@@ -114,10 +112,7 @@ var createUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot create user: %s",
status.Convert(err).Message(),
),
"Cannot create user: "+status.Convert(err).Message(),
output,
)
}
@@ -147,16 +142,16 @@ var destroyUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
"Error: "+status.Convert(err).Message(),
output,
)
}
if len(users.GetUsers()) != 1 {
err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID")
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
"Error: "+status.Convert(err).Message(),
output,
)
}
@@ -185,10 +180,7 @@ var destroyUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot destroy user: %s",
status.Convert(err).Message(),
),
"Cannot destroy user: "+status.Convert(err).Message(),
output,
)
}
@@ -233,7 +225,7 @@ var listUsersCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()),
"Cannot get users: "+status.Convert(err).Message(),
output,
)
}
@@ -247,7 +239,7 @@ var listUsersCmd = &cobra.Command{
tableData = append(
tableData,
[]string{
fmt.Sprintf("%d", user.GetId()),
strconv.FormatUint(user.GetId(), 10),
user.GetDisplayName(),
user.GetName(),
user.GetEmail(),
@@ -287,16 +279,16 @@ var renameUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
"Error: "+status.Convert(err).Message(),
output,
)
}
if len(users.GetUsers()) != 1 {
err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID")
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
"Error: "+status.Convert(err).Message(),
output,
)
}
@@ -312,10 +304,7 @@ var renameUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot rename user: %s",
status.Convert(err).Message(),
),
"Cannot rename user: "+status.Convert(err).Message(),
output,
)
}

View File

@@ -66,7 +66,7 @@ func killTestContainers(ctx context.Context) error {
if cont.State == "running" {
_ = cli.ContainerKill(ctx, cont.ID, "KILL")
}
// Then remove the container with retry logic
if removeContainerWithRetry(ctx, cli, cont.ID) {
removed++
@@ -87,25 +87,25 @@ func killTestContainers(ctx context.Context) error {
func removeContainerWithRetry(ctx context.Context, cli *client.Client, containerID string) bool {
maxRetries := 3
baseDelay := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ {
for attempt := range maxRetries {
err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{
Force: true,
})
if err == nil {
return true
}
// If this is the last attempt, don't wait
if attempt == maxRetries-1 {
break
}
// Wait with exponential backoff
delay := baseDelay * time.Duration(1<<attempt)
time.Sleep(delay)
}
return false
}

View File

@@ -156,10 +156,10 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
projectRoot := findProjectRoot(pwd)
runID := dockertestutil.ExtractRunIDFromContainerName(containerName)
env := []string{
fmt.Sprintf("HEADSCALE_INTEGRATION_POSTGRES=%d", boolToInt(config.UsePostgres)),
fmt.Sprintf("HEADSCALE_INTEGRATION_RUN_ID=%s", runID),
"HEADSCALE_INTEGRATION_RUN_ID=" + runID,
}
containerConfig := &container.Config{
Image: "golang:" + config.GoVersion,
@@ -175,7 +175,7 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
// Get the correct Docker socket path from the current context
dockerSocketPath := getDockerSocketPath()
if config.Verbose {
log.Printf("Using Docker socket: %s", dockerSocketPath)
}
@@ -184,7 +184,7 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
AutoRemove: false, // We'll remove manually for better control
Binds: []string{
fmt.Sprintf("%s:%s", projectRoot, projectRoot),
fmt.Sprintf("%s:/var/run/docker.sock", dockerSocketPath),
dockerSocketPath + ":/var/run/docker.sock",
logsDir + ":/tmp/control",
},
Mounts: []mount.Mount{
@@ -237,7 +237,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
}
testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
// Wait for all test containers to reach a final state
maxWaitTime := 10 * time.Second
checkInterval := 500 * time.Millisecond
@@ -254,7 +254,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
return nil
case <-ticker.C:
allFinalized := true
for _, testCont := range testContainers {
inspect, err := cli.ContainerInspect(ctx, testCont.ID)
if err != nil {
@@ -263,17 +263,18 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
}
continue
}
// Check if container is in a final state
if !isContainerFinalized(inspect.State) {
allFinalized = false
if verbose {
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
}
break
}
}
if allFinalized {
if verbose {
log.Printf("All test containers finalized, ready for artifact extraction")
@@ -290,7 +291,6 @@ func isContainerFinalized(state *container.State) bool {
return !state.Running && state.FinishedAt != ""
}
// findProjectRoot locates the project root by finding the directory containing go.mod.
func findProjectRoot(startPath string) string {
current := startPath
@@ -427,7 +427,7 @@ func listControlFiles(logsDir string) {
}
if entry.IsDir() {
// Include directories (pprof, mapresponses)
// Include directories (pprof, mapresponses)
if strings.Contains(name, "-pprof") || strings.Contains(name, "-mapresponses") {
dataDirs = append(dataDirs, name)
}
@@ -510,7 +510,7 @@ type testContainer struct {
// getCurrentTestContainers filters containers to only include those from the current test run.
func getCurrentTestContainers(containers []container.Summary, testContainerID string, verbose bool) []testContainer {
var testRunContainers []testContainer
// Find the test container to get its run ID label
var runID string
for _, cont := range containers {
@@ -521,16 +521,16 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
break
}
}
if runID == "" {
log.Printf("Error: test container %s missing required hi.run-id label", testContainerID[:12])
return testRunContainers
}
if verbose {
log.Printf("Looking for containers with run ID: %s", runID)
}
// Find all containers with the same run ID
for _, cont := range containers {
for _, name := range cont.Names {
@@ -546,18 +546,19 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
log.Printf("Including container %s (run ID: %s)", containerName, runID)
}
}
break
}
}
}
return testRunContainers
}
// extractContainerArtifacts saves logs and tar files from a container.
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
// Ensure the logs directory exists
if err := os.MkdirAll(logsDir, 0755); err != nil {
if err := os.MkdirAll(logsDir, 0o755); err != nil {
return fmt.Errorf("failed to create logs directory: %w", err)
}
@@ -608,12 +609,12 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
}
// Write stdout logs
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0644); err != nil {
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil {
return fmt.Errorf("failed to write stdout log: %w", err)
}
// Write stderr logs
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0644); err != nil {
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil {
return fmt.Errorf("failed to write stderr log: %w", err)
}
@@ -626,7 +627,7 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
// extractContainerFiles extracts database file and directories from headscale containers.
// Note: The actual file extraction is now handled by the integration tests themselves
// via SaveProfile, SaveMapResponses, and SaveDatabase functions in hsic.go
// via SaveProfile, SaveMapResponses, and SaveDatabase functions in hsic.go.
func extractContainerFiles(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
// Files are now extracted directly by the integration tests
// This function is kept for potential future use or other file types
@@ -677,7 +678,7 @@ func extractDirectory(ctx context.Context, cli *client.Client, containerID, sour
// Create target directory
targetDir := filepath.Join(logsDir, dirName)
if err := os.MkdirAll(targetDir, 0755); err != nil {
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
}

View File

@@ -10,10 +10,8 @@ import (
"strings"
)
var (
// ErrFileNotFoundInTar indicates a file was not found in the tar archive.
ErrFileNotFoundInTar = errors.New("file not found in tar")
)
// ErrFileNotFoundInTar indicates a file was not found in the tar archive.
var ErrFileNotFoundInTar = errors.New("file not found in tar")
// extractFileFromTar extracts a single file from a tar reader.
func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error {
@@ -42,6 +40,7 @@ func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error
if _, err := io.Copy(outFile, tr); err != nil {
return fmt.Errorf("failed to copy file contents: %w", err)
}
return nil
}
}
@@ -98,4 +97,4 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error {
}
return nil
}
}

View File

@@ -143,6 +143,7 @@
yq-go
ripgrep
postgresql
traceroute
# 'dot' is needed for pprof graphs
# go tool pprof -http=: <source>

View File

@@ -98,7 +98,6 @@ func (h *Headscale) handleExistingNode(
return nil, nil
}
}
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
@@ -169,7 +168,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, changed, err := h.state.HandleNodeFromPreAuthKey(
regReq,
machineKey,
@@ -178,9 +176,11 @@ func (h *Headscale) handleRegisterWithAuthKey(
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
}
if perr, ok := err.(types.PAKError); ok {
var perr types.PAKError
if errors.As(err, &perr) {
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
}
return nil, err
}

View File

@@ -1,11 +1,10 @@
package capver
import (
"slices"
"sort"
"strings"
"slices"
xmaps "golang.org/x/exp/maps"
"tailscale.com/tailcfg"
"tailscale.com/util/set"

View File

@@ -1,6 +1,6 @@
package capver
//Generated DO NOT EDIT
// Generated DO NOT EDIT
import "tailscale.com/tailcfg"
@@ -38,17 +38,16 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.82.5": 115,
}
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
87: "v1.60.0",
88: "v1.62.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
87: "v1.60.0",
88: "v1.62.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
}

View File

@@ -764,13 +764,13 @@ AND auth_key_id NOT IN (
// Drop all indexes first to avoid conflicts
indexesToDrop := []string{
"idx_users_deleted_at",
"idx_provider_identifier",
"idx_provider_identifier",
"idx_name_provider_identifier",
"idx_name_no_provider_identifier",
"idx_api_keys_prefix",
"idx_policies_deleted_at",
}
for _, index := range indexesToDrop {
_ = tx.Exec("DROP INDEX IF EXISTS " + index).Error
}
@@ -927,6 +927,7 @@ AND auth_key_id NOT IN (
}
log.Info().Msg("Schema recreation completed successfully")
return nil
},
Rollback: func(db *gorm.DB) error { return nil },

View File

@@ -93,7 +93,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
Avoid: false,
Nodes: []*tailcfg.DERPNode{
{
Name: fmt.Sprintf("%d", d.cfg.ServerRegionID),
Name: strconv.Itoa(d.cfg.ServerRegionID),
RegionID: d.cfg.ServerRegionID,
HostName: host,
DERPPort: port,

View File

@@ -103,7 +103,6 @@ func (e *ExtraRecordsMan) Run() {
return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()))
if err != nil {
log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete")
continue

View File

@@ -475,7 +475,10 @@ func (api headscaleV1APIServer) RenameNode(
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
log.Trace().

View File

@@ -32,7 +32,7 @@ const (
reservedResponseHeaderSize = 4
)
// httpError logs an error and sends an HTTP error response with the given
// httpError logs an error and sends an HTTP error response with the given.
func httpError(w http.ResponseWriter, err error) {
var herr HTTPError
if errors.As(err, &herr) {
@@ -102,6 +102,7 @@ func (h *Headscale) handleVerifyRequest(
resp := &tailcfg.DERPAdmitClientResponse{
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
}
return json.NewEncoder(writer).Encode(resp)
}

View File

@@ -500,7 +500,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
}
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
// or for the given nodes if at least one node ID is given as parameter.
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.state.ListNodes(nodeIDs...)
if err != nil {

View File

@@ -80,7 +80,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
}
}
// mockState is a mock implementation that provides the required methods
// mockState is a mock implementation that provides the required methods.
type mockState struct {
polMan policy.PolicyManager
derpMap *tailcfg.DERPMap
@@ -133,6 +133,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
}
}
}
return filtered, nil
}
// Return all peers except the node itself
@@ -142,6 +143,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
filtered = append(filtered, peer)
}
}
return filtered, nil
}
@@ -157,8 +159,10 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
}
}
}
return filtered, nil
}
return m.nodes, nil
}

View File

@@ -11,7 +11,7 @@ import (
"tailscale.com/types/views"
)
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag.
type NodeCanHaveTagChecker interface {
NodeCanHaveTag(node types.NodeView, tag string) bool
}

View File

@@ -111,5 +111,6 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
}
n, err := r.ResponseWriter.Write(b)
r.written += int64(n)
return n, err
}

View File

@@ -50,6 +50,7 @@ func NewNotifier(cfg *types.Config) *Notifier {
n.b = b
go b.doWork()
return n
}
@@ -72,7 +73,7 @@ func (n *Notifier) Close() {
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
}
// safeCloseChannel closes a channel and panic recovers if already closed
// safeCloseChannel closes a channel and panic recovers if already closed.
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
defer func() {
if r := recover(); r != nil {
@@ -170,6 +171,7 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
}
@@ -182,7 +184,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
return false
}
// LikelyConnectedMap returns a thread safe map of connected nodes
// LikelyConnectedMap returns a thread safe map of connected nodes.
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.connected
}

View File

@@ -1,17 +1,15 @@
package notifier
import (
"context"
"fmt"
"math/rand"
"net/netip"
"slices"
"sort"
"sync"
"testing"
"time"
"slices"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@@ -241,7 +239,7 @@ func TestBatcher(t *testing.T) {
defer n.RemoveNode(1, ch)
for _, u := range tt.updates {
n.NotifyAll(context.Background(), u)
n.NotifyAll(t.Context(), u)
}
n.b.flush()
@@ -270,7 +268,7 @@ func TestBatcher(t *testing.T) {
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// close a channel that was already closed, which can happen when a node changes
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// mock config for the notifier
cfg := &types.Config{
@@ -308,16 +306,17 @@ func TestIsLikelyConnectedRaceCondition(t *testing.T) {
for range iterations {
// Simulate race by having some goroutines check IsLikelyConnected
// while others add/remove the node
if routineID%3 == 0 {
switch routineID % 3 {
case 0:
// This goroutine checks connection status
isConnected := notifier.IsLikelyConnected(nodeID)
if isConnected != true && isConnected != false {
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
}
} else if routineID%3 == 1 {
case 1:
// This goroutine removes the node
notifier.RemoveNode(nodeID, updateChan)
} else {
default:
// This goroutine adds the node back
notifier.AddNode(nodeID, updateChan)
}

View File

@@ -84,11 +84,8 @@ func NewAuthProviderOIDC(
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(serverURL, "/"),
),
Scopes: cfg.Scope,
RedirectURL: strings.TrimSuffix(serverURL, "/") + "/oidc/callback",
Scopes: cfg.Scope,
}
registrationCache := zcache.New[string, RegistrationInfo](
@@ -131,7 +128,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
req *http.Request,
) {
vars := mux.Vars(req)
registrationIdStr, _ := vars["registration_id"]
registrationIdStr := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
@@ -232,7 +229,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
}
oauth2Token, err := a.getOauth2Token(req.Context(), code, state)
if err != nil {
httpError(writer, err)
return
@@ -364,6 +360,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
@@ -402,6 +399,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
if err != nil {
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
}
return oauth2Token, err
}

View File

@@ -2,9 +2,8 @@ package matcher
import (
"net/netip"
"strings"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
@@ -28,6 +27,7 @@ func (m Match) DebugString() string {
for _, prefix := range m.dests.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n")
}
return sb.String()
}
@@ -36,6 +36,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
for _, rule := range rules {
matches = append(matches, MatchFromFilterRule(rule))
}
return matches
}

View File

@@ -4,7 +4,6 @@ import (
"net/netip"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"

View File

@@ -5,7 +5,6 @@ import (
"slices"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/samber/lo"
@@ -131,7 +130,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
// AutoApproveRoutes approves any route that can be autoapproved from
// the nodes perspective according to the given policy.
// It reports true if any routes were approved.
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
if pm == nil {
return false

View File

@@ -7,9 +7,8 @@ import (
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@@ -1974,6 +1973,7 @@ func TestSSHPolicyRules(t *testing.T) {
}
}
}
func TestReduceRoutes(t *testing.T) {
type args struct {
node *types.Node

View File

@@ -13,9 +13,7 @@ import (
"tailscale.com/types/views"
)
var (
ErrInvalidAction = errors.New("invalid action")
)
var ErrInvalidAction = errors.New("invalid action")
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
@@ -52,7 +50,7 @@ func (pol *Policy) compileFilterRules(
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, users, nodes)
ips, err := dest.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips")
}
@@ -174,5 +172,6 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
for _, pref := range ips.Prefixes() {
out = append(out, pref.String())
}
return out
}

View File

@@ -4,19 +4,17 @@ import (
"encoding/json"
"fmt"
"net/netip"
"slices"
"strings"
"sync"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"slices"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
"tailscale.com/types/views"
"tailscale.com/util/deephash"
)
type PolicyManager struct {
@@ -166,6 +164,7 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter, pm.matchers
}
@@ -178,6 +177,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
@@ -190,6 +190,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
@@ -249,7 +250,6 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
// cannot just lookup in the prefix map and have to check
// if there is a "parent" prefix available.
for prefix, approveAddrs := range pm.autoApproveMap {
// Check if prefix is larger (so containing) and then overlaps
// the route to see if the node can approve a subset of an autoapprover
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {

View File

@@ -1,10 +1,10 @@
package v2
import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"

View File

@@ -6,9 +6,9 @@ import (
"errors"
"fmt"
"net/netip"
"strings"
"slices"
"strconv"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@@ -72,14 +72,14 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
// Check if it's the wildcard port range
if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 {
return json.Marshal(fmt.Sprintf("%s:*", alias))
return json.Marshal(alias + ":*")
}
// Otherwise, format as "alias:ports"
var ports []string
for _, port := range a.Ports {
if port.First == port.Last {
ports = append(ports, fmt.Sprintf("%d", port.First))
ports = append(ports, strconv.FormatUint(uint64(port.First), 10))
} else {
ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last))
}
@@ -133,6 +133,7 @@ func (u *Username) UnmarshalJSON(b []byte) error {
if err := u.Validate(); err != nil {
return err
}
return nil
}
@@ -203,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
return buildIPSetMultiErr(&ips, errs)
}
// Group is a special string which is always prefixed with `group:`
// Group is a special string which is always prefixed with `group:`.
type Group string
func (g Group) Validate() error {
@@ -218,6 +219,7 @@ func (g *Group) UnmarshalJSON(b []byte) error {
if err := g.Validate(); err != nil {
return err
}
return nil
}
@@ -264,7 +266,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod
return buildIPSetMultiErr(&ips, errs)
}
// Tag is a special string which is always prefixed with `tag:`
// Tag is a special string which is always prefixed with `tag:`.
type Tag string
func (t Tag) Validate() error {
@@ -279,6 +281,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
if err := t.Validate(); err != nil {
return err
}
return nil
}
@@ -347,6 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
if err := h.Validate(); err != nil {
return err
}
return nil
}
@@ -409,6 +413,7 @@ func (p *Prefix) parseString(addr string) error {
}
*p = Prefix(addrPref)
return nil
}
@@ -417,6 +422,7 @@ func (p *Prefix) parseString(addr string) error {
return err
}
*p = Prefix(pref)
return nil
}
@@ -428,6 +434,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
if err := p.Validate(); err != nil {
return err
}
return nil
}
@@ -462,7 +469,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild
}
}
// AutoGroup is a special string which is always prefixed with `autogroup:`
// AutoGroup is a special string which is always prefixed with `autogroup:`.
type AutoGroup string
const (
@@ -495,6 +502,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
if err := ag.Validate(); err != nil {
return err
}
return nil
}
@@ -632,13 +640,14 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
if err := ve.Alias.Validate(); err != nil {
if err := ve.Validate(); err != nil {
return err
}
default:
return fmt.Errorf("type %T not supported", vs)
}
return nil
}
@@ -713,6 +722,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.Alias = ptr
return nil
}
@@ -729,6 +739,7 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
for i, alias := range aliases {
(*a)[i] = alias.Alias
}
return nil
}
@@ -784,7 +795,7 @@ func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.I
return ips, multierr.New(append(errs, err)...)
}
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer.
func unmarshalPointer[T any](
b []byte,
parseFunc func(string) (T, error),
@@ -818,6 +829,7 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
for i, autoApprover := range autoApprovers {
(*aa)[i] = autoApprover.AutoApprover
}
return nil
}
@@ -874,6 +886,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.AutoApprover = ptr
return nil
}
@@ -894,6 +907,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.Owner = ptr
return nil
}
@@ -910,6 +924,7 @@ func (o *Owners) UnmarshalJSON(b []byte) error {
for i, owner := range owners {
(*o)[i] = owner.Owner
}
return nil
}
@@ -941,6 +956,7 @@ func parseOwner(s string) (Owner, error) {
case isGroup(s):
return ptr.To(Group(s)), nil
}
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
- user (containing an "@")
- group (starting with "group:")
@@ -1001,6 +1017,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
(*g)[group] = usernames
}
return nil
}
@@ -1252,7 +1269,7 @@ type Policy struct {
// We use the default JSON marshalling behavior provided by the Go runtime.
var (
// TODO(kradalby): Add these checks for tagOwners and autoApprovers
// TODO(kradalby): Add these checks for tagOwners and autoApprovers.
autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged}
autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
@@ -1279,7 +1296,7 @@ func validateAutogroupForSrc(src *AutoGroup) error {
}
if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSrc, *src) {
@@ -1307,7 +1324,7 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error {
}
if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSSHSrc, *src) {
@@ -1323,7 +1340,7 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error {
}
if dst.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSSHDst, *dst) {
@@ -1360,14 +1377,14 @@ func (p *Policy) validate() error {
for _, acl := range p.ACLs {
for _, src := range acl.Sources {
switch src.(type) {
switch src := src.(type) {
case *Host:
h := src.(*Host)
h := src
if !p.Hosts.exist(*h) {
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
}
case *AutoGroup:
ag := src.(*AutoGroup)
ag := src
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
@@ -1379,12 +1396,12 @@ func (p *Policy) validate() error {
continue
}
case *Group:
g := src.(*Group)
g := src
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := src.(*Tag)
tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1440,9 +1457,9 @@ func (p *Policy) validate() error {
}
for _, src := range ssh.Sources {
switch src.(type) {
switch src := src.(type) {
case *AutoGroup:
ag := src.(*AutoGroup)
ag := src
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
@@ -1454,21 +1471,21 @@ func (p *Policy) validate() error {
continue
}
case *Group:
g := src.(*Group)
g := src
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := src.(*Tag)
tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
}
}
for _, dst := range ssh.Destinations {
switch dst.(type) {
switch dst := dst.(type) {
case *AutoGroup:
ag := dst.(*AutoGroup)
ag := dst
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
continue
@@ -1479,7 +1496,7 @@ func (p *Policy) validate() error {
continue
}
case *Tag:
tagOwner := dst.(*Tag)
tagOwner := dst
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1489,9 +1506,9 @@ func (p *Policy) validate() error {
for _, tagOwners := range p.TagOwners {
for _, tagOwner := range tagOwners {
switch tagOwner.(type) {
switch tagOwner := tagOwner.(type) {
case *Group:
g := tagOwner.(*Group)
g := tagOwner
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
@@ -1501,14 +1518,14 @@ func (p *Policy) validate() error {
for _, approvers := range p.AutoApprovers.Routes {
for _, approver := range approvers {
switch approver.(type) {
switch approver := approver.(type) {
case *Group:
g := approver.(*Group)
g := approver
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := approver.(*Tag)
tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1517,14 +1534,14 @@ func (p *Policy) validate() error {
}
for _, approver := range p.AutoApprovers.ExitNode {
switch approver.(type) {
switch approver := approver.(type) {
case *Group:
g := approver.(*Group)
g := approver
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := approver.(*Tag)
tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1536,6 +1553,7 @@ func (p *Policy) validate() error {
}
p.validated = true
return nil
}
@@ -1589,6 +1607,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
)
}
}
return nil
}
@@ -1618,6 +1637,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
)
}
}
return nil
}

View File

@@ -5,13 +5,13 @@ import (
"net/netip"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go4.org/netipx"
@@ -68,7 +68,7 @@ func TestMarshalJSON(t *testing.T) {
// Marshal the policy to JSON
marshalled, err := json.MarshalIndent(policy, "", " ")
require.NoError(t, err)
// Make sure all expected fields are present in the JSON
jsonString := string(marshalled)
assert.Contains(t, jsonString, "group:example")
@@ -79,21 +79,21 @@ func TestMarshalJSON(t *testing.T) {
assert.Contains(t, jsonString, "accept")
assert.Contains(t, jsonString, "tcp")
assert.Contains(t, jsonString, "80")
// Unmarshal back to verify round trip
var roundTripped Policy
err = json.Unmarshal(marshalled, &roundTripped)
require.NoError(t, err)
// Compare the original and round-tripped policies
cmps := append(util.Comparers,
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
cmpopts.EquateEmpty(),
)
if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" {
t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff)
}
@@ -958,13 +958,13 @@ func TestUnmarshalPolicy(t *testing.T) {
},
}
cmps := append(util.Comparers,
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
)
// For round-trip testing, we'll normalize the policies before comparing
for _, tt := range tests {
@@ -981,6 +981,7 @@ func TestUnmarshalPolicy(t *testing.T) {
} else if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr)
}
return // Skip the rest of the test if we expected an error
}
@@ -1001,9 +1002,9 @@ func TestUnmarshalPolicy(t *testing.T) {
if err != nil {
t.Fatalf("round-trip unmarshalling: %v", err)
}
// Add EquateEmpty to handle nil vs empty maps/slices
roundTripCmps := append(cmps,
roundTripCmps := append(cmps,
cmpopts.EquateEmpty(),
cmpopts.IgnoreUnexported(Policy{}),
)
@@ -1584,6 +1585,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
builder.AddPrefix(mp(p))
}
ipSet, _ := builder.IPSet()
return ipSet
}

View File

@@ -73,10 +73,10 @@ func TestParsePortRange(t *testing.T) {
expected []tailcfg.PortRange
err string
}{
{"80", []tailcfg.PortRange{{80, 80}}, ""},
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
{"80-", nil, "invalid port range format"},
{"-90", nil, "invalid port range format"},

View File

@@ -158,6 +158,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
}
tsaddr.SortPrefixes(routes)
return routes
}

View File

@@ -429,6 +429,7 @@ func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) {
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
}
@@ -443,6 +444,7 @@ func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, er
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
}
@@ -701,7 +703,7 @@ func (s *State) HandleNodeFromPreAuthKey(
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
nodeToRegister.Expiry = &regReq.Expiry
} else if !regReq.Expiry.IsZero() {
// If client is sending an expired time (e.g., after logout),
// If client is sending an expired time (e.g., after logout),
// don't set expiry so the node won't be considered expired
log.Debug().
Time("requested_expiry", regReq.Expiry).

View File

@@ -2,6 +2,7 @@ package hscontrol
import (
"context"
"errors"
"fmt"
"net/http"
"os"
@@ -70,7 +71,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
// When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443.
certDomains := tsNode.CertDomains()
if len(certDomains) == 0 {
return fmt.Errorf("no cert domains available for HTTPS")
return errors.New("no cert domains available for HTTPS")
}
base := "https://" + certDomains[0]
go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -95,5 +96,6 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
logf("TailSQL started")
<-ctx.Done()
logf("TailSQL shutting down...")
return tsNode.Close()
}

View File

@@ -62,7 +62,7 @@ func Apple(url string) *elem.Element {
),
elem.Pre(nil,
elem.Code(nil,
elem.Text(fmt.Sprintf("tailscale login --login-server %s", url)),
elem.Text("tailscale login --login-server "+url),
),
),
headerTwo("GUI"),
@@ -143,10 +143,7 @@ func Apple(url string) *elem.Element {
elem.Code(
nil,
elem.Text(
fmt.Sprintf(
`defaults write io.tailscale.ipn.macos ControlURL %s`,
url,
),
"defaults write io.tailscale.ipn.macos ControlURL "+url,
),
),
),
@@ -155,10 +152,7 @@ func Apple(url string) *elem.Element {
elem.Code(
nil,
elem.Text(
fmt.Sprintf(
`defaults write io.tailscale.ipn.macsys ControlURL %s`,
url,
),
"defaults write io.tailscale.ipn.macsys ControlURL "+url,
),
),
),

View File

@@ -1,8 +1,6 @@
package templates
import (
"fmt"
"github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs"
)
@@ -31,7 +29,7 @@ func Windows(url string) *elem.Element {
),
elem.Pre(nil,
elem.Code(nil,
elem.Text(fmt.Sprintf(`tailscale login --login-server %s`, url)),
elem.Text("tailscale login --login-server "+url),
),
),
),

View File

@@ -180,6 +180,7 @@ func MustRegistrationID() RegistrationID {
if err != nil {
panic(err)
}
return rid
}

View File

@@ -339,6 +339,7 @@ func LoadConfig(path string, isFile bool) error {
log.Warn().Msg("No config file found, using defaults")
return nil
}
return fmt.Errorf("fatal error reading config file: %w", err)
}
@@ -843,7 +844,7 @@ func LoadServerConfig() (*Config, error) {
}
if prefix4 == nil && prefix6 == nil {
return nil, fmt.Errorf("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
}
allocStr := viper.GetString("prefixes.allocation")
@@ -1020,7 +1021,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
s := len(serverDomainParts)
b := len(baseDomainParts)
for i := range len(baseDomainParts) {
for i := range baseDomainParts {
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
return nil
}

View File

@@ -282,6 +282,7 @@ func TestReadConfigFromEnv(t *testing.T) {
assert.Equal(t, "trace", viper.GetString("log.level"))
assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4"))
assert.False(t, viper.GetBool("database.sqlite.write_ahead_log"))
return nil, nil
},
want: nil,

View File

@@ -28,8 +28,10 @@ var (
ErrNodeUserHasNoName = errors.New("node user has no name")
)
type NodeID uint64
type NodeIDs []NodeID
type (
NodeID uint64
NodeIDs []NodeID
)
func (n NodeIDs) Len() int { return len(n) }
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
@@ -169,6 +171,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
return true
}
}
return false
}
@@ -176,7 +179,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
// via CLI ("forced tags" and preauthkeys).
func (node *Node) IsTagged() bool {
if len(node.ForcedTags) > 0 {
return true
@@ -199,7 +202,7 @@ func (node *Node) IsTagged() bool {
// HasTag reports if a node has a given tag.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
// via CLI ("forced tags" and preauthkeys).
func (node *Node) HasTag(tag string) bool {
return slices.Contains(node.Tags(), tag)
}
@@ -577,6 +580,7 @@ func (nodes Nodes) DebugString() string {
sb.WriteString(node.DebugString())
sb.WriteString("\n")
}
return sb.String()
}
@@ -590,6 +594,7 @@ func (node Node) DebugString() string {
fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes())
fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes())
sb.WriteString("\n")
return sb.String()
}
@@ -689,7 +694,7 @@ func (v NodeView) Tags() []string {
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
// via CLI ("forced tags" and preauthkeys).
func (v NodeView) IsTagged() bool {
if !v.Valid() {
return false
@@ -727,7 +732,7 @@ func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC
// GetFQDN returns the fully qualified domain name for the node.
func (v NodeView) GetFQDN(baseDomain string) (string, error) {
if !v.Valid() {
return "", fmt.Errorf("failed to create valid FQDN: node view is invalid")
return "", errors.New("failed to create valid FQDN: node view is invalid")
}
return v.ж.GetFQDN(baseDomain)
}
@@ -773,4 +778,3 @@ func (v NodeView) IPsAsString() []string {
}
return v.ж.IPsAsString()
}

View File

@@ -2,7 +2,6 @@ package types
import (
"fmt"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"net/netip"
"strings"
"testing"
@@ -10,6 +9,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/types/key"

View File

@@ -11,7 +11,7 @@ import (
type PAKError string
func (e PAKError) Error() string { return string(e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %s", e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
// PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct {

View File

@@ -1,6 +1,7 @@
package types
import (
"errors"
"testing"
"time"
@@ -109,7 +110,8 @@ func TestCanUsePreAuthKey(t *testing.T) {
if err == nil {
t.Errorf("expected error but got none")
} else {
httpErr, ok := err.(PAKError)
var httpErr PAKError
ok := errors.As(err, &httpErr)
if !ok {
t.Errorf("expected HTTPError but got %T", err)
} else {

View File

@@ -249,7 +249,7 @@ func (c *OIDCClaims) Identifier() string {
// - Remove empty path segments
// - For non-URL identifiers, it joins non-empty segments with a single slash
// - Returns empty string for identifiers with only slashes
// - Normalize URL schemes to lowercase
// - Normalize URL schemes to lowercase.
func CleanIdentifier(identifier string) string {
if identifier == "" {
return identifier
@@ -273,7 +273,7 @@ func CleanIdentifier(identifier string) string {
cleanParts = append(cleanParts, part)
}
}
if len(cleanParts) == 0 {
u.Path = ""
} else {
@@ -281,6 +281,7 @@ func CleanIdentifier(identifier string) string {
}
// Ensure scheme is lowercase
u.Scheme = strings.ToLower(u.Scheme)
return u.String()
}
@@ -297,6 +298,7 @@ func CleanIdentifier(identifier string) string {
if len(cleanParts) == 0 {
return ""
}
return strings.Join(cleanParts, "/")
}

View File

@@ -1,4 +1,6 @@
package types
var Version = "dev"
var GitCommitHash = "dev"
var (
Version = "dev"
GitCommitHash = "dev"
)

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net/netip"
"regexp"
"strconv"
"strings"
"unicode"
@@ -21,8 +22,10 @@ const (
LabelHostnameLength = 63
)
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
var (
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
var ErrInvalidUserName = errors.New("invalid user name")
@@ -141,7 +144,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
rdnsSlice := []string{}
for i := lastOctet - 1; i >= 0; i-- {
rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i]))
rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10))
}
rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
rdnsBase := strings.Join(rdnsSlice, ".")
@@ -205,7 +208,7 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) {
prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".")
return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix))
return dnsname.ToFQDN(prefix + ".ip6.arpa")
}
var fqdns []dnsname.FQDN

View File

@@ -70,7 +70,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq
"rowsAffected": rowsAffected,
}
if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.SkipErrRecordNotFound) {
if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.SkipErrRecordNotFound) {
l.Logger.Error().Err(err).Fields(fields).Msgf("")
return
}

View File

@@ -58,5 +58,6 @@ var TheInternet = sync.OnceValue(func() *netipx.IPSet {
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
theInternetSet, _ := internetBuilder.IPSet()
return theInternetSet
})

View File

@@ -53,37 +53,37 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
}
type TraceroutePath struct {
// Hop is the current jump in the total traceroute.
Hop int
// Hop is the current jump in the total traceroute.
Hop int
// Hostname is the resolved hostname or IP address identifying the jump
Hostname string
// Hostname is the resolved hostname or IP address identifying the jump
Hostname string
// IP is the IP address of the jump
IP netip.Addr
// IP is the IP address of the jump
IP netip.Addr
// Latencies is a list of the latencies for this jump
Latencies []time.Duration
// Latencies is a list of the latencies for this jump
Latencies []time.Duration
}
type Traceroute struct {
// Hostname is the resolved hostname or IP address identifying the target
Hostname string
// Hostname is the resolved hostname or IP address identifying the target
Hostname string
// IP is the IP address of the target
IP netip.Addr
// IP is the IP address of the target
IP netip.Addr
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
Route []TraceroutePath
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
Route []TraceroutePath
// Success indicates if the traceroute was successful.
Success bool
// Success indicates if the traceroute was successful.
Success bool
// Err contains an error if the traceroute was not successful.
Err error
// Err contains an error if the traceroute was not successful.
Err error
}
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct.
func ParseTraceroute(output string) (Traceroute, error) {
lines := strings.Split(strings.TrimSpace(output), "\n")
if len(lines) < 1 {
@@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
}
// Parse each hop line
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
for i := 1; i < len(lines); i++ {
matches := hopRegex.FindStringSubmatch(lines[i])

View File

@@ -1077,7 +1077,6 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@@ -1213,7 +1212,6 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
func TestACLAutogroupMember(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := aclScenario(t,
&policyv2.Policy{
@@ -1271,7 +1269,6 @@ func TestACLAutogroupMember(t *testing.T) {
func TestACLAutogroupTagged(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := aclScenario(t,
&policyv2.Policy{

View File

@@ -3,12 +3,11 @@ package integration
import (
"fmt"
"net/netip"
"slices"
"strconv"
"testing"
"time"
"slices"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
@@ -19,7 +18,6 @@ import (
func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
@@ -66,7 +64,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients))
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@@ -161,12 +159,11 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
})
}
}
func assertLastSeenSet(t *testing.T, node *v1.Node) {
assert.NotNil(t, node)
assert.NotNil(t, node.LastSeen)
assert.NotNil(t, node.GetLastSeen())
}
// This test will first log in two sets of nodes to two sets of users, then
@@ -175,7 +172,6 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) {
// still has nodes, but they are not connected.
func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -204,7 +200,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients))
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@@ -259,7 +255,6 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
@@ -303,7 +298,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients))
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)

View File

@@ -1,14 +1,12 @@
package integration
import (
"fmt"
"maps"
"net/netip"
"sort"
"testing"
"time"
"maps"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@@ -21,7 +19,6 @@ import (
func TestOIDCAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
// Logins to MockOIDC is served by a queue with a strict order,
// if we use more than one node per user, the order of the logins
@@ -119,7 +116,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
// This test is really flaky.
func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
shortAccessTTL := 5 * time.Minute
@@ -174,9 +170,13 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
// of safety reasons) before checking if the clients have logged out.
// The Wait function can't do it itself as it has an upper bound of 1
// min.
time.Sleep(shortAccessTTL + 10*time.Second)
assertTailscaleNodesLogout(t, allClients)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
status, err := client.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}
}, shortAccessTTL+10*time.Second, 5*time.Second)
}
func TestOIDC024UserCreation(t *testing.T) {
@@ -295,9 +295,7 @@ func TestOIDC024UserCreation(t *testing.T) {
spec := ScenarioSpec{
NodesPerUser: 1,
}
for _, user := range tt.cliUsers {
spec.Users = append(spec.Users, user)
}
spec.Users = append(spec.Users, tt.cliUsers...)
for _, user := range tt.oidcUsers {
spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified))
@@ -350,7 +348,6 @@ func TestOIDC024UserCreation(t *testing.T) {
func TestOIDCAuthenticationWithPKCE(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
// Single user with one node for testing PKCE flow
spec := ScenarioSpec{
@@ -402,7 +399,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
func TestOIDCReloginSameNodeNewUser(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
// Create no nodes and no users
scenario, err := NewScenario(ScenarioSpec{
@@ -440,7 +436,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 0)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
assertNoErr(t, err)
@@ -482,7 +478,13 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
err = ts.Logout()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
// Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
@@ -530,16 +532,22 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Machine key is the same as the "machine" has not changed,
// but Node key is not as it is a new node
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey)
assert.Equal(t, listNodesAfterNewUserLogin[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey)
assert.NotEqual(t, listNodesAfterNewUserLogin[0].NodeKey, listNodesAfterNewUserLogin[1].NodeKey)
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
// Log out user2, and log into user1, no new node should be created,
// the node should now "become" node1 again
err = ts.Logout()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
// Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
@@ -588,24 +596,24 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Validate that the machine we had when we logged in the first time, has the same
// machine key, but a different ID than the newly logged in version of the same
// machine.
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey)
assert.Equal(t, listNodes[0].NodeKey, listNodesAfterNewUserLogin[0].NodeKey)
assert.Equal(t, listNodes[0].Id, listNodesAfterNewUserLogin[0].Id)
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey)
assert.NotEqual(t, listNodes[0].Id, listNodesAfterNewUserLogin[1].Id)
assert.NotEqual(t, listNodes[0].User.Id, listNodesAfterNewUserLogin[1].User.Id)
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
// Even tho we are logging in again with the same user, the previous key has been expired
// and a new one has been generated. The node entry in the database should be the same
// as the user + machinekey still matches.
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterLoggingBackIn[0].MachineKey)
assert.NotEqual(t, listNodes[0].NodeKey, listNodesAfterLoggingBackIn[0].NodeKey)
assert.Equal(t, listNodes[0].Id, listNodesAfterLoggingBackIn[0].Id)
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
// The "logged back in" machine should have the same machinekey but a different nodekey
// than the version logged in with a different user.
assert.Equal(t, listNodesAfterLoggingBackIn[0].MachineKey, listNodesAfterLoggingBackIn[1].MachineKey)
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey)
assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
}
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
@@ -623,7 +631,7 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
return mockoidc.MockUser{
Subject: username,
PreferredUsername: username,
Email: fmt.Sprintf("%s@headscale.net", username),
Email: username + "@headscale.net",
EmailVerified: emailVerified,
}
}

View File

@@ -2,9 +2,8 @@ package integration
import (
"net/netip"
"testing"
"slices"
"testing"
"github.com/juanfont/headscale/integration/hsic"
"github.com/samber/lo"
@@ -55,7 +54,6 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -95,7 +93,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients))
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@@ -140,7 +138,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
listNodes, err = headscale.ListNodes()
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
require.Len(t, listNodes, nodeCountBeforeLogout)
t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes))
for _, client := range allClients {

View File

@@ -18,8 +18,8 @@ import (
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"golang.org/x/exp/slices"
"tailscale.com/tailcfg"
)
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
@@ -30,7 +30,7 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul
err = json.Unmarshal([]byte(str), result)
if err != nil {
return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str)
return fmt.Errorf("failed to unmarshal: %w\n command err: %s", err, str)
}
return nil
@@ -48,7 +48,6 @@ func sortWithID[T GRPCSortable](a, b T) int {
func TestUserCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"user1", "user2"},
@@ -184,7 +183,7 @@ func TestUserCommand(t *testing.T) {
"--identifier=1",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Contains(t, deleteResult, "User destroyed")
var listAfterIDDelete []*v1.User
@@ -222,7 +221,7 @@ func TestUserCommand(t *testing.T) {
"--name=newname",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Contains(t, deleteResult, "User destroyed")
var listAfterNameDelete []v1.User
@@ -238,12 +237,11 @@ func TestUserCommand(t *testing.T) {
)
assertNoErr(t, err)
require.Len(t, listAfterNameDelete, 0)
require.Empty(t, listAfterNameDelete)
}
func TestPreAuthKeyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "preauthkeyspace"
count := 3
@@ -347,7 +345,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
continue
}
assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"})
assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags())
}
// Test key expiry
@@ -386,7 +384,6 @@ func TestPreAuthKeyCommand(t *testing.T) {
func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "pre-auth-key-without-exp-user"
spec := ScenarioSpec{
@@ -448,7 +445,6 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "pre-auth-key-reus-ephm-user"
spec := ScenarioSpec{
@@ -524,7 +520,6 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user1 := "user1"
user2 := "user2"
@@ -575,7 +570,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
assertNoErr(t, err)
listNodes, err := headscale.ListNodes()
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, listNodes, 1)
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
@@ -613,7 +608,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
}
listNodes, err = headscale.ListNodes()
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, listNodes, 2)
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
assert.Equal(t, user2, listNodes[1].GetUser().GetName())
@@ -621,7 +616,6 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
func TestApiKeyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
count := 5
@@ -653,7 +647,7 @@ func TestApiKeyCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.NotEmpty(t, apiResult)
keys[idx] = apiResult
@@ -672,7 +666,7 @@ func TestApiKeyCommand(t *testing.T) {
},
&listedAPIKeys,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listedAPIKeys, 5)
@@ -728,7 +722,7 @@ func TestApiKeyCommand(t *testing.T) {
listedAPIKeys[idx].GetPrefix(),
},
)
assert.Nil(t, err)
assert.NoError(t, err)
expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true
}
@@ -744,7 +738,7 @@ func TestApiKeyCommand(t *testing.T) {
},
&listedAfterExpireAPIKeys,
)
assert.Nil(t, err)
assert.NoError(t, err)
for index := range listedAfterExpireAPIKeys {
if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok {
@@ -770,7 +764,7 @@ func TestApiKeyCommand(t *testing.T) {
"--prefix",
listedAPIKeys[0].GetPrefix(),
})
assert.Nil(t, err)
assert.NoError(t, err)
var listedAPIKeysAfterDelete []v1.ApiKey
err = executeAndUnmarshal(headscale,
@@ -783,14 +777,13 @@ func TestApiKeyCommand(t *testing.T) {
},
&listedAPIKeysAfterDelete,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listedAPIKeysAfterDelete, 4)
}
func TestNodeTagCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"user1"},
@@ -811,7 +804,7 @@ func TestNodeTagCommand(t *testing.T) {
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
assert.Nil(t, err)
assert.NoError(t, err)
for index, regID := range regIDs {
_, err := headscale.Execute(
@@ -829,7 +822,7 @@ func TestNodeTagCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@@ -847,7 +840,7 @@ func TestNodeTagCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
nodes[index] = &node
}
@@ -866,7 +859,7 @@ func TestNodeTagCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
@@ -894,7 +887,7 @@ func TestNodeTagCommand(t *testing.T) {
},
&resultMachines,
)
assert.Nil(t, err)
assert.NoError(t, err)
found := false
for _, node := range resultMachines {
if node.GetForcedTags() != nil {
@@ -905,19 +898,15 @@ func TestNodeTagCommand(t *testing.T) {
}
}
}
assert.Equal(
assert.True(
t,
true,
found,
"should find a node with the tag 'tag:test' in the list of nodes",
)
}
func TestNodeAdvertiseTagCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
tests := []struct {
name string
@@ -1024,7 +1013,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
},
&resultMachines,
)
assert.Nil(t, err)
assert.NoError(t, err)
found := false
for _, node := range resultMachines {
if tags := node.GetValidTags(); tags != nil {
@@ -1043,7 +1032,6 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
func TestNodeCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"node-user", "other-user"},
@@ -1067,7 +1055,7 @@ func TestNodeCommand(t *testing.T) {
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
assert.Nil(t, err)
assert.NoError(t, err)
for index, regID := range regIDs {
_, err := headscale.Execute(
@@ -1085,7 +1073,7 @@ func TestNodeCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@@ -1103,7 +1091,7 @@ func TestNodeCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
nodes[index] = &node
}
@@ -1123,7 +1111,7 @@ func TestNodeCommand(t *testing.T) {
},
&listAll,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAll, 5)
@@ -1144,7 +1132,7 @@ func TestNodeCommand(t *testing.T) {
types.MustRegistrationID().String(),
}
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
assert.Nil(t, err)
assert.NoError(t, err)
for index, regID := range otherUserRegIDs {
_, err := headscale.Execute(
@@ -1162,7 +1150,7 @@ func TestNodeCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@@ -1180,7 +1168,7 @@ func TestNodeCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
otherUserMachines[index] = &node
}
@@ -1200,7 +1188,7 @@ func TestNodeCommand(t *testing.T) {
},
&listAllWithotherUser,
)
assert.Nil(t, err)
assert.NoError(t, err)
// All nodes, nodes + otherUser
assert.Len(t, listAllWithotherUser, 7)
@@ -1226,7 +1214,7 @@ func TestNodeCommand(t *testing.T) {
},
&listOnlyotherUserMachineUser,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listOnlyotherUserMachineUser, 2)
@@ -1258,7 +1246,7 @@ func TestNodeCommand(t *testing.T) {
"--force",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
// Test: list main user after node is deleted
var listOnlyMachineUserAfterDelete []v1.Node
@@ -1275,14 +1263,13 @@ func TestNodeCommand(t *testing.T) {
},
&listOnlyMachineUserAfterDelete,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listOnlyMachineUserAfterDelete, 4)
}
func TestNodeExpireCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"node-expire-user"},
@@ -1323,7 +1310,7 @@ func TestNodeExpireCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@@ -1341,7 +1328,7 @@ func TestNodeExpireCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
nodes[index] = &node
}
@@ -1360,7 +1347,7 @@ func TestNodeExpireCommand(t *testing.T) {
},
&listAll,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAll, 5)
@@ -1377,10 +1364,10 @@ func TestNodeExpireCommand(t *testing.T) {
"nodes",
"expire",
"--identifier",
fmt.Sprintf("%d", listAll[idx].GetId()),
strconv.FormatUint(listAll[idx].GetId(), 10),
},
)
assert.Nil(t, err)
assert.NoError(t, err)
}
var listAllAfterExpiry []v1.Node
@@ -1395,7 +1382,7 @@ func TestNodeExpireCommand(t *testing.T) {
},
&listAllAfterExpiry,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAllAfterExpiry, 5)
@@ -1408,7 +1395,6 @@ func TestNodeExpireCommand(t *testing.T) {
func TestNodeRenameCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"node-rename-command"},
@@ -1432,7 +1418,7 @@ func TestNodeRenameCommand(t *testing.T) {
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
assert.Nil(t, err)
assert.NoError(t, err)
for index, regID := range regIDs {
_, err := headscale.Execute(
@@ -1487,7 +1473,7 @@ func TestNodeRenameCommand(t *testing.T) {
},
&listAll,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAll, 5)
@@ -1504,11 +1490,11 @@ func TestNodeRenameCommand(t *testing.T) {
"nodes",
"rename",
"--identifier",
fmt.Sprintf("%d", listAll[idx].GetId()),
strconv.FormatUint(listAll[idx].GetId(), 10),
fmt.Sprintf("newnode-%d", idx+1),
},
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Contains(t, res, "Node renamed")
}
@@ -1525,7 +1511,7 @@ func TestNodeRenameCommand(t *testing.T) {
},
&listAllAfterRename,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAllAfterRename, 5)
@@ -1542,7 +1528,7 @@ func TestNodeRenameCommand(t *testing.T) {
"nodes",
"rename",
"--identifier",
fmt.Sprintf("%d", listAll[4].GetId()),
strconv.FormatUint(listAll[4].GetId(), 10),
strings.Repeat("t", 64),
},
)
@@ -1560,7 +1546,7 @@ func TestNodeRenameCommand(t *testing.T) {
},
&listAllAfterRenameAttempt,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAllAfterRenameAttempt, 5)
@@ -1573,7 +1559,6 @@ func TestNodeRenameCommand(t *testing.T) {
func TestNodeMoveCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"old-user", "new-user"},
@@ -1610,7 +1595,7 @@ func TestNodeMoveCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@@ -1628,13 +1613,13 @@ func TestNodeMoveCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, uint64(1), node.GetId())
assert.Equal(t, "nomad-node", node.GetName())
assert.Equal(t, node.GetUser().GetName(), "old-user")
assert.Equal(t, "old-user", node.GetUser().GetName())
nodeID := fmt.Sprintf("%d", node.GetId())
nodeID := strconv.FormatUint(node.GetId(), 10)
err = executeAndUnmarshal(
headscale,
@@ -1651,9 +1636,9 @@ func TestNodeMoveCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "new-user")
assert.Equal(t, "new-user", node.GetUser().GetName())
var allNodes []v1.Node
err = executeAndUnmarshal(
@@ -1667,13 +1652,13 @@ func TestNodeMoveCommand(t *testing.T) {
},
&allNodes,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, allNodes, 1)
assert.Equal(t, allNodes[0].GetId(), node.GetId())
assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user")
assert.Equal(t, "new-user", allNodes[0].GetUser().GetName())
_, err = headscale.Execute(
[]string{
@@ -1693,7 +1678,7 @@ func TestNodeMoveCommand(t *testing.T) {
err,
"user not found",
)
assert.Equal(t, node.GetUser().GetName(), "new-user")
assert.Equal(t, "new-user", node.GetUser().GetName())
err = executeAndUnmarshal(
headscale,
@@ -1710,9 +1695,9 @@ func TestNodeMoveCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "old-user")
assert.Equal(t, "old-user", node.GetUser().GetName())
err = executeAndUnmarshal(
headscale,
@@ -1729,14 +1714,13 @@ func TestNodeMoveCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "old-user")
assert.Equal(t, "old-user", node.GetUser().GetName())
}
func TestPolicyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"user1"},
@@ -1817,7 +1801,6 @@ func TestPolicyCommand(t *testing.T) {
func TestPolicyBrokenConfigCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,

View File

@@ -1,7 +1,6 @@
package integration
import (
"context"
"fmt"
"net"
"strconv"
@@ -104,7 +103,7 @@ func DERPVerify(
defer c.Close()
var result error
if err := c.Connect(context.Background()); err != nil {
if err := c.Connect(t.Context()); err != nil {
result = fmt.Errorf("client Connect: %w", err)
}
if m, err := c.Recv(); err != nil {

View File

@@ -15,7 +15,6 @@ import (
func TestResolveMagicDNS(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -49,7 +48,7 @@ func TestResolveMagicDNS(t *testing.T) {
// It is safe to ignore this error as we handled it when caching it
peerFQDN, _ := peer.FQDN()
assert.Equal(t, fmt.Sprintf("%s.headscale.net.", peer.Hostname()), peerFQDN)
assert.Equal(t, peer.Hostname()+".headscale.net.", peerFQDN)
command := []string{
"tailscale",
@@ -85,7 +84,6 @@ func TestResolveMagicDNS(t *testing.T) {
func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@@ -222,12 +220,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
_, err = hs.Execute([]string{"rm", erPath})
assertNoErr(t, err)
time.Sleep(2 * time.Second)
// The same paths should still be available as it is not cleared on delete.
for _, client := range allClients {
assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9")
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
result, _, err := client.Execute([]string{"dig", "docker.myvpn.example.com"})
assert.NoError(ct, err)
assert.Contains(ct, result, "9.9.9.9")
}
}, 10*time.Second, 1*time.Second)
// Write a new file, the backoff mechanism should make the filewatcher pick it up
// again.

View File

@@ -33,26 +33,27 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) {
}
// GenerateRunID creates a unique run identifier with timestamp and random hash.
// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3)
// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3).
func GenerateRunID() string {
now := time.Now()
timestamp := now.Format("20060102-150405")
// Add a short random hash to ensure uniqueness
randomHash := util.MustGenerateRandomStringDNSSafe(6)
return fmt.Sprintf("%s-%s", timestamp, randomHash)
}
// ExtractRunIDFromContainerName extracts the run ID from container name.
// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH"
// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH".
func ExtractRunIDFromContainerName(containerName string) string {
parts := strings.Split(containerName, "-")
if len(parts) >= 3 {
// Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH)
return strings.Join(parts[len(parts)-3:], "-")
}
panic(fmt.Sprintf("unexpected container name format: %s", containerName))
panic("unexpected container name format: " + containerName)
}
// IsRunningInContainer checks if the current process is running inside a Docker container.
@@ -62,4 +63,4 @@ func IsRunningInContainer() bool {
// This could be improved with more robust detection if needed
_, err := os.Stat("/.dockerenv")
return err == nil
}
}

View File

@@ -30,7 +30,7 @@ func ExecuteCommandTimeout(timeout time.Duration) ExecuteCommandOption {
})
}
// buffer is a goroutine safe bytes.buffer
// buffer is a goroutine safe bytes.buffer.
type buffer struct {
store bytes.Buffer
mutex sync.Mutex
@@ -58,8 +58,8 @@ func ExecuteCommand(
env []string,
options ...ExecuteCommandOption,
) (string, string, error) {
var stdout = buffer{}
var stderr = buffer{}
stdout := buffer{}
stderr := buffer{}
execConfig := ExecuteCommandConfig{
timeout: dockerExecuteTimeout,

View File

@@ -159,7 +159,6 @@ func New(
},
}
if dsic.workdir != "" {
runOptions.WorkingDir = dsic.workdir
}
@@ -192,7 +191,7 @@ func New(
}
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(runOptions, "derp")
container, err = pool.BuildAndRunWithBuildOptions(
buildOptions,
runOptions,

View File

@@ -2,13 +2,13 @@ package integration
import (
"strings"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"testing"
"time"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
type ClientsSpec struct {
@@ -71,9 +71,9 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
NodesPerUser: 1,
Users: []string{"user1", "user2", "user3"},
Networks: map[string][]string{
"usernet1": []string{"user1"},
"usernet2": []string{"user2"},
"usernet3": []string{"user3"},
"usernet1": {"user1"},
"usernet2": {"user2"},
"usernet3": {"user3"},
},
}
@@ -106,7 +106,6 @@ func derpServerScenario(
furtherAssertions ...func(*Scenario),
) {
IntegrationSkip(t)
// t.Parallel()
scenario, err := NewScenario(spec)
assertNoErr(t, err)

View File

@@ -26,7 +26,6 @@ import (
func TestPingAllByIP(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -68,7 +67,6 @@ func TestPingAllByIP(t *testing.T) {
func TestPingAllByIPPublicDERP(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -118,7 +116,6 @@ func TestEphemeralInAlternateTimezone(t *testing.T) {
func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -191,7 +188,6 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
// deleted by accident if they are still online and active.
func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -260,18 +256,21 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
// Wait a bit and bring up the clients again before the expiry
// time of the ephemeral nodes.
// Nodes should be able to reconnect and work fine.
time.Sleep(30 * time.Second)
for _, client := range allClients {
err := client.Up()
if err != nil {
t.Fatalf("failed to take down client %s: %s", client.Hostname(), err)
}
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
success = pingAllHelper(t, allClients, allAddrs)
// Wait for clients to sync and be able to ping each other after reconnection
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = scenario.WaitForTailscaleSync()
assert.NoError(ct, err)
success = pingAllHelper(t, allClients, allAddrs)
assert.Greater(ct, success, 0, "Ephemeral nodes should be able to reconnect and ping")
}, 60*time.Second, 2*time.Second)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
// Take down all clients, this should start an expiry timer for each.
@@ -284,7 +283,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
// This time wait for all of the nodes to expire and check that they are no longer
// registered.
time.Sleep(3 * time.Minute)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, userName := range spec.Users {
nodes, err := headscale.ListNodes(userName)
assert.NoError(ct, err)
assert.Len(ct, nodes, 0, "Ephemeral nodes should be expired and removed for user %s", userName)
}
}, 4*time.Minute, 10*time.Second)
for _, userName := range spec.Users {
nodes, err := headscale.ListNodes(userName)
@@ -305,7 +310,6 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
func TestPingAllByHostname(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -341,20 +345,6 @@ func TestPingAllByHostname(t *testing.T) {
// nolint:tparallel
func TestTaildrop(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
retry := func(times int, sleepInterval time.Duration, doWork func() error) error {
var err error
for range times {
err = doWork()
if err == nil {
return nil
}
time.Sleep(sleepInterval)
}
return err
}
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -396,40 +386,27 @@ func TestTaildrop(t *testing.T) {
"/var/run/tailscale/tailscaled.sock",
"http://local-tailscaled.sock/localapi/v0/file-targets",
}
err = retry(10, 1*time.Second, func() error {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, _, err := client.Execute(curlCommand)
if err != nil {
return err
}
assert.NoError(ct, err)
var fts []apitype.FileTarget
err = json.Unmarshal([]byte(result), &fts)
if err != nil {
return err
}
assert.NoError(ct, err)
if len(fts) != len(allClients)-1 {
ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname())
for _, ft := range fts {
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
}
return fmt.Errorf(
"client %s does not have all its peers as FileTargets, got %d, want: %d\n%s",
client.Hostname(),
assert.Failf(ct, "client %s does not have all its peers as FileTargets",
"got %d, want: %d\n%s",
len(fts),
len(allClients)-1,
ftStr,
)
}
return err
})
if err != nil {
t.Errorf(
"failed to query localapi for filetarget on %s, err: %s",
client.Hostname(),
err,
)
}
}, 10*time.Second, 1*time.Second)
}
for _, client := range allClients {
@@ -454,24 +431,15 @@ func TestTaildrop(t *testing.T) {
fmt.Sprintf("%s:", peerFQDN),
}
err := retry(10, 1*time.Second, func() error {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
t.Logf(
"Sending file from %s to %s\n",
client.Hostname(),
peer.Hostname(),
)
_, _, err := client.Execute(command)
return err
})
if err != nil {
t.Fatalf(
"failed to send taildrop file on %s with command %q, err: %s",
client.Hostname(),
strings.Join(command, " "),
err,
)
}
assert.NoError(ct, err)
}, 10*time.Second, 1*time.Second)
})
}
}
@@ -520,7 +488,6 @@ func TestTaildrop(t *testing.T) {
func TestUpdateHostnameFromClient(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
hostnames := map[string]string{
"1": "user1-host",
@@ -603,9 +570,47 @@ func TestUpdateHostnameFromClient(t *testing.T) {
assertNoErr(t, err)
}
time.Sleep(5 * time.Second)
// Verify that the server-side rename is reflected in DNSName while HostName remains unchanged
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Build a map of expected DNSNames by node ID
expectedDNSNames := make(map[string]string)
for _, node := range nodes {
nodeID := strconv.FormatUint(node.GetId(), 10)
expectedDNSNames[nodeID] = fmt.Sprintf("%d-givenname.headscale.net.", node.GetId())
}
// Verify from each client's perspective
for _, client := range allClients {
status, err := client.Status()
assert.NoError(ct, err)
// Check self node
selfID := string(status.Self.ID)
expectedDNS := expectedDNSNames[selfID]
assert.Equal(ct, expectedDNS, status.Self.DNSName,
"Self DNSName should be renamed for client %s (ID: %s)", client.Hostname(), selfID)
// HostName should remain as the original client-reported hostname
originalHostname := hostnames[selfID]
assert.Equal(ct, originalHostname, status.Self.HostName,
"Self HostName should remain unchanged for client %s (ID: %s)", client.Hostname(), selfID)
// Check peers
for _, peer := range status.Peer {
peerID := string(peer.ID)
if expectedDNS, ok := expectedDNSNames[peerID]; ok {
assert.Equal(ct, expectedDNS, peer.DNSName,
"Peer DNSName should be renamed for peer ID %s as seen by client %s", peerID, client.Hostname())
// HostName should remain as the original client-reported hostname
originalHostname := hostnames[peerID]
assert.Equal(ct, originalHostname, peer.HostName,
"Peer HostName should remain unchanged for peer ID %s as seen by client %s", peerID, client.Hostname())
}
}
}
}, 60*time.Second, 2*time.Second)
// Verify that the clients can see the new hostname, but no givenName
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
@@ -647,7 +652,6 @@ func TestUpdateHostnameFromClient(t *testing.T) {
func TestExpireNode(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -707,7 +711,23 @@ func TestExpireNode(t *testing.T) {
t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String())
time.Sleep(2 * time.Minute)
// Verify that the expired node has been marked in all peers list.
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
status, err := client.Status()
assert.NoError(ct, err)
if client.Hostname() != node.GetName() {
// Check if the expired node appears as expired in this client's peer list
for key, peer := range status.Peer {
if key == expiredNodeKey {
assert.True(ct, peer.Expired, "Node should be marked as expired for client %s", client.Hostname())
break
}
}
}
}
}, 3*time.Minute, 10*time.Second)
now := time.Now()
@@ -774,7 +794,6 @@ func TestExpireNode(t *testing.T) {
func TestNodeOnlineStatus(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -890,7 +909,6 @@ func TestNodeOnlineStatus(t *testing.T) {
// five times ensuring they are able to restablish connectivity.
func TestPingAllByIPManyUpDown(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@@ -944,8 +962,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
t.Fatalf("failed to take down all nodes: %s", err)
}
time.Sleep(5 * time.Second)
for _, client := range allClients {
c := client
wg.Go(func() error {
@@ -958,10 +974,14 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
t.Fatalf("failed to take down all nodes: %s", err)
}
time.Sleep(5 * time.Second)
// Wait for sync and successful pings after nodes come back up
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = scenario.WaitForTailscaleSync()
assert.NoError(ct, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
success := pingAllHelper(t, allClients, allAddrs)
assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up")
}, 30*time.Second, 2*time.Second)
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
@@ -970,7 +990,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
func Test2118DeletingOnlineNodePanics(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@@ -1042,10 +1061,24 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
)
require.NoError(t, err)
time.Sleep(2 * time.Second)
// Ensure that the node has been deleted, this did not occur due to a panic.
var nodeListAfter []v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
&nodeListAfter,
)
assert.NoError(ct, err)
assert.Len(ct, nodeListAfter, 1, "Node should be deleted from list")
}, 10*time.Second, 1*time.Second)
err = executeAndUnmarshal(
headscale,
[]string{

View File

@@ -191,7 +191,7 @@ func WithPostgres() Option {
}
}
// WithPolicy sets the policy mode for headscale
// WithPolicy sets the policy mode for headscale.
func WithPolicyMode(mode types.PolicyMode) Option {
return func(hsic *HeadscaleInContainer) {
hsic.policyMode = mode
@@ -279,7 +279,7 @@ func New(
return nil, err
}
hostname := fmt.Sprintf("hs-%s", hash)
hostname := "hs-" + hash
hsic := &HeadscaleInContainer{
hostname: hostname,
@@ -308,14 +308,14 @@ func New(
if hsic.postgres {
hsic.env["HEADSCALE_DATABASE_TYPE"] = "postgres"
hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = fmt.Sprintf("postgres-%s", hash)
hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = "postgres-" + hash
hsic.env["HEADSCALE_DATABASE_POSTGRES_USER"] = "headscale"
hsic.env["HEADSCALE_DATABASE_POSTGRES_PASS"] = "headscale"
hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale"
delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH")
pgRunOptions := &dockertest.RunOptions{
Name: fmt.Sprintf("postgres-%s", hash),
Name: "postgres-" + hash,
Repository: "postgres",
Tag: "latest",
Networks: networks,
@@ -328,7 +328,7 @@ func New(
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(pgRunOptions, "postgres")
pg, err := pool.RunWithOptions(pgRunOptions)
if err != nil {
return nil, fmt.Errorf("starting postgres container: %w", err)
@@ -373,7 +373,6 @@ func New(
Env: env,
}
if len(hsic.hostPortBindings) > 0 {
runOptions.PortBindings = map[docker.Port][]docker.PortBinding{}
for port, hostPorts := range hsic.hostPortBindings {
@@ -396,7 +395,7 @@ func New(
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale")
container, err := pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
runOptions,
@@ -566,7 +565,7 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error {
// extractTarToDirectory extracts a tar archive to a directory.
func extractTarToDirectory(tarData []byte, targetDir string) error {
if err := os.MkdirAll(targetDir, 0755); err != nil {
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
}
@@ -624,6 +623,7 @@ func (t *HeadscaleInContainer) SaveProfile(savePath string) error {
}
targetDir := path.Join(savePath, t.hostname+"-pprof")
return extractTarToDirectory(tarFile, targetDir)
}
@@ -634,6 +634,7 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error {
}
targetDir := path.Join(savePath, t.hostname+"-mapresponses")
return extractTarToDirectory(tarFile, targetDir)
}
@@ -672,17 +673,16 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
if err != nil {
return fmt.Errorf("failed to check database schema (sqlite3 command failed): %w", err)
}
if strings.TrimSpace(schemaCheck) == "" {
return fmt.Errorf("database file exists but has no schema (empty database)")
return errors.New("database file exists but has no schema (empty database)")
}
// Show a preview of the schema (first 500 chars)
schemaPreview := schemaCheck
if len(schemaPreview) > 500 {
schemaPreview = schemaPreview[:500] + "..."
}
log.Printf("Database schema preview:\n%s", schemaPreview)
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
if err != nil {
@@ -727,7 +727,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
}
}
return fmt.Errorf("no regular file found in database tar archive")
return errors.New("no regular file found in database tar archive")
}
// Execute runs a command inside the Headscale container and returns the
@@ -756,13 +756,13 @@ func (t *HeadscaleInContainer) Execute(
// GetPort returns the docker container port as a string.
func (t *HeadscaleInContainer) GetPort() string {
return fmt.Sprintf("%d", t.port)
return strconv.Itoa(t.port)
}
// GetHealthEndpoint returns a health endpoint for the HeadscaleInContainer
// instance.
func (t *HeadscaleInContainer) GetHealthEndpoint() string {
return fmt.Sprintf("%s/health", t.GetEndpoint())
return t.GetEndpoint() + "/health"
}
// GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer.
@@ -772,10 +772,10 @@ func (t *HeadscaleInContainer) GetEndpoint() string {
t.port)
if t.hasTLS() {
return fmt.Sprintf("https://%s", hostEndpoint)
return "https://" + hostEndpoint
}
return fmt.Sprintf("http://%s", hostEndpoint)
return "http://" + hostEndpoint
}
// GetCert returns the public certificate of the HeadscaleInContainer.
@@ -910,6 +910,7 @@ func (t *HeadscaleInContainer) ListNodes(
}
ret = append(ret, nodes...)
return nil
}
@@ -932,6 +933,7 @@ func (t *HeadscaleInContainer) ListNodes(
sort.Slice(ret, func(i, j int) bool {
return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1
})
return ret, nil
}
@@ -943,10 +945,10 @@ func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) {
var userMap map[string][]*v1.Node
for _, node := range nodes {
if _, ok := userMap[node.User.Name]; !ok {
mak.Set(&userMap, node.User.Name, []*v1.Node{node})
if _, ok := userMap[node.GetUser().GetName()]; !ok {
mak.Set(&userMap, node.GetUser().GetName(), []*v1.Node{node})
} else {
userMap[node.User.Name] = append(userMap[node.User.Name], node)
userMap[node.GetUser().GetName()] = append(userMap[node.GetUser().GetName()], node)
}
}
@@ -999,7 +1001,7 @@ func (t *HeadscaleInContainer) MapUsers() (map[string]*v1.User, error) {
var userMap map[string]*v1.User
for _, user := range users {
mak.Set(&userMap, user.Name, user)
mak.Set(&userMap, user.GetName(), user)
}
return userMap, nil
@@ -1095,7 +1097,7 @@ func (h *HeadscaleInContainer) PID() (int, error) {
case 1:
return pids[0], nil
default:
return 0, fmt.Errorf("multiple headscale processes running")
return 0, errors.New("multiple headscale processes running")
}
}
@@ -1121,7 +1123,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
"headscale", "nodes", "approve-routes",
"--output", "json",
"--identifier", strconv.FormatUint(id, 10),
fmt.Sprintf("--routes=%s", strings.Join(util.PrefixesToString(routes), ",")),
"--routes=" + strings.Join(util.PrefixesToString(routes), ","),
}
result, _, err := dockertestutil.ExecuteCommand(

View File

@@ -4,13 +4,12 @@ import (
"encoding/json"
"fmt"
"net/netip"
"slices"
"sort"
"strings"
"testing"
"time"
"slices"
cmpdiff "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@@ -37,7 +36,6 @@ var allPorts = filter.PortRange{First: 0, Last: 0xffff}
// routes.
func TestEnablingRoutes(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 3,
@@ -182,11 +180,12 @@ func TestEnablingRoutes(t *testing.T) {
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
if peerStatus.ID == "1" {
switch peerStatus.ID {
case "1":
requirePeerSubnetRoutes(t, peerStatus, nil)
} else if peerStatus.ID == "2" {
case "2":
requirePeerSubnetRoutes(t, peerStatus, nil)
} else {
default:
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")})
}
}
@@ -195,7 +194,6 @@ func TestEnablingRoutes(t *testing.T) {
func TestHASubnetRouterFailover(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 3,
@@ -779,7 +777,6 @@ func TestHASubnetRouterFailover(t *testing.T) {
// https://github.com/juanfont/headscale/issues/1604
func TestSubnetRouteACL(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "user4"
@@ -1003,7 +1000,6 @@ func TestSubnetRouteACL(t *testing.T) {
// set during login instead of set.
func TestEnablingExitRoutes(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "user2"
@@ -1097,7 +1093,6 @@ func TestEnablingExitRoutes(t *testing.T) {
// subnet router is working as expected.
func TestSubnetRouterMultiNetwork(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@@ -1177,7 +1172,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
// Enable route
_, err = headscale.ApproveRoutes(
nodes[0].Id,
nodes[0].GetId(),
[]netip.Prefix{*pref},
)
require.NoError(t, err)
@@ -1224,7 +1219,6 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@@ -1300,7 +1294,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
}
// Enable route
_, err = headscale.ApproveRoutes(nodes[0].Id, []netip.Prefix{tsaddr.AllIPv4()})
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
require.NoError(t, err)
time.Sleep(5 * time.Second)
@@ -1719,7 +1713,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
assertNoErr(t, err)
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.Key)
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey())
assertNoErr(t, err)
}
// extra creation end.
@@ -2065,7 +2059,6 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
// that are explicitly allowed in the ACL.
func TestSubnetRouteACLFiltering(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
// Use router and node users for better clarity
routerUser := "router"
@@ -2090,7 +2083,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t)
// Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24)
aclPolicyStr := fmt.Sprintf(`{
aclPolicyStr := `{
"hosts": {
"router": "100.64.0.1/32",
"node": "100.64.0.2/32"
@@ -2115,7 +2108,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
]
}
]
}`)
}`
route, err := scenario.SubnetOfNetwork("usernet1")
require.NoError(t, err)

View File

@@ -123,7 +123,7 @@ type ScenarioSpec struct {
// NodesPerUser is how many nodes should be attached to each user.
NodesPerUser int
// Networks, if set, is the seperate Docker networks that should be
// Networks, if set, is the separate Docker networks that should be
// created and a list of the users that should be placed in those networks.
// If not set, a single network will be created and all users+nodes will be
// added there.
@@ -1077,7 +1077,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
hostname := "hs-oidcmock-" + hash
usersJSON, err := json.Marshal(users)
if err != nil {
@@ -1093,16 +1093,15 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
},
Networks: s.Networks(),
Env: []string{
fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
"MOCKOIDC_ADDR=" + hostname,
fmt.Sprintf("MOCKOIDC_PORT=%d", port),
"MOCKOIDC_CLIENT_ID=superclient",
"MOCKOIDC_CLIENT_SECRET=supersecret",
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)),
"MOCKOIDC_ACCESS_TTL=" + accessTTL.String(),
"MOCKOIDC_USERS=" + string(usersJSON),
},
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,
@@ -1117,7 +1116,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc")
if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
mockOidcOptions,
@@ -1184,7 +1183,7 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-webservice-%s", hash)
hostname := "hs-webservice-" + hash
network, ok := s.networks[s.prefixedNetworkName(networkName)]
if !ok {

View File

@@ -28,7 +28,6 @@ func IntegrationSkip(t *testing.T) {
// nolint:tparallel
func TestHeadscale(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
var err error
@@ -75,7 +74,6 @@ func TestHeadscale(t *testing.T) {
// nolint:tparallel
func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
var err error

View File

@@ -22,35 +22,6 @@ func isSSHNoAccessStdError(stderr string) bool {
strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node")
}
var retry = func(times int, sleepInterval time.Duration,
doWork func() (string, string, error),
) (string, string, error) {
var result string
var stderr string
var err error
for range times {
tempResult, tempStderr, err := doWork()
result += tempResult
stderr += tempStderr
if err == nil {
return result, stderr, nil
}
// If we get a permission denied error, we can fail immediately
// since that is something we won-t recover from by retrying.
if err != nil && isSSHNoAccessStdError(stderr) {
return result, stderr, err
}
time.Sleep(sleepInterval)
}
return result, stderr, err
}
func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario {
t.Helper()
@@ -92,7 +63,6 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce
func TestSSHOneUserToAll(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@@ -160,7 +130,6 @@ func TestSSHOneUserToAll(t *testing.T) {
func TestSSHMultipleUsersAllToAll(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@@ -216,7 +185,6 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
func TestSSHNoSSHConfigured(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@@ -261,7 +229,6 @@ func TestSSHNoSSHConfigured(t *testing.T) {
func TestSSHIsBlockedInACL(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@@ -313,7 +280,6 @@ func TestSSHIsBlockedInACL(t *testing.T) {
func TestSSHUserOnlyIsolation(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@@ -404,6 +370,14 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
}
func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
return doSSHWithRetry(t, client, peer, true)
}
func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
return doSSHWithRetry(t, client, peer, false)
}
func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient, retry bool) (string, string, error) {
t.Helper()
peerFQDN, _ := peer.FQDN()
@@ -417,9 +391,29 @@ func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string,
log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname())
log.Printf("Command: %s", strings.Join(command, " "))
return retry(10, 1*time.Second, func() (string, string, error) {
return client.Execute(command)
})
var result, stderr string
var err error
if retry {
// Use assert.EventuallyWithT to retry SSH connections for success cases
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, stderr, err = client.Execute(command)
// If we get a permission denied error, we can fail immediately
// since that is something we won't recover from by retrying.
if err != nil && isSSHNoAccessStdError(stderr) {
return // Don't retry permission denied errors
}
// For all other errors, assert no error to trigger retry
assert.NoError(ct, err)
}, 10*time.Second, 1*time.Second)
} else {
// For failure cases, just execute once
result, stderr, err = client.Execute(command)
}
return result, stderr, err
}
func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClient) {
@@ -434,7 +428,7 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien
func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) {
t.Helper()
result, stderr, err := doSSH(t, client, peer)
result, stderr, err := doSSHWithoutRetry(t, client, peer)
assert.Empty(t, result)
@@ -444,7 +438,7 @@ func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer Tailsc
func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient) {
t.Helper()
result, stderr, _ := doSSH(t, client, peer)
result, stderr, _ := doSSHWithoutRetry(t, client, peer)
assert.Empty(t, result)

View File

@@ -251,7 +251,6 @@ func New(
Env: []string{},
}
if tsic.withWebsocketDERP {
if version != VersionHead {
return tsic, errInvalidClientConfig
@@ -463,7 +462,7 @@ func (t *TailscaleInContainer) buildLoginCommand(
if len(t.withTags) > 0 {
command = append(command,
fmt.Sprintf(`--advertise-tags=%s`, strings.Join(t.withTags, ",")),
"--advertise-tags="+strings.Join(t.withTags, ","),
)
}
@@ -685,7 +684,7 @@ func (t *TailscaleInContainer) MustID() types.NodeID {
// Panics if version is lower then minimum.
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version))
panic("tsic.Netmap() called with unsupported version: " + t.version)
}
command := []string{
@@ -1026,7 +1025,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err
"tailscale", "ping",
fmt.Sprintf("--timeout=%s", args.timeout),
fmt.Sprintf("--c=%d", args.count),
fmt.Sprintf("--until-direct=%s", strconv.FormatBool(args.direct)),
"--until-direct=" + strconv.FormatBool(args.direct),
}
command = append(command, hostnameOrIP)
@@ -1131,11 +1130,11 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
command := []string{
"curl",
"--silent",
"--connect-timeout", fmt.Sprintf("%d", int(args.connectionTimeout.Seconds())),
"--max-time", fmt.Sprintf("%d", int(args.maxTime.Seconds())),
"--retry", fmt.Sprintf("%d", args.retry),
"--retry-delay", fmt.Sprintf("%d", int(args.retryDelay.Seconds())),
"--retry-max-time", fmt.Sprintf("%d", int(args.retryMaxTime.Seconds())),
"--connect-timeout", strconv.Itoa(int(args.connectionTimeout.Seconds())),
"--max-time", strconv.Itoa(int(args.maxTime.Seconds())),
"--retry", strconv.Itoa(args.retry),
"--retry-delay", strconv.Itoa(int(args.retryDelay.Seconds())),
"--retry-max-time", strconv.Itoa(int(args.retryMaxTime.Seconds())),
url,
}
@@ -1230,7 +1229,7 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) {
}
if out.Len() == 0 {
return nil, fmt.Errorf("file is empty")
return nil, errors.New("file is empty")
}
return out.Bytes(), nil
@@ -1259,5 +1258,6 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
if err = json.Unmarshal(currentProfile, &p); err != nil {
return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err)
}
return &p.Persist.PrivateNodeKey, nil
}

View File

@@ -3,7 +3,6 @@ package integration
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/netip"
@@ -267,7 +266,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) {
// This isn't really relevant for Self as it won't be in its own socket/wireguard.
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
// assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname())
// assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname())
for _, peer := range status.Peer {
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
@@ -311,7 +310,7 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) {
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
t.Helper()
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
_, err := backoff.Retry(t.Context(), func() (struct{}, error) {
stdout, stderr, err := c.Execute(command)
if err != nil {
return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
@@ -492,6 +491,7 @@ func groupApprover(name string) policyv2.AutoApprover {
func tagApprover(name string) policyv2.AutoApprover {
return ptr.To(policyv2.Tag(name))
}
//
// // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus
// // if there is a peer with the given hostname. If no peer is found, nil is returned.