Files
headscale/hscontrol/mapper/mapper_test.go
Kristoffer Dalby 6fcff9e352 mapper, state: deliver nodeAttrs through MapResponse and harden nextdns DoH rewrite
WithSelfNode and buildTailPeers merge each node's policy CapMap
into the tailcfg.Node.CapMap they emit. State.NodeCapMap and
State.NodeCapMaps wrap the policy manager: NodeCapMap returns a
defensive clone per call; NodeCapMaps snapshots the full per-node
map once for batched callers, amortising pm.mu acquisition across
a peer build.

generateDNSConfig grew a per-node CapMap argument so it can apply
nodeAttr-driven DNS overlays. The nextdns DoH rewrite hardens against
policy-controlled inputs:

  - nextDNSDoHHost anchors the prefix match instead of substring,
    so a hostile resolver URL cannot smuggle a nextdns hostname in
    a path or query.
  - nextDNSProfileFromCapMap accepts only profile names matching
    [A-Za-z0-9._-]{1,64} and picks the lexicographically first when
    multiple are granted -- deterministic, no shell metacharacters
    or URL fragments through.
  - addNextDNSMetadata composes the rewritten URL via url.Parse +
    url.Values rather than fmt.Sprintf, so existing query strings
    on the resolver URL survive and metadata cannot inject a new
    component.

WithTaildropEnabled in servertest controls cfg.Taildrop.Enabled per
test so cap/file-sharing emission can be toggled in tests that need
to verify the off path.
2026-05-13 14:22:30 +02:00

196 lines
4.3 KiB
Go

package mapper
import (
"fmt"
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
)
var iap = func(ipStr string) *netip.Addr {
ip := netip.MustParseAddr(ipStr)
return &ip
}
func TestDNSConfigMapResponse(t *testing.T) {
tests := []struct {
magicDNS bool
want *tailcfg.DNSConfig
}{
{
magicDNS: true,
want: &tailcfg.DNSConfig{
Routes: map[string][]*dnstype.Resolver{},
Domains: []string{
"foobar.headscale.net",
},
Proxied: true,
},
},
{
magicDNS: false,
want: &tailcfg.DNSConfig{
Domains: []string{"foobar.headscale.net"},
Proxied: false,
},
},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("with-magicdns-%v", tt.magicDNS), func(t *testing.T) {
mach := func(hostname, username string, userid uint) *types.Node {
return &types.Node{
Hostname: hostname,
UserID: new(userid),
User: &types.User{
Name: username,
},
}
}
baseDomain := "foobar.headscale.net"
dnsConfigOrig := tailcfg.DNSConfig{
Routes: make(map[string][]*dnstype.Resolver),
Domains: []string{baseDomain},
Proxied: tt.magicDNS,
}
nodeInShared1 := mach("test_get_shared_nodes_1", "shared1", 1)
got := generateDNSConfig(
&types.Config{
TailcfgDNSConfig: &dnsConfigOrig,
},
nodeInShared1.View(),
nil,
)
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("expandAlias() unexpected result (-want +got):\n%s", diff)
}
})
}
}
func TestNextDNSCapMapRendering(t *testing.T) {
t.Parallel()
mkConfig := func(addrs ...string) *types.Config {
resolvers := make([]*dnstype.Resolver, len(addrs))
for i, a := range addrs {
resolvers[i] = &dnstype.Resolver{Addr: a}
}
return &types.Config{
TailcfgDNSConfig: &tailcfg.DNSConfig{
Resolvers: resolvers,
},
}
}
mkNode := func() types.NodeView {
return (&types.Node{
ID: 1,
Hostname: "node1",
IPv4: iap("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{OS: "linux"},
}).View()
}
// resolverAddr extracts the first resolver's address with a
// bounds check. Without it, a regression that drops the
// resolver list would nil-panic instead of failing cleanly.
resolverAddr := func(t *testing.T, got *tailcfg.DNSConfig) string {
t.Helper()
if got == nil {
t.Fatalf("generateDNSConfig returned nil")
}
if len(got.Resolvers) == 0 {
t.Fatalf("generateDNSConfig returned no Resolvers")
}
return got.Resolvers[0].Addr
}
t.Run("no_capmap_metadata_appended", func(t *testing.T) {
t.Parallel()
got := generateDNSConfig(
mkConfig("https://dns.nextdns.io/abc"),
mkNode(),
nil,
)
want := "https://dns.nextdns.io/abc?device_ip=100.64.0.1&device_model=linux&device_name=node1"
if addr := resolverAddr(t, got); addr != want {
t.Errorf("addr = %q, want %q", addr, want)
}
})
t.Run("profile_overrides_global", func(t *testing.T) {
t.Parallel()
capMap := tailcfg.NodeCapMap{
"nextdns:override": []tailcfg.RawMessage{},
}
got := generateDNSConfig(
mkConfig("https://dns.nextdns.io/global"),
mkNode(),
capMap,
)
want := "https://dns.nextdns.io/override?device_ip=100.64.0.1&device_model=linux&device_name=node1"
if addr := resolverAddr(t, got); addr != want {
t.Errorf("addr = %q, want %q", addr, want)
}
})
t.Run("no_device_info_skips_metadata", func(t *testing.T) {
t.Parallel()
capMap := tailcfg.NodeCapMap{
"nextdns:abc": []tailcfg.RawMessage{},
"nextdns:no-device-info": []tailcfg.RawMessage{},
}
got := generateDNSConfig(
mkConfig("https://dns.nextdns.io/global"),
mkNode(),
capMap,
)
want := "https://dns.nextdns.io/abc"
if addr := resolverAddr(t, got); addr != want {
t.Errorf("addr = %q, want %q", addr, want)
}
})
t.Run("non_nextdns_resolver_untouched", func(t *testing.T) {
t.Parallel()
capMap := tailcfg.NodeCapMap{
"nextdns:abc": []tailcfg.RawMessage{},
}
got := generateDNSConfig(
mkConfig("https://dns.example.org/dns-query"),
mkNode(),
capMap,
)
want := "https://dns.example.org/dns-query"
if addr := resolverAddr(t, got); addr != want {
t.Errorf("non-nextdns resolver was rewritten: %q", addr)
}
})
}