package encoding

import (
	"bytes"
	"crypto/aes"
	"crypto/cipher"
	"crypto/md5"
	"crypto/sha1"
	"crypto/sha256"
	"encoding/base64"
	"encoding/hex"
	"fmt"
)

func AESEncrypt(originData string, key string) (string, error) {
	keyBytes := []byte(key)

	block, err := aes.NewCipher(keyBytes)
	if err != nil {
		return "", err
	}

	originDataBytes := []byte(originData)
	blockSize := block.BlockSize()
	originDataBytes = pkcs7Padding(originDataBytes, blockSize)
	blockMode := cipher.NewCBCEncrypter(block, keyBytes[:blockSize])
	encrypted := make([]byte, len(originDataBytes))
	blockMode.CryptBlocks(encrypted, originDataBytes)

	return base64.StdEncoding.EncodeToString(encrypted), nil
}

func AESDecrypt(encrypted string, key string) (string, error) {
	keyBytes := []byte(key)

	block, err := aes.NewCipher(keyBytes)
	if err != nil {
		return "", err
	}

	decoded, err := base64.StdEncoding.DecodeString(encrypted)
	if err != nil {
		return "", err
	}

	blockSize := block.BlockSize()
	blockMode := cipher.NewCBCDecrypter(block, keyBytes[:blockSize])
	originData := make([]byte, len(decoded))
	blockMode.CryptBlocks(originData, decoded)
	originData = pkcs7UnPadding(originData)

	return string(originData), nil
}

func MD5(origin string, salt string) string {
	if salt != "" {
		origin = origin + salt
	}

	return fmt.Sprintf("%x", md5.New().Sum([]byte(origin)))
}

func SHA256(origin string, salt string) string {
	if salt != "" {
		origin = origin + salt
	}

	return fmt.Sprintf("%x", sha256.Sum256([]byte(origin)))
}

func SHA1(origin string, salt string) string {
	if salt != "" {
		origin = origin + salt
	}

	o := sha1.New()
	o.Write([]byte(origin))
	return hex.EncodeToString(o.Sum(nil))
}

func Base64Encode(content []byte) string {
	return base64.StdEncoding.EncodeToString(content)
}

func Base64Decode(content string) ([]byte, error) {
	return base64.StdEncoding.DecodeString(content)
}

func pkcs7Padding(ciphertext []byte, blockSize int) []byte {
	padding := blockSize - len(ciphertext)%blockSize
	padText := bytes.Repeat([]byte{byte(padding)}, padding)
	return append(ciphertext, padText...)
}

func pkcs7UnPadding(originData []byte) []byte {
	length := len(originData)
	unPadding := int(originData[length-1])
	return originData[:(length - unPadding)]
}