Fix joins in db.Find(AndCount) (#28978)

This commit is contained in:
KN4CK3R 2024-01-30 03:37:24 +01:00 committed by GitHub
parent 8ef53c871b
commit 27d4c11ec3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -133,17 +133,21 @@ type FindOptionsOrder interface {
// Find represents a common find function which accept an options interface // Find represents a common find function which accept an options interface
func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) { func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) {
sess := GetEngine(ctx) sess := GetEngine(ctx).Where(opts.ToConds())
if joinOpt, ok := opts.(FindOptionsJoin); ok && len(joinOpt.ToJoins()) > 0 { if joinOpt, ok := opts.(FindOptionsJoin); ok {
for _, joinFunc := range joinOpt.ToJoins() { for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil { if err := joinFunc(sess); err != nil {
return nil, err return nil, err
} }
} }
} }
if orderOpt, ok := opts.(FindOptionsOrder); ok {
if order := orderOpt.ToOrders(); order != "" {
sess.OrderBy(order)
}
}
sess = sess.Where(opts.ToConds())
page, pageSize := opts.GetPage(), opts.GetPageSize() page, pageSize := opts.GetPage(), opts.GetPageSize()
if !opts.IsListAll() && pageSize > 0 { if !opts.IsListAll() && pageSize > 0 {
if page == 0 { if page == 0 {
@ -151,9 +155,6 @@ func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) {
} }
sess.Limit(pageSize, (page-1)*pageSize) sess.Limit(pageSize, (page-1)*pageSize)
} }
if newOpt, ok := opts.(FindOptionsOrder); ok && newOpt.ToOrders() != "" {
sess.OrderBy(newOpt.ToOrders())
}
findPageSize := defaultFindSliceSize findPageSize := defaultFindSliceSize
if pageSize > 0 { if pageSize > 0 {
@ -168,8 +169,8 @@ func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) {
// Count represents a common count function which accept an options interface // Count represents a common count function which accept an options interface
func Count[T any](ctx context.Context, opts FindOptions) (int64, error) { func Count[T any](ctx context.Context, opts FindOptions) (int64, error) {
sess := GetEngine(ctx) sess := GetEngine(ctx).Where(opts.ToConds())
if joinOpt, ok := opts.(FindOptionsJoin); ok && len(joinOpt.ToJoins()) > 0 { if joinOpt, ok := opts.(FindOptionsJoin); ok {
for _, joinFunc := range joinOpt.ToJoins() { for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil { if err := joinFunc(sess); err != nil {
return 0, err return 0, err
@ -178,7 +179,7 @@ func Count[T any](ctx context.Context, opts FindOptions) (int64, error) {
} }
var object T var object T
return sess.Where(opts.ToConds()).Count(&object) return sess.Count(&object)
} }
// FindAndCount represents a common findandcount function which accept an options interface // FindAndCount represents a common findandcount function which accept an options interface
@ -188,8 +189,17 @@ func FindAndCount[T any](ctx context.Context, opts FindOptions) ([]*T, int64, er
if !opts.IsListAll() && pageSize > 0 && page >= 1 { if !opts.IsListAll() && pageSize > 0 && page >= 1 {
sess.Limit(pageSize, (page-1)*pageSize) sess.Limit(pageSize, (page-1)*pageSize)
} }
if newOpt, ok := opts.(FindOptionsOrder); ok && newOpt.ToOrders() != "" { if joinOpt, ok := opts.(FindOptionsJoin); ok {
sess.OrderBy(newOpt.ToOrders()) for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil {
return nil, 0, err
}
}
}
if orderOpt, ok := opts.(FindOptionsOrder); ok {
if order := orderOpt.ToOrders(); order != "" {
sess.OrderBy(order)
}
} }
findPageSize := defaultFindSliceSize findPageSize := defaultFindSliceSize