Skip to content

Commit

Permalink
feat(api): permission
Browse files Browse the repository at this point in the history
  • Loading branch information
ttktatakai committed Sep 20, 2024
1 parent 1ea2304 commit 03eec7a
Show file tree
Hide file tree
Showing 15 changed files with 500 additions and 240 deletions.
9 changes: 4 additions & 5 deletions backend/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,14 @@ func RunApi() error {
share.POST("", authAdmin(), c.CreateShare)
share.DELETE("/:id", authAdmin(), c.DeleteShare)
share.GET("", authAdmin(), c.GetShare)
share.GET("/connect/:uuid", c.ConnectShare)
}
// r.GET("/api/oneterm/v1/share/connect/:uuid", Error2Resp(), c.ConnectShare)
r.GET("/api/oneterm/v1/share/connect/:uuid", Error2Resp(), c.ConnectShare)

authorization := v1.Group("/authorization")
authorization := v1.Group("/authorization", authAdmin())
{
authorization.POST("", c.CreateAuthorization)
authorization.POST("", c.UpsertAuthorization)
authorization.DELETE("/:id", c.DeleteAccount)
authorization.PUT("/:id", c.UpdateAuthorization)
authorization.GET("", c.GetAuthorizations)
}
}

Expand Down
43 changes: 25 additions & 18 deletions backend/api/controller/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package controller

