mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-17 03:33:36 +09:00
mapper/batcher: restructure internals for correctness
Move per-node pending changes from a shared xsync.Map on the batcher into multiChannelNodeConn, protected by a dedicated mutex. The new appendPending/drainPending methods provide atomic append and drain operations, eliminating data races in addToBatch and processBatchedChanges. Add sync.Once to multiChannelNodeConn.close() to make it idempotent, preventing panics from concurrent close calls on the same channel. Add started atomic.Bool to guard Start() against being called multiple times, preventing orphaned goroutines. Add comprehensive concurrency tests validating these changes.
This commit is contained in:
@@ -52,10 +52,9 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
|
||||
tick: time.NewTicker(batchTime),
|
||||
|
||||
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -150,13 +150,12 @@ func BenchmarkUpdateSentPeers(b *testing.B) {
|
||||
// helper, it doesn't register cleanup and suppresses logging.
|
||||
func benchBatcher(nodeCount, bufferSize int) (*LockFreeBatcher, map[types.NodeID]chan *tailcfg.MapResponse) {
|
||||
b := &LockFreeBatcher{
|
||||
tick: time.NewTicker(1 * time.Hour), // never fires during bench
|
||||
workers: 4,
|
||||
workCh: make(chan work, 4*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
|
||||
done: make(chan struct{}),
|
||||
tick: time.NewTicker(1 * time.Hour), // never fires during bench
|
||||
workers: 4,
|
||||
workCh: make(chan work, 4*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
channels := make(map[types.NodeID]chan *tailcfg.MapResponse, nodeCount)
|
||||
@@ -204,8 +203,8 @@ func BenchmarkAddToBatch_Broadcast(b *testing.B) {
|
||||
for range b.N {
|
||||
batcher.addToBatch(ch)
|
||||
// Clear pending to avoid unbounded growth
|
||||
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool {
|
||||
batcher.pendingChanges.Delete(id)
|
||||
batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
nc.drainPending()
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -242,8 +241,8 @@ func BenchmarkAddToBatch_Targeted(b *testing.B) {
|
||||
batcher.addToBatch(ch)
|
||||
// Clear pending periodically to avoid growth
|
||||
if i%100 == 99 {
|
||||
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool {
|
||||
batcher.pendingChanges.Delete(id)
|
||||
batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
nc.drainPending()
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -298,7 +297,9 @@ func BenchmarkProcessBatchedChanges(b *testing.B) {
|
||||
b.StopTimer()
|
||||
// Seed pending changes
|
||||
for i := 1; i <= nodeCount; i++ {
|
||||
batcher.pendingChanges.Store(types.NodeID(i), []change.Change{change.DERPMap()}) //nolint:gosec // benchmark
|
||||
if nc, ok := batcher.nodes.Load(types.NodeID(i)); ok { //nolint:gosec // benchmark
|
||||
nc.appendPending(change.DERPMap())
|
||||
}
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
@@ -411,8 +412,8 @@ func BenchmarkConcurrentAddToBatch(b *testing.B) {
|
||||
case <-batcher.done:
|
||||
return
|
||||
default:
|
||||
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool {
|
||||
batcher.pendingChanges.Delete(id)
|
||||
batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
nc.drainPending()
|
||||
return true
|
||||
})
|
||||
time.Sleep(time.Millisecond) //nolint:forbidigo // benchmark drain loop
|
||||
@@ -646,7 +647,7 @@ func BenchmarkAddNode(b *testing.B) {
|
||||
// Connect all nodes (measuring AddNode cost)
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
@@ -707,7 +708,7 @@ func BenchmarkFullPipeline(b *testing.B) {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to add node %d: %v", i, err)
|
||||
}
|
||||
@@ -762,7 +763,7 @@ func BenchmarkMapResponseFromChange(b *testing.B) {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to add node %d: %v", i, err)
|
||||
}
|
||||
|
||||
1731
hscontrol/mapper/batcher_concurrency_test.go
Normal file
1731
hscontrol/mapper/batcher_concurrency_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -41,8 +41,7 @@ type LockFreeBatcher struct {
|
||||
done chan struct{}
|
||||
doneOnce sync.Once // Ensures done is only closed once
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.Change]
|
||||
started atomic.Bool // Ensures Start() is only called once
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
@@ -167,7 +166,12 @@ func (b *LockFreeBatcher) AddWork(r ...change.Change) {
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
if !b.started.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
b.done = make(chan struct{})
|
||||
|
||||
go b.doWork()
|
||||
}
|
||||
|
||||
@@ -336,15 +340,16 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
|
||||
}
|
||||
|
||||
b.connected.Delete(removedID)
|
||||
b.pendingChanges.Delete(removedID)
|
||||
}
|
||||
}
|
||||
|
||||
// Short circuit if any of the changes is a full update, which
|
||||
// means we can skip sending individual changes.
|
||||
if change.HasFull(changes) {
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
b.pendingChanges.Store(nodeID, []change.Change{change.FullUpdate()})
|
||||
b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
nc.pendingMu.Lock()
|
||||
nc.pending = []change.Change{change.FullUpdate()}
|
||||
nc.pendingMu.Unlock()
|
||||
|
||||
return true
|
||||
})
|
||||
@@ -356,20 +361,18 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
|
||||
|
||||
// Handle targeted changes - send only to the specific node
|
||||
for _, ch := range targeted {
|
||||
pending, _ := b.pendingChanges.LoadOrStore(ch.TargetNode, []change.Change{})
|
||||
pending = append(pending, ch)
|
||||
b.pendingChanges.Store(ch.TargetNode, pending)
|
||||
if nc, ok := b.nodes.Load(ch.TargetNode); ok {
|
||||
nc.appendPending(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle broadcast changes - send to all nodes, filtering as needed
|
||||
if len(broadcast) > 0 {
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
filtered := change.FilterForNode(nodeID, broadcast)
|
||||
|
||||
if len(filtered) > 0 {
|
||||
pending, _ := b.pendingChanges.LoadOrStore(nodeID, []change.Change{})
|
||||
pending = append(pending, filtered...)
|
||||
b.pendingChanges.Store(nodeID, pending)
|
||||
nc.appendPending(filtered...)
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -379,12 +382,8 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
|
||||
|
||||
// processBatchedChanges processes all pending batched changes.
|
||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
if b.pendingChanges == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Process all pending changes
|
||||
b.pendingChanges.Range(func(nodeID types.NodeID, pending []change.Change) bool {
|
||||
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
pending := nc.drainPending()
|
||||
if len(pending) == 0 {
|
||||
return true
|
||||
}
|
||||
@@ -394,9 +393,6 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil})
|
||||
}
|
||||
|
||||
// Clear the pending changes for this node
|
||||
b.pendingChanges.Delete(nodeID)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -532,6 +528,13 @@ type multiChannelNodeConn struct {
|
||||
mutex sync.RWMutex
|
||||
connections []*connectionEntry
|
||||
|
||||
// pendingMu protects pending changes independently of the connection mutex.
|
||||
// This avoids contention between addToBatch (which appends changes) and
|
||||
// send() (which sends data to connections).
|
||||
pendingMu sync.Mutex
|
||||
pending []change.Change
|
||||
|
||||
closeOnce sync.Once
|
||||
updateCount atomic.Int64
|
||||
|
||||
// lastSentPeers tracks which peers were last sent to this node.
|
||||
@@ -560,12 +563,14 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC
|
||||
}
|
||||
|
||||
func (mc *multiChannelNodeConn) close() {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
mc.closeOnce.Do(func() {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
for _, conn := range mc.connections {
|
||||
mc.stopConnection(conn)
|
||||
}
|
||||
for _, conn := range mc.connections {
|
||||
mc.stopConnection(conn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// stopConnection marks a connection as closed and tears down the owning session
|
||||
@@ -647,6 +652,25 @@ func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
|
||||
return len(mc.connections)
|
||||
}
|
||||
|
||||
// appendPending appends changes to this node's pending change list.
|
||||
// Thread-safe via pendingMu; does not contend with the connection mutex.
|
||||
func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {
|
||||
mc.pendingMu.Lock()
|
||||
mc.pending = append(mc.pending, changes...)
|
||||
mc.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
// drainPending atomically removes and returns all pending changes.
|
||||
// Returns nil if there are no pending changes.
|
||||
func (mc *multiChannelNodeConn) drainPending() []change.Change {
|
||||
mc.pendingMu.Lock()
|
||||
p := mc.pending
|
||||
mc.pending = nil
|
||||
mc.pendingMu.Unlock()
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// send broadcasts data to all active connections for the node.
|
||||
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
if data == nil {
|
||||
|
||||
@@ -2302,8 +2302,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
t.Logf("=== RAPID RECONNECTION TEST ===")
|
||||
t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes))
|
||||
|
||||
// Phase 1: Connect all nodes initially
|
||||
t.Logf("Phase 1: Connecting all nodes...")
|
||||
// Connect all nodes initially.
|
||||
t.Logf("Connecting all nodes...")
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
@@ -2321,8 +2321,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
}
|
||||
}, 5*time.Second, 50*time.Millisecond, "waiting for connections to settle")
|
||||
|
||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||
// Rapid disconnect ALL nodes (simulating nodes going down).
|
||||
t.Logf("Rapid disconnect all nodes...")
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
@@ -2330,8 +2330,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||
}
|
||||
|
||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||
// Rapid reconnect with NEW channels (simulating nodes coming back up).
|
||||
t.Logf("Rapid reconnect with new channels...")
|
||||
|
||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||
for i := range allNodes {
|
||||
@@ -2351,8 +2351,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
}
|
||||
}, 5*time.Second, 50*time.Millisecond, "waiting for reconnections to settle")
|
||||
|
||||
// Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR
|
||||
t.Logf("Phase 4: Checking debug status...")
|
||||
// Check debug status after reconnection.
|
||||
t.Logf("Checking debug status...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
@@ -2396,8 +2396,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
t.Logf("Batcher does not implement Debug() method")
|
||||
}
|
||||
|
||||
// Phase 5: Test if "disconnected" nodes can actually receive updates
|
||||
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
|
||||
// Test if "disconnected" nodes can actually receive updates.
|
||||
t.Logf("Testing if nodes can receive updates despite debug status...")
|
||||
|
||||
// Send a change that should reach all nodes
|
||||
batcher.AddWork(change.DERPMap())
|
||||
@@ -2442,8 +2442,8 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
t.Logf("=== MULTI-CONNECTION TEST ===")
|
||||
|
||||
// Phase 1: Connect first node with initial connection
|
||||
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||
// Connect first node with initial connection.
|
||||
t.Logf("Connecting node 1 with first connection...")
|
||||
|
||||
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil)
|
||||
if err != nil {
|
||||
@@ -2462,8 +2462,8 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
assert.True(c, batcher.IsConnected(node2.n.ID), "node2 should be connected")
|
||||
}, time.Second, 10*time.Millisecond, "waiting for initial connections")
|
||||
|
||||
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||
// Add second connection for node1 (multi-connection scenario).
|
||||
t.Logf("Adding second connection for node 1...")
|
||||
|
||||
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
@@ -2475,8 +2475,8 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
// Yield to allow connection to be processed
|
||||
runtime.Gosched()
|
||||
|
||||
// Phase 3: Add third connection for node1
|
||||
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||
// Add third connection for node1.
|
||||
t.Logf("Adding third connection for node 1...")
|
||||
|
||||
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
@@ -2488,8 +2488,8 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
// Yield to allow connection to be processed
|
||||
runtime.Gosched()
|
||||
|
||||
// Phase 4: Verify debug status shows correct connection count
|
||||
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||
// Verify debug status shows correct connection count.
|
||||
t.Logf("Verifying debug status shows multiple connections...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
@@ -2525,8 +2525,8 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Send update and verify ALL connections receive it
|
||||
t.Logf("Phase 5: Testing update distribution to all connections...")
|
||||
// Send update and verify ALL connections receive it.
|
||||
t.Logf("Testing update distribution to all connections...")
|
||||
|
||||
// Clear any existing updates from all channels
|
||||
clearChannel := func(ch chan *tailcfg.MapResponse) {
|
||||
@@ -2591,8 +2591,8 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
connection1Received, connection2Received, connection3Received)
|
||||
}
|
||||
|
||||
// Phase 6: Test connection removal and verify remaining connections still work
|
||||
t.Logf("Phase 6: Testing connection removal...")
|
||||
// Test connection removal and verify remaining connections still work.
|
||||
t.Logf("Testing connection removal...")
|
||||
|
||||
// Remove the second connection
|
||||
removed := batcher.RemoveNode(node1.n.ID, secondChannel)
|
||||
|
||||
6
hscontrol/util/norace.go
Normal file
6
hscontrol/util/norace.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build !race
|
||||
|
||||
package util
|
||||
|
||||
// RaceEnabled is true when the race detector is active.
|
||||
const RaceEnabled = false
|
||||
6
hscontrol/util/race.go
Normal file
6
hscontrol/util/race.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build race
|
||||
|
||||
package util
|
||||
|
||||
// RaceEnabled is true when the race detector is active.
|
||||
const RaceEnabled = true
|
||||
Reference in New Issue
Block a user