diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 1d9dc924..d16949e0 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -185,6 +185,10 @@ var ( // Batcher batches and distributes map responses to connected nodes. // It uses concurrent maps, per-node mutexes, and a worker pool. +// +// Lifecycle: Call Start() to spawn workers, then Close() to shut down. +// Close() blocks until all workers have exited. A Batcher must not +// be reused after Close(). type Batcher struct { tick *time.Ticker mapper *mapper @@ -198,6 +202,10 @@ type Batcher struct { done chan struct{} doneOnce sync.Once // Ensures done is only closed once + // wg tracks the doWork and all worker goroutines so that Close() + // can block until they have fully exited. + wg sync.WaitGroup + started atomic.Bool // Ensures Start() is only called once // Metrics @@ -329,6 +337,8 @@ func (b *Batcher) Start() { b.done = make(chan struct{}) + b.wg.Add(1) + go b.doWork() } @@ -344,6 +354,14 @@ func (b *Batcher) Close() { } }) + // Wait for all worker goroutines (and doWork) to exit before + // tearing down node connections. This prevents workers from + // sending on connections that are being closed concurrently. + b.wg.Wait() + + // Stop the ticker to prevent resource leaks. + b.tick.Stop() + // Close the underlying channels supplying the data to the clients. b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool { if conn == nil { @@ -357,7 +375,11 @@ func (b *Batcher) Close() { } func (b *Batcher) doWork() { + defer b.wg.Done() + for i := range b.workers { + b.wg.Add(1) + go b.worker(i + 1) } @@ -381,6 +403,8 @@ func (b *Batcher) doWork() { } func (b *Batcher) worker(workerID int) { + defer b.wg.Done() + wlog := log.With().Int(zf.WorkerID, workerID).Logger() for {