Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: derived atom is not recomputed after its dependencies changed #2906 #2907

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
90 changes: 51 additions & 39 deletions src/vanilla/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ type BatchPriority = 0 | 1 | 2

type Batch = [
/** finish recompute */
priority0: Set<() => void>,
priority0: Set<(batch: Batch) => void>,
/** atom listeners */
priority1: Set<() => void>,
priority1: Set<(batch: Batch) => void>,
/** atom mount hooks */
priority2: Set<() => void>,
priority2: Set<(batch: Batch) => void>,
] & {
/** changed Atoms */
C: Set<AnyAtom>
Expand All @@ -189,7 +189,7 @@ const createBatch = (): Batch =>
const addBatchFunc = (
batch: Batch,
priority: BatchPriority,
fn: () => void,
fn: (batch: Batch) => void,
) => {
batch[priority].add(fn)
}
Expand All @@ -203,7 +203,9 @@ const registerBatchAtom = (
batch.C.add(atom)
atomState.u?.(batch)
const scheduleListeners = () => {
atomState.m?.l.forEach((listener) => addBatchFunc(batch, 1, listener))
atomState.m?.l.forEach((listener) =>
addBatchFunc(batch, 1, () => listener()),
)
}
addBatchFunc(batch, 1, scheduleListeners)
}
Expand All @@ -212,9 +214,9 @@ const registerBatchAtom = (
const flushBatch = (batch: Batch) => {
let error: AnyError
let hasError = false
const call = (fn: () => void) => {
const call = (fn: (batch: Batch) => void) => {
try {
fn()
fn(batch)
} catch (e) {
if (!hasError) {
error = e
Expand All @@ -223,7 +225,6 @@ const flushBatch = (batch: Batch) => {
}
}
while (batch.C.size || batch.some((channel) => channel.size)) {
batch.C.clear()
for (const channel of batch) {
channel.forEach(call)
channel.clear()
Expand Down Expand Up @@ -315,7 +316,6 @@ const buildStore = (...storeArgs: StoreArgs): Store => {
atomState.v = valueOrPromise
}
delete atomState.e
delete atomState.x
if (!hasPrevValue || !Object.is(prevValue, atomState.v)) {
++atomState.n
if (pendingPromise) {
Expand Down Expand Up @@ -428,7 +428,6 @@ const buildStore = (...storeArgs: StoreArgs): Store => {
} catch (error) {
delete atomState.v
atomState.e = error
delete atomState.x
++atomState.n
return atomState
} finally {
Expand Down Expand Up @@ -458,11 +457,26 @@ const buildStore = (...storeArgs: StoreArgs): Store => {
return dependents
}

const recomputeDependents = <Value>(
batch: Batch,
atom: Atom<Value>,
atomState: AtomState<Value>,
) => {
const dirtyDependents = <Value>(atomState: AtomState<Value>) => {
const dependents = new Set<AtomState>([atomState])
const stack: AtomState[] = [atomState]
while (stack.length > 0) {
const aState = stack.pop()!
if (aState.x) {
// already dirty
continue
}
aState.x = true
for (const [, s] of getMountedOrBatchDependents(aState)) {
if (!dependents.has(s)) {
dependents.add(s)
stack.push(s)
}
}
}
}

const recomputeDependents = (batch: Batch) => {
// Step 1: traverse the dependency graph to build the topsorted atom list
// We don't bother to check for cycles, which simplifies the algorithm.
// This is a topological sort via depth-first search, slightly modified from
Expand All @@ -477,7 +491,10 @@ const buildStore = (...storeArgs: StoreArgs): Store => {
const visited = new Set<AnyAtom>()
// Visit the root atom. This is the only atom in the dependency graph
// without incoming edges, which is one reason we can simplify the algorithm
const stack: [a: AnyAtom, aState: AtomState][] = [[atom, atomState]]
const stack: [a: AnyAtom, aState: AtomState][] = Array.from(
batch.C,
(atom) => [atom, ensureAtomState(atom)],
)
while (stack.length > 0) {
const [a, aState] = stack[stack.length - 1]!
if (visited.has(a)) {
Expand All @@ -492,8 +509,6 @@ const buildStore = (...storeArgs: StoreArgs): Store => {
topSortedReversed.push([a, aState, aState.n])
// Atom has been visited but not yet processed
visited.add(a)
// Mark atom dirty
aState.x = true
stack.pop()
continue
}
Expand All @@ -508,29 +523,25 @@ const buildStore = (...storeArgs: StoreArgs): Store => {

// Step 2: use the topSortedReversed atom list to recompute all affected atoms
// Track what's changed, so that we can short circuit when possible
const finishRecompute = () => {
const changedAtoms = new Set<AnyAtom>([atom])
for (let i = topSortedReversed.length - 1; i >= 0; --i) {
const [a, aState, prevEpochNumber] = topSortedReversed[i]!
let hasChangedDeps = false
for (const dep of aState.d.keys()) {
if (dep !== a && changedAtoms.has(dep)) {
hasChangedDeps = true
break
}
for (let i = topSortedReversed.length - 1; i >= 0; --i) {
const [a, aState, prevEpochNumber] = topSortedReversed[i]!
let hasChangedDeps = false
for (const dep of aState.d.keys()) {
if (dep !== a && batch.C.has(dep)) {
hasChangedDeps = true
break
}
if (hasChangedDeps) {
readAtomState(batch, a)
mountDependencies(batch, a, aState)
if (prevEpochNumber !== aState.n) {
registerBatchAtom(batch, a, aState)
changedAtoms.add(a)
}
}
if (hasChangedDeps) {
readAtomState(batch, a)
mountDependencies(batch, a, aState)
if (prevEpochNumber !== aState.n) {
registerBatchAtom(batch, a, aState)
}
delete aState.x
}
delete aState.x
}
addBatchFunc(batch, 0, finishRecompute)
batch.C.clear()
}

const writeAtomState = <Value, Args extends unknown[], Result>(
Expand All @@ -557,8 +568,9 @@ const buildStore = (...storeArgs: StoreArgs): Store => {
setAtomStateValueOrPromise(a, aState, v)
mountDependencies(batch, a, aState)
if (prevEpochNumber !== aState.n) {
dirtyDependents(aState)
registerBatchAtom(batch, a, aState)
recomputeDependents(batch, a, aState)
addBatchFunc(batch, 0, recomputeDependents)
}
return undefined as R
} else {
Expand Down Expand Up @@ -679,7 +691,7 @@ const buildStore = (...storeArgs: StoreArgs): Store => {
// unmount self
const onUnmount = atomState.m.u
if (onUnmount) {
addBatchFunc(batch, 2, () => onUnmount(batch))
addBatchFunc(batch, 2, onUnmount)
}
delete atomState.m
atomState.h?.(batch)
Expand Down
56 changes: 17 additions & 39 deletions tests/vanilla/effect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ type Ref = {
cleanup?: Cleanup | undefined
}

const syncEffectChannelSymbol = Symbol()

function syncEffect(effect: Effect): Atom<void> {
const refAtom = atom<Ref>(() => ({ inProgress: 0, epoch: 0 }))
const refreshAtom = atom(0)
Expand Down Expand Up @@ -62,8 +60,7 @@ function syncEffect(effect: Effect): Atom<void> {
store.set(refreshAtom, (v) => v + 1)
} else {
// unmount
const syncEffectChannel = ensureBatchChannel(batch)
syncEffectChannel.add(() => {
scheduleListenerAfterRecompute(batch, () => {
ref.cleanup?.()
delete ref.cleanup
})
Expand All @@ -73,48 +70,29 @@ function syncEffect(effect: Effect): Atom<void> {
internalAtomState.u = (batch) => {
originalUpdateHook?.(batch)
// update
const syncEffectChannel = ensureBatchChannel(batch)
syncEffectChannel.add(runEffect)
scheduleListenerAfterRecompute(batch, runEffect)
}
function scheduleListenerAfterRecompute(
batch: Batch,
listener: () => void,
) {
const scheduleListener = () => {
batch[0].add(listener)
}
if (batch[0].size === 0) {
// no other listeners
// schedule after recomputeDependents
batch[0].add(scheduleListener)
} else {
scheduleListener()
}
}
}
return atom((get) => {
get(internalAtom)
})
}

type BatchWithSyncEffect = Batch & {
[syncEffectChannelSymbol]?: Set<() => void>
}
function ensureBatchChannel(batch: BatchWithSyncEffect) {
// ensure continuation of the flushBatch while loop
const originalQueue = batch[1]
if (!originalQueue) {
throw new Error('batch[1] must be present')
}
if (!batch[syncEffectChannelSymbol]) {
batch[syncEffectChannelSymbol] = new Set<() => void>()
batch[1] = {
...originalQueue,
add(item) {
originalQueue.add(item)
return this
},
clear() {
batch[syncEffectChannelSymbol]!.clear()
originalQueue.clear()
},
forEach(callback) {
batch[syncEffectChannelSymbol]!.forEach(callback)
originalQueue.forEach(callback)
},
get size() {
return batch[syncEffectChannelSymbol]!.size + originalQueue.size
},
}
}
return batch[syncEffectChannelSymbol]!
}

const getAtomStateMap = new WeakMap<Store, GetAtomState>()

/**
Expand Down
18 changes: 18 additions & 0 deletions tests/vanilla/store.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1108,3 +1108,21 @@ it('recomputes dependents of unmounted atoms', () => {
store.set(w)
expect(bRead).not.toHaveBeenCalled()
})

it('recomputes all changed atom dependents together', async () => {
const a = atom([0])
const b = atom([0])
const a0 = atom((get) => get(a)[0]!)
const b0 = atom((get) => get(b)[0]!)
const a0b0 = atom((get) => [get(a0), get(b0)])
const w = atom(null, (_, set) => {
set(a, [0])
set(b, [1])
})
const store = createStore()
store.sub(a0b0, () => {})
store.set(w)
expect(store.get(a0)).toBe(0)
expect(store.get(b0)).toBe(1)
expect(store.get(a0b0)).toEqual([0, 1])
})
Loading