Skip to content

Keep pgroll state schema in sync with pgroll binary version #876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ var analyzeCmd = &cobra.Command{
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
ctx := cmd.Context()
state, err := state.New(ctx, flags.PostgresURL(), flags.StateSchema())
state, err := state.New(ctx, flags.PostgresURL(), flags.StateSchema(), state.WithPgrollVersion(Version))
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) {
skipValidation := flags.SkipValidation()
verbose := flags.Verbose()

state, err := state.New(ctx, pgURL, stateSchema)
state, err := state.New(ctx, pgURL, stateSchema, state.WithPgrollVersion(Version))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var statusCmd = &cobra.Command{
RunE: func(cmd *cobra.Command, _ []string) error {
ctx := cmd.Context()

state, err := state.New(ctx, flags.PostgresURL(), flags.StateSchema())
state, err := state.New(ctx, flags.PostgresURL(), flags.StateSchema(), state.WithPgrollVersion(Version))
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require (
github.com/testcontainers/testcontainers-go v0.37.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0
github.com/xataio/pg_query_go/v6 v6.0.0-20250425105130-ed1845ee2d75
golang.org/x/mod v0.25.0
golang.org/x/tools v0.34.0
sigs.k8s.io/yaml v1.4.0
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
Expand Down
18 changes: 18 additions & 0 deletions internal/testutils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,24 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.
WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn)
}

func WithStateAtVersionAndConnectionToContainer(t *testing.T, version string, fn func(*state.State, string, *sql.DB)) {
t.Helper()
ctx := context.Background()

db, connStr, _ := setupTestDatabase(t)

st, err := state.New(ctx, connStr, "pgroll", state.WithPgrollVersion(version))
if err != nil {
t.Fatal(err)
}

if err := st.Init(ctx); err != nil {
t.Fatal(err)
}

fn(st, connStr, db)
}

func WithUninitializedState(t *testing.T, fn func(*state.State)) {
t.Helper()
ctx := context.Background()
Expand Down
7 changes: 7 additions & 0 deletions pkg/state/init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ ALTER TABLE placeholder.migrations
ALTER COLUMN created_at SET DATA TYPE timestamptz USING created_at AT TIME ZONE 'UTC',
ALTER COLUMN updated_at SET DATA TYPE timestamptz USING updated_at AT TIME ZONE 'UTC';

-- Table to track pgroll binary version
CREATE TABLE IF NOT EXISTS placeholder.pgroll_version (
version text NOT NULL,
initialized_at timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (version)
);

-- Helper functions
-- Are we in the middle of a migration?
CREATE OR REPLACE FUNCTION placeholder.is_active_migration_period (schemaname name)
Expand Down
13 changes: 13 additions & 0 deletions pkg/state/opts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-License-Identifier: Apache-2.0

package state

type StateOpt func(s *State)

// WithPgrollVersion sets the version of `pgroll` that is constructing the State
// instance
func WithPgrollVersion(version string) StateOpt {
return func(s *State) {
s.pgrollVersion = version
}
}
60 changes: 53 additions & 7 deletions pkg/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ var sqlInit string
const applicationName = "pgroll-state"

type State struct {
pgConn *sql.DB
schema string
pgConn *sql.DB
pgrollVersion string
schema string
}

func New(ctx context.Context, pgURL, stateSchema string) (*State, error) {
func New(ctx context.Context, pgURL, stateSchema string, opts ...StateOpt) (*State, error) {
dsn, err := pq.ParseURL(pgURL)
if err != nil {
dsn = pgURL
Expand All @@ -46,10 +47,38 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) {
return nil, err
}

return &State{
pgConn: conn,
schema: stateSchema,
}, nil
st := &State{
pgConn: conn,
pgrollVersion: "development",
schema: stateSchema,
}

// Apply options to the State instance
for _, opt := range opts {
opt(st)
}

// Check version compatibility between the pgroll version and the version of
// the pgroll state schema.
compat, err := st.VersionCompatibility(ctx)
if err != nil {
return nil, err
}

// If the state schema is newer than the pgroll version, return an error
if compat == VersionCompatVersionSchemaNewer {
return nil, ErrNewPgrollSchema
}

// if the state schema is older than the pgroll version, re-initialize the
// state schema
if compat == VersionCompatVersionSchemaOlder {
if err := st.Init(ctx); err != nil {
return nil, err
}
}

return st, nil
}

// Init initializes the required pg_roll schema to store the state
Expand All @@ -76,13 +105,30 @@ func (s *State) Init(ctx context.Context) error {
return err
}

// Clear the pgroll_version table
_, err = tx.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s.pgroll_version",
pq.QuoteIdentifier(s.schema)))
if err != nil {
return err
}

// Insert the version of `pgroll` that is being initialized into the
// pgroll_version table
_, err = tx.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s.pgroll_version (version) VALUES ($1)",
pq.QuoteIdentifier(s.schema)),
s.pgrollVersion)
if err != nil {
return err
}

