Skip to content

Commit

Permalink
Merge pull request #716 from MatrixAI/feature-discovery-task-atomization
Browse files Browse the repository at this point in the history
Claim Discovery Task Atomization to Apply Task Deadline Individually Across Each Claim
  • Loading branch information
amydevs authored May 23, 2024
2 parents 62945f0 + 4d98422 commit cc06e16
Show file tree
Hide file tree
Showing 10 changed files with 592 additions and 167 deletions.
7 changes: 6 additions & 1 deletion src/claims/payloads/claimLinkIdentity.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import type { Claim, SignedClaim } from '../types';
import type { NodeIdEncoded, ProviderIdentityIdEncoded } from '../../ids/types';
import type {
NodeIdEncoded,
ProviderIdentityClaimId,
ProviderIdentityIdEncoded,
} from '../../ids/types';
import * as ids from '../../ids';
import * as claimsUtils from '../utils';
import * as tokensUtils from '../../tokens/utils';
Expand All @@ -13,6 +17,7 @@ interface ClaimLinkIdentity extends Claim {
typ: 'ClaimLinkIdentity';
iss: NodeIdEncoded;
sub: ProviderIdentityIdEncoded;
providerIdentityClaimId?: ProviderIdentityClaimId;
}

function assertClaimLinkIdentity(
Expand Down
191 changes: 143 additions & 48 deletions src/discovery/Discovery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import type IdentitiesManager from '../identities/IdentitiesManager';
import type {
IdentityData,
IdentityId,
IdentitySignedClaim,
ProviderId,
ProviderIdentityClaimId,
ProviderIdentityId,
ProviderPaginationToken,
} from '../identities/types';
import type KeyRing from '../keys/KeyRing';
import type { ClaimIdEncoded, SignedClaim } from '../claims/types';
Expand Down Expand Up @@ -165,10 +167,8 @@ class Discovery {
}),
);
} catch (e) {
if (
e instanceof tasksErrors.ErrorTaskStop ||
e === discoveryStoppingTaskReason
) {
// We need to reschedule if the task was cancelled due to discovery domain stopping
if (e === discoveryStoppingTaskReason) {
// We need to recreate the task for the vertex
const vertexId = gestaltsUtils.decodeGestaltId(vertex);
if (vertexId == null) never();
Expand All @@ -177,6 +177,7 @@ class Discovery {
undefined,
undefined,
gestaltsUtils.decodeGestaltId(parent ?? undefined),
true,
);
return;
}
Expand Down Expand Up @@ -276,6 +277,10 @@ class Discovery {
this.discoverVertexHandlerId,
this.discoverVertexHandler,
);
this.taskManager.registerHandler(
this.checkRediscoveryHandlerId,
this.checkRediscoveryHandler,
);
// Start up rediscovery task
await this.taskManager.scheduleTask({
handlerId: this.discoverVertexHandlerId,
Expand Down Expand Up @@ -303,6 +308,7 @@ class Discovery {
}
await Promise.all(taskPromises);
this.taskManager.deregisterHandler(this.discoverVertexHandlerId);
this.taskManager.deregisterHandler(this.checkRediscoveryHandlerId);
this.logger.info(`Stopped ${this.constructor.name}`);
}

Expand Down Expand Up @@ -450,14 +456,14 @@ class Discovery {
switch (signedClaim.payload.typ) {
case 'ClaimLinkNode':
await this.processClaimLinkNode(
signedClaim,
signedClaim as SignedClaim<ClaimLinkNode>,
nodeId,
lastProcessedCutoffTime,
);
break;
case 'ClaimLinkIdentity':
await this.processClaimLinkIdentity(
signedClaim,
signedClaim as SignedClaim<ClaimLinkIdentity>,
nodeId,
ctx,
lastProcessedCutoffTime,
Expand All @@ -474,7 +480,7 @@ class Discovery {
}

protected async processClaimLinkNode(
signedClaim: SignedClaim,
signedClaim: SignedClaim<ClaimLinkNode>,
nodeId: NodeId,
lastProcessedCutoffTime = Date.now() - this.rediscoverSkipTime,
): Promise<void> {
Expand All @@ -498,25 +504,25 @@ class Discovery {
);
return;
}
const linkedVertexNodeId = node1Id.equals(nodeId) ? node2Id : node1Id;
const linkedNodeId = node1Id.equals(nodeId) ? node2Id : node1Id;
const linkedVertexNodeInfo: GestaltNodeInfo = {
nodeId: linkedVertexNodeId,
nodeId: linkedNodeId,
};
await this.gestaltGraph.linkNodeAndNode(
{
nodeId,
},
linkedVertexNodeInfo,
{
claim: signedClaim as SignedClaim<ClaimLinkNode>,
claim: signedClaim,
meta: {},
},
);
const claimId = decodeClaimId(signedClaim.payload.jti);
if (claimId == null) never();
await this.gestaltGraph.setClaimIdNewest(nodeId, claimId);
// Add this vertex to the queue if it hasn't already been visited
const linkedGestaltId: GestaltId = ['node', linkedVertexNodeId];
const linkedGestaltId: GestaltId = ['node', linkedNodeId];
if (
!(await this.processedTimeGreaterThan(
linkedGestaltId,
Expand All @@ -533,7 +539,7 @@ class Discovery {
}

protected async processClaimLinkIdentity(
signedClaim: SignedClaim,
signedClaim: SignedClaim<ClaimLinkIdentity>,
nodeId: NodeId,
ctx: ContextTimed,
lastProcessedCutoffTime = Date.now() - this.rediscoverSkipTime,
Expand Down Expand Up @@ -565,19 +571,8 @@ class Discovery {
return;
}
// Need to get the corresponding claim for this
let providerIdentityClaimId: ProviderIdentityClaimId | null = null;
const identityClaims = await this.verifyIdentityClaims(
providerId,
identityId,
);
for (const [id, claim] of Object.entries(identityClaims)) {
const issuerNodeId = nodesUtils.decodeNodeId(claim.payload.iss);
if (issuerNodeId == null) continue;
if (nodeId.equals(issuerNodeId)) {
providerIdentityClaimId = id as ProviderIdentityClaimId;
break;
}
}
const providerIdentityClaimId = signedClaim.payload
.providerIdentityClaimId as ProviderIdentityClaimId | null;
if (providerIdentityClaimId == null) {
this.logger.warn(
`Failed to get corresponding identity claim for ${providerId}:${identityId}`,
Expand All @@ -591,7 +586,7 @@ class Discovery {
},
identityInfo,
{
claim: signedClaim as SignedClaim<ClaimLinkIdentity>,
claim: signedClaim,
meta: {
providerIdentityClaimId: providerIdentityClaimId,
url: identityInfo.url,
Expand Down Expand Up @@ -633,27 +628,41 @@ class Discovery {
identityId,
ctx,
);
let lastProviderPaginationToken = await this.gestaltGraph
.getIdentity(providerIdentityId)
.then((identity) => identity?.lastProviderPaginationToken);
// If we don't have identity info, simply skip this vertex
if (vertexIdentityInfo == null) {
return;
}
// Getting and verifying claims
const claims = await this.verifyIdentityClaims(providerId, identityId);
// Link the identity with each node from its claims on the provider
// Iterate over each of the claims
for (const [claimId, claim] of Object.entries(claims)) {
if (ctx.signal.aborted) throw ctx.signal.reason;
const {
identityClaims,
lastProviderPaginationToken: lastProviderPaginationToken_,
} = await this.verifyIdentityClaims(
providerId,
identityId,
lastProviderPaginationToken,
ctx,
);
const isAborted = ctx.signal.aborted;
lastProviderPaginationToken = lastProviderPaginationToken_;
// Iterate over each of the claims, even if ctx has aborted
for (const [claimId, claim] of Object.entries(identityClaims)) {
// Claims on an identity provider will always be node -> identity
// So just cast payload data as such
const linkedVertexNodeId = nodesUtils.decodeNodeId(claim.payload.iss);
if (linkedVertexNodeId == null) never();
const linkedNodeId = nodesUtils.decodeNodeId(claim.payload.iss);
if (linkedNodeId == null) never();
// With this verified chain, we can link
const linkedVertexNodeInfo = {
nodeId: linkedVertexNodeId,
nodeId: linkedNodeId,
};
await this.gestaltGraph.linkNodeAndIdentity(
linkedVertexNodeInfo,
vertexIdentityInfo,
{
...vertexIdentityInfo,
lastProviderPaginationToken,
},
{
claim: claim,
meta: {
Expand All @@ -662,8 +671,8 @@ class Discovery {
},
},
);
// Add this vertex to the queue if it is not present
const gestaltNodeId: GestaltId = ['node', linkedVertexNodeId];
// Check and schedule node for processing
const gestaltNodeId: GestaltId = ['node', linkedNodeId];
if (
!(await this.processedTimeGreaterThan(
gestaltNodeId,
Expand All @@ -678,6 +687,11 @@ class Discovery {
);
}
}
// Throw after we have processed the node claims if the signal aborted whilst running verifyIdentityClaims
if (isAborted) {
throw ctx.signal.reason;
}
// Only setVertexProcessedTime if we have succeeded in processing all identities
await this.gestaltGraph.setVertexProcessedTime(
['identity', providerIdentityId],
Date.now(),
Expand Down Expand Up @@ -732,6 +746,7 @@ class Discovery {
delay?: number,
lastProcessedCutoffTime?: number,
parent?: GestaltId,
ignoreActive: boolean = false,
tran?: DBTransaction,
) {
if (tran == null) {
Expand All @@ -741,6 +756,7 @@ class Discovery {
delay,
lastProcessedCutoffTime,
parent,
ignoreActive,
tran,
),
);
Expand All @@ -762,6 +778,8 @@ class Discovery {
[this.constructor.name, this.discoverVertexHandlerId, gestaltIdEncoded],
tran,
)) {
// Ignore active tasks
if (ignoreActive && task.status === 'active') continue;
if (taskExisting == null) {
taskExisting = task;
continue;
Expand Down Expand Up @@ -845,33 +863,43 @@ class Discovery {
* Helper function to retrieve and verify the claims of an identity on a given
* provider. Connects with each node the identity claims to be linked with,
* and verifies the claim with the public key of the node.
*
* This method never throws if ctx has aborted, opting instead to return early
* with a lastProviderPaginationToken so that the caller can process the partially
* requested Claims as well as resume progress when calling again.
*/
protected async verifyIdentityClaims(
providerId: ProviderId,
identityId: IdentityId,
): Promise<Record<ProviderIdentityClaimId, SignedClaim<ClaimLinkIdentity>>> {
providerPaginationToken: ProviderPaginationToken | undefined,
ctx: ContextTimed,
): Promise<{
identityClaims: Record<
ProviderIdentityClaimId,
SignedClaim<ClaimLinkIdentity>
>;
lastProviderPaginationToken?: ProviderPaginationToken;
}> {
const provider = this.identitiesManager.getProvider(providerId);
// If we don't have this provider, no identity info to find
if (provider == null) {
return {};
return { identityClaims: {} };
}
// Get our own auth identity id
const authIdentityIds = await provider.getAuthIdentityIds();
// If we don't have one then we can't request data so just skip
if (authIdentityIds.length === 0 || authIdentityIds[0] == null) {
return {};
return { identityClaims: {} };
}
const authIdentityId = authIdentityIds[0];
const identityClaims: Record<
ProviderIdentityClaimId,
SignedClaim<ClaimLinkIdentity>
> = {};
for await (const identitySignedClaim of provider.getClaims(
authIdentityId,
identityId,
)) {
identitySignedClaim.claim;
// Claims on an identity provider will always be node -> identity

const identitySignedClaimDb = (
identitySignedClaim: IdentitySignedClaim,
) => {
const claim = identitySignedClaim.claim;
const data = claim.payload;
// Verify the claim with the public key of the node
Expand All @@ -883,8 +911,75 @@ class Discovery {
if (token.verifyWithPublicKey(publicKey)) {
identityClaims[identitySignedClaim.id] = claim;
}
}
return identityClaims;
};
let nextPaginationToken: ProviderPaginationToken | undefined =
providerPaginationToken;
let processed: boolean;
do {
// Refresh before each request made with identitySignedClaimGenerator
ctx.timer.refresh();
processed = false;
if (provider.preferGetClaimsPage) {
const iterator = provider.getClaimIdsPage(
authIdentityId,
identityId,
nextPaginationToken,
);
for await (const wrapper of iterator) {
processed = true;
// This will:
// 1. throw if the getClaimIdsPage takes too much time
// 2. the rest of this loop iteration takes too much time
if (ctx.signal.aborted) {
return {
identityClaims: identityClaims,
lastProviderPaginationToken: nextPaginationToken,
};
}
const claimId = wrapper.claimId;
if (wrapper.nextPaginationToken != null) {
nextPaginationToken = wrapper.nextPaginationToken;
}
// Refresh timer in preparation for request
ctx.timer.refresh();
const identitySignedClaim = await provider.getClaim(
authIdentityId,
claimId,
);
if (identitySignedClaim == null) {
continue;
}
// Claims on an identity provider will always be node -> identity
identitySignedClaimDb(identitySignedClaim);
}
} else {
const iterator = provider.getClaimsPage(
authIdentityId,
identityId,
nextPaginationToken,
);
for await (const wrapper of iterator) {
processed = true;
// This will:
// 1. throw if the getClaimIdsPage takes too much time
// 2. the rest of this loop iteration takes too much time
if (ctx.signal.aborted) {
return {
identityClaims: identityClaims,
lastProviderPaginationToken: nextPaginationToken,
};
}
if (wrapper.nextPaginationToken != null) {
nextPaginationToken = wrapper.nextPaginationToken;
}
// Claims on an identity provider will always be node -> identity
identitySignedClaimDb(wrapper.claim);
}
}
} while (processed);
return {
identityClaims,
};
}

/**
Expand Down
Loading

0 comments on commit cc06e16

Please sign in to comment.