From bce3d25b3b179e745cb7e7621b96ffb5bca48897 Mon Sep 17 00:00:00 2001 From: SugarMGP <2350745751@qq.com> Date: Tue, 3 Dec 2024 20:10:46 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84AES=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/config/config.go | 69 ----------------------- app/config/encrypt.go | 24 -------- app/config/init.go | 19 ------- app/midwares/checkInit.go | 17 ------ app/models/config.go | 11 ---- app/services/userService/create.go | 5 +- app/services/userService/getUser.go | 17 ++++-- app/services/userService/utils.go | 22 +++++--- app/utils/aes/aes.go | 86 +++++++++++++++++++++++++++++ app/utils/aestools.go | 74 ------------------------- config.example.yaml | 3 + config/database/migrations.go | 1 - config/router/router.go | 2 +- docs/README.md | 7 --- main.go | 4 ++ 15 files changed, 125 insertions(+), 236 deletions(-) delete mode 100644 app/config/config.go delete mode 100644 app/config/encrypt.go delete mode 100644 app/config/init.go delete mode 100644 app/midwares/checkInit.go delete mode 100644 app/models/config.go create mode 100644 app/utils/aes/aes.go delete mode 100644 app/utils/aestools.go diff --git a/app/config/config.go b/app/config/config.go deleted file mode 100644 index 8a4366a..0000000 --- a/app/config/config.go +++ /dev/null @@ -1,69 +0,0 @@ -package config - -import ( - "context" - "errors" - "time" - - "4u-go/app/models" - "4u-go/config/database" - "4u-go/config/redis" - "gorm.io/gorm" -) - -// 上下文用于 Redis 操作 -var ctx = context.Background() - -// getConfig 从 Redis 获取配置,如果不存在则从数据库中获取,并缓存到 Redis -func getConfig(key string) string { - val, err := redis.GlobalClient.Get(ctx, key).Result() - if err == nil { - return val - } - print(err) - var config = &models.Config{} - database.DB.Model(models.Config{}).Where( - &models.Config{ - Key: key, - }).First(&config) - - redis.GlobalClient.Set(ctx, key, config.Value, 0) - return config.Value -} - -// setConfig 设置指定的配置项,如果不存在则创建新的配置。 -func setConfig(key, value string) error { - redis.GlobalClient.Set(ctx, key, value, 0) - var config models.Config - result := database.DB.Where("`key` = ?", key).First(&config) - if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { - return result.Error - } - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - config = models.Config{ - Key: key, - Value: value, - UpdateTime: time.Now(), - } - result = database.DB.Create(&config) - } else { - config.Value = value - config.UpdateTime = time.Now() - result = database.DB.Updates(&config) - } - return result.Error -} - -// checkConfig 检查指定的配置项是否存在于 Redis 中。 -func checkConfig(key string) bool { - intCmd := redis.GlobalClient.Exists(ctx, key) - return intCmd.Val() == 1 -} - -func delConfig(key string) error { - redis.GlobalClient.Del(ctx, key) - res := database.DB.Where(&models.Config{ - Key: key, - }).Delete(models.Config{}) - return res.Error -} diff --git a/app/config/encrypt.go b/app/config/encrypt.go deleted file mode 100644 index ee8813f..0000000 --- a/app/config/encrypt.go +++ /dev/null @@ -1,24 +0,0 @@ -package config - -// encryptKey 是用于加密的配置键。 -const encryptKey = "encryptKey" - -// SetEncryptKey 设置加密密钥的值 -func SetEncryptKey(value string) error { - return setConfig(encryptKey, value) -} - -// GetEncryptKey 获取当前配置的加密密钥值 -func GetEncryptKey() string { - return getConfig(encryptKey) -} - -// IsSetEncryptKey 检查是否设置了加密密钥 -func IsSetEncryptKey() bool { - return checkConfig(encryptKey) -} - -// DelEncryptKey 删除加密密钥 -func DelEncryptKey() error { - return delConfig(encryptKey) -} diff --git a/app/config/init.go b/app/config/init.go deleted file mode 100644 index 196ffc6..0000000 --- a/app/config/init.go +++ /dev/null @@ -1,19 +0,0 @@ -package config - -// initKey 是用于初始化状态的配置键。 -const initKey = "initKey" - -// SetInit 设置初始化状态为 "True"。 -func SetInit() error { - return setConfig(initKey, "True") -} - -// ResetInit 设置初始化状态为 "False"。 -func ResetInit() error { - return setConfig(initKey, "False") -} - -// GetInit 获取当前的初始化状态。 -func GetInit() bool { - return getConfig(initKey) == "True" -} diff --git a/app/midwares/checkInit.go b/app/midwares/checkInit.go deleted file mode 100644 index 2004184..0000000 --- a/app/midwares/checkInit.go +++ /dev/null @@ -1,17 +0,0 @@ -package midwares - -import ( - "4u-go/app/apiException" - "4u-go/app/config" - "github.com/gin-gonic/gin" -) - -// CheckInit 中间件用于检查系统是否已初始化。 -func CheckInit(c *gin.Context) { - inited := config.GetInit() - if !inited { - apiException.AbortWithException(c, apiException.NotInit, nil) - return - } - c.Next() -} diff --git a/app/models/config.go b/app/models/config.go deleted file mode 100644 index 8cefd97..0000000 --- a/app/models/config.go +++ /dev/null @@ -1,11 +0,0 @@ -package models - -import "time" - -// Config 系统配置项的结构体 -type Config struct { - ID uint `gorm:"primaryKey"` // ID 是配置项的唯一标识 - Key string // Key 是配置项的键,必须唯一且不能为空 - Value string // Value 是配置项的值,不能为空 - UpdateTime time.Time `gorm:"comment:'设置时间';type:timestamp"` // UpdateTime 是配置项的最后更新时间 -} diff --git a/app/services/userService/create.go b/app/services/userService/create.go index 0fd0b37..68702b5 100644 --- a/app/services/userService/create.go +++ b/app/services/userService/create.go @@ -33,7 +33,10 @@ func CreateStudentUser( StudentID: studentID, } - EncryptUserKeyInfo(user) + err = EncryptUserKeyInfo(user) + if err != nil { + return nil, err + } res := database.DB.Create(&user) return user, res.Error diff --git a/app/services/userService/getUser.go b/app/services/userService/getUser.go index a521801..3a75fc7 100644 --- a/app/services/userService/getUser.go +++ b/app/services/userService/getUser.go @@ -17,7 +17,10 @@ func GetUserByWechatOpenID(openid string) (*models.User, error) { return nil, result.Error } - DecryptUserKeyInfo(&user) + err := DecryptUserKeyInfo(&user) + if err != nil { + return nil, err + } return &user, nil } @@ -29,11 +32,14 @@ func GetUserByStudentID(sid string) (*models.User, error) { StudentID: sid, }, ).First(&user) - if result.Error != nil { return nil, result.Error } - DecryptUserKeyInfo(&user) + + err := DecryptUserKeyInfo(&user) + if err != nil { + return nil, err + } return &user, nil } @@ -49,6 +55,9 @@ func GetUserByID(id uint) (*models.User, error) { return nil, result.Error } - DecryptUserKeyInfo(&user) + err := DecryptUserKeyInfo(&user) + if err != nil { + return nil, err + } return &user, nil } diff --git a/app/services/userService/utils.go b/app/services/userService/utils.go index 63e8a87..7922419 100644 --- a/app/services/userService/utils.go +++ b/app/services/userService/utils.go @@ -1,24 +1,30 @@ package userService import ( - "4u-go/app/config" "4u-go/app/models" - "4u-go/app/utils" + "4u-go/app/utils/aes" ) // DecryptUserKeyInfo 解密用户信息 -func DecryptUserKeyInfo(user *models.User) { - key := config.GetEncryptKey() +func DecryptUserKeyInfo(user *models.User) error { if user.PhoneNum != "" { - slt := utils.AesDecrypt(user.PhoneNum, key) + slt, err := aes.Decrypt(user.PhoneNum) + if err != nil { + return err + } user.PhoneNum = slt[0 : len(slt)-len(user.StudentID)] } + return nil } // EncryptUserKeyInfo 加密用户信息 -func EncryptUserKeyInfo(user *models.User) { - key := config.GetEncryptKey() +func EncryptUserKeyInfo(user *models.User) error { if user.PhoneNum != "" { - user.PhoneNum = utils.AesEncrypt(user.PhoneNum+user.StudentID, key) + num, err := aes.Encrypt(user.PhoneNum + user.StudentID) + if err != nil { + return err + } + user.PhoneNum = num } + return nil } diff --git a/app/utils/aes/aes.go b/app/utils/aes/aes.go new file mode 100644 index 0000000..c7954ea --- /dev/null +++ b/app/utils/aes/aes.go @@ -0,0 +1,86 @@ +package aes + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "errors" + + "4u-go/config/config" +) + +var encryptKey []byte + +// Init 读入 AES 密钥配置 +func Init() error { + key := config.Config.GetString("aes.encryptKey") + if len(key) != 16 && len(key) != 24 && len(key) != 32 { + return errors.New("AES 密钥长度必须为 16、24 或 32 字节") + } + encryptKey = []byte(key) + return nil +} + +// Encrypt AES 加密 +func Encrypt(orig string) (string, error) { + origData := []byte(orig) + + // 分组秘钥 + block, err := aes.NewCipher(encryptKey) + if err != nil { + return "", err + } + + // 进行 PKCS7 填充 + blockSize := block.BlockSize() + origData = PKCS7Padding(origData, blockSize) + + // 使用 CBC 加密模式 + blockMode := cipher.NewCBCEncrypter(block, encryptKey[:blockSize]) + cryted := make([]byte, len(origData)) + blockMode.CryptBlocks(cryted, origData) + + // 使用 RawURLEncoding 编码为 Base64,适合放入 URL + return base64.RawURLEncoding.EncodeToString(cryted), nil +} + +// Decrypt AES 解密 +func Decrypt(cryted string) (string, error) { + // 解码 Base64 字符串 + crytedByte, err := base64.RawURLEncoding.DecodeString(cryted) + if err != nil { + return "", err + } + + // 分组秘钥 + block, err := aes.NewCipher(encryptKey) + if err != nil { + return "", err + } + + // CBC 模式解密 + blockSize := block.BlockSize() + blockMode := cipher.NewCBCDecrypter(block, encryptKey[:blockSize]) + orig := make([]byte, len(crytedByte)) + blockMode.CryptBlocks(orig, crytedByte) + + // 去除 PKCS7 填充 + orig = PKCS7UnPadding(orig) + + return string(orig), nil +} + +// PKCS7Padding 填充数据,使长度为 blockSize 的倍数 +func PKCS7Padding(data []byte, blockSize int) []byte { + padding := blockSize - len(data)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padtext...) +} + +// PKCS7UnPadding 去除填充 +func PKCS7UnPadding(origData []byte) []byte { + length := len(origData) + unpadding := int(origData[length-1]) + return origData[:(length - unpadding)] +} diff --git a/app/utils/aestools.go b/app/utils/aestools.go deleted file mode 100644 index b79307b..0000000 --- a/app/utils/aestools.go +++ /dev/null @@ -1,74 +0,0 @@ -package utils - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "encoding/base64" - "fmt" -) - -// AesEncrypt aes加密 -func AesEncrypt(orig string, key string) string { - // 转成字节数组 - origData := []byte(orig) - k := []byte(key) - - // 分组秘钥 - block, err := aes.NewCipher(k) - if err != nil { - panic(fmt.Sprintf("key 长度必须 16/24/32长度: %s", err.Error())) - } - // 获取秘钥块的长度 - blockSize := block.BlockSize() - // 补全码 - origData = PKCS7Padding(origData, blockSize) - // 加密模式 - blockMode := cipher.NewCBCEncrypter(block, k[:blockSize]) - // 创建数组 - cryted := make([]byte, len(origData)) - // 加密 - blockMode.CryptBlocks(cryted, origData) - // 使用RawURLEncoding 不要使用StdEncoding - // 不要使用StdEncoding 放在url参数中回导致错误 - return base64.RawURLEncoding.EncodeToString(cryted) -} - -// AesDecrypt aes解密 -func AesDecrypt(cryted string, key string) string { - // 使用RawURLEncoding 不要使用StdEncoding - // 不要使用StdEncoding 放在url参数中回导致错误 - crytedByte, _ := base64.RawURLEncoding.DecodeString(cryted) //nolint:errcheck - k := []byte(key) - - // 分组秘钥 - block, err := aes.NewCipher(k) - if err != nil { - panic(fmt.Sprintf("key 长度必须 16/24/32长度: %s", err.Error())) - } - // 获取秘钥块的长度 - blockSize := block.BlockSize() - // 加密模式 - blockMode := cipher.NewCBCDecrypter(block, k[:blockSize]) - // 创建数组 - orig := make([]byte, len(crytedByte)) - // 解密 - blockMode.CryptBlocks(orig, crytedByte) - // 去补全码 - orig = PKCS7UnPadding(orig) - return string(orig) -} - -// PKCS7Padding 补码 -func PKCS7Padding(ciphertext []byte, blocksize int) []byte { - padding := blocksize - len(ciphertext)%blocksize - padtext := bytes.Repeat([]byte{byte(padding)}, padding) - return append(ciphertext, padtext...) -} - -// PKCS7UnPadding 去码 -func PKCS7UnPadding(origData []byte) []byte { - length := len(origData) - unpadding := int(origData[length-1]) - return origData[:(length - unpadding)] -} diff --git a/config.example.yaml b/config.example.yaml index 7c7c26e..3f7d8b4 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -29,6 +29,9 @@ user: admin: key: # 管理员密钥 +aes: # AES 密钥,长度必须为 16、24 或 32 字节 + encryptKey: + minio: # minio 存储配置 accessKey: # 用于身份验证的访问密钥 secretKey: # 用于身份验证的秘密密钥 diff --git a/config/database/migrations.go b/config/database/migrations.go index 3716902..6eed935 100755 --- a/config/database/migrations.go +++ b/config/database/migrations.go @@ -7,7 +7,6 @@ import ( func autoMigrate(db *gorm.DB) error { return db.AutoMigrate( - &models.Config{}, &models.User{}, &models.Announcement{}, &models.Activity{}, diff --git a/config/router/router.go b/config/router/router.go index df150ea..dc68570 100644 --- a/config/router/router.go +++ b/config/router/router.go @@ -16,7 +16,7 @@ import ( func Init(r *gin.Engine) { const pre = "/api" - api := r.Group(pre, midwares.CheckInit) + api := r.Group(pre) { user := api.Group("/user") { diff --git a/docs/README.md b/docs/README.md index 5630ad5..bbccada 100644 --- a/docs/README.md +++ b/docs/README.md @@ -68,13 +68,6 @@ cp config.example.yaml config.yaml copy config.example.yaml config.yaml ``` -在配置数据库后,向 config 表插入如下两条记录来完成初始化 - -| key | value | -|---|---| -| encryptKey | *16位的整数倍的字符串 | -| initKey | True | - 3. 启动程序 ```shell diff --git a/main.go b/main.go index 2375e53..68ff7fc 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "4u-go/app/midwares" + "4u-go/app/utils/aes" "4u-go/app/utils/log" "4u-go/config/config" "4u-go/config/database" @@ -27,6 +28,9 @@ func main() { r.NoRoute(midwares.HandleNotFound) log.ZapInit() redis.Init() + if err := aes.Init(); err != nil { + zap.L().Fatal(err.Error()) + } if err := database.Init(); err != nil { zap.L().Fatal(err.Error()) }