diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index b2ec1ed..55c2c69 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -11,7 +11,7 @@ jobs: sqlite: strategy: matrix: - go: ["1.21", "1.20", "1.19"] + go: ["1.22", "1.21", "1.20"] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -41,7 +41,7 @@ jobs: strategy: matrix: dbversion: ["postgres:latest"] - go: ["1.21", "1.20", "1.19"] + go: ["1.22", "1.21", "1.20"] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..f30699f --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +@hashicorp/boundary \ No newline at end of file diff --git a/coverage/coverage.html b/coverage/coverage.html index a3f2b62..bd12817 100644 --- a/coverage/coverage.html +++ b/coverage/coverage.html @@ -61,11 +61,11 @@ - + - + @@ -77,7 +77,7 @@ - + @@ -204,9 +204,11 @@ // column values for database operations. See: Expr(...) // // Set name column to null example: +// // SetColumnValues(map[string]interface{}{"name": Expr("NULL")}) // // Set exp_time column to N seconds from now: +// // SetColumnValues(map[string]interface{}{"exp_time": Expr("wt_add_seconds_to_now(?)", 10)}) func Expr(expr string, args ...interface{}) ExprValue { return ExprValue{Sql: expr, Vars: args} @@ -244,7 +246,6 @@ // OnConflict specifies how to handle alternative actions to take when an insert // results in a unique constraint or exclusion constraint error. type OnConflict struct { - // Target specifies what conflict you want to define a policy for. This can // be any one of these: // Columns: the name of a specific column or columns @@ -468,6 +469,18 @@ // DeleteOp is a delete operation DeleteOp OpType = 3 + + // DefaultBatchSize is the default batch size for bulk operations like + // CreateItems. This value is used if the caller does not specify a size + // using the WithBatchSize(...) option. Note: some databases have a limit + // on the number of query parameters (postgres is currently 64k and sqlite + // is 32k) and/or size of a SQL statement (sqlite is currently 1bn bytes), + // so this value should be set to a value that is less than the limits for + // your target db. + // See: + // - https://www.postgresql.org/docs/current/limits.html + // - https://www.sqlite.org/limits.html + DefaultBatchSize = 1000 ) // VetForWriter provides an interface that Create and Update can use to vet the @@ -624,57 +637,137 @@ } // CreateItems will create multiple items of the same type. Supported options: -// WithDebug, WithBeforeWrite, WithAfterWrite, WithReturnRowsAffected, -// OnConflict, WithVersion, WithTable, and WithWhere. WithLookup is not a supported option. -func (rw *RW) CreateItems(ctx context.Context, createItems []interface{}, opt ...Option) error { +// WithBatchSize, WithDebug, WithBeforeWrite, WithAfterWrite, +// WithReturnRowsAffected, OnConflict, WithVersion, WithTable, and WithWhere. +// WithLookup is not a supported option. +func (rw *RW) CreateItems(ctx context.Context, createItems interface{}, opt ...Option) error { const op = "dbw.CreateItems" - if rw.underlying == nil { - return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) - } - if len(createItems) == 0 { - return fmt.Errorf("%s: missing interfaces: %w", op, ErrInvalidParameter) - } + switch { + case rw.underlying == nil: + return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) + case isNil(createItems): + return fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter) + } + valCreateItems := reflect.ValueOf(createItems) + switch { + case valCreateItems.Kind() != reflect.Slice: + return fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter) + case valCreateItems.Len() == 0: + return fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter) + } if err := raiseErrorOnHooks(createItems); err != nil { return fmt.Errorf("%s: %w", op, err) } opts := GetOpts(opt...) - if opts.WithLookup { - return fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter) - } - // verify that createItems are all the same type. + switch { + case opts.WithLookup: + return fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter) + } var foundType reflect.Type - for i, v := range createItems { + for i := 0; i < valCreateItems.Len(); i++ { + // verify that createItems are all the same type and do some bits on each item if i == 0 { - foundType = reflect.TypeOf(v) + foundType = reflect.TypeOf(valCreateItems.Index(i).Interface()) + } + currentType := reflect.TypeOf(valCreateItems.Index(i).Interface()) + if currentType == nil { + return fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter) } - currentType := reflect.TypeOf(v) - if foundType != currentType { + if foundType != currentType { return fmt.Errorf("%s: create items contains disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter) } + + // these fields should be nil, since they are not writeable and we want the + // db to manage them + setFieldsToNil(valCreateItems.Index(i).Interface(), NonCreatableFields()) + + // vet each item + if !opts.WithSkipVetForWrite { + if vetter, ok := valCreateItems.Index(i).Interface().(VetForWriter); ok { + if err := vetter.VetForWrite(ctx, rw, CreateOp); err != nil { + return fmt.Errorf("%s: %w", op, err) + } + } + } } + if opts.WithBeforeWrite != nil { if err := opts.WithBeforeWrite(createItems); err != nil { return fmt.Errorf("%s: error before write: %w", op, err) } } - var rowsAffected int64 - for _, item := range createItems { - if err := rw.Create(ctx, item, - WithOnConflict(opts.WithOnConflict), - WithReturnRowsAffected(&rowsAffected), - WithDebug(opts.WithDebug), - WithVersion(opts.WithVersion), - WithWhere(opts.WithWhereClause, opts.WithWhereClauseArgs...), - WithTable(opts.WithTable), - ); err != nil { - return fmt.Errorf("%s: %w", op, err) - } + + db := rw.underlying.wrapped.WithContext(ctx) + if opts.WithOnConflict != nil { + c := clause.OnConflict{} + switch opts.WithOnConflict.Target.(type) { + case Constraint: + c.OnConstraint = string(opts.WithOnConflict.Target.(Constraint)) + case Columns: + columns := make([]clause.Column, 0, len(opts.WithOnConflict.Target.(Columns))) + for _, name := range opts.WithOnConflict.Target.(Columns) { + columns = append(columns, clause.Column{Name: name}) + } + c.Columns = columns + default: + return fmt.Errorf("%s: invalid conflict target %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Target), ErrInvalidParameter) + } + + switch opts.WithOnConflict.Action.(type) { + case DoNothing: + c.DoNothing = true + case UpdateAll: + c.UpdateAll = true + case []ColumnValue: + updates := opts.WithOnConflict.Action.([]ColumnValue) + set := make(clause.Set, 0, len(updates)) + for _, s := range updates { + // make sure it's not one of the std immutable columns + if contains([]string{"createtime", "publicid"}, strings.ToLower(s.Column)) { + return fmt.Errorf("%s: cannot do update on conflict for column %s: %w", op, s.Column, ErrInvalidParameter) + } + switch sv := s.Value.(type) { + case Column: + set = append(set, sv.toAssignment(s.Column)) + case ExprValue: + set = append(set, sv.toAssignment(s.Column)) + default: + set = append(set, rawAssignment(s.Column, s.Value)) + } + } + c.DoUpdates = set + default: + return fmt.Errorf("%s: invalid conflict action %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Action), ErrInvalidParameter) + } + if opts.WithVersion != nil || opts.WithWhereClause != "" { + // this is a bit of a hack, but we need to pass in one of the items + // to get the where clause since we need to get the gorm Model and + // Parse the gorm statement to build the where clause + where, args, err := rw.whereClausesFromOpts(ctx, valCreateItems.Index(0).Interface(), opts) + if err != nil { + return fmt.Errorf("%s: %w", op, err) + } + whereConditions := db.Statement.BuildCondition(where, args...) + c.Where = clause.Where{Exprs: whereConditions} + } + db = db.Clauses(c) } + if opts.WithDebug { + db = db.Debug() + } + if opts.WithTable != "" { + db = db.Table(opts.WithTable) + } + + tx := db.CreateInBatches(createItems, opts.WithBatchSize) + if tx.Error != nil { + return fmt.Errorf("%s: create failed: %w", op, tx.Error) + } if opts.WithRowsAffected != nil { - *opts.WithRowsAffected = rowsAffected + *opts.WithRowsAffected = tx.RowsAffected } - if opts.WithAfterWrite != nil { - if err := opts.WithAfterWrite(createItems, int(rowsAffected)); err != nil { + if tx.RowsAffected > 0 && opts.WithAfterWrite != nil { + if err := opts.WithAfterWrite(createItems, int(tx.RowsAffected)); err != nil { return fmt.Errorf("%s: error after write: %w", op, err) } } @@ -755,7 +848,7 @@ "github.com/hashicorp/go-hclog" "github.com/jackc/pgconn" - _ "github.com/jackc/pgx/v4" // required to load postgres drivers + _ "github.com/jackc/pgx/v5" // required to load postgres drivers "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -1064,56 +1157,117 @@ } // DeleteItems will delete multiple items of the same type. Options supported: -// WithDebug, WithTable -func (rw *RW) DeleteItems(ctx context.Context, deleteItems []interface{}, opt ...Option) (int, error) { +// WithWhereClause, WithDebug, WithTable +func (rw *RW) DeleteItems(ctx context.Context, deleteItems interface{}, opt ...Option) (int, error) { const op = "dbw.DeleteItems" - if rw.underlying == nil { - return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) - } - if len(deleteItems) == 0 { - return noRowsAffected, fmt.Errorf("%s: no interfaces to delete: %w", op, ErrInvalidParameter) - } + switch { + case rw.underlying == nil: + return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) + case isNil(deleteItems): + return noRowsAffected, fmt.Errorf("%s: no interfaces to delete: %w", op, ErrInvalidParameter) + } + valDeleteItems := reflect.ValueOf(deleteItems) + switch { + case valDeleteItems.Kind() != reflect.Slice: + return noRowsAffected, fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter) + case valDeleteItems.Len() == 0: + return noRowsAffected, fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter) + + } if err := raiseErrorOnHooks(deleteItems); err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } + opts := GetOpts(opt...) - if opts.WithLookup { - return noRowsAffected, fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter) - } - // verify that createItems are all the same type. + switch { + case opts.WithLookup: + return noRowsAffected, fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter) + case opts.WithVersion != nil: + return noRowsAffected, fmt.Errorf("%s: with version is not a supported option: %w", op, ErrInvalidParameter) + } + + // we need to dig out the stmt so in just a sec we can make sure the PKs are + // set for all the items, so we'll just use the first item to do so. + mDb := rw.underlying.wrapped.Model(valDeleteItems.Index(0).Interface()) + err := mDb.Statement.Parse(valDeleteItems.Index(0).Interface()) + switch { + case err != nil: + return noRowsAffected, fmt.Errorf("%s: (internal error) error parsing stmt: %w", op, err) + case err == nil && mDb.Statement.Schema == nil: + return noRowsAffected, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown) + } + + // verify that deleteItems are all the same type, among a myriad of + // other things on the set of items var foundType reflect.Type - for i, v := range deleteItems { + + for i := 0; i < valDeleteItems.Len(); i++ { if i == 0 { - foundType = reflect.TypeOf(v) - } - currentType := reflect.TypeOf(v) - if foundType != currentType { - return noRowsAffected, fmt.Errorf("%s: items contain disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter) + foundType = reflect.TypeOf(valDeleteItems.Index(i).Interface()) } + currentType := reflect.TypeOf(valDeleteItems.Index(i).Interface()) + switch { + case isNil(valDeleteItems.Index(i).Interface()) || currentType == nil: + return noRowsAffected, fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter) + case foundType != currentType: + return noRowsAffected, fmt.Errorf("%s: items contain disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter) + } + if opts.WithWhereClause == "" { + // make sure the PK is set for the current item + reflectValue := reflect.Indirect(reflect.ValueOf(valDeleteItems.Index(i).Interface())) + for _, pf := range mDb.Statement.Schema.PrimaryFields { + if _, isZero := pf.ValueOf(ctx, reflectValue); isZero { + return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter) + } + } + } } + if opts.WithBeforeWrite != nil { if err := opts.WithBeforeWrite(deleteItems); err != nil { return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err) } } - rowsDeleted := 0 - for _, item := range deleteItems { - cnt, err := rw.Delete(ctx, item, - WithDebug(opts.WithDebug), - WithTable(opts.WithTable), - ) - rowsDeleted += cnt - if err != nil { - return rowsDeleted, fmt.Errorf("%s: %w", op, err) + + db := rw.underlying.wrapped.WithContext(ctx) + if opts.WithDebug { + db = db.Debug() + } + + if opts.WithWhereClause != "" { + where, args, err := rw.whereClausesFromOpts(ctx, valDeleteItems.Index(0).Interface(), opts) + if err != nil { + return noRowsAffected, fmt.Errorf("%s: %w", op, err) } + db = db.Where(where, args...) } - if rowsDeleted > 0 && opts.WithAfterWrite != nil { + + switch { + case opts.WithTable != "": + db = db.Table(opts.WithTable) + default: + tabler, ok := valDeleteItems.Index(0).Interface().(tableNamer) + if ok { + db = db.Table(tabler.TableName()) + } + } + + db = db.Delete(deleteItems) + if db.Error != nil { + return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error) + } + rowsDeleted := int(db.RowsAffected) + if rowsDeleted > 0 && opts.WithAfterWrite != nil { if err := opts.WithAfterWrite(deleteItems, int(rowsDeleted)); err != nil { return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err) } } return rowsDeleted, nil } + +type tableNamer interface { + TableName() string +}