diff --git a/cmd/analyze.go b/cmd/analyze.go index b5b2202b7..03914da09 100644 --- a/cmd/analyze.go +++ b/cmd/analyze.go @@ -19,7 +19,7 @@ var analyzeCmd = &cobra.Command{ Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() - state, err := state.New(ctx, flags.PostgresURL(), flags.StateSchema()) + state, err := state.New(ctx, flags.PostgresURL(), flags.StateSchema(), state.WithPgrollVersion(Version)) if err != nil { return err } diff --git a/cmd/root.go b/cmd/root.go index 65541c029..14a21aa47 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -25,7 +25,7 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) { skipValidation := flags.SkipValidation() verbose := flags.Verbose() - state, err := state.New(ctx, pgURL, stateSchema) + state, err := state.New(ctx, pgURL, stateSchema, state.WithPgrollVersion(Version)) if err != nil { return nil, err } diff --git a/cmd/status.go b/cmd/status.go index ea05fe0b3..3a374eea5 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -18,7 +18,7 @@ var statusCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() - state, err := state.New(ctx, flags.PostgresURL(), flags.StateSchema()) + state, err := state.New(ctx, flags.PostgresURL(), flags.StateSchema(), state.WithPgrollVersion(Version)) if err != nil { return err } diff --git a/go.mod b/go.mod index 4a5ff3883..dc25c4ed2 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/testcontainers/testcontainers-go v0.37.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0 github.com/xataio/pg_query_go/v6 v6.0.0-20250425105130-ed1845ee2d75 + golang.org/x/mod v0.25.0 golang.org/x/tools v0.34.0 sigs.k8s.io/yaml v1.4.0 ) diff --git a/go.sum b/go.sum index 7cd26aa75..8d2ee9ec9 100644 --- a/go.sum +++ b/go.sum @@ -258,6 +258,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/internal/testutils/util.go b/internal/testutils/util.go index 7613b8496..0dd1b1369 100644 --- a/internal/testutils/util.go +++ b/internal/testutils/util.go @@ -131,6 +131,24 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql. WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn) } +func WithStateAtVersionAndConnectionToContainer(t *testing.T, version string, fn func(*state.State, string, *sql.DB)) { + t.Helper() + ctx := context.Background() + + db, connStr, _ := setupTestDatabase(t) + + st, err := state.New(ctx, connStr, "pgroll", state.WithPgrollVersion(version)) + if err != nil { + t.Fatal(err) + } + + if err := st.Init(ctx); err != nil { + t.Fatal(err) + } + + fn(st, connStr, db) +} + func WithUninitializedState(t *testing.T, fn func(*state.State)) { t.Helper() ctx := context.Background() diff --git a/pkg/state/init.sql b/pkg/state/init.sql index cc9958b2b..03a8742c1 100644 --- a/pkg/state/init.sql +++ b/pkg/state/init.sql @@ -43,6 +43,13 @@ ALTER TABLE placeholder.migrations ALTER COLUMN created_at SET DATA TYPE timestamptz USING created_at AT TIME ZONE 'UTC', ALTER COLUMN updated_at SET DATA TYPE timestamptz USING updated_at AT TIME ZONE 'UTC'; +-- Table to track pgroll binary version +CREATE TABLE IF NOT EXISTS placeholder.pgroll_version ( + version text NOT NULL, + initialized_at timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (version) +); + -- Helper functions -- Are we in the middle of a migration? CREATE OR REPLACE FUNCTION placeholder.is_active_migration_period (schemaname name) diff --git a/pkg/state/opts.go b/pkg/state/opts.go new file mode 100644 index 000000000..1d715c07e --- /dev/null +++ b/pkg/state/opts.go @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 + +package state + +type StateOpt func(s *State) + +// WithPgrollVersion sets the version of `pgroll` that is constructing the State +// instance +func WithPgrollVersion(version string) StateOpt { + return func(s *State) { + s.pgrollVersion = version + } +} diff --git a/pkg/state/state.go b/pkg/state/state.go index 7cc14f6a9..4823a75b0 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -25,11 +25,12 @@ var sqlInit string const applicationName = "pgroll-state" type State struct { - pgConn *sql.DB - schema string + pgConn *sql.DB + pgrollVersion string + schema string } -func New(ctx context.Context, pgURL, stateSchema string) (*State, error) { +func New(ctx context.Context, pgURL, stateSchema string, opts ...StateOpt) (*State, error) { dsn, err := pq.ParseURL(pgURL) if err != nil { dsn = pgURL @@ -46,10 +47,38 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) { return nil, err } - return &State{ - pgConn: conn, - schema: stateSchema, - }, nil + st := &State{ + pgConn: conn, + pgrollVersion: "development", + schema: stateSchema, + } + + // Apply options to the State instance + for _, opt := range opts { + opt(st) + } + + // Check version compatibility between the pgroll version and the version of + // the pgroll state schema. + compat, err := st.VersionCompatibility(ctx) + if err != nil { + return nil, err + } + + // If the state schema is newer than the pgroll version, return an error + if compat == VersionCompatVersionSchemaNewer { + return nil, ErrNewPgrollSchema + } + + // if the state schema is older than the pgroll version, re-initialize the + // state schema + if compat == VersionCompatVersionSchemaOlder { + if err := st.Init(ctx); err != nil { + return nil, err + } + } + + return st, nil } // Init initializes the required pg_roll schema to store the state @@ -76,6 +105,22 @@ func (s *State) Init(ctx context.Context) error { return err } + // Clear the pgroll_version table + _, err = tx.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s.pgroll_version", + pq.QuoteIdentifier(s.schema))) + if err != nil { + return err + } + + // Insert the version of `pgroll` that is being initialized into the + // pgroll_version table + _, err = tx.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s.pgroll_version (version) VALUES ($1)", + pq.QuoteIdentifier(s.schema)), + s.pgrollVersion) + if err != nil { + return err + } + return tx.Commit() } @@ -83,6 +128,7 @@ func (s *State) PgConn() *sql.DB { return s.pgConn } +// IsInitialized checks if the pgroll state schema is initialized. func (s *State) IsInitialized(ctx context.Context) (bool, error) { var isInitialized bool err := s.pgConn.QueryRowContext(ctx, diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index d14c91885..78d266089 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -1353,6 +1353,85 @@ func TestReadSchema(t *testing.T) { }) } +func TestPgrollSchemaVersionUpgrades(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + tests := []struct { + name string + initialSchemaVersion string + pgrollVersion string + expectedSchemaVersion string + expectedError error + }{ + { + name: "pgroll schema is older than the pgroll version - pgroll schema is updated", + initialSchemaVersion: "0.13.0", + pgrollVersion: "0.14.0", + expectedSchemaVersion: "0.14.0", + }, + { + name: "pgroll schema is newer than the pgroll version - state initialization fails", + initialSchemaVersion: "0.15.0", + pgrollVersion: "0.14.0", + expectedError: state.ErrNewPgrollSchema, + }, + { + name: "pgroll schema is the same as the pgroll version - pgroll schema is not updated", + initialSchemaVersion: "0.13.0", + pgrollVersion: "0.13.0", + expectedSchemaVersion: "0.13.0", + }, + { + name: "development versions of pgroll never cause a pgroll schema update", + initialSchemaVersion: "0.13.0", + pgrollVersion: "development", + expectedSchemaVersion: "0.13.0", + }, + { + name: "development versions of the pgroll schema are never upgraded", + initialSchemaVersion: "development", + pgrollVersion: "0.13.0", + expectedSchemaVersion: "development", + }, + { + name: "invalid pgroll version - pgroll schema is not updated", + initialSchemaVersion: "0.14.0", + pgrollVersion: "banana", + expectedSchemaVersion: "0.14.0", + }, + { + name: "invalid pgroll schema version - pgroll schema is not updated", + initialSchemaVersion: "banana", + pgrollVersion: "0.14.0", + expectedSchemaVersion: "banana", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testutils.WithStateAtVersionAndConnectionToContainer(t, tt.initialSchemaVersion, func(st *state.State, connStr string, _ *sql.DB) { + // Create a new state instance with the specified pgroll version. This + // will upgrade the pgroll schema if necessary. + s, err := state.New(ctx, connStr, "pgroll", state.WithPgrollVersion(tt.pgrollVersion)) + + if tt.expectedError != nil { + require.ErrorIs(t, err, tt.expectedError) + } else { + require.NoError(t, err) + // Get the version of the pgroll schema + schemaVersion, err := s.SchemaVersion(ctx) + require.NoError(t, err) + + // Ensure the expected pgroll schema version + require.Equal(t, tt.expectedSchemaVersion, schemaVersion) + } + }) + }) + } +} + func clearOIDS(s *schema.Schema) { for k := range s.Tables { c := s.Tables[k] diff --git a/pkg/state/version.go b/pkg/state/version.go new file mode 100644 index 000000000..10dac5ad3 --- /dev/null +++ b/pkg/state/version.go @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: Apache-2.0 + +package state + +import ( + "context" + "errors" + "fmt" + + "github.com/lib/pq" + "golang.org/x/mod/semver" +) + +var ErrNewPgrollSchema = errors.New("pgroll binary version is older than pgroll schema version") + +// VersionCompatibility represents the result of comparing pgroll binary and +// state schema versions +type VersionCompatibility int + +const ( + VersionCompatCheckSkipped VersionCompatibility = iota + VersionCompatNotInitialized + VersionCompatVersionSchemaOlder + VersionCompatVersionSchemaEqual + VersionCompatVersionSchemaNewer +) + +// VersionCompatibility compares the pgroll version that was used to initialize +// the `State` instance with the version of the pgroll state schema. +func (s *State) VersionCompatibility(ctx context.Context) (VersionCompatibility, error) { + pgrollVersion := s.pgrollVersion + + // Development versions of pgroll are not checked for compatibility + if pgrollVersion == "development" { + return VersionCompatCheckSkipped, nil + } + + // Only perform compatibility check if pgroll is initialized + ok, err := s.IsInitialized(ctx) + if err != nil { + return 0, fmt.Errorf("failed to check initialization status: %w", err) + } + if !ok { + return VersionCompatCheckSkipped, nil + } + + // Check if this is a legacy schema (pgroll schema exists but there is no + // `pgroll_version` table). + versionTableExists, err := s.versionTableExists(ctx) + if err != nil { + return 0, fmt.Errorf("failed to check version table existence: %w", err) + } + if !versionTableExists { + return VersionCompatVersionSchemaOlder, nil + } + + // Get the pgroll version that was used to initialize the pgroll schema + schemaVersion, err := s.SchemaVersion(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get stored version: %w", err) + } + + // pgroll schemas created by development versions of pgroll are not checked + // for compatibility. + if schemaVersion == "development" { + return VersionCompatCheckSkipped, nil + } + + // Ensure both versions have the 'v' prefix for compatibility with Go's + // semver package + schemaVersion = ensureVPrefix(schemaVersion) + pgrollVersion = ensureVPrefix(pgrollVersion) + + // If either the schema version or the pgroll version is invalid, do not make + // any assumptions about compatibility + if !semver.IsValid(schemaVersion) || !semver.IsValid(pgrollVersion) { + return VersionCompatCheckSkipped, nil + } + + // Canonicalize both versions to ensure they are in the correct format + schemaVersion = semver.Canonical(schemaVersion) + pgrollVersion = semver.Canonical(pgrollVersion) + + // Compare versions + cmp := semver.Compare(schemaVersion, pgrollVersion) + if cmp < 0 { + return VersionCompatVersionSchemaOlder, nil + } + if cmp > 0 { + return VersionCompatVersionSchemaNewer, nil + } + + return VersionCompatVersionSchemaEqual, nil +} + +// schemaVersion retrieves the version stored in the pgroll_version table. +func (s *State) SchemaVersion(ctx context.Context) (string, error) { + query := fmt.Sprintf("SELECT version FROM %s.pgroll_version ORDER BY initialized_at DESC LIMIT 1", + pq.QuoteIdentifier(s.schema)) + + var version string + err := s.pgConn.QueryRowContext(ctx, query).Scan(&version) + if err != nil { + return "", err + } + + return version, nil +} + +// versionTableExists checks if the pgroll_version table exists in the state +// schema. +func (s *State) versionTableExists(ctx context.Context) (bool, error) { + query := `SELECT EXISTS ( + SELECT 1 FROM information_schema.tables + WHERE table_schema = $1 AND table_name = 'pgroll_version' + )` + + var exists bool + err := s.pgConn.QueryRowContext(ctx, query, s.schema).Scan(&exists) + return exists, err +} + +// Ensure that the given version string starts with 'v' to ensure compatibility +// with the`golang.org/x/mod/semver` package +func ensureVPrefix(version string) string { + if len(version) > 0 && version[0] != 'v' { + return "v" + version + } + return version +}