Support for Custom URI Schemes in OAuth2 Redirect URIs (#37356)

Fix #34349

By the way, remove `(ctx *APIContext) HasAPIError() ` and `(ctx
*APIContext) GetErrMsg()` because they do nothing, the error handling
has been done in API's middeware

The existing OAuth2 tests were not quite right, refactored them together
This commit is contained in:
wxiaoguang
2026-04-23 05:33:27 +08:00
committed by GitHub
parent 8cfcef32c6
commit 83bdfc2a57
21 changed files with 340 additions and 512 deletions

View File

@@ -99,6 +99,7 @@ var OAuth2 = struct {
JWTClaimIssuer string `ini:"JWT_CLAIM_ISSUER"`
MaxTokenLength int
DefaultApplications []string
CustomSchemes []string
}{
Enabled: true,
AccessTokenExpirationTime: 3600,

View File

@@ -13,7 +13,6 @@ import (
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/glob"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/util"
"gitea.com/go-chi/binding"
)
@@ -51,7 +50,6 @@ func (j jsonProvider) NewEncoder(writer io.Writer) binding.JSONEncoder {
func AddBindingRules() {
binding.JSONProvider = jsonProvider{}
addGitRefNameBindingRule()
addValidURLListBindingRule()
addValidURLBindingRule()
addValidSiteURLBindingRule()
addGlobPatternRule()
@@ -80,33 +78,6 @@ func addGitRefNameBindingRule() {
})
}
func addValidURLListBindingRule() {
// URL validation rule
binding.AddRule(&binding.Rule{
IsMatch: func(rule string) bool {
return rule == "ValidUrlList"
},
IsValid: func(errs binding.Errors, name string, val any) (bool, binding.Errors) {
str := fmt.Sprintf("%v", val)
if len(str) == 0 {
errs.Add([]string{name}, binding.ERR_URL, "Url")
return false, errs
}
ok := true
urls := util.SplitTrimSpace(str, "\n")
for _, u := range urls {
if !IsValidURL(u) {
ok = false
errs.Add([]string{name}, binding.ERR_URL, u)
}
}
return ok, errs
},
})
}
func addValidURLBindingRule() {
// URL validation rule
binding.AddRule(&binding.Rule{

View File

@@ -27,7 +27,6 @@ type (
TestForm struct {
BranchName string `form:"BranchName" binding:"GitRefName"`
URL string `form:"ValidUrl" binding:"ValidUrl"`
URLs string `form:"ValidUrls" binding:"ValidUrlList"`
GlobPattern string `form:"GlobPattern" binding:"GlobPattern"`
RegexPattern string `form:"RegexPattern" binding:"RegexPattern"`
}

View File

@@ -1,157 +0,0 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package validation
import (
"testing"
"gitea.com/go-chi/binding"
)
func Test_ValidURLListValidation(t *testing.T) {
AddBindingRules()
// This is a copy of all the URL tests cases, plus additional ones to
// account for multiple URLs
urlListValidationTestCases := []validationTestCase{
{
description: "Empty URL",
data: TestForm{
URLs: "",
},
expectedErrors: binding.Errors{},
},
{
description: "URL without port",
data: TestForm{
URLs: "http://test.lan/",
},
expectedErrors: binding.Errors{},
},
{
description: "URL with port",
data: TestForm{
URLs: "http://test.lan:3000/",
},
expectedErrors: binding.Errors{},
},
{
description: "URL with IPv6 address without port",
data: TestForm{
URLs: "http://[::1]/",
},
expectedErrors: binding.Errors{},
},
{
description: "URL with IPv6 address with port",
data: TestForm{
URLs: "http://[::1]:3000/",
},
expectedErrors: binding.Errors{},
},
{
description: "Invalid URL",
data: TestForm{
URLs: "http//test.lan/",
},
expectedErrors: binding.Errors{
binding.Error{
FieldNames: []string{"URLs"},
Classification: binding.ERR_URL,
Message: "http//test.lan/",
},
},
},
{
description: "Invalid schema",
data: TestForm{
URLs: "ftp://test.lan/",
},
expectedErrors: binding.Errors{
binding.Error{
FieldNames: []string{"URLs"},
Classification: binding.ERR_URL,
Message: "ftp://test.lan/",
},
},
},
{
description: "Invalid port",
data: TestForm{
URLs: "http://test.lan:3x4/",
},
expectedErrors: binding.Errors{
binding.Error{
FieldNames: []string{"URLs"},
Classification: binding.ERR_URL,
Message: "http://test.lan:3x4/",
},
},
},
{
description: "Invalid port with IPv6 address",
data: TestForm{
URLs: "http://[::1]:3x4/",
},
expectedErrors: binding.Errors{
binding.Error{
FieldNames: []string{"URLs"},
Classification: binding.ERR_URL,
Message: "http://[::1]:3x4/",
},
},
},
{
description: "Multi URLs",
data: TestForm{
URLs: "http://test.lan:3000/\nhttp://test.local/",
},
expectedErrors: binding.Errors{},
},
{
description: "Multi URLs with newline",
data: TestForm{
URLs: "http://test.lan:3000/\nhttp://test.local/\n",
},
expectedErrors: binding.Errors{},
},
{
description: "List with invalid entry",
data: TestForm{
URLs: "http://test.lan:3000/\nhttp://[::1]:3x4/",
},
expectedErrors: binding.Errors{
binding.Error{
FieldNames: []string{"URLs"},
Classification: binding.ERR_URL,
Message: "http://[::1]:3x4/",
},
},
},
{
description: "List with two invalid entries",
data: TestForm{
URLs: "ftp://test.lan:3000/\nhttp://[::1]:3x4/\n",
},
expectedErrors: binding.Errors{
binding.Error{
FieldNames: []string{"URLs"},
Classification: binding.ERR_URL,
Message: "ftp://test.lan:3000/",
},
binding.Error{
FieldNames: []string{"URLs"},
Classification: binding.ERR_URL,
Message: "http://[::1]:3x4/",
},
},
},
}
for _, testCase := range urlListValidationTestCases {
t.Run(testCase.description, func(t *testing.T) {
performValidationTest(t, testCase)
})
}
}

View File

@@ -78,82 +78,97 @@ func GetInclude(field reflect.StructField) string {
return getRuleBody(field, "Include(")
}
// Validate validate
func ReportValidationError(errs binding.Errors, data map[string]any, fieldName, classification, errorMsg string) binding.Errors {
errs.Add([]string{fieldName}, classification, errorMsg)
data["HasError"] = true
data["ErrorMsg"] = fieldName + ": " + errorMsg
data["Err_"+fieldName] = true
// there is already a reported validation error, so no need to generate default error messages in Validate()
data["HasErrorFormValidation"] = true
return errs
}
func Validate(errs binding.Errors, data map[string]any, f Form, l translation.Locale) binding.Errors {
if errs.Len() == 0 {
// try to restore the form's values as much as possible,
// especially for RenderWithErrDeprecated to re-render the form with errors
AssignForm(f, data)
if errs.Len() == 0 || data["HasErrorFormValidation"] == true {
return errs
}
// if HasError=true, then must set default error message
// because still a lot of places use `ctx.Data["ErrorMsg"].(string)` even if the error fields can't be found
data["HasError"] = true
// If the field with name errs[0].FieldNames[0] is not found in form
// somehow, some code later on will panic on Data["ErrorMsg"].(string).
// So initialize it to some default.
data["ErrorMsg"] = l.Tr("form.unknown_error")
AssignForm(f, data)
data["ErrorMsg"] = l.TrString("form.unknown_error")
typ := reflect.TypeOf(f)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if field, ok := typ.FieldByName(errs[0].FieldNames[0]); ok {
fieldName := field.Tag.Get("form")
if fieldName != "-" {
data["Err_"+field.Name] = true
trName := field.Tag.Get("locale")
if len(trName) == 0 {
trName = l.TrString("form." + field.Name)
} else {
trName = l.TrString(trName)
}
switch errs[0].Classification {
case binding.ERR_REQUIRED:
data["ErrorMsg"] = trName + l.TrString("form.require_error")
case binding.ERR_ALPHA_DASH:
data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_error")
case binding.ERR_ALPHA_DASH_DOT:
data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_dot_error")
case validation.ErrGitRefName:
data["ErrorMsg"] = trName + l.TrString("form.git_ref_name_error")
case binding.ERR_SIZE:
data["ErrorMsg"] = trName + l.TrString("form.size_error", GetSize(field))
case binding.ERR_MIN_SIZE:
data["ErrorMsg"] = trName + l.TrString("form.min_size_error", GetMinSize(field))
case binding.ERR_MAX_SIZE:
data["ErrorMsg"] = trName + l.TrString("form.max_size_error", GetMaxSize(field))
case binding.ERR_EMAIL:
data["ErrorMsg"] = trName + l.TrString("form.email_error")
case binding.ERR_URL:
data["ErrorMsg"] = trName + l.TrString("form.url_error", errs[0].Message)
case binding.ERR_INCLUDE:
data["ErrorMsg"] = trName + l.TrString("form.include_error", GetInclude(field))
case validation.ErrGlobPattern:
data["ErrorMsg"] = trName + l.TrString("form.glob_pattern_error", errs[0].Message)
case validation.ErrRegexPattern:
data["ErrorMsg"] = trName + l.TrString("form.regex_pattern_error", errs[0].Message)
case validation.ErrUsername:
data["ErrorMsg"] = trName + l.TrString("form.username_error")
case validation.ErrInvalidGroupTeamMap:
data["ErrorMsg"] = trName + l.TrString("form.invalid_group_team_map_error", errs[0].Message)
case validation.ErrInvalidBadgeSlug:
data["ErrorMsg"] = trName + l.TrString("form.invalid_slug_error")
default:
msg := errs[0].Classification
if msg != "" && errs[0].Message != "" {
msg += ": "
}
msg += errs[0].Message
if msg == "" {
msg = l.TrString("form.unknown_error")
}
data["ErrorMsg"] = trName + ": " + msg
}
return errs
}
field, fieldExists := typ.FieldByName(errs[0].FieldNames[0])
if !fieldExists {
return errs
}
if field.Tag.Get("form") == "-" {
return errs
}
data["Err_"+field.Name] = true
trName := field.Tag.Get("locale")
if len(trName) == 0 {
trName = l.TrString("form." + field.Name)
} else {
trName = l.TrString(trName)
}
switch errs[0].Classification {
case binding.ERR_REQUIRED:
data["ErrorMsg"] = trName + l.TrString("form.require_error")
case binding.ERR_ALPHA_DASH:
data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_error")
case binding.ERR_ALPHA_DASH_DOT:
data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_dot_error")
case validation.ErrGitRefName:
data["ErrorMsg"] = trName + l.TrString("form.git_ref_name_error")
case binding.ERR_SIZE:
data["ErrorMsg"] = trName + l.TrString("form.size_error", GetSize(field))
case binding.ERR_MIN_SIZE:
data["ErrorMsg"] = trName + l.TrString("form.min_size_error", GetMinSize(field))
case binding.ERR_MAX_SIZE:
data["ErrorMsg"] = trName + l.TrString("form.max_size_error", GetMaxSize(field))
case binding.ERR_EMAIL:
data["ErrorMsg"] = trName + l.TrString("form.email_error")
case binding.ERR_URL:
data["ErrorMsg"] = trName + l.TrString("form.url_error", errs[0].Message)
case binding.ERR_INCLUDE:
data["ErrorMsg"] = trName + l.TrString("form.include_error", GetInclude(field))
case validation.ErrGlobPattern:
data["ErrorMsg"] = trName + l.TrString("form.glob_pattern_error", errs[0].Message)
case validation.ErrRegexPattern:
data["ErrorMsg"] = trName + l.TrString("form.regex_pattern_error", errs[0].Message)
case validation.ErrUsername:
data["ErrorMsg"] = trName + l.TrString("form.username_error")
case validation.ErrInvalidGroupTeamMap:
data["ErrorMsg"] = trName + l.TrString("form.invalid_group_team_map_error", errs[0].Message)
case validation.ErrInvalidBadgeSlug:
data["ErrorMsg"] = trName + l.TrString("form.invalid_slug_error")
default:
msg := errs[0].Classification
if msg != "" && errs[0].Message != "" {
msg += ": "
}
msg += errs[0].Message
if msg == "" {
msg = l.TrString("form.unknown_error")
}
data["ErrorMsg"] = trName + ": " + msg
}
return errs
}