Use MarkLongPolling instead of hard-coded route path (#37427)

This commit is contained in:
wxiaoguang
2026-04-26 19:42:29 +08:00
committed by GitHub
parent ebf30ac4db
commit 712b3a54b5
10 changed files with 73 additions and 42 deletions

View File

@@ -87,6 +87,7 @@ func getViteDevProxy() *httputil.ReverseProxy {
// the Vite dev server port from the port file written by the viteDevServerPortPlugin.
// It is needed because there are container-based development, only Gitea web server's port is exposed.
func ViteDevMiddleware(next http.Handler) http.Handler {
markLongPolling := routing.MarkLongPolling()
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
if !isViteDevRequest(req) {
next.ServeHTTP(resp, req)
@@ -97,8 +98,7 @@ func ViteDevMiddleware(next http.Handler) http.Handler {
next.ServeHTTP(resp, req)
return
}
routing.MarkLongPolling(resp, req)
proxy.ServeHTTP(resp, req)
markLongPolling(proxy).ServeHTTP(resp, req)
})
}

View File

@@ -13,18 +13,12 @@ import (
"code.gitea.io/gitea/modules/reqctx"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/web/middleware"
"code.gitea.io/gitea/modules/web/types"
"gitea.com/go-chi/binding"
"github.com/go-chi/chi/v5"
)
// PreMiddlewareProvider is a special middleware provider which will be executed
// before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting).
// A route can do something (e.g.: set middleware options) at the place where it is declared,
// and the code will be executed before other middlewares which are added before the declaration.
// Use cases: mark a route with some meta info, set some options for middlewares, etc.
type PreMiddlewareProvider func(next http.Handler) http.Handler
// Bind binding an obj to a handler's context data
func Bind[T any](_ T) http.HandlerFunc {
return func(resp http.ResponseWriter, req *http.Request) {
@@ -112,7 +106,7 @@ func isNilOrFuncNil(v any) bool {
func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []middlewareProvider {
for _, m := range middlewares {
if h, ok := m.(PreMiddlewareProvider); ok && h != nil {
if h, ok := m.(types.PreMiddlewareProvider); ok && h != nil {
all = append(all, toHandlerProvider(middlewareProvider(h)))
}
}
@@ -121,7 +115,7 @@ func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []midd
func wrapMiddlewareAppendNormal(all []middlewareProvider, middlewares []any) []middlewareProvider {
for _, m := range middlewares {
if _, ok := m.(PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) {
if _, ok := m.(types.PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) {
all = append(all, toHandlerProvider(m))
}
}

View File

@@ -13,6 +13,7 @@ import (
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/test"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/modules/web/types"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
@@ -312,12 +313,12 @@ func TestPreMiddlewareProvider(t *testing.T) {
root := NewRouter()
root.BeforeRouting(h("before-root"))
root.AfterRouting(h("root"))
root.Get("/a/1", h("mid"), PreMiddlewareProvider(p("pre-root")), h("end1"))
root.Get("/a/1", h("mid"), types.PreMiddlewareProvider(p("pre-root")), h("end1"))
sub := NewRouter()
sub.BeforeRouting(h("before-sub"))
sub.AfterRouting(h("sub"))
sub.Get("/2", h("mid"), PreMiddlewareProvider(p("pre-sub")), h("end2"))
sub.Get("/2", h("mid"), types.PreMiddlewareProvider(p("pre-sub")), h("end2"))
sub.NotFound(h("not-found"))
root.Mount("/a", sub)

View File

@@ -10,12 +10,18 @@ import (
"code.gitea.io/gitea/modules/gtprof"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/reqctx"
"code.gitea.io/gitea/modules/web/types"
)
type contextKeyType struct{}
var contextKey contextKeyType
func getRequestRecord(ctx context.Context) *requestRecord {
record, _ := ctx.Value(contextKey).(*requestRecord)
return record
}
// RecordFuncInfo records a func info into context
func RecordFuncInfo(ctx context.Context, funcInfo *FuncInfo) (end func()) {
end = func() {}
@@ -24,7 +30,7 @@ func RecordFuncInfo(ctx context.Context, funcInfo *FuncInfo) (end func()) {
traceSpan, end = gtprof.GetTracer().StartInContext(reqCtx, "http.func")
traceSpan.SetAttributeString("func", funcInfo.shortName)
}
if record, ok := ctx.Value(contextKey).(*requestRecord); ok {
if record := getRequestRecord(ctx); record != nil {
record.lock.Lock()
record.funcInfo = funcInfo
record.lock.Unlock()
@@ -32,22 +38,39 @@ func RecordFuncInfo(ctx context.Context, funcInfo *FuncInfo) (end func()) {
return end
}
// MarkLongPolling marks the request is a long-polling request, and the logger may output different message for it
func MarkLongPolling(resp http.ResponseWriter, req *http.Request) {
record, ok := req.Context().Value(contextKey).(*requestRecord)
if !ok {
return
func GetRequestRecordInfo(reqCtx context.Context) (ret struct {
HasRecord bool
IsLongPolling bool
},
) {
record := getRequestRecord(reqCtx)
if record == nil {
return ret
}
ret.HasRecord = true
record.lock.RLock()
ret.IsLongPolling = record.isLongPolling
record.lock.RUnlock()
return ret
}
record.lock.Lock()
record.isLongPolling = true
record.logLevel = log.TRACE
record.lock.Unlock()
// MarkLongPolling marks the request is a long-polling request, and the logger may output different message for it
func MarkLongPolling() types.PreMiddlewareProvider {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
record := getRequestRecord(req.Context()) // it must exist
record.lock.Lock()
record.isLongPolling = true
record.logLevel = log.TRACE
record.lock.Unlock()
next.ServeHTTP(w, req)
})
}
}
func MarkLogLevelTrace(resp http.ResponseWriter, req *http.Request) {
record, ok := req.Context().Value(contextKey).(*requestRecord)
if !ok {
record := getRequestRecord(req.Context())
if record == nil {
return
}
@@ -58,8 +81,8 @@ func MarkLogLevelTrace(resp http.ResponseWriter, req *http.Request) {
// UpdatePanicError updates a context's error info, a panic may be recovered by other middlewares, but we still need to know that.
func UpdatePanicError(ctx context.Context, err error) {
record, ok := ctx.Value(contextKey).(*requestRecord)
if !ok {
record := getRequestRecord(ctx)
if record == nil {
return
}

View File

@@ -0,0 +1,13 @@
// Copyright 2026 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package types
import "net/http"
// PreMiddlewareProvider is a special middleware provider which will be executed
// before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting).
// A route can do something (e.g.: set middleware options) at the place where it is declared,
// and the code will be executed before other middlewares which are added before the declaration.
// Use cases: mark a route with some meta info, set some options for middlewares, etc.
type PreMiddlewareProvider func(next http.Handler) http.Handler

View File

@@ -11,6 +11,7 @@ import (
"code.gitea.io/gitea/modules/reqctx"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/web/middleware"
"code.gitea.io/gitea/modules/web/routing"
"github.com/go-chi/chi/v5"
)
@@ -71,10 +72,6 @@ func isRoutePathExpensive(routePattern string) bool {
return false
}
func isRoutePathForLongPolling(routePattern string) bool {
return routePattern == "/user/events"
}
func determineRequestPriority(reqCtx reqctx.RequestContext) (ret struct {
SignedIn bool
Expensive bool
@@ -86,7 +83,7 @@ func determineRequestPriority(reqCtx reqctx.RequestContext) (ret struct {
ret.SignedIn = true
} else {
ret.Expensive = isRoutePathExpensive(chiRoutePath)
ret.LongPolling = isRoutePathForLongPolling(chiRoutePath)
ret.LongPolling = routing.GetRequestRecordInfo(reqCtx).IsLongPolling
}
return ret
}

View File

@@ -25,6 +25,4 @@ func TestBlockExpensive(t *testing.T) {
for _, c := range cases {
assert.Equal(t, c.expensive, isRoutePathExpensive(c.routePath), "routePath: %s", c.routePath)
}
assert.True(t, isRoutePathForLongPolling("/user/events"))
}

View File

@@ -14,6 +14,7 @@ import (
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/templates"
"code.gitea.io/gitea/modules/web/middleware"
"code.gitea.io/gitea/modules/web/routing"
"github.com/bohde/codel"
"github.com/go-chi/chi/v5"
@@ -68,7 +69,7 @@ func QoS() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
reqRecordInfo := routing.GetRequestRecordInfo(ctx)
priority := requestPriority(ctx)
// Check if the request can begin processing.
@@ -79,9 +80,8 @@ func QoS() func(next http.Handler) http.Handler {
return
}
// Release long-polling immediately, so they don't always
// take up an in-flight request
if strings.Contains(req.URL.Path, "/user/events") {
// Release long-polling immediately, so they don't always take up an in-flight request
if reqRecordInfo.IsLongPolling {
c.Release()
} else {
defer c.Release()

View File

@@ -23,6 +23,7 @@ import (
"code.gitea.io/gitea/modules/web"
"code.gitea.io/gitea/modules/web/middleware"
"code.gitea.io/gitea/modules/web/routing"
"code.gitea.io/gitea/modules/web/types"
"code.gitea.io/gitea/routers/common"
"code.gitea.io/gitea/routers/web/admin"
"code.gitea.io/gitea/routers/web/auth"
@@ -91,8 +92,8 @@ func optionsCorsHandler() func(next http.Handler) http.Handler {
}
type AuthMiddleware struct {
AllowOAuth2 web.PreMiddlewareProvider
AllowBasic web.PreMiddlewareProvider
AllowOAuth2 types.PreMiddlewareProvider
AllowBasic types.PreMiddlewareProvider
MiddlewareHandler func(*context.Context)
}
@@ -101,7 +102,7 @@ func newWebAuthMiddleware() *AuthMiddleware {
type keyAllowBasic struct{}
webAuth := &AuthMiddleware{}
middlewareSetContextValue := func(key, val any) web.PreMiddlewareProvider {
middlewareSetContextValue := func(key, val any) types.PreMiddlewareProvider {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dataStore := reqctx.GetRequestDataStore(r.Context())
@@ -588,7 +589,7 @@ func registerWebRoutes(m *web.Router, webAuth *AuthMiddleware) {
})
}, reqSignOut)
m.Any("/user/events", routing.MarkLongPolling, events.Events)
m.Any("/user/events", routing.MarkLongPolling(), events.Events)
m.Group("/login/oauth", func() {
m.Group("", func() {

View File

@@ -177,6 +177,8 @@ func TestRequireSignInView(t *testing.T) {
require.False(t, setting.Service.BlockAnonymousAccessExpensive)
req := NewRequest(t, "GET", "/user2/repo1/src/branch/master")
MakeRequest(t, req, http.StatusOK)
req = NewRequest(t, "GET", "/user/events")
MakeRequest(t, req, http.StatusOK)
})
t.Run("RequireSignInView", func(t *testing.T) {
defer test.MockVariableValue(&setting.Service.RequireSignInViewStrict, true)()
@@ -192,6 +194,8 @@ func TestRequireSignInView(t *testing.T) {
req := NewRequest(t, "GET", "/user2/repo1")
MakeRequest(t, req, http.StatusOK)
req = NewRequest(t, "GET", "/user/events")
MakeRequest(t, req, http.StatusSeeOther)
req = NewRequest(t, "GET", "/user2/repo1/src/branch/master")
resp := MakeRequest(t, req, http.StatusSeeOther)