return tx.Commit()
}

func (s *State) PgConn() *sql.DB {
return s.pgConn
}

// IsInitialized checks if the pgroll state schema is initialized.
func (s *State) IsInitialized(ctx context.Context) (bool, error) {
var isInitialized bool
err := s.pgConn.QueryRowContext(ctx,
Expand Down
79 changes: 79 additions & 0 deletions pkg/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,85 @@ func TestReadSchema(t *testing.T) {
})
}

func TestPgrollSchemaVersionUpgrades(t *testing.T) {
t.Parallel()

ctx := context.Background()

tests := []struct {
name string
initialSchemaVersion string
pgrollVersion string
expectedSchemaVersion string
expectedError error
}{
{
name: "pgroll schema is older than the pgroll version - pgroll schema is updated",
initialSchemaVersion: "0.13.0",
pgrollVersion: "0.14.0",
expectedSchemaVersion: "0.14.0",
},
{
name: "pgroll schema is newer than the pgroll version - state initialization fails",
initialSchemaVersion: "0.15.0",
pgrollVersion: "0.14.0",
expectedError: state.ErrNewPgrollSchema,
},
{
name: "pgroll schema is the same as the pgroll version - pgroll schema is not updated",
initialSchemaVersion: "0.13.0",
pgrollVersion: "0.13.0",
expectedSchemaVersion: "0.13.0",
},
{
name: "development versions of pgroll never cause a pgroll schema update",
initialSchemaVersion: "0.13.0",
pgrollVersion: "development",
expectedSchemaVersion: "0.13.0",
},
{
name: "development versions of the pgroll schema are never upgraded",
initialSchemaVersion: "development",
pgrollVersion: "0.13.0",
expectedSchemaVersion: "development",
},
{
name: "invalid pgroll version - pgroll schema is not updated",
initialSchemaVersion: "0.14.0",
pgrollVersion: "banana",
expectedSchemaVersion: "0.14.0",
},
{
name: "invalid pgroll schema version - pgroll schema is not updated",
initialSchemaVersion: "banana",
pgrollVersion: "0.14.0",
expectedSchemaVersion: "banana",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testutils.WithStateAtVersionAndConnectionToContainer(t, tt.initialSchemaVersion, func(st *state.State, connStr string, _ *sql.DB) {
// Create a new state instance with the specified pgroll version. This
// will upgrade the pgroll schema if necessary.
s, err := state.New(ctx, connStr, "pgroll", state.WithPgrollVersion(tt.pgrollVersion))

if tt.expectedError != nil {
require.ErrorIs(t, err, tt.expectedError)
} else {
require.NoError(t, err)
// Get the version of the pgroll schema
schemaVersion, err := s.SchemaVersion(ctx)
require.NoError(t, err)

// Ensure the expected pgroll schema version
require.Equal(t, tt.expectedSchemaVersion, schemaVersion)
}
})
})
}
}

func clearOIDS(s *schema.Schema) {
for k := range s.Tables {
c := s.Tables[k]
Expand Down
Loading
Loading