bunch of qol (#2748)
Some checks failed
Build / build-nix (push) Has been cancelled
Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Has been cancelled
Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Has been cancelled
Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Has been cancelled
Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Has been cancelled
Check Generated Files / check-generated (push) Has been cancelled
Tests / test (push) Has been cancelled
update-flake-lock / lockfile (push) Has been cancelled
GitHub Actions Version Updater / build (push) Has been cancelled
Close inactive issues / close-issues (push) Has been cancelled

This commit is contained in:
Kristoffer Dalby
2025-08-27 17:09:13 +02:00
committed by GitHub
parent 1a7a2f4196
commit 8e25f7f9dd
11 changed files with 307 additions and 95 deletions

View File

@@ -4,7 +4,7 @@
# This Dockerfile is more or less lifted from tailscale/tailscale # This Dockerfile is more or less lifted from tailscale/tailscale
# to ensure a similar build process when testing the HEAD of tailscale. # to ensure a similar build process when testing the HEAD of tailscale.
FROM golang:1.24-alpine AS build-env FROM golang:1.25-alpine AS build-env
WORKDIR /go/src WORKDIR /go/src

View File

@@ -68,7 +68,7 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error {
continue // Skip potentially dangerous paths continue // Skip potentially dangerous paths
} }
targetPath := filepath.Join(targetDir, filepath.Base(cleanName)) targetPath := filepath.Join(targetDir, cleanName)
switch header.Typeflag { switch header.Typeflag {
case tar.TypeDir: case tar.TypeDir:
@@ -77,6 +77,11 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error {
return fmt.Errorf("failed to create directory %s: %w", targetPath, err) return fmt.Errorf("failed to create directory %s: %w", targetPath, err)
} }
case tar.TypeReg: case tar.TypeReg:
// Ensure parent directories exist
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err)
}
// Create file // Create file
outFile, err := os.Create(targetPath) outFile, err := os.Create(targetPath)
if err != nil { if err != nil {

View File

@@ -121,6 +121,29 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Write([]byte(h.state.PolicyDebugString())) w.Write([]byte(h.state.PolicyDebugString()))
})) }))
debug.Handle("mapresponses", "Map responses for all nodes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
res, err := h.mapBatcher.DebugMapResponses()
if err != nil {
httpError(w, err)
return
}
if res == nil {
w.WriteHeader(http.StatusOK)
w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
return
}
resJSON, err := json.MarshalIndent(res, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(resJSON)
}))
err := statsviz.Register(debugMux) err := statsviz.Register(debugMux)
if err == nil { if err == nil {
debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)") debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)")

View File

@@ -24,6 +24,7 @@ type Batcher interface {
ConnectedMap() *xsync.Map[types.NodeID, bool] ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c change.ChangeSet) AddWork(c change.ChangeSet)
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
} }
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher { func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {

View File

@@ -489,3 +489,7 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
nc.updateCount.Add(1) nc.updateCount.Add(1)
return nil return nil
} }
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
return b.mapper.debugMapResponses()
}

View File

@@ -237,7 +237,6 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
// WithPeersRemoved adds removed peer IDs // WithPeersRemoved adds removed peer IDs
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
var tailscaleIDs []tailcfg.NodeID var tailscaleIDs []tailcfg.NodeID
for _, id := range removedIDs { for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID()) tailscaleIDs = append(tailscaleIDs, id.NodeID())
@@ -247,12 +246,16 @@ func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapRe
} }
// Build finalizes the response and returns marshaled bytes // Build finalizes the response and returns marshaled bytes
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) { func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
if len(b.errs) > 0 { if len(b.errs) > 0 {
return nil, multierr.New(b.errs...) return nil, multierr.New(b.errs...)
} }
if debugDumpMapResponsePath != "" { if debugDumpMapResponsePath != "" {
writeDebugMapResponse(b.resp, b.nodeID) node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
return nil, err
}
writeDebugMapResponse(b.resp, node)
} }
return b.resp, nil return b.resp, nil

