mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-17 03:33:36 +09:00
mapper/batcher: add unit tests and benchmarks
Add comprehensive unit tests for the LockFreeBatcher covering AddNode/RemoveNode lifecycle, addToBatch routing (broadcast, targeted, full update), processBatchedChanges deduplication, cleanup of offline nodes, close/shutdown behavior, IsConnected state tracking, and connected map consistency. Add benchmarks for connection entry send, multi-channel send and broadcast, peer diff computation, sentPeers updates, addToBatch at various scales (10/100/1000 nodes), processBatchedChanges, broadcast delivery, IsConnected lookups, connected map enumeration, connection churn, and concurrent send+churn scenarios. Widen setupBatcherWithTestData to accept testing.TB so benchmarks can reuse the same database-backed test setup as unit tests.
This commit is contained in:
783
hscontrol/mapper/batcher_bench_test.go
Normal file
783
hscontrol/mapper/batcher_bench_test.go
Normal file
@@ -0,0 +1,783 @@
|
||||
package mapper
|
||||
|
||||
// Benchmarks for batcher components and full pipeline.
|
||||
//
|
||||
// Organized into three tiers:
|
||||
// - Component benchmarks: individual functions (connectionEntry.send, computePeerDiff, etc.)
|
||||
// - System benchmarks: batching mechanics (addToBatch, processBatchedChanges, broadcast)
|
||||
// - Full pipeline benchmarks: end-to-end with real DB (gated behind !testing.Short())
|
||||
//
|
||||
// All benchmarks use sub-benchmarks with 10/100/1000 node counts for scaling analysis.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// Component Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
// BenchmarkConnectionEntry_Send measures the throughput of sending a single
|
||||
// MapResponse through a connectionEntry with a buffered channel.
|
||||
func BenchmarkConnectionEntry_Send(b *testing.B) {
|
||||
ch := make(chan *tailcfg.MapResponse, b.N+1)
|
||||
entry := makeConnectionEntry("bench-conn", ch)
|
||||
data := testMapResponse()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
_ = entry.send(data)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMultiChannelSend measures broadcast throughput to multiple connections.
|
||||
func BenchmarkMultiChannelSend(b *testing.B) {
|
||||
for _, connCount := range []int{1, 3, 10} {
|
||||
b.Run(fmt.Sprintf("%dconn", connCount), func(b *testing.B) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
channels := make([]chan *tailcfg.MapResponse, connCount)
|
||||
for i := range channels {
|
||||
channels[i] = make(chan *tailcfg.MapResponse, b.N+1)
|
||||
mc.addConnection(makeConnectionEntry(fmt.Sprintf("conn-%d", i), channels[i]))
|
||||
}
|
||||
|
||||
data := testMapResponse()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
_ = mc.send(data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkComputePeerDiff measures the cost of computing peer diffs at scale.
|
||||
func BenchmarkComputePeerDiff(b *testing.B) {
|
||||
for _, peerCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dpeers", peerCount), func(b *testing.B) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// Populate tracked peers: 1..peerCount
|
||||
for i := 1; i <= peerCount; i++ {
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(i), struct{}{})
|
||||
}
|
||||
|
||||
// Current peers: remove ~10% (every 10th peer is missing)
|
||||
current := make([]tailcfg.NodeID, 0, peerCount)
|
||||
for i := 1; i <= peerCount; i++ {
|
||||
if i%10 != 0 {
|
||||
current = append(current, tailcfg.NodeID(i))
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
_ = mc.computePeerDiff(current)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkUpdateSentPeers measures the cost of updating peer tracking state.
|
||||
func BenchmarkUpdateSentPeers(b *testing.B) {
|
||||
for _, peerCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dpeers_full", peerCount), func(b *testing.B) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// Pre-build response with full peer list
|
||||
peerIDs := make([]tailcfg.NodeID, peerCount)
|
||||
for i := range peerIDs {
|
||||
peerIDs[i] = tailcfg.NodeID(i + 1)
|
||||
}
|
||||
|
||||
resp := testMapResponseWithPeers(peerIDs...)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
mc.updateSentPeers(resp)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("%dpeers_incremental", peerCount), func(b *testing.B) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// Pre-populate with existing peers
|
||||
for i := 1; i <= peerCount; i++ {
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(i), struct{}{})
|
||||
}
|
||||
|
||||
// Build incremental response: add 10% new peers
|
||||
addCount := peerCount / 10
|
||||
if addCount == 0 {
|
||||
addCount = 1
|
||||
}
|
||||
|
||||
resp := testMapResponse()
|
||||
|
||||
resp.PeersChanged = make([]*tailcfg.Node, addCount)
|
||||
for i := range addCount {
|
||||
resp.PeersChanged[i] = &tailcfg.Node{ID: tailcfg.NodeID(peerCount + i + 1)}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
mc.updateSentPeers(resp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// System Benchmarks (no DB, batcher mechanics only)
|
||||
// ============================================================================
|
||||
|
||||
// benchBatcher creates a lightweight batcher for benchmarks. Unlike the test
|
||||
// helper, it doesn't register cleanup and suppresses logging.
|
||||
func benchBatcher(nodeCount, bufferSize int) (*LockFreeBatcher, map[types.NodeID]chan *tailcfg.MapResponse) {
|
||||
b := &LockFreeBatcher{
|
||||
tick: time.NewTicker(1 * time.Hour), // never fires during bench
|
||||
workers: 4,
|
||||
workCh: make(chan work, 4*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
channels := make(map[types.NodeID]chan *tailcfg.MapResponse, nodeCount)
|
||||
for i := 1; i <= nodeCount; i++ {
|
||||
id := types.NodeID(i) //nolint:gosec // benchmark with small controlled values
|
||||
mc := newMultiChannelNodeConn(id, nil)
|
||||
ch := make(chan *tailcfg.MapResponse, bufferSize)
|
||||
entry := &connectionEntry{
|
||||
id: fmt.Sprintf("conn-%d", i),
|
||||
c: ch,
|
||||
version: tailcfg.CapabilityVersion(100),
|
||||
created: time.Now(),
|
||||
}
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
mc.addConnection(entry)
|
||||
b.nodes.Store(id, mc)
|
||||
b.connected.Store(id, nil)
|
||||
channels[id] = ch
|
||||
}
|
||||
|
||||
b.totalNodes.Store(int64(nodeCount))
|
||||
|
||||
return b, channels
|
||||
}
|
||||
|
||||
// BenchmarkAddToBatch_Broadcast measures the cost of broadcasting a change
|
||||
// to all nodes via addToBatch (no worker processing, just queuing).
|
||||
func BenchmarkAddToBatch_Broadcast(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, _ := benchBatcher(nodeCount, 10)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
ch := change.DERPMap()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
batcher.addToBatch(ch)
|
||||
// Clear pending to avoid unbounded growth
|
||||
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool {
|
||||
batcher.pendingChanges.Delete(id)
|
||||
return true
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAddToBatch_Targeted measures the cost of adding a targeted change
|
||||
// to a single node.
|
||||
func BenchmarkAddToBatch_Targeted(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, _ := benchBatcher(nodeCount, 10)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := range b.N {
|
||||
targetID := types.NodeID(1 + (i % nodeCount)) //nolint:gosec // benchmark
|
||||
ch := change.Change{
|
||||
Reason: "bench-targeted",
|
||||
TargetNode: targetID,
|
||||
PeerPatches: []*tailcfg.PeerChange{
|
||||
{NodeID: tailcfg.NodeID(targetID)}, //nolint:gosec // benchmark
|
||||
},
|
||||
}
|
||||
batcher.addToBatch(ch)
|
||||
// Clear pending periodically to avoid growth
|
||||
if i%100 == 99 {
|
||||
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool {
|
||||
batcher.pendingChanges.Delete(id)
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAddToBatch_FullUpdate measures the cost of a FullUpdate broadcast.
|
||||
func BenchmarkAddToBatch_FullUpdate(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, _ := benchBatcher(nodeCount, 10)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
batcher.addToBatch(change.FullUpdate())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkProcessBatchedChanges measures the cost of moving pending changes
|
||||
// to the work queue.
|
||||
func BenchmarkProcessBatchedChanges(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dpending", nodeCount), func(b *testing.B) {
|
||||
batcher, _ := benchBatcher(nodeCount, 10)
|
||||
// Use a very large work channel to avoid blocking
|
||||
batcher.workCh = make(chan work, nodeCount*b.N+1)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
b.StopTimer()
|
||||
// Seed pending changes
|
||||
for i := 1; i <= nodeCount; i++ {
|
||||
batcher.pendingChanges.Store(types.NodeID(i), []change.Change{change.DERPMap()}) //nolint:gosec // benchmark
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
batcher.processBatchedChanges()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkBroadcastToN measures end-to-end broadcast: addToBatch + processBatchedChanges
|
||||
// to N nodes. Does NOT include worker processing (MapResponse generation).
|
||||
func BenchmarkBroadcastToN(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, _ := benchBatcher(nodeCount, 10)
|
||||
batcher.workCh = make(chan work, nodeCount*b.N+1)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
ch := change.DERPMap()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
batcher.addToBatch(ch)
|
||||
batcher.processBatchedChanges()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMultiChannelBroadcast measures the cost of sending a MapResponse
|
||||
// to N nodes each with varying connection counts.
|
||||
func BenchmarkMultiChannelBroadcast(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, _ := benchBatcher(nodeCount, b.N+1)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
// Add extra connections to every 3rd node
|
||||
for i := 1; i <= nodeCount; i++ {
|
||||
if i%3 == 0 {
|
||||
if mc, ok := batcher.nodes.Load(types.NodeID(i)); ok { //nolint:gosec // benchmark
|
||||
for j := range 2 {
|
||||
ch := make(chan *tailcfg.MapResponse, b.N+1)
|
||||
entry := &connectionEntry{
|
||||
id: fmt.Sprintf("extra-%d-%d", i, j),
|
||||
c: ch,
|
||||
version: tailcfg.CapabilityVersion(100),
|
||||
created: time.Now(),
|
||||
}
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
mc.addConnection(entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data := testMapResponse()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
batcher.nodes.Range(func(_ types.NodeID, mc *multiChannelNodeConn) bool {
|
||||
_ = mc.send(data)
|
||||
return true
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConcurrentAddToBatch measures addToBatch throughput under
|
||||
// concurrent access from multiple goroutines.
|
||||
func BenchmarkConcurrentAddToBatch(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, _ := benchBatcher(nodeCount, 10)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
// Background goroutine to drain pending periodically
|
||||
drainDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(drainDone)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-batcher.done:
|
||||
return
|
||||
default:
|
||||
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool {
|
||||
batcher.pendingChanges.Delete(id)
|
||||
return true
|
||||
})
|
||||
time.Sleep(time.Millisecond) //nolint:forbidigo // benchmark drain loop
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ch := change.DERPMap()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
batcher.addToBatch(ch)
|
||||
}
|
||||
})
|
||||
b.StopTimer()
|
||||
|
||||
// Cleanup
|
||||
close(batcher.done)
|
||||
<-drainDone
|
||||
// Re-open done so the defer doesn't double-close
|
||||
batcher.done = make(chan struct{})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkIsConnected measures the read throughput of IsConnected checks.
|
||||
func BenchmarkIsConnected(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, _ := benchBatcher(nodeCount, 1)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := range b.N {
|
||||
id := types.NodeID(1 + (i % nodeCount)) //nolint:gosec // benchmark
|
||||
_ = batcher.IsConnected(id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConnectedMap measures the cost of building the full connected map.
|
||||
func BenchmarkConnectedMap(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, _ := benchBatcher(nodeCount, 1)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
// Disconnect 10% of nodes for a realistic mix
|
||||
for i := 1; i <= nodeCount; i++ {
|
||||
if i%10 == 0 {
|
||||
now := time.Now()
|
||||
batcher.connected.Store(types.NodeID(i), &now) //nolint:gosec // benchmark
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
_ = batcher.ConnectedMap()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConnectionChurn measures the cost of add/remove connection cycling
|
||||
// which simulates client reconnection patterns.
|
||||
func BenchmarkConnectionChurn(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100, 1000} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, channels := benchBatcher(nodeCount, 10)
|
||||
|
||||
defer func() {
|
||||
close(batcher.done)
|
||||
batcher.tick.Stop()
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := range b.N {
|
||||
id := types.NodeID(1 + (i % nodeCount)) //nolint:gosec // benchmark
|
||||
|
||||
mc, ok := batcher.nodes.Load(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove old connection
|
||||
oldCh := channels[id]
|
||||
mc.removeConnectionByChannel(oldCh)
|
||||
|
||||
// Add new connection
|
||||
newCh := make(chan *tailcfg.MapResponse, 10)
|
||||
entry := &connectionEntry{
|
||||
id: fmt.Sprintf("churn-%d", i),
|
||||
c: newCh,
|
||||
version: tailcfg.CapabilityVersion(100),
|
||||
created: time.Now(),
|
||||
}
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
mc.addConnection(entry)
|
||||
|
||||
channels[id] = newCh
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConcurrentSendAndChurn measures the combined cost of sends happening
|
||||
// concurrently with connection churn - the hot path in production.
|
||||
func BenchmarkConcurrentSendAndChurn(b *testing.B) {
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
batcher, channels := benchBatcher(nodeCount, 100)
|
||||
|
||||
var mu sync.Mutex // protect channels map
|
||||
|
||||
stopChurn := make(chan struct{})
|
||||
defer close(stopChurn)
|
||||
|
||||
// Background churn on 10% of nodes
|
||||
go func() {
|
||||
i := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopChurn:
|
||||
return
|
||||
default:
|
||||
id := types.NodeID(1 + (i % nodeCount)) //nolint:gosec // benchmark
|
||||
if i%10 == 0 { // only churn 10%
|
||||
mc, ok := batcher.nodes.Load(id)
|
||||
if ok {
|
||||
mu.Lock()
|
||||
oldCh := channels[id]
|
||||
mu.Unlock()
|
||||
mc.removeConnectionByChannel(oldCh)
|
||||
|
||||
newCh := make(chan *tailcfg.MapResponse, 100)
|
||||
entry := &connectionEntry{
|
||||
id: fmt.Sprintf("churn-%d", i),
|
||||
c: newCh,
|
||||
version: tailcfg.CapabilityVersion(100),
|
||||
created: time.Now(),
|
||||
}
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
mc.addConnection(entry)
|
||||
mu.Lock()
|
||||
channels[id] = newCh
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
data := testMapResponse()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
batcher.nodes.Range(func(_ types.NodeID, mc *multiChannelNodeConn) bool {
|
||||
_ = mc.send(data)
|
||||
return true
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Full Pipeline Benchmarks (with DB)
|
||||
// ============================================================================
|
||||
|
||||
// BenchmarkAddNode measures the cost of adding nodes to the batcher,
|
||||
// including initial MapResponse generation from a real database.
|
||||
func BenchmarkAddNode(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("skipping full pipeline benchmark in short mode")
|
||||
}
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
testData, cleanup := setupBatcherWithTestData(b, NewBatcherAndMapper, 1, nodeCount, LARGE_BUFFER_SIZE)
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
allNodes := testData.Nodes
|
||||
|
||||
// Start consumers
|
||||
for i := range allNodes {
|
||||
allNodes[i].start()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for i := range allNodes {
|
||||
allNodes[i].cleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
// Connect all nodes (measuring AddNode cost)
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
// Disconnect for next iteration
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.RemoveNode(node.n.ID, node.ch)
|
||||
}
|
||||
// Drain channels
|
||||
for i := range allNodes {
|
||||
for {
|
||||
select {
|
||||
case <-allNodes[i].ch:
|
||||
default:
|
||||
goto drained
|
||||
}
|
||||
}
|
||||
|
||||
drained:
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFullPipeline measures the full pipeline cost: addToBatch → processBatchedChanges
|
||||
// → worker → generateMapResponse → send, with real nodes from a database.
|
||||
func BenchmarkFullPipeline(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("skipping full pipeline benchmark in short mode")
|
||||
}
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
testData, cleanup := setupBatcherWithTestData(b, NewBatcherAndMapper, 1, nodeCount, LARGE_BUFFER_SIZE)
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
allNodes := testData.Nodes
|
||||
|
||||
// Start consumers
|
||||
for i := range allNodes {
|
||||
allNodes[i].start()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for i := range allNodes {
|
||||
allNodes[i].cleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
// Connect all nodes first
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
b.Fatalf("failed to add node %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for initial maps to settle
|
||||
time.Sleep(200 * time.Millisecond) //nolint:forbidigo // benchmark coordination
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for range b.N {
|
||||
batcher.AddWork(change.DERPMap())
|
||||
// Allow workers to process (the batcher tick is what normally
|
||||
// triggers processBatchedChanges, but for benchmarks we need
|
||||
// to give the system time to process)
|
||||
time.Sleep(20 * time.Millisecond) //nolint:forbidigo // benchmark coordination
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMapResponseFromChange measures the cost of synchronous
|
||||
// MapResponse generation for individual nodes.
|
||||
func BenchmarkMapResponseFromChange(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("skipping full pipeline benchmark in short mode")
|
||||
}
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
for _, nodeCount := range []int{10, 100} {
|
||||
b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) {
|
||||
testData, cleanup := setupBatcherWithTestData(b, NewBatcherAndMapper, 1, nodeCount, LARGE_BUFFER_SIZE)
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
allNodes := testData.Nodes
|
||||
|
||||
// Start consumers
|
||||
for i := range allNodes {
|
||||
allNodes[i].start()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for i := range allNodes {
|
||||
allNodes[i].cleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
// Connect all nodes
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
b.Fatalf("failed to add node %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond) //nolint:forbidigo // benchmark coordination
|
||||
|
||||
ch := change.DERPMap()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := range b.N {
|
||||
nodeIdx := i % len(allNodes)
|
||||
_, _ = batcher.MapResponseFromChange(allNodes[nodeIdx].n.ID, ch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -160,7 +160,7 @@ type node struct {
|
||||
//
|
||||
// Returns TestData struct containing all created entities and a cleanup function.
|
||||
func setupBatcherWithTestData(
|
||||
t *testing.T,
|
||||
t testing.TB,
|
||||
bf batcherFunc,
|
||||
userCount, nodesPerUser, bufferSize int,
|
||||
) (*TestData, func()) {
|
||||
|
||||
948
hscontrol/mapper/batcher_unit_test.go
Normal file
948
hscontrol/mapper/batcher_unit_test.go
Normal file
@@ -0,0 +1,948 @@
|
||||
package mapper
|
||||
|
||||
// Unit tests for batcher components that do NOT require database setup.
|
||||
// These tests exercise connectionEntry, multiChannelNodeConn, computePeerDiff,
|
||||
// updateSentPeers, generateMapResponse branching, and handleNodeChange in isolation.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// Mock Infrastructure
|
||||
// ============================================================================
|
||||
|
||||
// mockNodeConnection implements nodeConnection for isolated unit testing
|
||||
// of generateMapResponse and handleNodeChange without a real database.
|
||||
type mockNodeConnection struct {
|
||||
id types.NodeID
|
||||
ver tailcfg.CapabilityVersion
|
||||
|
||||
// sendFn allows injecting custom send behavior.
|
||||
// If nil, sends are recorded and succeed.
|
||||
sendFn func(*tailcfg.MapResponse) error
|
||||
|
||||
// sent records all successful sends for assertion.
|
||||
sent []*tailcfg.MapResponse
|
||||
mu sync.Mutex
|
||||
|
||||
// Peer tracking
|
||||
peers *xsync.Map[tailcfg.NodeID, struct{}]
|
||||
}
|
||||
|
||||
func newMockNodeConnection(id types.NodeID) *mockNodeConnection {
|
||||
return &mockNodeConnection{
|
||||
id: id,
|
||||
ver: tailcfg.CapabilityVersion(100),
|
||||
peers: xsync.NewMap[tailcfg.NodeID, struct{}](),
|
||||
}
|
||||
}
|
||||
|
||||
// withSendError configures the mock to return the given error on send.
|
||||
func (m *mockNodeConnection) withSendError(err error) *mockNodeConnection {
|
||||
m.sendFn = func(_ *tailcfg.MapResponse) error { return err }
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockNodeConnection) nodeID() types.NodeID { return m.id }
|
||||
func (m *mockNodeConnection) version() tailcfg.CapabilityVersion { return m.ver }
|
||||
|
||||
func (m *mockNodeConnection) send(data *tailcfg.MapResponse) error {
|
||||
if m.sendFn != nil {
|
||||
return m.sendFn(data)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.sent = append(m.sent, data)
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockNodeConnection) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
|
||||
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
|
||||
for _, id := range currentPeers {
|
||||
currentSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
var removed []tailcfg.NodeID
|
||||
|
||||
m.peers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
|
||||
if _, exists := currentSet[id]; !exists {
|
||||
removed = append(removed, id)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
func (m *mockNodeConnection) updateSentPeers(resp *tailcfg.MapResponse) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Peers != nil {
|
||||
m.peers.Clear()
|
||||
|
||||
for _, peer := range resp.Peers {
|
||||
m.peers.Store(peer.ID, struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
for _, peer := range resp.PeersChanged {
|
||||
m.peers.Store(peer.ID, struct{}{})
|
||||
}
|
||||
|
||||
for _, id := range resp.PeersRemoved {
|
||||
m.peers.Delete(id)
|
||||
}
|
||||
}
|
||||
|
||||
// getSent returns a thread-safe copy of all sent responses.
|
||||
func (m *mockNodeConnection) getSent() []*tailcfg.MapResponse {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return append([]*tailcfg.MapResponse{}, m.sent...)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test Helpers
|
||||
// ============================================================================
|
||||
|
||||
// testMapResponse creates a minimal valid MapResponse for testing.
|
||||
func testMapResponse() *tailcfg.MapResponse {
|
||||
now := time.Now()
|
||||
|
||||
return &tailcfg.MapResponse{
|
||||
ControlTime: &now,
|
||||
}
|
||||
}
|
||||
|
||||
// testMapResponseWithPeers creates a MapResponse with the given peer IDs.
|
||||
func testMapResponseWithPeers(peerIDs ...tailcfg.NodeID) *tailcfg.MapResponse {
|
||||
resp := testMapResponse()
|
||||
|
||||
resp.Peers = make([]*tailcfg.Node, len(peerIDs))
|
||||
for i, id := range peerIDs {
|
||||
resp.Peers[i] = &tailcfg.Node{ID: id}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// ids is a convenience for creating a slice of tailcfg.NodeID.
|
||||
func ids(nodeIDs ...tailcfg.NodeID) []tailcfg.NodeID {
|
||||
return nodeIDs
|
||||
}
|
||||
|
||||
// expectReceive asserts that a message arrives on the channel within 100ms.
|
||||
func expectReceive(t *testing.T, ch <-chan *tailcfg.MapResponse, msg string) *tailcfg.MapResponse {
|
||||
t.Helper()
|
||||
|
||||
const timeout = 100 * time.Millisecond
|
||||
|
||||
select {
|
||||
case data := <-ch:
|
||||
return data
|
||||
case <-time.After(timeout):
|
||||
t.Fatalf("expected to receive on channel within %v: %s", timeout, msg)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// expectNoReceive asserts that no message arrives within timeout.
|
||||
func expectNoReceive(t *testing.T, ch <-chan *tailcfg.MapResponse, timeout time.Duration, msg string) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case data := <-ch:
|
||||
t.Fatalf("expected no receive but got %+v: %s", data, msg)
|
||||
case <-time.After(timeout):
|
||||
// Expected
|
||||
}
|
||||
}
|
||||
|
||||
// makeConnectionEntry creates a connectionEntry with the given channel.
|
||||
func makeConnectionEntry(id string, ch chan<- *tailcfg.MapResponse) *connectionEntry {
|
||||
entry := &connectionEntry{
|
||||
id: id,
|
||||
c: ch,
|
||||
version: tailcfg.CapabilityVersion(100),
|
||||
created: time.Now(),
|
||||
}
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// connectionEntry.send() Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestConnectionEntry_SendSuccess(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 1)
|
||||
entry := makeConnectionEntry("test-conn", ch)
|
||||
data := testMapResponse()
|
||||
|
||||
beforeSend := time.Now().Unix()
|
||||
err := entry.send(data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, entry.lastUsed.Load(), beforeSend,
|
||||
"lastUsed should be updated after successful send")
|
||||
|
||||
// Verify data was actually sent
|
||||
received := expectReceive(t, ch, "data should be on channel")
|
||||
assert.Equal(t, data, received)
|
||||
}
|
||||
|
||||
func TestConnectionEntry_SendNilData(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 1)
|
||||
entry := makeConnectionEntry("test-conn", ch)
|
||||
|
||||
err := entry.send(nil)
|
||||
|
||||
require.NoError(t, err, "nil data should return nil error")
|
||||
expectNoReceive(t, ch, 10*time.Millisecond, "nil data should not be sent to channel")
|
||||
}
|
||||
|
||||
func TestConnectionEntry_SendTimeout(t *testing.T) {
|
||||
// Unbuffered channel with no reader = always blocks
|
||||
ch := make(chan *tailcfg.MapResponse)
|
||||
entry := makeConnectionEntry("test-conn", ch)
|
||||
data := testMapResponse()
|
||||
|
||||
start := time.Now()
|
||||
err := entry.send(data)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.ErrorIs(t, err, ErrConnectionSendTimeout)
|
||||
assert.GreaterOrEqual(t, elapsed, 40*time.Millisecond,
|
||||
"should wait approximately 50ms before timeout")
|
||||
}
|
||||
|
||||
func TestConnectionEntry_SendClosed(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 1)
|
||||
entry := makeConnectionEntry("test-conn", ch)
|
||||
|
||||
// Mark as closed before sending
|
||||
entry.closed.Store(true)
|
||||
|
||||
err := entry.send(testMapResponse())
|
||||
|
||||
require.ErrorIs(t, err, errConnectionClosed)
|
||||
expectNoReceive(t, ch, 10*time.Millisecond,
|
||||
"closed entry should not send data to channel")
|
||||
}
|
||||
|
||||
func TestConnectionEntry_SendUpdatesLastUsed(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 1)
|
||||
entry := makeConnectionEntry("test-conn", ch)
|
||||
|
||||
// Set lastUsed to a past time
|
||||
pastTime := time.Now().Add(-1 * time.Hour).Unix()
|
||||
entry.lastUsed.Store(pastTime)
|
||||
|
||||
err := entry.send(testMapResponse())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Greater(t, entry.lastUsed.Load(), pastTime,
|
||||
"lastUsed should be updated to current time after send")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// multiChannelNodeConn.send() Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestMultiChannelSend_AllSuccess(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// Create 3 buffered channels (all will succeed)
|
||||
channels := make([]chan *tailcfg.MapResponse, 3)
|
||||
for i := range channels {
|
||||
channels[i] = make(chan *tailcfg.MapResponse, 1)
|
||||
mc.addConnection(makeConnectionEntry(fmt.Sprintf("conn-%d", i), channels[i]))
|
||||
}
|
||||
|
||||
data := testMapResponse()
|
||||
err := mc.send(data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, mc.getActiveConnectionCount(),
|
||||
"all connections should remain active after success")
|
||||
|
||||
// Verify all channels received the data
|
||||
for i, ch := range channels {
|
||||
received := expectReceive(t, ch,
|
||||
fmt.Sprintf("channel %d should receive data", i))
|
||||
assert.Equal(t, data, received)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiChannelSend_PartialFailure(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// 2 buffered channels (will succeed) + 1 unbuffered (will timeout)
|
||||
goodCh1 := make(chan *tailcfg.MapResponse, 1)
|
||||
goodCh2 := make(chan *tailcfg.MapResponse, 1)
|
||||
badCh := make(chan *tailcfg.MapResponse) // unbuffered, no reader
|
||||
|
||||
mc.addConnection(makeConnectionEntry("good-1", goodCh1))
|
||||
mc.addConnection(makeConnectionEntry("bad", badCh))
|
||||
mc.addConnection(makeConnectionEntry("good-2", goodCh2))
|
||||
|
||||
err := mc.send(testMapResponse())
|
||||
|
||||
require.NoError(t, err, "should succeed if at least one connection works")
|
||||
assert.Equal(t, 2, mc.getActiveConnectionCount(),
|
||||
"failed connection should be removed")
|
||||
|
||||
// Good channels should have received data
|
||||
expectReceive(t, goodCh1, "good-1 should receive")
|
||||
expectReceive(t, goodCh2, "good-2 should receive")
|
||||
}
|
||||
|
||||
func TestMultiChannelSend_AllFail(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// All unbuffered channels with no readers
|
||||
for i := range 3 {
|
||||
ch := make(chan *tailcfg.MapResponse) // unbuffered
|
||||
mc.addConnection(makeConnectionEntry(fmt.Sprintf("bad-%d", i), ch))
|
||||
}
|
||||
|
||||
err := mc.send(testMapResponse())
|
||||
|
||||
require.Error(t, err, "should return error when all connections fail")
|
||||
assert.Equal(t, 0, mc.getActiveConnectionCount(),
|
||||
"all failed connections should be removed")
|
||||
}
|
||||
|
||||
func TestMultiChannelSend_ZeroConnections(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
err := mc.send(testMapResponse())
|
||||
|
||||
require.NoError(t, err,
|
||||
"sending to node with 0 connections should succeed silently (rapid reconnection scenario)")
|
||||
}
|
||||
|
||||
func TestMultiChannelSend_NilData(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
ch := make(chan *tailcfg.MapResponse, 1)
|
||||
mc.addConnection(makeConnectionEntry("conn", ch))
|
||||
|
||||
err := mc.send(nil)
|
||||
|
||||
require.NoError(t, err, "nil data should return nil immediately")
|
||||
expectNoReceive(t, ch, 10*time.Millisecond, "nil data should not be sent")
|
||||
}
|
||||
|
||||
func TestMultiChannelSend_FailedConnectionRemoved(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
goodCh := make(chan *tailcfg.MapResponse, 10) // large buffer
|
||||
badCh := make(chan *tailcfg.MapResponse) // unbuffered, will timeout
|
||||
|
||||
mc.addConnection(makeConnectionEntry("good", goodCh))
|
||||
mc.addConnection(makeConnectionEntry("bad", badCh))
|
||||
|
||||
assert.Equal(t, 2, mc.getActiveConnectionCount())
|
||||
|
||||
// First send: bad connection removed
|
||||
err := mc.send(testMapResponse())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, mc.getActiveConnectionCount())
|
||||
|
||||
// Second send: only good connection remains, should succeed
|
||||
err = mc.send(testMapResponse())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, mc.getActiveConnectionCount())
|
||||
}
|
||||
|
||||
func TestMultiChannelSend_UpdateCount(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
ch := make(chan *tailcfg.MapResponse, 10)
|
||||
mc.addConnection(makeConnectionEntry("conn", ch))
|
||||
|
||||
assert.Equal(t, int64(0), mc.updateCount.Load())
|
||||
|
||||
_ = mc.send(testMapResponse())
|
||||
assert.Equal(t, int64(1), mc.updateCount.Load())
|
||||
|
||||
_ = mc.send(testMapResponse())
|
||||
assert.Equal(t, int64(2), mc.updateCount.Load())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// multiChannelNodeConn.close() Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestMultiChannelClose_MarksEntriesClosed(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
entries := make([]*connectionEntry, 3)
|
||||
for i := range entries {
|
||||
ch := make(chan *tailcfg.MapResponse, 1)
|
||||
entries[i] = makeConnectionEntry(fmt.Sprintf("conn-%d", i), ch)
|
||||
mc.addConnection(entries[i])
|
||||
}
|
||||
|
||||
mc.close()
|
||||
|
||||
for i, entry := range entries {
|
||||
assert.True(t, entry.closed.Load(),
|
||||
"entry %d should be marked as closed", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiChannelClose_PreventsSendPanic(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
ch := make(chan *tailcfg.MapResponse, 1)
|
||||
entry := makeConnectionEntry("conn", ch)
|
||||
mc.addConnection(entry)
|
||||
|
||||
mc.close()
|
||||
|
||||
// After close, connectionEntry.send should return errConnectionClosed
|
||||
// (not panic on send to closed channel)
|
||||
err := entry.send(testMapResponse())
|
||||
require.ErrorIs(t, err, errConnectionClosed,
|
||||
"send after close should return errConnectionClosed, not panic")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// multiChannelNodeConn connection management Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestMultiChannelNodeConn_AddRemoveConnections(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||
ch2 := make(chan *tailcfg.MapResponse, 1)
|
||||
ch3 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
// Add connections
|
||||
mc.addConnection(makeConnectionEntry("c1", ch1))
|
||||
assert.Equal(t, 1, mc.getActiveConnectionCount())
|
||||
assert.True(t, mc.hasActiveConnections())
|
||||
|
||||
mc.addConnection(makeConnectionEntry("c2", ch2))
|
||||
mc.addConnection(makeConnectionEntry("c3", ch3))
|
||||
assert.Equal(t, 3, mc.getActiveConnectionCount())
|
||||
|
||||
// Remove by channel pointer
|
||||
assert.True(t, mc.removeConnectionByChannel(ch2))
|
||||
assert.Equal(t, 2, mc.getActiveConnectionCount())
|
||||
|
||||
// Remove non-existent channel
|
||||
nonExistentCh := make(chan *tailcfg.MapResponse)
|
||||
assert.False(t, mc.removeConnectionByChannel(nonExistentCh))
|
||||
assert.Equal(t, 2, mc.getActiveConnectionCount())
|
||||
|
||||
// Remove remaining
|
||||
assert.True(t, mc.removeConnectionByChannel(ch1))
|
||||
assert.True(t, mc.removeConnectionByChannel(ch3))
|
||||
assert.Equal(t, 0, mc.getActiveConnectionCount())
|
||||
assert.False(t, mc.hasActiveConnections())
|
||||
}
|
||||
|
||||
func TestMultiChannelNodeConn_Version(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// No connections - version should be 0
|
||||
assert.Equal(t, tailcfg.CapabilityVersion(0), mc.version())
|
||||
|
||||
// Add connection with version 100
|
||||
ch := make(chan *tailcfg.MapResponse, 1)
|
||||
entry := makeConnectionEntry("conn", ch)
|
||||
entry.version = tailcfg.CapabilityVersion(100)
|
||||
mc.addConnection(entry)
|
||||
|
||||
assert.Equal(t, tailcfg.CapabilityVersion(100), mc.version())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// computePeerDiff Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestComputePeerDiff(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tracked []tailcfg.NodeID // peers previously sent to client
|
||||
current []tailcfg.NodeID // peers visible now
|
||||
wantRemoved []tailcfg.NodeID // expected removed peers
|
||||
}{
|
||||
{
|
||||
name: "no_changes",
|
||||
tracked: ids(1, 2, 3),
|
||||
current: ids(1, 2, 3),
|
||||
wantRemoved: nil,
|
||||
},
|
||||
{
|
||||
name: "one_removed",
|
||||
tracked: ids(1, 2, 3),
|
||||
current: ids(1, 3),
|
||||
wantRemoved: ids(2),
|
||||
},
|
||||
{
|
||||
name: "multiple_removed",
|
||||
tracked: ids(1, 2, 3, 4, 5),
|
||||
current: ids(2, 4),
|
||||
wantRemoved: ids(1, 3, 5),
|
||||
},
|
||||
{
|
||||
name: "all_removed",
|
||||
tracked: ids(1, 2, 3),
|
||||
current: nil,
|
||||
wantRemoved: ids(1, 2, 3),
|
||||
},
|
||||
{
|
||||
name: "peers_added_no_removal",
|
||||
tracked: ids(1),
|
||||
current: ids(1, 2, 3),
|
||||
wantRemoved: nil,
|
||||
},
|
||||
{
|
||||
name: "empty_tracked",
|
||||
tracked: nil,
|
||||
current: ids(1, 2, 3),
|
||||
wantRemoved: nil,
|
||||
},
|
||||
{
|
||||
name: "both_empty",
|
||||
tracked: nil,
|
||||
current: nil,
|
||||
wantRemoved: nil,
|
||||
},
|
||||
{
|
||||
name: "disjoint_sets",
|
||||
tracked: ids(1, 2, 3),
|
||||
current: ids(4, 5, 6),
|
||||
wantRemoved: ids(1, 2, 3),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// Populate tracked peers
|
||||
for _, id := range tt.tracked {
|
||||
mc.lastSentPeers.Store(id, struct{}{})
|
||||
}
|
||||
|
||||
got := mc.computePeerDiff(tt.current)
|
||||
|
||||
assert.ElementsMatch(t, tt.wantRemoved, got,
|
||||
"removed peers should match expected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// updateSentPeers Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestUpdateSentPeers(t *testing.T) {
|
||||
t.Run("full_peer_list_replaces_all", func(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
// Pre-populate with old peers
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(100), struct{}{})
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(200), struct{}{})
|
||||
|
||||
// Send full peer list
|
||||
mc.updateSentPeers(testMapResponseWithPeers(1, 2, 3))
|
||||
|
||||
// Old peers should be gone
|
||||
_, exists := mc.lastSentPeers.Load(tailcfg.NodeID(100))
|
||||
assert.False(t, exists, "old peer 100 should be cleared")
|
||||
|
||||
// New peers should be tracked
|
||||
for _, id := range ids(1, 2, 3) {
|
||||
_, exists := mc.lastSentPeers.Load(id)
|
||||
assert.True(t, exists, "peer %d should be tracked", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("incremental_add_via_PeersChanged", func(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
|
||||
|
||||
resp := testMapResponse()
|
||||
resp.PeersChanged = []*tailcfg.Node{{ID: 2}, {ID: 3}}
|
||||
mc.updateSentPeers(resp)
|
||||
|
||||
// All three should be tracked
|
||||
for _, id := range ids(1, 2, 3) {
|
||||
_, exists := mc.lastSentPeers.Load(id)
|
||||
assert.True(t, exists, "peer %d should be tracked", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("incremental_remove_via_PeersRemoved", func(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(2), struct{}{})
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(3), struct{}{})
|
||||
|
||||
resp := testMapResponse()
|
||||
resp.PeersRemoved = ids(2)
|
||||
mc.updateSentPeers(resp)
|
||||
|
||||
_, exists1 := mc.lastSentPeers.Load(tailcfg.NodeID(1))
|
||||
_, exists2 := mc.lastSentPeers.Load(tailcfg.NodeID(2))
|
||||
_, exists3 := mc.lastSentPeers.Load(tailcfg.NodeID(3))
|
||||
|
||||
assert.True(t, exists1, "peer 1 should remain")
|
||||
assert.False(t, exists2, "peer 2 should be removed")
|
||||
assert.True(t, exists3, "peer 3 should remain")
|
||||
})
|
||||
|
||||
t.Run("nil_response_is_noop", func(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
|
||||
|
||||
mc.updateSentPeers(nil)
|
||||
|
||||
_, exists := mc.lastSentPeers.Load(tailcfg.NodeID(1))
|
||||
assert.True(t, exists, "nil response should not change tracked peers")
|
||||
})
|
||||
|
||||
t.Run("full_then_incremental_sequence", func(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// Step 1: Full peer list
|
||||
mc.updateSentPeers(testMapResponseWithPeers(1, 2, 3))
|
||||
|
||||
// Step 2: Add peer 4
|
||||
resp := testMapResponse()
|
||||
resp.PeersChanged = []*tailcfg.Node{{ID: 4}}
|
||||
mc.updateSentPeers(resp)
|
||||
|
||||
// Step 3: Remove peer 2
|
||||
resp2 := testMapResponse()
|
||||
resp2.PeersRemoved = ids(2)
|
||||
mc.updateSentPeers(resp2)
|
||||
|
||||
// Should have 1, 3, 4
|
||||
for _, id := range ids(1, 3, 4) {
|
||||
_, exists := mc.lastSentPeers.Load(id)
|
||||
assert.True(t, exists, "peer %d should be tracked", id)
|
||||
}
|
||||
|
||||
_, exists := mc.lastSentPeers.Load(tailcfg.NodeID(2))
|
||||
assert.False(t, exists, "peer 2 should have been removed")
|
||||
})
|
||||
|
||||
t.Run("empty_full_peer_list_clears_all", func(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
|
||||
mc.lastSentPeers.Store(tailcfg.NodeID(2), struct{}{})
|
||||
|
||||
// Empty Peers slice (not nil) means "no peers"
|
||||
resp := testMapResponse()
|
||||
resp.Peers = []*tailcfg.Node{} // empty, not nil
|
||||
mc.updateSentPeers(resp)
|
||||
|
||||
count := 0
|
||||
|
||||
mc.lastSentPeers.Range(func(_ tailcfg.NodeID, _ struct{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 0, count, "empty peer list should clear all tracking")
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// generateMapResponse Tests (branching logic only, no DB needed)
|
||||
// ============================================================================
|
||||
|
||||
func TestGenerateMapResponse_EmptyChange(t *testing.T) {
|
||||
mc := newMockNodeConnection(1)
|
||||
|
||||
resp, err := generateMapResponse(mc, nil, change.Change{})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp, "empty change should return nil response")
|
||||
}
|
||||
|
||||
func TestGenerateMapResponse_InvalidNodeID(t *testing.T) {
|
||||
mc := newMockNodeConnection(0) // Invalid ID
|
||||
|
||||
resp, err := generateMapResponse(mc, &mapper{}, change.DERPMap())
|
||||
|
||||
require.ErrorIs(t, err, ErrInvalidNodeID)
|
||||
assert.Nil(t, resp)
|
||||
}
|
||||
|
||||
func TestGenerateMapResponse_NilMapper(t *testing.T) {
|
||||
mc := newMockNodeConnection(1)
|
||||
|
||||
resp, err := generateMapResponse(mc, nil, change.DERPMap())
|
||||
|
||||
require.ErrorIs(t, err, ErrMapperNil)
|
||||
assert.Nil(t, resp)
|
||||
}
|
||||
|
||||
func TestGenerateMapResponse_SelfOnlyOtherNode(t *testing.T) {
|
||||
mc := newMockNodeConnection(1)
|
||||
|
||||
// SelfUpdate targeted at node 99 should be skipped for node 1
|
||||
ch := change.SelfUpdate(99)
|
||||
resp, err := generateMapResponse(mc, &mapper{}, ch)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp,
|
||||
"self-only change targeted at different node should return nil")
|
||||
}
|
||||
|
||||
func TestGenerateMapResponse_SelfOnlySameNode(t *testing.T) {
|
||||
// SelfUpdate targeted at node 1: IsSelfOnly()=true and TargetNode==nodeID
|
||||
// This should NOT be short-circuited - it should attempt to generate.
|
||||
// We verify the routing logic by checking that the change is not empty
|
||||
// and not filtered out (unlike SelfOnlyOtherNode above).
|
||||
ch := change.SelfUpdate(1)
|
||||
assert.False(t, ch.IsEmpty(), "SelfUpdate should not be empty")
|
||||
assert.True(t, ch.IsSelfOnly(), "SelfUpdate should be self-only")
|
||||
assert.True(t, ch.ShouldSendToNode(1), "should be sent to target node")
|
||||
assert.False(t, ch.ShouldSendToNode(2), "should NOT be sent to other nodes")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// handleNodeChange Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestHandleNodeChange_NilConnection(t *testing.T) {
|
||||
err := handleNodeChange(nil, nil, change.DERPMap())
|
||||
|
||||
assert.ErrorIs(t, err, ErrNodeConnectionNil)
|
||||
}
|
||||
|
||||
func TestHandleNodeChange_EmptyChange(t *testing.T) {
|
||||
mc := newMockNodeConnection(1)
|
||||
|
||||
err := handleNodeChange(mc, nil, change.Change{})
|
||||
|
||||
require.NoError(t, err, "empty change should not send anything")
|
||||
assert.Empty(t, mc.getSent(), "no data should be sent for empty change")
|
||||
}
|
||||
|
||||
var errConnectionBroken = errors.New("connection broken")
|
||||
|
||||
func TestHandleNodeChange_SendError(t *testing.T) {
|
||||
mc := newMockNodeConnection(1).withSendError(errConnectionBroken)
|
||||
|
||||
// Need a real mapper for this test - we can't easily mock it.
|
||||
// Instead, test that when generateMapResponse returns nil data,
|
||||
// no send occurs. The send error path requires a valid MapResponse
|
||||
// which requires a mapper with state.
|
||||
// So we test the nil-data path here.
|
||||
err := handleNodeChange(mc, nil, change.Change{})
|
||||
assert.NoError(t, err, "empty change produces nil data, no send needed")
|
||||
}
|
||||
|
||||
func TestHandleNodeChange_NilDataNoSend(t *testing.T) {
|
||||
mc := newMockNodeConnection(1)
|
||||
|
||||
// SelfUpdate targeted at different node produces nil data
|
||||
ch := change.SelfUpdate(99)
|
||||
err := handleNodeChange(mc, &mapper{}, ch)
|
||||
|
||||
require.NoError(t, err, "nil data should not cause error")
|
||||
assert.Empty(t, mc.getSent(), "nil data should not trigger send")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// connectionEntry concurrent safety Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestConnectionEntry_ConcurrentSends(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 100)
|
||||
entry := makeConnectionEntry("concurrent", ch)
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
successCount atomic.Int64
|
||||
)
|
||||
|
||||
// 50 goroutines sending concurrently
|
||||
|
||||
for range 50 {
|
||||
wg.Go(func() {
|
||||
err := entry.send(testMapResponse())
|
||||
if err == nil {
|
||||
successCount.Add(1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, int64(50), successCount.Load(),
|
||||
"all sends to buffered channel should succeed")
|
||||
|
||||
// Drain and count
|
||||
count := 0
|
||||
|
||||
for range len(ch) {
|
||||
<-ch
|
||||
|
||||
count++
|
||||
}
|
||||
|
||||
assert.Equal(t, 50, count, "all 50 messages should be on channel")
|
||||
}
|
||||
|
||||
func TestConnectionEntry_ConcurrentSendAndClose(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 100)
|
||||
entry := makeConnectionEntry("race", ch)
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
panicked atomic.Bool
|
||||
)
|
||||
|
||||
// Goroutines sending rapidly
|
||||
|
||||
for range 20 {
|
||||
wg.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked.Store(true)
|
||||
}
|
||||
}()
|
||||
|
||||
for range 10 {
|
||||
_ = entry.send(testMapResponse())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Close midway through
|
||||
|
||||
wg.Go(func() {
|
||||
time.Sleep(1 * time.Millisecond) //nolint:forbidigo // concurrency test coordination
|
||||
entry.closed.Store(true)
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.False(t, panicked.Load(),
|
||||
"concurrent send and close should not panic")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// multiChannelNodeConn concurrent Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestMultiChannelSend_ConcurrentAddAndSend(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
// Start with one connection
|
||||
ch1 := make(chan *tailcfg.MapResponse, 100)
|
||||
mc.addConnection(makeConnectionEntry("initial", ch1))
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
panicked atomic.Bool
|
||||
)
|
||||
|
||||
// Goroutine adding connections
|
||||
|
||||
wg.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked.Store(true)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := range 10 {
|
||||
ch := make(chan *tailcfg.MapResponse, 100)
|
||||
mc.addConnection(makeConnectionEntry(fmt.Sprintf("added-%d", i), ch))
|
||||
}
|
||||
})
|
||||
|
||||
// Goroutine sending data
|
||||
|
||||
wg.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked.Store(true)
|
||||
}
|
||||
}()
|
||||
|
||||
for range 20 {
|
||||
_ = mc.send(testMapResponse())
|
||||
}
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.False(t, panicked.Load(),
|
||||
"concurrent add and send should not panic (mutex protects both)")
|
||||
}
|
||||
|
||||
func TestMultiChannelSend_ConcurrentRemoveAndSend(t *testing.T) {
|
||||
mc := newMultiChannelNodeConn(1, nil)
|
||||
|
||||
channels := make([]chan *tailcfg.MapResponse, 10)
|
||||
for i := range channels {
|
||||
channels[i] = make(chan *tailcfg.MapResponse, 100)
|
||||
mc.addConnection(makeConnectionEntry(fmt.Sprintf("conn-%d", i), channels[i]))
|
||||
}
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
panicked atomic.Bool
|
||||
)
|
||||
|
||||
// Goroutine removing connections
|
||||
|
||||
wg.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked.Store(true)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, ch := range channels {
|
||||
mc.removeConnectionByChannel(ch)
|
||||
}
|
||||
})
|
||||
|
||||
// Goroutine sending data concurrently
|
||||
|
||||
wg.Go(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked.Store(true)
|
||||
}
|
||||
}()
|
||||
|
||||
for range 20 {
|
||||
_ = mc.send(testMapResponse())
|
||||
}
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.False(t, panicked.Load(),
|
||||
"concurrent remove and send should not panic")
|
||||
}
|
||||
Reference in New Issue
Block a user