前提

go 语言已经发布了 1.18 带有正式泛型的版本,但很多文章都仍旧只是限于官方的 Add(数学加减法) 泛型函数例子。因此本文尝试使用泛型来简化数据库操作这一个过程,深入了解及使用这个新版的泛型。当然其中的实现都以简单为主,当做使用泛型的可用例子。 注意本文使用 gorm 作为基础的数据库访问工具,当然你也可以将本文提到的泛型操作改造并用于标准库。 至于项目 package 的布局,暂时以 java 的为主要参考。

api=gin 控制器
dao=数据库访问对象
entity=数据库实体对象(通常指表)
model=数据库映射对象(通常是查询结果的映射)
service=业务服务处理

1.实体定义

定义数据库访问对象的结构体

// entity/user.go
package entity

import "yujinping.top/mall/types"

type User struct {
	Id         uint64         `json:"id"`
	Username   string         `json:"username"`
	Password   string         `json:"password"`
	Name       string         `json:"name"`
	UserGuid   string         `json:"userGuid"`
	CreateTime types.DateTime `json:"createTime"`
}

2.泛型约束定义

定义数据库访问对象泛型约束接口,此种方法要求将所有的 SQL 操作语句要使用的泛型对象,都要加入进来。这个是 go 语言与 java 语言两者泛型的最大区别。

// models.go
package model

import "yujinping.top/mall/entity"

type Model interface {
	entity.User | entity.Article | entity.Role
}

3.DAO 的基础定义

定义数据库访问基础

// dao/base.go
package dao

import (
	stdLog "log"
	"os"
	"strings"
	"time"

	"github.com/rs/zerolog/log"
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	"yujinping.top/mall/common/model"
	"yujinping.top/mall/config"
)

//查询单个对象
func Get[T model.Model](q string, p ...any) T {
	var e T
	db().Raw(q, p...).Scan(&e)
	return e
}

// 查询结果集
func Query[T model.Model](q string, p ...any) []T {
	var objects []T
	db().Raw(q, p...).Scan(&objects)
	return objects
}

// 查询总记录数
func Count(q string, p ...any) int64 {
	var count int64
	db().Raw(q, p...).Scan(&count)
	return count
}

// 查询分页对象
func Page[T model.Model](page int64, size int64, q string, p ...any) model.Pagination[T] {
	countSql := buildCountSql(q)
	offset := (page - 1) * size
	if offset < 0 {
		offset = 0
		page = 1
	}
	limitSql, params := buildLimitSql(offset, size, q, p...)
	count := Count(countSql, p...)
	rows := Query[T](limitSql, params...)

	diff := count % size
	var pages int64
	if count == 0 {
		pages = 0
	} else {
		if diff > 0 {
			pages = (count-diff)/size + 1
		} else {
			pages = count / size
		}
	}

	var pagination model.Pagination[T]
	pagination.TotalRows = count
	pagination.CurrentPage = page
	pagination.TotalPages = pages
	pagination.PageSize = size
	pagination.Data = rows
	return pagination
}

// 构建查询记录总数SQL
func buildCountSql(q string) string {
	t := strings.ToUpper(q)
	pos := strings.Index(t, " FROM ")
	if pos != -1 {
		t = "SELECT COUNT(*)" + q[pos:]
	}
	return t
}

// 构建LIMIT SQL
func buildLimitSql(offset int64, limit int64, q string, p ...any) (string, []any) {
	t := q + " LIMIT ?,?"
	var params []any
	params = append(params, p...)
	params = append(params, offset, limit)
	return t, params
}

var _db *gorm.DB = nil

