Browse Source

添加aes tag

yjp 1 year ago
parent
commit
7c055ba576
7 changed files with 174 additions and 61 deletions
  1. 1 0
      README.md
  2. 22 8
      sql/sql.go
  3. 12 2
      sql/sql_mapping.go
  4. 18 18
      sql/sql_tpl/condition.go
  5. 2 2
      sql/sql_tpl/table_row.go
  6. 93 14
      sql/sql_tpl/value.go
  7. 26 17
      test/sdk_test.go

+ 1 - 0
README.md

@@ -264,6 +264,7 @@ column和notUpdate两个sqlmapping Tag,column指定了该字段对应的数据
 | key         | 该列是否作为逻辑键(实际到底哪个字段为键,是由DataContainer定义确定的)使用,如果一个结构的多个字段使用了key,这几个字段将被作为联合键使用 |
 | notUpdate   | 不对该列进行更新操作                                                                    |
 | updateClear | 允许将该列清空为零值                                                                    |
+| aes         | 进行aes加密并传递aes的密钥,密钥长度为32字节,不能包含';'                                            |
 
 ### SQL语句执行
 

+ 22 - 8
sql/sql.go

@@ -58,7 +58,12 @@ func InsertEntity[T any](executor Executor, tableName string, e T) error {
 			value = now
 		}
 
-		tableRow.Add(sqlColumn.Name, value)
+		var opts []sql_tpl.AfterParsedStrValueOption
+		if strutils.IsStringNotEmpty(sqlColumn.AESKey) {
+			opts = append(opts, sql_tpl.WithAESKey(sqlColumn.AESKey))
+		}
+
+		tableRow.Add(sqlColumn.Name, value, opts...)
 	}
 
 	executeParamsMap, err := sql_tpl.InsertExecuteParams{
@@ -112,7 +117,12 @@ func DeleteEntity[T any](executor Executor, tableName string, e T) error {
 			return errors.New("键字段没有传值")
 		}
 
-		conditions.Equal(sqlColumn.Name, sqlColumn.ValueFieldValue.Interface())
+		var opts []sql_tpl.AfterParsedStrValueOption
+		if strutils.IsStringNotEmpty(sqlColumn.AESKey) {
+			opts = append(opts, sql_tpl.WithAESKey(sqlColumn.AESKey))
+		}
+
+		conditions.Equal(sqlColumn.Name, sqlColumn.ValueFieldValue.Interface(), opts...)
 	}
 
 	executeParamsMap, err := sql_tpl.DeleteExecuteParams{
@@ -180,10 +190,15 @@ func UpdateEntity[T any](executor Executor, tableName string, e T) error {
 			value = now
 		}
 
+		var opts []sql_tpl.AfterParsedStrValueOption
+		if strutils.IsStringNotEmpty(sqlColumn.AESKey) {
+			opts = append(opts, sql_tpl.WithAESKey(sqlColumn.AESKey))
+		}
+
 		if sqlColumn.IsKey {
-			conditions.Equal(sqlColumn.Name, value)
+			conditions.Equal(sqlColumn.Name, value, opts...)
 		} else {
-			tableRows.Add(sqlColumn.Name, value)
+			tableRows.Add(sqlColumn.Name, value, opts...)
 		}
 	}
 
