mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-03 14:32:39 +09:00 
			
		
		
		
	Set CSRF cookies for OIDC (#2328)
	
		
			
	
		
	
	
		
	
		
			Some checks are pending
		
		
	
	
		
			
				
	
				Build / build-nix (push) Waiting to run
				
			
		
			
				
	
				Build / build-cross (GOARCH=386   GOOS=linux) (push) Waiting to run
				
			
		
			
				
	
				Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Waiting to run
				
			
		
			
				
	
				Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Waiting to run
				
			
		
			
				
	
				Build / build-cross (GOARCH=arm   GOOS=linux GOARM=5) (push) Waiting to run
				
			
		
			
				
	
				Build / build-cross (GOARCH=arm   GOOS=linux GOARM=6) (push) Waiting to run
				
			
		
			
				
	
				Build / build-cross (GOARCH=arm   GOOS=linux GOARM=7) (push) Waiting to run
				
			
		
			
				
	
				Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Waiting to run
				
			
		
			
				
	
				Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Waiting to run
				
			
		
			
				
	
				Tests / test (push) Waiting to run
				
			
		
		
	
	
				
					
				
			
		
			Some checks are pending
		
		
	
	Build / build-nix (push) Waiting to run
				
			Build / build-cross (GOARCH=386   GOOS=linux) (push) Waiting to run
				
			Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Waiting to run
				
			Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Waiting to run
				
			Build / build-cross (GOARCH=arm   GOOS=linux GOARM=5) (push) Waiting to run
				
			Build / build-cross (GOARCH=arm   GOOS=linux GOARM=6) (push) Waiting to run
				
			Build / build-cross (GOARCH=arm   GOOS=linux GOARM=7) (push) Waiting to run
				
			Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Waiting to run
				
			Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Waiting to run
				
			Tests / test (push) Waiting to run
				
			* set state and nounce in oidc to prevent csrf Fixes #2276 * try to fix new postgres issue Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							
								
								
									
										6
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							@@ -34,4 +34,10 @@ jobs:
 | 
			
		||||
 | 
			
		||||
      - name: Run tests
 | 
			
		||||
        if: steps.changed-files.outputs.files == 'true'
 | 
			
		||||
        env:
 | 
			
		||||
          # As of 2025-01-06, these env vars was not automatically
 | 
			
		||||
          # set anymore which breaks the initdb for postgres on
 | 
			
		||||
          # some of the database migration tests.
 | 
			
		||||
          LC_ALL: "en_US.UTF-8"
 | 
			
		||||
          LC_CTYPE: "en_US.UTF-8"
 | 
			
		||||
        run: nix develop --command -- gotestsum
 | 
			
		||||
 
 | 
			
		||||
@@ -3,9 +3,7 @@ package hscontrol
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	_ "embed"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"html/template"
 | 
			
		||||
@@ -157,13 +155,19 @@ func (a *AuthProviderOIDC) RegisterHandler(
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	randomBlob := make([]byte, randomByteSize)
 | 
			
		||||
	if _, err := rand.Read(randomBlob); err != nil {
 | 
			
		||||
	// Set the state and nonce cookies to protect against CSRF attacks
 | 
			
		||||
	state, err := setCSRFCookie(writer, req, "state")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(writer, "Internal server error", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	stateStr := hex.EncodeToString(randomBlob)[:32]
 | 
			
		||||
	// Set the state and nonce cookies to protect against CSRF attacks
 | 
			
		||||
	nonce, err := setCSRFCookie(writer, req, "nonce")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(writer, "Internal server error", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Initialize registration info with machine key
 | 
			
		||||
	registrationInfo := RegistrationInfo{
 | 
			
		||||
@@ -191,11 +195,12 @@ func (a *AuthProviderOIDC) RegisterHandler(
 | 
			
		||||
	for k, v := range a.cfg.ExtraParams {
 | 
			
		||||
		extras = append(extras, oauth2.SetAuthURLParam(k, v))
 | 
			
		||||
	}
 | 
			
		||||
	extras = append(extras, oidc.Nonce(nonce))
 | 
			
		||||
 | 
			
		||||
	// Cache the registration info
 | 
			
		||||
	a.registrationCache.Set(stateStr, registrationInfo)
 | 
			
		||||
	a.registrationCache.Set(state, registrationInfo)
 | 
			
		||||
 | 
			
		||||
	authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
 | 
			
		||||
	authURL := a.oauth2Config.AuthCodeURL(state, extras...)
 | 
			
		||||
	log.Debug().Msgf("Redirecting to %s for authentication", authURL)
 | 
			
		||||
 | 
			
		||||
	http.Redirect(writer, req, authURL, http.StatusFound)
 | 
			
		||||
@@ -228,11 +233,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Debug().Interface("cookies", req.Cookies()).Msg("Received oidc callback")
 | 
			
		||||
	cookieState, err := req.Cookie("state")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(writer, "state not found", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if state != cookieState.Value {
 | 
			
		||||
		http.Error(writer, "state did not match", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idToken, err := a.extractIDToken(req.Context(), code, state)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(writer, err.Error(), http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nonce, err := req.Cookie("nonce")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(writer, "nonce not found", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if idToken.Nonce != nonce.Value {
 | 
			
		||||
		http.Error(writer, "nonce did not match", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
 | 
			
		||||
 | 
			
		||||
	var claims types.OIDCClaims
 | 
			
		||||
@@ -592,3 +620,22 @@ func getUserName(
 | 
			
		||||
 | 
			
		||||
	return userName, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
 | 
			
		||||
	val, err := util.GenerateRandomStringURLSafe(64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return val, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c := &http.Cookie{
 | 
			
		||||
		Path:     "/oidc/callback",
 | 
			
		||||
		Name:     name,
 | 
			
		||||
		Value:    val,
 | 
			
		||||
		MaxAge:   int(time.Hour.Seconds()),
 | 
			
		||||
		Secure:   r.TLS != nil,
 | 
			
		||||
		HttpOnly: true,
 | 
			
		||||
	}
 | 
			
		||||
	http.SetCookie(w, c)
 | 
			
		||||
 | 
			
		||||
	return val, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,8 @@ import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/cookiejar"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"net/netip"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -747,6 +749,24 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type LoggingRoundTripper struct{}
 | 
			
		||||
 | 
			
		||||
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
 | 
			
		||||
	noTls := &http.Transport{
 | 
			
		||||
		TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
 | 
			
		||||
	}
 | 
			
		||||
	resp, err := noTls.RoundTrip(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("---")
 | 
			
		||||
	log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
 | 
			
		||||
	log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
 | 
			
		||||
 | 
			
		||||
	return resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AuthOIDCScenario) runTailscaleUp(
 | 
			
		||||
	userStr, loginServer string,
 | 
			
		||||
) error {
 | 
			
		||||
@@ -758,35 +778,39 @@ func (s *AuthOIDCScenario) runTailscaleUp(
 | 
			
		||||
	log.Printf("running tailscale up for user %s", userStr)
 | 
			
		||||
	if user, ok := s.users[userStr]; ok {
 | 
			
		||||
		for _, client := range user.Clients {
 | 
			
		||||
			c := client
 | 
			
		||||
			tsc := client
 | 
			
		||||
			user.joinWaitGroup.Go(func() error {
 | 
			
		||||
				loginURL, err := c.LoginWithURL(loginServer)
 | 
			
		||||
				loginURL, err := tsc.LoginWithURL(loginServer)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
 | 
			
		||||
					log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP())
 | 
			
		||||
				loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetHostname())
 | 
			
		||||
				loginURL.Scheme = "http"
 | 
			
		||||
 | 
			
		||||
				if len(headscale.GetCert()) > 0 {
 | 
			
		||||
					loginURL.Scheme = "https"
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				insecureTransport := &http.Transport{
 | 
			
		||||
					TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
 | 
			
		||||
				httptest.NewRecorder()
 | 
			
		||||
				hc := &http.Client{
 | 
			
		||||
					Transport: LoggingRoundTripper{},
 | 
			
		||||
				}
 | 
			
		||||
				hc.Jar, err = cookiejar.New(nil)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Printf("failed to create cookie jar: %s", err)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())
 | 
			
		||||
				log.Printf("%s login url: %s\n", tsc.Hostname(), loginURL.String())
 | 
			
		||||
 | 
			
		||||
				log.Printf("%s logging in with url", c.Hostname())
 | 
			
		||||
				httpClient := &http.Client{Transport: insecureTransport}
 | 
			
		||||
				log.Printf("%s logging in with url", tsc.Hostname())
 | 
			
		||||
				ctx := context.Background()
 | 
			
		||||
				req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
 | 
			
		||||
				resp, err := httpClient.Do(req)
 | 
			
		||||
				resp, err := hc.Do(req)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Printf(
 | 
			
		||||
						"%s failed to login using url %s: %s",
 | 
			
		||||
						c.Hostname(),
 | 
			
		||||
						tsc.Hostname(),
 | 
			
		||||
						loginURL,
 | 
			
		||||
						err,
 | 
			
		||||
					)
 | 
			
		||||
@@ -794,8 +818,10 @@ func (s *AuthOIDCScenario) runTailscaleUp(
 | 
			
		||||
					return err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
 | 
			
		||||
 | 
			
		||||
				if resp.StatusCode != http.StatusOK {
 | 
			
		||||
					log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
 | 
			
		||||
					log.Printf("%s response code of oidc login request was %s", tsc.Hostname(), resp.Status)
 | 
			
		||||
					body, _ := io.ReadAll(resp.Body)
 | 
			
		||||
					log.Printf("body: %s", body)
 | 
			
		||||
 | 
			
		||||
@@ -806,12 +832,12 @@ func (s *AuthOIDCScenario) runTailscaleUp(
 | 
			
		||||
 | 
			
		||||
				_, err = io.ReadAll(resp.Body)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Printf("%s failed to read response body: %s", c.Hostname(), err)
 | 
			
		||||
					log.Printf("%s failed to read response body: %s", tsc.Hostname(), err)
 | 
			
		||||
 | 
			
		||||
					return err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				log.Printf("Finished request for %s to join tailnet", c.Hostname())
 | 
			
		||||
				log.Printf("Finished request for %s to join tailnet", tsc.Hostname())
 | 
			
		||||
				return nil
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user