db.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. package operations
  2. import (
  3. "errors"
  4. "fmt"
  5. "gorm.io/driver/postgres"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/logger"
  8. )
  9. const (
  10. databaseTypePostgres = "postgres"
  11. )
  12. type DBConfig struct {
  13. Type string `mapstructure:"type"`
  14. UserName string `mapstructure:"user_name"`
  15. Password string `mapstructure:"password"`
  16. Address string `mapstructure:"address"`
  17. Port string `mapstructure:"port"`
  18. Database string `mapstructure:"database"`
  19. MaxConnections int `mapstructure:"max_connections"`
  20. MaxIdleConnections int `mapstructure:"max_idle_connections"`
  21. }
  22. func newGormDB(dbConfig *DBConfig) (*gorm.DB, error) {
  23. if dbConfig == nil {
  24. return nil, errors.New("没有传递数据库配置")
  25. }
  26. var gormDB *gorm.DB
  27. switch dbConfig.Type {
  28. case databaseTypePostgres:
  29. innerGormDB, err := newPostgresGormDB(dbConfig)
  30. if err != nil {
  31. return nil, err
  32. }
  33. gormDB = innerGormDB
  34. default:
  35. innerGormDB, err := newPostgresGormDB(dbConfig)
  36. if err != nil {
  37. return nil, err
  38. }
  39. gormDB = innerGormDB
  40. }
  41. return gormDB, nil
  42. }
  43. func newPostgresGormDB(dbConfig *DBConfig) (*gorm.DB, error) {
  44. dsn := "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable TimeZone=Asia/Shanghai"
  45. connStr := fmt.Sprintf(dsn, dbConfig.Address, dbConfig.Port, dbConfig.UserName, dbConfig.Password, dbConfig.Database)
  46. return gorm.Open(postgres.Open(connStr), &gorm.Config{
  47. Logger: logger.Default.LogMode(logger.Info),
  48. PrepareStmt: true,
  49. })
  50. }
  51. func destroyGormDB(gormDB *gorm.DB) error {
  52. if gormDB == nil {
  53. return nil
  54. }
  55. db, err := gormDB.DB()
  56. if err != nil {
  57. return err
  58. }
  59. return db.Close()
  60. }