V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
The Go Programming Language
http://golang.org/
Go Playground
Go Projects
Revel Web Framework
cooooing
V2EX  ›  Go 编程语言

Go 中如何处理需要方法泛型的写法?

  •  
  •   cooooing · 19 小时 2 分钟前 · 824 次点击

    首先是 Stream Map 的例子,需要将一个类型映射为另一个类型:

    不支持这种写法:

    func (s Stream[T])Map[T any, R any](stream Stream[T], mapper func(T) R) Stream[R] {
    	...
    }
    

    只能写成这种,链式调用啥的就断了。有点难受

    func Map[T any, R any] (stream Stream[T], mapper func(T) R) Stream[R] {
    	...
    }
    

    另一个例子是 orm 的,有一个管理多数据源的类,从中获取特定类型的查询(直接返回对应类型的结果):

    type DataSource struct {
    	DB *sql.DB
        // 是否打印日志、执行超时时间等数据源的全局设置。大概是个工厂吧
    }
    
    func(db *DataSource) GetExcutor[T any]() *Excutor[T] {
    	return &Executor[T]{DB:db.DB}
    }
    
    func (e *Executor[T]) List() ([]T, error) {
    	...
    }
    

    Go 不支持这个,有没有什么好的写法实现

    10 条回复    2025-08-28 15:51:54 +08:00
    devhxy
        1
    devhxy  
       18 小时 37 分钟前
    go 的泛型本就不完善,不要用 javaer 的思考方式去写 go
    morebuff
        2
    morebuff  
       18 小时 27 分钟前
    package db

    import (
    "context"
    "errors"

    "gorm.io/gorm"
    "gorm.io/gorm/clause"
    )

    type Dao[T any] interface {
    //Save 保存
    Save(ctx context.Context, model *T) error
    // Create 添加
    Create(ctx context.Context, model *T) error
    // CreateOrUpdate 批量添加或更新
    CreateOrUpdate(ctx context.Context, model *T) error
    // CreateBatch 批量添加
    CreateBatch(ctx context.Context, models []T) error
    // CreateBatchSize 批量添加指定大小
    CreateBatchSize(ctx context.Context, models []T, size int) error
    // Update 更新
    Update(ctx context.Context, condition, model *T) error
    //UpdaterWrapper 更新
    UpdaterWrapper(ctx context.Context, wrapper *Wrapper, model *T) error
    // GetByID 根据 ID 获取
    GetByID(ctx context.Context, ID any) (*T, error)
    // GetByCondition 根据条件获取
    GetByCondition(ctx context.Context, condition T) (*T, error)
    //GetByWrapper 根据条件获取
    GetByWrapper(ctx context.Context, wrapper *Wrapper) (*T, error)
    // Page 分页查询
    Page(ctx context.Context, page, size int) ([]T, int64, error)
    // PageByCondition 分页查询
    PageByCondition(ctx context.Context, page, size int, condition *T) ([]T, int64, error)
    // PageByWrapper 分页查询
    PageByWrapper(ctx context.Context, page, size int, wrapper *Wrapper) ([]T, int64, error)
    //ListByCondition 根据条件获取多个
    ListByCondition(ctx context.Context, condition T) ([]T, error)
    // ListByWrapper 根据条件获取多个
    ListByWrapper(ctx context.Context, wrapper *Wrapper) ([]T, error)
    // Count 统计数量
    Count(ctx context.Context) (int64, error)
    // CountByWrapper 统计数量
    CountByWrapper(ctx context.Context, wrapper *Wrapper) (int64, error)
    // Delete 删除
    Delete(ctx context.Context, value any) error
    //DeleteByWrapper 删除
    DeleteByWrapper(ctx context.Context, wrapper *Wrapper) error
    }

    type Curd[T any] struct {
    model T
    }

    // Save 保存
    func (c *Curd[T]) Save(ctx context.Context, model *T) error {
    return DB().Model(&c.model).WithContext(ctx).Save(&model).Error
    }

    // Create 添加
    func (c *Curd[T]) Create(ctx context.Context, model *T) error {
    err := DB().Model(&c.model).WithContext(ctx).Create(&model).Error
    if err != nil {
    if errors.Is(err, gorm.ErrDuplicatedKey) {
    return errors.New("record already exists")
    }
    return err
    }
    return nil
    }

    // CreateOrUpdate 添加或更新
    func (c *Curd[T]) CreateOrUpdate(ctx context.Context, model *T) error {
    err := DB().Model(&c.model).WithContext(ctx).Clauses(clause.OnConflict{
    UpdateAll: true,
    }).Create(&model).Error
    if err != nil {
    if errors.Is(err, gorm.ErrDuplicatedKey) {
    return errors.New("record already exists")
    }
    return err
    }
    return nil
    }

    // CreateBatch 批量添加
    func (c *Curd[T]) CreateBatch(ctx context.Context, models []T) error {
    err := DB().Model(&c.model).WithContext(ctx).Create(&models).Error
    if err != nil {
    if errors.Is(err, gorm.ErrDuplicatedKey) {
    return errors.New("数据重复")
    }
    return err
    }
    return nil
    }

    // CreateBatchSize 批量添加指定大小
    func (c *Curd[T]) CreateBatchSize(ctx context.Context, models []T, size int) error {
    err := DB().Model(&c.model).WithContext(ctx).CreateInBatches(&models, size).Error
    if err != nil {
    if errors.Is(err, gorm.ErrDuplicatedKey) {
    return errors.New("数据重复")
    }
    return err
    }
    return nil
    }

    // Update 更新 传入的 model 必须包含主键
    func (c *Curd[T]) Update(ctx context.Context, condition, model *T) error {
    err := DB().Model(&model).WithContext(ctx).Updates(&model).Where(&condition).Error
    if err != nil {
    if errors.Is(err, gorm.ErrDuplicatedKey) {
    return errors.New("数据重复")
    }
    return err
    }
    return nil
    }

    // UpdaterWrapper 根据构造器更新
    func (c *Curd[T]) UpdaterWrapper(ctx context.Context, wrapper *Wrapper, model *T) error {
    err := DB().Model(&model).WithContext(ctx).Scopes(wrapper.Build()).Updates(&model).Error
    if err != nil {
    if errors.Is(err, gorm.ErrDuplicatedKey) {
    return errors.New("数据重复")
    }
    return err
    }
    return nil
    }

    // Delete 删除 value 可以为单个主键或者多个主键切片
    func (c *Curd[T]) Delete(ctx context.Context, value any) error {
    return DB().Model(&c.model).WithContext(ctx).Delete(&c.model, value).Error
    }

    // DeleteByWrapper 根据构造器删除
    func (c *Curd[T]) DeleteByWrapper(ctx context.Context, wrapper *Wrapper) error {
    return DB().Model(&c.model).WithContext(ctx).Scopes(wrapper.Build()).Delete(&c.model).Error
    }

    // GetByID 根据 ID 获取
    func (c *Curd[T]) GetByID(ctx context.Context, ID any) (*T, error) {
    var model T
    err := DB().Model(&c.model).WithContext(ctx).Take(&model, ID).Error
    if err != nil {
    return nil, err
    }
    return &model, nil
    }

    // GetByCondition 根据条件获取
    func (c *Curd[T]) GetByCondition(ctx context.Context, condition *T) (*T, error) {
    var model T
    err := DB().Model(&c.model).WithContext(ctx).Where(&condition).Take(&model).Error
    if err != nil {
    return nil, err
    }
    return &model, nil
    }

    // GetByWrapper 根据构造条件获取
    func (c *Curd[T]) GetByWrapper(ctx context.Context, wrapper *Wrapper) (*T, error) {
    var model T
    err := DB().Model(&c.model).WithContext(ctx).Scopes(wrapper.Build()).Take(&model).Error
    if err != nil {
    return nil, err
    }
    return &model, nil
    }

    // Page 分页查询
    func (c *Curd[T]) Page(ctx context.Context, page, size int) (models []T, total int64, err error) {
    models = make([]T, 0)
    if err = DB().Model(&c.model).WithContext(ctx).Count(&total).Scopes(Page(page, size)).Find(&models).Error; err != nil {
    return models, 0, err
    }
    return models, total, nil
    }

    // PageByCondition 分页查询 根据条件
    func (c *Curd[T]) PageByCondition(ctx context.Context, page, size int, condition *T) (models []T, total int64, err error) {
    models = make([]T, 0)
    if err = DB().Model(&c.model).WithContext(ctx).Where(&condition).Count(&total).Scopes(Page(page, size)).Find(&models).Error; err != nil {
    return models, 0, err
    }
    return models, total, nil
    }

    // PageByWrapper 分页查询 根据构造器
    func (c *Curd[T]) PageByWrapper(ctx context.Context, page, size int, wrapper *Wrapper) (models []T, total int64, err error) {
    models = make([]T, 0)
    err = DB().WithContext(ctx).Scopes(wrapper.Build()).Count(&total).Scopes(Page(page, size)).Find(&models).Error
    if err != nil {
    return models, 0, err
    }
    return models, total, nil
    }

    // ListByCondition 根据条件获取多个
    func (c *Curd[T]) ListByCondition(ctx context.Context, condition *T) ([]T, error) {
    models := make([]T, 0)
    err := DB().Model(&c.model).WithContext(ctx).Where(&condition).Find(&models).Error
    if err != nil {
    return models, err
    }
    return models, nil
    }

    // ListByWrapper 根据构造器获取多个
    func (c *Curd[T]) ListByWrapper(ctx context.Context, wrapper *Wrapper) (models []T, err error) {
    models = make([]T, 0)
    err = DB().Model(&c.model).WithContext(ctx).Scopes(wrapper.Build()).Find(&models).Error
    if err != nil {
    return models, err
    }
    return models, nil
    }

    // Count 统计数量
    func (c *Curd[T]) Count(ctx context.Context) (total int64, err error) {
    if err = DB().WithContext(ctx).Model(&c.model).Count(&total).Error; err != nil {
    return 0, err
    }
    return total, nil
    }

    // CountByWrapper 统计数量 根据构造器
    func (c *Curd[T]) CountByWrapper(ctx context.Context, wrapper *Wrapper) (total int64, err error) {
    if err = DB().WithContext(ctx).Scopes(wrapper.Build()).Count(&total).Error; err != nil {
    return 0, err
    }
    return total, nil
    }





    //用例
    type memberDao struct {
    db.Curd[model.Member]
    }
    cooooing
        3
    cooooing  
    OP
       18 小时 13 分钟前
    @morebuff 有完整的 github 仓库地址吗,学习下
    Dorathea
        4
    Dorathea  
       18 小时 12 分钟前
    没能理解你说的
    比如 "链式调用", 你指的是 "class.method().method()" 么?
    那应该是可以的啊 [代码]( https://go.dev/play/p/LVRcU1G3kCQ)
    还有你说的

    这个看起来在上面的代码里也有?
    scopeccsky1111
        5
    scopeccsky1111  
       18 小时 4 分钟前
    @cooooing #3 下面的方式是支持的啊, 只是不支持方法中 func (s *Struct) method[T any]() 这种形式的
    morebuff
        6
    morebuff  
       18 小时 0 分钟前
    @cooooing 我自己项目中用的,模仿 mybatis-plus 封装了一点,平时用着方便
    package db

    import (
    "context"

    "gorm.io/gorm"
    )

    type Wrapper struct {
    db *gorm.DB
    }

    func BuildWrapper() *Wrapper {
    return &Wrapper{db: DB()}
    }

    // DB 获取原始 DB 和错误的方法
    func (w *Wrapper) DB() *gorm.DB {
    return w.db
    }

    func (w *Wrapper) GetError() error {
    return w.db.Error
    }

    // EQ 等于
    func (w *Wrapper) EQ(column string, value any) *Wrapper {
    w.db = w.db.Where(column+" = ?", value)
    return w
    }

    // NE 不等于
    func (w *Wrapper) NE(column string, value any) *Wrapper {
    w.db = w.db.Where(column+" <> ?", value)
    return w
    }

    // GT 大于
    func (w *Wrapper) GT(column string, value any) *Wrapper {
    w.db = w.db.Where(column+" > ?", value)
    return w
    }

    // GE 大于等于
    func (w *Wrapper) GE(column string, value any) *Wrapper {
    w.db = w.db.Where(column+" >= ?", value)
    return w
    }

    // LT 小于
    func (w *Wrapper) LT(column string, value any) *Wrapper {
    w.db = w.db.Where(column+" < ?", value)
    return w
    }

    // LE 小于等于
    func (w *Wrapper) LE(column string, value any) *Wrapper {
    w.db = w.db.Where(column+" <= ?", value)
    return w
    }

    // Like 模糊查询
    func (w *Wrapper) Like(column string, value any) *Wrapper {
    w.db = w.db.Where(column+" LIKE ?", value)
    return w
    }

    // NotLike 模糊查询
    func (w *Wrapper) NotLike(column string, value any) *Wrapper {
    w.db = w.db.Where(column+" NOT LIKE ?", value)
    return w
    }

    // LikeLeft 模糊查询
    func (w *Wrapper) LikeLeft(column string, value string) *Wrapper {
    w.db = w.db.Where(column+" LIKE ?", "%"+value)
    return w
    }

    // LikeRight 模糊查询
    func (w *Wrapper) LikeRight(column string, value string) *Wrapper {
    w.db = w.db.Where(column+" LIKE ?", value+"%")
    return w
    }

    // NotLikeLeft 模糊查询
    func (w *Wrapper) NotLikeLeft(column string, value string) *Wrapper {
    w.db = w.db.Where(column+" NOT LIKE ?", "%"+value)
    return w
    }

    // NotLikeRight 模糊查询
    func (w *Wrapper) NotLikeRight(column string, value string) *Wrapper {
    w.db = w.db.Where(column+" NOT LIKE ?", value+"%")
    return w
    }

    // IN 批量查询
    func (w *Wrapper) IN(column string, values any) *Wrapper {
    w.db = w.db.Where(column+" IN ?", values)
    return w
    }

    // NotIN 批量查询
    func (w *Wrapper) NotIN(column string, values any) *Wrapper {
    w.db = w.db.Not(column+" IN ?", values)
    return w
    }

    // Between 区间查询
    func (w *Wrapper) Between(column string, before, after any) *Wrapper {
    w.db = w.db.Where(column+" BETWEEN ? AND ?", before, after)
    return w
    }

    // NotBetween 区间查询
    func (w *Wrapper) NotBetween(column string, before, after any) *Wrapper {
    w.db = w.db.Where(column+" NOT BETWEEN ? AND ?", before, after)
    return w
    }

    // Or 或者
    func (w *Wrapper) Or(condition, value any) *Wrapper {
    w.db = w.db.Or(condition, value)
    return w
    }

    // OrderBy 排序 desc 是否倒序
    func (w *Wrapper) OrderBy(column string, desc bool) *Wrapper {
    if desc {
    w.db = w.db.Order(column + " DESC")
    } else {
    w.db = w.db.Order(column + " ASC")
    }
    return w
    }

    // OrderByAsc 排序
    func (w *Wrapper) OrderByAsc(column string) *Wrapper {
    w.db = w.db.Order(column + " ASC")
    return w
    }

    // OrderByDesc 排序 desc 是否倒序
    func (w *Wrapper) OrderByDesc(column string) *Wrapper {
    w.db = w.db.Order(column + " DESC")
    return w
    }

    // GroupBy 分组
    func (w *Wrapper) GroupBy(column string) *Wrapper {
    w.db = w.db.Group(column)
    return w
    }

    // Limit 限制条数
    func (w *Wrapper) Limit(limit int) *Wrapper {
    w.db = w.db.Limit(limit)
    return w
    }

    // Offset 偏移量
    func (w *Wrapper) Offset(offset int) *Wrapper {
    w.db = w.db.Offset(offset)
    return w
    }

    // Columns 指定字段
    func (w *Wrapper) Columns(columns ...string) *Wrapper {
    w.db = w.db.Select(columns)
    return w
    }

    // Page 分页
    func (w *Wrapper) Page(page, size int) *Wrapper {
    w.db = w.db.Scopes(Page(page, size))
    return w
    }

    // Scopes 使用构造器
    func (w *Wrapper) Scopes(fn func(db *gorm.DB) *gorm.DB) *Wrapper {
    w.db = w.db.Scopes(fn)
    return w
    }

    // IsNull 判断字段为 NULL
    func (w *Wrapper) IsNull(column string) *Wrapper {
    w.db = w.db.Where(column + " IS NULL")
    return w
    }

    // IsNotNull 判断字段不为 NULL
    func (w *Wrapper) IsNotNull(column string) *Wrapper {
    w.db = w.db.Where(column + " IS NOT NULL")
    return w
    }

    // Joins 连接查询
    func (w *Wrapper) Joins(query string, args ...interface{}) *Wrapper {
    w.db = w.db.Joins(query, args...)
    return w
    }

    // Find 查询
    func (w *Wrapper) Find(dest interface{}) error {
    return w.db.Find(dest).Error
    }

    // First 查询第一条
    func (w *Wrapper) First(dest interface{}) error {
    return w.db.First(dest).Error
    }

    // Count 统计
    func (w *Wrapper) Count(count *int64) error {
    return w.db.Count(count).Error
    }

    // Build 构建查询
    func (w *Wrapper) Build() func(db *gorm.DB) *gorm.DB {
    return func(db *gorm.DB) *gorm.DB {
    return db.Where(w.db.Statement.Clauses["WHERE"].Expression)
    }
    }

    // Clone 克隆当前 Wrapper
    func (w *Wrapper) Clone() *Wrapper {
    return &Wrapper{db: w.db.Session(&gorm.Session{})}
    }

    // WithContext 设置上下文
    func (w *Wrapper) WithContext(ctx context.Context) *Wrapper {
    w.db = w.db.WithContext(ctx)
    return w
    }

    // Transaction 添加事务
    func (w *Wrapper) Transaction(fn func(tx *Wrapper) error) error {
    return w.db.Transaction(func(tx *gorm.DB) error {
    txWrapper := &Wrapper{db: tx}
    return fn(txWrapper)
    })
    }

    // Scan 扫描结果到目标
    func (w *Wrapper) Scan(dest any) error {
    return w.db.Scan(dest).Error
    }

    // Prepare 预编译 SQL 语句
    func (w *Wrapper) Prepare() *Wrapper {
    w.db = w.db.Session(&gorm.Session{PrepareStmt: true})
    return w
    }

    // Select 仅选择需要的字段
    func (w *Wrapper) Select(fields ...string) *Wrapper {
    w.db = w.db.Select(fields)
    return w
    }
    cooooing
        7
    cooooing  
    OP
       17 小时 59 分钟前
    @Dorathea
    链式调用是指在 map 的时候改变了 stream 的泛型具体类型,比如 int->string

    Stream[int].
    Filter(...). // 正常链式调用
    Map(func(i int) string { return fmt.Sprintf("item %d", i) }). // 这里 map 返回的是 Stream[string]
    ToArray()

    只能写成
    Map(Stream[int].Filter(...)).ToArray()

    第二个那个是从一个没有泛型的类,创建返回一个有泛型的其他类,就需要第一个类的方法上使用泛型传入类型信息

    类似的还有这种写法:
    type Processor interface {
    Process[T any](data T)
    }
    sakeven
        8
    sakeven  
       17 小时 30 分钟前
    试了一下,只能定义成这样:

    ```go

    package main

    import (
    "fmt"
    )

    type Stream[T any, R any] struct {
    data []T
    }

    func (s Stream[T,R])Map(stream Stream[T,R], mapper func(T) R) Stream[R,T] {
    new := Stream[R,T]{data: make([]R, len(s.data))}

    for i := range s.data {
    new.data[i] = mapper(s.data[i])
    }
    return new
    }

    func main() {
    intStream := Stream[int, string]{data: []int{1, 2, 3, 4, 5}}
    squaredStream := intStream.Map(intStream, func(x int) string { return fmt.Sprintf("a %d", x) })
    fmt.Println(squaredStream.data)
    }

    ```


    R 的类型需要在创建 Stream 实例的时候,声明好具体类型。
    Gilfoyle26
        9
    Gilfoyle26  
       17 小时 25 分钟前
    感觉 go 已经很克制了,像隔壁的 swift 整出 200 多个关键字出来
    strobber16
        10
    strobber16  
       17 小时 6 分钟前
    关于   ·   帮助文档   ·   自助推广系统   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   3966 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 21ms · UTC 00:58 · PVG 08:58 · LAX 17:58 · JFK 20:58
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.