func db() *gorm.DB {
	if _db == nil {
		_config := config.GetDbConfig()
		log.Debug().Msgf("Connecting to database[%s] %s:%d ...", _config.Database, _config.Host, _config.Port)
		var err error
		var cfg = gorm.Config{
			SkipDefaultTransaction: true,
			PrepareStmt:            true,
		}
		if _config.ShowSql {
			cfg.Logger = createConsoleLogger()
		}
		_db, err = gorm.Open(mysql.Open(_config.GetUrl()), &cfg)

		if err != nil {
			var msg any = "Failed to connect to database!"
			panic(msg)
		}
		log.Debug().Msg("Connected to database successfully!")
		// // GORM 使用 database/sql 维护连接池
		sqlDB, _ := _db.DB()
		sqlDB.SetMaxIdleConns(20)
		// SetMaxOpenConnections 设置打开数据库连接的最大数量。
		sqlDB.SetMaxOpenConns(80)
		// SetConnMaxLifetime 设置了连接可复用的最大时间。
		// sqlDB.SetConnMaxLifetime(30 * time.Second)
	}
	return _db
}

func createConsoleLogger() logger.Interface {
	console := os.Stdout
	//console := colorable.NewColorableStdout()
	_logger := logger.New(
		stdLog.New(console, "\r\n", stdLog.LstdFlags|stdLog.Lshortfile), // io writer(日志输出的目标,前缀和日志包含的内容——译者注)
		logger.Config{
			SlowThreshold:             2 * time.Second, //1000 * 1000, // 慢 SQL 阈值 2 秒
			LogLevel:                  logger.Info,     // 日志级别
			IgnoreRecordNotFoundError: true,            // 忽略ErrRecordNotFound(记录未找到)错误
			Colorful:                  true,            // 彩色打印
		},
	)
	return _logger
}

func Ping() {
	log.Debug().Msg("Ping database ...")
	db().Exec("SELECT 1 FROM DUAL")
	log.Info().Msg("Database is ready.")
}

4.定义各个实体访问对象

// dao/user_dao.go
package dao

import (
	"yujinping.top/mall/common/model"
	"yujinping.top/mall/common/util"
	"yujinping.top/mall/entity"
)

type UserDao struct {
}

func NewUserDao() UserDao {
	return UserDao{}
}
func (d *UserDao) QueryUserList() []entity.User {
	return Query[entity.User]("SELECT * FROM user")
}

func (d *UserDao) GetUser(id uint64) (entity.User, error) {
	var user entity.User
	user = Get[entity.User]("SELECT * FROM user WHERE id=?", id)
	return user, nil
}

func (d *UserDao) FindUser(username string, password string) (entity.User, error) {
	pwd := util.Md5(password)
	user := Get[entity.User]("SELECT * FROM user WHERE username=? AND password=?", username, pwd)
	return user, nil
}

func (d *UserDao) GetUserPage(page int64, size int64) model.Pagination[entity.User] {
	return Page[entity.User](page, size, "SELECT * FROM user")
}

5.定义 gin 控制器

package api

import (
	"github.com/gin-gonic/gin"
	"github.com/spf13/cast"
	"yujinping.top/mall/dao"
	"yujinping.top/mall/resp"
)

// 单个用户详情
func GetUser(c *gin.Context) resp.Res {
	id := cast.ToUint64(c.Param("id"))
  var userDao dao.UserDao
	user, err := userDao.GetUser(id)
	if err != nil {
		return resp.Fail("User not found")
	}
	return resp.Ok(user)
}

//用户分页管理
func UserPage(c *gin.Context) resp.Res {
	page := cast.ToInt64(c.Param("page"))
	size := cast.ToInt64(c.Param("size"))
	if page == 0 {
		page = 1
	}
	if size == 0 {
		size = 20
	}
  var userDao dao.UserDao
	return resp.Ok(userDao.GetUserPage(page, size))
}

注意事项

  • go 的泛型标记使用[]中括号,没有使用更为广泛的<>,这个极其不习惯,很容易和数组、切片、map 混淆,这个是非常值得注意的地方。
  • 另外泛型约束,显得有点拖沓!按照上面的思路,如果有 N 条 SQL 读取操作不同的对象,那么这些对象定义都要加到泛型约束里去,否则编译时直接报非法的泛型。也许就是因为 go 是强类型的原因吧!