diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index c518502e..6841f446 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -63,6 +63,11 @@ func NewHeadscaleDatabase( return nil, err } + err = checkVersionUpgradePath(dbConn) + if err != nil { + return nil, fmt.Errorf("version check: %w", err) + } + migrations := gormigrate.New( dbConn, gormigrate.DefaultOptions, @@ -760,6 +765,20 @@ AND auth_key_id NOT IN ( return nil, fmt.Errorf("migration failed: %w", err) } + // Store the current version in the database after migrations succeed. + // Dev builds skip this to preserve the stored version for the next + // real versioned binary. + currentVersion := types.GetVersionInfo().Version + if !isDev(currentVersion) { + err = setDatabaseVersion(dbConn, currentVersion) + if err != nil { + return nil, fmt.Errorf( + "storing database version: %w", + err, + ) + } + } + // Validate that the schema ends up in the expected state. // This is currently only done on sqlite as squibble does not // support Postgres and we use our sqlite schema as our source of diff --git a/hscontrol/db/schema.sql b/hscontrol/db/schema.sql index ef0a2a0e..41e817ee 100644 --- a/hscontrol/db/schema.sql +++ b/hscontrol/db/schema.sql @@ -104,3 +104,9 @@ CREATE TABLE policies( deleted_at datetime ); CREATE INDEX idx_policies_deleted_at ON policies(deleted_at); + +CREATE TABLE database_versions( + id integer PRIMARY KEY, + version text NOT NULL, + updated_at datetime +); diff --git a/hscontrol/db/versioncheck.go b/hscontrol/db/versioncheck.go new file mode 100644 index 00000000..1f2591d9 --- /dev/null +++ b/hscontrol/db/versioncheck.go @@ -0,0 +1,256 @@ +package db + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +var errVersionUpgrade = errors.New("version upgrade not supported") + +var errVersionDowngrade = errors.New("version downgrade not supported") + +var errVersionMajorChange = errors.New("major version change not supported") + +var errVersionParse = errors.New("cannot parse version") + +var errVersionFormat = errors.New( + "version does not follow semver major.minor.patch format", +) + +// DatabaseVersion tracks the headscale version that last +// successfully started against this database. +// It is a single-row table (ID is always 1). +type DatabaseVersion struct { + ID uint `gorm:"primaryKey"` + Version string `gorm:"not null"` + UpdatedAt time.Time +} + +const createDatabaseVersionsSQL = `CREATE TABLE IF NOT EXISTS database_versions( + id integer PRIMARY KEY, + version text NOT NULL, + updated_at datetime +)` + +// semver holds parsed major.minor.patch components. +type semver struct { + Major int + Minor int + Patch int +} + +func (s semver) String() string { + return fmt.Sprintf("v%d.%d.%d", s.Major, s.Minor, s.Patch) +} + +// parseVersion parses a version string like "v0.25.0", "0.25.1", +// "v0.25.0-beta.1", or "v0.25.0-rc1+build123" into its major, minor, +// patch components. Pre-release and build metadata suffixes are stripped. +func parseVersion(s string) (semver, error) { + if s == "" || s == "dev" { + return semver{}, fmt.Errorf("%q: %w", s, errVersionParse) + } + + v := strings.TrimPrefix(s, "v") + + // Strip pre-release suffix (everything after first '-') + // and build metadata (everything after first '+'). + if idx := strings.IndexAny(v, "-+"); idx != -1 { + v = v[:idx] + } + + parts := strings.Split(v, ".") + if len(parts) != 3 { + return semver{}, fmt.Errorf("%q: %w", s, errVersionFormat) + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return semver{}, fmt.Errorf("invalid major version in %q: %w", s, err) + } + + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return semver{}, fmt.Errorf("invalid minor version in %q: %w", s, err) + } + + patch, err := strconv.Atoi(parts[2]) + if err != nil { + return semver{}, fmt.Errorf("invalid patch version in %q: %w", s, err) + } + + return semver{Major: major, Minor: minor, Patch: patch}, nil +} + +// ensureDatabaseVersionTable creates the database_versions table if it +// does not already exist. Uses raw SQL to match schema.sql exactly. +// This runs before gormigrate migrations. +func ensureDatabaseVersionTable(db *gorm.DB) error { + err := db.Exec(createDatabaseVersionsSQL).Error + if err != nil { + return fmt.Errorf("creating database version table: %w", err) + } + + return nil +} + +// getDatabaseVersion reads the stored version from the database. +// Returns an empty string if no version has been stored yet. +func getDatabaseVersion(db *gorm.DB) (string, error) { + var version string + + result := db.Raw("SELECT version FROM database_versions WHERE id = 1").Scan(&version) + if result.Error != nil { + return "", fmt.Errorf("reading database version: %w", result.Error) + } + + if result.RowsAffected == 0 { + return "", nil + } + + return version, nil +} + +// setDatabaseVersion upserts the version row in the database. +func setDatabaseVersion(db *gorm.DB, version string) error { + now := time.Now().UTC() + + // Try update first, then insert if no rows affected. + result := db.Exec( + "UPDATE database_versions SET version = ?, updated_at = ? WHERE id = 1", + version, now, + ) + if result.Error != nil { + return fmt.Errorf("updating database version: %w", result.Error) + } + + if result.RowsAffected == 0 { + err := db.Exec( + "INSERT INTO database_versions (id, version, updated_at) VALUES (1, ?, ?)", + version, now, + ).Error + if err != nil { + return fmt.Errorf("inserting database version: %w", err) + } + } + + return nil +} + +// isDev reports whether a version string represents a development build +// that should skip version checking. +func isDev(version string) bool { + return version == "" || version == "dev" || version == "(devel)" +} + +// checkVersionUpgradePath verifies that the running headscale version +// is compatible with the version that last used this database. +// +// Rules: +// - If the running binary has no version ("dev" or empty), warn and skip. +// - If no version is stored in the database, allow (first run with this feature). +// - If the stored version is "dev", allow (previous run was unversioned). +// - Same minor version: always allowed (patch changes in either direction). +// - Single minor version upgrade (stored.minor+1 == current.minor): allowed. +// - Multi-minor upgrade or any minor downgrade: blocked with a fatal error. +func checkVersionUpgradePath(db *gorm.DB) error { + err := ensureDatabaseVersionTable(db) + if err != nil { + return err + } + + currentVersion := types.GetVersionInfo().Version + + // Running binary has no real version — skip the check but + // preserve whatever version is already stored. + if isDev(currentVersion) { + storedVersion, err := getDatabaseVersion(db) + if err != nil { + return err + } + + if storedVersion != "" && !isDev(storedVersion) { + log.Warn(). + Str("database_version", storedVersion). + Msg("running a development build of headscale without a version number, " + + "database version check is skipped, the stored database version is preserved") + } + + return nil + } + + storedVersion, err := getDatabaseVersion(db) + if err != nil { + return err + } + + // No stored version — first run with this feature. Allow startup; + // the version will be stored after migrations succeed. + if storedVersion == "" { + return nil + } + + // Previous run was an unversioned build — no meaningful comparison. + if isDev(storedVersion) { + return nil + } + + current, err := parseVersion(currentVersion) + if err != nil { + return fmt.Errorf("parsing current version: %w", err) + } + + stored, err := parseVersion(storedVersion) + if err != nil { + return fmt.Errorf("parsing stored database version: %w", err) + } + + if current.Major != stored.Major { + return fmt.Errorf( + "headscale version %s cannot be used with a database last used by %s: %w", + currentVersion, storedVersion, errVersionMajorChange, + ) + } + + minorDiff := current.Minor - stored.Minor + + switch { + case minorDiff == 0: + // Same minor version — patch changes are always fine. + return nil + + case minorDiff == 1: + // Single minor version upgrade — allowed. + return nil + + case minorDiff > 1: + // Multi-minor upgrade — blocked. + return fmt.Errorf( + "headscale version %s cannot be used with a database last used by %s, "+ + "upgrading more than one minor version at a time is not supported, "+ + "please upgrade to the latest v%d.%d.x release first, then to %s, "+ + "release page: https://github.com/juanfont/headscale/releases: %w", + currentVersion, storedVersion, + stored.Major, stored.Minor+1, + current.String(), + errVersionUpgrade, + ) + + default: + // minorDiff < 0 — any minor downgrade is blocked. + return fmt.Errorf( + "headscale version %s cannot be used with a database last used by %s, "+ + "downgrading to a previous minor version is not supported, "+ + "release page: https://github.com/juanfont/headscale/releases: %w", + currentVersion, storedVersion, + errVersionDowngrade, + ) + } +} diff --git a/hscontrol/db/versioncheck_test.go b/hscontrol/db/versioncheck_test.go new file mode 100644 index 00000000..83c9ab5c --- /dev/null +++ b/hscontrol/db/versioncheck_test.go @@ -0,0 +1,318 @@ +package db + +import ( + "fmt" + "testing" + + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestParseVersion(t *testing.T) { + tests := []struct { + input string + want semver + wantErr bool + }{ + {input: "v0.25.0", want: semver{0, 25, 0}}, + {input: "0.25.0", want: semver{0, 25, 0}}, + {input: "v0.25.1", want: semver{0, 25, 1}}, + {input: "v1.0.0", want: semver{1, 0, 0}}, + {input: "v0.28.3", want: semver{0, 28, 3}}, + // Pre-release suffixes stripped + {input: "v0.25.0-beta.1", want: semver{0, 25, 0}}, + {input: "v0.25.0-rc1", want: semver{0, 25, 0}}, + // Build metadata stripped + {input: "v0.25.0+build123", want: semver{0, 25, 0}}, + {input: "v0.25.0-beta.1+build123", want: semver{0, 25, 0}}, + // Invalid inputs + {input: "", wantErr: true}, + {input: "dev", wantErr: true}, + {input: "vfoo.bar.baz", wantErr: true}, + {input: "v1.2", wantErr: true}, + {input: "v1", wantErr: true}, + {input: "not-a-version", wantErr: true}, + {input: "v1.2.3.4", wantErr: true}, + {input: "(devel)", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, err := parseVersion(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestSemverString(t *testing.T) { + s := semver{0, 28, 3} + assert.Equal(t, "v0.28.3", s.String()) +} + +func TestIsDev(t *testing.T) { + assert.True(t, isDev("")) + assert.True(t, isDev("dev")) + assert.True(t, isDev("(devel)")) + assert.False(t, isDev("v0.28.0")) + assert.False(t, isDev("0.28.0")) +} + +// versionTestDB creates an in-memory SQLite database with the +// database_versions table already bootstrapped. +func versionTestDB(t *testing.T) *gorm.DB { + t.Helper() + + db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{}) + require.NoError(t, err) + + err = ensureDatabaseVersionTable(db) + require.NoError(t, err) + + return db +} + +func TestSetAndGetDatabaseVersion(t *testing.T) { + db := versionTestDB(t) + + // Initially empty + v, err := getDatabaseVersion(db) + require.NoError(t, err) + assert.Empty(t, v) + + // Set a version + err = setDatabaseVersion(db, "v0.27.0") + require.NoError(t, err) + + v, err = getDatabaseVersion(db) + require.NoError(t, err) + assert.Equal(t, "v0.27.0", v) + + // Update the version (upsert) + err = setDatabaseVersion(db, "v0.28.0") + require.NoError(t, err) + + v, err = getDatabaseVersion(db) + require.NoError(t, err) + assert.Equal(t, "v0.28.0", v) +} + +func TestEnsureDatabaseVersionTableIdempotent(t *testing.T) { + db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{}) + require.NoError(t, err) + + // Call twice — should not error + err = ensureDatabaseVersionTable(db) + require.NoError(t, err) + + err = ensureDatabaseVersionTable(db) + require.NoError(t, err) +} + +// TestCheckVersionUpgradePathDirect tests the version comparison logic +// by directly seeding the database, bypassing types.GetVersionInfo() +// (which returns "dev" in test environments and cannot be overridden). +func TestCheckVersionUpgradePathDirect(t *testing.T) { + tests := []struct { + name string + storedVersion string // empty means no row stored + currentVersion string + wantErr bool + errContains string + }{ + // Fresh database (no stored version) + { + name: "fresh db allows any version", + storedVersion: "", + currentVersion: "v0.28.0", + }, + + // Stored is dev + { + name: "real version over dev db", + storedVersion: "dev", + currentVersion: "v0.28.0", + }, + { + name: "devel version in db", + storedVersion: "(devel)", + currentVersion: "v0.28.0", + }, + + // Same version + { + name: "same version", + storedVersion: "v0.27.0", + currentVersion: "v0.27.0", + }, + + // Patch changes within same minor + { + name: "patch upgrade", + storedVersion: "v0.27.0", + currentVersion: "v0.27.3", + }, + { + name: "patch downgrade within same minor", + storedVersion: "v0.27.3", + currentVersion: "v0.27.0", + }, + + // Single minor upgrade + { + name: "single minor upgrade", + storedVersion: "v0.27.0", + currentVersion: "v0.28.0", + }, + { + name: "single minor upgrade with different patches", + storedVersion: "v0.27.3", + currentVersion: "v0.28.1", + }, + + // Multi-minor upgrade (blocked) + { + name: "two minor versions ahead", + storedVersion: "v0.25.0", + currentVersion: "v0.27.0", + wantErr: true, + errContains: "latest v0.26.x", + }, + { + name: "three minor versions ahead", + storedVersion: "v0.25.0", + currentVersion: "v0.28.0", + wantErr: true, + errContains: "latest v0.26.x", + }, + + // Minor downgrades (blocked) + { + name: "single minor downgrade", + storedVersion: "v0.28.0", + currentVersion: "v0.27.0", + wantErr: true, + errContains: "downgrading", + }, + { + name: "multi minor downgrade", + storedVersion: "v0.28.0", + currentVersion: "v0.25.0", + wantErr: true, + errContains: "downgrading", + }, + + // Major version mismatch + { + name: "major version upgrade", + storedVersion: "v0.28.0", + currentVersion: "v1.0.0", + wantErr: true, + errContains: "major version", + }, + { + name: "major version downgrade", + storedVersion: "v1.0.0", + currentVersion: "v0.28.0", + wantErr: true, + errContains: "major version", + }, + + // Pre-release versions + { + name: "pre-release single minor upgrade", + storedVersion: "v0.27.0", + currentVersion: "v0.28.0-beta.1", + }, + { + name: "pre-release multi minor upgrade blocked", + storedVersion: "v0.25.0", + currentVersion: "v0.27.0-rc1", + wantErr: true, + errContains: "latest v0.26.x", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := versionTestDB(t) + + // Seed the stored version if provided + if tt.storedVersion != "" { + err := setDatabaseVersion(db, tt.storedVersion) + require.NoError(t, err) + } + + err := checkVersionUpgradePathFromVersions(db, tt.currentVersion) + if tt.wantErr { + require.Error(t, err) + + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// checkVersionUpgradePathFromVersions is a test helper that runs the +// version comparison logic with a specific currentVersion string, +// bypassing types.GetVersionInfo(). It replicates the logic from +// checkVersionUpgradePath but accepts the version as a parameter. +func checkVersionUpgradePathFromVersions(db *gorm.DB, currentVersion string) error { + if isDev(currentVersion) { + return nil + } + + storedVersion, err := getDatabaseVersion(db) + if err != nil { + return err + } + + if storedVersion == "" { + return nil + } + + if isDev(storedVersion) { + return nil + } + + current, err := parseVersion(currentVersion) + if err != nil { + return err + } + + stored, err := parseVersion(storedVersion) + if err != nil { + return err + } + + if current.Major != stored.Major { + return errVersionMajorChange + } + + minorDiff := current.Minor - stored.Minor + + switch { + case minorDiff == 0: + return nil + case minorDiff == 1: + return nil + case minorDiff > 1: + return fmt.Errorf( + "please upgrade to the latest v%d.%d.x release first: %w", + stored.Major, stored.Minor+1, + errVersionUpgrade, + ) + default: + return fmt.Errorf("downgrading: %w", errVersionDowngrade) + } +}