Files
headscale/hscontrol/mapper/node_conn.go
Kristoffer Dalby 86e279869e mapper/batcher: minor production code cleanup
L1: Replace crypto/rand with an atomic counter for generating
connection IDs. These identifiers are process-local and do not need
cryptographic randomness; a monotonic counter is cheaper and
produces shorter, sortable IDs.

L5: Use getActiveConnectionCount() in Debug() instead of directly
locking the mutex and reading the connections slice. This avoids
bypassing the accessor that already exists for this purpose.

L6: Extract the hardcoded 15*time.Minute cleanup threshold into
the named constant offlineNodeCleanupThreshold.

L7: Inline the trivial addWork wrapper; AddWork now calls addToBatch
directly.

Updates #2545
2026-03-14 02:52:28 -07:00

391 lines
12 KiB
Go

package mapper
import (
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
// connectionEntry represents a single connection to a node.
type connectionEntry struct {
id string // unique connection ID
c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion
created time.Time
stop func()
lastUsed atomic.Int64 // Unix timestamp of last successful send
closed atomic.Bool // Indicates if this connection has been closed
}
// multiChannelNodeConn manages multiple concurrent connections for a single node.
type multiChannelNodeConn struct {
id types.NodeID
mapper *mapper
log zerolog.Logger
mutex sync.RWMutex
connections []*connectionEntry
// pendingMu protects pending changes independently of the connection mutex.
// This avoids contention between addToBatch (which appends changes) and
// send() (which sends data to connections).
pendingMu sync.Mutex
pending []change.Change
closeOnce sync.Once
updateCount atomic.Int64
// lastSentPeers tracks which peers were last sent to this node.
// This enables computing diffs for policy changes instead of sending
// full peer lists (which clients interpret as "no change" when empty).
// Using xsync.Map for lock-free concurrent access.
lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}]
}
// connIDCounter is a monotonically increasing counter used to generate
// unique connection identifiers without the overhead of crypto/rand.
// Connection IDs are process-local and need not be cryptographically random.
var connIDCounter atomic.Uint64
// generateConnectionID generates a unique connection identifier.
func generateConnectionID() string {
return strconv.FormatUint(connIDCounter.Add(1), 10)
}
// newMultiChannelNodeConn creates a new multi-channel node connection.
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
return &multiChannelNodeConn{
id: id,
mapper: mapper,
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(),
}
}
func (mc *multiChannelNodeConn) close() {
mc.closeOnce.Do(func() {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
mc.stopConnection(conn)
}
})
}
// stopConnection marks a connection as closed and tears down the owning session
// at most once, even if multiple cleanup paths race to remove it.
func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) {
if conn.closed.CompareAndSwap(false, true) {
if conn.stop != nil {
conn.stop()
}
}
}
// removeConnectionAtIndexLocked removes the active connection at index.
// If stopConnection is true, it also stops that session.
// Caller must hold mc.mutex.
func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry {
conn := mc.connections[i]
copy(mc.connections[i:], mc.connections[i+1:])
mc.connections[len(mc.connections)-1] = nil // release pointer for GC
mc.connections = mc.connections[:len(mc.connections)-1]
if stopConnection {
mc.stopConnection(conn)
}
return conn
}
// addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mc.mutex.Lock()
defer mc.mutex.Unlock()
mc.connections = append(mc.connections, entry)
mc.log.Debug().Str(zf.ConnID, entry.id).
Int("total_connections", len(mc.connections)).
Msg("connection added")
}
// removeConnectionByChannel removes a connection by matching channel pointer.
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for i, entry := range mc.connections {
if entry.c == c {
mc.removeConnectionAtIndexLocked(i, false)
mc.log.Debug().Str(zf.ConnID, entry.id).
Int("remaining_connections", len(mc.connections)).
Msg("connection removed")
return true
}
}
return false
}
// hasActiveConnections checks if the node has any active connections.
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
return len(mc.connections) > 0
}
// getActiveConnectionCount returns the number of active connections.
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
return len(mc.connections)
}
// appendPending appends changes to this node's pending change list.
// Thread-safe via pendingMu; does not contend with the connection mutex.
func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {
mc.pendingMu.Lock()
mc.pending = append(mc.pending, changes...)
mc.pendingMu.Unlock()
}
// drainPending atomically removes and returns all pending changes.
// Returns nil if there are no pending changes.
func (mc *multiChannelNodeConn) drainPending() []change.Change {
mc.pendingMu.Lock()
p := mc.pending
mc.pending = nil
mc.pendingMu.Unlock()
return p
}
// send broadcasts data to all active connections for the node.
//
// To avoid holding the write lock during potentially slow sends (each stale
// connection can block for up to 50ms), the method snapshots connections under
// a read lock, sends without any lock held, then write-locks only to remove
// failures. New connections added between the snapshot and cleanup are safe:
// they receive a full initial map via AddNode, so missing this update causes
// no data loss.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
// Snapshot connections under read lock.
mc.mutex.RLock()
if len(mc.connections) == 0 {
mc.mutex.RUnlock()
mc.log.Trace().
Msg("send: no active connections, skipping")
return nil
}
// Copy the slice so we can release the read lock before sending.
snapshot := make([]*connectionEntry, len(mc.connections))
copy(snapshot, mc.connections)
mc.mutex.RUnlock()
mc.log.Trace().
Int("total_connections", len(snapshot)).
Msg("send: broadcasting")
// Send to all connections without holding any lock.
// Stale connection timeouts (50ms each) happen here without blocking
// other goroutines that need the mutex.
var (
lastErr error
successCount int
failed []*connectionEntry
)
for _, conn := range snapshot {
err := conn.send(data)
if err != nil {
lastErr = err
failed = append(failed, conn)
mc.log.Warn().Err(err).
Str(zf.ConnID, conn.id).
Msg("send: connection failed")
} else {
successCount++
}
}
// Write-lock only to remove failed connections.
if len(failed) > 0 {
mc.mutex.Lock()
// Remove by pointer identity: only remove entries that still exist
// in the current connections slice and match a failed pointer.
// New connections added since the snapshot are not affected.
failedSet := make(map[*connectionEntry]struct{}, len(failed))
for _, f := range failed {
failedSet[f] = struct{}{}
}
clean := mc.connections[:0]
for _, conn := range mc.connections {
if _, isFailed := failedSet[conn]; !isFailed {
clean = append(clean, conn)
} else {
mc.log.Debug().
Str(zf.ConnID, conn.id).
Msg("send: removing failed connection")
// Tear down the owning session so the old serveLongPoll
// goroutine exits instead of lingering as a stale session.
mc.stopConnection(conn)
}
}
// Nil out trailing slots so removed *connectionEntry values
// are not retained by the backing array.
for i := len(clean); i < len(mc.connections); i++ {
mc.connections[i] = nil
}
mc.connections = clean
mc.mutex.Unlock()
}
mc.updateCount.Add(1)
mc.log.Trace().
Int("successful_sends", successCount).
Int("failed_connections", len(failed)).
Msg("send: broadcast complete")
// Success if at least one send succeeded
if successCount > 0 {
return nil
}
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
}
// send sends data to a single connection entry with timeout-based stale connection detection.
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
// Check if the connection has been closed to prevent send on closed channel panic.
// This can happen during shutdown when Close() is called while workers are still processing.
if entry.closed.Load() {
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
}
// Use a short timeout to detect stale connections where the client isn't reading the channel.
// This is critical for detecting Docker containers that are forcefully terminated
// but still have channels that appear open.
//
// We use time.NewTimer + Stop instead of time.After to avoid leaking timers.
// time.After creates a timer that lives in the runtime's timer heap until it fires,
// even when the send succeeds immediately. On the hot path (1000+ nodes per tick),
// this leaks thousands of timers per second.
timer := time.NewTimer(50 * time.Millisecond) //nolint:mnd
defer timer.Stop()
select {
case entry.c <- data:
// Update last used timestamp on successful send
entry.lastUsed.Store(time.Now().Unix())
return nil
case <-timer.C:
// Connection is likely stale - client isn't reading from channel
// This catches the case where Docker containers are killed but channels remain open
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
}
}
// nodeID returns the node ID.
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
return mc.id
}
// version returns the capability version from the first active connection.
// All connections for a node should have the same version in practice.
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
if len(mc.connections) == 0 {
return 0
}
return mc.connections[0].version
}
// updateSentPeers updates the tracked peer state based on a sent MapResponse.
// This must be called after successfully sending a response to keep track of
// what the client knows about, enabling accurate diffs for future updates.
func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) {
if resp == nil {
return
}
// Full peer list replaces tracked state entirely
if resp.Peers != nil {
mc.lastSentPeers.Clear()
for _, peer := range resp.Peers {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
}
// Incremental additions
for _, peer := range resp.PeersChanged {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
// Incremental removals
for _, id := range resp.PeersRemoved {
mc.lastSentPeers.Delete(id)
}
}
// computePeerDiff compares the current peer list against what was last sent
// and returns the peers that were removed (in lastSentPeers but not in current).
func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
for _, id := range currentPeers {
currentSet[id] = struct{}{}
}
var removed []tailcfg.NodeID
// Find removed: in lastSentPeers but not in current
mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
if _, exists := currentSet[id]; !exists {
removed = append(removed, id)
}
return true
})
return removed
}
// change applies a change to all active connections for the node.
func (mc *multiChannelNodeConn) change(r change.Change) error {
return handleNodeChange(mc, mc.mapper, r)
}