mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-17 19:53:57 +09:00
mapper/batcher: reduce lock contention with two-phase send
Rewrite multiChannelNodeConn.send() to use a two-phase approach: 1. RLock: snapshot connections slice (cheap pointer copy) 2. Unlock: send to all connections (50ms timeouts happen here) 3. Lock: remove failed connections by pointer identity Previously, send() held the write lock for the entire duration of sending to all connections. With N stale connections each timing out at 50ms, this blocked addConnection/removeConnection for N*50ms. The two-phase approach holds the lock only for O(N) pointer operations, not for N*50ms I/O waits.
This commit is contained in:
@@ -950,18 +950,19 @@ func TestBug7_CleanupOfflineNodes_PendingChangesCleanedStructurally(t *testing.T
|
||||
}
|
||||
|
||||
// TestBug8_SerialTimeoutUnderWriteLock exercises Bug #8 (performance):
|
||||
// multiChannelNodeConn.send() holds the write lock for the ENTIRE duration
|
||||
// of sending to all connections. Each send has a 50ms timeout for stale
|
||||
// connections. With N stale connections, the write lock is held for N*50ms,
|
||||
// blocking all addConnection/removeConnection calls.
|
||||
// multiChannelNodeConn.send() originally held the write lock for the ENTIRE
|
||||
// duration of sending to all connections. Each send has a 50ms timeout for
|
||||
// stale connections. With N stale connections, the write lock was held for
|
||||
// N*50ms, blocking all addConnection/removeConnection calls.
|
||||
//
|
||||
// BUG: batcher_lockfree.go:629-697 - mutex.Lock() held during all conn.send()
|
||||
// BUG: mutex.Lock() held during all conn.send() calls, each with 50ms timeout.
|
||||
//
|
||||
// calls, each with 50ms timeout. 5 stale connections = 250ms lock hold.
|
||||
// 5 stale connections = 250ms lock hold, blocking addConnection/removeConnection.
|
||||
//
|
||||
// FIX: Copy connections under read lock, send without lock, then take
|
||||
// FIX: Snapshot connections under read lock, release, send without any lock
|
||||
//
|
||||
// write lock only for removing failed connections.
|
||||
// (timeouts happen here), then write-lock only to remove failed connections.
|
||||
// The lock is now held only for O(N) pointer copies, not for N*50ms I/O.
|
||||
func TestBug8_SerialTimeoutUnderWriteLock(t *testing.T) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
@@ -975,27 +976,42 @@ func TestBug8_SerialTimeoutUnderWriteLock(t *testing.T) {
|
||||
mc.addConnection(makeConnectionEntry(fmt.Sprintf("stale-%d", i), ch))
|
||||
}
|
||||
|
||||
// Measure how long send() takes - it should timeout at ~50ms for ONE
|
||||
// connection, but with serial timeouts it takes staleCount * 50ms.
|
||||
start := time.Now()
|
||||
// The key test: verify that the mutex is NOT held during the slow sends.
|
||||
// We do this by trying to acquire the lock from another goroutine during
|
||||
// the send. With the old code (lock held for 250ms), this would block.
|
||||
// With the fix, the lock is free during sends.
|
||||
lockAcquired := make(chan time.Duration, 1)
|
||||
|
||||
go func() {
|
||||
// Give send() a moment to start (it will be in the unlocked send window)
|
||||
time.Sleep(20 * time.Millisecond) //nolint:forbidigo // concurrency test coordination
|
||||
|
||||
// Try to acquire the write lock. It should succeed quickly because
|
||||
// the lock is only held briefly for the snapshot and cleanup.
|
||||
start := time.Now()
|
||||
|
||||
mc.mutex.Lock()
|
||||
lockWait := time.Since(start)
|
||||
mc.mutex.Unlock()
|
||||
|
||||
lockAcquired <- lockWait
|
||||
}()
|
||||
|
||||
// Run send() with 5 stale connections. Total wall time will be ~250ms
|
||||
// (5 * 50ms serial timeouts), but the lock should be free during sends.
|
||||
_ = mc.send(testMapResponse())
|
||||
elapsed := time.Since(start)
|
||||
|
||||
t.Logf("send() with %d stale connections took %v (expected ~50ms, got ~%dms)",
|
||||
staleCount, elapsed, elapsed.Milliseconds())
|
||||
lockWait := <-lockAcquired
|
||||
t.Logf("lock acquisition during send() with %d stale connections waited %v",
|
||||
staleCount, lockWait)
|
||||
|
||||
// The write lock is held for the entire duration. With 5 stale connections,
|
||||
// each timing out at 50ms, that's ~250ms of write lock hold time.
|
||||
// This blocks ALL other operations (addConnection, removeConnection, etc.)
|
||||
//
|
||||
// The fix should make send() complete in ~50ms regardless of stale count
|
||||
// by releasing the lock before sending, or sending in parallel.
|
||||
assert.Less(t, elapsed, 100*time.Millisecond,
|
||||
"BUG #8: send() held write lock for %v with %d stale connections. "+
|
||||
"Serial 50ms timeouts under write lock cause %d*50ms=%dms lock hold. "+
|
||||
"Fix: copy connections under read lock, send without lock, then "+
|
||||
"write-lock only for cleanup",
|
||||
elapsed, staleCount, staleCount, staleCount*50)
|
||||
// The lock wait should be very short (<50ms) since the lock is released
|
||||
// before sending. With the old code it would be ~230ms (250ms - 20ms sleep).
|
||||
assert.Less(t, lockWait, 50*time.Millisecond,
|
||||
"mutex was held for %v during send() with %d stale connections; "+
|
||||
"lock should be released before sending to allow "+
|
||||
"concurrent addConnection/removeConnection calls",
|
||||
lockWait, staleCount)
|
||||
}
|
||||
|
||||
// TestBug1_BroadcastNoDataLoss verifies that concurrent broadcast addToBatch
|
||||
|
||||
@@ -707,71 +707,108 @@ func (mc *multiChannelNodeConn) drainPending() []change.Change {
|
||||
}
|
||||
|
||||
// send broadcasts data to all active connections for the node.
|
||||
// send broadcasts data to all connections using a two-phase approach to avoid
|
||||
// holding the write lock during potentially slow sends. Each stale connection
|
||||
// can block for up to 50ms (see connectionEntry.send), so N stale connections
|
||||
// under a single write lock would block for N*50ms. The two-phase approach:
|
||||
//
|
||||
// 1. RLock: snapshot the connections slice (cheap pointer copy)
|
||||
// 2. Unlock: send to all connections without any lock held (timeouts happen here)
|
||||
// 3. Lock: remove only the failed connections by pointer identity
|
||||
//
|
||||
// New connections added during step 2 are safe: they receive a full initial
|
||||
// map via AddNode, so missing this particular update causes no data loss.
|
||||
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
// Phase 1: snapshot connections under read lock.
|
||||
mc.mutex.RLock()
|
||||
if len(mc.connections) == 0 {
|
||||
// During rapid reconnection, nodes may temporarily have no active connections
|
||||
// This is not an error - the node will receive a full map when it reconnects
|
||||
mc.mutex.RUnlock()
|
||||
mc.log.Debug().Caller().
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
|
||||
return nil // Return success instead of error
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy the slice header (shares underlying array, but that's fine since
|
||||
// we only read; writes go through the write lock in phase 3).
|
||||
snapshot := make([]*connectionEntry, len(mc.connections))
|
||||
copy(snapshot, mc.connections)
|
||||
mc.mutex.RUnlock()
|
||||
|
||||
mc.log.Debug().Caller().
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Int("total_connections", len(snapshot)).
|
||||
Msg("send: broadcasting to all connections")
|
||||
|
||||
var lastErr error
|
||||
// Phase 2: send to all connections without holding any lock.
|
||||
// Stale connection timeouts (50ms each) happen here without blocking
|
||||
// other goroutines that need the mutex.
|
||||
var (
|
||||
lastErr error
|
||||
successCount int
|
||||
failed []*connectionEntry
|
||||
)
|
||||
|
||||
successCount := 0
|
||||
|
||||
var failedConnections []int // Track failed connections for removal
|
||||
|
||||
// Send to all connections
|
||||
for i, conn := range mc.connections {
|
||||
for _, conn := range snapshot {
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
Str(zf.ConnID, conn.id).
|
||||
Msg("send: attempting to send to connection")
|
||||
|
||||
err := conn.send(data)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
|
||||
failedConnections = append(failedConnections, i)
|
||||
failed = append(failed, conn)
|
||||
|
||||
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
Str(zf.ConnID, conn.id).
|
||||
Msg("send: connection send failed")
|
||||
} else {
|
||||
successCount++
|
||||
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
Str(zf.ConnID, conn.id).
|
||||
Msg("send: successfully sent to connection")
|
||||
}
|
||||
}
|
||||
|
||||
// Remove failed connections (in reverse order to maintain indices)
|
||||
for i := len(failedConnections) - 1; i >= 0; i-- {
|
||||
idx := failedConnections[i]
|
||||
entry := mc.removeConnectionAtIndexLocked(idx, true)
|
||||
mc.log.Debug().Caller().
|
||||
Str(zf.ConnID, entry.id).
|
||||
Msg("send: removed failed connection")
|
||||
// Phase 3: write-lock only to remove failed connections.
|
||||
if len(failed) > 0 {
|
||||
mc.mutex.Lock()
|
||||
// Remove by pointer identity: only remove entries that still exist
|
||||
// in the current connections slice and match a failed pointer.
|
||||
// New connections added between phase 1 and 3 are not affected.
|
||||
failedSet := make(map[*connectionEntry]struct{}, len(failed))
|
||||
for _, f := range failed {
|
||||
failedSet[f] = struct{}{}
|
||||
}
|
||||
|
||||
clean := mc.connections[:0]
|
||||
for _, conn := range mc.connections {
|
||||
if _, isFailed := failedSet[conn]; !isFailed {
|
||||
clean = append(clean, conn)
|
||||
} else {
|
||||
mc.log.Debug().Caller().
|
||||
Str(zf.ConnID, conn.id).
|
||||
Msg("send: removing failed connection")
|
||||
// Tear down the owning session so the old serveLongPoll
|
||||
// goroutine exits instead of lingering as a stale session.
|
||||
mc.stopConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
mc.connections = clean
|
||||
mc.mutex.Unlock()
|
||||
}
|
||||
|
||||
mc.updateCount.Add(1)
|
||||
|
||||
mc.log.Debug().
|
||||
Int("successful_sends", successCount).
|
||||
Int("failed_connections", len(failedConnections)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Int("failed_connections", len(failed)).
|
||||
Msg("send: completed broadcast")
|
||||
|
||||
// Success if at least one send succeeded
|
||||
|
||||
Reference in New Issue
Block a user