Skip to content

Commit 470b5ef

Browse files
kvchandrew-farries
andauthored
Return list of DBAction from Start (#944)
This PR returns adds a new return value to `Start` to return the list of `DBActions`. However, returning 3 values is too many, so in a follow-up PR I am combining the list of `DBActions` and `backfill.Task` into the same data structure. The name is pending, my ideas so far `StartStep` or `StartOperation`. If you have a better name, I am open to using them. New data structure in a follow-up PR: ```golang type StartOperation struct { Actions []DBAction Task *backfill.Task } ``` --------- Co-authored-by: Andrew Farries <andyrb@gmail.com>
1 parent 6dfeb32 commit 470b5ef

25 files changed

+217
-213
lines changed

pkg/migrations/migrations.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ import (
1616

1717
// Operation is an operation that can be applied to a schema
1818
type Operation interface {
19-
// Start will apply the required changes to enable supporting the new schema
19+
// Start will return the list of required changes to enable supporting the new schema
2020
// version in the database (through a view)
2121
// update the given views to expose the new schema version
2222
// Returns the table that requires backfilling, if any.
23-
Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) (*backfill.Task, error)
23+
Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) ([]DBAction, *backfill.Task, error)
2424

2525
// Complete will update the database schema to match the current version
2626
// after calling Start.
@@ -118,7 +118,7 @@ func (m *Migration) UpdateVirtualSchema(ctx context.Context, s *schema.Schema) e
118118
// Run `Start` on each operation using the fake DB. Updates will be made to
119119
// the in-memory schema `s` without touching the physical database.
120120
for _, op := range m.Operations {
121-
if _, err := op.Start(ctx, NewNoopLogger(), db, s); err != nil {
121+
if _, _, err := op.Start(ctx, NewNoopLogger(), db, s); err != nil {
122122
return err
123123
}
124124
}

pkg/migrations/op_add_column.go

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ var (
1919
_ Createable = (*OpAddColumn)(nil)
2020
)
2121

22-
func (o *OpAddColumn) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) (*backfill.Task, error) {
22+
func (o *OpAddColumn) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) ([]DBAction, *backfill.Task, error) {
2323
l.LogOperationStart(o)
2424

2525
table := s.GetTable(o.Table)
2626
if table == nil {
27-
return nil, TableDoesNotExistError{Name: o.Table}
27+
return nil, nil, TableDoesNotExistError{Name: o.Table}
2828
}
2929

3030
// If the column has a DEFAULT, check if it can be added using the fast path
@@ -33,20 +33,19 @@ func (o *OpAddColumn) Start(ctx context.Context, l Logger, conn db.DB, s *schema
3333
if o.Column.HasDefault() {
3434
v, err := defaults.UsesFastPath(ctx, conn, table.Name, o.Column.Type, *o.Column.Default)
3535
if err != nil {
36-
return nil, fmt.Errorf("failed to check for fast path default optimization: %w", err)
36+
return nil, nil, fmt.Errorf("failed to check for fast path default optimization: %w", err)
3737
}
3838
fastPathDefault = v
3939
}
4040

41-
if err := addColumn(ctx, conn, *o, table, fastPathDefault); err != nil {
42-
return nil, fmt.Errorf("failed to start add column operation: %w", err)
41+
action, err := addColumn(conn, *o, table, fastPathDefault)
42+
if err != nil {
43+
return nil, nil, err
4344
}
45+
dbActions := []DBAction{action}
4446

4547
if o.Column.Comment != nil {
46-
err := NewCommentColumnAction(conn, table.Name, TemporaryName(o.Column.Name), o.Column.Comment).Execute(ctx)
47-
if err != nil {
48-
return nil, fmt.Errorf("failed to add comment to column: %w", err)
49-
}
48+
dbActions = append(dbActions, NewCommentColumnAction(conn, table.Name, TemporaryName(o.Column.Name), o.Column.Comment))
5049
}
5150

5251
// If the column is `NOT NULL` and there is no default value (either because
@@ -56,47 +55,48 @@ func (o *OpAddColumn) Start(ctx context.Context, l Logger, conn db.DB, s *schema
5655
skipInherit := false
5756
skipValidate := true
5857
if !o.Column.IsNullable() && (o.Column.Default == nil || !fastPathDefault) {
59-
if err := NewCreateCheckConstraintAction(
60-
conn,
61-
table.Name,
62-
NotNullConstraintName(o.Column.Name),
63-
fmt.Sprintf("%s IS NOT NULL", o.Column.Name),
64-
[]string{o.Column.Name},
65-
skipInherit,
66-
skipValidate,
67-
).Execute(ctx); err != nil {
68-
return nil, fmt.Errorf("failed to add not null constraint: %w", err)
69-
}
58+
dbActions = append(dbActions,
59+
NewCreateCheckConstraintAction(
60+
conn,
61+
table.Name,
62+
NotNullConstraintName(o.Column.Name),
63+
fmt.Sprintf("%s IS NOT NULL", o.Column.Name),
64+
[]string{o.Column.Name},
65+
skipInherit,
66+
skipValidate,
67+
))
7068
}
7169

7270
if o.Column.Check != nil {
73-
if err := NewCreateCheckConstraintAction(
74-
conn,
75-
table.Name,
76-
o.Column.Check.Name,
77-
o.Column.Check.Constraint,
78-
[]string{o.Column.Name},
79-
skipInherit,
80-
skipValidate,
81-
).Execute(ctx); err != nil {
82-
return nil, fmt.Errorf("failed to add check constraint: %w", err)
83-
}
71+
dbActions = append(dbActions,
72+
NewCreateCheckConstraintAction(
73+
conn,
74+
table.Name,
75+
o.Column.Check.Name,
76+
o.Column.Check.Constraint,
77+
[]string{o.Column.Name},
78+
skipInherit,
79+
skipValidate,
80+
))
8481
}
8582

8683
if o.Column.Unique {
87-
createIndex := NewCreateUniqueIndexConcurrentlyAction(conn, s.Name, UniqueIndexName(o.Column.Name), table.Name, TemporaryName(o.Column.Name))
88-
err := createIndex.Execute(ctx)
89-
if err != nil {
90-
return nil, fmt.Errorf("failed to add unique index: %w", err)
91-
}
84+
dbActions = append(dbActions,
85+
NewCreateUniqueIndexConcurrentlyAction(
86+
conn,
87+
s.Name,
88+
UniqueIndexName(o.Column.Name),
89+
table.Name,
90+
TemporaryName(o.Column.Name),
91+
))
9292
}
9393

9494
// If the column has a DEFAULT that cannot be set using the fast path
9595
// optimization, the `up` SQL expression must be used to set the DEFAULT
9696
// value for the column.
9797
if o.Column.HasDefault() && !fastPathDefault {
9898
if o.Up != *o.Column.Default {
99-
return nil, UpSQLMustBeColumnDefaultError{Column: o.Column.Name}
99+
return nil, nil, UpSQLMustBeColumnDefaultError{Column: o.Column.Name}
100100
}
101101
}
102102

@@ -118,7 +118,7 @@ func (o *OpAddColumn) Start(ctx context.Context, l Logger, conn db.DB, s *schema
118118
tmpColumn.Name = TemporaryName(o.Column.Name)
119119
table.AddColumn(o.Column.Name, tmpColumn)
120120

121-
return task, nil
121+
return dbActions, task, nil
122122
}
123123

124124
func toSchemaColumn(c Column) *schema.Column {
@@ -253,7 +253,7 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
253253
return nil
254254
}
255255

256-
func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table, fastPathDefault bool) error {
256+
func addColumn(conn db.DB, o OpAddColumn, t *schema.Table, fastPathDefault bool) (DBAction, error) {
257257
// don't add non-nullable columns with no default directly
258258
// they are handled by:
259259
// - adding the column as nullable
@@ -266,7 +266,7 @@ func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table,
266266
}
267267

268268
if o.Column.Generated != nil {
269-
return fmt.Errorf("adding generated columns to existing tables is not supported")
269+
return nil, fmt.Errorf("adding generated columns to existing tables is not supported")
270270
}
271271

272272
// Don't add a column with a CHECK constraint directly.
@@ -299,7 +299,7 @@ func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table,
299299
o.Column.Name = TemporaryName(o.Column.Name)
300300

301301
withPK := true
302-
return NewAddColumnAction(conn, t.Name, o.Column, withPK).Execute(ctx)
302+
return NewAddColumnAction(conn, t.Name, o.Column, withPK), nil
303303
}
304304

305305
// upgradeNotNullConstraintToNotNullAttribute validates and upgrades a NOT NULL

pkg/migrations/op_alter_column.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,24 @@ var (
1818
_ Createable = (*OpAlterColumn)(nil)
1919
)
2020

21-
func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) (*backfill.Task, error) {
21+
func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) ([]DBAction, *backfill.Task, error) {
2222
l.LogOperationStart(o)
2323

2424
table := s.GetTable(o.Table)
2525
if table == nil {
26-
return nil, TableDoesNotExistError{Name: o.Table}
26+
return nil, nil, TableDoesNotExistError{Name: o.Table}
2727
}
2828
column := table.GetColumn(o.Column)
2929
if column == nil {
30-
return nil, ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
30+
return nil, nil, ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
3131
}
3232
ops := o.subOperations()
3333

3434
// Duplicate the column on the underlying table.
3535
d := duplicatorForOperations(ops, conn, table, column).
3636
WithName(column.Name, TemporaryName(o.Column))
3737
if err := d.Execute(ctx); err != nil {
38-
return nil, fmt.Errorf("failed to duplicate column: %w", err)
38+
return nil, nil, fmt.Errorf("failed to duplicate column: %w", err)
3939
}
4040

4141
// Copy the columns from table columns, so we can use it later
@@ -79,16 +79,19 @@ func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, s *sche
7979
},
8080
)
8181
task := backfill.NewTask(table, triggers...)
82+
83+
var dbActions []DBAction
8284
// perform any operation specific start steps
8385
for _, op := range ops {
84-
bf, err := op.Start(ctx, l, conn, s)
86+
actions, bf, err := op.Start(ctx, l, conn, s)
8587
if err != nil {
86-
return nil, err
88+
return nil, nil, err
8789
}
8890
task.AddTriggers(bf)
91+
dbActions = append(dbActions, actions...)
8992
}
9093

91-
return task, nil
94+
return dbActions, task, nil
9295
}
9396

9497
func (o *OpAlterColumn) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {

pkg/migrations/op_change_type.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ type OpChangeType struct {
2020

2121
var _ Operation = (*OpChangeType)(nil)
2222

23-
func (o *OpChangeType) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) (*backfill.Task, error) {
23+
func (o *OpChangeType) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) ([]DBAction, *backfill.Task, error) {
2424
l.LogOperationStart(o)
2525

2626
table := s.GetTable(o.Table)
2727
if table == nil {
28-
return nil, TableDoesNotExistError{Name: o.Table}
28+
return nil, nil, TableDoesNotExistError{Name: o.Table}
2929
}
3030

31-
return backfill.NewTask(table), nil
31+
return nil, backfill.NewTask(table), nil
3232
}
3333

3434
func (o *OpChangeType) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {

pkg/migrations/op_create_constraint.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package migrations
44

55
import (
66
"context"
7-
"fmt"
87

98
"github.com/xataio/pgroll/pkg/backfill"
109
"github.com/xataio/pgroll/pkg/db"
@@ -16,19 +15,19 @@ var (
1615
_ Createable = (*OpCreateConstraint)(nil)
1716
)
1817

19-
func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) (*backfill.Task, error) {
18+
func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) ([]DBAction, *backfill.Task, error) {
2019
l.LogOperationStart(o)
2120

2221
table := s.GetTable(o.Table)
2322
if table == nil {
24-
return nil, TableDoesNotExistError{Name: o.Table}
23+
return nil, nil, TableDoesNotExistError{Name: o.Table}
2524
}
2625

2726
columns := make([]*schema.Column, len(o.Columns))
2827
for i, colName := range o.Columns {
2928
columns[i] = table.GetColumn(colName)
3029
if columns[i] == nil {
31-
return nil, ColumnDoesNotExistError{Table: o.Table, Name: colName}
30+
return nil, nil, ColumnDoesNotExistError{Table: o.Table, Name: colName}
3231
}
3332
}
3433

@@ -37,9 +36,7 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, s
3736
for _, colName := range o.Columns {
3837
d = d.WithName(table.GetColumn(colName).Name, TemporaryName(colName))
3938
}
40-
if err := d.Execute(ctx); err != nil {
41-
return nil, fmt.Errorf("failed to duplicate columns for new constraint: %w", err)
42-
}
39+
dbActions := []DBAction{d}
4340

4441
// Copy the columns from table columns, so we can use it later
4542
// in the down trigger with the physical name
@@ -89,14 +86,25 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, s
8986

9087
switch o.Type {
9188
case OpCreateConstraintTypeUnique, OpCreateConstraintTypePrimaryKey:
92-
return task, NewCreateUniqueIndexConcurrentlyAction(conn, s.Name, o.Name, table.Name, temporaryNames(o.Columns)...).Execute(ctx)
89+
dbActions = append(dbActions,
90+
NewCreateUniqueIndexConcurrentlyAction(conn, s.Name, o.Name, table.Name, temporaryNames(o.Columns)...),
91+
)
92+
return dbActions, task, nil
93+
9394
case OpCreateConstraintTypeCheck:
94-
return task, NewCreateCheckConstraintAction(conn, table.Name, o.Name, *o.Check, o.Columns, o.NoInherit, true).Execute(ctx)
95+
dbActions = append(dbActions,
96+
NewCreateCheckConstraintAction(conn, table.Name, o.Name, *o.Check, o.Columns, o.NoInherit, true),
97+
)
98+
return dbActions, task, nil
99+
95100
case OpCreateConstraintTypeForeignKey:
96-
return task, NewCreateFKConstraintAction(conn, table.Name, o.Name, temporaryNames(o.Columns), o.References, false, false, true).Execute(ctx)
101+
dbActions = append(dbActions,
102+
NewCreateFKConstraintAction(conn, table.Name, o.Name, temporaryNames(o.Columns), o.References, false, false, true),
103+
)
104+
return dbActions, task, nil
97105
}
98106

99-
return task, nil
107+
return dbActions, task, nil
100108
}
101109

102110
func (o *OpCreateConstraint) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {

pkg/migrations/op_create_index.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ var (
1818
_ Createable = (*OpCreateIndex)(nil)
1919
)
2020

21-
func (o *OpCreateIndex) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) (*backfill.Task, error) {
21+
func (o *OpCreateIndex) Start(ctx context.Context, l Logger, conn db.DB, s *schema.Schema) ([]DBAction, *backfill.Task, error) {
2222
l.LogOperationStart(o)
2323

2424
table := s.GetTable(o.Table)
2525
if table == nil {
26-
return nil, TableDoesNotExistError{Name: o.Table}
26+
return nil, nil, TableDoesNotExistError{Name: o.Table}
2727
}
2828

2929
cols := make(map[string]IndexField, len(o.Columns))
@@ -32,18 +32,20 @@ func (o *OpCreateIndex) Start(ctx context.Context, l Logger, conn db.DB, s *sche
3232
cols[physicalName[0]] = settings
3333
}
3434

35-
err := NewCreateIndexConcurrentlyAction(
36-
conn,
37-
table.Name,
38-
o.Name,
39-
string(o.Method),
40-
o.Unique,
41-
cols,
42-
o.StorageParameters,
43-
o.Predicate,
44-
).Execute(ctx)
45-
46-
return nil, err
35+
dbActions := []DBAction{
36+
NewCreateIndexConcurrentlyAction(
37+
conn,
38+
table.Name,
39+
o.Name,
40+
string(o.Method),
41+
o.Unique,
42+
cols,
43+
o.StorageParameters,
44+
o.Predicate,
45+
),
46+
}
47+
48+
return dbActions, nil, nil
4749
}
4850

4951
func (o *OpCreateIndex) Complete(l Logger, conn db.DB, s *schema.Schema) ([]DBAction, error) {

0 commit comments

Comments
 (0)