Skip to content

Commit 9fe8d9b

Browse files
Duplicate foreign key constraints
Ensure that foreign key constraints, including multi-column constraints, are duplicated correctly.
1 parent d12406a commit 9fe8d9b

File tree

2 files changed

+85
-32
lines changed

2 files changed

+85
-32
lines changed

pkg/migrations/duplicate.go

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ const (
4343
cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
4444
cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
4545
cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
46+
cAlterTableAddForeignKeySQL = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
4647
)
4748

4849
// NewColumnDuplicator creates a new Duplicator for a column.
@@ -91,7 +92,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
9192
colNames = append(colNames, name)
9293

9394
// Duplicate the column with the new type
94-
// and check and fk constraints
9595
if sql := d.stmtBuilder.duplicateColumn(c.column, c.asName, c.withoutNotNull, c.withType, d.withoutConstraint); sql != "" {
9696
_, err := d.conn.ExecContext(ctx, sql)
9797
if err != nil {
@@ -108,6 +108,7 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
108108
}
109109
}
110110

111+
// Duplicate the column's comment
111112
if sql := d.stmtBuilder.duplicateComment(c.column, c.asName); sql != "" {
112113
_, err := d.conn.ExecContext(ctx, sql)
113114
if err != nil {
@@ -120,7 +121,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
120121
// if the check constraint is not valid for the new column type, in which case
121122
// the error is ignored.
122123
for _, sql := range d.stmtBuilder.duplicateCheckConstraints(d.withoutConstraint, colNames...) {
123-
// Update the check constraint expression to use the new column names if any of the columns are duplicated
124124
_, err := d.conn.ExecContext(ctx, sql)
125125
err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode)
126126
if err != nil {
@@ -132,12 +132,21 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
132132
// The constraint is duplicated by adding a unique index on the column concurrently.
133133
// The index is converted into a unique constraint on migration completion.
134134
for _, sql := range d.stmtBuilder.duplicateUniqueConstraints(d.withoutConstraint, colNames...) {
135-
// Update the unique constraint columns to use the new column names if any of the columns are duplicated
136135
if _, err := d.conn.ExecContext(ctx, sql); err != nil {
137136
return err
138137
}
139138
}
140139

140+
// Generate SQL to duplicate any foreign key constraints on the columns.
141+
// If the foreign key constraint is not valid for a new column type, the error is ignored.
142+
for _, sql := range d.stmtBuilder.duplicateForeignKeyConstraints(d.withoutConstraint, colNames...) {
143+
_, err := d.conn.ExecContext(ctx, sql)
144+
err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode)
145+
if err != nil {
146+
return err
147+
}
148+
}
149+
141150
return nil
142151
}
143152

@@ -175,6 +184,26 @@ func (d *duplicatorStmtBuilder) duplicateUniqueConstraints(withoutConstraint []s
175184
return stmts
176185
}
177186

187+
func (d *duplicatorStmtBuilder) duplicateForeignKeyConstraints(withoutConstraint []string, colNames ...string) []string {
188+
stmts := make([]string, 0, len(d.table.ForeignKeys))
189+
for _, fk := range d.table.ForeignKeys {
190+
if slices.Contains(withoutConstraint, fk.Name) {
191+
continue
192+
}
193+
if duplicatedMember, constraintColumns := d.allConstraintColumns(fk.Columns, colNames...); duplicatedMember {
194+
stmts = append(stmts, fmt.Sprintf(cAlterTableAddForeignKeySQL,
195+
pq.QuoteIdentifier(d.table.Name),
196+
pq.QuoteIdentifier(DuplicationName(fk.Name)),
197+
strings.Join(quoteColumnNames(constraintColumns), ", "),
198+
pq.QuoteIdentifier(fk.ReferencedTable),
199+
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
200+
fk.OnDelete,
201+
))
202+
}
203+
}
204+
return stmts
205+
}
206+
178207
// duplicatedConstraintColumns returns a new slice of constraint columns with
179208
// the columns that are duplicated replaced with temporary names.
180209
func (d *duplicatorStmtBuilder) duplicatedConstraintColumns(constraintColumns []string, duplicatedColumns ...string) []string {
@@ -213,7 +242,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
213242
) string {
214243
const (
215244
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
216-
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
217245
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
218246
)
219247

@@ -232,23 +260,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
232260
)
233261
}
234262

