app: switch from gorilla to chi mux

Replace gorilla/mux with go-chi/chi as the HTTP router and add a
custom zerolog-based request logger to replace chi's default
stdlib-based middleware.Logger, consistent with the rest of the
application.

Updates #1850
This commit is contained in:
Kristoffer Dalby
2026-02-24 18:47:40 +00:00
parent 25ccb5a161
commit 30338441c1
5 changed files with 179 additions and 55 deletions

View File

@@ -27,7 +27,7 @@
let
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
buildGo = pkgs.buildGo126Module;
vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0=";
vendorHash = "sha256-oUN53ELb3+xn4yA7lEfXyT2c7NxbQC6RtbkGVq6+RLU=";
in
{
headscale = buildGo {

2
go.mod
View File

@@ -14,6 +14,8 @@ require (
github.com/docker/docker v28.5.2+incompatible
github.com/fsnotify/fsnotify v1.9.0
github.com/glebarez/sqlite v1.11.0
github.com/go-chi/chi/v5 v5.2.5
github.com/go-chi/metrics v0.1.1
github.com/go-gormigrate/gormigrate/v2 v2.1.5
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e
github.com/gofrs/uuid/v5 v5.4.0

4
go.sum
View File

@@ -181,6 +181,10 @@ github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/go-chi/metrics v0.1.1 h1:CXhbnkAVVjb0k73EBRQ6Z2YdWFnbXZgNtg1Mboguibk=
github.com/go-chi/metrics v0.1.1/go.mod h1:mcGTM1pPalP7WCtb+akNYFO/lwNwBBLCuedepqjoPn4=
github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8=
github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M=
github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY=

View File

@@ -20,7 +20,9 @@ import (
"github.com/cenkalti/backoff/v5"
"github.com/davecgh/go-spew/spew"
"github.com/gorilla/mux"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/metrics"
grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@@ -457,50 +459,57 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
return os.Remove(h.cfg.UnixSocket)
}
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router := mux.NewRouter()
router.Use(prometheusMiddleware)
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux {
r := chi.NewRouter()
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != http.MethodOptions
},
}))
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.RequestLogger(&zerologRequestLogger{}))
r.Use(middleware.Recoverer)
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).
Methods(http.MethodPost, http.MethodGet)
r.Post(ts2021UpgradePath, h.NoiseUpgradeHandler)
router.HandleFunc("/robots.txt", h.RobotsHandler).Methods(http.MethodGet)
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/version", h.VersionHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).
Methods(http.MethodGet)
r.Get("/robots.txt", h.RobotsHandler)
r.Get("/health", h.HealthHandler)
r.Get("/version", h.VersionHandler)
r.Get("/key", h.KeyHandler)
r.Get("/register/{registration_id}", h.authProvider.RegisterHandler)
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
r.Get("/oidc/callback", provider.OIDCCallbackHandler)
}
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
Methods(http.MethodGet)
router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
r.Get("/apple", h.AppleConfigMessage)
r.Get("/apple/{platform}", h.ApplePlatformConfig)
r.Get("/windows", h.WindowsConfigMessage)
// TODO(kristoffer): move swagger into a package
router.HandleFunc("/swagger", headscale.SwaggerUI).Methods(http.MethodGet)
router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1).
Methods(http.MethodGet)
r.Get("/swagger", headscale.SwaggerUI)
r.Get("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1)
router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost)
r.Post("/verify", h.VerifyHandler)
if h.cfg.DERP.ServerEnabled {
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
router.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler)
router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap()))
r.HandleFunc("/derp", h.DERPServer.DERPHandler)
r.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
r.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler)
r.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap()))
}
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(h.httpAuthenticationMiddleware)
apiRouter.PathPrefix("/v1/").HandlerFunc(grpcMux.ServeHTTP)
router.HandleFunc("/favicon.ico", FaviconHandler)
router.PathPrefix("/").HandlerFunc(BlankHandler)
r.Route("/api", func(r chi.Router) {
r.Use(h.httpAuthenticationMiddleware)
r.HandleFunc("/v1/*", grpcMux.ServeHTTP)
})
r.Get("/favicon.ico", FaviconHandler)
r.Get("/", BlankHandler)
return router
return r
}
// Serve launches the HTTP and gRPC server service Headscale and the API.
@@ -1083,3 +1092,52 @@ func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
return resp, nil
}
// zerologRequestLogger implements chi's middleware.LogFormatter
// to route HTTP request logs through zerolog.
type zerologRequestLogger struct{}
func (z *zerologRequestLogger) NewLogEntry(
r *http.Request,
) middleware.LogEntry {
return &zerologLogEntry{
method: r.Method,
path: r.URL.Path,
proto: r.Proto,
remote: r.RemoteAddr,
}
}
type zerologLogEntry struct {
method string
path string
proto string
remote string
}
func (e *zerologLogEntry) Write(
status, bytes int,
header http.Header,
elapsed time.Duration,
extra any,
) {
log.Info().
Str("method", e.method).
Str("path", e.path).
Str("proto", e.proto).
Str("remote", e.remote).
Int("status", status).
Int("bytes", bytes).
Dur("elapsed", elapsed).
Msg("http request")
}
func (e *zerologLogEntry) Panic(
v any,
stack []byte,
) {
log.Error().
Interface("panic", v).
Bytes("stack", stack).
Msg("http handler panic")
}

