Skip to content

Commit 3d6ac6d

Browse files
Add WithNoVersionSchemaForRawSQL option (#365)
Add a new `WithNoVersionSchemaForRawSQL` option to control whether or not version schema should be created for raw SQL migrations. With the option set, a raw SQL migration: * Has no version schema created on migration start. * Leaves the previous version schema in place on migration completion.
1 parent 7cef8b1 commit 3d6ac6d

File tree

7 files changed

+146
-22
lines changed

7 files changed

+146
-22
lines changed

pkg/migrations/migrations.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,13 @@ func (m *Migration) Validate(ctx context.Context, s *schema.Schema) error {
8585

8686
return nil
8787
}
88+
89+
// ContainsRawSQLOperation returns true if the migration contains a raw SQL operation
90+
func (m *Migration) ContainsRawSQLOperation() bool {
91+
for _, op := range m.Operations {
92+
if _, ok := op.(*OpRawSQL); ok {
93+
return true
94+
}
95+
}
96+
return false
97+
}

pkg/roll/execute.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func (m *Roll) StartDDLOperations(ctx context.Context, migration *migrations.Mig
9090
}
9191
}
9292

93-
if m.disableVersionSchemas {
93+
if m.disableVersionSchemas || migration.ContainsRawSQLOperation() && m.noVersionSchemaForRawSQL {
9494
// skip creating version schemas
9595
return tablesToBackfill, nil
9696
}
@@ -131,7 +131,7 @@ func (m *Roll) Complete(ctx context.Context) error {
131131
}
132132

133133
// Drop the old schema
134-
if !m.disableVersionSchemas {
134+
if !m.disableVersionSchemas && (!migration.ContainsRawSQLOperation() || !m.noVersionSchemaForRawSQL) {
135135
prevVersion, err := m.state.PreviousVersion(ctx, m.schema)
136136
if err != nil {
137137
return fmt.Errorf("unable to get name of previous version: %w", err)
@@ -172,7 +172,9 @@ func (m *Roll) Complete(ctx context.Context) error {
172172
}
173173

174174
if _, ok := op.(migrations.RequiresSchemaRefreshOperation); ok {
175-
refreshViews = true
175+
if _, ok := op.(*migrations.OpRawSQL); !ok || !m.noVersionSchemaForRawSQL {
176+
refreshViews = true
177+
}
176178
}
177179
}
178180

pkg/roll/execute_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,54 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) {
155155
})
156156
}
157157

158+
func TestNoVersionSchemaForRawSQLMigrationsOptionIsRespected(t *testing.T) {
159+
t.Parallel()
160+
161+
opts := []roll.Option{roll.WithNoVersionSchemaForRawSQL()}
162+
163+
testutils.WithMigratorAndStateAndConnectionToContainerWithOptions(t, opts, func(mig *roll.Roll, st *state.State, db *sql.DB) {
164+
ctx := context.Background()
165+
166+
// Apply a create table migration
167+
err := mig.Start(ctx, &migrations.Migration{Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}})
168+
require.NoError(t, err)
169+
err = mig.Complete(ctx)
170+
require.NoError(t, err)
171+
172+
// Apply a raw SQL migration - no version schema should be created for this version
173+
err = mig.Start(ctx, &migrations.Migration{
174+
Name: "02_create_table",
175+
Operations: migrations.Operations{&migrations.OpRawSQL{
176+
Up: "CREATE TABLE table2(a int)",
177+
}},
178+
})
179+
require.NoError(t, err)
180+
err = mig.Complete(ctx)
181+
require.NoError(t, err)
182+
183+
// Start a third create table migration
184+
err = mig.Start(ctx, &migrations.Migration{Name: "03_create_table", Operations: migrations.Operations{createTableOp("table3")}})
185+
require.NoError(t, err)
186+
187+
// The previous version is migration 01 because there is no version schema
188+
// for migration 02 due to the `WithNoVersionSchemaForRawSQL` option
189+
prevVersion, err := st.PreviousVersion(ctx, "public")
190+
require.NoError(t, err)
191+
require.NotNil(t, prevVersion)
192+
assert.Equal(t, "01_create_table", *prevVersion)
193+
194+
// Complete the third migration
195+
err = mig.Complete(ctx)
196+
require.NoError(t, err)
197+
198+
// The latest version version is migration 03
199+
latestVersion, err := st.LatestVersion(ctx, "public")
200+
require.NoError(t, err)
201+
require.NotNil(t, latestVersion)
202+
assert.Equal(t, "03_create_table", *latestVersion)
203+
})
204+
}
205+
158206
func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) {
159207
t.Parallel()
160208

pkg/roll/options.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ type options struct {
1818

1919
// disable pgroll version schemas creation and deletion
2020
disableVersionSchemas bool
21-
migrationHooks MigrationHooks
21+
22+
// disable creation of version schema for raw SQL migrations
23+
noVersionSchemaForRawSQL bool
24+
25+
migrationHooks MigrationHooks
2226
}
2327

2428
// MigrationHooks defines hooks that can be set to be called at various points
@@ -58,6 +62,12 @@ func WithDisableViewsManagement() Option {
5862
}
5963
}
6064

