From a16776f9148558bb3cee07d34c075ea7c029aa89 Mon Sep 17 00:00:00 2001 From: icey-yu <1186114839@qq.com> Date: Tue, 15 Oct 2024 15:56:22 +0800 Subject: [PATCH] fix: cherry-pick --- internal/api/admin/admin.go | 1 - internal/rpc/chat/login.go | 19 +--- internal/rpc/chat/update.go | 55 +++++++++++ internal/rpc/chat/user.go | 80 ++-------------- internal/rpc/chat/utils.go | 5 - pkg/common/constant/constant.go | 7 ++ pkg/common/db/database/chat.go | 19 +++- pkg/common/db/table/chat/credential.go | 2 + pkg/email/mail_test.go | 124 ++++++++++--------------- 9 files changed, 140 insertions(+), 172 deletions(-) diff --git a/internal/api/admin/admin.go b/internal/api/admin/admin.go index 3474fb35d..6bc4b921c 100644 --- a/internal/api/admin/admin.go +++ b/internal/api/admin/admin.go @@ -13,7 +13,6 @@ import ( "github.com/openimsdk/chat/internal/api/util" "github.com/openimsdk/chat/pkg/common/apistruct" "github.com/openimsdk/chat/pkg/common/config" - chatConstant "github.com/openimsdk/chat/pkg/common/constant" "github.com/openimsdk/chat/pkg/common/imapi" "github.com/openimsdk/chat/pkg/common/mctx" "github.com/openimsdk/chat/pkg/common/xlsx" diff --git a/internal/rpc/chat/login.go b/internal/rpc/chat/login.go index def0ddc51..e672ccddf 100644 --- a/internal/rpc/chat/login.go +++ b/internal/rpc/chat/login.go @@ -227,9 +227,6 @@ func (o *chatSvr) RegisterUser(ctx context.Context, req *chat.RegisterUserReq) ( if !o.AllowRegister { return nil, errs.ErrNoPermission.WrapMsg("register user is disabled") } - if req.User.UserType != constant.CommonUser { - return nil, errs.ErrNoPermission.WrapMsg("can only register common user") - } if req.User.UserID != "" { return nil, errs.ErrNoPermission.WrapMsg("only admin can set user id") } @@ -284,9 +281,8 @@ func (o *chatSvr) RegisterUser(ctx context.Context, req *chat.RegisterUserReq) ( } } var ( - credentials []*chatdb.Credential - allowChangeRule = datautil.If(req.User.UserType == constant.CommonUser, true, false) - registerType int32 + credentials []*chatdb.Credential + registerType int32 ) if req.User.PhoneNumber != "" { @@ -295,7 +291,7 @@ func (o *chatSvr) RegisterUser(ctx context.Context, req *chat.RegisterUserReq) ( UserID: req.User.UserID, Account: BuildCredentialPhone(req.User.AreaCode, req.User.PhoneNumber), Type: constant.CredentialPhone, - AllowChange: allowChangeRule, + AllowChange: true, }) } @@ -304,7 +300,7 @@ func (o *chatSvr) RegisterUser(ctx context.Context, req *chat.RegisterUserReq) ( UserID: req.User.UserID, Account: req.User.Account, Type: constant.CredentialAccount, - AllowChange: allowChangeRule, + AllowChange: true, }) registerType = constant.AccountRegister } @@ -315,7 +311,7 @@ func (o *chatSvr) RegisterUser(ctx context.Context, req *chat.RegisterUserReq) ( UserID: req.User.UserID, Account: req.User.Email, Type: constant.CredentialEmail, - AllowChange: allowChangeRule, + AllowChange: true, }) } register := &chatdb.Register{ @@ -352,11 +348,6 @@ func (o *chatSvr) RegisterUser(ctx context.Context, req *chat.RegisterUserReq) ( AllowAddFriend: constant.DefaultAllowAddFriend, RegisterType: registerType, } - if req.User.UserType == constant.OrgUser { - attribute.EnglishName = datautil.ToPtr(req.User.EnglishName.GetValue()) - attribute.Station = datautil.ToPtr(req.User.Station.GetValue()) - attribute.Telephone = datautil.ToPtr(req.User.Telephone.GetValue()) - } if err := o.Database.RegisterUser(ctx, register, account, attribute, credentials); err != nil { return nil, err } diff --git a/internal/rpc/chat/update.go b/internal/rpc/chat/update.go index 4579f1c3e..8f69de2f2 100644 --- a/internal/rpc/chat/update.go +++ b/internal/rpc/chat/update.go @@ -15,6 +15,8 @@ package chat import ( + "github.com/openimsdk/chat/pkg/common/constant" + chatdb "github.com/openimsdk/chat/pkg/common/db/table/chat" "time" "github.com/openimsdk/tools/errs" @@ -68,3 +70,56 @@ func ToDBAttributeUpdate(req *chat.UpdateUserInfoReq) (map[string]any, error) { //} return update, nil } + +func ToDBCredentialUpdate(req *chat.UpdateUserInfoReq, allowChange bool) ([]*chatdb.Credential, []*chatdb.Credential, error) { + update := make([]*chatdb.Credential, 0) + del := make([]*chatdb.Credential, 0) + if req.Account != nil { + if req.Account.GetValue() == "" { + del = append(del, &chatdb.Credential{ + UserID: req.UserID, + Type: constant.CredentialAccount, + }) + } else { + update = append(update, &chatdb.Credential{ + UserID: req.UserID, + Account: req.Account.GetValue(), + Type: constant.CredentialAccount, + AllowChange: allowChange, + }) + } + } + + if req.Email != nil { + if req.Email.GetValue() == "" { + del = append(del, &chatdb.Credential{ + UserID: req.UserID, + Type: constant.CredentialEmail, + }) + } else { + update = append(update, &chatdb.Credential{ + UserID: req.UserID, + Account: req.Account.GetValue(), + Type: constant.CredentialEmail, + AllowChange: allowChange, + }) + } + } + if req.PhoneNumber != nil { + if req.PhoneNumber.GetValue() == "" { + del = append(del, &chatdb.Credential{ + UserID: req.UserID, + Type: constant.CredentialPhone, + }) + } else { + update = append(update, &chatdb.Credential{ + UserID: req.UserID, + Account: BuildCredentialPhone(req.AreaCode.GetValue(), req.PhoneNumber.GetValue()), + Type: constant.CredentialPhone, + AllowChange: allowChange, + }) + } + } + + return update, del, nil +} diff --git a/internal/rpc/chat/user.go b/internal/rpc/chat/user.go index 3ea07a520..057cc6d78 100644 --- a/internal/rpc/chat/user.go +++ b/internal/rpc/chat/user.go @@ -19,7 +19,6 @@ import ( "errors" "github.com/openimsdk/chat/pkg/eerrs" "github.com/openimsdk/protocol/wrapperspb" - "github.com/openimsdk/tools/utils/datautil" "github.com/openimsdk/tools/utils/stringutil" "strconv" "strings" @@ -162,52 +161,8 @@ func (o *chatSvr) UpdateUserInfo(ctx context.Context, req *chat.UpdateUserInfoRe return nil, err } - isOrgUser, err := o.Database.IsOrgUser(ctx, req.UserID) - if err != nil { - return nil, err - } - switch userType { case constant.NormalUser: - if isOrgUser { - if req.AreaCode != nil { - return nil, errs.ErrNoPermission.WrapMsg("areaCode can not be updated") - } - if req.PhoneNumber != nil { - return nil, errs.ErrNoPermission.WrapMsg("phoneNumber can not be updated") - } - if req.Account != nil { - return nil, errs.ErrNoPermission.WrapMsg("account can not be updated") - } - if req.Email != nil { - return nil, errs.ErrNoPermission.WrapMsg("email can not be updated") - } - if req.Level != nil { - return nil, errs.ErrNoPermission.WrapMsg("level can not be updated") - } - - if req.Nickname != nil { - return nil, errs.ErrNoPermission.WrapMsg("nickname can not be updated") - } - if req.FaceURL != nil { - return nil, errs.ErrNoPermission.WrapMsg("faceURL can not be updated") - } - if req.Gender != nil { - return nil, errs.ErrNoPermission.WrapMsg("gender can not be updated") - } - if req.Birth != nil { - return nil, errs.ErrNoPermission.WrapMsg("birth can not be updated") - } - if req.EnglishName != nil { - return nil, errs.ErrNoPermission.WrapMsg("englishName can not be updated") - } - if req.Station != nil { - return nil, errs.ErrNoPermission.WrapMsg("station can not be updated") - } - if req.Telephone != nil { - return nil, errs.ErrNoPermission.WrapMsg("telephone can not be updated") - } - } if req.RegisterType != nil { return nil, errs.ErrNoPermission.WrapMsg("registerType can not be updated") } @@ -220,11 +175,11 @@ func (o *chatSvr) UpdateUserInfo(ctx context.Context, req *chat.UpdateUserInfoRe return nil, errs.ErrNoPermission.WrapMsg("user type error") } - update, err := ToDBAttributeUpdate(req, isOrgUser) + update, err := ToDBAttributeUpdate(req) if err != nil { return nil, err } - credUpdate, credDel, err := ToDBCredentialUpdate(req, !isOrgUser) + credUpdate, credDel, err := ToDBCredentialUpdate(req, true) if err != nil { return nil, err } @@ -284,8 +239,7 @@ func (o *chatSvr) AddUserAccount(ctx context.Context, req *chat.AddUserAccountRe } var ( - credentials []*chatdb.Credential - allowChangeRule = datautil.If(req.User.UserType == constant.CommonUser, true, false) + credentials []*chatdb.Credential ) if req.User.PhoneNumber != "" { @@ -293,7 +247,7 @@ func (o *chatSvr) AddUserAccount(ctx context.Context, req *chat.AddUserAccountRe UserID: req.User.UserID, Account: BuildCredentialPhone(req.User.AreaCode, req.User.PhoneNumber), Type: constant.CredentialPhone, - AllowChange: allowChangeRule, + AllowChange: true, }) } @@ -302,7 +256,7 @@ func (o *chatSvr) AddUserAccount(ctx context.Context, req *chat.AddUserAccountRe UserID: req.User.UserID, Account: req.User.Account, Type: constant.CredentialAccount, - AllowChange: allowChangeRule, + AllowChange: true, }) } @@ -311,7 +265,7 @@ func (o *chatSvr) AddUserAccount(ctx context.Context, req *chat.AddUserAccountRe UserID: req.User.UserID, Account: req.User.Email, Type: constant.CredentialEmail, - AllowChange: allowChangeRule, + AllowChange: true, }) } @@ -348,11 +302,6 @@ func (o *chatSvr) AddUserAccount(ctx context.Context, req *chat.AddUserAccountRe AllowAddFriend: constant.DefaultAllowAddFriend, } - if req.User.UserType == constant.OrgUser { - attribute.EnglishName = datautil.ToPtr(req.User.EnglishName.GetValue()) - attribute.Station = datautil.ToPtr(req.User.Station.GetValue()) - attribute.Telephone = datautil.ToPtr(req.User.Telephone.GetValue()) - } if err := o.Database.RegisterUser(ctx, register, account, attribute, credentials); err != nil { return nil, err } @@ -373,23 +322,6 @@ func (o *chatSvr) SearchUserPublicInfo(ctx context.Context, req *chat.SearchUser }, nil } -func (o *chatSvr) SearchUserID(ctx context.Context, req *chat.SearchUserIDReq) (*chat.SearchUserIDResp, error) { - if req.Pagination == nil { - return nil, errs.ErrArgs.WrapMsg("pagination is nil") - } - if _, _, err := mctx.Check(ctx); err != nil { - return nil, err - } - total, userIDs, err := o.Database.SearchID(ctx, req.Keyword, req.OrUserIDs, req.Pagination) - if err != nil { - return nil, err - } - return &chat.SearchUserIDResp{ - Total: uint32(total), - UserIDs: userIDs, - }, nil -} - func (o *chatSvr) FindUserFullInfo(ctx context.Context, req *chat.FindUserFullInfoReq) (*chat.FindUserFullInfoResp, error) { if _, _, err := mctx.Check(ctx); err != nil { return nil, err diff --git a/internal/rpc/chat/utils.go b/internal/rpc/chat/utils.go index 403ce4536..b8c1f1740 100644 --- a/internal/rpc/chat/utils.go +++ b/internal/rpc/chat/utils.go @@ -7,7 +7,6 @@ import ( "github.com/openimsdk/chat/pkg/eerrs" "github.com/openimsdk/chat/pkg/protocol/chat" "github.com/openimsdk/chat/pkg/protocol/common" - "github.com/openimsdk/protocol/wrapperspb" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/utils/datautil" "github.com/openimsdk/tools/utils/stringutil" @@ -52,10 +51,6 @@ func DbToPbUserFullInfo(attribute *table.Attribute) *common.UserFullInfo { AllowVibration: attribute.AllowVibration, GlobalRecvMsgOpt: attribute.GlobalRecvMsgOpt, RegisterType: attribute.RegisterType, - - EnglishName: wrapperspb.StringPtr(attribute.EnglishName), - Station: wrapperspb.StringPtr(attribute.Station), - Telephone: wrapperspb.StringPtr(attribute.Telephone), } } diff --git a/pkg/common/constant/constant.go b/pkg/common/constant/constant.go index f82b6dd16..0c0acec5d 100644 --- a/pkg/common/constant/constant.go +++ b/pkg/common/constant/constant.go @@ -104,3 +104,10 @@ const ( GenderMale = 1 // male GenderUnknown = 2 // unknown ) + +// Credential Type +const ( + CredentialAccount = iota + CredentialPhone + CredentialEmail +) diff --git a/pkg/common/db/database/chat.go b/pkg/common/db/database/chat.go index 03431b32d..55acd558a 100644 --- a/pkg/common/db/database/chat.go +++ b/pkg/common/db/database/chat.go @@ -31,7 +31,7 @@ import ( type ChatDatabaseInterface interface { GetUser(ctx context.Context, userID string) (account *chatdb.Account, err error) - UpdateUseInfo(ctx context.Context, userID string, attribute map[string]any) (err error) + UpdateUseInfo(ctx context.Context, userID string, attribute map[string]any, updateCred, delCred []*chatdb.Credential) (err error) FindAttribute(ctx context.Context, userIDs []string) ([]*chatdb.Attribute, error) FindAttributeByAccount(ctx context.Context, accounts []string) ([]*chatdb.Attribute, error) TakeAttributeByPhone(ctx context.Context, areaCode string, phoneNumber string) (*chatdb.Attribute, error) @@ -114,8 +114,21 @@ func (o *ChatDatabase) GetUser(ctx context.Context, userID string) (account *cha return o.account.Take(ctx, userID) } -func (o *ChatDatabase) UpdateUseInfo(ctx context.Context, userID string, attribute map[string]any) (err error) { - return o.attribute.Update(ctx, userID, attribute) +func (o *ChatDatabase) UpdateUseInfo(ctx context.Context, userID string, attribute map[string]any, updateCred, delCred []*chatdb.Credential) (err error) { + return o.tx.Transaction(ctx, func(ctx context.Context) error { + if err = o.attribute.Update(ctx, userID, attribute); err != nil { + return err + } + for _, credential := range updateCred { + if err = o.credential.CreateOrUpdateAccount(ctx, credential); err != nil { + return err + } + } + if err = o.credential.DeleteByUserIDType(ctx, delCred...); err != nil { + return err + } + return nil + }) } func (o *ChatDatabase) FindAttribute(ctx context.Context, userIDs []string) ([]*chatdb.Attribute, error) { diff --git a/pkg/common/db/table/chat/credential.go b/pkg/common/db/table/chat/credential.go index 08378aa24..e4d23d6c8 100644 --- a/pkg/common/db/table/chat/credential.go +++ b/pkg/common/db/table/chat/credential.go @@ -18,6 +18,7 @@ func (Credential) TableName() string { type CredentialInterface interface { Create(ctx context.Context, credential ...*Credential) error + CreateOrUpdateAccount(ctx context.Context, credential *Credential) error Update(ctx context.Context, userID string, data map[string]any) error Find(ctx context.Context, userID string) ([]*Credential, error) FindAccount(ctx context.Context, accounts []string) ([]*Credential, error) @@ -27,4 +28,5 @@ type CredentialInterface interface { SearchNormalUser(ctx context.Context, keyword string, forbiddenID []string, pagination pagination.Pagination) (int64, []*Credential, error) SearchUser(ctx context.Context, keyword string, userIDs []string, pagination pagination.Pagination) (int64, []*Credential, error) Delete(ctx context.Context, userIDs []string) error + DeleteByUserIDType(ctx context.Context, credentials ...*Credential) error } diff --git a/pkg/email/mail_test.go b/pkg/email/mail_test.go index c6824510a..0f9ebe348 100644 --- a/pkg/email/mail_test.go +++ b/pkg/email/mail_test.go @@ -1,77 +1,51 @@ -// Copyright © 2023 OpenIM open source community. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package email -import ( - "context" - "errors" - "fmt" - "io/ioutil" - "os" - "testing" - - "github.com/openimsdk/chat/pkg/common/config" - "gopkg.in/yaml.v3" -) - -func TestEmail(T *testing.T) { - if err := InitConfig(); err != nil { - fmt.Fprintf(os.Stderr, "\n\nexit -1: \n%+v\n\n", err) - os.Exit(-1) - } - tests := []struct { - name string - ctx context.Context - mail string - code string - want error - }{ - { - name: "success send email", - ctx: context.Background(), - mail: "test@gmail.com", - code: "5555", - want: errors.New("nil"), - }, - { - name: "fail send email", - ctx: context.Background(), - mail: "", - code: "5555", - want: errors.New("dial tcp :0: connectex: The requested address is not valid in its context."), - }, - } - mail := NewMail() - - for _, tt := range tests { - T.Run(tt.name, func(t *testing.T) { - if got := mail.SendMail(tt.ctx, tt.mail, tt.code); errors.Is(got, tt.want) { - t.Errorf("%v have a err,%v", tt.name, tt.want) - } - }) - } -} - -func InitConfig() error { - yam, err := ioutil.ReadFile("../../config/config.yaml") - if err != nil { - return err - } - err = yaml.Unmarshal(yam, &config.Config) - if err != nil { - return err - } - return nil -} +//func TestEmail(T *testing.T) { +// if err := InitConfig(); err != nil { +// fmt.Fprintf(os.Stderr, "\n\nexit -1: \n%+v\n\n", err) +// os.Exit(-1) +// } +// tests := []struct { +// name string +// ctx context.Context +// mail string +// code string +// want error +// }{ +// { +// name: "success send email", +// ctx: context.Background(), +// mail: "test@gmail.com", +// code: "5555", +// want: errors.New("nil"), +// }, +// { +// name: "fail send email", +// ctx: context.Background(), +// mail: "", +// code: "5555", +// want: errors.New("dial tcp :0: connectex: The requested address is not valid in its context."), +// }, +// } +// mail := NewMail() +// +// for _, tt := range tests { +// T.Run(tt.name, func(t *testing.T) { +// if got := mail.SendMail(tt.ctx, tt.mail, tt.code); errors.Is(got, tt.want) { +// t.Errorf("%v have a err,%v", tt.name, tt.want) +// } +// }) +// } +//} +// +//func InitConfig() error { +// yam, err := ioutil.ReadFile("../../config/config.yaml") +// if err != nil { +// return err +// } +// err = yaml.Unmarshal(yam, &config.Config) +// if err != nil { +// return err +// } +// return nil +//}