Browse Source

修改接口

yjp 1 year ago
parent
commit
046232cf81
4 changed files with 118 additions and 22 deletions
  1. 93 0
      db_operations/condition.go
  2. 7 4
      db_operations/db_operations.go
  3. 12 12
      db_operations/operations.go
  4. 6 6
      demo/demo.go

+ 93 - 0
db_operations/condition.go

@@ -0,0 +1,93 @@
+package db_operations
+
+import "gorm.io/gorm"
+
+type Conditions struct {
+	queries   []string
+	queryArgs [][]any
+}
+
+func NewConditions() *Conditions {
+	return &Conditions{
+		queries:   make([]string, 0),
+		queryArgs: make([][]any, 0),
+	}
+}
+
+func (clause *Conditions) Equal(columnName string, value any) *Conditions {
+	clause.queries = append(clause.queries, columnName+" = ?")
+	clause.queryArgs = append(clause.queryArgs, []any{value})
+	return clause
+}
+
+func (clause *Conditions) Like(columnName string, value any) *Conditions {
+	clause.queries = append(clause.queries, columnName+" LIKE ?")
+	clause.queryArgs = append(clause.queryArgs, []any{value})
+	return clause
+}
+
+func (clause *Conditions) In(columnName string, value any) *Conditions {
+	clause.queries = append(clause.queries, columnName+" IN ?")
+	clause.queryArgs = append(clause.queryArgs, []any{value})
+	return clause
+}
+
+func (clause *Conditions) NotIn(columnName string, value any) *Conditions {
+	clause.queries = append(clause.queries, columnName+" NOT IN ?")
+	clause.queryArgs = append(clause.queryArgs, []any{value})
+	return clause
+}
+
+func (clause *Conditions) Not(columnName string, value any) *Conditions {
+	clause.queries = append(clause.queries, columnName+" != ?")
+	clause.queryArgs = append(clause.queryArgs, []any{value})
+	return clause
+}
+
+func (clause *Conditions) LessThan(columnName string, value any) *Conditions {
+	clause.queries = append(clause.queries, columnName+" < ?")
+	clause.queryArgs = append(clause.queryArgs, []any{value})
+	return clause
+}
+
+func (clause *Conditions) LessThanAndEqual(columnName string, value any) *Conditions {
+	clause.queries = append(clause.queries, columnName+" <= ?")
+	clause.queryArgs = append(clause.queryArgs, []any{value})
+	return clause
+}
+
+func (clause *Conditions) GreaterThan(columnName string, value any) *Conditions {
+	clause.queries = append(clause.queries, columnName+" > ?")
+	clause.queryArgs = append(clause.queryArgs, []any{value})
+	return clause
+}
+
+func (clause *Conditions) GreaterThanAndEqual(columnName string, value any) *Conditions {
+	clause.queries = append(clause.queries, columnName+" >= ?")
+	clause.queryArgs = append(clause.queryArgs, []any{value})
+	return clause
+}
+
+func (clause *Conditions) where(db *gorm.DB) *gorm.DB {
+	for i, query := range clause.queries {
+		db = db.Where(query, clause.queryArgs[i]...)
+	}
+
+	return db
+}
+
+func (clause *Conditions) or(db *gorm.DB) *gorm.DB {
+	for i, query := range clause.queries {
+		db = db.Or(query, clause.queryArgs[i]...)
+	}
+
+	return db
+}
+
+func (clause *Conditions) having(db *gorm.DB) *gorm.DB {
+	for i, query := range clause.queries {
+		db = db.Having(query, clause.queryArgs[i]...)
+	}
+
+	return db
+}

+ 7 - 4
db_operations/db_operations.go

