mapper/batcher: reduce lock contention with two-phase send

Rewrite multiChannelNodeConn.send() to use a two-phase approach:
1. RLock: snapshot connections slice (cheap pointer copy)
2. Unlock: send to all connections (50ms timeouts happen here)
3. Lock: remove failed connections by pointer identity

Previously, send() held the write lock for the entire duration of
sending to all connections. With N stale connections each timing out
at 50ms, this blocked addConnection/removeConnection for N*50ms.
The two-phase approach holds the lock only for O(N) pointer
operations, not for N*50ms I/O waits.
This commit is contained in:
Kristoffer Dalby
2026-03-10 17:18:12 +00:00
parent da33795e79
commit 3ebe4d99c1
2 changed files with 106 additions and 53 deletions

View File

@@ -950,18 +950,19 @@ func TestBug7_CleanupOfflineNodes_PendingChangesCleanedStructurally(t *testing.T
}
// TestBug8_SerialTimeoutUnderWriteLock exercises Bug #8 (performance):
// multiChannelNodeConn.send() holds the write lock for the ENTIRE duration
// of sending to all connections. Each send has a 50ms timeout for stale
// connections. With N stale connections, the write lock is held for N*50ms,
// blocking all addConnection/removeConnection calls.
// multiChannelNodeConn.send() originally held the write lock for the ENTIRE
// duration of sending to all connections. Each send has a 50ms timeout for
// stale connections. With N stale connections, the write lock was held for
// N*50ms, blocking all addConnection/removeConnection calls.
//
// BUG: batcher_lockfree.go:629-697 - mutex.Lock() held during all conn.send()
// BUG: mutex.Lock() held during all conn.send() calls, each with 50ms timeout.
//
// calls, each with 50ms timeout. 5 stale connections = 250ms lock hold.
// 5 stale connections = 250ms lock hold, blocking addConnection/removeConnection.
//
// FIX: Copy connections under read lock, send without lock, then take
// FIX: Snapshot connections under read lock, release, send without any lock
//
// write lock only for removing failed connections.
// (timeouts happen here), then write-lock only to remove failed connections.
// The lock is now held only for O(N) pointer copies, not for N*50ms I/O.
func TestBug8_SerialTimeoutUnderWriteLock(t *testing.T) {
zerolog.SetGlobalLevel(zerolog.Disabled)
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
@@ -975,27 +976,42 @@ func TestBug8_SerialTimeoutUnderWriteLock(t *testing.T) {
mc.addConnection(makeConnectionEntry(fmt.Sprintf("stale-%d", i), ch))
}
// Measure how long send() takes - it should timeout at ~50ms for ONE
// connection, but with serial timeouts it takes staleCount * 50ms.
start := time.Now()
// The key test: verify that the mutex is NOT held during the slow sends.
// We do this by trying to acquire the lock from another goroutine during
// the send. With the old code (lock held for 250ms), this would block.
// With the fix, the lock is free during sends.
lockAcquired := make(chan time.Duration, 1)
go func() {
// Give send() a moment to start (it will be in the unlocked send window)
time.Sleep(20 * time.Millisecond) //nolint:forbidigo // concurrency test coordination
// Try to acquire the write lock. It should succeed quickly because
// the lock is only held briefly for the snapshot and cleanup.
start := time.Now()
mc.mutex.Lock()
lockWait := time.Since(start)
mc.mutex.Unlock()
lockAcquired <- lockWait
}()
// Run send() with 5 stale connections. Total wall time will be ~250ms
// (5 * 50ms serial timeouts), but the lock should be free during sends.
_ = mc.send(testMapResponse())
elapsed := time.Since(start)
t.Logf("send() with %d stale connections took %v (expected ~50ms, got ~%dms)",
staleCount, elapsed, elapsed.Milliseconds())
lockWait := <-lockAcquired
t.Logf("lock acquisition during send() with %d stale connections waited %v",
staleCount, lockWait)
// The write lock is held for the entire duration. With 5 stale connections,
// each timing out at 50ms, that's ~250ms of write lock hold time.
// This blocks ALL other operations (addConnection, removeConnection, etc.)
//
// The fix should make send() complete in ~50ms regardless of stale count
// by releasing the lock before sending, or sending in parallel.
assert.Less(t, elapsed, 100*time.Millisecond,
"BUG #8: send() held write lock for %v with %d stale connections. "+
"Serial 50ms timeouts under write lock cause %d*50ms=%dms lock hold. "+
"Fix: copy connections under read lock, send without lock, then "+
"write-lock only for cleanup",
elapsed, staleCount, staleCount, staleCount*50)
// The lock wait should be very short (<50ms) since the lock is released
// before sending. With the old code it would be ~230ms (250ms - 20ms sleep).
assert.Less(t, lockWait, 50*time.Millisecond,
"mutex was held for %v during send() with %d stale connections; "+
"lock should be released before sending to allow "+
"concurrent addConnection/removeConnection calls",
lockWait, staleCount)
}
// TestBug1_BroadcastNoDataLoss verifies that concurrent broadcast addToBatch

