@@ -43,6 +43,7 @@ const (
43
43
cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
44
44
cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
45
45
cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
46
+ cAlterTableAddForeignKeySQL = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
46
47
)
47
48
48
49
// NewColumnDuplicator creates a new Duplicator for a column.
@@ -91,7 +92,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
91
92
colNames = append (colNames , name )
92
93
93
94
// Duplicate the column with the new type
94
- // and check and fk constraints
95
95
if sql := d .stmtBuilder .duplicateColumn (c .column , c .asName , c .withoutNotNull , c .withType , d .withoutConstraint ); sql != "" {
96
96
_ , err := d .conn .ExecContext (ctx , sql )
97
97
if err != nil {
@@ -108,6 +108,7 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
108
108
}
109
109
}
110
110
111
+ // Duplicate the column's comment
111
112
if sql := d .stmtBuilder .duplicateComment (c .column , c .asName ); sql != "" {
112
113
_ , err := d .conn .ExecContext (ctx , sql )
113
114
if err != nil {
@@ -120,7 +121,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
120
121
// if the check constraint is not valid for the new column type, in which case
121
122
// the error is ignored.
122
123
for _ , sql := range d .stmtBuilder .duplicateCheckConstraints (d .withoutConstraint , colNames ... ) {
123
- // Update the check constraint expression to use the new column names if any of the columns are duplicated
124
124
_ , err := d .conn .ExecContext (ctx , sql )
125
125
err = errorIgnoringErrorCode (err , undefinedFunctionErrorCode )
126
126
if err != nil {
@@ -132,12 +132,21 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
132
132
// The constraint is duplicated by adding a unique index on the column concurrently.
133
133
// The index is converted into a unique constraint on migration completion.
134
134
for _ , sql := range d .stmtBuilder .duplicateUniqueConstraints (d .withoutConstraint , colNames ... ) {
135
- // Update the unique constraint columns to use the new column names if any of the columns are duplicated
136
135
if _ , err := d .conn .ExecContext (ctx , sql ); err != nil {
137
136
return err
138
137
}
139
138
}
140
139
140
+ // Generate SQL to duplicate any foreign key constraints on the columns.
141
+ // If the foreign key constraint is not valid for a new column type, the error is ignored.
142
+ for _ , sql := range d .stmtBuilder .duplicateForeignKeyConstraints (d .withoutConstraint , colNames ... ) {
143
+ _ , err := d .conn .ExecContext (ctx , sql )
144
+ err = errorIgnoringErrorCode (err , dataTypeMismatchErrorCode )
145
+ if err != nil {
146
+ return err
147
+ }
148
+ }
149
+
141
150
return nil
142
151
}
143
152
@@ -175,6 +184,26 @@ func (d *duplicatorStmtBuilder) duplicateUniqueConstraints(withoutConstraint []s
175
184
return stmts
176
185
}
177
186
187
+ func (d * duplicatorStmtBuilder ) duplicateForeignKeyConstraints (withoutConstraint []string , colNames ... string ) []string {
188
+ stmts := make ([]string , 0 , len (d .table .ForeignKeys ))
189
+ for _ , fk := range d .table .ForeignKeys {
190
+ if slices .Contains (withoutConstraint , fk .Name ) {
191
+ continue
192
+ }
193
+ if duplicatedMember , constraintColumns := d .allConstraintColumns (fk .Columns , colNames ... ); duplicatedMember {
194
+ stmts = append (stmts , fmt .Sprintf (cAlterTableAddForeignKeySQL ,
195
+ pq .QuoteIdentifier (d .table .Name ),
196
+ pq .QuoteIdentifier (DuplicationName (fk .Name )),
197
+ strings .Join (quoteColumnNames (constraintColumns ), ", " ),
198
+ pq .QuoteIdentifier (fk .ReferencedTable ),
199
+ strings .Join (quoteColumnNames (fk .ReferencedColumns ), ", " ),
200
+ fk .OnDelete ,
201
+ ))
202
+ }
203
+ }
204
+ return stmts
205
+ }
206
+
178
207
// duplicatedConstraintColumns returns a new slice of constraint columns with
179
208
// the columns that are duplicated replaced with temporary names.
180
209
func (d * duplicatorStmtBuilder ) duplicatedConstraintColumns (constraintColumns []string , duplicatedColumns ... string ) []string {
@@ -213,7 +242,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
213
242
) string {
214
243
const (
215
244
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
216
- cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
217
245
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
218
246
)
219
247
@@ -232,23 +260,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
232
260
)
233
261
}
234
262
235
- // Generate SQL to duplicate any foreign key constraints on the column
236
- for _ , fk := range d .table .ForeignKeys {
237
- if slices .Contains (withoutConstraint , fk .Name ) {
238
- continue
239
- }
240
-
241
- if slices .Contains (fk .Columns , column .Name ) {
242
- sql += fmt .Sprintf (", " + cAddForeignKeySQL ,
243
- pq .QuoteIdentifier (DuplicationName (fk .Name )),
244
- strings .Join (quoteColumnNames (copyAndReplace (fk .Columns , column .Name , asName )), ", " ),
245
- pq .QuoteIdentifier (fk .ReferencedTable ),
246
- strings .Join (quoteColumnNames (fk .ReferencedColumns ), ", " ),
247
- fk .OnDelete ,
248
- )
249
- }
250
- }
251
-
252
263
return sql
253
264
}
254
265
@@ -295,17 +306,6 @@ func StripDuplicationPrefix(name string) string {
295
306
return strings .TrimPrefix (name , "_pgroll_dup_" )
296
307
}
297
308
298
- func copyAndReplace (xs []string , oldValue , newValue string ) []string {
299
- ys := slices .Clone (xs )
300
-
301
- for i , c := range ys {
302
- if c == oldValue {
303
- ys [i ] = newValue
304
- }
305
- }
306
- return ys
307
- }
308
-
309
309
func errorIgnoringErrorCode (err error , code pq.ErrorCode ) error {
310
310
pqErr := & pq.Error {}
311
311
if ok := errors .As (err , & pqErr ); ok {
0 commit comments