From 30338441c1c95d4aa975d52b6f42a9ac7cc20b96 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 24 Feb 2026 18:47:40 +0000 Subject: [PATCH] app: switch from gorilla to chi mux Replace gorilla/mux with go-chi/chi as the HTTP router and add a custom zerolog-based request logger to replace chi's default stdlib-based middleware.Logger, consistent with the rest of the application. Updates #1850 --- flake.nix | 2 +- go.mod | 2 + go.sum | 4 ++ hscontrol/app.go | 120 +++++++++++++++++++++++++++++++++------------ hscontrol/noise.go | 106 ++++++++++++++++++++++++++++++--------- 5 files changed, 179 insertions(+), 55 deletions(-) diff --git a/flake.nix b/flake.nix index 8a1ba421..ae02d0ff 100644 --- a/flake.nix +++ b/flake.nix @@ -27,7 +27,7 @@ let pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system}; buildGo = pkgs.buildGo126Module; - vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0="; + vendorHash = "sha256-oUN53ELb3+xn4yA7lEfXyT2c7NxbQC6RtbkGVq6+RLU="; in { headscale = buildGo { diff --git a/go.mod b/go.mod index c99d4ddd..3adc7e48 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,8 @@ require ( github.com/docker/docker v28.5.2+incompatible github.com/fsnotify/fsnotify v1.9.0 github.com/glebarez/sqlite v1.11.0 + github.com/go-chi/chi/v5 v5.2.5 + github.com/go-chi/metrics v0.1.1 github.com/go-gormigrate/gormigrate/v2 v2.1.5 github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e github.com/gofrs/uuid/v5 v5.4.0 diff --git a/go.sum b/go.sum index e9c39e36..4c5f48ac 100644 --- a/go.sum +++ b/go.sum @@ -181,6 +181,10 @@ github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-chi/metrics v0.1.1 h1:CXhbnkAVVjb0k73EBRQ6Z2YdWFnbXZgNtg1Mboguibk= +github.com/go-chi/metrics v0.1.1/go.mod h1:mcGTM1pPalP7WCtb+akNYFO/lwNwBBLCuedepqjoPn4= github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8= github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M= github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= diff --git a/hscontrol/app.go b/hscontrol/app.go index abd29a45..bb4733f7 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -20,7 +20,9 @@ import ( "github.com/cenkalti/backoff/v5" "github.com/davecgh/go-spew/spew" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/metrics" grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -457,50 +459,57 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error { return os.Remove(h.cfg.UnixSocket) } -func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { - router := mux.NewRouter() - router.Use(prometheusMiddleware) +func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux { + r := chi.NewRouter() + r.Use(metrics.Collector(metrics.CollectorOpts{ + Host: false, + Proto: true, + Skip: func(r *http.Request) bool { + return r.Method != http.MethodOptions + }, + })) + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.RequestLogger(&zerologRequestLogger{})) + r.Use(middleware.Recoverer) - router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler). - Methods(http.MethodPost, http.MethodGet) + r.Post(ts2021UpgradePath, h.NoiseUpgradeHandler) - router.HandleFunc("/robots.txt", h.RobotsHandler).Methods(http.MethodGet) - router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) - router.HandleFunc("/version", h.VersionHandler).Methods(http.MethodGet) - router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) - router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler). - Methods(http.MethodGet) + r.Get("/robots.txt", h.RobotsHandler) + r.Get("/health", h.HealthHandler) + r.Get("/version", h.VersionHandler) + r.Get("/key", h.KeyHandler) + r.Get("/register/{registration_id}", h.authProvider.RegisterHandler) if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { - router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet) + r.Get("/oidc/callback", provider.OIDCCallbackHandler) } - router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) - router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). - Methods(http.MethodGet) - router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) + r.Get("/apple", h.AppleConfigMessage) + r.Get("/apple/{platform}", h.ApplePlatformConfig) + r.Get("/windows", h.WindowsConfigMessage) // TODO(kristoffer): move swagger into a package - router.HandleFunc("/swagger", headscale.SwaggerUI).Methods(http.MethodGet) - router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1). - Methods(http.MethodGet) + r.Get("/swagger", headscale.SwaggerUI) + r.Get("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1) - router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost) + r.Post("/verify", h.VerifyHandler) if h.cfg.DERP.ServerEnabled { - router.HandleFunc("/derp", h.DERPServer.DERPHandler) - router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler) - router.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler) - router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap())) + r.HandleFunc("/derp", h.DERPServer.DERPHandler) + r.HandleFunc("/derp/probe", derpServer.DERPProbeHandler) + r.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler) + r.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap())) } - apiRouter := router.PathPrefix("/api").Subrouter() - apiRouter.Use(h.httpAuthenticationMiddleware) - apiRouter.PathPrefix("/v1/").HandlerFunc(grpcMux.ServeHTTP) - router.HandleFunc("/favicon.ico", FaviconHandler) - router.PathPrefix("/").HandlerFunc(BlankHandler) + r.Route("/api", func(r chi.Router) { + r.Use(h.httpAuthenticationMiddleware) + r.HandleFunc("/v1/*", grpcMux.ServeHTTP) + }) + r.Get("/favicon.ico", FaviconHandler) + r.Get("/", BlankHandler) - return router + return r } // Serve launches the HTTP and gRPC server service Headscale and the API. @@ -1083,3 +1092,52 @@ func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) { return resp, nil } + +// zerologRequestLogger implements chi's middleware.LogFormatter +// to route HTTP request logs through zerolog. +type zerologRequestLogger struct{} + +func (z *zerologRequestLogger) NewLogEntry( + r *http.Request, +) middleware.LogEntry { + return &zerologLogEntry{ + method: r.Method, + path: r.URL.Path, + proto: r.Proto, + remote: r.RemoteAddr, + } +} + +type zerologLogEntry struct { + method string + path string + proto string + remote string +} + +func (e *zerologLogEntry) Write( + status, bytes int, + header http.Header, + elapsed time.Duration, + extra any, +) { + log.Info(). + Str("method", e.method). + Str("path", e.path). + Str("proto", e.proto). + Str("remote", e.remote). + Int("status", status). + Int("bytes", bytes). + Dur("elapsed", elapsed). + Msg("http request") +} + +func (e *zerologLogEntry) Panic( + v any, + stack []byte, +) { + log.Error(). + Interface("panic", v). + Bytes("stack", stack). + Msg("http handler panic") +} diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 2880f33a..57a79b96 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -8,7 +8,9 @@ import ( "io" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/metrics" "github.com/juanfont/headscale/hscontrol/capver" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" @@ -69,7 +71,7 @@ func (h *Headscale) NoiseUpgradeHandler( return } - noiseServer := noiseServer{ + ns := noiseServer{ headscale: h, challenge: key.NewChallenge(), } @@ -79,42 +81,88 @@ func (h *Headscale) NoiseUpgradeHandler( writer, req, *h.noisePrivateKey, - noiseServer.earlyNoise, + ns.earlyNoise, ) if err != nil { httpError(writer, fmt.Errorf("upgrading noise connection: %w", err)) return } - noiseServer.conn = noiseConn - noiseServer.machineKey = noiseServer.conn.Peer() - noiseServer.protocolVersion = noiseServer.conn.ProtocolVersion() + ns.conn = noiseConn + ns.machineKey = ns.conn.Peer() + ns.protocolVersion = ns.conn.ProtocolVersion() // This router is served only over the Noise connection, and exposes only the new API. // // The HTTP2 server that exposes this router is created for // a single hijacked connection from /ts2021, using netutil.NewOneConnListener - router := mux.NewRouter() - router.Use(prometheusMiddleware) - router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler). - Methods(http.MethodPost) + r := chi.NewRouter() + r.Use(metrics.Collector(metrics.CollectorOpts{ + Host: false, + Proto: true, + Skip: func(r *http.Request) bool { + return r.Method != http.MethodOptions + }, + })) + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Logger) + r.Use(middleware.Recoverer) - // Endpoints outside of the register endpoint must use getAndValidateNode to - // get the node to ensure that the MachineKey matches the Node setting up the - // connection. - router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler) + r.Handle("/metrics", metrics.Handler()) - noiseServer.httpBaseConfig = &http.Server{ - Handler: router, + r.Route("/machine", func(r chi.Router) { + r.Post("/register", ns.RegistrationHandler) + r.Post("/map", ns.PollNetMapHandler) + + // Not implemented yet + // + // /whoami is a debug endpoint to validate that the client can communicate over the connection, + // not clear if there is a specific response, it looks like it is just logged. + // https://github.com/tailscale/tailscale/blob/dfba01ca9bd8c4df02c3c32f400d9aeb897c5fc7/cmd/tailscale/cli/debug.go#L1138 + r.Get("/whoami", ns.NotImplementedHandler) + + // client sends a [tailcfg.SetDNSRequest] to this endpoints and expect + // the server to create or update this DNS record "somewhere". + // It is typically a TXT record for an ACME challenge. + r.Post("/set-dns", ns.NotImplementedHandler) + + // A patch of [tailcfg.SetDeviceAttributesRequest] to update device attributes. + // We currently do not support device attributes. + r.Patch("/set-device-attr", ns.NotImplementedHandler) + + // A [tailcfg.AuditLogRequest] to send audit log entries to the server. + // The server is expected to store them "somewhere". + // We currently do not support device attributes. + r.Post("/audit-log", ns.NotImplementedHandler) + + // handles requests to get an OIDC ID token. Receives a [tailcfg.TokenRequest]. + r.Post("/id-token", ns.NotImplementedHandler) + + // Asks the server if a feature is available and receive information about how to enable it. + // Gets a [tailcfg.QueryFeatureRequest] and returns a [tailcfg.QueryFeatureResponse]. + r.Post("/feature/query", ns.NotImplementedHandler) + + r.Post("/update-health", ns.NotImplementedHandler) + + r.Route("/webclient", func(r chi.Router) {}) + }) + + r.Post("/c2n", ns.NotImplementedHandler) + + r.Get("/ssh-action", ns.SSHAction) + + ns.httpBaseConfig = &http.Server{ + Handler: r, ReadHeaderTimeout: types.HTTPTimeout, } - noiseServer.http2Server = &http2.Server{} + ns.http2Server = &http2.Server{} - noiseServer.http2Server.ServeConn( + ns.http2Server.ServeConn( noiseConn, &http2.ServeConnOpts{ - BaseConfig: noiseServer.httpBaseConfig, + BaseConfig: ns.httpBaseConfig, }, ) } @@ -189,7 +237,19 @@ func rejectUnsupported( return false } -// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol +func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) { + d, _ := io.ReadAll(req.Body) + log.Trace().Caller().Str("path", req.URL.String()).Bytes("body", d).Msgf("not implemented handler hit") + http.Error(writer, "Not implemented yet", http.StatusNotImplemented) +} + +// SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction] +// to the client with the verdict of an SSH access request. +func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) { + log.Trace().Caller().Str("path", req.URL.String()).Msg("got SSH action request") +} + +// PollNetMapHandler takes care of /machine/:id/map using the Noise protocol // // This is the busiest endpoint, as it keeps the HTTP long poll that updates // the clients when something in the network changes. @@ -198,7 +258,7 @@ func rejectUnsupported( // only after their first request (marked with the ReadOnly field). // // At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (ns *noiseServer) NoisePollNetMapHandler( +func (ns *noiseServer) PollNetMapHandler( writer http.ResponseWriter, req *http.Request, ) { @@ -237,8 +297,8 @@ func regErr(err error) *tailcfg.RegisterResponse { return &tailcfg.RegisterResponse{Error: err.Error()} } -// NoiseRegistrationHandler handles the actual registration process of a node. -func (ns *noiseServer) NoiseRegistrationHandler( +// RegistrationHandler handles the actual registration process of a node. +func (ns *noiseServer) RegistrationHandler( writer http.ResponseWriter, req *http.Request, ) {