mapper/batcher: track worker goroutines and stop ticker on Close

Close() previously closed the done channel and returned immediately,
without waiting for worker goroutines to exit. This caused goroutine
leaks in tests and allowed workers to race with connection teardown.
The ticker was also never stopped, leaking its internal goroutine.

Add a sync.WaitGroup to track the doWork goroutine and every worker
it spawns. Close() now calls wg.Wait() after signalling shutdown,
ensuring all goroutines have exited before tearing down connections.
Also stop the ticker to prevent resource leaks.

Document that a Batcher must not be reused after Close().
This commit is contained in:
Kristoffer Dalby
2026-03-13 15:32:15 +00:00
parent 3276bda0c0
commit 051a38a4c4

View File

@@ -185,6 +185,10 @@ var (
// Batcher batches and distributes map responses to connected nodes.
// It uses concurrent maps, per-node mutexes, and a worker pool.
//
// Lifecycle: Call Start() to spawn workers, then Close() to shut down.
// Close() blocks until all workers have exited. A Batcher must not
// be reused after Close().
type Batcher struct {
tick *time.Ticker
mapper *mapper
@@ -198,6 +202,10 @@ type Batcher struct {
done chan struct{}
doneOnce sync.Once // Ensures done is only closed once
// wg tracks the doWork and all worker goroutines so that Close()
// can block until they have fully exited.
wg sync.WaitGroup
started atomic.Bool // Ensures Start() is only called once
// Metrics
@@ -329,6 +337,8 @@ func (b *Batcher) Start() {
b.done = make(chan struct{})
b.wg.Add(1)
go b.doWork()
}
@@ -344,6 +354,14 @@ func (b *Batcher) Close() {
}
})
// Wait for all worker goroutines (and doWork) to exit before
// tearing down node connections. This prevents workers from
// sending on connections that are being closed concurrently.
b.wg.Wait()
// Stop the ticker to prevent resource leaks.
b.tick.Stop()
// Close the underlying channels supplying the data to the clients.
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
if conn == nil {
@@ -357,7 +375,11 @@ func (b *Batcher) Close() {
}
func (b *Batcher) doWork() {
defer b.wg.Done()
for i := range b.workers {
b.wg.Add(1)
go b.worker(i + 1)
}
@@ -381,6 +403,8 @@ func (b *Batcher) doWork() {
}
func (b *Batcher) worker(workerID int) {
defer b.wg.Done()
wlog := log.With().Int(zf.WorkerID, workerID).Logger()
for {