@@ -439,13 +454,12 @@ const (
 
 func ParseSqlResults(results any, e any) error {
 	decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
-		DecodeHook: func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
-
-			if f.Kind() != reflect.String {
+		DecodeHook: func(fromType reflect.Type, toType reflect.Type, data interface{}) (interface{}, error) {
+			if fromType.Kind() != reflect.String {
 				return data, nil
 			}
 
-			if t != reflect.TypeOf(time.Time{}) {
+			if toType != reflect.TypeOf(time.Time{}) {
 				return data, nil
 			}
 

+ 12 - 2
sql/sql_mapping.go

@@ -20,6 +20,7 @@ const (
 	sqlMappingKey         = "key"
 	sqlMappingNotUpdate   = "notUpdate"
 	sqlMappingUpdateClear = "updateClear"
+	sqlMappingAes         = "aes"
 )
 
 type Mapping struct {
@@ -73,6 +74,7 @@ type MappingColumn struct {
 	IsKey          bool
 	CanUpdate      bool
 	CanUpdateClear bool
+	AESKey         string
 
 	// 原字段的反射结构
 	OriginFieldType  reflect.Type
@@ -90,7 +92,7 @@ func parseSqlMappingColumn(field reflect.StructField, fieldValue reflect.Value)
 	if valueFieldType.Kind() == reflect.Ptr {
 		valueFieldType = valueFieldType.Elem()
 
-		if valueFieldValue.IsZero() {
+		if !valueFieldValue.IsValid() || valueFieldValue.IsNil() || valueFieldValue.IsZero() {
 			valueFieldValue = reflect.Zero(valueFieldType)
 		} else {
 			valueFieldValue = fieldValue.Elem()
@@ -102,6 +104,7 @@ func parseSqlMappingColumn(field reflect.StructField, fieldValue reflect.Value)
 		IsKey:            false,
 		CanUpdate:        true,
 		CanUpdateClear:   false,
+		AESKey:           "",
 		OriginFieldType:  field.Type,
 		OriginFieldValue: fieldValue,
 		ValueFieldType:   valueFieldType,
@@ -110,6 +113,7 @@ func parseSqlMappingColumn(field reflect.StructField, fieldValue reflect.Value)
 
 	if sqlColumn.Name == defaultKeyColumnName {
 		sqlColumn.IsKey = true
+		sqlColumn.CanUpdate = false
 	}
 
 	sqlMappingTag, ok := field.Tag.Lookup(sqlMappingTagKey)
@@ -124,7 +128,7 @@ func parseSqlMappingColumn(field reflect.StructField, fieldValue reflect.Value)
 	sqlMappingParts := strings.Split(sqlMappingTag, sqlMappingTagPartSeparator)
 	if sqlMappingParts != nil || len(sqlMappingParts) != 0 {
 		for _, sqlMappingPart := range sqlMappingParts {
-			sqlPartKeyValue := strings.Split(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator)
+			sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator, 2)
 			switch sqlPartKeyValue[0] {
 			case sqlMappingColumn:
 				sqlColumn.Name = strings.TrimSpace(sqlPartKeyValue[1])
@@ -135,6 +139,12 @@ func parseSqlMappingColumn(field reflect.StructField, fieldValue reflect.Value)
 				sqlColumn.CanUpdate = false
 			case sqlMappingUpdateClear:
 				sqlColumn.CanUpdateClear = true
+			case sqlMappingAes:
+				if len(strings.TrimSpace(sqlPartKeyValue[1])) != 32 {
+					return nil, errors.New("AES密钥长度应该为32个字节")
+				}
+
+				sqlColumn.AESKey = strings.TrimSpace(sqlPartKeyValue[1])
 			default:
 				continue
 			}

+ 18 - 18
sql/sql_tpl/condition.go

@@ -16,12 +16,12 @@ func (conditions *Conditions) AddCondition(condition string) *Conditions {
 	return conditions
 }
 
-func (conditions *Conditions) Equal(columnName string, value any) *Conditions {
+func (conditions *Conditions) Equal(columnName string, value any, opts ...AfterParsedStrValueOption) *Conditions {
 	if conditions.err != nil {
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions
@@ -32,12 +32,12 @@ func (conditions *Conditions) Equal(columnName string, value any) *Conditions {
 	return conditions
 }
 
-func (conditions *Conditions) Like(columnName string, value string) *Conditions {
+func (conditions *Conditions) Like(columnName string, value string, opts ...AfterParsedStrValueOption) *Conditions {
 	if conditions.err != nil {
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions
@@ -48,12 +48,12 @@ func (conditions *Conditions) Like(columnName string, value string) *Conditions
 	return conditions
 }
 
-func (conditions *Conditions) In(columnName string, value any) *Conditions {
+func (conditions *Conditions) In(columnName string, value any, opts ...AfterParsedStrValueOption) *Conditions {
 	if conditions.err != nil {
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions
@@ -64,12 +64,12 @@ func (conditions *Conditions) In(columnName string, value any) *Conditions {
 	return conditions
 }
 
-func (conditions *Conditions) NotIn(columnName string, value any) *Conditions {
+func (conditions *Conditions) NotIn(columnName string, value any, opts ...AfterParsedStrValueOption) *Conditions {
 	if conditions.err != nil {
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions
@@ -80,12 +80,12 @@ func (conditions *Conditions) NotIn(columnName string, value any) *Conditions {
 	return conditions
 }
 
-func (conditions *Conditions) Not(columnName string, value any) *Conditions {
+func (conditions *Conditions) Not(columnName string, value any, opts ...AfterParsedStrValueOption) *Conditions {
 	if conditions.err != nil {
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions
@@ -96,12 +96,12 @@ func (conditions *Conditions) Not(columnName string, value any) *Conditions {
 	return conditions
 }
 
-func (conditions *Conditions) LessThan(columnName string, value any) *Conditions {
+func (conditions *Conditions) LessThan(columnName string, value any, opts ...AfterParsedStrValueOption) *Conditions {
 	if conditions.err != nil {
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions
@@ -112,12 +112,12 @@ func (conditions *Conditions) LessThan(columnName string, value any) *Conditions
 	return conditions
 }
 
-func (conditions *Conditions) LessThanAndEqual(columnName string, value any) *Conditions {
+func (conditions *Conditions) LessThanAndEqual(columnName string, value any, opts ...AfterParsedStrValueOption) *Conditions {
 	if conditions.err != nil {
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions
@@ -128,12 +128,12 @@ func (conditions *Conditions) LessThanAndEqual(columnName string, value any) *Co
 	return conditions
 }
 
-func (conditions *Conditions) GreaterThan(columnName string, value any) *Conditions {
+func (conditions *Conditions) GreaterThan(columnName string, value any, opts ...AfterParsedStrValueOption) *Conditions {
 	if conditions.err != nil {
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions
@@ -144,12 +144,12 @@ func (conditions *Conditions) GreaterThan(columnName string, value any) *Conditi
 	return conditions
 }
 
-func (conditions *Conditions) GreaterThanAndEqual(columnName string, value any) *Conditions {
+func (conditions *Conditions) GreaterThanAndEqual(columnName string, value any, opts ...AfterParsedStrValueOption) *Conditions {
 	if conditions.err != nil {
 		return conditions
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		conditions.err = err
 		return conditions

+ 2 - 2
sql/sql_tpl/table_row.go

@@ -16,12 +16,12 @@ func NewTableRows() *TableRows {
 	}
 }
 
-func (tableRows *TableRows) Add(column string, value any) *TableRows {
+func (tableRows *TableRows) Add(column string, value any, opts ...AfterParsedStrValueOption) *TableRows {
 	if tableRows.err != nil {
 		return tableRows
 	}
 
-	parsedValue, err := parseValue(value)
+	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
 		tableRows.err = err
 		return tableRows

+ 93 - 14
sql/sql_tpl/value.go

@@ -2,6 +2,7 @@ package sql_tpl
 
 import (
 	"errors"
+	"git.sxidc.com/go-tools/utils/encoding"
 	"reflect"
 	"strconv"
 	"time"
@@ -11,7 +12,20 @@ const (
 	timeWriteFormat = time.DateTime + ".000000 +08:00"
 )
 
-func parseValue(value any) (string, error) {
+type AfterParsedStrValueOption func(strValue string) (string, error)
+
+func WithAESKey(aesKey string) AfterParsedStrValueOption {
+	return func(strValue string) (string, error) {
+		encrypted, err := encoding.AESEncrypt(strValue, aesKey)
+		if err != nil {
+			return "", err
+		}
+
+		return "'aes::" + encrypted + "'", nil
+	}
+}
+
+func parseValue(value any, opts ...AfterParsedStrValueOption) (string, error) {
 	valueValue := reflect.ValueOf(value)
 
 	if !valueValue.IsValid() {
@@ -26,34 +40,99 @@ func parseValue(value any) (string, error) {
 		valueValue = valueValue.Elem()
 	}
 
+	var parsedValue string
+
 	switch v := valueValue.Interface().(type) {
 	case string:
-		return "'" + v + "'", nil
+		parsedValue = v
+
+		if opts == nil || len(opts) == 0 {
+			return "'" + parsedValue + "'", nil
+		}
 	case bool:
-		return strconv.FormatBool(v), nil
+		parsedValue = strconv.FormatBool(v)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case time.Time:
-		return "'" + v.Format(timeWriteFormat) + "'", nil
+		parsedValue = v.Format(timeWriteFormat)
+
+		if opts == nil || len(opts) == 0 {
+			return "'" + parsedValue + "'", nil
+		}
 	case int:
-		return strconv.Itoa(v), nil
+		parsedValue = strconv.Itoa(v)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case int8:
-		return strconv.FormatInt(int64(v), 10), nil
+		parsedValue = strconv.FormatInt(int64(v), 10)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case int16:
-		return strconv.FormatInt(int64(v), 10), nil
+		parsedValue = strconv.FormatInt(int64(v), 10)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case int32:
-		return strconv.FormatInt(int64(v), 10), nil
+		parsedValue = strconv.FormatInt(int64(v), 10)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case int64:
-		return strconv.FormatInt(v, 10), nil
+		parsedValue = strconv.FormatInt(v, 10)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case uint:
-		return strconv.FormatUint(uint64(v), 10), nil
+		parsedValue = strconv.FormatUint(uint64(v), 10)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case uint8:
-		return strconv.FormatUint(uint64(v), 10), nil
+		parsedValue = strconv.FormatUint(uint64(v), 10)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case uint16:
-		return strconv.FormatUint(uint64(v), 10), nil
+		parsedValue = strconv.FormatUint(uint64(v), 10)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case uint32:
-		return strconv.FormatUint(uint64(v), 10), nil
+		parsedValue = strconv.FormatUint(uint64(v), 10)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	case uint64:
-		return strconv.FormatUint(v, 10), nil
+		parsedValue = strconv.FormatUint(v, 10)
+
+		if opts == nil || len(opts) == 0 {
+			return parsedValue, nil
+		}
 	default:
 		return "", errors.New("不支持的类型")
 	}
+
+	for _, opt := range opts {
+		innerParsedValue, err := opt(parsedValue)
+		if err != nil {
+			return "", err
+		}
+
+		parsedValue = innerParsedValue
+	}
+
+	return parsedValue, nil
 }

+ 26 - 17
test/sdk_test.go

@@ -16,7 +16,7 @@ import (
 
 type Class struct {
 	ID              string     `mapstructure:"id"`
-	Name            string     `sqlmapping:"updateClear;" mapstructure:"name"`
+	Name            string     `sqlmapping:"updateClear;aes:@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L;" mapstructure:"name"`
 	StudentNum      int        `sqlmapping:"column:student_num;notUpdate;" mapstructure:"student_num_alias"`
 	GraduatedTime   time.Time  `mapstructure:"graduated_time"`
 	CreatedTime     *time.Time `mapstructure:"created_time"`
@@ -358,16 +358,25 @@ func TestSqlMapping(t *testing.T) {
 			t.Fatal("列名不正确")
 		}
 
-		if sqlColumn.IsKey && sqlColumn.Name != "id" {
-			t.Fatal("键字段不正确")
+		if sqlColumn.Name == "id" {
+			if !sqlColumn.IsKey || sqlColumn.CanUpdate || sqlColumn.CanUpdateClear ||
+				strutils.IsStringNotEmpty(sqlColumn.AESKey) {
+				t.Fatal("id字段Tag不正确")
+			}
 		}
 
-		if !sqlColumn.CanUpdate && (sqlColumn.Name != "id" && sqlColumn.Name != "student_num") {
-			t.Fatal("不可更新字段不正确")
+		if sqlColumn.Name == "name" {
+			if sqlColumn.IsKey || !sqlColumn.CanUpdate || !sqlColumn.CanUpdateClear ||
+				strutils.IsStringEmpty(sqlColumn.AESKey) || sqlColumn.AESKey != "@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L" {
+				t.Fatal("name字段Tag不正确")
+			}
 		}
 
-		if sqlColumn.CanUpdateClear && sqlColumn.Name != "name" {
-			t.Fatal("可清除字段不正确")
+		if sqlColumn.Name == "student_num" {
+			if sqlColumn.IsKey || sqlColumn.CanUpdate || sqlColumn.CanUpdateClear ||
+				strutils.IsStringNotEmpty(sqlColumn.AESKey) {
+				t.Fatal("student_num字段Tag不正确")
+			}
 		}
 	}
 }
@@ -383,7 +392,7 @@ func TestSql(t *testing.T) {
 	insertExecuteParams, err := sql_tpl.InsertExecuteParams{
 		TableName: tableName,
 		TableRows: sql_tpl.NewTableRows().Add("id", classID).
-			Add("name", className).
+			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
 			Add("graduated_time", now).
 			Add("created_time", now).
@@ -459,7 +468,7 @@ func TestSql(t *testing.T) {
 	err = sql.Insert(sdk.GetInstance(), &sql_tpl.InsertExecuteParams{
 		TableName: tableName,
 		TableRows: sql_tpl.NewTableRows().Add("id", classID).
-			Add("name", className).
+			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
 			Add("graduated_time", now).
 			Add("created_time", now).
@@ -472,7 +481,7 @@ func TestSql(t *testing.T) {
 	err = sql.Update(sdk.GetInstance(), &sql_tpl.UpdateExecuteParams{
 		TableName: tableName,
 		TableRows: sql_tpl.NewTableRows().Add("id", classID).
-			Add("name", newClassName).
+			Add("name", newClassName, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", newStudentNum),
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID),
@@ -519,7 +528,7 @@ func TestSql(t *testing.T) {
 		err = sql.Insert(tx, &sql_tpl.InsertExecuteParams{
 			TableName: tableName,
 			TableRows: sql_tpl.NewTableRows().Add("id", classID).
-				Add("name", className).
+				Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 				Add("student_num", studentNum).
 				Add("graduated_time", now).
 				Add("created_time", now).
@@ -532,7 +541,7 @@ func TestSql(t *testing.T) {
 		err = sql.Update(tx, &sql_tpl.UpdateExecuteParams{
 			TableName: tableName,
 			TableRows: sql_tpl.NewTableRows().Add("id", classID).
-				Add("name", newClassName).
+				Add("name", newClassName, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 				Add("student_num", newStudentNum),
 			Conditions: sql_tpl.NewConditions().
 				Equal("id", classID),
@@ -576,7 +585,7 @@ func TestSql(t *testing.T) {
 		SelectColumns: []string{"id", "name"},
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID).
-			Equal("name", className).
+			Equal("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Equal("student_num", studentNum),
 		PageNo:   0,
 		PageSize: 0,
@@ -608,7 +617,7 @@ func TestSql(t *testing.T) {
 		SelectColumns: []string{"id", "name"},
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID).
-			Equal("name", className).
+			Equal("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Equal("student_num", studentNum),
 	})
 	if err != nil {
@@ -633,7 +642,7 @@ func TestSql(t *testing.T) {
 		TableName: tableName,
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID).
-			Equal("name", className).
+			Equal("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Equal("student_num", studentNum),
 	})
 	if err != nil {
@@ -648,7 +657,7 @@ func TestSql(t *testing.T) {
 		TableName: tableName,
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID).
-			Equal("name", className).
+			Equal("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Equal("student_num", studentNum),
 	})
 	if err != nil {
@@ -663,7 +672,7 @@ func TestSql(t *testing.T) {
 		TableName: tableName,
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID).
-			Equal("name", className).
+			Equal("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Equal("student_num", studentNum),
 	})
 	if err != nil {