235-
// Generate SQL to duplicate any foreign key constraints on the column
236-
for _, fk := range d.table.ForeignKeys {
237-
if slices.Contains(withoutConstraint, fk.Name) {
238-
continue
239-
}
240-
241-
if slices.Contains(fk.Columns, column.Name) {
242-
sql += fmt.Sprintf(", "+cAddForeignKeySQL,
243-
pq.QuoteIdentifier(DuplicationName(fk.Name)),
244-
strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, column.Name, asName)), ", "),
245-
pq.QuoteIdentifier(fk.ReferencedTable),
246-
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
247-
fk.OnDelete,
248-
)
249-
}
250-
}
251-
252263
return sql
253264
}
254265

@@ -295,17 +306,6 @@ func StripDuplicationPrefix(name string) string {
295306
return strings.TrimPrefix(name, "_pgroll_dup_")
296307
}
297308

298-
func copyAndReplace(xs []string, oldValue, newValue string) []string {
299-
ys := slices.Clone(xs)
300-
301-
for i, c := range ys {
302-
if c == oldValue {
303-
ys[i] = newValue
304-
}
305-
}
306-
return ys
307-
}
308-
309309
func errorIgnoringErrorCode(err error, code pq.ErrorCode) error {
310310
pqErr := &pq.Error{}
311311
if ok := errors.As(err, &pqErr); ok {

pkg/migrations/duplicate_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ var table = &schema.Table{
3232
"new_york_adults": {Name: "new_york_adults", Columns: []string{"city", "age"}, Definition: `"city" = 'New York' AND "age" > 21`},
3333
"different_nick": {Name: "different_nick", Columns: []string{"name", "nick"}, Definition: `"name" != "nick"`},
3434
},
35+
ForeignKeys: map[string]schema.ForeignKey{
36+
"fk_city": {Name: "fk_city", Columns: []string{"city"}, ReferencedTable: "cities", ReferencedColumns: []string{"id"}, OnDelete: "NO ACTION"},
37+
"fk_name_nick": {Name: "fk_name_nick", Columns: []string{"name", "nick"}, ReferencedTable: "users", ReferencedColumns: []string{"name", "nick"}, OnDelete: "CASCADE"},
38+
},
3539
}
3640

3741
func TestDuplicateStmtBuilderCheckConstraints(t *testing.T) {
@@ -121,3 +125,52 @@ func TestDuplicateStmtBuilderUniqueConstraints(t *testing.T) {
121125
})
122126
}
123127
}
128+
129+
func TestDuplicateStmtBuilderForeignKeyConstraints(t *testing.T) {
130+
d := &duplicatorStmtBuilder{table}
131+
for name, testCases := range map[string]struct {
132+
columns []string
133+
expectedStmts []string
134+
}{
135+
"duplicate single column with no FK constraint": {
136+
columns: []string{"description"},
137+
expectedStmts: []string{},
138+
},
139+
"single-column FK with single column duplicated": {
140+
columns: []string{"city"},
141+
expectedStmts: []string{
142+
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
143+
},
144+
},
145+
"single-column FK with multiple columns duplicated": {
146+
columns: []string{"city", "description"},
147+
expectedStmts: []string{
148+
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
149+
},
150+
},
151+
"multi-column FK with single column duplicated": {
152+
columns: []string{"name"},
153+
expectedStmts: []string{
154+
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
155+
},
156+
},
157+
"multi-column FK with multiple unrelated column duplicated": {
158+
columns: []string{"name", "description"},
159+
expectedStmts: []string{
160+
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
161+
},
162+
},
163+
"multi-column FK with multiple columns": {
164+
columns: []string{"name", "nick"},
165+
expectedStmts: []string{`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "_pgroll_new_nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`},
166+
},
167+
} {
168+
t.Run(name, func(t *testing.T) {
169+
stmts := d.duplicateForeignKeyConstraints(nil, testCases.columns...)
170+
assert.Equal(t, len(testCases.expectedStmts), len(stmts))
171+
for _, stmt := range stmts {
172+
assert.Contains(t, testCases.expectedStmts, stmt)
173+
}
174+
})
175+
}
176+
}

0 commit comments

Comments
 (0)