mapper/batcher: restructure internals for correctness

Move per-node pending changes from a shared xsync.Map on the batcher
into multiChannelNodeConn, protected by a dedicated mutex. The new
appendPending/drainPending methods provide atomic append and drain
operations, eliminating data races in addToBatch and
processBatchedChanges.

Add sync.Once to multiChannelNodeConn.close() to make it idempotent,
preventing panics from concurrent close calls on the same channel.

Add started atomic.Bool to guard Start() against being called
multiple times, preventing orphaned goroutines.

Add comprehensive concurrency tests validating these changes.
This commit is contained in:
Kristoffer Dalby
2026-03-13 13:31:39 +00:00
parent 21e02e5d1f
commit 57070680a5
7 changed files with 1836 additions and 69 deletions

View File

@@ -52,10 +52,9 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
tick: time.NewTicker(batchTime),
// The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
}
}

View File

@@ -150,13 +150,12 @@ func BenchmarkUpdateSentPeers(b *testing.B) {
// helper, it doesn't register cleanup and suppresses logging.
func benchBatcher(nodeCount, bufferSize int) (*LockFreeBatcher, map[types.NodeID]chan *tailcfg.MapResponse) {
b := &LockFreeBatcher{
tick: time.NewTicker(1 * time.Hour), // never fires during bench
workers: 4,
workCh: make(chan work, 4*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
done: make(chan struct{}),
tick: time.NewTicker(1 * time.Hour), // never fires during bench
workers: 4,
workCh: make(chan work, 4*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
done: make(chan struct{}),
}
channels := make(map[types.NodeID]chan *tailcfg.MapResponse, nodeCount)
@@ -204,8 +203,8 @@ func BenchmarkAddToBatch_Broadcast(b *testing.B) {
for range b.N {
batcher.addToBatch(ch)
// Clear pending to avoid unbounded growth
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool {
batcher.pendingChanges.Delete(id)
batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
nc.drainPending()
return true
})
}
@@ -242,8 +241,8 @@ func BenchmarkAddToBatch_Targeted(b *testing.B) {
batcher.addToBatch(ch)
// Clear pending periodically to avoid growth
if i%100 == 99 {
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool {
batcher.pendingChanges.Delete(id)
batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
nc.drainPending()
return true
})
}
@@ -298,7 +297,9 @@ func BenchmarkProcessBatchedChanges(b *testing.B) {
b.StopTimer()
// Seed pending changes
for i := 1; i <= nodeCount; i++ {
batcher.pendingChanges.Store(types.NodeID(i), []change.Change{change.DERPMap()}) //nolint:gosec // benchmark
if nc, ok := batcher.nodes.Load(types.NodeID(i)); ok { //nolint:gosec // benchmark
nc.appendPending(change.DERPMap())
}
}
b.StartTimer()
@@ -411,8 +412,8 @@ func BenchmarkConcurrentAddToBatch(b *testing.B) {
case <-batcher.done:
return
default:
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool {
batcher.pendingChanges.Delete(id)
batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
nc.drainPending()
return true
})
time.Sleep(time.Millisecond) //nolint:forbidigo // benchmark drain loop
@@ -646,7 +647,7 @@ func BenchmarkAddNode(b *testing.B) {
// Connect all nodes (measuring AddNode cost)
for i := range allNodes {
node := &allNodes[i]
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
}
b.StopTimer()
@@ -707,7 +708,7 @@ func BenchmarkFullPipeline(b *testing.B) {
for i := range allNodes {
node := &allNodes[i]
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
if err != nil {
b.Fatalf("failed to add node %d: %v", i, err)
}
@@ -762,7 +763,7 @@ func BenchmarkMapResponseFromChange(b *testing.B) {
for i := range allNodes {
node := &allNodes[i]
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
if err != nil {
b.Fatalf("failed to add node %d: %v", i, err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -41,8 +41,7 @@ type LockFreeBatcher struct {
done chan struct{}
doneOnce sync.Once // Ensures done is only closed once
// Batching state
pendingChanges *xsync.Map[types.NodeID, []change.Change]
started atomic.Bool // Ensures Start() is only called once
// Metrics
totalNodes atomic.Int64
@@ -167,7 +166,12 @@ func (b *LockFreeBatcher) AddWork(r ...change.Change) {
}
func (b *LockFreeBatcher) Start() {
if !b.started.CompareAndSwap(false, true) {
return
}
b.done = make(chan struct{})
go b.doWork()
}
@@ -336,15 +340,16 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
}
b.connected.Delete(removedID)
b.pendingChanges.Delete(removedID)
}
}
// Short circuit if any of the changes is a full update, which
// means we can skip sending individual changes.
if change.HasFull(changes) {
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
b.pendingChanges.Store(nodeID, []change.Change{change.FullUpdate()})
b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
nc.pendingMu.Lock()
nc.pending = []change.Change{change.FullUpdate()}
nc.pendingMu.Unlock()
return true
})
@@ -356,20 +361,18 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
// Handle targeted changes - send only to the specific node
for _, ch := range targeted {
pending, _ := b.pendingChanges.LoadOrStore(ch.TargetNode, []change.Change{})
pending = append(pending, ch)
b.pendingChanges.Store(ch.TargetNode, pending)
if nc, ok := b.nodes.Load(ch.TargetNode); ok {
nc.appendPending(ch)
}
}
// Handle broadcast changes - send to all nodes, filtering as needed
if len(broadcast) > 0 {
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
filtered := change.FilterForNode(nodeID, broadcast)
if len(filtered) > 0 {
pending, _ := b.pendingChanges.LoadOrStore(nodeID, []change.Change{})
pending = append(pending, filtered...)
b.pendingChanges.Store(nodeID, pending)
nc.appendPending(filtered...)
}
return true
@@ -379,12 +382,8 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
// processBatchedChanges processes all pending batched changes.
func (b *LockFreeBatcher) processBatchedChanges() {
if b.pendingChanges == nil {
return
}
// Process all pending changes
b.pendingChanges.Range(func(nodeID types.NodeID, pending []change.Change) bool {
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
pending := nc.drainPending()
if len(pending) == 0 {
return true
}
@@ -394,9 +393,6 @@ func (b *LockFreeBatcher) processBatchedChanges() {
b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil})
}
// Clear the pending changes for this node
b.pendingChanges.Delete(nodeID)
return true
})
}
@@ -532,6 +528,13 @@ type multiChannelNodeConn struct {
mutex sync.RWMutex
connections []*connectionEntry
// pendingMu protects pending changes independently of the connection mutex.
// This avoids contention between addToBatch (which appends changes) and
// send() (which sends data to connections).
pendingMu sync.Mutex
pending []change.Change
closeOnce sync.Once
updateCount atomic.Int64
// lastSentPeers tracks which peers were last sent to this node.
@@ -560,12 +563,14 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC
}
func (mc *multiChannelNodeConn) close() {
mc.mutex.Lock()
defer mc.mutex.Unlock()
mc.closeOnce.Do(func() {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
mc.stopConnection(conn)
}
for _, conn := range mc.connections {
mc.stopConnection(conn)
}
})
}
// stopConnection marks a connection as closed and tears down the owning session
@@ -647,6 +652,25 @@ func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
return len(mc.connections)
}
// appendPending appends changes to this node's pending change list.
// Thread-safe via pendingMu; does not contend with the connection mutex.
func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {
mc.pendingMu.Lock()
mc.pending = append(mc.pending, changes...)
mc.pendingMu.Unlock()
}
// drainPending atomically removes and returns all pending changes.
// Returns nil if there are no pending changes.
func (mc *multiChannelNodeConn) drainPending() []change.Change {
mc.pendingMu.Lock()
p := mc.pending
mc.pending = nil
mc.pendingMu.Unlock()
return p
}
// send broadcasts data to all active connections for the node.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if data == nil {

View File

@@ -2302,8 +2302,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
t.Logf("=== RAPID RECONNECTION TEST ===")
t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes))
// Phase 1: Connect all nodes initially
t.Logf("Phase 1: Connecting all nodes...")
// Connect all nodes initially.
t.Logf("Connecting all nodes...")
for i := range allNodes {
node := &allNodes[i]
@@ -2321,8 +2321,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
}
}, 5*time.Second, 50*time.Millisecond, "waiting for connections to settle")
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
t.Logf("Phase 2: Rapid disconnect all nodes...")
// Rapid disconnect ALL nodes (simulating nodes going down).
t.Logf("Rapid disconnect all nodes...")
for i := range allNodes {
node := &allNodes[i]
@@ -2330,8 +2330,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
t.Logf("Node %d RemoveNode result: %t", i, removed)
}
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
t.Logf("Phase 3: Rapid reconnect with new channels...")
// Rapid reconnect with NEW channels (simulating nodes coming back up).
t.Logf("Rapid reconnect with new channels...")
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
for i := range allNodes {
@@ -2351,8 +2351,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
}
}, 5*time.Second, 50*time.Millisecond, "waiting for reconnections to settle")
// Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR
t.Logf("Phase 4: Checking debug status...")
// Check debug status after reconnection.
t.Logf("Checking debug status...")
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
@@ -2396,8 +2396,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
t.Logf("Batcher does not implement Debug() method")
}
// Phase 5: Test if "disconnected" nodes can actually receive updates
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
// Test if "disconnected" nodes can actually receive updates.
t.Logf("Testing if nodes can receive updates despite debug status...")
// Send a change that should reach all nodes
batcher.AddWork(change.DERPMap())
@@ -2442,8 +2442,8 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Logf("=== MULTI-CONNECTION TEST ===")
// Phase 1: Connect first node with initial connection
t.Logf("Phase 1: Connecting node 1 with first connection...")
// Connect first node with initial connection.
t.Logf("Connecting node 1 with first connection...")
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil)
if err != nil {
@@ -2462,8 +2462,8 @@ func TestBatcherMultiConnection(t *testing.T) {
assert.True(c, batcher.IsConnected(node2.n.ID), "node2 should be connected")
}, time.Second, 10*time.Millisecond, "waiting for initial connections")
// Phase 2: Add second connection for node1 (multi-connection scenario)
t.Logf("Phase 2: Adding second connection for node 1...")
// Add second connection for node1 (multi-connection scenario).
t.Logf("Adding second connection for node 1...")
secondChannel := make(chan *tailcfg.MapResponse, 10)
@@ -2475,8 +2475,8 @@ func TestBatcherMultiConnection(t *testing.T) {
// Yield to allow connection to be processed
runtime.Gosched()
// Phase 3: Add third connection for node1
t.Logf("Phase 3: Adding third connection for node 1...")
// Add third connection for node1.
t.Logf("Adding third connection for node 1...")
thirdChannel := make(chan *tailcfg.MapResponse, 10)
@@ -2488,8 +2488,8 @@ func TestBatcherMultiConnection(t *testing.T) {
// Yield to allow connection to be processed
runtime.Gosched()
// Phase 4: Verify debug status shows correct connection count
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
// Verify debug status shows correct connection count.
t.Logf("Verifying debug status shows multiple connections...")
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
@@ -2525,8 +2525,8 @@ func TestBatcherMultiConnection(t *testing.T) {
}
}
// Phase 5: Send update and verify ALL connections receive it
t.Logf("Phase 5: Testing update distribution to all connections...")
// Send update and verify ALL connections receive it.
t.Logf("Testing update distribution to all connections...")
// Clear any existing updates from all channels
clearChannel := func(ch chan *tailcfg.MapResponse) {
@@ -2591,8 +2591,8 @@ func TestBatcherMultiConnection(t *testing.T) {
connection1Received, connection2Received, connection3Received)
}
// Phase 6: Test connection removal and verify remaining connections still work
t.Logf("Phase 6: Testing connection removal...")
// Test connection removal and verify remaining connections still work.
t.Logf("Testing connection removal...")
// Remove the second connection
removed := batcher.RemoveNode(node1.n.ID, secondChannel)

6
hscontrol/util/norace.go Normal file
View File

@@ -0,0 +1,6 @@
//go:build !race
package util
// RaceEnabled is true when the race detector is active.
const RaceEnabled = false

6
hscontrol/util/race.go Normal file
View File

@@ -0,0 +1,6 @@
//go:build race
package util
// RaceEnabled is true when the race detector is active.
const RaceEnabled = true