diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index c800aefc..fbbb2519 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -43,6 +43,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *Batcher { // The size of this channel is arbitrary chosen, the sizing should be revisited. workCh: make(chan work, workers*200), + done: make(chan struct{}), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), connected: xsync.NewMap[types.NodeID, *time.Time](), } @@ -258,6 +259,7 @@ func (b *Batcher) AddNode( if err != nil { nlog.Error().Err(err).Msg("initial map generation failed") nodeConn.removeConnectionByChannel(c) + b.markDisconnectedIfNoConns(id, nodeConn) return fmt.Errorf("generating initial map for node %d: %w", id, err) } @@ -272,6 +274,7 @@ func (b *Batcher) AddNode( nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd Msg("initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) + b.markDisconnectedIfNoConns(id, nodeConn) return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id) } @@ -335,8 +338,6 @@ func (b *Batcher) Start() { return } - b.done = make(chan struct{}) - b.wg.Add(1) go b.doWork() @@ -349,9 +350,7 @@ func (b *Batcher) Close() { // close workCh here because processBatchedChanges or // MapResponseFromChange may still be sending on it concurrently. b.doneOnce.Do(func() { - if b.done != nil { - close(b.done) - } + close(b.done) }) // Wait for all worker goroutines (and doWork) to exit before @@ -619,8 +618,6 @@ func (b *Batcher) cleanupOfflineNodes() { cleaned := 0 for _, nodeID := range nodesToCleanup { - deleted := false - b.nodes.Compute( nodeID, func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) { @@ -628,22 +625,22 @@ func (b *Batcher) cleanupOfflineNodes() { return conn, xsync.CancelOp } - deleted = true + // Perform all bookkeeping inside the Compute callback so + // that a concurrent AddNode (which calls LoadOrStore on + // b.nodes) cannot slip in between the delete and the + // connected/counter updates. + b.connected.Delete(nodeID) + b.totalNodes.Add(-1) + + cleaned++ + + log.Info().Uint64(zf.NodeID, nodeID.Uint64()). + Dur("offline_duration", cleanupThreshold). + Msg("cleaning up node that has been offline for too long") return conn, xsync.DeleteOp }, ) - - if deleted { - log.Info().Uint64(zf.NodeID, nodeID.Uint64()). - Dur("offline_duration", cleanupThreshold). - Msg("cleaning up node that has been offline for too long") - - b.connected.Delete(nodeID) - b.totalNodes.Add(-1) - - cleaned++ - } } if cleaned > 0 { @@ -711,6 +708,17 @@ func (b *Batcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { return ret } +// markDisconnectedIfNoConns stores a disconnect timestamp in b.connected +// when the node has no remaining active connections. This prevents +// IsConnected from returning a stale true after all connections have been +// removed on an error path (e.g. AddNode failure). +func (b *Batcher) markDisconnectedIfNoConns(id types.NodeID, nc *multiChannelNodeConn) { + if !nc.hasActiveConnections() { + now := time.Now() + b.connected.Store(id, &now) + } +} + // MapResponseFromChange queues work to generate a map response and waits for the result. // This allows synchronous map generation using the same worker pool. func (b *Batcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) { @@ -808,7 +816,9 @@ func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) { // Caller must hold mc.mutex. func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry { conn := mc.connections[i] - mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) + copy(mc.connections[i:], mc.connections[i+1:]) + mc.connections[len(mc.connections)-1] = nil // release pointer for GC + mc.connections = mc.connections[:len(mc.connections)-1] if stopConnection { mc.stopConnection(conn) @@ -964,6 +974,12 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { } } + // Nil out trailing slots so removed *connectionEntry values + // are not retained by the backing array. + for i := len(clean); i < len(mc.connections); i++ { + mc.connections[i] = nil + } + mc.connections = clean mc.mutex.Unlock() }