@@ -24,6 +24,9 @@ type BaseDBOperations interface {
 	// 会重置数据库连接的方法
 	Table(name string, args ...any) DBOperations
 
+	// 会重置数据库连接的方法,一般配合Raw使用
+	NewSession() DBOperations
+
 	// 执行SQL语句,使用Raw之后,为了触发SQL执行,需要调用Row或者Rows
 	// 如果是查询语句,使用Rows或Row均可,主要看自己需要查询的是单行还是多行
 	// 如果是写语句,必须使用Rows,否则由于没有返回结果,Rows会报错
@@ -32,12 +35,12 @@ type BaseDBOperations interface {
 
 	// 组织SQL语句相关的方法
 	Select(query string, args ...any) DBOperations
-	Where(query string, args ...any) DBOperations
-	Or(query string, args ...any) DBOperations
+	Joins(query string, args ...any) DBOperations
+	Where(conditions *Conditions) DBOperations
+	Or(conditions *Conditions) DBOperations
+	Having(conditions *Conditions) DBOperations
 	GroupBy(groupBy string) DBOperations
 	OrderBy(orderBy string) DBOperations
-	Joins(query string, args ...any) DBOperations
-	Having(query string, args ...any) DBOperations
 	Paging(pageNo int, pageSize int) DBOperations
 
 	// 写方法

+ 12 - 12
db_operations/operations.go

@@ -130,33 +130,33 @@ func (op *Operations) Select(query string, args ...any) DBOperations {
 	return op
 }
 
-func (op *Operations) Where(query string, args ...any) DBOperations {
-	op.processDB = op.processDB.Where(query, args...)
+func (op *Operations) Joins(query string, args ...any) DBOperations {
+	op.processDB = op.processDB.Joins(query, args...)
 	return op
 }
 
-func (op *Operations) Or(query string, args ...any) DBOperations {
-	op.processDB = op.processDB.Or(query, args...)
+func (op *Operations) Where(conditions *Conditions) DBOperations {
+	op.processDB = conditions.where(op.processDB)
 	return op
 }
 
-func (op *Operations) GroupBy(groupBy string) DBOperations {
-	op.processDB = op.processDB.Group(groupBy)
+func (op *Operations) Or(conditions *Conditions) DBOperations {
+	op.processDB = conditions.or(op.processDB)
 	return op
 }
 
-func (op *Operations) OrderBy(orderBy string) DBOperations {
-	op.processDB = op.processDB.Order(orderBy)
+func (op *Operations) Having(conditions *Conditions) DBOperations {
+	op.processDB = conditions.having(op.processDB)
 	return op
 }
 
-func (op *Operations) Joins(query string, args ...any) DBOperations {
-	op.processDB = op.processDB.Joins(query, args...)
+func (op *Operations) GroupBy(groupBy string) DBOperations {
+	op.processDB = op.processDB.Group(groupBy)
 	return op
 }
 
-func (op *Operations) Having(query string, args ...any) DBOperations {
-	op.processDB = op.processDB.Having(query, args...)
+func (op *Operations) OrderBy(orderBy string) DBOperations {
+	op.processDB = op.processDB.Order(orderBy)
 	return op
 }
 

+ 6 - 6
demo/demo.go

@@ -91,7 +91,7 @@ func main() {
 		panic(err)
 	}
 
-	err = sdk.GetInstance().GetDBOperations().NewSession().
+	err = sdk.GetInstance().GetDBOperations().
 		Table("test.classes").
 		Create(db_operations.NewTableRowFromMap(map[string]any{
 			"id":          classID1,
@@ -102,15 +102,15 @@ func main() {
 		panic(err)
 	}
 
-	err = sdk.GetInstance().GetDBOperations().NewSession().
+	err = sdk.GetInstance().GetDBOperations().
 		Table("test.classes").
-		Where("id = ?", classID1).
+		Where(db_operations.NewConditions().Equal("id", classID1)).
 		Delete()
 	if err != nil {
 		panic(err)
 	}
 
-	err = sdk.GetInstance().GetDBOperations().NewSession().
+	err = sdk.GetInstance().GetDBOperations().
 		Table("test.classes").
 		CreateBatch([]db_operations.TableRow{
 			*db_operations.NewTableRow().
@@ -139,9 +139,9 @@ func main() {
 
 	tx.CommitTransaction()
 
-	tableRow, err := sdk.GetInstance().GetDBOperations().NewSession().
+	tableRow, err := sdk.GetInstance().GetDBOperations().
 		Table("test.classes").
-		Where("id = ?", classID1).
+		Where(db_operations.NewConditions().Equal("id", classID1)).
 		Row()
 	if err != nil {
 		panic(err)