Skip to content

Commit 95d7f41

Browse files
committed
new dbaction: createfkconstraint
1 parent 16173dd commit 95d7f41

File tree

3 files changed

+78
-63
lines changed

3 files changed

+78
-63
lines changed

pkg/migrations/dbactions.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,49 @@ func (a *validateConstraintAction) Execute(ctx context.Context) error {
394394
pq.QuoteIdentifier(a.constraint)))
395395
return err
396396
}
397+
398+
// createFKConstraintAction is a DBAction that creates a new foreign key constraint
399+
type createFKConstraintAction struct {
400+
conn db.DB
401+
table string
402+
constraint string
403+
columns []string
404+
initiallyDeferred bool
405+
deferrable bool
406+
reference *TableForeignKeyReference
407+
skipValidation bool
408+
}
409+
410+
func NewCreateFKConstraintAction(conn db.DB, table, constraint string, columns []string, reference *TableForeignKeyReference, initiallyDeferred, deferrable, skipValidation bool) *createFKConstraintAction {
411+
return &createFKConstraintAction{
412+
conn: conn,
413+
table: table,
414+
constraint: constraint,
415+
columns: columns,
416+
reference: reference,
417+
initiallyDeferred: initiallyDeferred,
418+
deferrable: deferrable,
419+
skipValidation: skipValidation,
420+
}
421+
}
422+
423+
func (a *createFKConstraintAction) Execute(ctx context.Context) error {
424+
sql := fmt.Sprintf("ALTER TABLE %s ADD ", pq.QuoteIdentifier(a.table))
425+
writer := &ConstraintSQLWriter{
426+
Name: a.constraint,
427+
Columns: a.columns,
428+
InitiallyDeferred: a.initiallyDeferred,
429+
Deferrable: a.deferrable,
430+
SkipValidation: a.skipValidation,
431+
}
432+
sql += writer.WriteForeignKey(
433+
a.reference.Table,
434+
a.reference.Columns,
435+
a.reference.OnDelete,
436+
a.reference.OnUpdate,
437+
a.reference.OnDeleteSetColumns,
438+
a.reference.MatchType)
439+
440+
_, err := a.conn.ExecContext(ctx, sql)
441+
return err
442+
}

pkg/migrations/op_create_constraint.go

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, la
9595
case OpCreateConstraintTypeCheck:
9696
return table, o.addCheckConstraint(ctx, conn, table.Name)
9797
case OpCreateConstraintTypeForeignKey:
98-
return table, o.addForeignKeyConstraint(ctx, conn, table)
98+
return table, NewCreateFKConstraintAction(conn, o.Name, table.Name, temporaryNames(o.Columns), o.References, false, false, true).Execute(ctx)
9999
}
100100

101101
return table, nil
@@ -296,28 +296,6 @@ func (o *OpCreateConstraint) addCheckConstraint(ctx context.Context, conn db.DB,
296296
return err
297297
}
298298

299-
func (o *OpCreateConstraint) addForeignKeyConstraint(ctx context.Context, conn db.DB, table *schema.Table) error {
300-
sql := fmt.Sprintf("ALTER TABLE %s ADD ", pq.QuoteIdentifier(table.Name))
301-
302-
writer := &ConstraintSQLWriter{
303-
Name: o.Name,
304-
Columns: temporaryNames(o.Columns),
305-
SkipValidation: true,
306-
}
307-
sql += writer.WriteForeignKey(
308-
o.References.Table,
309-
o.References.Columns,
310-
o.References.OnDelete,
311-
o.References.OnUpdate,
312-
o.References.OnDeleteSetColumns,
313-
o.References.MatchType,
314-
)
315-
316-
_, err := conn.ExecContext(ctx, sql)
317-
318-
return err
319-
}
320-
321299
func temporaryNames(columns []string) []string {
322300
names := make([]string, len(columns))
323301
for i, col := range columns {

pkg/migrations/op_set_fk.go

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ import (
66
"context"
77
"fmt"
88

9-
"github.com/lib/pq"
10-
119
"github.com/xataio/pgroll/pkg/db"
1210
"github.com/xataio/pgroll/pkg/schema"
1311
)
@@ -26,9 +24,39 @@ func (o *OpSetForeignKey) Start(ctx context.Context, l Logger, conn db.DB, lates
2624
l.LogOperationStart(o)
2725

2826
table := s.GetTable(o.Table)
27+
if table == nil {
28+
return nil, TableDoesNotExistError{Name: o.Table}
29+
}
30+
column := table.GetColumn(o.Column)
31+
if column == nil {
32+
return nil, ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
33+
}
34+
referencedTable := s.GetTable(o.References.Table)
35+
if referencedTable == nil {
36+
return nil, TableDoesNotExistError{Name: o.References.Table}
37+
}
38+
39+
referencedColumn := referencedTable.GetColumn(o.References.Column)
40+
if referencedColumn == nil {
41+
return nil, ColumnDoesNotExistError{Table: o.References.Table, Name: o.References.Column}
42+
}
2943

3044
// Create a NOT VALID foreign key constraint on the new column.
31-
if err := o.addForeignKeyConstraint(ctx, conn, s); err != nil {
45+
if err := NewCreateFKConstraintAction(conn,
46+
o.Table,
47+
o.References.Name,
48+
[]string{o.Column},
49+
&TableForeignKeyReference{
50+
Table: o.References.Table,
51+
Columns: []string{o.References.Column},
52+
MatchType: o.References.MatchType,
53+
OnDelete: o.References.OnDelete,
54+
OnUpdate: o.References.OnUpdate,
55+
},
56+
o.References.InitiallyDeferred,
57+
o.References.Deferrable,
58+
true,
59+
).Execute(ctx); err != nil {
3260
return nil, fmt.Errorf("failed to add foreign key constraint: %w", err)
3361
}
3462

@@ -84,40 +112,3 @@ func (o *OpSetForeignKey) Validate(ctx context.Context, s *schema.Schema) error
84112

85113
return nil
86114
}
87-
88-
func (o *OpSetForeignKey) addForeignKeyConstraint(ctx context.Context, conn db.DB, s *schema.Schema) error {
89-
table := s.GetTable(o.Table)
90-
if table == nil {
91-
return TableDoesNotExistError{Name: o.Table}
92-
}
93-
column := table.GetColumn(o.Column)
94-
if column == nil {
95-
return ColumnDoesNotExistError{Table: o.Table, Name: o.Column}
96-
}
97-
referencedTable := s.GetTable(o.References.Table)
98-
if referencedTable == nil {
99-
return TableDoesNotExistError{Name: o.References.Table}
100-
}
101-
102-
referencedColumn := referencedTable.GetColumn(o.References.Column)
103-
if referencedColumn == nil {
104-
return ColumnDoesNotExistError{Table: o.References.Table, Name: o.References.Column}
105-
}
106-
107-
sql := fmt.Sprintf("ALTER TABLE %s ADD ", pq.QuoteIdentifier(table.Name))
108-
writer := &ConstraintSQLWriter{
109-
Name: o.References.Name,
110-
Columns: []string{column.Name},
111-
SkipValidation: true,
112-
}
113-
sql += writer.WriteForeignKey(
114-
referencedTable.Name,
115-
[]string{referencedColumn.Name},
116-
o.References.OnDelete,
117-
o.References.OnUpdate,
118-
nil,
119-
o.References.MatchType)
120-
121-
_, err := conn.ExecContext(ctx, sql)
122-
return err
123-
}

0 commit comments

Comments
 (0)