Skip to content

Commit 06bfebd

Browse files
Convert ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) SQL to pgroll operation (#507)
Convert SQL DDL of the form: ```sql "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)" ``` To the equivalent `pgroll` operation: ```json [ { "create_constraint": { "type": "unique", "table": "foo", "name": "bar", "columns": ["a"], "up": { "a": "...", }, "down": { "a": "..." } } } ] ``` We need to be conservative when converting SQL statements to `pgroll` operations to ensure that information present in the SQL is not lost during the conversion. There are several options possible as part of `ADD CONSTRAINT ... UNIQUE` statements that aren't currently representable by the `OpCreateConstraint` operation, for example: ```sql ALTER TABLE foo ADD CONSTRAINT bar UNIQUE NULLS NOT DISTINCT (a) ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) INCLUDE (b) ``` In these cases we must resort to converting to an `OpRawSQL`. Tests are added to cover these unrepresentable cases. Part of #504
1 parent fd94011 commit 06bfebd

File tree

5 files changed

+166
-4
lines changed

5 files changed

+166
-4
lines changed

pkg/sql2pgroll/alter_table.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,30 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err
3333
op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, false)
3434
case pgq.AlterTableType_AT_AlterColumnType:
3535
op, err = convertAlterTableAlterColumnType(stmt, alterTableCmd)
36+
case pgq.AlterTableType_AT_AddConstraint:
37+
op, err = convertAlterTableAddConstraint(stmt, alterTableCmd)
3638
}
3739

3840
if err != nil {
3941
return nil, err
4042
}
4143

44+
if op == nil {
45+
return nil, nil
46+
}
47+
4248
ops = append(ops, op)
4349
}
4450

4551
return ops, nil
4652
}
4753

54+
// convertAlterTableSetNotNull converts SQL statements like:
55+
//
56+
// `ALTER TABLE foo ALTER COLUMN a SET NOT NULL`
57+
// `ALTER TABLE foo ALTER COLUMN a DROP NOT NULL`
58+
//
59+
// to an OpAlterColumn operation.
4860
func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) (migrations.Operation, error) {
4961
return &migrations.OpAlterColumn{
5062
Table: stmt.GetRelation().GetRelname(),
@@ -55,6 +67,11 @@ func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCm
5567
}, nil
5668
}
5769

70+
// convertAlterTableAlterColumnType converts a SQL statement like:
71+
//
72+
// `ALTER TABLE foo ALTER COLUMN a SET DATA TYPE text`
73+
//
74+
// to an OpAlterColumn operation.
5875
func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
5976
node, ok := cmd.GetDef().Node.(*pgq.Node_ColumnDef)
6077
if !ok {
@@ -70,6 +87,89 @@ func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTa
7087
}, nil
7188
}
7289

90+
// convertAlterTableAddConstraint converts SQL statements like:
91+
//
92+
// `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)`
93+
//
94+
// To an OpCreateConstraint operation.
95+
func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
96+
node, ok := cmd.GetDef().Node.(*pgq.Node_Constraint)
97+
if !ok {
98+
return nil, fmt.Errorf("expected constraint definition, got %T", cmd.GetDef().Node)
99+
}
100+
101+
var op migrations.Operation
102+
var err error
103+
switch node.Constraint.GetContype() {
104+
case pgq.ConstrType_CONSTR_UNIQUE:
105+
op, err = convertAlterTableAddUniqueConstraint(stmt, node.Constraint)
106+
default:
107+
return nil, nil
108+
}
109+
110+
if err != nil {
111+
return nil, err
112+
}
113+
114+
return op, nil
115+
}
116+
117+
// convertAlterTableAddUniqueConstraint converts SQL statements like:
118+
//
119+
// `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)`
120+
//
121+
// to an OpCreateConstraint operation.
122+
func convertAlterTableAddUniqueConstraint(stmt *pgq.AlterTableStmt, constraint *pgq.Constraint) (migrations.Operation, error) {
123+
if !canConvertUniqueConstraint(constraint) {
124+
return nil, nil
125+
}
126+
127+
// Extract the columns covered by the unique constraint
128+
columns := make([]string, 0, len(constraint.GetKeys()))
129+
for _, keyNode := range constraint.GetKeys() {
130+
key, ok := keyNode.Node.(*pgq.Node_String_)
131+
if !ok {
132+
return nil, fmt.Errorf("expected string key, got %T", keyNode)
133+
}
134+
columns = append(columns, key.String_.GetSval())
135+
}
136+
137+
// Build the up and down SQL placeholders for each column covered by the
138+
// constraint
139+
upDown := make(map[string]string, len(columns))
140+
for _, column := range columns {
141+
upDown[column] = PlaceHolderSQL
142+
}
143+
144+
return &migrations.OpCreateConstraint{
145+
Type: migrations.OpCreateConstraintTypeUnique,
146+
Name: constraint.GetConname(),
147+
Table: stmt.GetRelation().GetRelname(),
148+
Columns: columns,
149+
Down: upDown,
150+
Up: upDown,
151+
}, nil
152+
}
153+
154+
// canConvertUniqueConstraint checks if the unique constraint `constraint` can
155+
// be faithfully converted to an OpCreateConstraint operation without losing
156+
// information.
157+
func canConvertUniqueConstraint(constraint *pgq.Constraint) bool {
158+
if constraint.GetNullsNotDistinct() {
159+
return false
160+
}
161+
if len(constraint.GetIncluding()) > 0 {
162+
return false
163+
}
164+
if len(constraint.GetOptions()) > 0 {
165+
return false
166+
}
167+
if constraint.GetIndexspace() != "" {
168+
return false
169+
}
170+
return true
171+
}
172+
73173
func ptr[T any](x T) *T {
74174
return &x
75175
}

