Skip to content

Commit 2aee64f

Browse files
[baseline] Add pgroll baseline command (#832)
Add a new `pgroll baseline` command to `pgroll` that allows users to establish a baseline migration from an existing database schema. ### Usage Run: ```bash $ pgroll baseline 01_initial_version migrations/ ``` This does two things: * Creates a new entry in `pgroll.migrations` with a new migration type of `baseline`. The migration JSON is empty and the `resulting_schema` field contains the schema at the time the command was run. * Creates a placeholder `migrations/01_initial_version.yaml` file for the user to populate (likely with the use of `pg_dump`). ### Features - New `pgroll baseline <version> <target directory>` command (currently hidden) - Creates a `baseline` migration in the `pgroll.migrations` table that captures the current schema state without applying changes. - Validation to prevent baseline creation during active migrations - Adds test coverage for baseline functionality ### Implementation Details - Added `Roll.CreateBaseline` method and corresponding tests - Added `State.CreateBaseline` method for writing `baseline` migrations to `pgroll.migrations` - Extended migration type definition to include 'baseline' type The `pgroll baseline` command is currently hidden. Un-hiding the command will happen when the `pull` and `migrate` commands are made baseline-aware. This is the first part of #364
1 parent cc8fa32 commit 2aee64f

File tree

9 files changed

+297
-29
lines changed

9 files changed

+297
-29
lines changed

cli-definition.json

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,25 @@
1010
"subcommands": [],
1111
"args": []
1212
},
13+
{
14+
"name": "baseline",
15+
"short": "Create a baseline migration for an existing database schema",
16+
"use": "baseline <version> <target directory>",
17+
"example": "",
18+
"flags": [
19+
{
20+
"name": "json",
21+
"shorthand": "j",
22+
"description": "output in JSON format instead of YAML",
23+
"default": "false"
24+
}
25+
],
26+
"subcommands": [],
27+
"args": [
28+
"version",
29+
"directory"
30+
]
31+
},
1332
{
1433
"name": "complete",
1534
"short": "Complete an ongoing migration with the operations present in the given file",

cmd/baseline.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package cmd
4+
5+
import (
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"os"
10+
11+
"github.com/pterm/pterm"
12+
"github.com/spf13/cobra"
13+
"github.com/xataio/pgroll/pkg/migrations"
14+
)
15+
16+
func baselineCmd() *cobra.Command {
17+
var useJSON bool
18+
19+
baselineCmd := &cobra.Command{
20+
Use: "baseline <version> <target directory>",
21+
Short: "Create a baseline migration for an existing database schema",
22+
Args: cobra.ExactArgs(2),
23+
ValidArgs: []string{"version", "directory"},
24+
Hidden: true,
25+
RunE: func(cmd *cobra.Command, args []string) error {
26+
version := args[0]
27+
targetDir := args[1]
28+
29+
ctx := cmd.Context()
30+
31+
// Create a roll instance
32+
m, err := NewRollWithInitCheck(ctx)
33+
if err != nil {
34+
return err
35+
}
36+
defer m.Close()
37+
38+
// Ensure that the target directory exists
39+
if err := ensureDirectoryExists(targetDir); err != nil {
40+
return err
41+
}
42+
43+
// Prompt for confirmation
44+
fmt.Println("Creating a baseline migration will restart the migration history.")
45+
ok, _ := pterm.DefaultInteractiveConfirm.Show()
46+
if !ok {
47+
return nil
48+
}
49+
50+
// Create a placeholder baseline migration
51+
ops := migrations.Operations{&migrations.OpRawSQL{Up: ""}}
52+
opsJSON, err := json.Marshal(ops)
53+
if err != nil {
54+
return fmt.Errorf("failed to marshal operations: %w", err)
55+
}
56+
mig := &migrations.RawMigration{
57+
Name: version,
58+
Operations: opsJSON,
59+
}
60+
61+
// Write the placeholder migration to disk
62+
filePath, err := writeMigrationToFile(mig, targetDir, "", useJSON)
63+
if err != nil {
64+
return fmt.Errorf("failed to write placeholder baseline migration: %w", err)
65+
}
66+
67+
sp, _ := pterm.DefaultSpinner.WithText(fmt.Sprintf("Creating baseline migration %q...", version)).Start()
68+
69+
// Create the baseline in the target database
70+
err = m.CreateBaseline(ctx, version)
71+
if err != nil {
72+
sp.Fail(fmt.Sprintf("Failed to create baseline: %s", err))
73+
err = errors.Join(err, os.Remove(filePath))
74+
return err
75+
}
76+
77+
sp.Success(fmt.Sprintf("Baseline created successfully. Placeholder migration %q written", filePath))
78+
return nil
79+
},
80+
}
81+
82+
baselineCmd.Flags().BoolVarP(&useJSON, "json", "j", false, "output in JSON format instead of YAML")
83+
84+
return baselineCmd
85+
}

cmd/pull.go

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,9 @@ func pullCmd() *cobra.Command {
3434
}
3535
defer m.Close()
3636

37-
// Ensure that the target directory is valid, creating it if it doesn't
38-
// exist
39-
_, err = os.Stat(targetDir)
40-
if err != nil {
41-
if os.IsNotExist(err) {
42-
err := os.MkdirAll(targetDir, 0o755)
43-
if err != nil {
44-
return fmt.Errorf("failed to create target directory: %w", err)
45-
}
46-
} else {
47-
return fmt.Errorf("failed to stat directory: %w", err)
48-
}
37+
// Ensure that the target directory exists
38+
if err := ensureDirectoryExists(targetDir); err != nil {
39+
return err
4940
}
5041

5142
// Get the list of missing migrations (those that have been applied to
@@ -61,9 +52,9 @@ func pullCmd() *cobra.Command {
6152
if withPrefixes {
6253
prefix = fmt.Sprintf("%04d", i+1) + "_"
6354
}
64-
err := writeMigrationToFile(mig, targetDir, prefix, useJSON)
55+
filePath, err := writeMigrationToFile(mig, targetDir, prefix, useJSON)
6556
if err != nil {
66-
return fmt.Errorf("failed to write migration %q: %w", mig.Name, err)
57+
return fmt.Errorf("failed to write migration %q: %w", filePath, err)
6758
}
6859
}
6960
return nil
@@ -76,13 +67,30 @@ func pullCmd() *cobra.Command {
7667
return pullCmd
7768
}
7869

70+
// ensureDirectoryExists ensures that the target directory exists, creating it if it doesn't.
71+
// Returns an error if the directory cannot be created or if there's an issue checking its existence.
72+
func ensureDirectoryExists(targetDir string) error {
73+
_, err := os.Stat(targetDir)
74+
if err != nil {
75+
if os.IsNotExist(err) {
76+
err := os.MkdirAll(targetDir, 0o755)
77+
if err != nil {
78+
return fmt.Errorf("failed to create target directory: %w", err)
79+
}
80+
} else {
81+
return fmt.Errorf("failed to stat directory: %w", err)
82+
}
83+
}
84+
return nil
85+
}
86+
7987
// WriteToFile writes the migration to a file in `targetDir`, prefixing the
8088
// filename with `prefix`. The output format defaults to YAML, but can
81-
// be changed to JSON by setting `useJSON` to true.
82-
func writeMigrationToFile(m *migrations.RawMigration, targetDir, prefix string, useJSON bool) error {
83-
err := os.MkdirAll(targetDir, 0o755)
84-
if err != nil {
85-
return err
89+
// be changed to JSON by setting `useJSON` to true. The function returns
90+
// the full path of the created file or an error if the operation fails.
91+
func writeMigrationToFile(m *migrations.RawMigration, targetDir, prefix string, useJSON bool) (string, error) {
92+
if err := ensureDirectoryExists(targetDir); err != nil {
93+
return "", err
8694
}
8795

8896
format := migrations.NewMigrationFormat(useJSON)
@@ -91,9 +99,13 @@ func writeMigrationToFile(m *migrations.RawMigration, targetDir, prefix string,
9199

92100
file, err := os.Create(filePath)
93101
if err != nil {
94-
return err
102+
return "", err
95103
}
96104
defer file.Close()
97105

98-
return migrations.NewWriter(file, format).WriteRaw(m)
106+
err = migrations.NewWriter(file, format).WriteRaw(m)
107+
if err != nil {
108+
return "", err
109+
}
110+
return filePath, nil
99111
}

cmd/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ func Prepare() *cobra.Command {
106106
rootCmd.AddCommand(pullCmd())
107107
rootCmd.AddCommand(latestCmd())
108108
rootCmd.AddCommand(convertCmd())
109+
rootCmd.AddCommand(baselineCmd())
109110

110111
return rootCmd
111112
}

pkg/roll/baseline.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package roll
4+
5+
import (
6+
"context"
7+
)
8+
9+
// CreateBaseline creates a baseline migration for an existing database schema.
10+
// This is used when starting pgroll with an existing database - it captures
11+
// the current schema state as a baseline version without applying any changes.
12+
// Future migrations will build upon this baseline version.
13+
func (m *Roll) CreateBaseline(ctx context.Context, baselineVersion string) error {
14+
// Log the operation
15+
m.logger.Info("Creating baseline version %q for schema %q", baselineVersion, m.schema)
16+
17+
// Delegate to state to create the actual baseline migration record
18+
return m.state.CreateBaseline(ctx, m.schema, baselineVersion)
19+
}

pkg/roll/baseline_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package roll_test
4+
5+
import (
6+
"context"
7+
"database/sql"
8+
"testing"
9+
10+
"github.com/stretchr/testify/require"
11+
"github.com/xataio/pgroll/internal/testutils"
12+
"github.com/xataio/pgroll/pkg/roll"
13+
"github.com/xataio/pgroll/pkg/schema"
14+
"github.com/xataio/pgroll/pkg/state"
15+
)
16+
17+
func TestBaseline(t *testing.T) {
18+
t.Parallel()
19+
20+
t.Run("baseline migration captures the current schema", func(t *testing.T) {
21+
testutils.WithMigratorAndStateAndConnectionToContainerWithOptions(t, nil, func(roll *roll.Roll, st *state.State, db *sql.DB) {
22+
ctx := context.Background()
23+
24+
// Create a table in the database to simulate an existing schema
25+
_, err := db.ExecContext(ctx, "CREATE TABLE users (id int)")
26+
require.NoError(t, err)
27+
28+
// Create a baseline migration
29+
err = roll.CreateBaseline(ctx, "01_initial_version")
30+
require.NoError(t, err)
31+
32+
// Get the captured database schema after the baseline migration was applied
33+
sc, err := st.SchemaAfterMigration(ctx, "public", "01_initial_version")
34+
require.NoError(t, err)
35+
36+
// Define the expected schema
37+
wantSchema := &schema.Schema{
38+
Name: "public",
39+
Tables: map[string]*schema.Table{
40+
"users": {
41+
Name: "users",
42+
Columns: map[string]*schema.Column{
43+
"id": {
44+
Name: "id",
45+
Type: "integer",
46+
Nullable: true,
47+
PostgresType: "base",
48+
},
49+
},
50+
PrimaryKey: []string{},
51+
Indexes: map[string]*schema.Index{},
52+
ForeignKeys: map[string]*schema.ForeignKey{},
53+
CheckConstraints: map[string]*schema.CheckConstraint{},
54+
UniqueConstraints: map[string]*schema.UniqueConstraint{},
55+
ExcludeConstraints: map[string]*schema.ExcludeConstraint{},
56+
},
57+
},
58+
}
59+
60+
// Clear OIDs from the schema to avoid comparison issues
61+
clearOIDS(sc)
62+
63+
// Assert the the schema matches the expected schema
64+
require.Equal(t, wantSchema, sc)
65+
})
66+
})
67+
}
68+
69+
func clearOIDS(s *schema.Schema) {
70+
for k := range s.Tables {
71+
c := s.Tables[k]
72+
c.OID = ""
73+
s.Tables[k] = c
74+
}
75+
}

pkg/roll/execute_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020
)
2121

2222
const (
23-
schema = "public"
23+
cSchema = "public"
2424
)
2525

2626
func TestMain(m *testing.M) {
@@ -41,7 +41,7 @@ func TestSchemaIsCreatedAfterMigrationStart(t *testing.T) {
4141
//
4242
// Check that the schema exists
4343
//
44-
if !schemaExists(t, db, roll.VersionedSchemaName(schema, version)) {
44+
if !schemaExists(t, db, roll.VersionedSchemaName(cSchema, version)) {
4545
t.Errorf("Expected schema %q to exist", version)
4646
}
4747
})
@@ -61,7 +61,7 @@ func TestDisabledSchemaManagement(t *testing.T) {
6161
//
6262
// Check that the schema doesn't get created
6363
//
64-
if schemaExists(t, db, roll.VersionedSchemaName(schema, version)) {
64+
if schemaExists(t, db, roll.VersionedSchemaName(cSchema, version)) {
6565
t.Errorf("Expected schema %q to not exist", version)
6666
}
6767

@@ -78,7 +78,7 @@ func TestDisabledSchemaManagement(t *testing.T) {
7878
t.Fatalf("Failed to complete migration: %v", err)
7979
}
8080

81-
if schemaExists(t, db, roll.VersionedSchemaName(schema, version)) {
81+
if schemaExists(t, db, roll.VersionedSchemaName(cSchema, version)) {
8282
t.Errorf("Expected schema %q to not exist", version)
8383
}
8484
})
@@ -111,7 +111,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) {
111111
//
112112
// Check that the schema for the first version has been dropped
113113
//
114-
if schemaExists(t, db, roll.VersionedSchemaName(schema, firstVersion)) {
114+
if schemaExists(t, db, roll.VersionedSchemaName(cSchema, firstVersion)) {
115115
t.Errorf("Expected schema %q to not exist", firstVersion)
116116
}
117117
})
@@ -150,7 +150,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) {
150150
//
151151
// Check that the schema for the first version has been dropped
152152
//
153-
if schemaExists(t, db, roll.VersionedSchemaName(schema, firstVersion)) {
153+
if schemaExists(t, db, roll.VersionedSchemaName(cSchema, firstVersion)) {
154154
t.Errorf("Expected schema %q to not exist", firstVersion)
155155
}
156156
})
@@ -198,7 +198,7 @@ func TestNoVersionSchemaForRawSQLMigrationsOptionIsRespected(t *testing.T) {
198198
require.NoError(t, err)
199199
require.NotNil(t, prevVersion)
200200
assert.Equal(t, "02_create_table", *prevVersion)
201-
assert.False(t, schemaExists(t, db, roll.VersionedSchemaName(schema, "02_create_table")))
201+
assert.False(t, schemaExists(t, db, roll.VersionedSchemaName(cSchema, "02_create_table")))
202202

203203
// Complete the third migration
204204
err = mig.Complete(ctx)
@@ -229,7 +229,7 @@ func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) {
229229
//
230230
// Check that the schema has been dropped
231231
//
232-
if schemaExists(t, db, roll.VersionedSchemaName(schema, version)) {
232+
if schemaExists(t, db, roll.VersionedSchemaName(cSchema, version)) {
233233
t.Errorf("Expected schema %q to not exist", version)
234234
}
235235
})

pkg/state/init.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ CREATE UNIQUE INDEX IF NOT EXISTS history_is_linear ON placeholder.migrations (s
3131
ALTER TABLE placeholder.migrations
3232
ADD COLUMN IF NOT EXISTS migration_type varchar(32) DEFAULT 'pgroll' CONSTRAINT migration_type_check CHECK (migration_type IN ('pgroll', 'inferred'));
3333

34+
-- Update the `migration_type` column to also allow a `baseline` migration type.
35+
ALTER TABLE placeholder.migrations
36+
DROP CONSTRAINT migration_type_check;
37+
38+
ALTER TABLE placeholder.migrations
39+
ADD CONSTRAINT migration_type_check CHECK (migration_type IN ('pgroll', 'inferred', 'baseline'));
40+
3441
-- Change timestamp columns to use timestamptz
3542
ALTER TABLE placeholder.migrations
3643
ALTER COLUMN created_at SET DATA TYPE timestamptz USING created_at AT TIME ZONE 'UTC',

0 commit comments

Comments
 (0)