View File

@@ -8,7 +8,9 @@ import (
"io"
"net/http"
"github.com/gorilla/mux"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/metrics"
"github.com/juanfont/headscale/hscontrol/capver"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
@@ -69,7 +71,7 @@ func (h *Headscale) NoiseUpgradeHandler(
return
}
noiseServer := noiseServer{
ns := noiseServer{
headscale: h,
challenge: key.NewChallenge(),
}
@@ -79,42 +81,88 @@ func (h *Headscale) NoiseUpgradeHandler(
writer,
req,
*h.noisePrivateKey,
noiseServer.earlyNoise,
ns.earlyNoise,
)
if err != nil {
httpError(writer, fmt.Errorf("upgrading noise connection: %w", err))
return
}
noiseServer.conn = noiseConn
noiseServer.machineKey = noiseServer.conn.Peer()
noiseServer.protocolVersion = noiseServer.conn.ProtocolVersion()
ns.conn = noiseConn
ns.machineKey = ns.conn.Peer()
ns.protocolVersion = ns.conn.ProtocolVersion()
// This router is served only over the Noise connection, and exposes only the new API.
//
// The HTTP2 server that exposes this router is created for
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
router := mux.NewRouter()
router.Use(prometheusMiddleware)
router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler).
Methods(http.MethodPost)
r := chi.NewRouter()
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != http.MethodOptions
},
}))
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
// Endpoints outside of the register endpoint must use getAndValidateNode to
// get the node to ensure that the MachineKey matches the Node setting up the
// connection.
router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler)
r.Handle("/metrics", metrics.Handler())
noiseServer.httpBaseConfig = &http.Server{
Handler: router,
r.Route("/machine", func(r chi.Router) {
r.Post("/register", ns.RegistrationHandler)
r.Post("/map", ns.PollNetMapHandler)
// Not implemented yet
//
// /whoami is a debug endpoint to validate that the client can communicate over the connection,
// not clear if there is a specific response, it looks like it is just logged.
// https://github.com/tailscale/tailscale/blob/dfba01ca9bd8c4df02c3c32f400d9aeb897c5fc7/cmd/tailscale/cli/debug.go#L1138
r.Get("/whoami", ns.NotImplementedHandler)
// client sends a [tailcfg.SetDNSRequest] to this endpoints and expect
// the server to create or update this DNS record "somewhere".
// It is typically a TXT record for an ACME challenge.
r.Post("/set-dns", ns.NotImplementedHandler)
// A patch of [tailcfg.SetDeviceAttributesRequest] to update device attributes.
// We currently do not support device attributes.
r.Patch("/set-device-attr", ns.NotImplementedHandler)
// A [tailcfg.AuditLogRequest] to send audit log entries to the server.
// The server is expected to store them "somewhere".
// We currently do not support device attributes.
r.Post("/audit-log", ns.NotImplementedHandler)
// handles requests to get an OIDC ID token. Receives a [tailcfg.TokenRequest].
r.Post("/id-token", ns.NotImplementedHandler)
// Asks the server if a feature is available and receive information about how to enable it.
// Gets a [tailcfg.QueryFeatureRequest] and returns a [tailcfg.QueryFeatureResponse].
r.Post("/feature/query", ns.NotImplementedHandler)
r.Post("/update-health", ns.NotImplementedHandler)
r.Route("/webclient", func(r chi.Router) {})
})
r.Post("/c2n", ns.NotImplementedHandler)
r.Get("/ssh-action", ns.SSHAction)
ns.httpBaseConfig = &http.Server{
Handler: r,
ReadHeaderTimeout: types.HTTPTimeout,
}
noiseServer.http2Server = &http2.Server{}
ns.http2Server = &http2.Server{}
noiseServer.http2Server.ServeConn(
ns.http2Server.ServeConn(
noiseConn,
&http2.ServeConnOpts{
BaseConfig: noiseServer.httpBaseConfig,
BaseConfig: ns.httpBaseConfig,
},
)
}
@@ -189,7 +237,19 @@ func rejectUnsupported(
return false
}
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) {
d, _ := io.ReadAll(req.Body)
log.Trace().Caller().Str("path", req.URL.String()).Bytes("body", d).Msgf("not implemented handler hit")
http.Error(writer, "Not implemented yet", http.StatusNotImplemented)
}
// SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction]
// to the client with the verdict of an SSH access request.
func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) {
log.Trace().Caller().Str("path", req.URL.String()).Msg("got SSH action request")
}
// PollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
@@ -198,7 +258,7 @@ func rejectUnsupported(
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (ns *noiseServer) NoisePollNetMapHandler(
func (ns *noiseServer) PollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
@@ -237,8 +297,8 @@ func regErr(err error) *tailcfg.RegisterResponse {
return &tailcfg.RegisterResponse{Error: err.Error()}
}
// NoiseRegistrationHandler handles the actual registration process of a node.
func (ns *noiseServer) NoiseRegistrationHandler(
// RegistrationHandler handles the actual registration process of a node.
func (ns *noiseServer) RegistrationHandler(
writer http.ResponseWriter,
req *http.Request,
) {