pkg/sql2pgroll/alter_table_test.go

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ func TestConvertAlterTableStatements(t *testing.T) {
3535
sql: "ALTER TABLE foo ALTER COLUMN a TYPE text",
3636
expectedOp: expect.AlterTableOp3,
3737
},
38+
{
39+
sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)",
40+
expectedOp: expect.AlterTableOp4,
41+
},
42+
{
43+
sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a, b)",
44+
expectedOp: expect.AlterTableOp5,
45+
},
3846
}
3947

4048
for _, tc := range tests {
@@ -44,10 +52,31 @@ func TestConvertAlterTableStatements(t *testing.T) {
4452

4553
require.Len(t, ops, 1)
4654

47-
alterColumnOps, ok := ops[0].(*migrations.OpAlterColumn)
48-
require.True(t, ok)
55+
assert.Equal(t, tc.expectedOp, ops[0])
56+
})
57+
}
58+
}
59+
60+
func TestUnconvertableAlterTableAddConstraintStatements(t *testing.T) {
61+
t.Parallel()
62+
63+
tests := []string{
64+
// UNIQUE constraints with various options that are not representable by
65+
// `OpCreateConstraint` operations
66+
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE NULLS NOT DISTINCT (a)",
67+
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) INCLUDE (b)",
68+
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) WITH (fillfactor=70)",
69+
"ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a) USING INDEX TABLESPACE baz",
70+
}
71+
72+
for _, sql := range tests {
73+
t.Run(sql, func(t *testing.T) {
74+
ops, err := sql2pgroll.Convert(sql)
75+
require.NoError(t, err)
76+
77+
require.Len(t, ops, 1)
4978

50-
assert.Equal(t, tc.expectedOp, alterColumnOps)
79+
assert.Equal(t, expect.RawSQLOp(sql), ops[0])
5180
})
5281
}
5382
}

pkg/sql2pgroll/create_table.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
)
99

1010
// convertCreateStmt converts a CREATE TABLE statement to a pgroll operation.
11-
func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) {
11+
func convertCreateStmt(stmt *pgq.CreateStmt) (migrations.Operations, error) {
1212
columns := make([]migrations.Column, 0, len(stmt.TableElts))
1313
for _, elt := range stmt.TableElts {
1414
columns = append(columns, convertColumnDef(elt.GetColumnDef()))

pkg/sql2pgroll/expect/alter_table.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,30 @@ var AlterTableOp3 = &migrations.OpAlterColumn{
3131
Down: sql2pgroll.PlaceHolderSQL,
3232
}
3333

34+
var AlterTableOp4 = &migrations.OpCreateConstraint{
35+
Type: migrations.OpCreateConstraintTypeUnique,
36+
Name: "bar",
37+
Table: "foo",
38+
Columns: []string{"a"},
39+
Down: map[string]string{"a": sql2pgroll.PlaceHolderSQL},
40+
Up: map[string]string{"a": sql2pgroll.PlaceHolderSQL},
41+
}
42+
43+
var AlterTableOp5 = &migrations.OpCreateConstraint{
44+
Type: migrations.OpCreateConstraintTypeUnique,
45+
Name: "bar",
46+
Table: "foo",
47+
Columns: []string{"a", "b"},
48+
Down: map[string]string{
49+
"a": sql2pgroll.PlaceHolderSQL,
50+
"b": sql2pgroll.PlaceHolderSQL,
51+
},
52+
Up: map[string]string{
53+
"a": sql2pgroll.PlaceHolderSQL,
54+
"b": sql2pgroll.PlaceHolderSQL,
55+
},
56+
}
57+
3458
func ptr[T any](v T) *T {
3559
return &v
3660
}

pkg/sql2pgroll/expect/raw_sql.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package expect
4+
5+
import "github.com/xataio/pgroll/pkg/migrations"
6+
7+
func RawSQLOp(sql string) *migrations.OpRawSQL {
8+
return &migrations.OpRawSQL{Up: sql}
9+
}

0 commit comments

Comments
 (0)