Skip to content
39 changes: 32 additions & 7 deletions pkg/migrations/duplicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

// Duplicator duplicates a column in a table, including all constraints and
// comments.
// Column must be set with WithColumn before calling Duplicate.
type Duplicator struct {
conn db.DB
table *schema.Table
Expand All @@ -24,6 +25,7 @@ type Duplicator struct {
withoutNotNull bool
withType string
withoutConstraint string
duplicatedUnique map[string]schema.UniqueConstraint
}

const (
Expand All @@ -32,13 +34,11 @@ const (
)

// NewColumnDuplicator creates a new Duplicator for a column.
func NewColumnDuplicator(conn db.DB, table *schema.Table, column *schema.Column) *Duplicator {
func NewColumnDuplicator(conn db.DB, table *schema.Table) *Duplicator {
return &Duplicator{
conn: conn,
table: table,
column: column,
asName: TemporaryName(column.Name),
withType: column.Type,
conn: conn,
table: table,
duplicatedUnique: make(map[string]schema.UniqueConstraint),
}
}

Expand All @@ -60,9 +60,20 @@ func (d *Duplicator) WithoutNotNull() *Duplicator {
return d
}

func (d *Duplicator) WithColumn(c *schema.Column) *Duplicator {
d.column = c
d.asName = TemporaryName(c.Name)
d.withType = c.Type
return d
}

// Duplicate duplicates a column in the table, including all constraints and
// comments.
func (d *Duplicator) Duplicate(ctx context.Context) error {
if d.column == nil {
return errors.New("column not set")
}

const (
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
Expand Down Expand Up @@ -171,16 +182,30 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
}

if slices.Contains(uc.Columns, d.column.Name) {
cols := uc.Columns
// Drop the existing unique index if it is a duplicated unique constraint
if existingConstraint, ok := d.duplicatedUnique[DuplicationName(uc.Name)]; ok {
dropIndex := fmt.Sprintf("DROP INDEX CONCURRENTLY %s", pq.QuoteIdentifier(DuplicationName(uc.Name)))
_, err := d.conn.ExecContext(ctx, dropIndex)
if err != nil {
return err
}
cols = existingConstraint.Columns
}
sql = fmt.Sprintf(cCreateUniqueIndexSQL,
pq.QuoteIdentifier(DuplicationName(uc.Name)),
pq.QuoteIdentifier(d.table.Name),
strings.Join(quoteColumnNames(copyAndReplace(uc.Columns, d.column.Name, d.asName)), ", "),
strings.Join(quoteColumnNames(copyAndReplace(cols, d.column.Name, d.asName)), ", "),
)

_, err = d.conn.ExecContext(ctx, sql)
if err != nil {
return err
}
d.duplicatedUnique[DuplicationName(uc.Name)] = schema.UniqueConstraint{
Name: DuplicationName(uc.Name),
Columns: copyAndReplace(cols, d.column.Name, d.asName),
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/migrations/op_alter_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ func (o *OpAlterColumn) subOperations() []Operation {

// duplicatorForOperations returns a Duplicator for the given operations
func duplicatorForOperations(ops []Operation, conn db.DB, table *schema.Table, column *schema.Column) *Duplicator {
d := NewColumnDuplicator(conn, table, column)
d := NewColumnDuplicator(conn, table).WithColumn(column)

for _, op := range ops {
switch op := (op).(type) {
Expand Down
26 changes: 12 additions & 14 deletions pkg/migrations/op_create_constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ var _ Operation = (*OpCreateConstraint)(nil)

func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
var err error
var table *schema.Table
table := s.GetTable(o.Table)
d := NewColumnDuplicator(conn, table)
for _, col := range o.Columns {
if table, err = o.duplicateColumnBeforeStart(ctx, conn, latestSchema, tr, col, s); err != nil {
if err = o.duplicateColumnBeforeStart(ctx, conn, latestSchema, tr, table, col, d, s); err != nil {
return nil, err
}
}
Expand All @@ -32,18 +33,15 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema
return table, nil
}

func (o *OpCreateConstraint) duplicateColumnBeforeStart(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, colName string, s *schema.Schema) (*schema.Table, error) {
table := s.GetTable(o.Table)
column := table.GetColumn(colName)

d := NewColumnDuplicator(conn, table, column)
func (o *OpCreateConstraint) duplicateColumnBeforeStart(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, table *schema.Table, colName string, d *Duplicator, s *schema.Schema) error {
d.WithColumn(table.GetColumn(colName))
if err := d.Duplicate(ctx); err != nil {
return nil, fmt.Errorf("failed to duplicate column for new constraint: %w", err)
return fmt.Errorf("failed to duplicate column for new constraint: %w", err)
}

upSQL, ok := o.Up[colName]
if !ok {
return nil, fmt.Errorf("up migration is missing for column %s", colName)
return fmt.Errorf("up migration is missing for column %s", colName)
}
physicalColumnName := TemporaryName(colName)
err := createTrigger(ctx, conn, tr, triggerConfig{
Expand All @@ -57,16 +55,16 @@ func (o *OpCreateConstraint) duplicateColumnBeforeStart(ctx context.Context, con
SQL: upSQL,
})
if err != nil {
return nil, fmt.Errorf("failed to create up trigger: %w", err)
return fmt.Errorf("failed to create up trigger: %w", err)
}

table.AddColumn(colName, schema.Column{
d.table.AddColumn(colName, schema.Column{
Name: physicalColumnName,
})

downSQL, ok := o.Down[colName]
if !ok {
return nil, fmt.Errorf("down migration is missing for column %s", colName)
return fmt.Errorf("down migration is missing for column %s", colName)
}
err = createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, physicalColumnName),
Expand All @@ -79,9 +77,9 @@ func (o *OpCreateConstraint) duplicateColumnBeforeStart(ctx context.Context, con
SQL: downSQL,
})
if err != nil {
return nil, fmt.Errorf("failed to create down trigger: %w", err)
return fmt.Errorf("failed to create down trigger: %w", err)
}
return table, nil
return nil
}

func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
Expand Down
2 changes: 1 addition & 1 deletion pkg/migrations/op_drop_constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, latestSchema s
column := table.GetColumn(table.GetConstraintColumns(o.Name)[0])

// Create a copy of the column on the underlying table.
d := NewColumnDuplicator(conn, table, column).WithoutConstraint(o.Name)
d := NewColumnDuplicator(conn, table).WithColumn(column).WithoutConstraint(o.Name)
if err := d.Duplicate(ctx); err != nil {
return nil, fmt.Errorf("failed to duplicate column: %w", err)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/migrations/rename.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ func RenameDuplicatedColumn(ctx context.Context, conn db.DB, table *schema.Table
if err != nil {
return fmt.Errorf("failed to create unique constraint from index %q: %w", ui.Name, err)
}
delete(table.Indexes, ui.Name)
}
}

Expand Down