View File

@@ -18,17 +18,17 @@ func TestMapResponseBuilder_Basic(t *testing.T) {
Enabled: true, Enabled: true,
}, },
} }
mockState := &state.State{} mockState := &state.State{}
m := &mapper{ m := &mapper{
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID) builder := m.NewMapResponseBuilder(nodeID)
// Test basic builder creation // Test basic builder creation
assert.NotNil(t, builder) assert.NotNil(t, builder)
assert.Equal(t, nodeID, builder.nodeID) assert.Equal(t, nodeID, builder.nodeID)
@@ -45,13 +45,13 @@ func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(42) capVer := tailcfg.CapabilityVersion(42)
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer) WithCapabilityVersion(capVer)
assert.Equal(t, capVer, builder.capVer) assert.Equal(t, capVer, builder.capVer)
assert.False(t, builder.hasErrors()) assert.False(t, builder.hasErrors())
} }
@@ -62,18 +62,18 @@ func TestMapResponseBuilder_WithDomain(t *testing.T) {
ServerURL: "https://test.example.com", ServerURL: "https://test.example.com",
BaseDomain: domain, BaseDomain: domain,
} }
mockState := &state.State{} mockState := &state.State{}
m := &mapper{ m := &mapper{
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithDomain() WithDomain()
assert.Equal(t, domain, builder.resp.Domain) assert.Equal(t, domain, builder.resp.Domain)
assert.False(t, builder.hasErrors()) assert.False(t, builder.hasErrors())
} }
@@ -85,12 +85,12 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithCollectServicesDisabled() WithCollectServicesDisabled()
value, isSet := builder.resp.CollectServices.Get() value, isSet := builder.resp.CollectServices.Get()
assert.True(t, isSet) assert.True(t, isSet)
assert.False(t, value) assert.False(t, value)
@@ -99,22 +99,22 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) { func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
logTailEnabled bool logTailEnabled bool
expected bool expected bool
}{ }{
{ {
name: "LogTail enabled", name: "LogTail enabled",
logTailEnabled: true, logTailEnabled: true,
expected: false, // DisableLogTail should be false when LogTail is enabled expected: false, // DisableLogTail should be false when LogTail is enabled
}, },
{ {
name: "LogTail disabled", name: "LogTail disabled",
logTailEnabled: false, logTailEnabled: false,
expected: true, // DisableLogTail should be true when LogTail is disabled expected: true, // DisableLogTail should be true when LogTail is disabled
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cfg := &types.Config{ cfg := &types.Config{
@@ -127,12 +127,12 @@ func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithDebugConfig() WithDebugConfig()
require.NotNil(t, builder.resp.Debug) require.NotNil(t, builder.resp.Debug)
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail) assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
assert.False(t, builder.hasErrors()) assert.False(t, builder.hasErrors())
@@ -147,22 +147,22 @@ func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
changes := []*tailcfg.PeerChange{ changes := []*tailcfg.PeerChange{
{ {
NodeID: 123, NodeID: 123,
DERPRegion: 1, DERPRegion: 1,
}, },
{ {
NodeID: 456, NodeID: 456,
DERPRegion: 2, DERPRegion: 2,
}, },
} }
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(changes) WithPeerChangedPatch(changes)
assert.Equal(t, changes, builder.resp.PeersChangedPatch) assert.Equal(t, changes, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors()) assert.False(t, builder.hasErrors())
} }
@@ -174,14 +174,14 @@ func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
removedID1 := types.NodeID(123) removedID1 := types.NodeID(123)
removedID2 := types.NodeID(456) removedID2 := types.NodeID(456)
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1, removedID2) WithPeersRemoved(removedID1, removedID2)
expected := []tailcfg.NodeID{ expected := []tailcfg.NodeID{
removedID1.NodeID(), removedID1.NodeID(),
removedID2.NodeID(), removedID2.NodeID(),
@@ -197,25 +197,25 @@ func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
// Simulate an error in the builder // Simulate an error in the builder
builder := m.NewMapResponseBuilder(nodeID) builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError) builder.addError(assert.AnError)
// All subsequent calls should continue to work and accumulate errors // All subsequent calls should continue to work and accumulate errors
result := builder. result := builder.
WithDomain(). WithDomain().
WithCollectServicesDisabled(). WithCollectServicesDisabled().
WithDebugConfig() WithDebugConfig()
assert.True(t, result.hasErrors()) assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 1) assert.Len(t, result.errs, 1)
assert.Equal(t, assert.AnError, result.errs[0]) assert.Equal(t, assert.AnError, result.errs[0])
// Build should return the error // Build should return the error
data, err := result.Build("none") data, err := result.Build()
assert.Nil(t, data) assert.Nil(t, data)
assert.Error(t, err) assert.Error(t, err)
} }
@@ -229,22 +229,22 @@ func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
Enabled: false, Enabled: false,
}, },
} }
mockState := &state.State{} mockState := &state.State{}
m := &mapper{ m := &mapper{
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(99) capVer := tailcfg.CapabilityVersion(99)
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer). WithCapabilityVersion(capVer).
WithDomain(). WithDomain().
WithCollectServicesDisabled(). WithCollectServicesDisabled().
WithDebugConfig() WithDebugConfig()
// Verify all fields are set correctly // Verify all fields are set correctly
assert.Equal(t, capVer, builder.capVer) assert.Equal(t, capVer, builder.capVer)
assert.Equal(t, domain, builder.resp.Domain) assert.Equal(t, domain, builder.resp.Domain)
@@ -263,16 +263,16 @@ func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
removedID1 := types.NodeID(100) removedID1 := types.NodeID(100)
removedID2 := types.NodeID(200) removedID2 := types.NodeID(200)
// Test calling WithPeersRemoved multiple times // Test calling WithPeersRemoved multiple times
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1). WithPeersRemoved(removedID1).
WithPeersRemoved(removedID2) WithPeersRemoved(removedID2)
// Second call should overwrite the first // Second call should overwrite the first
expected := []tailcfg.NodeID{removedID2.NodeID()} expected := []tailcfg.NodeID{removedID2.NodeID()}
assert.Equal(t, expected, builder.resp.PeersRemoved) assert.Equal(t, expected, builder.resp.PeersRemoved)
@@ -286,12 +286,12 @@ func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch([]*tailcfg.PeerChange{}) WithPeerChangedPatch([]*tailcfg.PeerChange{})
assert.Empty(t, builder.resp.PeersChangedPatch) assert.Empty(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors()) assert.False(t, builder.hasErrors())
} }
@@ -303,12 +303,12 @@ func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID). builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(nil) WithPeerChangedPatch(nil)
assert.Nil(t, builder.resp.PeersChangedPatch) assert.Nil(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors()) assert.False(t, builder.hasErrors())
} }
@@ -320,28 +320,28 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
cfg: cfg, cfg: cfg,
state: mockState, state: mockState,
} }
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
// Create a builder and add multiple errors // Create a builder and add multiple errors
builder := m.NewMapResponseBuilder(nodeID) builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError) builder.addError(assert.AnError)
builder.addError(assert.AnError) builder.addError(assert.AnError)
builder.addError(nil) // This should be ignored builder.addError(nil) // This should be ignored
// All subsequent calls should continue to work // All subsequent calls should continue to work
result := builder. result := builder.
WithDomain(). WithDomain().
WithCollectServicesDisabled() WithCollectServicesDisabled()
assert.True(t, result.hasErrors()) assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 2) // nil error should be ignored assert.Len(t, result.errs, 2) // nil error should be ignored
// Build should return a multierr // Build should return a multierr
data, err := result.Build("none") data, err := result.Build()
assert.Nil(t, data) assert.Nil(t, data)
assert.Error(t, err) assert.Error(t, err)
// The error should contain information about multiple errors // The error should contain information about multiple errors
assert.Contains(t, err.Error(), "multiple errors") assert.Contains(t, err.Error(), "multiple errors")
} }

