go泛型简化数据库读取操作
前提
go 语言已经发布了 1.18 带有正式泛型的版本,但很多文章都仍旧只是限于官方的 Add(数学加减法) 泛型函数例子。因此本文尝试使用泛型来简化数据库操作这一个过程,深入了解及使用这个新版的泛型。当然其中的实现都以简单为主,当做使用泛型的可用例子。 注意本文使用 gorm 作为基础的数据库访问工具,当然你也可以将本文提到的泛型操作改造并用于标准库。 至于项目 package 的布局,暂时以 java 的为主要参考。
api=gin 控制器
dao=数据库访问对象
entity=数据库实体对象(通常指表)
model=数据库映射对象(通常是查询结果的映射)
service=业务服务处理
1.实体定义
定义数据库访问对象的结构体
// entity/user.go
package entity
import "yujinping.top/mall/types"
type User struct {
Id uint64 `json:"id"`
Username string `json:"username"`
Password string `json:"password"`
Name string `json:"name"`
UserGuid string `json:"userGuid"`
CreateTime types.DateTime `json:"createTime"`
}
2.泛型约束定义
定义数据库访问对象泛型约束接口,此种方法要求将所有的 SQL 操作语句要使用的泛型对象,都要加入进来。这个是 go 语言与 java 语言两者泛型的最大区别。
// models.go
package model
import "yujinping.top/mall/entity"
type Model interface {
entity.User | entity.Article | entity.Role
}
3.DAO 的基础定义
定义数据库访问基础
// dao/base.go
package dao
import (
stdLog "log"
"os"
"strings"
"time"
"github.com/rs/zerolog/log"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"yujinping.top/mall/common/model"
"yujinping.top/mall/config"
)
//查询单个对象
func Get[T model.Model](q string, p ...any) T {
var e T
db().Raw(q, p...).Scan(&e)
return e
}
// 查询结果集
func Query[T model.Model](q string, p ...any) []T {
var objects []T
db().Raw(q, p...).Scan(&objects)
return objects
}
// 查询总记录数
func Count(q string, p ...any) int64 {
var count int64
db().Raw(q, p...).Scan(&count)
return count
}
// 查询分页对象
func Page[T model.Model](page int64, size int64, q string, p ...any) model.Pagination[T] {
countSql := buildCountSql(q)
offset := (page - 1) * size
if offset < 0 {
offset = 0
page = 1
}
limitSql, params := buildLimitSql(offset, size, q, p...)
count := Count(countSql, p...)
rows := Query[T](limitSql, params...)
diff := count % size
var pages int64
if count == 0 {
pages = 0
} else {
if diff > 0 {
pages = (count-diff)/size + 1
} else {
pages = count / size
}
}
var pagination model.Pagination[T]
pagination.TotalRows = count
pagination.CurrentPage = page
pagination.TotalPages = pages
pagination.PageSize = size
pagination.Data = rows
return pagination
}
// 构建查询记录总数SQL
func buildCountSql(q string) string {
t := strings.ToUpper(q)
pos := strings.Index(t, " FROM ")
if pos != -1 {
t = "SELECT COUNT(*)" + q[pos:]
}
return t
}
// 构建LIMIT SQL
func buildLimitSql(offset int64, limit int64, q string, p ...any) (string, []any) {
t := q + " LIMIT ?,?"
var params []any
params = append(params, p...)
params = append(params, offset, limit)
return t, params
}
var _db *gorm.DB = nil
func db() *gorm.DB {
if _db == nil {
_config := config.GetDbConfig()
log.Debug().Msgf("Connecting to database[%s] %s:%d ...", _config.Database, _config.Host, _config.Port)
var err error
var cfg = gorm.Config{
SkipDefaultTransaction: true,
PrepareStmt: true,
}
if _config.ShowSql {
cfg.Logger = createConsoleLogger()
}
_db, err = gorm.Open(mysql.Open(_config.GetUrl()), &cfg)
if err != nil {
var msg any = "Failed to connect to database!"
panic(msg)
}
log.Debug().Msg("Connected to database successfully!")
// // GORM 使用 database/sql 维护连接池
sqlDB, _ := _db.DB()
sqlDB.SetMaxIdleConns(20)
// SetMaxOpenConnections 设置打开数据库连接的最大数量。
sqlDB.SetMaxOpenConns(80)
// SetConnMaxLifetime 设置了连接可复用的最大时间。
// sqlDB.SetConnMaxLifetime(30 * time.Second)
}
return _db
}
func createConsoleLogger() logger.Interface {
console := os.Stdout
//console := colorable.NewColorableStdout()
_logger := logger.New(
stdLog.New(console, "\r\n", stdLog.LstdFlags|stdLog.Lshortfile), // io writer(日志输出的目标,前缀和日志包含的内容——译者注)
logger.Config{
SlowThreshold: 2 * time.Second, //1000 * 1000, // 慢 SQL 阈值 2 秒
LogLevel: logger.Info, // 日志级别
IgnoreRecordNotFoundError: true, // 忽略ErrRecordNotFound(记录未找到)错误
Colorful: true, // 彩色打印
},
)
return _logger
}
func Ping() {
log.Debug().Msg("Ping database ...")
db().Exec("SELECT 1 FROM DUAL")
log.Info().Msg("Database is ready.")
}
4.定义各个实体访问对象
// dao/user_dao.go
package dao
import (
"yujinping.top/mall/common/model"
"yujinping.top/mall/common/util"
"yujinping.top/mall/entity"
)
type UserDao struct {
}
func NewUserDao() UserDao {
return UserDao{}
}
func (d *UserDao) QueryUserList() []entity.User {
return Query[entity.User]("SELECT * FROM user")
}
func (d *UserDao) GetUser(id uint64) (entity.User, error) {
var user entity.User
user = Get[entity.User]("SELECT * FROM user WHERE id=?", id)
return user, nil
}
func (d *UserDao) FindUser(username string, password string) (entity.User, error) {
pwd := util.Md5(password)
user := Get[entity.User]("SELECT * FROM user WHERE username=? AND password=?", username, pwd)
return user, nil
}
func (d *UserDao) GetUserPage(page int64, size int64) model.Pagination[entity.User] {
return Page[entity.User](page, size, "SELECT * FROM user")
}
5.定义 gin 控制器
package api
import (
"github.com/gin-gonic/gin"
"github.com/spf13/cast"
"yujinping.top/mall/dao"
"yujinping.top/mall/resp"
)
// 单个用户详情
func GetUser(c *gin.Context) resp.Res {
id := cast.ToUint64(c.Param("id"))
var userDao dao.UserDao
user, err := userDao.GetUser(id)
if err != nil {
return resp.Fail("User not found")
}
return resp.Ok(user)
}
//用户分页管理
func UserPage(c *gin.Context) resp.Res {
page := cast.ToInt64(c.Param("page"))
size := cast.ToInt64(c.Param("size"))
if page == 0 {
page = 1
}
if size == 0 {
size = 20
}
var userDao dao.UserDao
return resp.Ok(userDao.GetUserPage(page, size))
}
注意事项
- go 的泛型标记使用[]中括号,没有使用更为广泛的<>,这个极其不习惯,很容易和数组、切片、map 混淆,这个是非常值得注意的地方。
- 另外泛型约束,显得有点拖沓!按照上面的思路,如果有 N 条 SQL 读取操作不同的对象,那么这些对象定义都要加到泛型约束里去,否则编译时直接报非法的泛型。也许就是因为 go 是强类型的原因吧!