mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-14 06:05:15 +09:00
Generalise the registration pipeline to a more general auth pipeline supporting both node registrations and SSH check auth requests. Rename RegistrationID to AuthID, unexport AuthRequest fields, and introduce AuthVerdict to unify the auth finish API. Add the urlParam generic helper for extracting typed URL parameters from chi routes, used by the new auth request handler. Updates #1850
339 lines
9.4 KiB
Go
339 lines
9.4 KiB
Go
package util
|
|
|
|
import (
|
|
"cmp"
|
|
"errors"
|
|
"fmt"
|
|
"net/netip"
|
|
"net/url"
|
|
"os"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/util/cmpver"
|
|
)
|
|
|
|
// URL parsing errors.
|
|
var (
|
|
ErrMultipleURLsFound = errors.New("multiple URLs found")
|
|
ErrNoURLFound = errors.New("no URL found")
|
|
ErrEmptyTracerouteOutput = errors.New("empty traceroute output")
|
|
ErrTracerouteHeaderParse = errors.New("parsing traceroute header")
|
|
ErrTracerouteDidNotReach = errors.New("traceroute did not reach target")
|
|
)
|
|
|
|
func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool {
|
|
if cmpver.Compare(minimum, toCheck) <= 0 ||
|
|
toCheck == "unstable" ||
|
|
toCheck == "head" {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// ParseLoginURLFromCLILogin parses the output of the tailscale up command to extract the login URL.
|
|
// It returns an error if not exactly one URL is found.
|
|
func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
|
|
lines := strings.Split(output, "\n")
|
|
|
|
var urlStr string
|
|
|
|
for _, line := range lines {
|
|
line = strings.TrimSpace(line)
|
|
if strings.HasPrefix(line, "http://") || strings.HasPrefix(line, "https://") {
|
|
if urlStr != "" {
|
|
return nil, fmt.Errorf("%w: %s and %s", ErrMultipleURLsFound, urlStr, line)
|
|
}
|
|
|
|
urlStr = line
|
|
}
|
|
}
|
|
|
|
if urlStr == "" {
|
|
return nil, ErrNoURLFound
|
|
}
|
|
|
|
loginURL, err := url.Parse(urlStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing URL: %w", err)
|
|
}
|
|
|
|
return loginURL, nil
|
|
}
|
|
|
|
type TraceroutePath struct {
|
|
// Hop is the current jump in the total traceroute.
|
|
Hop int
|
|
|
|
// Hostname is the resolved hostname or IP address identifying the jump
|
|
Hostname string
|
|
|
|
// IP is the IP address of the jump
|
|
IP netip.Addr
|
|
|
|
// Latencies is a list of the latencies for this jump
|
|
Latencies []time.Duration
|
|
}
|
|
|
|
type Traceroute struct {
|
|
// Hostname is the resolved hostname or IP address identifying the target
|
|
Hostname string
|
|
|
|
// IP is the IP address of the target
|
|
IP netip.Addr
|
|
|
|
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
|
|
Route []TraceroutePath
|
|
|
|
// Success indicates if the traceroute was successful.
|
|
Success bool
|
|
|
|
// Err contains an error if the traceroute was not successful.
|
|
Err error
|
|
}
|
|
|
|
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct.
|
|
func ParseTraceroute(output string) (Traceroute, error) {
|
|
lines := strings.Split(strings.TrimSpace(output), "\n")
|
|
if len(lines) < 1 {
|
|
return Traceroute{}, ErrEmptyTracerouteOutput
|
|
}
|
|
|
|
// Parse the header line - handle both 'traceroute' and 'tracert' (Windows)
|
|
headerRegex := regexp.MustCompile(`(?i)(?:traceroute|tracing route) to ([^ ]+) (?:\[([^\]]+)\]|\(([^)]+)\))`)
|
|
|
|
headerMatches := headerRegex.FindStringSubmatch(lines[0])
|
|
if len(headerMatches) < 2 {
|
|
return Traceroute{}, fmt.Errorf("%w: %s", ErrTracerouteHeaderParse, lines[0])
|
|
}
|
|
|
|
hostname := headerMatches[1]
|
|
// IP can be in either capture group 2 or 3 depending on format
|
|
ipStr := headerMatches[2]
|
|
if ipStr == "" {
|
|
ipStr = headerMatches[3]
|
|
}
|
|
|
|
ip, err := netip.ParseAddr(ipStr)
|
|
if err != nil {
|
|
return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err)
|
|
}
|
|
|
|
result := Traceroute{
|
|
Hostname: hostname,
|
|
IP: ip,
|
|
Route: []TraceroutePath{},
|
|
Success: false,
|
|
}
|
|
|
|
// More flexible regex that handles various traceroute output formats
|
|
// Main pattern handles: "hostname (IP)", "hostname [IP]", "IP only", "* * *"
|
|
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(.*)$`)
|
|
// Patterns for parsing the hop details
|
|
hostIPRegex := regexp.MustCompile(`^([^ ]+) \(([^)]+)\)`)
|
|
hostIPBracketRegex := regexp.MustCompile(`^([^ ]+) \[([^\]]+)\]`)
|
|
// Pattern for latencies with flexible spacing and optional '<'
|
|
latencyRegex := regexp.MustCompile(`(<?\d+(?:\.\d+)?)\s*ms\b`)
|
|
|
|
for i := 1; i < len(lines); i++ {
|
|
line := strings.TrimSpace(lines[i])
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
matches := hopRegex.FindStringSubmatch(line)
|
|
if len(matches) == 0 {
|
|
continue
|
|
}
|
|
|
|
hop, err := strconv.Atoi(matches[1])
|
|
if err != nil {
|
|
// Skip lines that don't start with a hop number
|
|
continue
|
|
}
|
|
|
|
remainder := strings.TrimSpace(matches[2])
|
|
|
|
var (
|
|
hopHostname string
|
|
hopIP netip.Addr
|
|
latencies []time.Duration
|
|
)
|
|
|
|
// Check for Windows tracert format which has latencies before hostname
|
|
// Format: " 1 <1 ms <1 ms <1 ms router.local [192.168.1.1]"
|
|
latencyFirst := false
|
|
|
|
if strings.Contains(remainder, " ms ") && !strings.HasPrefix(remainder, "*") {
|
|
// Check if latencies appear before any hostname/IP
|
|
firstSpace := strings.Index(remainder, " ")
|
|
if firstSpace > 0 {
|
|
firstPart := remainder[:firstSpace]
|
|
if _, err := strconv.ParseFloat(strings.TrimPrefix(firstPart, "<"), 64); err == nil { //nolint:noinlineerr
|
|
latencyFirst = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if latencyFirst {
|
|
// Windows format: extract latencies first
|
|
for {
|
|
latMatch := latencyRegex.FindStringSubmatchIndex(remainder)
|
|
if latMatch == nil || latMatch[0] > 0 {
|
|
break
|
|
}
|
|
// Extract and remove the latency from the beginning
|
|
latStr := strings.TrimPrefix(remainder[latMatch[2]:latMatch[3]], "<")
|
|
|
|
ms, err := strconv.ParseFloat(latStr, 64)
|
|
if err == nil {
|
|
// Round to nearest microsecond to avoid floating point precision issues
|
|
duration := time.Duration(ms * float64(time.Millisecond))
|
|
latencies = append(latencies, duration.Round(time.Microsecond))
|
|
}
|
|
|
|
remainder = strings.TrimSpace(remainder[latMatch[1]:])
|
|
}
|
|
}
|
|
|
|
// Now parse hostname/IP from remainder
|
|
if strings.HasPrefix(remainder, "*") {
|
|
// Timeout hop
|
|
hopHostname = "*"
|
|
// Skip any remaining asterisks
|
|
remainder = strings.TrimLeft(remainder, "* ")
|
|
} else if hostMatch := hostIPRegex.FindStringSubmatch(remainder); len(hostMatch) >= 3 {
|
|
// Format: hostname (IP)
|
|
hopHostname = hostMatch[1]
|
|
hopIP, _ = netip.ParseAddr(hostMatch[2])
|
|
remainder = strings.TrimSpace(remainder[len(hostMatch[0]):])
|
|
} else if hostMatch := hostIPBracketRegex.FindStringSubmatch(remainder); len(hostMatch) >= 3 {
|
|
// Format: hostname [IP] (Windows)
|
|
hopHostname = hostMatch[1]
|
|
hopIP, _ = netip.ParseAddr(hostMatch[2])
|
|
remainder = strings.TrimSpace(remainder[len(hostMatch[0]):])
|
|
} else {
|
|
// Try to parse as IP only or hostname only
|
|
parts := strings.Fields(remainder)
|
|
if len(parts) > 0 {
|
|
hopHostname = parts[0]
|
|
if ip, err := netip.ParseAddr(parts[0]); err == nil { //nolint:noinlineerr
|
|
hopIP = ip
|
|
}
|
|
|
|
remainder = strings.TrimSpace(strings.Join(parts[1:], " "))
|
|
}
|
|
}
|
|
|
|
// Extract latencies from the remaining part (if not already done)
|
|
if !latencyFirst {
|
|
latencyMatches := latencyRegex.FindAllStringSubmatch(remainder, -1)
|
|
for _, match := range latencyMatches {
|
|
if len(match) > 1 {
|
|
// Remove '<' prefix if present (e.g., "<1 ms")
|
|
latStr := strings.TrimPrefix(match[1], "<")
|
|
|
|
ms, err := strconv.ParseFloat(latStr, 64)
|
|
if err == nil {
|
|
// Round to nearest microsecond to avoid floating point precision issues
|
|
duration := time.Duration(ms * float64(time.Millisecond))
|
|
latencies = append(latencies, duration.Round(time.Microsecond))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
path := TraceroutePath{
|
|
Hop: hop,
|
|
Hostname: hopHostname,
|
|
IP: hopIP,
|
|
Latencies: latencies,
|
|
}
|
|
|
|
result.Route = append(result.Route, path)
|
|
|
|
// Check if we've reached the target
|
|
if hopIP == ip {
|
|
result.Success = true
|
|
}
|
|
}
|
|
|
|
// If we didn't reach the target, it's unsuccessful
|
|
if !result.Success {
|
|
result.Err = ErrTracerouteDidNotReach
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func IsCI() bool {
|
|
if _, ok := os.LookupEnv("CI"); ok {
|
|
return true
|
|
}
|
|
|
|
if _, ok := os.LookupEnv("GITHUB_RUN_ID"); ok {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// EnsureHostname guarantees a valid hostname for node registration.
|
|
// It extracts a hostname from Hostinfo, providing sensible defaults
|
|
// if Hostinfo is nil or Hostname is empty. This prevents nil pointer dereferences
|
|
// and ensures nodes always have a valid hostname.
|
|
// The hostname is truncated to 63 characters to comply with DNS label length limits (RFC 1123).
|
|
// This function never fails - it always returns a valid hostname.
|
|
//
|
|
// Strategy:
|
|
// 1. If hostinfo is nil/empty → generate default from keys
|
|
// 2. If hostname is provided → normalise it
|
|
// 3. If normalisation fails → generate invalid-<random> replacement
|
|
//
|
|
// Returns the guaranteed-valid hostname to use.
|
|
func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string {
|
|
if !hostinfo.Valid() || hostinfo.Hostname() == "" {
|
|
key := cmp.Or(machineKey, nodeKey)
|
|
if key == "" {
|
|
return "unknown-node"
|
|
}
|
|
|
|
keyPrefix := key
|
|
if len(key) > 8 {
|
|
keyPrefix = key[:8]
|
|
}
|
|
|
|
return "node-" + keyPrefix
|
|
}
|
|
|
|
lowercased := strings.ToLower(hostinfo.Hostname())
|
|
|
|
err := ValidateHostname(lowercased)
|
|
if err == nil {
|
|
return lowercased
|
|
}
|
|
|
|
return InvalidString()
|
|
}
|
|
|
|
// GenerateRegistrationKey generates a vanity key for tracking web authentication
|
|
// registration flows in logs. This key is NOT stored in the database and does NOT use bcrypt -
|
|
// it's purely for observability and correlating log entries during the registration process.
|
|
func GenerateRegistrationKey() (string, error) {
|
|
const (
|
|
registerKeyPrefix = "hskey-reg-" //nolint:gosec // This is a vanity key for logging, not a credential
|
|
registerKeyLength = 64
|
|
)
|
|
|
|
randomPart, err := GenerateRandomStringURLSafe(registerKeyLength)
|
|
if err != nil {
|
|
return "", fmt.Errorf("generating registration key: %w", err)
|
|
}
|
|
|
|
return registerKeyPrefix + randomPart, nil
|
|
}
|