Skip to content

Commit

Permalink
feat: delete auth filters
Browse files Browse the repository at this point in the history
  • Loading branch information
jsiebens committed May 28, 2022
1 parent 198b679 commit 2b5439b
Show file tree
Hide file tree
Showing 13 changed files with 488 additions and 186 deletions.
4 changes: 2 additions & 2 deletions internal/broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Broker interface {
AddClient(*Client)
RemoveClient(uint64)

SignalTailnedDeleted()
SignalUpdate()
SignalPeerUpdated(id uint64)
SignalPeersRemoved([]uint64)
SignalDNSUpdated()
Expand Down Expand Up @@ -94,7 +94,7 @@ func (h *broker) RemoveClient(id uint64) {
h.closingClients <- id
}

func (h *broker) SignalTailnedDeleted() {
func (h *broker) SignalUpdate() {
h.signalChannel <- &Signal{}
}

Expand Down
37 changes: 37 additions & 0 deletions internal/cmd/auth_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func authFilterCommand() *coral.Command {

command.AddCommand(createAuthFilterCommand())
command.AddCommand(listAuthFilterCommand())
command.AddCommand(deleteAuthFilterCommand())

return command
}
Expand Down Expand Up @@ -141,3 +142,39 @@ func createAuthFilterCommand() *coral.Command {

return command
}

func deleteAuthFilterCommand() *coral.Command {
command := &coral.Command{
Use: "delete",
SilenceUsage: true,
}

var authFilterID uint64

var target = Target{}
target.prepareCommand(command)

command.Flags().Uint64Var(&authFilterID, "auth-filter-id", 0, "")

command.RunE = func(command *coral.Command, args []string) error {
client, c, err := target.createGRPCClient()
if err != nil {
return err
}
defer safeClose(c)

req := &api.DeleteAuthFilterRequest{
AuthFilterId: authFilterID,
}

_, err = client.DeleteAuthFilter(context.Background(), req)

if err != nil {
return err
}

return nil
}

return command
}
26 changes: 26 additions & 0 deletions internal/domain/auth_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"github.com/hashicorp/go-bexpr"
"github.com/mitchellh/pointerstructure"
"gorm.io/gorm"
)

type AuthFilter struct {
Expand Down Expand Up @@ -56,6 +57,21 @@ func (fs AuthFilters) Evaluate(v interface{}) []Tailnet {
return tailnets
}

func (r *repository) GetAuthFilter(ctx context.Context, id uint64) (*AuthFilter, error) {
var t AuthFilter
tx := r.withContext(ctx).Take(&t, "id = ?", id)

if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}

if tx.Error != nil {
return nil, tx.Error
}

return &t, nil
}

func (r *repository) SaveAuthFilter(ctx context.Context, m *AuthFilter) error {
tx := r.withContext(ctx).Save(m)

Expand Down Expand Up @@ -95,3 +111,13 @@ func (r *repository) ListAuthFiltersByAuthMethod(ctx context.Context, authMethod

return filters, nil
}

func (r *repository) DeleteAuthFilter(ctx context.Context, id uint64) error {
tx := r.withContext(ctx).Delete(&AuthFilter{ID: id})
return tx.Error
}

func (r *repository) DeleteAuthFiltersByTailnet(ctx context.Context, tailnetID uint64) error {
tx := r.withContext(ctx).Where("tailnet_id = ?", tailnetID).Delete(&AuthFilter{})
return tx.Error
}
21 changes: 21 additions & 0 deletions internal/domain/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,24 @@ func (r *repository) SetMachineLastSeen(ctx context.Context, machineID uint64) e

return nil
}

func (r *repository) ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error) {
now := time.Now().UTC()

subQuery := r.withContext(ctx).
Select("machines.id").
Table("machines").
Joins("JOIN users u on u.id = machines.user_id JOIN accounts a on a.id = u.account_id").
Where("a.auth_method_id = ?", authMethodID)

tx := r.withContext(ctx).
Table("machines").
Where("tags = '' AND (expires_at is null or expires_at > ?) AND id in (?)", &now, subQuery).
Updates(map[string]interface{}{"expires_at": &now})

if tx.Error != nil {
return 0, tx.Error
}

return tx.RowsAffected, nil
}
4 changes: 4 additions & 0 deletions internal/domain/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ type Repository interface {
ListAuthMethods(ctx context.Context) ([]AuthMethod, error)
GetAuthMethod(ctx context.Context, id uint64) (*AuthMethod, error)

GetAuthFilter(ctx context.Context, id uint64) (*AuthFilter, error)
SaveAuthFilter(ctx context.Context, m *AuthFilter) error
ListAuthFilters(ctx context.Context) (AuthFilters, error)
ListAuthFiltersByAuthMethod(ctx context.Context, authMethodID uint64) (AuthFilters, error)
DeleteAuthFilter(ctx context.Context, id uint64) error
DeleteAuthFiltersByTailnet(ctx context.Context, tailnetID uint64) error

GetAccount(ctx context.Context, accountID uint64) (*Account, error)
GetOrCreateAccount(ctx context.Context, authMethodID uint64, externalID, loginName string) (*Account, bool, error)
Expand Down Expand Up @@ -62,6 +65,7 @@ type Repository interface {
ListMachinePeers(ctx context.Context, tailnetID uint64, key string) (Machines, error)
ListInactiveEphemeralMachines(ctx context.Context, checkpoint time.Time) (Machines, error)
SetMachineLastSeen(ctx context.Context, machineID uint64) error
ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error)

Transaction(func(rp Repository) error) error
}
Expand Down
3 changes: 3 additions & 0 deletions internal/domain/tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ func (i *Tags) Scan(destination interface{}) error {
}

