package db
import (
"context"
"errors"
"
gorm.io/gorm"
"
gorm.io/gorm/clause"
)
type Dao[T any] interface {
//Save 保存
Save(ctx context.Context, model *T) error
// Create 添加
Create(ctx context.Context, model *T) error
// CreateOrUpdate 批量添加或更新
CreateOrUpdate(ctx context.Context, model *T) error
// CreateBatch 批量添加
CreateBatch(ctx context.Context, models []T) error
// CreateBatchSize 批量添加指定大小
CreateBatchSize(ctx context.Context, models []T, size int) error
// Update 更新
Update(ctx context.Context, condition, model *T) error
//UpdaterWrapper 更新
UpdaterWrapper(ctx context.Context, wrapper *Wrapper, model *T) error
// GetByID 根据 ID 获取
GetByID(ctx context.Context, ID any) (*T, error)
// GetByCondition 根据条件获取
GetByCondition(ctx context.Context, condition T) (*T, error)
//GetByWrapper 根据条件获取
GetByWrapper(ctx context.Context, wrapper *Wrapper) (*T, error)
// Page 分页查询
Page(ctx context.Context, page, size int) ([]T, int64, error)
// PageByCondition 分页查询
PageByCondition(ctx context.Context, page, size int, condition *T) ([]T, int64, error)
// PageByWrapper 分页查询
PageByWrapper(ctx context.Context, page, size int, wrapper *Wrapper) ([]T, int64, error)
//ListByCondition 根据条件获取多个
ListByCondition(ctx context.Context, condition T) ([]T, error)
// ListByWrapper 根据条件获取多个
ListByWrapper(ctx context.Context, wrapper *Wrapper) ([]T, error)
// Count 统计数量
Count(ctx context.Context) (int64, error)
// CountByWrapper 统计数量
CountByWrapper(ctx context.Context, wrapper *Wrapper) (int64, error)
// Delete 删除
Delete(ctx context.Context, value any) error
//DeleteByWrapper 删除
DeleteByWrapper(ctx context.Context, wrapper *Wrapper) error
}
type Curd[T any] struct {
model T
}
// Save 保存
func (c *Curd[T]) Save(ctx context.Context, model *T) error {
return DB().Model(&c.model).WithContext(ctx).Save(&model).Error
}
// Create 添加
func (c *Curd[T]) Create(ctx context.Context, model *T) error {
err := DB().Model(&c.model).WithContext(ctx).Create(&model).Error
if err != nil {
if
errors.Is(err, gorm.ErrDuplicatedKey) {
return errors.New("record already exists")
}
return err
}
return nil
}
// CreateOrUpdate 添加或更新
func (c *Curd[T]) CreateOrUpdate(ctx context.Context, model *T) error {
err := DB().Model(&c.model).WithContext(ctx).Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(&model).Error
if err != nil {
if
errors.Is(err, gorm.ErrDuplicatedKey) {
return errors.New("record already exists")
}
return err
}
return nil
}
// CreateBatch 批量添加
func (c *Curd[T]) CreateBatch(ctx context.Context, models []T) error {
err := DB().Model(&c.model).WithContext(ctx).Create(&models).Error
if err != nil {
if
errors.Is(err, gorm.ErrDuplicatedKey) {
return errors.New("数据重复")
}
return err
}
return nil
}
// CreateBatchSize 批量添加指定大小
func (c *Curd[T]) CreateBatchSize(ctx context.Context, models []T, size int) error {
err := DB().Model(&c.model).WithContext(ctx).CreateInBatches(&models, size).Error
if err != nil {
if
errors.Is(err, gorm.ErrDuplicatedKey) {
return errors.New("数据重复")
}
return err
}
return nil
}
// Update 更新 传入的 model 必须包含主键
func (c *Curd[T]) Update(ctx context.Context, condition, model *T) error {
err := DB().Model(&model).WithContext(ctx).Updates(&model).Where(&condition).Error
if err != nil {
if
errors.Is(err, gorm.ErrDuplicatedKey) {
return errors.New("数据重复")
}
return err
}
return nil
}
// UpdaterWrapper 根据构造器更新
func (c *Curd[T]) UpdaterWrapper(ctx context.Context, wrapper *Wrapper, model *T) error {
err := DB().Model(&model).WithContext(ctx).Scopes(
wrapper.Build()).Updates(&model).Error
if err != nil {
if
errors.Is(err, gorm.ErrDuplicatedKey) {
return errors.New("数据重复")
}
return err
}
return nil
}
// Delete 删除 value 可以为单个主键或者多个主键切片
func (c *Curd[T]) Delete(ctx context.Context, value any) error {
return DB().Model(&c.model).WithContext(ctx).Delete(&c.model, value).Error
}
// DeleteByWrapper 根据构造器删除
func (c *Curd[T]) DeleteByWrapper(ctx context.Context, wrapper *Wrapper) error {
return DB().Model(&c.model).WithContext(ctx).Scopes(
wrapper.Build()).Delete(&c.model).Error
}
// GetByID 根据 ID 获取
func (c *Curd[T]) GetByID(ctx context.Context, ID any) (*T, error) {
var model T
err := DB().Model(&c.model).WithContext(ctx).Take(&model, ID).Error
if err != nil {
return nil, err
}
return &model, nil
}
// GetByCondition 根据条件获取
func (c *Curd[T]) GetByCondition(ctx context.Context, condition *T) (*T, error) {
var model T
err := DB().Model(&c.model).WithContext(ctx).Where(&condition).Take(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}
// GetByWrapper 根据构造条件获取
func (c *Curd[T]) GetByWrapper(ctx context.Context, wrapper *Wrapper) (*T, error) {
var model T
err := DB().Model(&c.model).WithContext(ctx).Scopes(
wrapper.Build()).Take(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}
// Page 分页查询
func (c *Curd[T]) Page(ctx context.Context, page, size int) (models []T, total int64, err error) {
models = make([]T, 0)
if err = DB().Model(&c.model).WithContext(ctx).Count(&total).Scopes(Page(page, size)).Find(&models).Error; err != nil {
return models, 0, err
}
return models, total, nil
}
// PageByCondition 分页查询 根据条件
func (c *Curd[T]) PageByCondition(ctx context.Context, page, size int, condition *T) (models []T, total int64, err error) {
models = make([]T, 0)
if err = DB().Model(&c.model).WithContext(ctx).Where(&condition).Count(&total).Scopes(Page(page, size)).Find(&models).Error; err != nil {
return models, 0, err
}
return models, total, nil
}
// PageByWrapper 分页查询 根据构造器
func (c *Curd[T]) PageByWrapper(ctx context.Context, page, size int, wrapper *Wrapper) (models []T, total int64, err error) {
models = make([]T, 0)
err = DB().WithContext(ctx).Scopes(
wrapper.Build()).Count(&total).Scopes(Page(page, size)).Find(&models).Error
if err != nil {
return models, 0, err
}
return models, total, nil
}
// ListByCondition 根据条件获取多个
func (c *Curd[T]) ListByCondition(ctx context.Context, condition *T) ([]T, error) {
models := make([]T, 0)
err := DB().Model(&c.model).WithContext(ctx).Where(&condition).Find(&models).Error
if err != nil {
return models, err
}
return models, nil
}
// ListByWrapper 根据构造器获取多个
func (c *Curd[T]) ListByWrapper(ctx context.Context, wrapper *Wrapper) (models []T, err error) {
models = make([]T, 0)
err = DB().Model(&c.model).WithContext(ctx).Scopes(
wrapper.Build()).Find(&models).Error
if err != nil {
return models, err
}
return models, nil
}
// Count 统计数量
func (c *Curd[T]) Count(ctx context.Context) (total int64, err error) {
if err = DB().WithContext(ctx).Model(&c.model).Count(&total).Error; err != nil {
return 0, err
}
return total, nil
}
// CountByWrapper 统计数量 根据构造器
func (c *Curd[T]) CountByWrapper(ctx context.Context, wrapper *Wrapper) (total int64, err error) {
if err = DB().WithContext(ctx).Scopes(
wrapper.Build()).Count(&total).Error; err != nil {
return 0, err
}
return total, nil
}
//用例
type memberDao struct {
db.Curd[model.Member]
}