From 3ebe4d99c1abcbc0aceec7fa65d4022dcd9f0218 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 10 Mar 2026 17:18:12 +0000 Subject: [PATCH] mapper/batcher: reduce lock contention with two-phase send Rewrite multiChannelNodeConn.send() to use a two-phase approach: 1. RLock: snapshot connections slice (cheap pointer copy) 2. Unlock: send to all connections (50ms timeouts happen here) 3. Lock: remove failed connections by pointer identity Previously, send() held the write lock for the entire duration of sending to all connections. With N stale connections each timing out at 50ms, this blocked addConnection/removeConnection for N*50ms. The two-phase approach holds the lock only for O(N) pointer operations, not for N*50ms I/O waits. --- hscontrol/mapper/batcher_concurrency_test.go | 68 +++++++++------ hscontrol/mapper/batcher_lockfree.go | 91 ++++++++++++++------ 2 files changed, 106 insertions(+), 53 deletions(-) diff --git a/hscontrol/mapper/batcher_concurrency_test.go b/hscontrol/mapper/batcher_concurrency_test.go index 9fbd3130..b15175b1 100644 --- a/hscontrol/mapper/batcher_concurrency_test.go +++ b/hscontrol/mapper/batcher_concurrency_test.go @@ -950,18 +950,19 @@ func TestBug7_CleanupOfflineNodes_PendingChangesCleanedStructurally(t *testing.T } // TestBug8_SerialTimeoutUnderWriteLock exercises Bug #8 (performance): -// multiChannelNodeConn.send() holds the write lock for the ENTIRE duration -// of sending to all connections. Each send has a 50ms timeout for stale -// connections. With N stale connections, the write lock is held for N*50ms, -// blocking all addConnection/removeConnection calls. +// multiChannelNodeConn.send() originally held the write lock for the ENTIRE +// duration of sending to all connections. Each send has a 50ms timeout for +// stale connections. With N stale connections, the write lock was held for +// N*50ms, blocking all addConnection/removeConnection calls. // -// BUG: batcher_lockfree.go:629-697 - mutex.Lock() held during all conn.send() +// BUG: mutex.Lock() held during all conn.send() calls, each with 50ms timeout. // -// calls, each with 50ms timeout. 5 stale connections = 250ms lock hold. +// 5 stale connections = 250ms lock hold, blocking addConnection/removeConnection. // -// FIX: Copy connections under read lock, send without lock, then take +// FIX: Snapshot connections under read lock, release, send without any lock // -// write lock only for removing failed connections. +// (timeouts happen here), then write-lock only to remove failed connections. +// The lock is now held only for O(N) pointer copies, not for N*50ms I/O. func TestBug8_SerialTimeoutUnderWriteLock(t *testing.T) { zerolog.SetGlobalLevel(zerolog.Disabled) defer zerolog.SetGlobalLevel(zerolog.DebugLevel) @@ -975,27 +976,42 @@ func TestBug8_SerialTimeoutUnderWriteLock(t *testing.T) { mc.addConnection(makeConnectionEntry(fmt.Sprintf("stale-%d", i), ch)) } - // Measure how long send() takes - it should timeout at ~50ms for ONE - // connection, but with serial timeouts it takes staleCount * 50ms. - start := time.Now() + // The key test: verify that the mutex is NOT held during the slow sends. + // We do this by trying to acquire the lock from another goroutine during + // the send. With the old code (lock held for 250ms), this would block. + // With the fix, the lock is free during sends. + lockAcquired := make(chan time.Duration, 1) + + go func() { + // Give send() a moment to start (it will be in the unlocked send window) + time.Sleep(20 * time.Millisecond) //nolint:forbidigo // concurrency test coordination + + // Try to acquire the write lock. It should succeed quickly because + // the lock is only held briefly for the snapshot and cleanup. + start := time.Now() + + mc.mutex.Lock() + lockWait := time.Since(start) + mc.mutex.Unlock() + + lockAcquired <- lockWait + }() + + // Run send() with 5 stale connections. Total wall time will be ~250ms + // (5 * 50ms serial timeouts), but the lock should be free during sends. _ = mc.send(testMapResponse()) - elapsed := time.Since(start) - t.Logf("send() with %d stale connections took %v (expected ~50ms, got ~%dms)", - staleCount, elapsed, elapsed.Milliseconds()) + lockWait := <-lockAcquired + t.Logf("lock acquisition during send() with %d stale connections waited %v", + staleCount, lockWait) - // The write lock is held for the entire duration. With 5 stale connections, - // each timing out at 50ms, that's ~250ms of write lock hold time. - // This blocks ALL other operations (addConnection, removeConnection, etc.) - // - // The fix should make send() complete in ~50ms regardless of stale count - // by releasing the lock before sending, or sending in parallel. - assert.Less(t, elapsed, 100*time.Millisecond, - "BUG #8: send() held write lock for %v with %d stale connections. "+ - "Serial 50ms timeouts under write lock cause %d*50ms=%dms lock hold. "+ - "Fix: copy connections under read lock, send without lock, then "+ - "write-lock only for cleanup", - elapsed, staleCount, staleCount, staleCount*50) + // The lock wait should be very short (<50ms) since the lock is released + // before sending. With the old code it would be ~230ms (250ms - 20ms sleep). + assert.Less(t, lockWait, 50*time.Millisecond, + "mutex was held for %v during send() with %d stale connections; "+ + "lock should be released before sending to allow "+ + "concurrent addConnection/removeConnection calls", + lockWait, staleCount) } // TestBug1_BroadcastNoDataLoss verifies that concurrent broadcast addToBatch diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 2d0f27c8..37a748bd 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -707,71 +707,108 @@ func (mc *multiChannelNodeConn) drainPending() []change.Change { } // send broadcasts data to all active connections for the node. +// send broadcasts data to all connections using a two-phase approach to avoid +// holding the write lock during potentially slow sends. Each stale connection +// can block for up to 50ms (see connectionEntry.send), so N stale connections +// under a single write lock would block for N*50ms. The two-phase approach: +// +// 1. RLock: snapshot the connections slice (cheap pointer copy) +// 2. Unlock: send to all connections without any lock held (timeouts happen here) +// 3. Lock: remove only the failed connections by pointer identity +// +// New connections added during step 2 are safe: they receive a full initial +// map via AddNode, so missing this particular update causes no data loss. func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { if data == nil { return nil } - mc.mutex.Lock() - defer mc.mutex.Unlock() - + // Phase 1: snapshot connections under read lock. + mc.mutex.RLock() if len(mc.connections) == 0 { - // During rapid reconnection, nodes may temporarily have no active connections - // This is not an error - the node will receive a full map when it reconnects + mc.mutex.RUnlock() mc.log.Debug().Caller(). Msg("send: skipping send to node with no active connections (likely rapid reconnection)") - return nil // Return success instead of error + return nil } + // Copy the slice header (shares underlying array, but that's fine since + // we only read; writes go through the write lock in phase 3). + snapshot := make([]*connectionEntry, len(mc.connections)) + copy(snapshot, mc.connections) + mc.mutex.RUnlock() + mc.log.Debug().Caller(). - Int("total_connections", len(mc.connections)). + Int("total_connections", len(snapshot)). Msg("send: broadcasting to all connections") - var lastErr error + // Phase 2: send to all connections without holding any lock. + // Stale connection timeouts (50ms each) happen here without blocking + // other goroutines that need the mutex. + var ( + lastErr error + successCount int + failed []*connectionEntry + ) - successCount := 0 - - var failedConnections []int // Track failed connections for removal - - // Send to all connections - for i, conn := range mc.connections { + for _, conn := range snapshot { mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)). - Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i). + Str(zf.ConnID, conn.id). Msg("send: attempting to send to connection") err := conn.send(data) if err != nil { lastErr = err - failedConnections = append(failedConnections, i) + failed = append(failed, conn) + mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)). - Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i). + Str(zf.ConnID, conn.id). Msg("send: connection send failed") } else { successCount++ mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)). - Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i). + Str(zf.ConnID, conn.id). Msg("send: successfully sent to connection") } } - // Remove failed connections (in reverse order to maintain indices) - for i := len(failedConnections) - 1; i >= 0; i-- { - idx := failedConnections[i] - entry := mc.removeConnectionAtIndexLocked(idx, true) - mc.log.Debug().Caller(). - Str(zf.ConnID, entry.id). - Msg("send: removed failed connection") + // Phase 3: write-lock only to remove failed connections. + if len(failed) > 0 { + mc.mutex.Lock() + // Remove by pointer identity: only remove entries that still exist + // in the current connections slice and match a failed pointer. + // New connections added between phase 1 and 3 are not affected. + failedSet := make(map[*connectionEntry]struct{}, len(failed)) + for _, f := range failed { + failedSet[f] = struct{}{} + } + + clean := mc.connections[:0] + for _, conn := range mc.connections { + if _, isFailed := failedSet[conn]; !isFailed { + clean = append(clean, conn) + } else { + mc.log.Debug().Caller(). + Str(zf.ConnID, conn.id). + Msg("send: removing failed connection") + // Tear down the owning session so the old serveLongPoll + // goroutine exits instead of lingering as a stale session. + mc.stopConnection(conn) + } + } + + mc.connections = clean + mc.mutex.Unlock() } mc.updateCount.Add(1) mc.log.Debug(). Int("successful_sends", successCount). - Int("failed_connections", len(failedConnections)). - Int("remaining_connections", len(mc.connections)). + Int("failed_connections", len(failed)). Msg("send: completed broadcast") // Success if at least one send succeeded