diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 5b2adddc..2c1bf94e 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -52,10 +52,9 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB tick: time.NewTicker(batchTime), // The size of this channel is arbitrary chosen, the sizing should be revisited. - workCh: make(chan work, workers*200), - nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), - connected: xsync.NewMap[types.NodeID, *time.Time](), - pendingChanges: xsync.NewMap[types.NodeID, []change.Change](), + workCh: make(chan work, workers*200), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + connected: xsync.NewMap[types.NodeID, *time.Time](), } } diff --git a/hscontrol/mapper/batcher_bench_test.go b/hscontrol/mapper/batcher_bench_test.go index 65d1c4ba..2d30110b 100644 --- a/hscontrol/mapper/batcher_bench_test.go +++ b/hscontrol/mapper/batcher_bench_test.go @@ -150,13 +150,12 @@ func BenchmarkUpdateSentPeers(b *testing.B) { // helper, it doesn't register cleanup and suppresses logging. func benchBatcher(nodeCount, bufferSize int) (*LockFreeBatcher, map[types.NodeID]chan *tailcfg.MapResponse) { b := &LockFreeBatcher{ - tick: time.NewTicker(1 * time.Hour), // never fires during bench - workers: 4, - workCh: make(chan work, 4*200), - nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), - connected: xsync.NewMap[types.NodeID, *time.Time](), - pendingChanges: xsync.NewMap[types.NodeID, []change.Change](), - done: make(chan struct{}), + tick: time.NewTicker(1 * time.Hour), // never fires during bench + workers: 4, + workCh: make(chan work, 4*200), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + connected: xsync.NewMap[types.NodeID, *time.Time](), + done: make(chan struct{}), } channels := make(map[types.NodeID]chan *tailcfg.MapResponse, nodeCount) @@ -204,8 +203,8 @@ func BenchmarkAddToBatch_Broadcast(b *testing.B) { for range b.N { batcher.addToBatch(ch) // Clear pending to avoid unbounded growth - batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool { - batcher.pendingChanges.Delete(id) + batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { + nc.drainPending() return true }) } @@ -242,8 +241,8 @@ func BenchmarkAddToBatch_Targeted(b *testing.B) { batcher.addToBatch(ch) // Clear pending periodically to avoid growth if i%100 == 99 { - batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool { - batcher.pendingChanges.Delete(id) + batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { + nc.drainPending() return true }) } @@ -298,7 +297,9 @@ func BenchmarkProcessBatchedChanges(b *testing.B) { b.StopTimer() // Seed pending changes for i := 1; i <= nodeCount; i++ { - batcher.pendingChanges.Store(types.NodeID(i), []change.Change{change.DERPMap()}) //nolint:gosec // benchmark + if nc, ok := batcher.nodes.Load(types.NodeID(i)); ok { //nolint:gosec // benchmark + nc.appendPending(change.DERPMap()) + } } b.StartTimer() @@ -411,8 +412,8 @@ func BenchmarkConcurrentAddToBatch(b *testing.B) { case <-batcher.done: return default: - batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool { - batcher.pendingChanges.Delete(id) + batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { + nc.drainPending() return true }) time.Sleep(time.Millisecond) //nolint:forbidigo // benchmark drain loop @@ -646,7 +647,7 @@ func BenchmarkAddNode(b *testing.B) { // Connect all nodes (measuring AddNode cost) for i := range allNodes { node := &allNodes[i] - _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) } b.StopTimer() @@ -707,7 +708,7 @@ func BenchmarkFullPipeline(b *testing.B) { for i := range allNodes { node := &allNodes[i] - err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) if err != nil { b.Fatalf("failed to add node %d: %v", i, err) } @@ -762,7 +763,7 @@ func BenchmarkMapResponseFromChange(b *testing.B) { for i := range allNodes { node := &allNodes[i] - err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) if err != nil { b.Fatalf("failed to add node %d: %v", i, err) } diff --git a/hscontrol/mapper/batcher_concurrency_test.go b/hscontrol/mapper/batcher_concurrency_test.go new file mode 100644 index 00000000..1a99dc93 --- /dev/null +++ b/hscontrol/mapper/batcher_concurrency_test.go @@ -0,0 +1,1731 @@ +package mapper + +// Concurrency, lifecycle, and scale tests for the batcher. +// Tests in this file exercise: +// - addToBatch and processBatchedChanges under concurrent access +// - cleanupOfflineNodes correctness +// - Batcher lifecycle (Close, shutdown, double-close) +// - 1000-node scale testing of batching and channel mechanics +// +// Most tests use the lightweight batcher helper which creates a batcher with +// pre-populated nodes but NO database, enabling fast 1000-node tests. + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/puzpuzpuz/xsync/v4" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" +) + +// ============================================================================ +// Lightweight Batcher Helper (no database needed) +// ============================================================================ + +// lightweightBatcher provides a batcher with pre-populated nodes for testing +// the batching, channel, and concurrency mechanics without database overhead. +type lightweightBatcher struct { + b *LockFreeBatcher + channels map[types.NodeID]chan *tailcfg.MapResponse +} + +// setupLightweightBatcher creates a batcher with nodeCount pre-populated nodes. +// Each node gets a buffered channel of bufferSize. The batcher's worker loop +// is NOT started (no doWork), so addToBatch/processBatchedChanges can be tested +// in isolation. Use startWorkers() if you need the full loop. +func setupLightweightBatcher(t *testing.T, nodeCount, bufferSize int) *lightweightBatcher { + t.Helper() + + b := &LockFreeBatcher{ + tick: time.NewTicker(10 * time.Millisecond), + workers: 4, + workCh: make(chan work, 4*200), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + connected: xsync.NewMap[types.NodeID, *time.Time](), + done: make(chan struct{}), + } + + channels := make(map[types.NodeID]chan *tailcfg.MapResponse, nodeCount) + for i := 1; i <= nodeCount; i++ { + id := types.NodeID(i) //nolint:gosec // test with small controlled values + mc := newMultiChannelNodeConn(id, nil) // nil mapper is fine for channel tests + ch := make(chan *tailcfg.MapResponse, bufferSize) + entry := &connectionEntry{ + id: fmt.Sprintf("conn-%d", i), + c: ch, + version: tailcfg.CapabilityVersion(100), + created: time.Now(), + } + entry.lastUsed.Store(time.Now().Unix()) + mc.addConnection(entry) + b.nodes.Store(id, mc) + b.connected.Store(id, nil) // nil = connected + channels[id] = ch + } + + b.totalNodes.Store(int64(nodeCount)) + + return &lightweightBatcher{b: b, channels: channels} +} + +func (lb *lightweightBatcher) cleanup() { + lb.b.doneOnce.Do(func() { + close(lb.b.done) + }) + lb.b.tick.Stop() +} + +// countTotalPending counts total pending change entries across all nodes. +func countTotalPending(b *LockFreeBatcher) int { + count := 0 + + b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { + nc.pendingMu.Lock() + count += len(nc.pending) + nc.pendingMu.Unlock() + + return true + }) + + return count +} + +// countNodesPending counts how many nodes have pending changes. +func countNodesPending(b *LockFreeBatcher) int { + count := 0 + + b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { + nc.pendingMu.Lock() + hasPending := len(nc.pending) > 0 + nc.pendingMu.Unlock() + + if hasPending { + count++ + } + + return true + }) + + return count +} + +// getPendingForNode returns pending changes for a specific node. +func getPendingForNode(b *LockFreeBatcher, id types.NodeID) []change.Change { + nc, ok := b.nodes.Load(id) + if !ok { + return nil + } + + nc.pendingMu.Lock() + pending := make([]change.Change, len(nc.pending)) + copy(pending, nc.pending) + nc.pendingMu.Unlock() + + return pending +} + +// runConcurrently runs n goroutines executing fn, waits for all to finish, +// and returns the number of panics caught. +func runConcurrently(t *testing.T, n int, fn func(i int)) int { + t.Helper() + + var ( + wg sync.WaitGroup + panics atomic.Int64 + ) + + for i := range n { + wg.Add(1) + + go func(idx int) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + panics.Add(1) + t.Logf("panic in goroutine %d: %v", idx, r) + } + }() + + fn(idx) + }(i) + } + + wg.Wait() + + return int(panics.Load()) +} + +// runConcurrentlyWithTimeout is like runConcurrently but fails if not done +// within timeout (deadlock detection). +func runConcurrentlyWithTimeout(t *testing.T, n int, timeout time.Duration, fn func(i int)) int { + t.Helper() + + done := make(chan int, 1) + + go func() { + done <- runConcurrently(t, n, fn) + }() + + select { + case panics := <-done: + return panics + case <-time.After(timeout): + t.Fatalf("deadlock detected: %d goroutines did not complete within %v", n, timeout) + return -1 + } +} + +// ============================================================================ +// addToBatch Concurrency Tests +// ============================================================================ + +// TestAddToBatch_ConcurrentTargeted_NoDataLoss verifies that concurrent +// targeted addToBatch calls do not lose data. +// +// Previously (Bug #1): addToBatch used LoadOrStore→append→Store on a +// separate pendingChanges map, which was NOT atomic. Two goroutines could +// Load the same slice, both append, and one Store would overwrite the other. +// FIX: pendingChanges moved into multiChannelNodeConn with mutex protection, +// eliminating the race entirely. +func TestAddToBatch_ConcurrentTargeted_NoDataLoss(t *testing.T) { + lb := setupLightweightBatcher(t, 10, 10) + defer lb.cleanup() + + targetNode := types.NodeID(1) + + const goroutines = 100 + + // Each goroutine adds one targeted change to the same node + panics := runConcurrentlyWithTimeout(t, goroutines, 10*time.Second, func(i int) { + ch := change.Change{ + Reason: fmt.Sprintf("targeted-%d", i), + TargetNode: targetNode, + PeerPatches: []*tailcfg.PeerChange{ + {NodeID: tailcfg.NodeID(i + 100)}, //nolint:gosec // test + }, + } + lb.b.addToBatch(ch) + }) + + require.Zero(t, panics, "no panics expected") + + // All 100 changes MUST be present. The Load→append→Store race causes + // data loss: typically 30-50% of changes are silently dropped. + pending := getPendingForNode(lb.b, targetNode) + t.Logf("targeted changes: expected=%d, got=%d (lost=%d)", + goroutines, len(pending), goroutines-len(pending)) + + assert.Len(t, pending, goroutines, + "addToBatch lost %d/%d targeted changes under concurrent access", + goroutines-len(pending), goroutines) +} + +// TestAddToBatch_ConcurrentBroadcast verifies that concurrent broadcasts +// distribute changes to all nodes. +func TestAddToBatch_ConcurrentBroadcast(t *testing.T) { + lb := setupLightweightBatcher(t, 50, 10) + defer lb.cleanup() + + const goroutines = 50 + + panics := runConcurrentlyWithTimeout(t, goroutines, 10*time.Second, func(_ int) { + lb.b.addToBatch(change.DERPMap()) + }) + + assert.Zero(t, panics, "no panics expected") + + // Each node should have received some DERP changes + nodesWithPending := countNodesPending(lb.b) + t.Logf("nodes with pending changes: %d/%d", nodesWithPending, 50) + assert.Positive(t, nodesWithPending, + "at least some nodes should have pending changes after broadcast") +} + +// TestAddToBatch_FullUpdateOverrides verifies that a FullUpdate replaces +// all pending changes for every node. +func TestAddToBatch_FullUpdateOverrides(t *testing.T) { + lb := setupLightweightBatcher(t, 10, 10) + defer lb.cleanup() + + // Add some targeted changes first + for i := 1; i <= 10; i++ { + lb.b.addToBatch(change.Change{ + Reason: "pre-existing", + TargetNode: types.NodeID(i), //nolint:gosec // test with small values + PeerPatches: []*tailcfg.PeerChange{ + {NodeID: tailcfg.NodeID(100 + i)}, //nolint:gosec // test with small values + }, + }) + } + + // Full update should replace all pending changes + lb.b.addToBatch(change.FullUpdate()) + + // Every node should have exactly one pending change (the FullUpdate) + lb.b.nodes.Range(func(id types.NodeID, _ *multiChannelNodeConn) bool { + pending := getPendingForNode(lb.b, id) + require.Len(t, pending, 1, "node %d should have exactly 1 pending (FullUpdate)", id) + assert.True(t, pending[0].IsFull(), "pending change should be a full update") + + return true + }) +} + +// TestAddToBatch_NodeRemovalCleanup verifies that PeersRemoved in a change +// cleans up the node from the batcher's internal state. +func TestAddToBatch_NodeRemovalCleanup(t *testing.T) { + lb := setupLightweightBatcher(t, 5, 10) + defer lb.cleanup() + + removedNode := types.NodeID(3) + + // Verify node exists before removal + _, exists := lb.b.nodes.Load(removedNode) + require.True(t, exists, "node 3 should exist before removal") + + // Send a change that includes node 3 in PeersRemoved + lb.b.addToBatch(change.Change{ + Reason: "node deleted", + PeersRemoved: []types.NodeID{removedNode}, + }) + + // Node should be removed from all maps + _, exists = lb.b.nodes.Load(removedNode) + assert.False(t, exists, "node 3 should be removed from nodes map") + + _, exists = lb.b.connected.Load(removedNode) + assert.False(t, exists, "node 3 should be removed from connected map") + + pending := getPendingForNode(lb.b, removedNode) + assert.Empty(t, pending, "node 3 should have no pending changes") + + assert.Equal(t, int64(4), lb.b.totalNodes.Load(), "total nodes should be decremented") +} + +// ============================================================================ +// processBatchedChanges Tests +// ============================================================================ + +// TestProcessBatchedChanges_QueuesWork verifies that processBatchedChanges +// moves pending changes to the work queue and clears them. +func TestProcessBatchedChanges_QueuesWork(t *testing.T) { + lb := setupLightweightBatcher(t, 3, 10) + defer lb.cleanup() + + // Add pending changes for each node + for i := 1; i <= 3; i++ { + if nc, ok := lb.b.nodes.Load(types.NodeID(i)); ok { //nolint:gosec // test + nc.appendPending(change.DERPMap()) + } + } + + lb.b.processBatchedChanges() + + // Pending should be cleared + assert.Equal(t, 0, countNodesPending(lb.b), + "all pending changes should be cleared after processing") + + // Work items should be on the work channel + assert.Len(t, lb.b.workCh, 3, + "3 work items should be queued") +} + +// TestProcessBatchedChanges_ConcurrentAdd_NoDataLoss verifies that concurrent +// addToBatch and processBatchedChanges calls do not lose data. +// +// Previously (Bug #2): processBatchedChanges used Range→Delete on a separate +// pendingChanges map. A concurrent addToBatch could Store new changes between +// Range reading the key and Delete removing it, losing freshly-stored changes. +// FIX: pendingChanges moved into multiChannelNodeConn with atomic drainPending(), +// eliminating the race entirely. +func TestProcessBatchedChanges_ConcurrentAdd_NoDataLoss(t *testing.T) { + // Use a single node to maximize contention on one key. + lb := setupLightweightBatcher(t, 1, 10) + defer lb.cleanup() + + // Use a large work channel so processBatchedChanges never blocks. + lb.b.workCh = make(chan work, 100000) + + const iterations = 500 + + var addedCount atomic.Int64 + + var wg sync.WaitGroup + + // Goroutine 1: continuously add targeted changes to node 1 + + wg.Go(func() { + for i := range iterations { + lb.b.addToBatch(change.Change{ + Reason: fmt.Sprintf("add-%d", i), + TargetNode: types.NodeID(1), + PeerPatches: []*tailcfg.PeerChange{ + {NodeID: tailcfg.NodeID(i + 100)}, //nolint:gosec // test + }, + }) + addedCount.Add(1) + } + }) + + // Goroutine 2: continuously process batched changes + + wg.Go(func() { + for range iterations { + lb.b.processBatchedChanges() + } + }) + + wg.Wait() + + // One final process to flush any remaining + lb.b.processBatchedChanges() + + // Count how many work items were actually queued + queuedWork := len(lb.b.workCh) + // Also count any still-pending + remaining := len(getPendingForNode(lb.b, types.NodeID(1))) + + total := queuedWork + remaining + added := int(addedCount.Load()) + + t.Logf("added=%d, queued_work=%d, still_pending=%d, total_accounted=%d, lost=%d", + added, queuedWork, remaining, total, added-total) + + // Every added change must either be in the work queue or still pending. + // The Range→Delete race in processBatchedChanges causes inconsistency: + // changes can be lost (total < added) or duplicated (total > added). + assert.Equal(t, added, total, + "processBatchedChanges has %d inconsistent changes (%d added vs %d accounted) "+ + "under concurrent access", + total-added, added, total) +} + +// TestProcessBatchedChanges_EmptyPending verifies processBatchedChanges +// is a no-op when there are no pending changes. +func TestProcessBatchedChanges_EmptyPending(t *testing.T) { + lb := setupLightweightBatcher(t, 5, 10) + defer lb.cleanup() + + lb.b.processBatchedChanges() + + assert.Empty(t, lb.b.workCh, + "no work should be queued when there are no pending changes") +} + +// ============================================================================ +// cleanupOfflineNodes Tests +// ============================================================================ + +// TestCleanupOfflineNodes_RemovesOld verifies that nodes offline longer +// than the 15-minute threshold are removed. +func TestCleanupOfflineNodes_RemovesOld(t *testing.T) { + lb := setupLightweightBatcher(t, 5, 10) + defer lb.cleanup() + + // Make node 3 appear offline for 20 minutes + oldTime := time.Now().Add(-20 * time.Minute) + lb.b.connected.Store(types.NodeID(3), &oldTime) + // Remove its active connections so it appears truly offline + if mc, ok := lb.b.nodes.Load(types.NodeID(3)); ok { + ch := lb.channels[types.NodeID(3)] + mc.removeConnectionByChannel(ch) + } + + lb.b.cleanupOfflineNodes() + + _, exists := lb.b.nodes.Load(types.NodeID(3)) + assert.False(t, exists, "node 3 should be cleaned up (offline >15min)") + + // Other nodes should still be present + _, exists = lb.b.nodes.Load(types.NodeID(1)) + assert.True(t, exists, "node 1 should still exist") +} + +// TestCleanupOfflineNodes_KeepsRecent verifies that recently disconnected +// nodes are not cleaned up. +func TestCleanupOfflineNodes_KeepsRecent(t *testing.T) { + lb := setupLightweightBatcher(t, 5, 10) + defer lb.cleanup() + + // Make node 3 appear offline for only 5 minutes (under threshold) + recentTime := time.Now().Add(-5 * time.Minute) + lb.b.connected.Store(types.NodeID(3), &recentTime) + + if mc, ok := lb.b.nodes.Load(types.NodeID(3)); ok { + ch := lb.channels[types.NodeID(3)] + mc.removeConnectionByChannel(ch) + } + + lb.b.cleanupOfflineNodes() + + _, exists := lb.b.nodes.Load(types.NodeID(3)) + assert.True(t, exists, "node 3 should NOT be cleaned up (offline <15min)") +} + +// TestCleanupOfflineNodes_KeepsActive verifies that nodes with active +// connections are never cleaned up, even if disconnect time is set. +func TestCleanupOfflineNodes_KeepsActive(t *testing.T) { + lb := setupLightweightBatcher(t, 5, 10) + defer lb.cleanup() + + // Set old disconnect time but keep the connection active + oldTime := time.Now().Add(-20 * time.Minute) + lb.b.connected.Store(types.NodeID(3), &oldTime) + // Don't remove connection - node still has active connections + + lb.b.cleanupOfflineNodes() + + _, exists := lb.b.nodes.Load(types.NodeID(3)) + assert.True(t, exists, + "node 3 should NOT be cleaned up (still has active connections)") +} + +// ============================================================================ +// Batcher Lifecycle Tests +// ============================================================================ + +// TestBatcher_CloseStopsWorkers verifies that Close() signals workers to stop +// and doesn't deadlock. +func TestBatcher_CloseStopsWorkers(t *testing.T) { + lb := setupLightweightBatcher(t, 3, 10) + + // Start workers + lb.b.Start() + + // Queue some work + if nc, ok := lb.b.nodes.Load(types.NodeID(1)); ok { + nc.appendPending(change.DERPMap()) + } + + lb.b.processBatchedChanges() + + // Close should not deadlock + done := make(chan struct{}) + + go func() { + lb.b.Close() + close(done) + }() + + select { + case <-done: + // Success + case <-time.After(5 * time.Second): + t.Fatal("Close() deadlocked") + } +} + +// TestBatcher_CloseMultipleTimes_DoubleClosePanic exercises Bug #4: +// multiChannelNodeConn.close() has no idempotency guard. Calling Close() +// concurrently triggers close() on the same channels multiple times, +// panicking with "close of closed channel". +// +// BUG: batcher_lockfree.go:555-565 - close() calls close(conn.c) with no guard +// FIX: Add sync.Once or atomic.Bool to multiChannelNodeConn.close(). +func TestBatcher_CloseMultipleTimes_DoubleClosePanic(t *testing.T) { + lb := setupLightweightBatcher(t, 3, 10) + lb.b.Start() + + // Close multiple times concurrently. + // The done channel and workCh are protected by sync.Once and should not panic. + // But node connection close() WILL panic because it has no idempotency guard. + panics := runConcurrently(t, 10, func(_ int) { + lb.b.Close() + }) + + assert.Zero(t, panics, + "BUG #4: %d panics from concurrent Close() due to "+ + "multiChannelNodeConn.close() lacking idempotency guard. "+ + "Fix: add sync.Once or atomic.Bool to close()", panics) +} + +// TestBatcher_QueueWorkDuringShutdown verifies that queueWork doesn't block +// when the batcher is shutting down. +func TestBatcher_QueueWorkDuringShutdown(t *testing.T) { + lb := setupLightweightBatcher(t, 3, 10) + + // Close the done channel to simulate shutdown + close(lb.b.done) + + // queueWork should not block (it selects on done channel) + done := make(chan struct{}) + + go func() { + lb.b.queueWork(work{ + c: change.DERPMap(), + nodeID: types.NodeID(1), + }) + close(done) + }() + + select { + case <-done: + // Success - didn't block + case <-time.After(1 * time.Second): + t.Fatal("queueWork blocked during shutdown") + } +} + +// TestBatcher_MapResponseDuringShutdown verifies that MapResponseFromChange +// returns ErrBatcherShuttingDown when the batcher is closed. +func TestBatcher_MapResponseDuringShutdown(t *testing.T) { + lb := setupLightweightBatcher(t, 3, 10) + + // Close the done channel + close(lb.b.done) + + _, err := lb.b.MapResponseFromChange(types.NodeID(1), change.DERPMap()) + assert.ErrorIs(t, err, ErrBatcherShuttingDown) +} + +// TestBatcher_IsConnectedReflectsState verifies IsConnected accurately +// reflects the connection state of nodes. +func TestBatcher_IsConnectedReflectsState(t *testing.T) { + lb := setupLightweightBatcher(t, 5, 10) + defer lb.cleanup() + + // All nodes should be connected + for i := 1; i <= 5; i++ { + assert.True(t, lb.b.IsConnected(types.NodeID(i)), //nolint:gosec // test + "node %d should be connected", i) + } + + // Non-existent node should not be connected + assert.False(t, lb.b.IsConnected(types.NodeID(999))) + + // Disconnect node 3 (remove connection + set disconnect time) + if mc, ok := lb.b.nodes.Load(types.NodeID(3)); ok { + mc.removeConnectionByChannel(lb.channels[types.NodeID(3)]) + } + + now := time.Now() + lb.b.connected.Store(types.NodeID(3), &now) + + assert.False(t, lb.b.IsConnected(types.NodeID(3)), + "node 3 should not be connected after disconnection") + + // Other nodes should still be connected + assert.True(t, lb.b.IsConnected(types.NodeID(1))) + assert.True(t, lb.b.IsConnected(types.NodeID(5))) +} + +// TestBatcher_ConnectedMapConsistency verifies ConnectedMap returns accurate +// state for all nodes. +func TestBatcher_ConnectedMapConsistency(t *testing.T) { + lb := setupLightweightBatcher(t, 5, 10) + defer lb.cleanup() + + // Disconnect node 2 + if mc, ok := lb.b.nodes.Load(types.NodeID(2)); ok { + mc.removeConnectionByChannel(lb.channels[types.NodeID(2)]) + } + + now := time.Now() + lb.b.connected.Store(types.NodeID(2), &now) + + cm := lb.b.ConnectedMap() + + // Connected nodes + for _, id := range []types.NodeID{1, 3, 4, 5} { + val, ok := cm.Load(id) + assert.True(t, ok, "node %d should be in ConnectedMap", id) + assert.True(t, val, "node %d should be connected", id) + } + + // Disconnected node + val, ok := cm.Load(types.NodeID(2)) + assert.True(t, ok, "node 2 should be in ConnectedMap") + assert.False(t, val, "node 2 should be disconnected") +} + +// ============================================================================ +// Bug Reproduction Tests (all expected to FAIL until bugs are fixed) +// ============================================================================ + +// TestBug3_CleanupOfflineNodes_TOCTOU exercises Bug #3: +// cleanupOfflineNodes has a TOCTOU (time-of-check-time-of-use) race. +// Between checking hasActiveConnections()==false and calling nodes.Delete(), +// AddNode can reconnect the node, and the cleanup deletes the fresh connection. +// +// BUG: batcher_lockfree.go:407-414 checks hasActiveConnections, +// +// then :426 deletes the node. A reconnect between these two lines +// causes a live node to be deleted. +// +// FIX: Use Compute() on nodes map to atomically check-and-delete, or +// +// add a generation counter to detect stale cleanup. +func TestBug3_CleanupOfflineNodes_TOCTOU(t *testing.T) { + lb := setupLightweightBatcher(t, 5, 10) + defer lb.cleanup() + + targetNode := types.NodeID(3) + + // Make node 3 appear offline for >15 minutes (past cleanup threshold) + oldTime := time.Now().Add(-20 * time.Minute) + lb.b.connected.Store(targetNode, &oldTime) + // Remove its active connections so it appears truly offline + if mc, ok := lb.b.nodes.Load(targetNode); ok { + ch := lb.channels[targetNode] + mc.removeConnectionByChannel(ch) + } + + // Now simulate a reconnection happening concurrently with cleanup. + // We'll add a new connection to the node DURING cleanup. + var wg sync.WaitGroup + + reconnected := make(chan struct{}) + + // Goroutine 1: wait a tiny bit, then reconnect node 3 + + wg.Go(func() { + // Wait for cleanup to start the Range phase + time.Sleep(50 * time.Microsecond) //nolint:forbidigo // concurrency test coordination + + mc, exists := lb.b.nodes.Load(targetNode) + if !exists { + // Node already deleted by cleanup - that's the bug! + return + } + + newCh := make(chan *tailcfg.MapResponse, 10) + entry := &connectionEntry{ + id: "reconnected", + c: newCh, + version: tailcfg.CapabilityVersion(100), + created: time.Now(), + } + entry.lastUsed.Store(time.Now().Unix()) + mc.addConnection(entry) + lb.b.connected.Store(targetNode, nil) // nil = connected + lb.channels[targetNode] = newCh + + close(reconnected) + }) + + // Goroutine 2: run cleanup + + wg.Go(func() { + lb.b.cleanupOfflineNodes() + }) + + wg.Wait() + + // After cleanup + reconnection, node 3 MUST still exist. + // The TOCTOU bug: cleanup checks hasActiveConnections=false, then a + // reconnection adds a connection, then cleanup deletes the node anyway. + _, exists := lb.b.nodes.Load(targetNode) + assert.True(t, exists, + "BUG #3: cleanupOfflineNodes deleted node %d despite it being reconnected. "+ + "TOCTOU race: check hasActiveConnections→reconnect→Delete loses the live node. "+ + "Fix: use Compute() for atomic check-and-delete, or generation counter", + targetNode) +} + +// TestBug5_WorkerPanicKillsWorkerPermanently exercises Bug #5: +// Workers have no recover() wrapper. A panic in generateMapResponse or +// handleNodeChange permanently kills the worker goroutine, reducing +// throughput and eventually deadlocking when all workers are dead. +// +// BUG: batcher_lockfree.go:212-287 - worker() has no defer recover() +// FIX: Add defer recover() that logs the panic and continues the loop. +func TestBug5_WorkerPanicKillsWorkerPermanently(t *testing.T) { + zerolog.SetGlobalLevel(zerolog.Disabled) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + lb := setupLightweightBatcher(t, 3, 10) + defer lb.cleanup() + + // We need workers running. Use a small worker count. + lb.b.workers = 2 + lb.b.Start() + + // Give workers time to start + time.Sleep(50 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + + // Record initial work processed count + initialProcessed := lb.b.workProcessed.Load() + + // Queue work that will cause the worker to encounter an error + // (node exists but mapper is nil, which goes through the nc.change path + // that calls handleNodeChange → generateMapResponse with nil mapper). + // This produces an error but doesn't panic by itself. + // + // To actually trigger a panic, we need to make the node connection's + // change() method panic. We can do this by corrupting internal state. + // However, that's fragile. Instead, we verify the architectural issue: + // if a worker DID panic, does the batcher recover? + // + // We simulate this by checking: after queuing invalid work that produces + // errors, can we still process valid work? With no panic recovery, + // a real panic would make subsequent work permanently stuck. + + // Queue several work items for non-existent nodes (produces errors) + for range 10 { + lb.b.queueWork(work{ + c: change.DERPMap(), + nodeID: types.NodeID(99999), // doesn't exist + }) + } + + // Wait for workers to process the error-producing work + time.Sleep(100 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + + errorsAfterBad := lb.b.workErrors.Load() + processedAfterBad := lb.b.workProcessed.Load() + t.Logf("after bad work: processed=%d, errors=%d", + processedAfterBad-initialProcessed, errorsAfterBad) + + // Now queue valid work items (node 1 exists) + beforeValid := lb.b.workProcessed.Load() + for range 5 { + lb.b.queueWork(work{ + c: change.DERPMap(), + nodeID: types.NodeID(1), + }) + } + + // Wait for processing + time.Sleep(200 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + + afterValid := lb.b.workProcessed.Load() + validProcessed := afterValid - beforeValid + + t.Logf("valid work processed: %d/5", validProcessed) + + // This passes currently because nil-mapper errors don't panic. + // But the architectural bug remains: if ANY code path in the worker + // panics (e.g., nil pointer in mapper, index out of range in builder), + // the worker dies permanently with no recovery. + // + // We assert that workers SHOULD have a recovery mechanism: + assert.Equal(t, int64(5), validProcessed, + "workers should process all valid work even after encountering errors") + + // The real test: verify worker() has defer recover(). + // Since we can't easily cause a real panic in the worker without + // modifying production code, we document this as a structural bug. + // A proper fix adds: defer func() { if r := recover(); r != nil { log... } }() + // at the top of worker(). + // + // For now, we verify at minimum that error-producing work doesn't kill workers. + assert.GreaterOrEqual(t, errorsAfterBad, int64(10), + "worker should have recorded errors for non-existent node work") +} + +// TestBug6_StartCalledMultipleTimes_GoroutineLeak exercises Bug #6: +// Start() creates a new done channel and launches doWork() every time, +// with no guard against multiple calls. Each call spawns (workers+1) +// goroutines that never get cleaned up. +// +// BUG: batcher_lockfree.go:163-166 - Start() has no "already started" check +// FIX: Add sync.Once or atomic.Bool to prevent multiple Start() calls. +func TestBug6_StartCalledMultipleTimes_GoroutineLeak(t *testing.T) { + zerolog.SetGlobalLevel(zerolog.Disabled) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + lb := setupLightweightBatcher(t, 3, 10) + lb.b.workers = 2 + + goroutinesBefore := runtime.NumGoroutine() + + // Call Start() once - this should launch (workers + 1) goroutines + // (1 for doWork + workers for worker()) + lb.b.Start() + time.Sleep(50 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + + goroutinesAfterFirst := runtime.NumGoroutine() + firstStartDelta := goroutinesAfterFirst - goroutinesBefore + t.Logf("goroutines: before=%d, after_first_Start=%d, delta=%d", + goroutinesBefore, goroutinesAfterFirst, firstStartDelta) + + // Call Start() again - this SHOULD be a no-op + // BUG: it creates a NEW done channel (orphaning goroutines listening on the old one) + // and launches another doWork()+workers set + lb.b.Start() + time.Sleep(50 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + + goroutinesAfterSecond := runtime.NumGoroutine() + secondStartDelta := goroutinesAfterSecond - goroutinesAfterFirst + t.Logf("goroutines: after_second_Start=%d, delta=%d (should be 0)", + goroutinesAfterSecond, secondStartDelta) + + // Call Start() a third time + lb.b.Start() + time.Sleep(50 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + + goroutinesAfterThird := runtime.NumGoroutine() + thirdStartDelta := goroutinesAfterThird - goroutinesAfterSecond + t.Logf("goroutines: after_third_Start=%d, delta=%d (should be 0)", + goroutinesAfterThird, thirdStartDelta) + + // Close() only closes the LAST done channel, leaving earlier goroutines leaked + lb.b.Close() + time.Sleep(100 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + + goroutinesAfterClose := runtime.NumGoroutine() + t.Logf("goroutines after Close: %d (leaked: %d)", + goroutinesAfterClose, goroutinesAfterClose-goroutinesBefore) + + // Second Start() should NOT have created new goroutines + assert.Zero(t, secondStartDelta, + "BUG #6: second Start() call leaked %d goroutines. "+ + "Start() has no idempotency guard, each call spawns new goroutines. "+ + "Fix: add sync.Once or atomic.Bool to prevent multiple Start() calls", + secondStartDelta) +} + +// TestBug7_CleanupOfflineNodes_PendingChangesCleanedStructurally verifies that +// pending changes are automatically cleaned up when a node is removed from the +// nodes map, because pending state lives inside multiChannelNodeConn. +// +// Previously (Bug #7): pendingChanges was a separate map that was NOT cleaned +// when cleanupOfflineNodes removed a node, causing orphaned entries. +// FIX: pendingChanges moved into multiChannelNodeConn — deleting the node +// from b.nodes automatically drops its pending changes. +func TestBug7_CleanupOfflineNodes_PendingChangesCleanedStructurally(t *testing.T) { + lb := setupLightweightBatcher(t, 5, 10) + defer lb.cleanup() + + targetNode := types.NodeID(3) + + // Make node 3 appear offline for >15 minutes + oldTime := time.Now().Add(-20 * time.Minute) + lb.b.connected.Store(targetNode, &oldTime) + + if mc, ok := lb.b.nodes.Load(targetNode); ok { + ch := lb.channels[targetNode] + mc.removeConnectionByChannel(ch) + } + + // Add pending changes for node 3 before cleanup + if nc, ok := lb.b.nodes.Load(targetNode); ok { + nc.appendPending(change.DERPMap()) + } + + // Verify pending exists before cleanup + pending := getPendingForNode(lb.b, targetNode) + require.Len(t, pending, 1, "node 3 should have pending changes before cleanup") + + // Run cleanup + lb.b.cleanupOfflineNodes() + + // Node 3 should be removed from nodes and connected + _, existsInNodes := lb.b.nodes.Load(targetNode) + assert.False(t, existsInNodes, "node 3 should be removed from nodes map") + + _, existsInConnected := lb.b.connected.Load(targetNode) + assert.False(t, existsInConnected, "node 3 should be removed from connected map") + + // Pending changes are structurally gone because the node was deleted. + // getPendingForNode returns nil for non-existent nodes. + pendingAfter := getPendingForNode(lb.b, targetNode) + assert.Empty(t, pendingAfter, + "pending changes should be gone after node deletion (structural fix)") +} + +// TestBug8_SerialTimeoutUnderWriteLock exercises Bug #8 (performance): +// multiChannelNodeConn.send() holds the write lock for the ENTIRE duration +// of sending to all connections. Each send has a 50ms timeout for stale +// connections. With N stale connections, the write lock is held for N*50ms, +// blocking all addConnection/removeConnection calls. +// +// BUG: batcher_lockfree.go:629-697 - mutex.Lock() held during all conn.send() +// +// calls, each with 50ms timeout. 5 stale connections = 250ms lock hold. +// +// FIX: Copy connections under read lock, send without lock, then take +// +// write lock only for removing failed connections. +func TestBug8_SerialTimeoutUnderWriteLock(t *testing.T) { + zerolog.SetGlobalLevel(zerolog.Disabled) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + mc := newMultiChannelNodeConn(1, nil) + + // Add 5 stale connections (unbuffered, no reader = will timeout at 50ms each) + const staleCount = 5 + for i := range staleCount { + ch := make(chan *tailcfg.MapResponse) // unbuffered + mc.addConnection(makeConnectionEntry(fmt.Sprintf("stale-%d", i), ch)) + } + + // Measure how long send() takes - it should timeout at ~50ms for ONE + // connection, but with serial timeouts it takes staleCount * 50ms. + start := time.Now() + _ = mc.send(testMapResponse()) + elapsed := time.Since(start) + + t.Logf("send() with %d stale connections took %v (expected ~50ms, got ~%dms)", + staleCount, elapsed, elapsed.Milliseconds()) + + // The write lock is held for the entire duration. With 5 stale connections, + // each timing out at 50ms, that's ~250ms of write lock hold time. + // This blocks ALL other operations (addConnection, removeConnection, etc.) + // + // The fix should make send() complete in ~50ms regardless of stale count + // by releasing the lock before sending, or sending in parallel. + assert.Less(t, elapsed, 100*time.Millisecond, + "BUG #8: send() held write lock for %v with %d stale connections. "+ + "Serial 50ms timeouts under write lock cause %d*50ms=%dms lock hold. "+ + "Fix: copy connections under read lock, send without lock, then "+ + "write-lock only for cleanup", + elapsed, staleCount, staleCount, staleCount*50) +} + +// TestBug1_BroadcastNoDataLoss verifies that concurrent broadcast addToBatch +// calls do not lose data. +// +// Previously (Bug #1, broadcast path): Same Load→append→Store race as targeted +// changes, but on the broadcast code path within the Range callback. +// FIX: pendingChanges moved into multiChannelNodeConn with mutex protection. +func TestBug1_BroadcastNoDataLoss(t *testing.T) { + // Use many nodes so the Range iteration takes longer, widening the race window + lb := setupLightweightBatcher(t, 100, 10) + defer lb.cleanup() + + const goroutines = 50 + + // Each goroutine broadcasts a DERPMap change to all 100 nodes + panics := runConcurrentlyWithTimeout(t, goroutines, 10*time.Second, func(_ int) { + lb.b.addToBatch(change.DERPMap()) + }) + + require.Zero(t, panics, "no panics expected") + + // Each of the 100 nodes should have exactly `goroutines` pending changes. + // The race causes some nodes to have fewer. + var ( + totalLost int + nodesWithLoss int + ) + + lb.b.nodes.Range(func(id types.NodeID, _ *multiChannelNodeConn) bool { + pending := getPendingForNode(lb.b, id) + if len(pending) < goroutines { + totalLost += goroutines - len(pending) + nodesWithLoss++ + } + + return true + }) + + t.Logf("broadcast data loss: %d total changes lost across %d/%d nodes", + totalLost, nodesWithLoss, 100) + + assert.Zero(t, totalLost, + "broadcast lost %d changes across %d nodes under concurrent access", + totalLost, nodesWithLoss) +} + +// ============================================================================ +// 1000-Node Scale Tests (lightweight, no DB) +// ============================================================================ + +// TestScale1000_AddToBatch_Broadcast verifies that broadcasting to 1000 nodes +// works correctly under concurrent access. +func TestScale1000_AddToBatch_Broadcast(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node test in short mode") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + lb := setupLightweightBatcher(t, 1000, 10) + defer lb.cleanup() + + const concurrentBroadcasts = 100 + + panics := runConcurrentlyWithTimeout(t, concurrentBroadcasts, 30*time.Second, func(_ int) { + lb.b.addToBatch(change.DERPMap()) + }) + + assert.Zero(t, panics, "no panics expected") + + nodesWithPending := countNodesPending(lb.b) + totalPending := countTotalPending(lb.b) + + t.Logf("1000-node broadcast: %d/%d nodes have pending, %d total pending items", + nodesWithPending, 1000, totalPending) + + // All 1000 nodes should have at least some pending changes + // (may lose some due to Bug #1 race, but should have most) + assert.GreaterOrEqual(t, nodesWithPending, 900, + "at least 90%% of nodes should have pending changes") +} + +// TestScale1000_ProcessBatchedWithConcurrentAdd tests processBatchedChanges +// running concurrently with addToBatch at 1000 nodes. +func TestScale1000_ProcessBatchedWithConcurrentAdd(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node test in short mode") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + lb := setupLightweightBatcher(t, 1000, 10) + defer lb.cleanup() + + // Use a large work channel to avoid blocking. + // 50 broadcasts × 1000 nodes = up to 50,000 work items. + lb.b.workCh = make(chan work, 100000) + + var wg sync.WaitGroup + + // Producer: add broadcasts + + wg.Go(func() { + for range 50 { + lb.b.addToBatch(change.DERPMap()) + } + }) + + // Consumer: process batched changes repeatedly + + wg.Go(func() { + for range 50 { + lb.b.processBatchedChanges() + time.Sleep(1 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + } + }) + + done := make(chan struct{}) + + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + t.Logf("1000-node concurrent add+process completed without deadlock") + case <-time.After(30 * time.Second): + t.Fatal("deadlock detected in 1000-node concurrent add+process") + } + + queuedWork := len(lb.b.workCh) + t.Logf("work items queued: %d", queuedWork) + assert.Positive(t, queuedWork, "should have queued some work items") +} + +// TestScale1000_MultiChannelBroadcast tests broadcasting a MapResponse +// to 1000 nodes, each with 1-3 connections. +func TestScale1000_MultiChannelBroadcast(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node test in short mode") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + const ( + nodeCount = 1000 + bufferSize = 5 + ) + + // Create nodes with varying connection counts + b := &LockFreeBatcher{ + tick: time.NewTicker(10 * time.Millisecond), + workers: 4, + workCh: make(chan work, 4*200), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + connected: xsync.NewMap[types.NodeID, *time.Time](), + done: make(chan struct{}), + } + + defer func() { + close(b.done) + b.tick.Stop() + }() + + type nodeChannels struct { + channels []chan *tailcfg.MapResponse + } + + allNodeChannels := make(map[types.NodeID]*nodeChannels, nodeCount) + + for i := 1; i <= nodeCount; i++ { + id := types.NodeID(i) //nolint:gosec // test with small controlled values + mc := newMultiChannelNodeConn(id, nil) + + connCount := 1 + (i % 3) // 1, 2, or 3 connections + nc := &nodeChannels{channels: make([]chan *tailcfg.MapResponse, connCount)} + + for j := range connCount { + ch := make(chan *tailcfg.MapResponse, bufferSize) + nc.channels[j] = ch + entry := &connectionEntry{ + id: fmt.Sprintf("conn-%d-%d", i, j), + c: ch, + version: tailcfg.CapabilityVersion(100), + created: time.Now(), + } + entry.lastUsed.Store(time.Now().Unix()) + mc.addConnection(entry) + } + + b.nodes.Store(id, mc) + allNodeChannels[id] = nc + } + + // Broadcast to all nodes + data := testMapResponse() + + var successCount, failCount atomic.Int64 + + start := time.Now() + + b.nodes.Range(func(id types.NodeID, mc *multiChannelNodeConn) bool { + err := mc.send(data) + if err != nil { + failCount.Add(1) + } else { + successCount.Add(1) + } + + return true + }) + + elapsed := time.Since(start) + + t.Logf("broadcast to %d nodes: %d success, %d failures, took %v", + nodeCount, successCount.Load(), failCount.Load(), elapsed) + + assert.Equal(t, int64(nodeCount), successCount.Load(), + "all nodes should receive broadcast successfully") + assert.Zero(t, failCount.Load(), "no broadcast failures expected") + + // Verify at least some channels received data + receivedCount := 0 + + for _, nc := range allNodeChannels { + for _, ch := range nc.channels { + select { + case <-ch: + receivedCount++ + default: + } + } + } + + t.Logf("channels that received data: %d", receivedCount) + assert.Positive(t, receivedCount, "channels should have received broadcast data") +} + +// TestScale1000_ConnectionChurn tests 1000 nodes with 10% churning connections +// while broadcasts are happening. Stable nodes should not lose data. +func TestScale1000_ConnectionChurn(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node test in short mode") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + lb := setupLightweightBatcher(t, 1000, 20) + defer lb.cleanup() + + const churnNodes = 100 // 10% of nodes churn + + const churnCycles = 50 + + var ( + panics atomic.Int64 + wg sync.WaitGroup + ) + + // Churn goroutine: rapidly add/remove connections for nodes 901-1000 + + wg.Go(func() { + for cycle := range churnCycles { + for i := 901; i <= 901+churnNodes-1; i++ { + id := types.NodeID(i) //nolint:gosec // test with small controlled values + + mc, exists := lb.b.nodes.Load(id) + if !exists { + continue + } + + // Remove old connection + oldCh := lb.channels[id] + mc.removeConnectionByChannel(oldCh) + + // Add new connection + newCh := make(chan *tailcfg.MapResponse, 20) + entry := &connectionEntry{ + id: fmt.Sprintf("churn-%d-%d", i, cycle), + c: newCh, + version: tailcfg.CapabilityVersion(100), + created: time.Now(), + } + entry.lastUsed.Store(time.Now().Unix()) + mc.addConnection(entry) + + lb.channels[id] = newCh + } + } + }) + + // Broadcast goroutine: send addToBatch calls during churn + + wg.Go(func() { + for range churnCycles { + func() { + defer func() { + if r := recover(); r != nil { + panics.Add(1) + } + }() + + lb.b.addToBatch(change.DERPMap()) + }() + } + }) + + done := make(chan struct{}) + + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success + case <-time.After(30 * time.Second): + t.Fatal("deadlock in 1000-node connection churn test") + } + + assert.Zero(t, panics.Load(), "no panics during connection churn") + + // Verify stable nodes (1-900) still have active connections + stableConnected := 0 + + for i := 1; i <= 900; i++ { + if mc, exists := lb.b.nodes.Load(types.NodeID(i)); exists { //nolint:gosec // test + if mc.hasActiveConnections() { + stableConnected++ + } + } + } + + t.Logf("stable nodes still connected: %d/900", stableConnected) + assert.Equal(t, 900, stableConnected, + "all stable nodes should retain their connections during churn") +} + +// TestScale1000_ConcurrentAddRemove tests concurrent AddNode-like and +// RemoveNode-like operations at 1000-node scale. +func TestScale1000_ConcurrentAddRemove(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node test in short mode") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + lb := setupLightweightBatcher(t, 1000, 10) + defer lb.cleanup() + + const goroutines = 200 + + panics := runConcurrentlyWithTimeout(t, goroutines, 30*time.Second, func(i int) { + id := types.NodeID(1 + (i % 1000)) //nolint:gosec // test + + mc, exists := lb.b.nodes.Load(id) + if !exists { + return + } + + if i%2 == 0 { + // Add a new connection + ch := make(chan *tailcfg.MapResponse, 10) + entry := &connectionEntry{ + id: fmt.Sprintf("concurrent-%d", i), + c: ch, + version: tailcfg.CapabilityVersion(100), + created: time.Now(), + } + entry.lastUsed.Store(time.Now().Unix()) + mc.addConnection(entry) + } else { + // Try to remove a connection (may fail if already removed) + ch := lb.channels[id] + mc.removeConnectionByChannel(ch) + } + }) + + assert.Zero(t, panics, "no panics during concurrent add/remove at 1000 nodes") +} + +// TestScale1000_IsConnectedConsistency verifies IsConnected returns consistent +// results during rapid connection state changes at 1000-node scale. +func TestScale1000_IsConnectedConsistency(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node test in short mode") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + lb := setupLightweightBatcher(t, 1000, 10) + defer lb.cleanup() + + var ( + panics atomic.Int64 + wg sync.WaitGroup + ) + + // Goroutines reading IsConnected + + wg.Go(func() { + for range 1000 { + func() { + defer func() { + if r := recover(); r != nil { + panics.Add(1) + } + }() + + for i := 1; i <= 1000; i++ { + _ = lb.b.IsConnected(types.NodeID(i)) //nolint:gosec // test + } + }() + } + }) + + // Goroutine modifying connection state + + wg.Go(func() { + for i := range 100 { + id := types.NodeID(1 + (i % 1000)) //nolint:gosec // test + if i%2 == 0 { + now := time.Now() + lb.b.connected.Store(id, &now) // disconnect + } else { + lb.b.connected.Store(id, nil) // reconnect + } + } + }) + + done := make(chan struct{}) + + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success + case <-time.After(30 * time.Second): + t.Fatal("deadlock in IsConnected consistency test") + } + + assert.Zero(t, panics.Load(), + "IsConnected should not panic under concurrent modification") +} + +// TestScale1000_BroadcastDuringNodeChurn tests that broadcast addToBatch +// calls work correctly while 20% of nodes are joining and leaving. +func TestScale1000_BroadcastDuringNodeChurn(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node test in short mode") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + lb := setupLightweightBatcher(t, 1000, 10) + defer lb.cleanup() + + var ( + panics atomic.Int64 + wg sync.WaitGroup + ) + + // Node churn: 20% of nodes (nodes 801-1000) joining/leaving + + wg.Go(func() { + for cycle := range 20 { + for i := 801; i <= 1000; i++ { + func() { + defer func() { + if r := recover(); r != nil { + panics.Add(1) + } + }() + + id := types.NodeID(i) //nolint:gosec // test + if cycle%2 == 0 { + // "Remove" node + lb.b.nodes.Delete(id) + lb.b.connected.Delete(id) + } else { + // "Add" node back + mc := newMultiChannelNodeConn(id, nil) + ch := make(chan *tailcfg.MapResponse, 10) + entry := &connectionEntry{ + id: fmt.Sprintf("rechurn-%d-%d", i, cycle), + c: ch, + version: tailcfg.CapabilityVersion(100), + created: time.Now(), + } + entry.lastUsed.Store(time.Now().Unix()) + mc.addConnection(entry) + lb.b.nodes.Store(id, mc) + lb.b.connected.Store(id, nil) + } + }() + } + } + }) + + // Concurrent broadcasts + + wg.Go(func() { + for range 50 { + func() { + defer func() { + if r := recover(); r != nil { + panics.Add(1) + } + }() + + lb.b.addToBatch(change.DERPMap()) + }() + } + }) + + done := make(chan struct{}) + + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + t.Logf("broadcast during churn completed, panics: %d", panics.Load()) + case <-time.After(30 * time.Second): + t.Fatal("deadlock in broadcast during node churn") + } + + assert.Zero(t, panics.Load(), + "broadcast during node churn should not panic") +} + +// TestScale1000_WorkChannelSaturation tests that the work channel doesn't +// deadlock when it fills up (queueWork selects on done channel as escape). +func TestScale1000_WorkChannelSaturation(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node test in short mode") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + // Create batcher with SMALL work channel to force saturation + b := &LockFreeBatcher{ + tick: time.NewTicker(10 * time.Millisecond), + workers: 2, + workCh: make(chan work, 10), // Very small - will saturate + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + connected: xsync.NewMap[types.NodeID, *time.Time](), + done: make(chan struct{}), + } + + defer func() { + close(b.done) + b.tick.Stop() + }() + + // Add 1000 nodes + for i := 1; i <= 1000; i++ { + id := types.NodeID(i) //nolint:gosec // test + mc := newMultiChannelNodeConn(id, nil) + ch := make(chan *tailcfg.MapResponse, 1) + entry := &connectionEntry{ + id: fmt.Sprintf("conn-%d", i), + c: ch, + version: tailcfg.CapabilityVersion(100), + created: time.Now(), + } + entry.lastUsed.Store(time.Now().Unix()) + mc.addConnection(entry) + b.nodes.Store(id, mc) + } + + // Add pending changes for all 1000 nodes + for i := 1; i <= 1000; i++ { + if nc, ok := b.nodes.Load(types.NodeID(i)); ok { //nolint:gosec // test + nc.appendPending(change.DERPMap()) + } + } + + // processBatchedChanges should not deadlock even with small work channel. + // queueWork uses select with b.done as escape hatch. + // Start a consumer to slowly drain the work channel. + var consumed atomic.Int64 + + go func() { + for { + select { + case <-b.workCh: + consumed.Add(1) + case <-b.done: + return + } + } + }() + + done := make(chan struct{}) + + go func() { + b.processBatchedChanges() + close(done) + }() + + select { + case <-done: + t.Logf("processBatchedChanges completed, consumed %d work items", consumed.Load()) + case <-time.After(30 * time.Second): + t.Fatal("processBatchedChanges deadlocked with saturated work channel") + } +} + +// TestScale1000_FullUpdate_AllNodesGetPending verifies that a FullUpdate +// creates pending entries for all 1000 nodes. +func TestScale1000_FullUpdate_AllNodesGetPending(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node test in short mode") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + lb := setupLightweightBatcher(t, 1000, 10) + defer lb.cleanup() + + lb.b.addToBatch(change.FullUpdate()) + + nodesWithPending := countNodesPending(lb.b) + assert.Equal(t, 1000, nodesWithPending, + "FullUpdate should create pending entries for all 1000 nodes") + + // Verify each node has exactly one full update pending + lb.b.nodes.Range(func(id types.NodeID, _ *multiChannelNodeConn) bool { + pending := getPendingForNode(lb.b, id) + require.Len(t, pending, 1, "node %d should have 1 pending change", id) + assert.True(t, pending[0].IsFull(), "pending change for node %d should be full", id) + + return true + }) +} + +// ============================================================================ +// 1000-Node Full Pipeline Tests (with DB) +// ============================================================================ + +// TestScale1000_AllToAll_FullPipeline tests the complete pipeline: +// create 1000 nodes in DB, add them to batcher, send FullUpdate, +// verify all nodes see 999 peers. +func TestScale1000_AllToAll_FullPipeline(t *testing.T) { + if testing.Short() { + t.Skip("skipping 1000-node full pipeline test in short mode") + } + + if util.RaceEnabled { + t.Skip("skipping 1000-node test with race detector (bcrypt setup too slow)") + } + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + t.Logf("setting up 1000-node test environment (this may take a minute)...") + + testData, cleanup := setupBatcherWithTestData(t, NewBatcherAndMapper, 1, 1000, 200) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes + + t.Logf("created %d nodes, connecting to batcher...", len(allNodes)) + + // Start update consumers + for i := range allNodes { + allNodes[i].start() + } + + // Connect all nodes + for i := range allNodes { + node := &allNodes[i] + + err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) + if err != nil { + t.Fatalf("failed to add node %d: %v", i, err) + } + // Yield periodically to avoid overwhelming the work queue + if i%50 == 49 { + time.Sleep(10 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + } + } + + t.Logf("all nodes connected, sending FullUpdate and waiting for convergence...") + + // Send FullUpdate + batcher.AddWork(change.FullUpdate()) + + expectedPeers := len(allNodes) - 1 // Each sees all others + + // Wait for all nodes to see all peers + assert.EventuallyWithT(t, func(c *assert.CollectT) { + convergedCount := 0 + + for i := range allNodes { + if int(allNodes[i].maxPeersCount.Load()) >= expectedPeers { + convergedCount++ + } + } + + assert.Equal(c, len(allNodes), convergedCount, + "all nodes should see %d peers (converged: %d/%d)", + expectedPeers, convergedCount, len(allNodes)) + }, 5*time.Minute, 5*time.Second, "waiting for 1000-node convergence") + + // Final statistics + totalUpdates := int64(0) + minPeers := len(allNodes) + maxPeers := 0 + + for i := range allNodes { + stats := allNodes[i].cleanup() + + totalUpdates += stats.TotalUpdates + if stats.MaxPeersSeen < minPeers { + minPeers = stats.MaxPeersSeen + } + + if stats.MaxPeersSeen > maxPeers { + maxPeers = stats.MaxPeersSeen + } + } + + t.Logf("1000-node pipeline: total_updates=%d, min_peers=%d, max_peers=%d, expected=%d", + totalUpdates, minPeers, maxPeers, expectedPeers) + + assert.GreaterOrEqual(t, minPeers, expectedPeers, + "all nodes should have seen at least %d peers", expectedPeers) +} diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 4d35c274..fdb220c5 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -41,8 +41,7 @@ type LockFreeBatcher struct { done chan struct{} doneOnce sync.Once // Ensures done is only closed once - // Batching state - pendingChanges *xsync.Map[types.NodeID, []change.Change] + started atomic.Bool // Ensures Start() is only called once // Metrics totalNodes atomic.Int64 @@ -167,7 +166,12 @@ func (b *LockFreeBatcher) AddWork(r ...change.Change) { } func (b *LockFreeBatcher) Start() { + if !b.started.CompareAndSwap(false, true) { + return + } + b.done = make(chan struct{}) + go b.doWork() } @@ -336,15 +340,16 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) { } b.connected.Delete(removedID) - b.pendingChanges.Delete(removedID) } } // Short circuit if any of the changes is a full update, which // means we can skip sending individual changes. if change.HasFull(changes) { - b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { - b.pendingChanges.Store(nodeID, []change.Change{change.FullUpdate()}) + b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { + nc.pendingMu.Lock() + nc.pending = []change.Change{change.FullUpdate()} + nc.pendingMu.Unlock() return true }) @@ -356,20 +361,18 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) { // Handle targeted changes - send only to the specific node for _, ch := range targeted { - pending, _ := b.pendingChanges.LoadOrStore(ch.TargetNode, []change.Change{}) - pending = append(pending, ch) - b.pendingChanges.Store(ch.TargetNode, pending) + if nc, ok := b.nodes.Load(ch.TargetNode); ok { + nc.appendPending(ch) + } } // Handle broadcast changes - send to all nodes, filtering as needed if len(broadcast) > 0 { - b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { + b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { filtered := change.FilterForNode(nodeID, broadcast) if len(filtered) > 0 { - pending, _ := b.pendingChanges.LoadOrStore(nodeID, []change.Change{}) - pending = append(pending, filtered...) - b.pendingChanges.Store(nodeID, pending) + nc.appendPending(filtered...) } return true @@ -379,12 +382,8 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) { // processBatchedChanges processes all pending batched changes. func (b *LockFreeBatcher) processBatchedChanges() { - if b.pendingChanges == nil { - return - } - - // Process all pending changes - b.pendingChanges.Range(func(nodeID types.NodeID, pending []change.Change) bool { + b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { + pending := nc.drainPending() if len(pending) == 0 { return true } @@ -394,9 +393,6 @@ func (b *LockFreeBatcher) processBatchedChanges() { b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil}) } - // Clear the pending changes for this node - b.pendingChanges.Delete(nodeID) - return true }) } @@ -532,6 +528,13 @@ type multiChannelNodeConn struct { mutex sync.RWMutex connections []*connectionEntry + // pendingMu protects pending changes independently of the connection mutex. + // This avoids contention between addToBatch (which appends changes) and + // send() (which sends data to connections). + pendingMu sync.Mutex + pending []change.Change + + closeOnce sync.Once updateCount atomic.Int64 // lastSentPeers tracks which peers were last sent to this node. @@ -560,12 +563,14 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC } func (mc *multiChannelNodeConn) close() { - mc.mutex.Lock() - defer mc.mutex.Unlock() + mc.closeOnce.Do(func() { + mc.mutex.Lock() + defer mc.mutex.Unlock() - for _, conn := range mc.connections { - mc.stopConnection(conn) - } + for _, conn := range mc.connections { + mc.stopConnection(conn) + } + }) } // stopConnection marks a connection as closed and tears down the owning session @@ -647,6 +652,25 @@ func (mc *multiChannelNodeConn) getActiveConnectionCount() int { return len(mc.connections) } +// appendPending appends changes to this node's pending change list. +// Thread-safe via pendingMu; does not contend with the connection mutex. +func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) { + mc.pendingMu.Lock() + mc.pending = append(mc.pending, changes...) + mc.pendingMu.Unlock() +} + +// drainPending atomically removes and returns all pending changes. +// Returns nil if there are no pending changes. +func (mc *multiChannelNodeConn) drainPending() []change.Change { + mc.pendingMu.Lock() + p := mc.pending + mc.pending = nil + mc.pendingMu.Unlock() + + return p +} + // send broadcasts data to all active connections for the node. func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { if data == nil { diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 4036361f..281bdb02 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -2302,8 +2302,8 @@ func TestBatcherRapidReconnection(t *testing.T) { t.Logf("=== RAPID RECONNECTION TEST ===") t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes)) - // Phase 1: Connect all nodes initially - t.Logf("Phase 1: Connecting all nodes...") + // Connect all nodes initially. + t.Logf("Connecting all nodes...") for i := range allNodes { node := &allNodes[i] @@ -2321,8 +2321,8 @@ func TestBatcherRapidReconnection(t *testing.T) { } }, 5*time.Second, 50*time.Millisecond, "waiting for connections to settle") - // Phase 2: Rapid disconnect ALL nodes (simulating nodes going down) - t.Logf("Phase 2: Rapid disconnect all nodes...") + // Rapid disconnect ALL nodes (simulating nodes going down). + t.Logf("Rapid disconnect all nodes...") for i := range allNodes { node := &allNodes[i] @@ -2330,8 +2330,8 @@ func TestBatcherRapidReconnection(t *testing.T) { t.Logf("Node %d RemoveNode result: %t", i, removed) } - // Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up) - t.Logf("Phase 3: Rapid reconnect with new channels...") + // Rapid reconnect with NEW channels (simulating nodes coming back up). + t.Logf("Rapid reconnect with new channels...") newChannels := make([]chan *tailcfg.MapResponse, len(allNodes)) for i := range allNodes { @@ -2351,8 +2351,8 @@ func TestBatcherRapidReconnection(t *testing.T) { } }, 5*time.Second, 50*time.Millisecond, "waiting for reconnections to settle") - // Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR - t.Logf("Phase 4: Checking debug status...") + // Check debug status after reconnection. + t.Logf("Checking debug status...") if debugBatcher, ok := batcher.(interface { Debug() map[types.NodeID]any @@ -2396,8 +2396,8 @@ func TestBatcherRapidReconnection(t *testing.T) { t.Logf("Batcher does not implement Debug() method") } - // Phase 5: Test if "disconnected" nodes can actually receive updates - t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...") + // Test if "disconnected" nodes can actually receive updates. + t.Logf("Testing if nodes can receive updates despite debug status...") // Send a change that should reach all nodes batcher.AddWork(change.DERPMap()) @@ -2442,8 +2442,8 @@ func TestBatcherMultiConnection(t *testing.T) { t.Logf("=== MULTI-CONNECTION TEST ===") - // Phase 1: Connect first node with initial connection - t.Logf("Phase 1: Connecting node 1 with first connection...") + // Connect first node with initial connection. + t.Logf("Connecting node 1 with first connection...") err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil) if err != nil { @@ -2462,8 +2462,8 @@ func TestBatcherMultiConnection(t *testing.T) { assert.True(c, batcher.IsConnected(node2.n.ID), "node2 should be connected") }, time.Second, 10*time.Millisecond, "waiting for initial connections") - // Phase 2: Add second connection for node1 (multi-connection scenario) - t.Logf("Phase 2: Adding second connection for node 1...") + // Add second connection for node1 (multi-connection scenario). + t.Logf("Adding second connection for node 1...") secondChannel := make(chan *tailcfg.MapResponse, 10) @@ -2475,8 +2475,8 @@ func TestBatcherMultiConnection(t *testing.T) { // Yield to allow connection to be processed runtime.Gosched() - // Phase 3: Add third connection for node1 - t.Logf("Phase 3: Adding third connection for node 1...") + // Add third connection for node1. + t.Logf("Adding third connection for node 1...") thirdChannel := make(chan *tailcfg.MapResponse, 10) @@ -2488,8 +2488,8 @@ func TestBatcherMultiConnection(t *testing.T) { // Yield to allow connection to be processed runtime.Gosched() - // Phase 4: Verify debug status shows correct connection count - t.Logf("Phase 4: Verifying debug status shows multiple connections...") + // Verify debug status shows correct connection count. + t.Logf("Verifying debug status shows multiple connections...") if debugBatcher, ok := batcher.(interface { Debug() map[types.NodeID]any @@ -2525,8 +2525,8 @@ func TestBatcherMultiConnection(t *testing.T) { } } - // Phase 5: Send update and verify ALL connections receive it - t.Logf("Phase 5: Testing update distribution to all connections...") + // Send update and verify ALL connections receive it. + t.Logf("Testing update distribution to all connections...") // Clear any existing updates from all channels clearChannel := func(ch chan *tailcfg.MapResponse) { @@ -2591,8 +2591,8 @@ func TestBatcherMultiConnection(t *testing.T) { connection1Received, connection2Received, connection3Received) } - // Phase 6: Test connection removal and verify remaining connections still work - t.Logf("Phase 6: Testing connection removal...") + // Test connection removal and verify remaining connections still work. + t.Logf("Testing connection removal...") // Remove the second connection removed := batcher.RemoveNode(node1.n.ID, secondChannel) diff --git a/hscontrol/util/norace.go b/hscontrol/util/norace.go new file mode 100644 index 00000000..d925058c --- /dev/null +++ b/hscontrol/util/norace.go @@ -0,0 +1,6 @@ +//go:build !race + +package util + +// RaceEnabled is true when the race detector is active. +const RaceEnabled = false diff --git a/hscontrol/util/race.go b/hscontrol/util/race.go new file mode 100644 index 00000000..2c21a9b3 --- /dev/null +++ b/hscontrol/util/race.go @@ -0,0 +1,6 @@ +//go:build race + +package util + +// RaceEnabled is true when the race detector is active. +const RaceEnabled = true