mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-22 05:59:02 +09:00
hscontrol: add servertest harness for in-process control plane testing
Add a new hscontrol/servertest package that provides a test harness for exercising the full Headscale control protocol in-process, using Tailscale's controlclient.Direct as the client. The harness consists of: - TestServer: wraps a Headscale instance with an httptest.Server - TestClient: wraps controlclient.Direct with NetworkMap tracking - TestHarness: orchestrates N clients against a single server - Assertion helpers for mesh completeness, visibility, and consistency Export minimal accessor methods on Headscale (HTTPHandler, NoisePublicKey, GetState, SetServerURL, StartBatcher, StartEphemeralGC) so the servertest package can construct a working server from outside the hscontrol package. This enables fast, deterministic tests of connection lifecycle, update propagation, and network weather scenarios without Docker.
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
@@ -1069,6 +1070,56 @@ func (h *Headscale) Change(cs ...change.Change) {
|
||||
h.mapBatcher.AddWork(cs...)
|
||||
}
|
||||
|
||||
// HTTPHandler returns an http.Handler for the Headscale control server.
|
||||
// The handler serves the Tailscale control protocol including the /key
|
||||
// endpoint and /ts2021 Noise upgrade path.
|
||||
func (h *Headscale) HTTPHandler() http.Handler {
|
||||
return h.createRouter(grpcRuntime.NewServeMux())
|
||||
}
|
||||
|
||||
// NoisePublicKey returns the server's Noise protocol public key.
|
||||
func (h *Headscale) NoisePublicKey() key.MachinePublic {
|
||||
return h.noisePrivateKey.Public()
|
||||
}
|
||||
|
||||
// GetState returns the server's state manager for programmatic access
|
||||
// to users, nodes, policies, and other server state.
|
||||
func (h *Headscale) GetState() *state.State {
|
||||
return h.state
|
||||
}
|
||||
|
||||
// SetServerURLForTest updates the server URL in the configuration.
|
||||
// This is needed for test servers where the URL is not known until
|
||||
// the HTTP test server starts.
|
||||
// It panics when called outside of tests.
|
||||
func (h *Headscale) SetServerURLForTest(tb testing.TB, url string) {
|
||||
tb.Helper()
|
||||
|
||||
h.cfg.ServerURL = url
|
||||
}
|
||||
|
||||
// StartBatcherForTest initialises and starts the map response batcher.
|
||||
// It registers a cleanup function on tb to stop the batcher.
|
||||
// It panics when called outside of tests.
|
||||
func (h *Headscale) StartBatcherForTest(tb testing.TB) {
|
||||
tb.Helper()
|
||||
|
||||
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||
h.mapBatcher.Start()
|
||||
tb.Cleanup(func() { h.mapBatcher.Close() })
|
||||
}
|
||||
|
||||
// StartEphemeralGCForTest starts the ephemeral node garbage collector.
|
||||
// It registers a cleanup function on tb to stop the collector.
|
||||
// It panics when called outside of tests.
|
||||
func (h *Headscale) StartEphemeralGCForTest(tb testing.TB) {
|
||||
tb.Helper()
|
||||
|
||||
go h.ephemeralGC.Start()
|
||||
|
||||
tb.Cleanup(func() { h.ephemeralGC.Close() })
|
||||
}
|
||||
|
||||
// Provide some middleware that can inspect the ACME/autocert https calls
|
||||
// and log when things are failing.
|
||||
type acmeLogger struct {
|
||||
|
||||
219
hscontrol/servertest/assertions.go
Normal file
219
hscontrol/servertest/assertions.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package servertest
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AssertMeshComplete verifies that every client in the slice sees
|
||||
// exactly (len(clients) - 1) peers, i.e. a fully connected mesh.
|
||||
func AssertMeshComplete(tb testing.TB, clients []*TestClient) {
|
||||
tb.Helper()
|
||||
|
||||
expected := len(clients) - 1
|
||||
for _, c := range clients {
|
||||
nm := c.Netmap()
|
||||
if nm == nil {
|
||||
tb.Errorf("AssertMeshComplete: %s has no netmap", c.Name)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if got := len(nm.Peers); got != expected {
|
||||
tb.Errorf("AssertMeshComplete: %s has %d peers, want %d (peers: %v)",
|
||||
c.Name, got, expected, c.PeerNames())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AssertSymmetricVisibility checks that peer visibility is symmetric:
|
||||
// if client A sees client B, then client B must also see client A.
|
||||
func AssertSymmetricVisibility(tb testing.TB, clients []*TestClient) {
|
||||
tb.Helper()
|
||||
|
||||
for _, a := range clients {
|
||||
for _, b := range clients {
|
||||
if a == b {
|
||||
continue
|
||||
}
|
||||
|
||||
_, aSeesB := a.PeerByName(b.Name)
|
||||
|
||||
_, bSeesA := b.PeerByName(a.Name)
|
||||
if aSeesB != bSeesA {
|
||||
tb.Errorf("AssertSymmetricVisibility: %s sees %s = %v, but %s sees %s = %v",
|
||||
a.Name, b.Name, aSeesB, b.Name, a.Name, bSeesA)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AssertPeerOnline checks that the observer sees peerName as online.
|
||||
func AssertPeerOnline(tb testing.TB, observer *TestClient, peerName string) {
|
||||
tb.Helper()
|
||||
|
||||
peer, ok := observer.PeerByName(peerName)
|
||||
if !ok {
|
||||
tb.Errorf("AssertPeerOnline: %s does not see peer %s", observer.Name, peerName)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
isOnline, known := peer.Online().GetOk()
|
||||
if !known || !isOnline {
|
||||
tb.Errorf("AssertPeerOnline: %s sees peer %s but Online=%v (known=%v), want true",
|
||||
observer.Name, peerName, isOnline, known)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertPeerOffline checks that the observer sees peerName as offline.
|
||||
func AssertPeerOffline(tb testing.TB, observer *TestClient, peerName string) {
|
||||
tb.Helper()
|
||||
|
||||
peer, ok := observer.PeerByName(peerName)
|
||||
if !ok {
|
||||
// Peer gone entirely counts as "offline" for this assertion.
|
||||
return
|
||||
}
|
||||
|
||||
isOnline, known := peer.Online().GetOk()
|
||||
if known && isOnline {
|
||||
tb.Errorf("AssertPeerOffline: %s sees peer %s as online, want offline",
|
||||
observer.Name, peerName)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertPeerGone checks that the observer does NOT have peerName in
|
||||
// its peer list at all.
|
||||
func AssertPeerGone(tb testing.TB, observer *TestClient, peerName string) {
|
||||
tb.Helper()
|
||||
|
||||
_, ok := observer.PeerByName(peerName)
|
||||
if ok {
|
||||
tb.Errorf("AssertPeerGone: %s still sees peer %s", observer.Name, peerName)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertPeerHasAllowedIPs checks that a peer has the expected
|
||||
// AllowedIPs prefixes.
|
||||
func AssertPeerHasAllowedIPs(tb testing.TB, observer *TestClient, peerName string, want []netip.Prefix) {
|
||||
tb.Helper()
|
||||
|
||||
peer, ok := observer.PeerByName(peerName)
|
||||
if !ok {
|
||||
tb.Errorf("AssertPeerHasAllowedIPs: %s does not see peer %s", observer.Name, peerName)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
got := make([]netip.Prefix, 0, peer.AllowedIPs().Len())
|
||||
for i := range peer.AllowedIPs().Len() {
|
||||
got = append(got, peer.AllowedIPs().At(i))
|
||||
}
|
||||
|
||||
if len(got) != len(want) {
|
||||
tb.Errorf("AssertPeerHasAllowedIPs: %s sees %s with AllowedIPs %v, want %v",
|
||||
observer.Name, peerName, got, want)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set for comparison.
|
||||
wantSet := make(map[netip.Prefix]bool, len(want))
|
||||
for _, p := range want {
|
||||
wantSet[p] = true
|
||||
}
|
||||
|
||||
for _, p := range got {
|
||||
if !wantSet[p] {
|
||||
tb.Errorf("AssertPeerHasAllowedIPs: %s sees %s with unexpected AllowedIP %v (want %v)",
|
||||
observer.Name, peerName, p, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AssertConsistentState checks that all clients agree on peer
|
||||
// properties: every connected client should see the same set of
|
||||
// peer hostnames.
|
||||
func AssertConsistentState(tb testing.TB, clients []*TestClient) {
|
||||
tb.Helper()
|
||||
|
||||
for _, c := range clients {
|
||||
nm := c.Netmap()
|
||||
if nm == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
peerNames := make(map[string]bool, len(nm.Peers))
|
||||
for _, p := range nm.Peers {
|
||||
hi := p.Hostinfo()
|
||||
if hi.Valid() {
|
||||
peerNames[hi.Hostname()] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check that c sees all other connected clients.
|
||||
for _, other := range clients {
|
||||
if other == c || other.Netmap() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if !peerNames[other.Name] {
|
||||
tb.Errorf("AssertConsistentState: %s does not see %s (peers: %v)",
|
||||
c.Name, other.Name, c.PeerNames())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EventuallyAssertMeshComplete retries AssertMeshComplete up to
|
||||
// timeout, useful when waiting for state to propagate.
|
||||
func EventuallyAssertMeshComplete(tb testing.TB, clients []*TestClient, timeout time.Duration) {
|
||||
tb.Helper()
|
||||
|
||||
expected := len(clients) - 1
|
||||
deadline := time.After(timeout)
|
||||
|
||||
for {
|
||||
allGood := true
|
||||
|
||||
for _, c := range clients {
|
||||
nm := c.Netmap()
|
||||
if nm == nil || len(nm.Peers) < expected {
|
||||
allGood = false
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allGood {
|
||||
// Final strict check.
|
||||
AssertMeshComplete(tb, clients)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-deadline:
|
||||
// Report the failure with details.
|
||||
for _, c := range clients {
|
||||
nm := c.Netmap()
|
||||
|
||||
got := 0
|
||||
if nm != nil {
|
||||
got = len(nm.Peers)
|
||||
}
|
||||
|
||||
if got != expected {
|
||||
tb.Errorf("EventuallyAssertMeshComplete: %s has %d peers, want %d (timeout %v)",
|
||||
c.Name, got, expected, timeout)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Poll again.
|
||||
}
|
||||
}
|
||||
}
|
||||
430
hscontrol/servertest/client.go
Normal file
430
hscontrol/servertest/client.go
Normal file
@@ -0,0 +1,430 @@
|
||||
package servertest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/types/persist"
|
||||
"tailscale.com/util/eventbus"
|
||||
)
|
||||
|
||||
// TestClient wraps a Tailscale controlclient.Direct connected to a
|
||||
// TestServer. It tracks all received NetworkMap updates, providing
|
||||
// helpers to wait for convergence and inspect the client's view of
|
||||
// the network.
|
||||
type TestClient struct {
|
||||
// Name is a human-readable identifier for this client.
|
||||
Name string
|
||||
|
||||
server *TestServer
|
||||
direct *controlclient.Direct
|
||||
authKey string
|
||||
user *types.User
|
||||
|
||||
// Connection lifecycle.
|
||||
pollCtx context.Context //nolint:containedctx // test-only; context stored for cancel control
|
||||
pollCancel context.CancelFunc
|
||||
pollDone chan struct{}
|
||||
|
||||
// Accumulated state from MapResponse callbacks.
|
||||
mu sync.RWMutex
|
||||
netmap *netmap.NetworkMap
|
||||
history []*netmap.NetworkMap
|
||||
|
||||
// updates is a buffered channel that receives a signal
|
||||
// each time a new NetworkMap arrives.
|
||||
updates chan *netmap.NetworkMap
|
||||
|
||||
bus *eventbus.Bus
|
||||
dialer *tsdial.Dialer
|
||||
tracker *health.Tracker
|
||||
}
|
||||
|
||||
// ClientOption configures a TestClient.
|
||||
type ClientOption func(*clientConfig)
|
||||
|
||||
type clientConfig struct {
|
||||
ephemeral bool
|
||||
hostname string
|
||||
tags []string
|
||||
user *types.User
|
||||
}
|
||||
|
||||
// WithEphemeral makes the client register as an ephemeral node.
|
||||
func WithEphemeral() ClientOption {
|
||||
return func(c *clientConfig) { c.ephemeral = true }
|
||||
}
|
||||
|
||||
// WithHostname sets the client's hostname in Hostinfo.
|
||||
func WithHostname(name string) ClientOption {
|
||||
return func(c *clientConfig) { c.hostname = name }
|
||||
}
|
||||
|
||||
// WithTags sets ACL tags on the pre-auth key.
|
||||
func WithTags(tags ...string) ClientOption {
|
||||
return func(c *clientConfig) { c.tags = tags }
|
||||
}
|
||||
|
||||
// WithUser sets the user for the client. If not set, the harness
|
||||
// creates a default user.
|
||||
func WithUser(user *types.User) ClientOption {
|
||||
return func(c *clientConfig) { c.user = user }
|
||||
}
|
||||
|
||||
// NewClient creates a TestClient, registers it with the TestServer
|
||||
// using a pre-auth key, and starts long-polling for map updates.
|
||||
func NewClient(tb testing.TB, server *TestServer, name string, opts ...ClientOption) *TestClient {
|
||||
tb.Helper()
|
||||
|
||||
cc := &clientConfig{
|
||||
hostname: name,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(cc)
|
||||
}
|
||||
|
||||
// Resolve user.
|
||||
user := cc.user
|
||||
if user == nil {
|
||||
// Create a per-client user if none specified.
|
||||
user = server.CreateUser(tb, "user-"+name)
|
||||
}
|
||||
|
||||
// Create pre-auth key.
|
||||
uid := types.UserID(user.ID)
|
||||
|
||||
var authKey string
|
||||
if cc.ephemeral {
|
||||
authKey = server.CreateEphemeralPreAuthKey(tb, uid)
|
||||
} else {
|
||||
authKey = server.CreatePreAuthKey(tb, uid)
|
||||
}
|
||||
|
||||
// Set up Tailscale client infrastructure.
|
||||
bus := eventbus.New()
|
||||
tracker := health.NewTracker(bus)
|
||||
dialer := tsdial.NewDialer(netmon.NewStatic())
|
||||
dialer.SetBus(bus)
|
||||
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
direct, err := controlclient.NewDirect(controlclient.Options{
|
||||
Persist: persist.Persist{},
|
||||
GetMachinePrivateKey: func() (key.MachinePrivate, error) { return machineKey, nil },
|
||||
ServerURL: server.URL,
|
||||
AuthKey: authKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
BackendLogID: "servertest-" + name,
|
||||
Hostname: cc.hostname,
|
||||
},
|
||||
DiscoPublicKey: key.NewDisco().Public(),
|
||||
Logf: tb.Logf,
|
||||
HealthTracker: tracker,
|
||||
Dialer: dialer,
|
||||
Bus: bus,
|
||||
})
|
||||
if err != nil {
|
||||
tb.Fatalf("servertest: NewDirect(%s): %v", name, err)
|
||||
}
|
||||
|
||||
tc := &TestClient{
|
||||
Name: name,
|
||||
server: server,
|
||||
direct: direct,
|
||||
authKey: authKey,
|
||||
user: user,
|
||||
updates: make(chan *netmap.NetworkMap, 64),
|
||||
bus: bus,
|
||||
dialer: dialer,
|
||||
tracker: tracker,
|
||||
}
|
||||
|
||||
tb.Cleanup(func() {
|
||||
tc.cleanup()
|
||||
})
|
||||
|
||||
// Register with the server.
|
||||
tc.register(tb)
|
||||
|
||||
// Start long-polling in the background.
|
||||
tc.startPoll(tb)
|
||||
|
||||
return tc
|
||||
}
|
||||
|
||||
// register performs the initial TryLogin to register the client.
|
||||
func (c *TestClient) register(tb testing.TB) {
|
||||
tb.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
url, err := c.direct.TryLogin(ctx, controlclient.LoginDefault)
|
||||
if err != nil {
|
||||
tb.Fatalf("servertest: TryLogin(%s): %v", c.Name, err)
|
||||
}
|
||||
|
||||
if url != "" {
|
||||
tb.Fatalf("servertest: TryLogin(%s): unexpected auth URL: %s (expected auto-auth with preauth key)", c.Name, url)
|
||||
}
|
||||
}
|
||||
|
||||
// startPoll begins the long-poll MapRequest loop.
|
||||
func (c *TestClient) startPoll(tb testing.TB) {
|
||||
tb.Helper()
|
||||
|
||||
c.pollCtx, c.pollCancel = context.WithCancel(context.Background())
|
||||
c.pollDone = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(c.pollDone)
|
||||
// PollNetMap blocks until ctx is cancelled or the server closes
|
||||
// the connection.
|
||||
_ = c.direct.PollNetMap(c.pollCtx, c)
|
||||
}()
|
||||
}
|
||||
|
||||
// UpdateFullNetmap implements controlclient.NetmapUpdater.
|
||||
// Called by controlclient.Direct when a new NetworkMap is received.
|
||||
func (c *TestClient) UpdateFullNetmap(nm *netmap.NetworkMap) {
|
||||
c.mu.Lock()
|
||||
c.netmap = nm
|
||||
c.history = append(c.history, nm)
|
||||
c.mu.Unlock()
|
||||
|
||||
// Non-blocking send to the updates channel.
|
||||
select {
|
||||
case c.updates <- nm:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup releases all resources.
|
||||
func (c *TestClient) cleanup() {
|
||||
if c.pollCancel != nil {
|
||||
c.pollCancel()
|
||||
}
|
||||
|
||||
if c.pollDone != nil {
|
||||
// Wait for PollNetMap to exit, but don't hang.
|
||||
select {
|
||||
case <-c.pollDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if c.direct != nil {
|
||||
c.direct.Close()
|
||||
}
|
||||
|
||||
if c.dialer != nil {
|
||||
c.dialer.Close()
|
||||
}
|
||||
|
||||
if c.bus != nil {
|
||||
c.bus.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// --- Lifecycle methods ---
|
||||
|
||||
// Disconnect cancels the long-poll context, simulating a clean
|
||||
// client disconnect.
|
||||
func (c *TestClient) Disconnect(tb testing.TB) {
|
||||
tb.Helper()
|
||||
|
||||
if c.pollCancel != nil {
|
||||
c.pollCancel()
|
||||
<-c.pollDone
|
||||
}
|
||||
}
|
||||
|
||||
// Reconnect registers and starts a new long-poll session.
|
||||
// Call Disconnect first, or this will disconnect automatically.
|
||||
func (c *TestClient) Reconnect(tb testing.TB) {
|
||||
tb.Helper()
|
||||
|
||||
// Cancel any existing poll.
|
||||
if c.pollCancel != nil {
|
||||
c.pollCancel()
|
||||
|
||||
select {
|
||||
case <-c.pollDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
tb.Fatalf("servertest: Reconnect(%s): old poll did not exit", c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Re-register and start polling again.
|
||||
c.register(tb)
|
||||
|
||||
c.startPoll(tb)
|
||||
}
|
||||
|
||||
// ReconnectAfter disconnects, waits for d, then reconnects.
|
||||
// The timer works correctly with testing/synctest for
|
||||
// time-controlled tests.
|
||||
func (c *TestClient) ReconnectAfter(tb testing.TB, d time.Duration) {
|
||||
tb.Helper()
|
||||
c.Disconnect(tb)
|
||||
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
|
||||
<-timer.C
|
||||
c.Reconnect(tb)
|
||||
}
|
||||
|
||||
// --- State accessors ---
|
||||
|
||||
// Netmap returns the latest NetworkMap, or nil if none received yet.
|
||||
func (c *TestClient) Netmap() *netmap.NetworkMap {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return c.netmap
|
||||
}
|
||||
|
||||
// WaitForPeers blocks until the client sees at least n peers,
|
||||
// or until timeout expires.
|
||||
func (c *TestClient) WaitForPeers(tb testing.TB, n int, timeout time.Duration) {
|
||||
tb.Helper()
|
||||
|
||||
deadline := time.After(timeout)
|
||||
|
||||
for {
|
||||
if nm := c.Netmap(); nm != nil && len(nm.Peers) >= n {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.updates:
|
||||
// Check again.
|
||||
case <-deadline:
|
||||
nm := c.Netmap()
|
||||
|
||||
got := 0
|
||||
if nm != nil {
|
||||
got = len(nm.Peers)
|
||||
}
|
||||
|
||||
tb.Fatalf("servertest: WaitForPeers(%s, %d): timeout after %v (got %d peers)", c.Name, n, timeout, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForUpdate blocks until the next netmap update arrives or timeout.
|
||||
func (c *TestClient) WaitForUpdate(tb testing.TB, timeout time.Duration) *netmap.NetworkMap {
|
||||
tb.Helper()
|
||||
|
||||
select {
|
||||
case nm := <-c.updates:
|
||||
return nm
|
||||
case <-time.After(timeout):
|
||||
tb.Fatalf("servertest: WaitForUpdate(%s): timeout after %v", c.Name, timeout)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Peers returns the current peer list, or nil.
|
||||
func (c *TestClient) Peers() []tailcfg.NodeView {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.netmap == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.netmap.Peers
|
||||
}
|
||||
|
||||
// PeerByName finds a peer by hostname. Returns the peer and true
|
||||
// if found, zero value and false otherwise.
|
||||
func (c *TestClient) PeerByName(hostname string) (tailcfg.NodeView, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.netmap == nil {
|
||||
return tailcfg.NodeView{}, false
|
||||
}
|
||||
|
||||
for _, p := range c.netmap.Peers {
|
||||
hi := p.Hostinfo()
|
||||
if hi.Valid() && hi.Hostname() == hostname {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
|
||||
return tailcfg.NodeView{}, false
|
||||
}
|
||||
|
||||
// PeerNames returns the hostnames of all current peers.
|
||||
func (c *TestClient) PeerNames() []string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.netmap == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(c.netmap.Peers))
|
||||
for _, p := range c.netmap.Peers {
|
||||
hi := p.Hostinfo()
|
||||
if hi.Valid() {
|
||||
names = append(names, hi.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
// UpdateCount returns the total number of full netmap updates received.
|
||||
func (c *TestClient) UpdateCount() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.history)
|
||||
}
|
||||
|
||||
// History returns a copy of all NetworkMap snapshots in order.
|
||||
func (c *TestClient) History() []*netmap.NetworkMap {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
out := make([]*netmap.NetworkMap, len(c.history))
|
||||
copy(out, c.history)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// SelfName returns the self node's hostname from the latest netmap.
|
||||
func (c *TestClient) SelfName() string {
|
||||
nm := c.Netmap()
|
||||
if nm == nil || !nm.SelfNode.Valid() {
|
||||
return ""
|
||||
}
|
||||
|
||||
return nm.SelfNode.Hostinfo().Hostname()
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer for debug output.
|
||||
func (c *TestClient) String() string {
|
||||
nm := c.Netmap()
|
||||
if nm == nil {
|
||||
return fmt.Sprintf("TestClient(%s, no netmap)", c.Name)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("TestClient(%s, %d peers)", c.Name, len(nm.Peers))
|
||||
}
|
||||
157
hscontrol/servertest/harness.go
Normal file
157
hscontrol/servertest/harness.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package servertest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
)
|
||||
|
||||
// TestHarness orchestrates a TestServer with multiple TestClients,
|
||||
// providing a convenient setup for multi-node control plane tests.
|
||||
type TestHarness struct {
|
||||
Server *TestServer
|
||||
clients []*TestClient
|
||||
|
||||
// Default user shared by all clients unless overridden.
|
||||
defaultUser *types.User
|
||||
}
|
||||
|
||||
// HarnessOption configures a TestHarness.
|
||||
type HarnessOption func(*harnessConfig)
|
||||
|
||||
type harnessConfig struct {
|
||||
serverOpts []ServerOption
|
||||
clientOpts []ClientOption
|
||||
convergenceMax time.Duration
|
||||
}
|
||||
|
||||
func defaultHarnessConfig() *harnessConfig {
|
||||
return &harnessConfig{
|
||||
convergenceMax: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerOptions passes ServerOptions through to the underlying
|
||||
// TestServer.
|
||||
func WithServerOptions(opts ...ServerOption) HarnessOption {
|
||||
return func(c *harnessConfig) { c.serverOpts = append(c.serverOpts, opts...) }
|
||||
}
|
||||
|
||||
// WithDefaultClientOptions applies ClientOptions to every client
|
||||
// created by NewHarness.
|
||||
func WithDefaultClientOptions(opts ...ClientOption) HarnessOption {
|
||||
return func(c *harnessConfig) { c.clientOpts = append(c.clientOpts, opts...) }
|
||||
}
|
||||
|
||||
// WithConvergenceTimeout sets how long WaitForMeshComplete waits.
|
||||
func WithConvergenceTimeout(d time.Duration) HarnessOption {
|
||||
return func(c *harnessConfig) { c.convergenceMax = d }
|
||||
}
|
||||
|
||||
// NewHarness creates a TestServer and numClients connected clients.
|
||||
// All clients share a default user and are registered with reusable
|
||||
// pre-auth keys. The harness waits for all clients to form a
|
||||
// complete mesh before returning.
|
||||
func NewHarness(tb testing.TB, numClients int, opts ...HarnessOption) *TestHarness {
|
||||
tb.Helper()
|
||||
|
||||
hc := defaultHarnessConfig()
|
||||
for _, o := range opts {
|
||||
o(hc)
|
||||
}
|
||||
|
||||
server := NewServer(tb, hc.serverOpts...)
|
||||
|
||||
// Create a shared default user.
|
||||
user := server.CreateUser(tb, "harness-default")
|
||||
|
||||
h := &TestHarness{
|
||||
Server: server,
|
||||
defaultUser: user,
|
||||
}
|
||||
|
||||
// Create and connect clients.
|
||||
for i := range numClients {
|
||||
name := clientName(i)
|
||||
|
||||
copts := append([]ClientOption{WithUser(user)}, hc.clientOpts...)
|
||||
c := NewClient(tb, server, name, copts...)
|
||||
h.clients = append(h.clients, c)
|
||||
}
|
||||
|
||||
// Wait for the mesh to converge.
|
||||
if numClients > 1 {
|
||||
h.WaitForMeshComplete(tb, hc.convergenceMax)
|
||||
} else if numClients == 1 {
|
||||
// Single node: just wait for the first netmap.
|
||||
h.clients[0].WaitForUpdate(tb, hc.convergenceMax)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Client returns the i-th client (0-indexed).
|
||||
func (h *TestHarness) Client(i int) *TestClient {
|
||||
return h.clients[i]
|
||||
}
|
||||
|
||||
// Clients returns all clients.
|
||||
func (h *TestHarness) Clients() []*TestClient {
|
||||
return h.clients
|
||||
}
|
||||
|
||||
// ConnectedClients returns clients that currently have an active
|
||||
// long-poll session (pollDone channel is still open).
|
||||
func (h *TestHarness) ConnectedClients() []*TestClient {
|
||||
var out []*TestClient
|
||||
|
||||
for _, c := range h.clients {
|
||||
select {
|
||||
case <-c.pollDone:
|
||||
// Poll has ended, client is disconnected.
|
||||
default:
|
||||
out = append(out, c)
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// AddClient creates and connects a new client to the existing mesh.
|
||||
func (h *TestHarness) AddClient(tb testing.TB, opts ...ClientOption) *TestClient {
|
||||
tb.Helper()
|
||||
|
||||
name := clientName(len(h.clients))
|
||||
copts := append([]ClientOption{WithUser(h.defaultUser)}, opts...)
|
||||
c := NewClient(tb, h.Server, name, copts...)
|
||||
h.clients = append(h.clients, c)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// WaitForMeshComplete blocks until every connected client sees
|
||||
// (connectedCount - 1) peers.
|
||||
func (h *TestHarness) WaitForMeshComplete(tb testing.TB, timeout time.Duration) {
|
||||
tb.Helper()
|
||||
|
||||
connected := h.ConnectedClients()
|
||||
|
||||
expectedPeers := max(len(connected)-1, 0)
|
||||
|
||||
for _, c := range connected {
|
||||
c.WaitForPeers(tb, expectedPeers, timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForConvergence waits until all connected clients have a
|
||||
// non-nil NetworkMap and their peer counts have stabilised.
|
||||
func (h *TestHarness) WaitForConvergence(tb testing.TB, timeout time.Duration) {
|
||||
tb.Helper()
|
||||
h.WaitForMeshComplete(tb, timeout)
|
||||
}
|
||||
|
||||
func clientName(index int) string {
|
||||
return fmt.Sprintf("node-%d", index)
|
||||
}
|
||||
182
hscontrol/servertest/server.go
Normal file
182
hscontrol/servertest/server.go
Normal file
@@ -0,0 +1,182 @@
|
||||
// Package servertest provides an in-process test harness for Headscale's
|
||||
// control plane. It wires a real Headscale server to real Tailscale
|
||||
// controlclient.Direct instances, enabling fast, deterministic tests
|
||||
// of the full control protocol without Docker or separate processes.
|
||||
package servertest
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
hscontrol "github.com/juanfont/headscale/hscontrol"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// TestServer is an in-process Headscale control server suitable for
|
||||
// use with Tailscale's controlclient.Direct.
|
||||
type TestServer struct {
|
||||
App *hscontrol.Headscale
|
||||
HTTPServer *httptest.Server
|
||||
URL string
|
||||
st *state.State
|
||||
}
|
||||
|
||||
// ServerOption configures a TestServer.
|
||||
type ServerOption func(*serverConfig)
|
||||
|
||||
type serverConfig struct {
|
||||
batchDelay time.Duration
|
||||
bufferedChanSize int
|
||||
ephemeralTimeout time.Duration
|
||||
batcherWorkers int
|
||||
}
|
||||
|
||||
func defaultServerConfig() *serverConfig {
|
||||
return &serverConfig{
|
||||
batchDelay: 50 * time.Millisecond,
|
||||
batcherWorkers: 1,
|
||||
ephemeralTimeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// WithBatchDelay sets the batcher's change coalescing delay.
|
||||
func WithBatchDelay(d time.Duration) ServerOption {
|
||||
return func(c *serverConfig) { c.batchDelay = d }
|
||||
}
|
||||
|
||||
// WithBufferedChanSize sets the per-node map session channel buffer.
|
||||
func WithBufferedChanSize(n int) ServerOption {
|
||||
return func(c *serverConfig) { c.bufferedChanSize = n }
|
||||
}
|
||||
|
||||
// WithEphemeralTimeout sets the ephemeral node inactivity timeout.
|
||||
func WithEphemeralTimeout(d time.Duration) ServerOption {
|
||||
return func(c *serverConfig) { c.ephemeralTimeout = d }
|
||||
}
|
||||
|
||||
// NewServer creates and starts a Headscale test server.
|
||||
// The server is fully functional and accepts real Tailscale control
|
||||
// protocol connections over Noise.
|
||||
func NewServer(tb testing.TB, opts ...ServerOption) *TestServer {
|
||||
tb.Helper()
|
||||
|
||||
sc := defaultServerConfig()
|
||||
for _, o := range opts {
|
||||
o(sc)
|
||||
}
|
||||
|
||||
tmpDir := tb.TempDir()
|
||||
|
||||
cfg := types.Config{
|
||||
// Placeholder; updated below once httptest server starts.
|
||||
ServerURL: "http://localhost:0",
|
||||
NoisePrivateKeyPath: tmpDir + "/noise_private.key",
|
||||
EphemeralNodeInactivityTimeout: sc.ephemeralTimeout,
|
||||
Database: types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
Policy: types.PolicyConfig{
|
||||
Mode: types.PolicyModeDB,
|
||||
},
|
||||
Tuning: types.Tuning{
|
||||
BatchChangeDelay: sc.batchDelay,
|
||||
BatcherWorkers: sc.batcherWorkers,
|
||||
NodeMapSessionBufferedChanSize: sc.bufferedChanSize,
|
||||
},
|
||||
}
|
||||
|
||||
app, err := hscontrol.NewHeadscale(&cfg)
|
||||
if err != nil {
|
||||
tb.Fatalf("servertest: NewHeadscale: %v", err)
|
||||
}
|
||||
|
||||
// Set a minimal DERP map so MapResponse generation works.
|
||||
app.GetState().SetDERPMap(&tailcfg.DERPMap{
|
||||
Regions: map[int]*tailcfg.DERPRegion{
|
||||
900: {
|
||||
RegionID: 900,
|
||||
RegionCode: "test",
|
||||
RegionName: "Test Region",
|
||||
Nodes: []*tailcfg.DERPNode{{
|
||||
Name: "test0",
|
||||
RegionID: 900,
|
||||
HostName: "127.0.0.1",
|
||||
IPv4: "127.0.0.1",
|
||||
DERPPort: -1, // not a real DERP, just needed for MapResponse
|
||||
}},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Start subsystems.
|
||||
app.StartBatcherForTest(tb)
|
||||
app.StartEphemeralGCForTest(tb)
|
||||
|
||||
// Start the HTTP server with Headscale's full handler (including
|
||||
// /key and /ts2021 Noise upgrade).
|
||||
ts := httptest.NewServer(app.HTTPHandler())
|
||||
tb.Cleanup(ts.Close)
|
||||
|
||||
// Now update the config to point at the real URL so that
|
||||
// MapResponse.ControlURL etc. are correct.
|
||||
app.SetServerURLForTest(tb, ts.URL)
|
||||
|
||||
return &TestServer{
|
||||
App: app,
|
||||
HTTPServer: ts,
|
||||
URL: ts.URL,
|
||||
st: app.GetState(),
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the server's state manager for creating users,
|
||||
// nodes, and pre-auth keys.
|
||||
func (s *TestServer) State() *state.State {
|
||||
return s.st
|
||||
}
|
||||
|
||||
// CreateUser creates a test user and returns it.
|
||||
func (s *TestServer) CreateUser(tb testing.TB, name string) *types.User {
|
||||
tb.Helper()
|
||||
|
||||
u, _, err := s.st.CreateUser(types.User{Name: name})
|
||||
if err != nil {
|
||||
tb.Fatalf("servertest: CreateUser(%q): %v", name, err)
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// CreatePreAuthKey creates a reusable pre-auth key for the given user.
|
||||
func (s *TestServer) CreatePreAuthKey(tb testing.TB, userID types.UserID) string {
|
||||
tb.Helper()
|
||||
|
||||
uid := userID
|
||||
|
||||
pak, err := s.st.CreatePreAuthKey(&uid, true, false, nil, nil)
|
||||
if err != nil {
|
||||
tb.Fatalf("servertest: CreatePreAuthKey: %v", err)
|
||||
}
|
||||
|
||||
return pak.Key
|
||||
}
|
||||
|
||||
// CreateEphemeralPreAuthKey creates an ephemeral pre-auth key.
|
||||
func (s *TestServer) CreateEphemeralPreAuthKey(tb testing.TB, userID types.UserID) string {
|
||||
tb.Helper()
|
||||
|
||||
uid := userID
|
||||
|
||||
pak, err := s.st.CreatePreAuthKey(&uid, false, true, nil, nil)
|
||||
if err != nil {
|
||||
tb.Fatalf("servertest: CreateEphemeralPreAuthKey: %v", err)
|
||||
}
|
||||
|
||||
return pak.Key
|
||||
}
|
||||
Reference in New Issue
Block a user