mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-17 11:43:36 +09:00
L1: Replace crypto/rand with an atomic counter for generating connection IDs. These identifiers are process-local and do not need cryptographic randomness; a monotonic counter is cheaper and produces shorter, sortable IDs. L5: Use getActiveConnectionCount() in Debug() instead of directly locking the mutex and reading the connections slice. This avoids bypassing the accessor that already exists for this purpose. L6: Extract the hardcoded 15*time.Minute cleanup threshold into the named constant offlineNodeCleanupThreshold. L7: Inline the trivial addWork wrapper; AddWork now calls addToBatch directly. Updates #2545
391 lines
12 KiB
Go
391 lines
12 KiB
Go
package mapper
|
|
|
|
import (
|
|
"fmt"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
|
"github.com/puzpuzpuz/xsync/v4"
|
|
"github.com/rs/zerolog"
|
|
"github.com/rs/zerolog/log"
|
|
"tailscale.com/tailcfg"
|
|
)
|
|
|
|
// connectionEntry represents a single connection to a node.
|
|
type connectionEntry struct {
|
|
id string // unique connection ID
|
|
c chan<- *tailcfg.MapResponse
|
|
version tailcfg.CapabilityVersion
|
|
created time.Time
|
|
stop func()
|
|
lastUsed atomic.Int64 // Unix timestamp of last successful send
|
|
closed atomic.Bool // Indicates if this connection has been closed
|
|
}
|
|
|
|
// multiChannelNodeConn manages multiple concurrent connections for a single node.
|
|
type multiChannelNodeConn struct {
|
|
id types.NodeID
|
|
mapper *mapper
|
|
log zerolog.Logger
|
|
|
|
mutex sync.RWMutex
|
|
connections []*connectionEntry
|
|
|
|
// pendingMu protects pending changes independently of the connection mutex.
|
|
// This avoids contention between addToBatch (which appends changes) and
|
|
// send() (which sends data to connections).
|
|
pendingMu sync.Mutex
|
|
pending []change.Change
|
|
|
|
closeOnce sync.Once
|
|
updateCount atomic.Int64
|
|
|
|
// lastSentPeers tracks which peers were last sent to this node.
|
|
// This enables computing diffs for policy changes instead of sending
|
|
// full peer lists (which clients interpret as "no change" when empty).
|
|
// Using xsync.Map for lock-free concurrent access.
|
|
lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}]
|
|
}
|
|
|
|
// connIDCounter is a monotonically increasing counter used to generate
|
|
// unique connection identifiers without the overhead of crypto/rand.
|
|
// Connection IDs are process-local and need not be cryptographically random.
|
|
var connIDCounter atomic.Uint64
|
|
|
|
// generateConnectionID generates a unique connection identifier.
|
|
func generateConnectionID() string {
|
|
return strconv.FormatUint(connIDCounter.Add(1), 10)
|
|
}
|
|
|
|
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
|
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
|
|
return &multiChannelNodeConn{
|
|
id: id,
|
|
mapper: mapper,
|
|
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
|
|
log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(),
|
|
}
|
|
}
|
|
|
|
func (mc *multiChannelNodeConn) close() {
|
|
mc.closeOnce.Do(func() {
|
|
mc.mutex.Lock()
|
|
defer mc.mutex.Unlock()
|
|
|
|
for _, conn := range mc.connections {
|
|
mc.stopConnection(conn)
|
|
}
|
|
})
|
|
}
|
|
|
|
// stopConnection marks a connection as closed and tears down the owning session
|
|
// at most once, even if multiple cleanup paths race to remove it.
|
|
func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) {
|
|
if conn.closed.CompareAndSwap(false, true) {
|
|
if conn.stop != nil {
|
|
conn.stop()
|
|
}
|
|
}
|
|
}
|
|
|
|
// removeConnectionAtIndexLocked removes the active connection at index.
|
|
// If stopConnection is true, it also stops that session.
|
|
// Caller must hold mc.mutex.
|
|
func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry {
|
|
conn := mc.connections[i]
|
|
copy(mc.connections[i:], mc.connections[i+1:])
|
|
mc.connections[len(mc.connections)-1] = nil // release pointer for GC
|
|
mc.connections = mc.connections[:len(mc.connections)-1]
|
|
|
|
if stopConnection {
|
|
mc.stopConnection(conn)
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
// addConnection adds a new connection.
|
|
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
|
mc.mutex.Lock()
|
|
defer mc.mutex.Unlock()
|
|
|
|
mc.connections = append(mc.connections, entry)
|
|
mc.log.Debug().Str(zf.ConnID, entry.id).
|
|
Int("total_connections", len(mc.connections)).
|
|
Msg("connection added")
|
|
}
|
|
|
|
// removeConnectionByChannel removes a connection by matching channel pointer.
|
|
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
|
|
mc.mutex.Lock()
|
|
defer mc.mutex.Unlock()
|
|
|
|
for i, entry := range mc.connections {
|
|
if entry.c == c {
|
|
mc.removeConnectionAtIndexLocked(i, false)
|
|
mc.log.Debug().Str(zf.ConnID, entry.id).
|
|
Int("remaining_connections", len(mc.connections)).
|
|
Msg("connection removed")
|
|
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// hasActiveConnections checks if the node has any active connections.
|
|
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
|
|
mc.mutex.RLock()
|
|
defer mc.mutex.RUnlock()
|
|
|
|
return len(mc.connections) > 0
|
|
}
|
|
|
|
// getActiveConnectionCount returns the number of active connections.
|
|
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
|
|
mc.mutex.RLock()
|
|
defer mc.mutex.RUnlock()
|
|
|
|
return len(mc.connections)
|
|
}
|
|
|
|
// appendPending appends changes to this node's pending change list.
|
|
// Thread-safe via pendingMu; does not contend with the connection mutex.
|
|
func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {
|
|
mc.pendingMu.Lock()
|
|
mc.pending = append(mc.pending, changes...)
|
|
mc.pendingMu.Unlock()
|
|
}
|
|
|
|
// drainPending atomically removes and returns all pending changes.
|
|
// Returns nil if there are no pending changes.
|
|
func (mc *multiChannelNodeConn) drainPending() []change.Change {
|
|
mc.pendingMu.Lock()
|
|
p := mc.pending
|
|
mc.pending = nil
|
|
mc.pendingMu.Unlock()
|
|
|
|
return p
|
|
}
|
|
|
|
// send broadcasts data to all active connections for the node.
|
|
//
|
|
// To avoid holding the write lock during potentially slow sends (each stale
|
|
// connection can block for up to 50ms), the method snapshots connections under
|
|
// a read lock, sends without any lock held, then write-locks only to remove
|
|
// failures. New connections added between the snapshot and cleanup are safe:
|
|
// they receive a full initial map via AddNode, so missing this update causes
|
|
// no data loss.
|
|
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|
if data == nil {
|
|
return nil
|
|
}
|
|
|
|
// Snapshot connections under read lock.
|
|
mc.mutex.RLock()
|
|
|
|
if len(mc.connections) == 0 {
|
|
mc.mutex.RUnlock()
|
|
mc.log.Trace().
|
|
Msg("send: no active connections, skipping")
|
|
|
|
return nil
|
|
}
|
|
|
|
// Copy the slice so we can release the read lock before sending.
|
|
snapshot := make([]*connectionEntry, len(mc.connections))
|
|
copy(snapshot, mc.connections)
|
|
mc.mutex.RUnlock()
|
|
|
|
mc.log.Trace().
|
|
Int("total_connections", len(snapshot)).
|
|
Msg("send: broadcasting")
|
|
|
|
// Send to all connections without holding any lock.
|
|
// Stale connection timeouts (50ms each) happen here without blocking
|
|
// other goroutines that need the mutex.
|
|
var (
|
|
lastErr error
|
|
successCount int
|
|
failed []*connectionEntry
|
|
)
|
|
|
|
for _, conn := range snapshot {
|
|
err := conn.send(data)
|
|
if err != nil {
|
|
lastErr = err
|
|
|
|
failed = append(failed, conn)
|
|
|
|
mc.log.Warn().Err(err).
|
|
Str(zf.ConnID, conn.id).
|
|
Msg("send: connection failed")
|
|
} else {
|
|
successCount++
|
|
}
|
|
}
|
|
|
|
// Write-lock only to remove failed connections.
|
|
if len(failed) > 0 {
|
|
mc.mutex.Lock()
|
|
// Remove by pointer identity: only remove entries that still exist
|
|
// in the current connections slice and match a failed pointer.
|
|
// New connections added since the snapshot are not affected.
|
|
failedSet := make(map[*connectionEntry]struct{}, len(failed))
|
|
for _, f := range failed {
|
|
failedSet[f] = struct{}{}
|
|
}
|
|
|
|
clean := mc.connections[:0]
|
|
for _, conn := range mc.connections {
|
|
if _, isFailed := failedSet[conn]; !isFailed {
|
|
clean = append(clean, conn)
|
|
} else {
|
|
mc.log.Debug().
|
|
Str(zf.ConnID, conn.id).
|
|
Msg("send: removing failed connection")
|
|
// Tear down the owning session so the old serveLongPoll
|
|
// goroutine exits instead of lingering as a stale session.
|
|
mc.stopConnection(conn)
|
|
}
|
|
}
|
|
|
|
// Nil out trailing slots so removed *connectionEntry values
|
|
// are not retained by the backing array.
|
|
for i := len(clean); i < len(mc.connections); i++ {
|
|
mc.connections[i] = nil
|
|
}
|
|
|
|
mc.connections = clean
|
|
mc.mutex.Unlock()
|
|
}
|
|
|
|
mc.updateCount.Add(1)
|
|
|
|
mc.log.Trace().
|
|
Int("successful_sends", successCount).
|
|
Int("failed_connections", len(failed)).
|
|
Msg("send: broadcast complete")
|
|
|
|
// Success if at least one send succeeded
|
|
if successCount > 0 {
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
|
|
}
|
|
|
|
// send sends data to a single connection entry with timeout-based stale connection detection.
|
|
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
|
if data == nil {
|
|
return nil
|
|
}
|
|
|
|
// Check if the connection has been closed to prevent send on closed channel panic.
|
|
// This can happen during shutdown when Close() is called while workers are still processing.
|
|
if entry.closed.Load() {
|
|
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
|
|
}
|
|
|
|
// Use a short timeout to detect stale connections where the client isn't reading the channel.
|
|
// This is critical for detecting Docker containers that are forcefully terminated
|
|
// but still have channels that appear open.
|
|
//
|
|
// We use time.NewTimer + Stop instead of time.After to avoid leaking timers.
|
|
// time.After creates a timer that lives in the runtime's timer heap until it fires,
|
|
// even when the send succeeds immediately. On the hot path (1000+ nodes per tick),
|
|
// this leaks thousands of timers per second.
|
|
timer := time.NewTimer(50 * time.Millisecond) //nolint:mnd
|
|
defer timer.Stop()
|
|
|
|
select {
|
|
case entry.c <- data:
|
|
// Update last used timestamp on successful send
|
|
entry.lastUsed.Store(time.Now().Unix())
|
|
return nil
|
|
case <-timer.C:
|
|
// Connection is likely stale - client isn't reading from channel
|
|
// This catches the case where Docker containers are killed but channels remain open
|
|
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
|
|
}
|
|
}
|
|
|
|
// nodeID returns the node ID.
|
|
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
|
|
return mc.id
|
|
}
|
|
|
|
// version returns the capability version from the first active connection.
|
|
// All connections for a node should have the same version in practice.
|
|
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
|
|
mc.mutex.RLock()
|
|
defer mc.mutex.RUnlock()
|
|
|
|
if len(mc.connections) == 0 {
|
|
return 0
|
|
}
|
|
|
|
return mc.connections[0].version
|
|
}
|
|
|
|
// updateSentPeers updates the tracked peer state based on a sent MapResponse.
|
|
// This must be called after successfully sending a response to keep track of
|
|
// what the client knows about, enabling accurate diffs for future updates.
|
|
func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) {
|
|
if resp == nil {
|
|
return
|
|
}
|
|
|
|
// Full peer list replaces tracked state entirely
|
|
if resp.Peers != nil {
|
|
mc.lastSentPeers.Clear()
|
|
|
|
for _, peer := range resp.Peers {
|
|
mc.lastSentPeers.Store(peer.ID, struct{}{})
|
|
}
|
|
}
|
|
|
|
// Incremental additions
|
|
for _, peer := range resp.PeersChanged {
|
|
mc.lastSentPeers.Store(peer.ID, struct{}{})
|
|
}
|
|
|
|
// Incremental removals
|
|
for _, id := range resp.PeersRemoved {
|
|
mc.lastSentPeers.Delete(id)
|
|
}
|
|
}
|
|
|
|
// computePeerDiff compares the current peer list against what was last sent
|
|
// and returns the peers that were removed (in lastSentPeers but not in current).
|
|
func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
|
|
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
|
|
for _, id := range currentPeers {
|
|
currentSet[id] = struct{}{}
|
|
}
|
|
|
|
var removed []tailcfg.NodeID
|
|
|
|
// Find removed: in lastSentPeers but not in current
|
|
mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
|
|
if _, exists := currentSet[id]; !exists {
|
|
removed = append(removed, id)
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
return removed
|
|
}
|
|
|
|
// change applies a change to all active connections for the node.
|
|
func (mc *multiChannelNodeConn) change(r change.Change) error {
|
|
return handleNodeChange(mc, mc.mapper, r)
|
|
}
|