db.go 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. package operations
  2. import (
  3. "fmt"
  4. "github.com/pkg/errors"
  5. "gorm.io/driver/postgres"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/logger"
  8. )
  9. const (
  10. logLevelSilent = "silent"
  11. logLevelError = "error"
  12. logLevelWarn = "warn"
  13. logLevelInfo = "info"
  14. )
  15. func newGormDB(dbConfig *Config) (*gorm.DB, error) {
  16. if dbConfig == nil {
  17. return nil, errors.New("没有传递数据库配置")
  18. }
  19. gormDB, err := newPostgresGormDB(dbConfig)
  20. if err != nil {
  21. return nil, err
  22. }
  23. return gormDB, nil
  24. }
  25. func newPostgresGormDB(dbConfig *Config) (*gorm.DB, error) {
  26. dsn := "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable TimeZone=Asia/Shanghai"
  27. connStr := fmt.Sprintf(dsn, dbConfig.Address, dbConfig.Port, dbConfig.UserName, dbConfig.Password, dbConfig.Database)
  28. gormDB, err := gorm.Open(postgres.Open(connStr), &gorm.Config{
  29. Logger: logger.Default.LogMode(chooseLogLevel(dbConfig.LogLevel)),
  30. PrepareStmt: true,
  31. })
  32. if err != nil {
  33. return nil, errors.New(err.Error())
  34. }
  35. return gormDB, nil
  36. }
  37. func destroyGormDB(gormDB *gorm.DB) error {
  38. if gormDB == nil {
  39. return nil
  40. }
  41. db, err := gormDB.DB()
  42. if err != nil {
  43. return errors.New(err.Error())
  44. }
  45. err = db.Close()
  46. if err != nil {
  47. return errors.New(err.Error())
  48. }
  49. db = nil
  50. return nil
  51. }
  52. func chooseLogLevel(logLevel string) logger.LogLevel {
  53. switch logLevel {
  54. case logLevelSilent:
  55. return logger.Silent
  56. case logLevelError:
  57. return logger.Error
  58. case logLevelWarn:
  59. return logger.Warn
  60. case logLevelInfo:
  61. return logger.Info
  62. default:
  63. return logger.Info
  64. }
  65. }