func (i Tags) Value() (driver.Value, error) {
if len(i) == 0 {
return "", nil
}
v := "|" + strings.Join(i, "|") + "|"
return v, nil
}
Expand Down
38 changes: 38 additions & 0 deletions internal/service/auth_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,44 @@ func (s *Service) CreateAuthFilter(ctx context.Context, req *api.CreateAuthFilte
return &response, nil
}

func (s *Service) DeleteAuthFilter(ctx context.Context, req *api.DeleteAuthFilterRequest) (*api.DeleteAuthFilterResponse, error) {

err := s.repository.Transaction(func(rp domain.Repository) error {

filter, err := rp.GetAuthFilter(ctx, req.AuthFilterId)
if err != nil {
return err
}

if filter == nil {
return status.Error(codes.NotFound, "auth filter not found")
}

c, err := rp.ExpireMachineByAuthMethod(ctx, filter.AuthMethodID)
if err != nil {
return err
}

if err := rp.DeleteAuthFilter(ctx, filter.ID); err != nil {
return err
}

if c != 0 {
s.brokers(*filter.TailnetID).SignalUpdate()
}

return nil
})

if err != nil {
return nil, err
}

response := api.DeleteAuthFilterResponse{}

return &response, nil
}

func (s *Service) mapToApi(authMethod *domain.AuthMethod, filter domain.AuthFilter) *api.AuthFilter {
result := api.AuthFilter{
Id: filter.ID,
Expand Down
6 changes: 5 additions & 1 deletion internal/service/tailnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *api.DeleteTailnetReque
return err
}

if err := tx.DeleteAuthFiltersByTailnet(ctx, req.TailnetId); err != nil {
return err
}

if err := tx.DeleteACLPolicy(ctx, req.TailnetId); err != nil {
return err
}
Expand All @@ -100,7 +104,7 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *api.DeleteTailnetReque
return nil, err
}

s.brokers(req.TailnetId).SignalTailnedDeleted()
s.brokers(req.TailnetId).SignalUpdate()

return &api.DeleteTailnetResponse{}, nil
}
Loading

0 comments on commit 2b5439b

Please sign in to comment.