View File

@@ -707,71 +707,108 @@ func (mc *multiChannelNodeConn) drainPending() []change.Change {
}
// send broadcasts data to all active connections for the node.
// send broadcasts data to all connections using a two-phase approach to avoid
// holding the write lock during potentially slow sends. Each stale connection
// can block for up to 50ms (see connectionEntry.send), so N stale connections
// under a single write lock would block for N*50ms. The two-phase approach:
//
// 1. RLock: snapshot the connections slice (cheap pointer copy)
// 2. Unlock: send to all connections without any lock held (timeouts happen here)
// 3. Lock: remove only the failed connections by pointer identity
//
// New connections added during step 2 are safe: they receive a full initial
// map via AddNode, so missing this particular update causes no data loss.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
mc.mutex.Lock()
defer mc.mutex.Unlock()
// Phase 1: snapshot connections under read lock.
mc.mutex.RLock()
if len(mc.connections) == 0 {
// During rapid reconnection, nodes may temporarily have no active connections
// This is not an error - the node will receive a full map when it reconnects
mc.mutex.RUnlock()
mc.log.Debug().Caller().
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
return nil // Return success instead of error
return nil
}
// Copy the slice header (shares underlying array, but that's fine since
// we only read; writes go through the write lock in phase 3).
snapshot := make([]*connectionEntry, len(mc.connections))
copy(snapshot, mc.connections)
mc.mutex.RUnlock()
mc.log.Debug().Caller().
Int("total_connections", len(mc.connections)).
Int("total_connections", len(snapshot)).
Msg("send: broadcasting to all connections")
var lastErr error
// Phase 2: send to all connections without holding any lock.
// Stale connection timeouts (50ms each) happen here without blocking
// other goroutines that need the mutex.
var (
lastErr error
successCount int
failed []*connectionEntry
)
successCount := 0
var failedConnections []int // Track failed connections for removal
// Send to all connections
for i, conn := range mc.connections {
for _, conn := range snapshot {
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
Str(zf.ConnID, conn.id).
Msg("send: attempting to send to connection")
err := conn.send(data)
if err != nil {
lastErr = err
failedConnections = append(failedConnections, i)
failed = append(failed, conn)
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
Str(zf.ConnID, conn.id).
Msg("send: connection send failed")
} else {
successCount++
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
Str(zf.ConnID, conn.id).
Msg("send: successfully sent to connection")
}
}
// Remove failed connections (in reverse order to maintain indices)
for i := len(failedConnections) - 1; i >= 0; i-- {
idx := failedConnections[i]
entry := mc.removeConnectionAtIndexLocked(idx, true)
mc.log.Debug().Caller().
Str(zf.ConnID, entry.id).
Msg("send: removed failed connection")
// Phase 3: write-lock only to remove failed connections.
if len(failed) > 0 {
mc.mutex.Lock()
// Remove by pointer identity: only remove entries that still exist
// in the current connections slice and match a failed pointer.
// New connections added between phase 1 and 3 are not affected.
failedSet := make(map[*connectionEntry]struct{}, len(failed))
for _, f := range failed {
failedSet[f] = struct{}{}
}
clean := mc.connections[:0]
for _, conn := range mc.connections {
if _, isFailed := failedSet[conn]; !isFailed {
clean = append(clean, conn)
} else {
mc.log.Debug().Caller().
Str(zf.ConnID, conn.id).
Msg("send: removing failed connection")
// Tear down the owning session so the old serveLongPoll
// goroutine exits instead of lingering as a stale session.
mc.stopConnection(conn)
}
}
mc.connections = clean
mc.mutex.Unlock()
}
mc.updateCount.Add(1)
mc.log.Debug().
Int("successful_sends", successCount).
Int("failed_connections", len(failedConnections)).
Int("remaining_connections", len(mc.connections)).
Int("failed_connections", len(failed)).
Msg("send: completed broadcast")
// Success if at least one send succeeded