Ver Fonte

添加table结构

yjp há 1 ano atrás
pai
commit
a99691b789
4 ficheiros alterados com 342 adições e 34 exclusões
  1. 6 6
      db_operations/db_operations.go
  2. 24 14
      db_operations/operations.go
  3. 294 0
      db_operations/table_row.go
  4. 18 14
      demo/demo.go

+ 6 - 6
db_operations/db_operations.go

@@ -42,15 +42,15 @@ type BaseDBOperations interface {
 	Paging(pageNo int, pageSize int) DBOperations
 
 	// 写方法
-	Create(tableRow map[string]any) error
-	CreateBatch(tableRows []map[string]any) error
+	Create(tableRow *TableRow) error
+	CreateBatch(tableRows []TableRow) error
 	Delete() error
-	Updates(newTableRow map[string]any) error
-	UpdatesWithRowsAffected(newTableRow map[string]any) (int64, error)
+	Updates(newTableRow *TableRow) error
+	UpdatesWithRowsAffected(newTableRow *TableRow) (int64, error)
 
 	// 查询方法
-	Rows(pageNo int, pageSize int) ([]map[string]any, error)
-	Row() (map[string]any, error)
+	Rows(pageNo int, pageSize int) ([]TableRow, error)
+	Row() (*TableRow, error)
 
 	// 其他方法
 	Count(count *int64) error

+ 24 - 14
db_operations/operations.go

@@ -178,8 +178,8 @@ func (op *Operations) Paging(pageNo int, pageSize int) DBOperations {
 	return op
 }
 
-func (op *Operations) Create(tableRow map[string]any) error {
-	err := op.processDB.Create(tableRow).Error
+func (op *Operations) Create(tableRow *TableRow) error {
+	err := op.processDB.Create(tableRow.ToMap()).Error
 	if err != nil {
 		if strings.Contains(err.Error(), "SQLSTATE 23505") {
 			return dberr.ErrDBRecordHasExist
@@ -191,8 +191,13 @@ func (op *Operations) Create(tableRow map[string]any) error {
 	return nil
 }
 
-func (op *Operations) CreateBatch(tableRows []map[string]any) error {
-	err := op.processDB.Create(tableRows).Error
+func (op *Operations) CreateBatch(tableRows []TableRow) error {
+	tableRowMaps := make([]map[string]any, 0)
+	for _, tableRow := range tableRows {
+		tableRowMaps = append(tableRowMaps, tableRow.ToMap())
+	}
+
+	err := op.processDB.Create(tableRowMaps).Error
 	if err != nil {
 		if strings.Contains(err.Error(), "SQLSTATE 23505") {
 			return dberr.ErrDBRecordHasExist
@@ -208,8 +213,8 @@ func (op *Operations) Delete() error {
 	return op.processDB.Delete(make(map[string]any)).Error
 }
 
-func (op *Operations) Updates(newTableRow map[string]any) error {
-	err := op.processDB.Updates(newTableRow).Error
+func (op *Operations) Updates(newTableRow *TableRow) error {
+	err := op.processDB.Updates(newTableRow.ToMap()).Error
 	if err != nil {
 		if strings.Contains(err.Error(), "SQLSTATE 23505") {
 			return dberr.ErrDBRecordHasExist
@@ -221,8 +226,8 @@ func (op *Operations) Updates(newTableRow map[string]any) error {
 	return nil
 }
 
-func (op *Operations) UpdatesWithRowsAffected(newTableRow map[string]any) (int64, error) {
-	op.processDB = op.processDB.Updates(newTableRow)
+func (op *Operations) UpdatesWithRowsAffected(newTableRow *TableRow) (int64, error) {
+	op.processDB = op.processDB.Updates(newTableRow.ToMap())
 	if op.processDB.Error != nil {
 		return 0, op.processDB.Error
 	}
@@ -230,7 +235,7 @@ func (op *Operations) UpdatesWithRowsAffected(newTableRow map[string]any) (int64
 	return op.processDB.RowsAffected, nil
 }
 
-func (op *Operations) Rows(pageNo int, pageSize int) ([]map[string]any, error) {
+func (op *Operations) Rows(pageNo int, pageSize int) ([]TableRow, error) {
 	if pageNo != 0 && pageSize != 0 {
 		offset := (pageNo - 1) * pageSize
 		op.processDB = op.processDB.Offset(offset).Limit(pageSize)
@@ -240,16 +245,21 @@ func (op *Operations) Rows(pageNo int, pageSize int) ([]map[string]any, error) {
 		op.processDB = op.processDB.Offset(-1).Limit(-1)
 	}()
 
-	valueMaps := make([]map[string]any, 0)
-	err := op.processDB.Scan(&valueMaps).Error
+	tableRowMaps := make([]map[string]any, 0)
+	err := op.processDB.Scan(&tableRowMaps).Error
 	if err != nil {
 		return nil, err
 	}
 
-	return valueMaps, nil
+	tableRows := make([]TableRow, 0)
+	for _, tableRowMap := range tableRowMaps {
+		tableRows = append(tableRows, *NewTableRowFromMap(tableRowMap))
+	}
+
+	return tableRows, nil
 }
 
-func (op *Operations) Row() (map[string]any, error) {
+func (op *Operations) Row() (*TableRow, error) {
 	valueMap := make(map[string]any)
 	err := op.processDB.Scan(&valueMap).Error
 	if err != nil {
@@ -260,7 +270,7 @@ func (op *Operations) Row() (map[string]any, error) {
 		return nil, dberr.ErrDBRecordNotExist
 	}
 
-	return valueMap, nil
+	return NewTableRowFromMap(valueMap), nil
 }
 
 func (op *Operations) Count(count *int64) error {

+ 294 - 0
db_operations/table_row.go

@@ -0,0 +1,294 @@
+package db_operations
+
+import (
+	"reflect"
+	"time"
+)
+
+type TableRow struct {
+	row map[string]any
+}
+
+func NewTableRow() *TableRow {
+	return &TableRow{row: make(map[string]any)}
+}
+
+func NewTableRowFromMap(m map[string]any) *TableRow {
+	tableRow := NewTableRow()
+
+	for key, value := range m {
+		v := value
+
+		valueType := reflect.TypeOf(value)
+		if valueType.Kind() == reflect.Ptr {
+			v = reflect.ValueOf(value).Elem().Interface()
+		}
+
+		switch typedValue := v.(type) {
+		case string:
+			tableRow.row[key] = typedValue
+		case int:
+			tableRow.row[key] = uint64(typedValue)
+		case uint:
+			tableRow.row[key] = uint64(typedValue)
+		case int8:
+			tableRow.row[key] = uint64(typedValue)
+		case uint8:
+			tableRow.row[key] = uint64(typedValue)
+		case int16:
+			tableRow.row[key] = uint64(typedValue)
+		case uint16:
+			tableRow.row[key] = uint64(typedValue)
+		case int32:
+			tableRow.row[key] = uint64(typedValue)
+		case uint32:
+			tableRow.row[key] = uint64(typedValue)
+		case int64:
+			tableRow.row[key] = uint64(typedValue)
+		case uint64:
+			tableRow.row[key] = typedValue
+		case float32:
+			tableRow.row[key] = float64(typedValue)
+		case float64:
+			tableRow.row[key] = typedValue
+		case bool:
+			tableRow.row[key] = typedValue
+		case []byte:
+			tableRow.row[key] = typedValue
+		case time.Time:
+			tableRow.row[key] = typedValue
+		default:
+			panic("未支持的数据类型")
+		}
+	}
+
+	return tableRow
+}
+
+func (tableRow *TableRow) ToMap() map[string]any {
+	return tableRow.row
+}
+
+func (tableRow *TableRow) AddColumnValueTime(columnName string, value time.Time) *TableRow {
+	tableRow.row[columnName] = value
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueBool(columnName string, value bool) *TableRow {
+	tableRow.row[columnName] = value
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueString(columnName string, value string) *TableRow {
+	tableRow.row[columnName] = value
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueBytes(columnName string, value []byte) *TableRow {
+	tableRow.row[columnName] = value
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueInt(columnName string, value int) *TableRow {
+	tableRow.row[columnName] = uint64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueInt8(columnName string, value int8) *TableRow {
+	tableRow.row[columnName] = uint64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueInt16(columnName string, value int16) *TableRow {
+	tableRow.row[columnName] = uint64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueInt32(columnName string, value int32) *TableRow {
+	tableRow.row[columnName] = uint64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueInt64(columnName string, value int64) *TableRow {
+	tableRow.row[columnName] = uint64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueUint(columnName string, value uint) *TableRow {
+	tableRow.row[columnName] = uint64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueUint8(columnName string, value uint8) *TableRow {
+	tableRow.row[columnName] = uint64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueUint16(columnName string, value uint16) *TableRow {
+	tableRow.row[columnName] = uint64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueUint32(columnName string, value uint32) *TableRow {
+	tableRow.row[columnName] = uint64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueUint64(columnName string, value uint64) *TableRow {
+	tableRow.row[columnName] = value
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueFloat32(columnName string, value float32) *TableRow {
+	tableRow.row[columnName] = float64(value)
+	return tableRow
+}
+
+func (tableRow *TableRow) AddColumnValueFloat64(columnName string, value float64) *TableRow {
+	tableRow.row[columnName] = value
+	return tableRow
+}
+
+func (tableRow *TableRow) ColumnValueTime(columnName string) time.Time {
+	value, ok := tableRow.row[columnName].(time.Time)
+	if !ok {
+		return time.Time{}
+	}
+
+	return value
+}
+
+func (tableRow *TableRow) ColumnValueBool(columnName string) bool {
+	value, ok := tableRow.row[columnName].(bool)
+	if !ok {
+		return false
+	}
+
+	return value
+}
+
+func (tableRow *TableRow) ColumnValueString(columnName string) string {
+	value, ok := tableRow.row[columnName].(string)
+	if !ok {
+		return ""
+	}
+
+	return value
+}
+
+func (tableRow *TableRow) ColumnValueBytes(columnName string) []byte {
+	value, ok := tableRow.row[columnName].([]byte)
+	if !ok {
+		return make([]byte, 0)
+	}
+
+	return value
+}
+
+func (tableRow *TableRow) ColumnValueInt(columnName string) int {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return int(value)
+}
+
+func (tableRow *TableRow) ColumnValueInt8(columnName string) int8 {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return int8(value)
+}
+
+func (tableRow *TableRow) ColumnValueInt16(columnName string) int16 {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return int16(value)
+}
+
+func (tableRow *TableRow) ColumnValueInt32(columnName string) int32 {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return int32(value)
+}
+
+func (tableRow *TableRow) ColumnValueInt64(columnName string) int64 {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return int64(value)
+}
+
+func (tableRow *TableRow) ColumnValueUint(columnName string) uint {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return uint(value)
+}
+
+func (tableRow *TableRow) ColumnValueUint8(columnName string) uint8 {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return uint8(value)
+}
+
+func (tableRow *TableRow) ColumnValueUint16(columnName string) uint16 {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return uint16(value)
+}
+
+func (tableRow *TableRow) ColumnValueUint32(columnName string) uint32 {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return uint32(value)
+}
+
+func (tableRow *TableRow) ColumnValueUint64(columnName string) uint64 {
+	value, ok := tableRow.row[columnName].(uint64)
+	if !ok {
+		return 0
+	}
+
+	return value
+}
+
+func (tableRow *TableRow) ColumnValueFloat32(columnName string) float32 {
+	value, ok := tableRow.row[columnName].(float64)
+	if !ok {
+		return 0
+	}
+
+	return float32(value)
+}
+
+func (tableRow *TableRow) ColumnValueFloat64(columnName string) float64 {
+	value, ok := tableRow.row[columnName].(float64)
+	if !ok {
+		return 0
+	}
+
+	return value
+}

+ 18 - 14
demo/demo.go

@@ -2,6 +2,7 @@ package main
 
 import (
 	"git.sxidc.com/go-tools/utils/strutils"
+	"git.sxidc.com/service-supports/ds-sdk/db_operations"
 	"git.sxidc.com/service-supports/ds-sdk/sdk"
 	"math/rand"
 )
@@ -92,11 +93,11 @@ func main() {
 
 	err = sdk.GetInstance().GetDBOperations().NewSession().
 		Table("test.classes").
-		Create(map[string]any{
+		Create(db_operations.NewTableRowFromMap(map[string]any{
 			"id":          classID1,
 			"name":        className1,
 			"student_num": studentNum1,
-		})
+		}))
 	if err != nil {
 		panic(err)
 	}
@@ -111,17 +112,16 @@ func main() {
 
 	err = sdk.GetInstance().GetDBOperations().NewSession().
 		Table("test.classes").
-		CreateBatch([]map[string]any{
-			{
-				"id":          classID1,
-				"name":        className1,
-				"student_num": studentNum1,
-			},
-			{
+		CreateBatch([]db_operations.TableRow{
+			*db_operations.NewTableRow().
+				AddColumnValueString("id", classID1).
+				AddColumnValueString("name", className1).
+				AddColumnValueInt32("student_num", studentNum1),
+			*db_operations.NewTableRowFromMap(map[string]any{
 				"id":          classID2,
 				"name":        className2,
 				"student_num": studentNum2,
-			},
+			}),
 		})
 	if err != nil {
 		panic(err)
@@ -147,7 +147,9 @@ func main() {
 		panic(err)
 	}
 
-	if tableRow["id"] != classID1 || tableRow["name"] != newClassName1 || tableRow["student_num"] != newStudentNum1 {
+	if tableRow.ColumnValueString("id") != classID1 ||
+		tableRow.ColumnValueString("name") != newClassName1 ||
+		tableRow.ColumnValueInt32("student_num") != newStudentNum1 {
 		panic("数据查询错误")
 	}
 
@@ -159,12 +161,14 @@ func main() {
 	}
 
 	for _, tableRow := range tableRows {
-		if tableRow["id"] == classID1 {
-			if tableRow["name"] != newClassName1 || tableRow["student_num"] != newStudentNum1 {
+		if tableRow.ColumnValueString("id") == classID1 {
+			if tableRow.ColumnValueString("name") != newClassName1 ||
+				tableRow.ColumnValueInt32("student_num") != newStudentNum1 {
 				panic("数据查询错误")
 			}
 		} else {
-			if tableRow["name"] != className2 || tableRow["student_num"] != studentNum2 {
+			if tableRow.ColumnValueString("name") != className2 ||
+				tableRow.ColumnValueInt32("student_num") != studentNum2 {
 				panic("数据查询错误")
 			}
 		}