View File

@@ -9,6 +9,7 @@ import (
"os" "os"
"path" "path"
"slices" "slices"
"strconv"
"strings" "strings"
"time" "time"
@@ -154,7 +155,7 @@ func (m *mapper) fullMapResponse(
WithUserProfiles(peers). WithUserProfiles(peers).
WithPacketFilters(). WithPacketFilters().
WithPeers(peers). WithPeers(peers).
Build(messages...) Build()
} }
func (m *mapper) derpMapResponse( func (m *mapper) derpMapResponse(
@@ -207,36 +208,15 @@ func (m *mapper) peerRemovedResponse(
func writeDebugMapResponse( func writeDebugMapResponse(
resp *tailcfg.MapResponse, resp *tailcfg.MapResponse,
nodeID types.NodeID, node *types.Node,
messages ...string,
) { ) {
data := map[string]any{ body, err := json.MarshalIndent(resp, "", " ")
"Messages": messages,
"MapResponse": resp,
}
responseType := "keepalive"
switch {
case len(resp.Peers) > 0:
responseType = "full"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
responseType = "self"
case len(resp.PeersChanged) > 0:
responseType = "changed"
case len(resp.PeersChangedPatch) > 0:
responseType = "patch"
case len(resp.PeersRemoved) > 0:
responseType = "removed"
}
body, err := json.MarshalIndent(data, "", " ")
if err != nil { if err != nil {
panic(err) panic(err)
} }
perms := fs.FileMode(debugMapResponsePerm) perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, nodeID.String()) mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", node.ID))
err = os.MkdirAll(mPath, perms) err = os.MkdirAll(mPath, perms)
if err != nil { if err != nil {
panic(err) panic(err)
@@ -246,7 +226,7 @@ func writeDebugMapResponse(
mapResponsePath := path.Join( mapResponsePath := path.Join(
mPath, mPath,
fmt.Sprintf("%s-%s.json", now, responseType), fmt.Sprintf("%s.json", now),
) )
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
@@ -279,3 +259,62 @@ func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
// netip.Prefixes that are allowed for that node. It is used to filter routes // netip.Prefixes that are allowed for that node. It is used to filter routes
// from the primary route manager to the node. // from the primary route manager to the node.
type routeFilterFunc func(id types.NodeID) []netip.Prefix type routeFilterFunc func(id types.NodeID) []netip.Prefix
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
if debugDumpMapResponsePath == "" {
return nil, nil
}
nodes, err := os.ReadDir(debugDumpMapResponsePath)
if err != nil {
return nil, err
}
result := make(map[types.NodeID][]tailcfg.MapResponse)
for _, node := range nodes {
if !node.IsDir() {
continue
}
nodeIDu, err := strconv.ParseUint(node.Name(), 10, 64)
if err != nil {
log.Error().Err(err).Msgf("Parsing node ID from dir %s", node.Name())
continue
}
nodeID := types.NodeID(nodeIDu)
files, err := os.ReadDir(path.Join(debugDumpMapResponsePath, node.Name()))
if err != nil {
log.Error().Err(err).Msgf("Reading dir %s", node.Name())
continue
}
slices.SortStableFunc(files, func(a, b fs.DirEntry) int {
return strings.Compare(a.Name(), b.Name())
})
for _, file := range files {
if file.IsDir() || !strings.HasSuffix(file.Name(), ".json") {
continue
}
body, err := os.ReadFile(path.Join(debugDumpMapResponsePath, node.Name(), file.Name()))
if err != nil {
log.Error().Err(err).Msgf("Reading file %s", file.Name())
continue
}
var resp tailcfg.MapResponse
err = json.Unmarshal(body, &resp)
if err != nil {
log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name())
continue
}
result[nodeID] = append(result[nodeID], resp)
}
}
return result, nil
}

View File

@@ -5,7 +5,9 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
"tailscale.com/tailcfg"
) )
type ControlServer interface { type ControlServer interface {
@@ -29,4 +31,5 @@ type ControlServer interface {
GetCert() []byte GetCert() []byte
GetHostname() string GetHostname() string
SetPolicy(*policyv2.Policy) error SetPolicy(*policyv2.Policy) error
GetAllMapReponses() (map[types.NodeID][]tailcfg.MapResponse, error)
} }

View File

@@ -21,6 +21,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@@ -55,6 +56,17 @@ func TestPingAllByIP(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
hs, err := scenario.Headscale()
require.NoError(t, err)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
all, err := hs.GetAllMapReponses()
assert.NoError(ct, err)
onlineMap := buildExpectedOnlineMap(all)
assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap)
}, 30*time.Second, 2*time.Second)
// assertClientsState(t, allClients) // assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
@@ -940,6 +952,9 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
) )
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
hs, err := scenario.Headscale()
require.NoError(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err) assertNoErrListClients(t, err)
@@ -961,7 +976,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
wg, _ := errgroup.WithContext(context.Background()) wg, _ := errgroup.WithContext(context.Background())
for run := range 3 { for run := range 3 {
t.Logf("Starting DownUpPing run %d", run+1) t.Logf("Starting DownUpPing run %d at %s", run+1, time.Now().Format("2006-01-02T15-04-05.999999999"))
for _, client := range allClients { for _, client := range allClients {
c := client c := client
@@ -974,6 +989,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
if err := wg.Wait(); err != nil { if err := wg.Wait(); err != nil {
t.Fatalf("failed to take down all nodes: %s", err) t.Fatalf("failed to take down all nodes: %s", err)
} }
t.Logf("All nodes taken down at %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
for _, client := range allClients { for _, client := range allClients {
c := client c := client
@@ -984,13 +1000,24 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
} }
if err := wg.Wait(); err != nil { if err := wg.Wait(); err != nil {
t.Fatalf("failed to take down all nodes: %s", err) t.Fatalf("failed to bring up all nodes: %s", err)
} }
t.Logf("All nodes brought up at %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
// Wait for sync and successful pings after nodes come back up // Wait for sync and successful pings after nodes come back up
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assert.NoError(t, err) assert.NoError(t, err)
t.Logf("All nodes synced up %s", time.Now().Format("2006-01-02T15-04-05.999999999"))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
all, err := hs.GetAllMapReponses()
assert.NoError(ct, err)
onlineMap := buildExpectedOnlineMap(all)
assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap)
}, 60*time.Second, 2*time.Second)
success := pingAllHelper(t, allClients, allAddrs) success := pingAllHelper(t, allClients, allAddrs)
assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps)) assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps))
} }
@@ -1103,3 +1130,52 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
assert.True(t, nodeListAfter[0].GetOnline()) assert.True(t, nodeListAfter[0].GetOnline())
assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId()) assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId())
} }
func buildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool {
res := make(map[types.NodeID]map[types.NodeID]bool)
for nid, mrs := range all {
res[nid] = make(map[types.NodeID]bool)
for _, mr := range mrs {
for _, peer := range mr.Peers {
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChanged {
if peer.Online != nil {
res[nid][types.NodeID(peer.ID)] = *peer.Online
}
}
for _, peer := range mr.PeersChangedPatch {
if peer.Online != nil {
res[nid][types.NodeID(peer.NodeID)] = *peer.Online
}
}
}
}
return res
}
func assertExpectedOnlineMapAllOnline(t *assert.CollectT, expectedPeerCount int, onlineMap map[types.NodeID]map[types.NodeID]bool) {
for nid, peers := range onlineMap {
onlineCount := 0
for _, online := range peers {
if online {
onlineCount++
}
}
assert.Equalf(t, expectedPeerCount, len(peers), "node:%d had an unexpected number of peers in online map", nid)
if expectedPeerCount != onlineCount {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Not all of node:%d peers where online:\n", nid))
for pid, online := range peers {
sb.WriteString(fmt.Sprintf("\tPeer node:%d online: %t\n", pid, online))
}
sb.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n")
sb.WriteString("expected all peers to be online.")
t.Errorf("%s", sb.String())
}
}
}

View File

@@ -622,6 +622,27 @@ func extractTarToDirectory(tarData []byte, targetDir string) error {
} }
tarReader := tar.NewReader(bytes.NewReader(tarData)) tarReader := tar.NewReader(bytes.NewReader(tarData))
// Find the top-level directory to strip
var topLevelDir string
firstPass := tar.NewReader(bytes.NewReader(tarData))
for {
header, err := firstPass.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
if header.Typeflag == tar.TypeDir && topLevelDir == "" {
topLevelDir = strings.TrimSuffix(header.Name, "/")
break
}
}
// Second pass: extract files, stripping the top-level directory
tarReader = tar.NewReader(bytes.NewReader(tarData))
for { for {
header, err := tarReader.Next() header, err := tarReader.Next()
if err == io.EOF { if err == io.EOF {
@@ -637,7 +658,20 @@ func extractTarToDirectory(tarData []byte, targetDir string) error {
continue // Skip potentially dangerous paths continue // Skip potentially dangerous paths
} }
targetPath := filepath.Join(targetDir, filepath.Base(cleanName)) // Strip the top-level directory
if topLevelDir != "" && strings.HasPrefix(cleanName, topLevelDir+"/") {
cleanName = strings.TrimPrefix(cleanName, topLevelDir+"/")
} else if cleanName == topLevelDir {
// Skip the top-level directory itself
continue
}
// Skip empty paths after stripping
if cleanName == "" {
continue
}
targetPath := filepath.Join(targetDir, cleanName)
switch header.Typeflag { switch header.Typeflag {
case tar.TypeDir: case tar.TypeDir:
@@ -646,6 +680,11 @@ func extractTarToDirectory(tarData []byte, targetDir string) error {
return fmt.Errorf("failed to create directory %s: %w", targetPath, err) return fmt.Errorf("failed to create directory %s: %w", targetPath, err)
} }
case tar.TypeReg: case tar.TypeReg:
// Ensure parent directories exist
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err)
}
// Create file // Create file
outFile, err := os.Create(targetPath) outFile, err := os.Create(targetPath)
if err != nil { if err != nil {
@@ -674,7 +713,7 @@ func (t *HeadscaleInContainer) SaveProfile(savePath string) error {
return err return err
} }
targetDir := path.Join(savePath, t.hostname+"-pprof") targetDir := path.Join(savePath, "pprof")
return extractTarToDirectory(tarFile, targetDir) return extractTarToDirectory(tarFile, targetDir)
} }
@@ -685,7 +724,7 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error {
return err return err
} }
targetDir := path.Join(savePath, t.hostname+"-mapresponses") targetDir := path.Join(savePath, "mapresponses")
return extractTarToDirectory(tarFile, targetDir) return extractTarToDirectory(tarFile, targetDir)
} }
@@ -1243,3 +1282,22 @@ func (t *HeadscaleInContainer) SendInterrupt() error {
return nil return nil
} }
func (t *HeadscaleInContainer) GetAllMapReponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
// Execute curl inside the container to access the debug endpoint locally
command := []string{
"curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/mapresponses",
}
result, err := t.Execute(command)
if err != nil {
return nil, fmt.Errorf("fetching mapresponses from debug endpoint: %w", err)
}
var res map[types.NodeID][]tailcfg.MapResponse
if err := json.Unmarshal([]byte(result), &res); err != nil {
return nil, fmt.Errorf("decoding routes response: %w", err)
}
return res, nil
}