diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bf5604f..e13bed0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,7 +11,7 @@ jobs: run-tests: strategy: matrix: - go: ['1.19'] + go: ['1.20'] platform: [ubuntu-latest] runs-on: ubuntu-latest diff --git a/go.mod b/go.mod index 20dc8c4..430e4e2 100644 --- a/go.mod +++ b/go.mod @@ -1,21 +1,21 @@ module gorm.io/driver/postgres -go 1.19 +go 1.20 require ( - github.com/jackc/pgx/v5 v5.5.5 + github.com/jackc/pgx/v5 v5.6.0 gorm.io/gorm v1.25.10 ) require ( github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/crypto v0.31.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/text v0.21.0 // indirect ) retract v1.5.5 // Published accidentally. diff --git a/go.sum b/go.sum index a65ad5b..50dd830 100644 --- a/go.sum +++ b/go.sum @@ -2,12 +2,12 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -18,12 +18,12 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/migrator.go b/migrator.go index df18db1..6b57ce6 100644 --- a/migrator.go +++ b/migrator.go @@ -3,10 +3,10 @@ package postgres import ( "database/sql" "fmt" + "github.com/jackc/pgx/v5" "regexp" "strings" - "github.com/jackc/pgx/v5" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" @@ -38,28 +38,34 @@ WHERE ` var typeAliasMap = map[string][]string{ - "int": {"integer"}, - "int2": {"smallint"}, - "int4": {"integer"}, - "int8": {"bigint"}, - "smallint": {"int2"}, - "integer": {"int4"}, - "bigint": {"int8"}, - "decimal": {"numeric"}, - "numeric": {"decimal"}, - "timestamptz": {"timestamp with time zone"}, - "timestamp with time zone": {"timestamptz"}, - "bool": {"boolean"}, - "boolean": {"bool"}, - "serial2": {"smallserial"}, - "serial4": {"serial"}, - "serial8": {"bigserial"}, - "varbit": {"bit varying"}, - "char": {"character"}, - "varchar": {"character varying"}, - "float4": {"real"}, - "float8": {"double precision"}, - "timetz": {"time with time zone"}, + "int": {"integer"}, + "int2": {"smallint"}, + "int4": {"integer"}, + "int8": {"bigint"}, + "smallint": {"int2"}, + "integer": {"int4"}, + "bigint": {"int8"}, + "date": {"date"}, + "decimal": {"numeric"}, + "numeric": {"decimal"}, + "timestamp": {"timestamp"}, + "timestamptz": {"timestamp with time zone"}, + "timestamp without time zone": {"timestamp"}, + "timestamp with time zone": {"timestamptz"}, + "bool": {"boolean"}, + "boolean": {"bool"}, + "serial2": {"smallserial"}, + "serial4": {"serial"}, + "serial8": {"bigserial"}, + "varbit": {"bit varying"}, + "char": {"character"}, + "varchar": {"character varying"}, + "float4": {"real"}, + "float8": {"double precision"}, + "time": {"time"}, + "timetz": {"time with time zone"}, + "time without time zone": {"time"}, + "time with time zone": {"timetz"}, } type Migrator struct { @@ -130,7 +136,8 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } createIndexSQL += "INDEX " - if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { + hasConcurrentOption := strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" + if hasConcurrentOption { createIndexSQL += "CONCURRENTLY " } @@ -142,6 +149,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " ?" } + if idx.Option != "" && !hasConcurrentOption { + createIndexSQL += " " + idx.Option + } + if idx.Where != "" { createIndexSQL += " WHERE " + idx.Where } @@ -385,10 +396,16 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return err } } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } + } else if !field.HasDefaultValue { + // case - as-is column has default value and to-be column has no default value + // need to drop default + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } } } return nil @@ -484,8 +501,8 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, column.LengthValue = typeLenValue } - if (strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && - strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)")) || (identityIncrement.Valid && identityIncrement.String != "") { + autoIncrementValuePattern := regexp.MustCompile(`^nextval\('"?[^']+seq"?'::regclass\)$`) + if autoIncrementValuePattern.MatchString(column.DefaultValueValue.String) || (identityIncrement.Valid && identityIncrement.String != "") { column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} column.DefaultValueValue = sql.NullString{} } diff --git a/postgres.go b/postgres.go index e865b0f..2d8fd99 100644 --- a/postgres.go +++ b/postgres.go @@ -1,13 +1,16 @@ package postgres import ( + "context" "database/sql" "fmt" "regexp" "strconv" "strings" + "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/stdlib" "gorm.io/gorm" "gorm.io/gorm/callbacks" @@ -31,7 +34,7 @@ type Config struct { } var ( - timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") + timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone|timezone)=(.*?)($|&| )") defaultIdentifierLength = 63 //maximum identifier length for postgres ) @@ -99,10 +102,23 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol } result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN) + var options []stdlib.OptionOpenDB if len(result) > 2 { config.RuntimeParams["timezone"] = result[2] + options = append(options, stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { + loc, tzErr := time.LoadLocation(result[2]) + if tzErr != nil { + return tzErr + } + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: loc}, + }) + return nil + })) } - db.ConnPool = stdlib.OpenDB(*config) + db.ConnPool = stdlib.OpenDB(*config, options...) } return } @@ -228,7 +244,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } return "decimal" case schema.String: - if field.Size > 0 { + if field.Size > 0 && field.Size <= 10485760 { return fmt.Sprintf("varchar(%d)", field.Size) } return "text" diff --git a/postgres_test.go b/postgres_test.go new file mode 100644 index 0000000..4f28726 --- /dev/null +++ b/postgres_test.go @@ -0,0 +1,57 @@ +package postgres + +import ( + "testing" +"gorm.io/gorm/schema" +) + +func Test_DataTypeOf(t *testing.T) { + type fields struct { + Config *Config + } + type args struct { + field *schema.Field + } + tests := []struct { + name string + fields fields + args args + want string + } { + { + name: "it should return boolean", + args: args{field: &schema.Field{DataType: schema.Bool}}, + want: "boolean", + }, + { + name: "it should return text -1", + args: args{field: &schema.Field{DataType: schema.String, Size: -1}}, + want: "text", + }, + { + name: "it should return text > 10485760", + args: args{field: &schema.Field{DataType: schema.String, Size: 12345678}}, + want: "text", + }, + { + name: "it should return varchar(100)", + args: args{field: &schema.Field{DataType: schema.String, Size: 100}}, + want: "varchar(100)", + }, + { + name: "it should return varchar(10485760)", + args: args{field: &schema.Field{DataType: schema.String, Size: 10485760}}, + want: "varchar(10485760)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dialector := Dialector{ + Config: tt.fields.Config, + } + if got := dialector.DataTypeOf(tt.args.field); got != tt.want { + t.Errorf("DataTypeOf() = %v, want %v", got, tt.want) + } + }) + } +} \ No newline at end of file