| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- package sql
- import (
- "errors"
- "fmt"
- "git.sxidc.com/go-tools/utils/encoding"
- "git.sxidc.com/go-tools/utils/reflectutils"
- "git.sxidc.com/go-tools/utils/strutils"
- "reflect"
- "strings"
- "time"
- )
- const (
- sqlResultTimeMicroFormat = "2006-01-02T15:04:05.000000+08:00"
- sqlResultTimeMilliFormat = "2006-01-02T15:04:05.000+08:00"
- sqlResultTimeSecFormat = "2006-01-02T15:04:05+08:00"
- )
- func parseSqlTableRowTimeStr(timeStr string) (time.Time, error) {
- var layout string
- if strings.HasSuffix(timeStr, ".000000+08:00") {
- layout = sqlResultTimeMicroFormat
- } else if strings.HasSuffix(timeStr, ".000+08:00") {
- layout = sqlResultTimeMilliFormat
- } else {
- layout = sqlResultTimeSecFormat
- }
- return time.ParseInLocation(layout, timeStr, time.Local)
- }
- func ParseSqlTableRow(input any, output any) error {
- if input == nil || output == nil {
- return nil
- }
- // 输出的Type,可以是slice的指针或者是结构的指针
- outputType := reflect.TypeOf(output)
- if outputType.Kind() != reflect.Ptr {
- return errors.New("输出实体应该为结构的slice或者是结构的指针")
- }
- // 取元素类型
- if outputType.Kind() == reflect.Ptr {
- outputType = outputType.Elem()
- }
- // 检查元素类型是否为slice或者结构
- if outputType.Kind() != reflect.Slice && outputType.Kind() != reflect.Struct {
- return errors.New("输出实体应该为结构的slice或者是结构的指针")
- }
- // 如果输出类型为slice,则取slice元素类型
- outputElemType := outputType
- if outputElemType.Kind() == reflect.Slice {
- outputElemType = outputElemType.Elem()
- }
- // 校验元素类型是否为结构类型
- if outputElemType.Kind() != reflect.Struct {
- return errors.New("输出实体slice应该为结构的slice指针")
- }
- // 构造需要遍历的tableRows
- tableRows, ok := input.([]map[string]any)
- if !ok {
- tableRow, ok := input.(map[string]any)
- if !ok {
- return errors.New("输入数据应该为[]map[string]any或[]map[string]any")
- }
- tableRows = []map[string]any{tableRow}
- }
- // 构造输出实体slice
- outputEntities := reflect.MakeSlice(reflect.SliceOf(outputElemType), 0, 0)
- for _, tableRow := range tableRows {
- // 构造输出实体
- outputEntityValue := reflect.New(outputElemType).Elem().Addr()
- outputEntity := outputEntityValue.Interface()
- err := formOutputEntity(tableRow, outputEntity)
- if err != nil {
- return err
- }
- // 保存输出实体
- outputEntities = reflect.Append(outputEntities, outputEntityValue.Elem())
- }
- // 将输出实体赋值给输出指针变量
- outputValue := reflect.Indirect(reflect.ValueOf(output))
- if outputType.Kind() == reflect.Slice {
- outputValue.Set(outputEntities)
- } else {
- outputValue.Set(outputEntities.Index(0))
- }
- return nil
- }
- func formOutputEntity(tableRow map[string]any, outputEntity any) error {
- sqlResult, err := ParseSqlResult(outputEntity)
- if err != nil {
- return err
- }
- for fieldName, resultElement := range sqlResult.ResultElement {
- switch element := resultElement.(type) {
- case *ResultStruct:
- err := formOutputEntity(tableRow, element.FieldValueElem.Addr().Interface())
- if err != nil {
- return err
- }
- case *ResultColumn:
- tableRowValue, ok := tableRow[element.Name]
- if !ok {
- continue
- }
- // 构造结构字段,如果结构字段是指针且为nil,需要构造元素
- fieldValue := element.FieldValueElem
- outputKind := reflectutils.GroupValueKind(fieldValue)
- switch outputKind {
- case reflect.Bool:
- err := reflectutils.AssignBoolValue(tableRowValue, fieldValue)
- if err != nil {
- return err
- }
- case reflect.String:
- strValue := tableRowValue.(string)
- if strutils.IsStringNotEmpty(element.ParseTime) {
- parsedTime, err := parseSqlTableRowTimeStr(strValue)
- if err != nil {
- return err
- }
- strValue = parsedTime.Format(element.ParseTime)
- } else if strutils.IsStringNotEmpty(element.AESKey) {
- if strutils.IsStringNotEmpty(strValue) {
- decryptedValue, err := encoding.AESDecrypt(strValue, element.AESKey)
- if err != nil {
- return err
- }
- strValue = decryptedValue
- }
- }
- err = reflectutils.AssignStringValue(strValue, fieldValue)
- if err != nil {
- return err
- }
- case reflect.Int64:
- err := reflectutils.AssignIntValue(tableRowValue, fieldValue)
- if err != nil {
- return err
- }
- case reflect.Uint64:
- err := reflectutils.AssignUintValue(tableRowValue, fieldValue)
- if err != nil {
- return err
- }
- case reflect.Float64:
- err := reflectutils.AssignFloatValue(tableRowValue, fieldValue)
- if err != nil {
- return err
- }
- case reflect.Struct:
- if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
- parsedTime, err := parseSqlTableRowTimeStr(tableRowValue.(string))
- if err != nil {
- return err
- }
- fieldValue.Set(reflect.ValueOf(parsedTime))
- continue
- }
- return fmt.Errorf("字段: %s 列: %s 不支持的类型: %s",
- fieldName, element.Name, reflect.TypeOf(tableRowValue).String())
- default:
- return fmt.Errorf("字段: %s 列: %s 不支持的类型: %s",
- fieldName, element.Name, reflect.TypeOf(tableRowValue).String())
- }
- default:
- return errors.New("不支持的元素类型")
- }
- }
- return nil
- }
|