diff --git a/modules/public/vitedev.go b/modules/public/vitedev.go index 7cfe692390b..e6be460599a 100644 --- a/modules/public/vitedev.go +++ b/modules/public/vitedev.go @@ -87,6 +87,7 @@ func getViteDevProxy() *httputil.ReverseProxy { // the Vite dev server port from the port file written by the viteDevServerPortPlugin. // It is needed because there are container-based development, only Gitea web server's port is exposed. func ViteDevMiddleware(next http.Handler) http.Handler { + markLongPolling := routing.MarkLongPolling() return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { if !isViteDevRequest(req) { next.ServeHTTP(resp, req) @@ -97,8 +98,7 @@ func ViteDevMiddleware(next http.Handler) http.Handler { next.ServeHTTP(resp, req) return } - routing.MarkLongPolling(resp, req) - proxy.ServeHTTP(resp, req) + markLongPolling(proxy).ServeHTTP(resp, req) }) } diff --git a/modules/web/router.go b/modules/web/router.go index 5ef18e96795..f4575399b9d 100644 --- a/modules/web/router.go +++ b/modules/web/router.go @@ -13,18 +13,12 @@ import ( "code.gitea.io/gitea/modules/reqctx" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/web/middleware" + "code.gitea.io/gitea/modules/web/types" "gitea.com/go-chi/binding" "github.com/go-chi/chi/v5" ) -// PreMiddlewareProvider is a special middleware provider which will be executed -// before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting). -// A route can do something (e.g.: set middleware options) at the place where it is declared, -// and the code will be executed before other middlewares which are added before the declaration. -// Use cases: mark a route with some meta info, set some options for middlewares, etc. -type PreMiddlewareProvider func(next http.Handler) http.Handler - // Bind binding an obj to a handler's context data func Bind[T any](_ T) http.HandlerFunc { return func(resp http.ResponseWriter, req *http.Request) { @@ -112,7 +106,7 @@ func isNilOrFuncNil(v any) bool { func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []middlewareProvider { for _, m := range middlewares { - if h, ok := m.(PreMiddlewareProvider); ok && h != nil { + if h, ok := m.(types.PreMiddlewareProvider); ok && h != nil { all = append(all, toHandlerProvider(middlewareProvider(h))) } } @@ -121,7 +115,7 @@ func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []midd func wrapMiddlewareAppendNormal(all []middlewareProvider, middlewares []any) []middlewareProvider { for _, m := range middlewares { - if _, ok := m.(PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) { + if _, ok := m.(types.PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) { all = append(all, toHandlerProvider(m)) } } diff --git a/modules/web/router_test.go b/modules/web/router_test.go index d424f072e95..c1c314f85e1 100644 --- a/modules/web/router_test.go +++ b/modules/web/router_test.go @@ -13,6 +13,7 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/test" "code.gitea.io/gitea/modules/util" + "code.gitea.io/gitea/modules/web/types" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" @@ -312,12 +313,12 @@ func TestPreMiddlewareProvider(t *testing.T) { root := NewRouter() root.BeforeRouting(h("before-root")) root.AfterRouting(h("root")) - root.Get("/a/1", h("mid"), PreMiddlewareProvider(p("pre-root")), h("end1")) + root.Get("/a/1", h("mid"), types.PreMiddlewareProvider(p("pre-root")), h("end1")) sub := NewRouter() sub.BeforeRouting(h("before-sub")) sub.AfterRouting(h("sub")) - sub.Get("/2", h("mid"), PreMiddlewareProvider(p("pre-sub")), h("end2")) + sub.Get("/2", h("mid"), types.PreMiddlewareProvider(p("pre-sub")), h("end2")) sub.NotFound(h("not-found")) root.Mount("/a", sub) diff --git a/modules/web/routing/context.go b/modules/web/routing/context.go index e302507bf27..7799a24e948 100644 --- a/modules/web/routing/context.go +++ b/modules/web/routing/context.go @@ -10,12 +10,18 @@ import ( "code.gitea.io/gitea/modules/gtprof" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/reqctx" + "code.gitea.io/gitea/modules/web/types" ) type contextKeyType struct{} var contextKey contextKeyType +func getRequestRecord(ctx context.Context) *requestRecord { + record, _ := ctx.Value(contextKey).(*requestRecord) + return record +} + // RecordFuncInfo records a func info into context func RecordFuncInfo(ctx context.Context, funcInfo *FuncInfo) (end func()) { end = func() {} @@ -24,7 +30,7 @@ func RecordFuncInfo(ctx context.Context, funcInfo *FuncInfo) (end func()) { traceSpan, end = gtprof.GetTracer().StartInContext(reqCtx, "http.func") traceSpan.SetAttributeString("func", funcInfo.shortName) } - if record, ok := ctx.Value(contextKey).(*requestRecord); ok { + if record := getRequestRecord(ctx); record != nil { record.lock.Lock() record.funcInfo = funcInfo record.lock.Unlock() @@ -32,22 +38,39 @@ func RecordFuncInfo(ctx context.Context, funcInfo *FuncInfo) (end func()) { return end } -// MarkLongPolling marks the request is a long-polling request, and the logger may output different message for it -func MarkLongPolling(resp http.ResponseWriter, req *http.Request) { - record, ok := req.Context().Value(contextKey).(*requestRecord) - if !ok { - return +func GetRequestRecordInfo(reqCtx context.Context) (ret struct { + HasRecord bool + IsLongPolling bool +}, +) { + record := getRequestRecord(reqCtx) + if record == nil { + return ret } + ret.HasRecord = true + record.lock.RLock() + ret.IsLongPolling = record.isLongPolling + record.lock.RUnlock() + return ret +} - record.lock.Lock() - record.isLongPolling = true - record.logLevel = log.TRACE - record.lock.Unlock() +// MarkLongPolling marks the request is a long-polling request, and the logger may output different message for it +func MarkLongPolling() types.PreMiddlewareProvider { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + record := getRequestRecord(req.Context()) // it must exist + record.lock.Lock() + record.isLongPolling = true + record.logLevel = log.TRACE + record.lock.Unlock() + next.ServeHTTP(w, req) + }) + } } func MarkLogLevelTrace(resp http.ResponseWriter, req *http.Request) { - record, ok := req.Context().Value(contextKey).(*requestRecord) - if !ok { + record := getRequestRecord(req.Context()) + if record == nil { return } @@ -58,8 +81,8 @@ func MarkLogLevelTrace(resp http.ResponseWriter, req *http.Request) { // UpdatePanicError updates a context's error info, a panic may be recovered by other middlewares, but we still need to know that. func UpdatePanicError(ctx context.Context, err error) { - record, ok := ctx.Value(contextKey).(*requestRecord) - if !ok { + record := getRequestRecord(ctx) + if record == nil { return } diff --git a/modules/web/types/premiddleware.go b/modules/web/types/premiddleware.go new file mode 100644 index 00000000000..275b55b8c69 --- /dev/null +++ b/modules/web/types/premiddleware.go @@ -0,0 +1,13 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package types + +import "net/http" + +// PreMiddlewareProvider is a special middleware provider which will be executed +// before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting). +// A route can do something (e.g.: set middleware options) at the place where it is declared, +// and the code will be executed before other middlewares which are added before the declaration. +// Use cases: mark a route with some meta info, set some options for middlewares, etc. +type PreMiddlewareProvider func(next http.Handler) http.Handler diff --git a/routers/common/blockexpensive.go b/routers/common/blockexpensive.go index fec364351ca..0407264b1e4 100644 --- a/routers/common/blockexpensive.go +++ b/routers/common/blockexpensive.go @@ -11,6 +11,7 @@ import ( "code.gitea.io/gitea/modules/reqctx" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/web/middleware" + "code.gitea.io/gitea/modules/web/routing" "github.com/go-chi/chi/v5" ) @@ -71,10 +72,6 @@ func isRoutePathExpensive(routePattern string) bool { return false } -func isRoutePathForLongPolling(routePattern string) bool { - return routePattern == "/user/events" -} - func determineRequestPriority(reqCtx reqctx.RequestContext) (ret struct { SignedIn bool Expensive bool @@ -86,7 +83,7 @@ func determineRequestPriority(reqCtx reqctx.RequestContext) (ret struct { ret.SignedIn = true } else { ret.Expensive = isRoutePathExpensive(chiRoutePath) - ret.LongPolling = isRoutePathForLongPolling(chiRoutePath) + ret.LongPolling = routing.GetRequestRecordInfo(reqCtx).IsLongPolling } return ret } diff --git a/routers/common/blockexpensive_test.go b/routers/common/blockexpensive_test.go index db5c0db7dda..6ee4af60e86 100644 --- a/routers/common/blockexpensive_test.go +++ b/routers/common/blockexpensive_test.go @@ -25,6 +25,4 @@ func TestBlockExpensive(t *testing.T) { for _, c := range cases { assert.Equal(t, c.expensive, isRoutePathExpensive(c.routePath), "routePath: %s", c.routePath) } - - assert.True(t, isRoutePathForLongPolling("/user/events")) } diff --git a/routers/common/qos.go b/routers/common/qos.go index 96f23b64fe6..fbde4192236 100644 --- a/routers/common/qos.go +++ b/routers/common/qos.go @@ -14,6 +14,7 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/templates" "code.gitea.io/gitea/modules/web/middleware" + "code.gitea.io/gitea/modules/web/routing" "github.com/bohde/codel" "github.com/go-chi/chi/v5" @@ -68,7 +69,7 @@ func QoS() func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { ctx := req.Context() - + reqRecordInfo := routing.GetRequestRecordInfo(ctx) priority := requestPriority(ctx) // Check if the request can begin processing. @@ -79,9 +80,8 @@ func QoS() func(next http.Handler) http.Handler { return } - // Release long-polling immediately, so they don't always - // take up an in-flight request - if strings.Contains(req.URL.Path, "/user/events") { + // Release long-polling immediately, so they don't always take up an in-flight request + if reqRecordInfo.IsLongPolling { c.Release() } else { defer c.Release() diff --git a/routers/web/web.go b/routers/web/web.go index 15f8d4886b4..d70eb2d02d5 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -23,6 +23,7 @@ import ( "code.gitea.io/gitea/modules/web" "code.gitea.io/gitea/modules/web/middleware" "code.gitea.io/gitea/modules/web/routing" + "code.gitea.io/gitea/modules/web/types" "code.gitea.io/gitea/routers/common" "code.gitea.io/gitea/routers/web/admin" "code.gitea.io/gitea/routers/web/auth" @@ -91,8 +92,8 @@ func optionsCorsHandler() func(next http.Handler) http.Handler { } type AuthMiddleware struct { - AllowOAuth2 web.PreMiddlewareProvider - AllowBasic web.PreMiddlewareProvider + AllowOAuth2 types.PreMiddlewareProvider + AllowBasic types.PreMiddlewareProvider MiddlewareHandler func(*context.Context) } @@ -101,7 +102,7 @@ func newWebAuthMiddleware() *AuthMiddleware { type keyAllowBasic struct{} webAuth := &AuthMiddleware{} - middlewareSetContextValue := func(key, val any) web.PreMiddlewareProvider { + middlewareSetContextValue := func(key, val any) types.PreMiddlewareProvider { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dataStore := reqctx.GetRequestDataStore(r.Context()) @@ -588,7 +589,7 @@ func registerWebRoutes(m *web.Router, webAuth *AuthMiddleware) { }) }, reqSignOut) - m.Any("/user/events", routing.MarkLongPolling, events.Events) + m.Any("/user/events", routing.MarkLongPolling(), events.Events) m.Group("/login/oauth", func() { m.Group("", func() { diff --git a/tests/integration/signin_test.go b/tests/integration/signin_test.go index 4000c7ebe1d..5d5420d95d4 100644 --- a/tests/integration/signin_test.go +++ b/tests/integration/signin_test.go @@ -177,6 +177,8 @@ func TestRequireSignInView(t *testing.T) { require.False(t, setting.Service.BlockAnonymousAccessExpensive) req := NewRequest(t, "GET", "/user2/repo1/src/branch/master") MakeRequest(t, req, http.StatusOK) + req = NewRequest(t, "GET", "/user/events") + MakeRequest(t, req, http.StatusOK) }) t.Run("RequireSignInView", func(t *testing.T) { defer test.MockVariableValue(&setting.Service.RequireSignInViewStrict, true)() @@ -192,6 +194,8 @@ func TestRequireSignInView(t *testing.T) { req := NewRequest(t, "GET", "/user2/repo1") MakeRequest(t, req, http.StatusOK) + req = NewRequest(t, "GET", "/user/events") + MakeRequest(t, req, http.StatusSeeOther) req = NewRequest(t, "GET", "/user2/repo1/src/branch/master") resp := MakeRequest(t, req, http.StatusSeeOther)