diff --git a/internal/domain/machine.go b/internal/domain/machine.go index c3b8433d..44eda827 100644 --- a/internal/domain/machine.go +++ b/internal/domain/machine.go @@ -372,14 +372,14 @@ func (r *repository) SetMachineLastSeen(ctx context.Context, machineID uint64) e return nil } -func (r *repository) ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error) { +func (r *repository) ExpireMachineByAuthMethod(ctx context.Context, tailnetID, authMethodID uint64) (int64, error) { now := time.Now().UTC() subQuery := r.withContext(ctx). Select("machines.id"). Table("machines"). Joins("JOIN users u on u.id = machines.user_id JOIN accounts a on a.id = u.account_id"). - Where("a.auth_method_id = ?", authMethodID) + Where("machines.tailnet_id = ? AND a.auth_method_id = ?", tailnetID, authMethodID) tx := r.withContext(ctx). Table("machines"). diff --git a/internal/domain/repository.go b/internal/domain/repository.go index ef2ed52c..664d8288 100644 --- a/internal/domain/repository.go +++ b/internal/domain/repository.go @@ -65,7 +65,7 @@ type Repository interface { ListMachinePeers(ctx context.Context, tailnetID uint64, key string) (Machines, error) ListInactiveEphemeralMachines(ctx context.Context, checkpoint time.Time) (Machines, error) SetMachineLastSeen(ctx context.Context, machineID uint64) error - ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error) + ExpireMachineByAuthMethod(ctx context.Context, tailnetID, authMethodID uint64) (int64, error) SaveRegistrationRequest(ctx context.Context, request *RegistrationRequest) error GetRegistrationRequestByKey(ctx context.Context, key string) (*RegistrationRequest, error) diff --git a/internal/service/auth_filters.go b/internal/service/auth_filters.go index a77a5ad1..c4f2c73f 100644 --- a/internal/service/auth_filters.go +++ b/internal/service/auth_filters.go @@ -96,7 +96,7 @@ func (s *Service) DeleteAuthFilter(ctx context.Context, req *api.DeleteAuthFilte return status.Error(codes.NotFound, "auth filter not found") } - c, err := rp.ExpireMachineByAuthMethod(ctx, filter.AuthMethodID) + c, err := rp.ExpireMachineByAuthMethod(ctx, *filter.TailnetID, filter.AuthMethodID) if err != nil { return err }