import (
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/samber/lo"
Expand All @@ -12,12 +14,17 @@ import (
"gorm.io/gorm"

"github.com/veops/oneterm/acl"
redis "github.com/veops/oneterm/cache"
"github.com/veops/oneterm/conf"
mysql "github.com/veops/oneterm/db"
"github.com/veops/oneterm/model"
"github.com/veops/oneterm/util"
)

const (
kFmtAccountIds = "accountIds-%d"
)

var (
accountPreHooks = []preHook[*model.Account]{
func(ctx *gin.Context, data *model.Account) {
Expand Down Expand Up @@ -145,7 +152,7 @@ func (c *Controller) GetAccounts(ctx *gin.Context) {
}

if info && !acl.IsAdmin(currentUser) {
ids, err := getAccountIdsByAuthorization(ctx)
ids, err := GetAccountIdsByAuthorization(ctx)
if err != nil {
return
}
Expand All @@ -157,29 +164,29 @@ func (c *Controller) GetAccounts(ctx *gin.Context) {
doGet[*model.Account](ctx, !info, db, acl.GetResourceTypeName(conf.RESOURCE_ACCOUNT), accountPostHooks...)
}

func getAccountIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
assetIds, err := getAssertIdsByAuthorization(ctx)
func GetAccountIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx)

k := fmt.Sprintf(kFmtAccountIds, currentUser.GetUid())
if err = redis.Get(ctx, k, &ids); err == nil {
return
}

assetIds, err := GetAssetIdsByAuthorization(ctx)
if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return
}
assets := make([]*model.Asset, 0)
if err = mysql.DB.Model(&model.Asset{}).Where("id IN ?", assetIds).Find(&assets).Error; err != nil {
ss := make([][]int, 0)
if err = mysql.DB.Model(&model.Asset{}).Where("id IN ?", assetIds).Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return
}
authorizationIds, _ := ctx.Value("authorizationIds").([]*model.AuthorizationIds)
parentNodeIds, _ := ctx.Value("parentNodeIds").([]int)
for _, a := range assets {
if lo.Contains(parentNodeIds, a.Id) {
ids = append(ids, lo.Keys(a.Authorization)...)
}
accountIds := lo.Uniq(
lo.Map(lo.Filter(authorizationIds, func(item *model.AuthorizationIds, _ int) bool {
return item.AssetId != nil && *item.AssetId == a.Id && item.AccountId != nil
}),
func(item *model.AuthorizationIds, _ int) int { return *item.AccountId }))
ids = append(ids, accountIds...)
}
ids = lo.Uniq(lo.Flatten(ss))
_, _, accountIds := getIdsByAuthorizationIds(ctx)
ids = lo.Uniq(append(ids, accountIds...))

redis.SetEx(ctx, k, ids, time.Minute)

return
}
99 changes: 74 additions & 25 deletions backend/api/controller/asset.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
package controller

import (
"context"
"fmt"
"net/http"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/spf13/cast"
"go.uber.org/zap"

"github.com/veops/oneterm/acl"
redis "github.com/veops/oneterm/cache"
"github.com/veops/oneterm/conf"
mysql "github.com/veops/oneterm/db"
"github.com/veops/oneterm/logger"
"github.com/veops/oneterm/model"
"github.com/veops/oneterm/schedule"
)

const (
kFmtAssetIds = "assetIds-%d"
kAuthorizationIds = "authorizationIds"
kParentNodeIds = "parentNodeIds"
kAccountIds = "accountIds"
)

var (
assetPreHooks = []preHook[*model.Asset]{
func(ctx *gin.Context, data *model.Asset) {
Expand Down Expand Up @@ -92,7 +101,7 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","), func(s string, _ int) int { return cast.ToInt(s) }))
}
if q, ok := ctx.GetQuery("parent_id"); ok {
parentIds, err := handleParentId(cast.ToInt(q))
parentIds, err := handleParentId(ctx, cast.ToInt(q))
if err != nil {
logger.L().Error("parent id found failed", zap.Error(err))
return
Expand All @@ -101,7 +110,7 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
}

if info && !acl.IsAdmin(currentUser) {
ids, err := getAssertIdsByAuthorization(ctx)
ids, err := GetAssetIdsByAuthorization(ctx)
if err != nil {
return
}
Expand Down Expand Up @@ -147,28 +156,32 @@ func assetPostHookAuth(ctx *gin.Context, data []*model.Asset) {
return
}
authorizationIds, _ := ctx.Value("authorizationIds").([]*model.AuthorizationIds)
parentNodeIds, _ := ctx.Value("parentNodeIds").([]int)
parentNodeIds, _, accountIds := getIdsByAuthorizationIds(ctx)
for _, a := range data {
if lo.Contains(parentNodeIds, a.Id) {
continue
}
accountIds := lo.Uniq(
ids := lo.Uniq(
lo.Map(lo.Filter(authorizationIds, func(item *model.AuthorizationIds, _ int) bool {
return item.AssetId != nil && *item.AssetId == a.Id && item.AccountId != nil
}),
func(item *model.AuthorizationIds, _ int) int { return *item.AccountId }))

for k := range a.Authorization {
if !lo.Contains(accountIds, k) {
if !lo.Contains(ids, k) && !lo.Contains(accountIds, k) {
delete(a.Authorization, k)
}
}
}
}

func handleParentId(parentId int) (pids []int, err error) {
func handleParentId(ctx context.Context, parentId int) (pids []int, err error) {
nodes := make([]*model.NodeIdPid, 0)
if err = mysql.DB.Model(&model.Node{}).Find(&nodes).Error; err != nil {
return
if err = redis.Get(ctx, kFmtAllNodes, &nodes); err != nil {
if err = mysql.DB.Model(&model.Node{}).Find(&nodes).Error; err != nil {
return
}
redis.SetEx(ctx, kFmtAllNodes, nodes, time.Hour)
}
g := make(map[int][]int)
for _, n := range nodes {
Expand All @@ -186,33 +199,69 @@ func handleParentId(parentId int) (pids []int, err error) {
return
}

func getAssertIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
authorizationResourceIds, err := getAutorizationResourceIds(ctx)
func GetAssetIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx)

authIds, err := getAuthorizationIds(ctx)
if err != nil {
return
}
ctx.Set(kAuthorizationIds, authIds)

k := fmt.Sprintf(kFmtAssetIds, currentUser.GetUid())
if err = redis.Get(ctx, k, &ids); err == nil {
return
}

parentNodeIds, ids, accountIds := getIdsByAuthorizationIds(ctx)

tmp, err := handleSelfChild(ctx, parentNodeIds)
if err != nil {
handleRemoteErr(ctx, err)
return
}
authIds := make([]*model.AuthorizationIds, 0)
if err = mysql.DB.Model(authIds).Where("resource_id IN ?", authorizationResourceIds).Find(&ids).Error; err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
parentNodeIds = append(parentNodeIds, tmp...)
ctx.Set(kParentNodeIds, parentNodeIds)
ctx.Set(kAccountIds, accountIds)
tmp, err = getAssetIdsByNodeAccount(ctx, parentNodeIds, accountIds)
if err != nil {
return
}
ctx.Set("authorizationIds", authIds)
ids = lo.Uniq(append(ids, tmp...))

redis.SetEx(ctx, k, ids, time.Minute)

return
}

func getIdsByAuthorizationIds(ctx context.Context) (parentNodeIds, assetIds, accountIds []int) {
authIds, _ := ctx.Value(kAuthorizationIds).([]*model.AuthorizationIds)

parentNodeIds := make([]int, 0)
for _, a := range authIds {
if a.NodeId != nil {
if a.NodeId != nil && a.AssetId == nil && a.AccountId == nil {
parentNodeIds = append(parentNodeIds, *a.NodeId)
} else if a.AssetId != nil {
ids = append(ids, *a.AssetId)
}
if a.AssetId != nil && a.NodeId == nil && a.AccountId == nil {
assetIds = append(assetIds, *a.AssetId)
}
if a.AccountId != nil && a.AssetId == nil && a.NodeId == nil {
accountIds = append(accountIds, *a.AccountId)
}
}
ctx.Set("parentNodeIds", parentNodeIds)
tmp := make([]int, 0)
if err = mysql.DB.Model(&model.Asset{}).Where("parent_id IN?", parentNodeIds).Pluck("id", &tmp).Error; err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return
}

func getAuthorizationIds(ctx *gin.Context) (authIds []*model.AuthorizationIds, err error) {
resourceIds, err := getAutorizationResourceIds(ctx)
if err != nil {
handleRemoteErr(ctx, err)
return
}
ids = append(ids, tmp...)

err = mysql.DB.Model(authIds).Where("resource_id IN ?", resourceIds).Find(&authIds).Error
return
}

func getAssetIdsByNodeAccount(ctx context.Context, parentNodeIds, accountIds []int) (assetIds []int, err error) {
err = mysql.DB.Model(&model.Asset{}).Where("parent_id IN?", parentNodeIds).Or("JSON_KEYS(authorization) IN ?", accountIds).Pluck("id", &assetIds).Error
return
}
Loading

0 comments on commit 03eec7a

Please sign in to comment.