65+
func WithNoVersionSchemaForRawSQL() Option {
66+
return func(o *options) {
67+
o.noVersionSchemaForRawSQL = true
68+
}
69+
}
70+
6171
// WithMigrationHooks sets the migration hooks for the Roll instance
6272
// Migration hooks are called at various points during the migration process
6373
// to allow for custom behavior to be injected

pkg/roll/roll.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ type Roll struct {
2727
// disable pgroll version schemas creation and deletion
2828
disableVersionSchemas bool
2929

30+
// disable creation of version schema for raw SQL migrations
31+
noVersionSchemaForRawSQL bool
32+
3033
migrationHooks MigrationHooks
3134
state *state.State
3235
pgVersion PGVersion
@@ -58,13 +61,14 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...
5861
}
5962

6063
return &Roll{
61-
pgConn: &db.RDB{DB: conn},
62-
schema: schema,
63-
state: state,
64-
pgVersion: PGVersion(pgMajorVersion),
65-
disableVersionSchemas: rollOpts.disableVersionSchemas,
66-
migrationHooks: rollOpts.migrationHooks,
67-
sqlTransformer: sqlTransformer,
64+
pgConn: &db.RDB{DB: conn},
65+
schema: schema,
66+
state: state,
67+
pgVersion: PGVersion(pgMajorVersion),
68+
disableVersionSchemas: rollOpts.disableVersionSchemas,
69+
noVersionSchemaForRawSQL: rollOpts.noVersionSchemaForRawSQL,
70+
migrationHooks: rollOpts.migrationHooks,
71+
sqlTransformer: sqlTransformer,
6872
}, nil
6973
}
7074

pkg/state/state.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,27 @@ LANGUAGE SQL
7070
STABLE;
7171
7272
-- Get the name of the previous version of the schema, or NULL if there is none.
73+
-- This ignores previous versions for which no version schema exists, such as
74+
-- versions corresponding to inferred migrations.
7375
CREATE OR REPLACE FUNCTION %[1]s.previous_version(schemaname NAME) RETURNS text
7476
AS $$
75-
WITH RECURSIVE find_ancestor AS (
76-
SELECT schema, name, parent, migration_type FROM %[1]s.migrations
77-
WHERE name = (SELECT %[1]s.latest_version(schemaname)) AND schema = schemaname
77+
WITH RECURSIVE ancestors AS (
78+
SELECT name, parent, migration_type, 0 AS depth FROM %[1]s.migrations
79+
WHERE name = %[1]s.latest_version(schemaname)
7880
79-
UNION ALL
81+
UNION ALL
8082
81-
SELECT m.schema, m.name, m.parent, m.migration_type FROM %[1]s.migrations m
82-
INNER JOIN find_ancestor fa ON fa.parent = m.name AND fa.schema = m.schema
83-
WHERE m.migration_type = 'inferred'
83+
SELECT m.name, m.parent, m.migration_type, a.depth + 1
84+
FROM %[1]s.migrations m
85+
JOIN ancestors a ON m.name = a.parent
8486
)
85-
SELECT a.parent
86-
FROM find_ancestor AS a
87-
JOIN %[1]s.migrations AS b ON a.parent = b.name AND a.schema = b.schema
88-
WHERE b.migration_type = 'pgroll';
87+
SELECT a.name FROM ancestors a
88+
JOIN information_schema.schemata s
89+
ON s.schema_name = schemaname || '_' || a.name
90+
WHERE migration_type = 'pgroll'
91+
AND a.depth > 0
92+
ORDER by a.depth ASC
93+
LIMIT 1;
8994
$$
9095
LANGUAGE SQL
9196
STABLE;

pkg/testutils/util.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,51 @@ func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schem
174174
fn(mig, db)
175175
}
176176

177+
func WithMigratorAndStateAndConnectionToContainerWithOptions(t *testing.T, opts []roll.Option, fn func(*roll.Roll, *state.State, *sql.DB)) {
178+
t.Helper()
179+
ctx := context.Background()
180+
181+
db, connStr, dbName := setupTestDatabase(t)
182+
183+
st, err := state.New(ctx, connStr, "pgroll")
184+
if err != nil {
185+
t.Fatal(err)
186+
}
187+
188+
err = st.Init(ctx)
189+
if err != nil {
190+
t.Fatal(err)
191+
}
192+
193+
mig, err := roll.New(ctx, connStr, "public", st, opts...)
194+
if err != nil {
195+
t.Fatal(err)
196+
}
197+
198+
t.Cleanup(func() {
199+
if err := mig.Close(); err != nil {
200+
t.Fatalf("Failed to close migrator connection: %v", err)
201+
}
202+
})
203+
204+
_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", "public"))
205+
if err != nil {
206+
t.Fatal(err)
207+
}
208+
209+
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON SCHEMA %s TO pgroll", "public"))
210+
if err != nil {
211+
t.Fatal(err)
212+
}
213+
214+
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO pgroll", dbName))
215+
if err != nil {
216+
t.Fatal(err)
217+
}
218+
219+
fn(mig, st, db)
220+
}
221+
177222
func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) {
178223
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, schema, []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
179224
}

0 commit comments

Comments
 (0)