From 82958835ceb26df8554c53c773044f98a1a93b43 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 16 Feb 2026 14:02:05 +0000 Subject: [PATCH] db: enforce strict version upgrade path Add a version check that runs before database migrations to ensure users do not skip minor versions or downgrade. This protects database migrations and allows future cleanup of old migration code. Rules enforced: - Same minor version: always allowed (patch changes either way) - Single minor upgrade (e.g. 0.27 -> 0.28): allowed - Multi-minor upgrade (e.g. 0.25 -> 0.28): blocked with guidance - Any minor downgrade: blocked - Major version change: blocked - Dev builds: warn but allow, preserve stored version The version is stored in a purpose-built database_versions table after migrations succeed. The table is created with raw SQL before gormigrate runs to avoid circular dependencies. Updates #3058 --- hscontrol/db/db.go | 19 ++ hscontrol/db/schema.sql | 6 + hscontrol/db/versioncheck.go | 256 ++++++++++++++++++++++++ hscontrol/db/versioncheck_test.go | 318 ++++++++++++++++++++++++++++++ 4 files changed, 599 insertions(+) create mode 100644 hscontrol/db/versioncheck.go create mode 100644 hscontrol/db/versioncheck_test.go diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index c518502e..6841f446 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -63,6 +63,11 @@ func NewHeadscaleDatabase( return nil, err } + err = checkVersionUpgradePath(dbConn) + if err != nil { + return nil, fmt.Errorf("version check: %w", err) + } + migrations := gormigrate.New( dbConn, gormigrate.DefaultOptions, @@ -760,6 +765,20 @@ AND auth_key_id NOT IN ( return nil, fmt.Errorf("migration failed: %w", err) } + // Store the current version in the database after migrations succeed. + // Dev builds skip this to preserve the stored version for the next + // real versioned binary. + currentVersion := types.GetVersionInfo().Version + if !isDev(currentVersion) { + err = setDatabaseVersion(dbConn, currentVersion) + if err != nil { + return nil, fmt.Errorf( + "storing database version: %w", + err, + ) + } + } + // Validate that the schema ends up in the expected state. // This is currently only done on sqlite as squibble does not // support Postgres and we use our sqlite schema as our source of diff --git a/hscontrol/db/schema.sql b/hscontrol/db/schema.sql index ef0a2a0e..41e817ee 100644 --- a/hscontrol/db/schema.sql +++ b/hscontrol/db/schema.sql @@ -104,3 +104,9 @@ CREATE TABLE policies( deleted_at datetime ); CREATE INDEX idx_policies_deleted_at ON policies(deleted_at); + +CREATE TABLE database_versions( + id integer PRIMARY KEY, + version text NOT NULL, + updated_at datetime +); diff --git a/hscontrol/db/versioncheck.go b/hscontrol/db/versioncheck.go new file mode 100644 index 00000000..1f2591d9 --- /dev/null +++ b/hscontrol/db/versioncheck.go @@ -0,0 +1,256 @@ +package db + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +var errVersionUpgrade = errors.New("version upgrade not supported") + +var errVersionDowngrade = errors.New("version downgrade not supported") + +var errVersionMajorChange = errors.New("major version change not supported") + +var errVersionParse = errors.New("cannot parse version") + +var errVersionFormat = errors.New( + "version does not follow semver major.minor.patch format", +) + +// DatabaseVersion tracks the headscale version that last +// successfully started against this database. +// It is a single-row table (ID is always 1). +type DatabaseVersion struct { + ID uint `gorm:"primaryKey"` + Version string `gorm:"not null"` + UpdatedAt time.Time +} + +const createDatabaseVersionsSQL = `CREATE TABLE IF NOT EXISTS database_versions( + id integer PRIMARY KEY, + version text NOT NULL, + updated_at datetime +)` + +// semver holds parsed major.minor.patch components. +type semver struct { + Major int + Minor int + Patch int +} + +func (s semver) String() string { + return fmt.Sprintf("v%d.%d.%d", s.Major, s.Minor, s.Patch) +} + +// parseVersion parses a version string like "v0.25.0", "0.25.1", +// "v0.25.0-beta.1", or "v0.25.0-rc1+build123" into its major, minor, +// patch components. Pre-release and build metadata suffixes are stripped. +func parseVersion(s string) (semver, error) { + if s == "" || s == "dev" { + return semver{}, fmt.Errorf("%q: %w", s, errVersionParse) + } + + v := strings.TrimPrefix(s, "v") + + // Strip pre-release suffix (everything after first '-') + // and build metadata (everything after first '+'). + if idx := strings.IndexAny(v, "-+"); idx != -1 { + v = v[:idx] + } + + parts := strings.Split(v, ".") + if len(parts) != 3 { + return semver{}, fmt.Errorf("%q: %w", s, errVersionFormat) + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return semver{}, fmt.Errorf("invalid major version in %q: %w", s, err) + } + + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return semver{}, fmt.Errorf("invalid minor version in %q: %w", s, err) + } + + patch, err := strconv.Atoi(parts[2]) + if err != nil { + return semver{}, fmt.Errorf("invalid patch version in %q: %w", s, err) + } + + return semver{Major: major, Minor: minor, Patch: patch}, nil +} + +// ensureDatabaseVersionTable creates the database_versions table if it +// does not already exist. Uses raw SQL to match schema.sql exactly. +// This runs before gormigrate migrations. +func ensureDatabaseVersionTable(db *gorm.DB) error { + err := db.Exec(createDatabaseVersionsSQL).Error + if err != nil { + return fmt.Errorf("creating database version table: %w", err) + } + + return nil +} + +// getDatabaseVersion reads the stored version from the database. +// Returns an empty string if no version has been stored yet. +func getDatabaseVersion(db *gorm.DB) (string, error) { + var version string + + result := db.Raw("SELECT version FROM database_versions WHERE id = 1").Scan(&version) + if result.Error != nil { + return "", fmt.Errorf("reading database version: %w", result.Error) + } + + if result.RowsAffected == 0 { + return "", nil + } + + return version, nil +} + +// setDatabaseVersion upserts the version row in the database. +func setDatabaseVersion(db *gorm.DB, version string) error { + now := time.Now().UTC() + + // Try update first, then insert if no rows affected. + result := db.Exec( + "UPDATE database_versions SET version = ?, updated_at = ? WHERE id = 1", + version, now, + ) + if result.Error != nil { + return fmt.Errorf("updating database version: %w", result.Error) + } + + if result.RowsAffected == 0 { + err := db.Exec( + "INSERT INTO database_versions (id, version, updated_at) VALUES (1, ?, ?)", + version, now, + ).Error + if err != nil { + return fmt.Errorf("inserting database version: %w", err) + } + } + + return nil +} + +// isDev reports whether a version string represents a development build +// that should skip version checking. +func isDev(version string) bool { + return version == "" || version == "dev" || version == "(devel)" +} + +// checkVersionUpgradePath verifies that the running headscale version +// is compatible with the version that last used this database. +// +// Rules: +// - If the running binary has no version ("dev" or empty), warn and skip. +// - If no version is stored in the database, allow (first run with this feature). +// - If the stored version is "dev", allow (previous run was unversioned). +// - Same minor version: always allowed (patch changes in either direction). +// - Single minor version upgrade (stored.minor+1 == current.minor): allowed. +// - Multi-minor upgrade or any minor downgrade: blocked with a fatal error. +func checkVersionUpgradePath(db *gorm.DB) error { + err := ensureDatabaseVersionTable(db) + if err != nil { + return err + } + + currentVersion := types.GetVersionInfo().Version + + // Running binary has no real version — skip the check but + // preserve whatever version is already stored. + if isDev(currentVersion) { + storedVersion, err := getDatabaseVersion(db) + if err != nil { + return err + } + + if storedVersion != "" && !isDev(storedVersion) { + log.Warn(). + Str("database_version", storedVersion). + Msg("running a development build of headscale without a version number, " + + "database version check is skipped, the stored database version is preserved") + } + + return nil + } + + storedVersion, err := getDatabaseVersion(db) + if err != nil { + return err + } + + // No stored version — first run with this feature. Allow startup; + // the version will be stored after migrations succeed. + if storedVersion == "" { + return nil + } + + // Previous run was an unversioned build — no meaningful comparison. + if isDev(storedVersion) { + return nil + } + + current, err := parseVersion(currentVersion) + if err != nil { + return fmt.Errorf("parsing current version: %w", err) + } + + stored, err := parseVersion(storedVersion) + if err != nil { + return fmt.Errorf("parsing stored database version: %w", err) + } + + if current.Major != stored.Major { + return fmt.Errorf( + "headscale version %s cannot be used with a database last used by %s: %w", + currentVersion, storedVersion, errVersionMajorChange, + ) + } + + minorDiff := current.Minor - stored.Minor + + switch { + case minorDiff == 0: + // Same minor version — patch changes are always fine. + return nil + + case minorDiff == 1: + // Single minor version upgrade — allowed. + return nil + + case minorDiff > 1: + // Multi-minor upgrade — blocked. + return fmt.Errorf( + "headscale version %s cannot be used with a database last used by %s, "+ + "upgrading more than one minor version at a time is not supported, "+ + "please upgrade to the latest v%d.%d.x release first, then to %s, "+ + "release page: https://github.com/juanfont/headscale/releases: %w", + currentVersion, storedVersion, + stored.Major, stored.Minor+1, + current.String(), + errVersionUpgrade, + ) + + default: + // minorDiff < 0 — any minor downgrade is blocked. + return fmt.Errorf( + "headscale version %s cannot be used with a database last used by %s, "+ + "downgrading to a previous minor version is not supported, "+ + "release page: https://github.com/juanfont/headscale/releases: %w", + currentVersion, storedVersion, + errVersionDowngrade, + ) + } +} diff --git a/hscontrol/db/versioncheck_test.go b/hscontrol/db/versioncheck_test.go new file mode 100644 index 00000000..83c9ab5c --- /dev/null +++ b/hscontrol/db/versioncheck_test.go @@ -0,0 +1,318 @@ +package db + +import ( + "fmt" + "testing" + + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestParseVersion(t *testing.T) { + tests := []struct { + input string + want semver + wantErr bool + }{ + {input: "v0.25.0", want: semver{0, 25, 0}}, + {input: "0.25.0", want: semver{0, 25, 0}}, + {input: "v0.25.1", want: semver{0, 25, 1}}, + {input: "v1.0.0", want: semver{1, 0, 0}}, + {input: "v0.28.3", want: semver{0, 28, 3}}, + // Pre-release suffixes stripped + {input: "v0.25.0-beta.1", want: semver{0, 25, 0}}, + {input: "v0.25.0-rc1", want: semver{0, 25, 0}}, + // Build metadata stripped + {input: "v0.25.0+build123", want: semver{0, 25, 0}}, + {input: "v0.25.0-beta.1+build123", want: semver{0, 25, 0}}, + // Invalid inputs + {input: "", wantErr: true}, + {input: "dev", wantErr: true}, + {input: "vfoo.bar.baz", wantErr: true}, + {input: "v1.2", wantErr: true}, + {input: "v1", wantErr: true}, + {input: "not-a-version", wantErr: true}, + {input: "v1.2.3.4", wantErr: true}, + {input: "(devel)", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, err := parseVersion(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestSemverString(t *testing.T) { + s := semver{0, 28, 3} + assert.Equal(t, "v0.28.3", s.String()) +} + +func TestIsDev(t *testing.T) { + assert.True(t, isDev("")) + assert.True(t, isDev("dev")) + assert.True(t, isDev("(devel)")) + assert.False(t, isDev("v0.28.0")) + assert.False(t, isDev("0.28.0")) +} + +// versionTestDB creates an in-memory SQLite database with the +// database_versions table already bootstrapped. +func versionTestDB(t *testing.T) *gorm.DB { + t.Helper() + + db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{}) + require.NoError(t, err) + + err = ensureDatabaseVersionTable(db) + require.NoError(t, err) + + return db +} + +func TestSetAndGetDatabaseVersion(t *testing.T) { + db := versionTestDB(t) + + // Initially empty + v, err := getDatabaseVersion(db) + require.NoError(t, err) + assert.Empty(t, v) + + // Set a version + err = setDatabaseVersion(db, "v0.27.0") + require.NoError(t, err) + + v, err = getDatabaseVersion(db) + require.NoError(t, err) + assert.Equal(t, "v0.27.0", v) + + // Update the version (upsert) + err = setDatabaseVersion(db, "v0.28.0") + require.NoError(t, err) + + v, err = getDatabaseVersion(db) + require.NoError(t, err) + assert.Equal(t, "v0.28.0", v) +} + +func TestEnsureDatabaseVersionTableIdempotent(t *testing.T) { + db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{}) + require.NoError(t, err) + + // Call twice — should not error + err = ensureDatabaseVersionTable(db) + require.NoError(t, err) + + err = ensureDatabaseVersionTable(db) + require.NoError(t, err) +} + +// TestCheckVersionUpgradePathDirect tests the version comparison logic +// by directly seeding the database, bypassing types.GetVersionInfo() +// (which returns "dev" in test environments and cannot be overridden). +func TestCheckVersionUpgradePathDirect(t *testing.T) { + tests := []struct { + name string + storedVersion string // empty means no row stored + currentVersion string + wantErr bool + errContains string + }{ + // Fresh database (no stored version) + { + name: "fresh db allows any version", + storedVersion: "", + currentVersion: "v0.28.0", + }, + + // Stored is dev + { + name: "real version over dev db", + storedVersion: "dev", + currentVersion: "v0.28.0", + }, + { + name: "devel version in db", + storedVersion: "(devel)", + currentVersion: "v0.28.0", + }, + + // Same version + { + name: "same version", + storedVersion: "v0.27.0", + currentVersion: "v0.27.0", + }, + + // Patch changes within same minor + { + name: "patch upgrade", + storedVersion: "v0.27.0", + currentVersion: "v0.27.3", + }, + { + name: "patch downgrade within same minor", + storedVersion: "v0.27.3", + currentVersion: "v0.27.0", + }, + + // Single minor upgrade + { + name: "single minor upgrade", + storedVersion: "v0.27.0", + currentVersion: "v0.28.0", + }, + { + name: "single minor upgrade with different patches", + storedVersion: "v0.27.3", + currentVersion: "v0.28.1", + }, + + // Multi-minor upgrade (blocked) + { + name: "two minor versions ahead", + storedVersion: "v0.25.0", + currentVersion: "v0.27.0", + wantErr: true, + errContains: "latest v0.26.x", + }, + { + name: "three minor versions ahead", + storedVersion: "v0.25.0", + currentVersion: "v0.28.0", + wantErr: true, + errContains: "latest v0.26.x", + }, + + // Minor downgrades (blocked) + { + name: "single minor downgrade", + storedVersion: "v0.28.0", + currentVersion: "v0.27.0", + wantErr: true, + errContains: "downgrading", + }, + { + name: "multi minor downgrade", + storedVersion: "v0.28.0", + currentVersion: "v0.25.0", + wantErr: true, + errContains: "downgrading", + }, + + // Major version mismatch + { + name: "major version upgrade", + storedVersion: "v0.28.0", + currentVersion: "v1.0.0", + wantErr: true, + errContains: "major version", + }, + { + name: "major version downgrade", + storedVersion: "v1.0.0", + currentVersion: "v0.28.0", + wantErr: true, + errContains: "major version", + }, + + // Pre-release versions + { + name: "pre-release single minor upgrade", + storedVersion: "v0.27.0", + currentVersion: "v0.28.0-beta.1", + }, + { + name: "pre-release multi minor upgrade blocked", + storedVersion: "v0.25.0", + currentVersion: "v0.27.0-rc1", + wantErr: true, + errContains: "latest v0.26.x", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := versionTestDB(t) + + // Seed the stored version if provided + if tt.storedVersion != "" { + err := setDatabaseVersion(db, tt.storedVersion) + require.NoError(t, err) + } + + err := checkVersionUpgradePathFromVersions(db, tt.currentVersion) + if tt.wantErr { + require.Error(t, err) + + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// checkVersionUpgradePathFromVersions is a test helper that runs the +// version comparison logic with a specific currentVersion string, +// bypassing types.GetVersionInfo(). It replicates the logic from +// checkVersionUpgradePath but accepts the version as a parameter. +func checkVersionUpgradePathFromVersions(db *gorm.DB, currentVersion string) error { + if isDev(currentVersion) { + return nil + } + + storedVersion, err := getDatabaseVersion(db) + if err != nil { + return err + } + + if storedVersion == "" { + return nil + } + + if isDev(storedVersion) { + return nil + } + + current, err := parseVersion(currentVersion) + if err != nil { + return err + } + + stored, err := parseVersion(storedVersion) + if err != nil { + return err + } + + if current.Major != stored.Major { + return errVersionMajorChange + } + + minorDiff := current.Minor - stored.Minor + + switch { + case minorDiff == 0: + return nil + case minorDiff == 1: + return nil + case minorDiff > 1: + return fmt.Errorf( + "please upgrade to the latest v%d.%d.x release first: %w", + stored.Major, stored.Minor+1, + errVersionUpgrade, + ) + default: + return fmt.Errorf("downgrading: %w", errVersionDowngrade) + } +}