diff --git a/hscontrol/app.go b/hscontrol/app.go index 30134ac6..77b0c103 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -99,6 +99,10 @@ type Headscale struct { DERPServer *derpServer.DERPServer + // realIPMiddleware is nil when cfg.TrustedProxies is empty; the + // router skips the mount and r.RemoteAddr stays as the TCP peer. + realIPMiddleware func(http.Handler) http.Handler + // Things that generate changes extraRecordMan *dns.ExtraRecordsMan authProvider AuthProvider @@ -140,6 +144,13 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { state: s, } + if len(cfg.TrustedProxies) > 0 { + app.realIPMiddleware, err = trustedProxyRealIP(cfg.TrustedProxies) + if err != nil { + return nil, fmt.Errorf("building trusted_proxies middleware: %w", err) + } + } + // Initialize ephemeral garbage collector ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) { node, ok := app.state.GetNodeByID(ni) @@ -512,7 +523,11 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux { }, })) r.Use(middleware.RequestID) - r.Use(middleware.RealIP) + + if h.realIPMiddleware != nil { + r.Use(h.realIPMiddleware) + } + r.Use(middleware.RequestLogger(&zerologRequestLogger{})) r.Use(middleware.Recoverer) r.Use(securityHeaders) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index d8a26cb1..9ad9a274 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -158,7 +158,11 @@ func (h *Headscale) NoiseUpgradeHandler( }, })) r.Use(middleware.RequestID) - r.Use(middleware.RealIP) + + if h.realIPMiddleware != nil { + r.Use(h.realIPMiddleware) + } + r.Use(middleware.RequestLogger(&zerologRequestLogger{})) r.Use(middleware.Recoverer) diff --git a/hscontrol/realip.go b/hscontrol/realip.go new file mode 100644 index 00000000..2b363060 --- /dev/null +++ b/hscontrol/realip.go @@ -0,0 +1,124 @@ +package hscontrol + +import ( + "fmt" + "net" + "net/http" + "net/netip" + + realclientip "github.com/realclientip/realclientip-go" +) + +const ( + headerTrueClientIP = "True-Client-IP" + headerXRealIP = "X-Real-IP" + headerXForwardedFor = "X-Forwarded-For" +) + +var proxyHeaders = [...]string{headerTrueClientIP, headerXRealIP, headerXForwardedFor} + +// trustedProxyRealIP rewrites r.RemoteAddr from proxy headers when the +// peer is in trusted; for any other peer the headers are stripped so a +// downstream handler cannot read a spoofed value. X-Forwarded-For uses +// RightmostTrustedRangeStrategy so prepending a value cannot win in a +// proxy chain. +func trustedProxyRealIP(trusted []netip.Prefix) (func(http.Handler) http.Handler, error) { + ranges := make([]net.IPNet, 0, len(trusted)) + for _, p := range trusted { + ranges = append(ranges, prefixToIPNet(p)) + } + + trueClientIP, err := realclientip.NewSingleIPHeaderStrategy(headerTrueClientIP) + if err != nil { + return nil, fmt.Errorf("%s strategy: %w", headerTrueClientIP, err) + } + + xRealIP, err := realclientip.NewSingleIPHeaderStrategy(headerXRealIP) + if err != nil { + return nil, fmt.Errorf("%s strategy: %w", headerXRealIP, err) + } + + xForwardedFor, err := realclientip.NewRightmostTrustedRangeStrategy(headerXForwardedFor, ranges) + if err != nil { + return nil, fmt.Errorf("%s strategy: %w", headerXForwardedFor, err) + } + + strategy := realclientip.NewChainStrategy(trueClientIP, xRealIP, xForwardedFor) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !peerTrusted(r.RemoteAddr, trusted) { + for _, h := range proxyHeaders { + r.Header.Del(h) + } + + next.ServeHTTP(w, r) + + return + } + + // Proxy headers carry no port; write the IP alone so + // `remote=` logs the resolved client, not the proxy's + // ephemeral TCP port. + if ip := strategy.ClientIP(r.Header, r.RemoteAddr); ip != "" { + r.RemoteAddr = ip + } + + next.ServeHTTP(w, r) + }) + }, nil +} + +// peerTrusted returns false on unparseable input so callers fall +// through to the header-stripping path. +func peerTrusted(remoteAddr string, trusted []netip.Prefix) bool { + addr, ok := parsePeerAddr(remoteAddr) + if !ok { + return false + } + + for _, p := range trusted { + if p.Contains(addr) { + return true + } + } + + return false +} + +func parsePeerAddr(remoteAddr string) (netip.Addr, bool) { + if remoteAddr == "" { + return netip.Addr{}, false + } + + ap, err := netip.ParseAddrPort(remoteAddr) + if err == nil { + return ap.Addr(), true + } + + host, _, splitErr := net.SplitHostPort(remoteAddr) + if splitErr != nil { + host = remoteAddr + } + + addr, err := netip.ParseAddr(host) + if err != nil { + return netip.Addr{}, false + } + + return addr, true +} + +// prefixToIPNet bridges to realclientip-go, which predates net/netip. +func prefixToIPNet(p netip.Prefix) net.IPNet { + addr := p.Addr() + if addr.Is4() { + b := addr.As4() + + return net.IPNet{IP: b[:], Mask: net.CIDRMask(p.Bits(), 32)} + } + + b := addr.As16() + + return net.IPNet{IP: b[:], Mask: net.CIDRMask(p.Bits(), 128)} +} diff --git a/hscontrol/realip_test.go b/hscontrol/realip_test.go new file mode 100644 index 00000000..2d3b3371 --- /dev/null +++ b/hscontrol/realip_test.go @@ -0,0 +1,242 @@ +package hscontrol + +import ( + "net/http" + "net/http/httptest" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +//nolint:goconst // repeated test fixtures (addresses, headers), not refactor candidates +func TestPeerTrusted(t *testing.T) { + trusted := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("127.0.0.1/32"), + netip.MustParsePrefix("fd00::/8"), + } + + tests := []struct { + name string + remoteAddr string + want bool + }{ + {name: "v4-in-range", remoteAddr: "10.0.0.5:1234", want: true}, + {name: "v4-edge", remoteAddr: "10.0.255.255:1", want: true}, + {name: "v4-out-of-range", remoteAddr: "10.1.0.0:1234", want: false}, + {name: "v4-loopback", remoteAddr: "127.0.0.1:443", want: true}, + {name: "v6-in-range", remoteAddr: "[fd00::1]:443", want: true}, + {name: "v6-out-of-range", remoteAddr: "[2001:db8::1]:443", want: false}, + {name: "no-port", remoteAddr: "10.0.0.5", want: true}, + {name: "empty", remoteAddr: "", want: false}, + {name: "non-ip-host", remoteAddr: "localhost:8080", want: false}, + {name: "garbage", remoteAddr: "not-a-thing", want: false}, + {name: "unix-socket", remoteAddr: "@", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := peerTrusted(tt.remoteAddr, trusted) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestPrefixToIPNet(t *testing.T) { + tests := []struct { + name string + in netip.Prefix + want string + }{ + {name: "v4", in: netip.MustParsePrefix("10.0.0.0/16"), want: "10.0.0.0/16"}, + {name: "v4-host", in: netip.MustParsePrefix("127.0.0.1/32"), want: "127.0.0.1/32"}, + {name: "v6", in: netip.MustParsePrefix("fd00::/8"), want: "fd00::/8"}, + {name: "v6-host", in: netip.MustParsePrefix("::1/128"), want: "::1/128"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := prefixToIPNet(tt.in) + assert.Equal(t, tt.want, got.String()) + }) + } +} + +//nolint:goconst // repeated test fixtures (addresses, headers), not refactor candidates +func TestTrustedProxyRealIP(t *testing.T) { + trusted := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("fd00::/8"), + } + + mw, err := trustedProxyRealIP(trusted) + require.NoError(t, err) + + tests := []struct { + name string + remoteAddr string + headers map[string]string + wantRemote string + wantStripped []string + wantKept map[string]string + }{ + { + name: "untrusted/no-headers", + remoteAddr: "203.0.113.1:1234", + wantRemote: "203.0.113.1:1234", + }, + { + name: "untrusted/strips-x-real-ip", + remoteAddr: "203.0.113.1:1234", + headers: map[string]string{"X-Real-IP": "1.2.3.4"}, + wantRemote: "203.0.113.1:1234", + wantStripped: []string{"X-Real-IP"}, + }, + { + name: "untrusted/strips-x-forwarded-for", + remoteAddr: "203.0.113.1:1234", + headers: map[string]string{"X-Forwarded-For": "1.2.3.4"}, + wantRemote: "203.0.113.1:1234", + wantStripped: []string{"X-Forwarded-For"}, + }, + { + name: "untrusted/strips-true-client-ip", + remoteAddr: "203.0.113.1:1234", + headers: map[string]string{"True-Client-IP": "1.2.3.4"}, + wantRemote: "203.0.113.1:1234", + wantStripped: []string{"True-Client-IP"}, + }, + { + name: "untrusted/strips-all-three", + remoteAddr: "203.0.113.1:1234", + headers: map[string]string{ + "True-Client-IP": "1.2.3.4", + "X-Real-IP": "5.6.7.8", + "X-Forwarded-For": "9.10.11.12", + }, + wantRemote: "203.0.113.1:1234", + wantStripped: []string{"True-Client-IP", "X-Real-IP", "X-Forwarded-For"}, + }, + { + name: "untrusted/keeps-unrelated-header", + remoteAddr: "203.0.113.1:1234", + headers: map[string]string{"User-Agent": "curl/8", "X-Real-IP": "1.2.3.4"}, + wantRemote: "203.0.113.1:1234", + wantStripped: []string{"X-Real-IP"}, + wantKept: map[string]string{"User-Agent": "curl/8"}, + }, + { + name: "trusted/no-headers", + remoteAddr: "10.0.0.5:1234", + wantRemote: "10.0.0.5:1234", + }, + { + name: "trusted/x-real-ip", + remoteAddr: "10.0.0.5:1234", + headers: map[string]string{"X-Real-IP": "1.2.3.4"}, + wantRemote: "1.2.3.4", + }, + { + name: "trusted/true-client-ip-wins-over-others", + remoteAddr: "10.0.0.5:1234", + headers: map[string]string{ + "True-Client-IP": "1.2.3.4", + "X-Real-IP": "5.6.7.8", + "X-Forwarded-For": "9.10.11.12", + }, + wantRemote: "1.2.3.4", + }, + { + name: "trusted/x-real-ip-wins-over-xff", + remoteAddr: "10.0.0.5:1234", + headers: map[string]string{ + "X-Real-IP": "1.2.3.4", + "X-Forwarded-For": "9.10.11.12", + }, + wantRemote: "1.2.3.4", + }, + { + name: "trusted/xff-rightmost-walk-discards-trusted-hop", + remoteAddr: "10.0.0.5:1234", + headers: map[string]string{"X-Forwarded-For": "203.0.113.99, 10.0.0.5"}, + wantRemote: "203.0.113.99", + }, + { + name: "trusted/xff-all-trusted-leaves-remote-alone", + remoteAddr: "10.0.0.5:1234", + headers: map[string]string{"X-Forwarded-For": "10.0.0.99, 10.0.0.5"}, + wantRemote: "10.0.0.5:1234", + }, + { + name: "trusted/ipv6-peer-v6-real-ip", + remoteAddr: "[fd00::1]:1234", + headers: map[string]string{"X-Real-IP": "2001:db8::1"}, + wantRemote: "2001:db8::1", + }, + { + name: "ipv6-untrusted-strips-header", + remoteAddr: "[2001:db8::1]:1234", + headers: map[string]string{"X-Real-IP": "1.2.3.4"}, + wantRemote: "[2001:db8::1]:1234", + wantStripped: []string{"X-Real-IP"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var observed *http.Request + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + observed = r + })) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) + req.RemoteAddr = tt.remoteAddr + + for k, v := range tt.headers { + req.Header.Set(k, v) + } + + handler.ServeHTTP(httptest.NewRecorder(), req) + + require.NotNil(t, observed, "handler must be invoked") + assert.Equal(t, tt.wantRemote, observed.RemoteAddr) + + for _, h := range tt.wantStripped { + assert.Empty(t, observed.Header.Get(h), "header %s should be stripped", h) + } + + for k, v := range tt.wantKept { + assert.Equal(t, v, observed.Header.Get(k), "header %s should be preserved", k) + } + }) + } +} + +func TestTrustedProxyRealIPEmptyTrusted(t *testing.T) { + // Sanity: factory accepts an empty slice without error. Wiring code is + // responsible for skipping the mount entirely, but the factory itself + // must remain safe for tests that compose it manually. + mw, err := trustedProxyRealIP(nil) + require.NoError(t, err) + require.NotNil(t, mw) + + var observed *http.Request + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + observed = r + })) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) + req.RemoteAddr = "10.0.0.5:1234" + req.Header.Set("X-Real-IP", "1.2.3.4") + handler.ServeHTTP(httptest.NewRecorder(), req) + + require.NotNil(t, observed) + // No prefix is trusted, so even an LAN-looking peer is not trusted; the + // spoofed header must be stripped and RemoteAddr left alone. + assert.Equal(t, "10.0.0.5:1234", observed.RemoteAddr) + assert.Empty(t, observed.Header.Get("X-Real-IP")) +} diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 8df7c7bb..bb0a8918 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -35,6 +35,7 @@ var ( errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + errTrustedProxyZeroRange = errors.New("0.0.0.0/0 and ::/0 are not allowed") ErrNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") ErrInvalidAllocationStrategy = errors.New("invalid prefix allocation strategy") ) @@ -98,6 +99,7 @@ type Config struct { MetricsAddr string GRPCAddr string GRPCAllowInsecure bool + TrustedProxies []netip.Prefix Node NodeConfig PrefixV4 *netip.Prefix PrefixV6 *netip.Prefix @@ -1049,6 +1051,31 @@ func prefixV6() (*netip.Prefix, bool, error) { return &prefixV6, !ipSet.ContainsPrefix(prefixV6), nil } +// trustedProxies rejects 0.0.0.0/0 and ::/0 because they defeat the +// peer-trust gate and almost always indicate misconfiguration. +func trustedProxies() ([]netip.Prefix, error) { + raw := viper.GetStringSlice("trusted_proxies") + if len(raw) == 0 { + return nil, nil + } + + out := make([]netip.Prefix, 0, len(raw)) + for i, s := range raw { + p, err := netip.ParsePrefix(s) + if err != nil { + return nil, fmt.Errorf("trusted_proxies[%d] %q: %w", i, s, err) + } + + if p.Bits() == 0 { + return nil, fmt.Errorf("trusted_proxies[%d] %q: %w", i, s, errTrustedProxyZeroRange) + } + + out = append(out, p.Masked()) + } + + return out, nil +} + // LoadCLIConfig returns the needed configuration for the CLI client // of Headscale to connect to a Headscale server. func LoadCLIConfig() (*Config, error) { @@ -1088,6 +1115,11 @@ func LoadServerConfig() (*Config, error) { return nil, err } + trusted, err := trustedProxies() + if err != nil { + return nil, err + } + if prefix4 == nil && prefix6 == nil { return nil, ErrNoPrefixConfigured } @@ -1178,6 +1210,7 @@ func LoadServerConfig() (*Config, error) { MetricsAddr: viper.GetString("metrics_listen_addr"), GRPCAddr: viper.GetString("grpc_listen_addr"), GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"), + TrustedProxies: trusted, DisableUpdateCheck: false, PrefixV4: prefix4, diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index d3b23a2d..8943683a 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -3,6 +3,7 @@ package types import ( "encoding/json" "fmt" + "net/netip" "os" "path/filepath" "testing" @@ -510,3 +511,93 @@ func TestConfigJSONOmitsSecrets(t *testing.T) { "marshalled Config must not contain secret %q", secret) } } + +//nolint:goconst // repeated CIDR strings are test fixtures, not refactor candidates +func TestTrustedProxies(t *testing.T) { + tests := []struct { + name string + input any + want []netip.Prefix + wantErr string + }{ + { + name: "unset", + input: nil, + want: nil, + }, + { + name: "empty", + input: []string{}, + want: nil, + }, + { + name: "single-v4", + input: []string{"10.0.0.0/16"}, + want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/16")}, + }, + { + name: "single-v6", + input: []string{"fd00::/8"}, + want: []netip.Prefix{netip.MustParsePrefix("fd00::/8")}, + }, + { + name: "mixed-v4-v6", + input: []string{"127.0.0.1/32", "::1/128", "10.0.0.0/16"}, + want: []netip.Prefix{ + netip.MustParsePrefix("127.0.0.1/32"), + netip.MustParsePrefix("::1/128"), + netip.MustParsePrefix("10.0.0.0/16"), + }, + }, + { + name: "non-canonical-masked", + input: []string{"10.0.0.5/16"}, + want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/16")}, + }, + { + name: "bare-ip-rejected", + input: []string{"10.0.0.1"}, + wantErr: `trusted_proxies[0] "10.0.0.1"`, + }, + { + name: "garbage-reports-index", + input: []string{"10.0.0.0/16", "not-an-ip"}, + wantErr: `trusted_proxies[1] "not-an-ip"`, + }, + { + name: "ipv4-zero-rejected", + input: []string{"0.0.0.0/0"}, + wantErr: "0.0.0.0/0 and ::/0 are not allowed", + }, + { + name: "ipv6-zero-rejected", + input: []string{"::/0"}, + wantErr: "0.0.0.0/0 and ::/0 are not allowed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + viper.Reset() + + if tt.input != nil { + viper.Set("trusted_proxies", tt.input) + } + + got, err := trustedProxies() + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + + return + } + + require.NoError(t, err) + + if diff := cmp.Diff(tt.want, got, cmpopts.EquateComparable(netip.Prefix{})); diff != "" { + t.Errorf("trustedProxies() mismatch (-want +got):\n%s", diff) + } + }) + } +}