diff --git a/discojs/discojs-core/src/aggregator/base.ts b/discojs/discojs-core/src/aggregator/base.ts index 7dd2f94e6..ac802431e 100644 --- a/discojs/discojs-core/src/aggregator/base.ts +++ b/discojs/discojs-core/src/aggregator/base.ts @@ -67,7 +67,7 @@ export abstract class Base { */ protected readonly roundCutoff = 0, /** - * The number of communication rounds occuring during any given aggregation round. + * The number of communication rounds occurring during any given aggregation round. */ public readonly communicationRounds = 1 ) { @@ -272,7 +272,7 @@ export abstract class Base { } /** - * The current commnication round. + * The current communication round. */ get communicationRound (): number { return this._communicationRound diff --git a/discojs/discojs-core/src/client/federated/base.ts b/discojs/discojs-core/src/client/federated/base.ts index 43fd36acb..515cbfd0f 100644 --- a/discojs/discojs-core/src/client/federated/base.ts +++ b/discojs/discojs-core/src/client/federated/base.ts @@ -19,14 +19,6 @@ export class Base extends Client { * by this client class, the server is the only node which we are communicating with. */ public static readonly SERVER_NODE_ID = 'federated-server-node-id' - /** - * Most recent server-fetched round. - */ - private serverRound?: number - /** - * Most recent server-fetched aggregated result. - */ - private serverResult?: WeightsContainer /** * Statistics curated by the federated server. */ @@ -92,41 +84,37 @@ export class Base extends Client { /** * Send a message containing our local weight updates to the federated server. + * And waits for the server to reply with the most recent aggregated weights * @param weights The weight updates to send */ - async sendPayload (payload: WeightsContainer): Promise { + private async sendPayloadAndReceiveResult (payload: WeightsContainer): Promise { const msg: messages.SendPayload = { type: type.SendPayload, payload: await serialization.weights.encode(payload), round: this.aggregator.round } this.server.send(msg) + // It is important than the client immediately awaits the server result or it may miss it + return await this.receiveResult() } /** - * Fetches the server's result for its current (most recent) round and add it to our aggregator. + * Waits for the server's result for its current (most recent) round and add it to our aggregator. * Updates the aggregator's round if it's behind the server's. */ - async receiveResult (): Promise { - this.serverRound = undefined - this.serverResult = undefined - - const msg: messages.MessageBase = { - type: type.ReceiveServerPayload - } - this.server.send(msg) - + private async receiveResult (): Promise { try { const { payload, round } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload) - this.serverRound = round + const serverRound = round // Store the server result only if it is not stale if (this.aggregator.round <= round) { - this.serverResult = serialization.weights.decode(payload) + const serverResult = serialization.weights.decode(payload) // Update the local round to match the server's - if (this.aggregator.round < this.serverRound) { - this.aggregator.setRound(this.serverRound) + if (this.aggregator.round < serverRound) { + this.aggregator.setRound(serverRound) } + return serverResult } } catch (e) { console.error(e) @@ -226,13 +214,11 @@ export class Base extends Client { throw new Error('local aggregation result was not set') } - // Send our contribution to the server - await this.sendPayload(this.aggregator.makePayloads(weights).first()) - // Fetch the server result - await this.receiveResult() + // Send our local contribution to the server + // and receive the most recent weights as an answer to our contribution + const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first()) - // TODO @s314cy: add communication rounds to federated learning - if (this.serverResult !== undefined && this.aggregator.add(Base.SERVER_NODE_ID, this.serverResult, round, 0)) { + if (serverResult !== undefined && this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) { // Regular case: the server sends us its aggregation result which will serve our // own aggregation result. } else { diff --git a/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts b/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts index 00cbbe976..34e4ddd67 100644 --- a/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts +++ b/discojs/discojs-core/src/dataset/data/preprocessing/tabular_preprocessing.ts @@ -1,3 +1,4 @@ +import { Task, tf } from '../../..' import { List } from 'immutable' import { PreprocessingFunction } from './base' @@ -9,7 +10,25 @@ export enum TabularPreprocessing { Normalize } +interface TabularEntry extends tf.TensorContainerObject { + xs: number[] + ys: tf.Tensor1D | number | undefined +} + +const sanitize: PreprocessingFunction = { + type: TabularPreprocessing.Sanitize, + apply: (entry: tf.TensorContainer, task: Task): tf.TensorContainer => { + const { xs, ys } = entry as TabularEntry + return { + xs: xs.map(i => i === undefined ? 0 : i), + ys: ys + } + } +} + /** * Available tabular preprocessing functions. */ -export const AVAILABLE_PREPROCESSING = List() +export const AVAILABLE_PREPROCESSING = List([ + sanitize] +).sortBy((e) => e.type) diff --git a/discojs/discojs-core/src/dataset/data/tabular_data.ts b/discojs/discojs-core/src/dataset/data/tabular_data.ts index 5b2416f18..cb0e92a74 100644 --- a/discojs/discojs-core/src/dataset/data/tabular_data.ts +++ b/discojs/discojs-core/src/dataset/data/tabular_data.ts @@ -21,7 +21,8 @@ export class TabularData extends Data { try { await dataset.iterator() } catch (e) { - throw new Error('Data input format is not compatible with the chosen task') + console.error('Data input format is not compatible with the chosen task.') + throw (e) } return new TabularData(dataset, task, size) diff --git a/discojs/discojs-core/src/default_tasks/titanic.ts b/discojs/discojs-core/src/default_tasks/titanic.ts index 5c7214266..c51be7490 100644 --- a/discojs/discojs-core/src/default_tasks/titanic.ts +++ b/discojs/discojs-core/src/default_tasks/titanic.ts @@ -1,4 +1,4 @@ -import { tf, Task, TaskProvider } from '..' +import { tf, Task, TaskProvider, data } from '..' export const titanic: TaskProvider = { getTask (): Task { @@ -49,7 +49,8 @@ export const titanic: TaskProvider = { roundDuration: 10, validationSplit: 0.2, batchSize: 30, - preprocessingFunctions: [], + preprocessingFunctions: [data.TabularPreprocessing.Sanitize], + learningRate: 0.001, modelCompileData: { optimizer: 'sgd', loss: 'binaryCrossentropy', diff --git a/discojs/discojs-core/src/validation/validator.spec.ts b/discojs/discojs-core/src/validation/validator.spec.ts index 0e060a97c..b9cec0acc 100644 --- a/discojs/discojs-core/src/validation/validator.spec.ts +++ b/discojs/discojs-core/src/validation/validator.spec.ts @@ -1,7 +1,8 @@ import { assert } from 'chai' import fs from 'fs' -import { Task, node, Validator, ConsoleLogger, EmptyMemory, client as clients, data, aggregator } from '@epfml/discojs-node' +import { Task, node, Validator, ConsoleLogger, EmptyMemory, + client as clients, data, aggregator, defaultTasks } from '@epfml/discojs-node' const simplefaceMock = { taskID: 'simple_face', @@ -55,25 +56,36 @@ describe('validator', () => { `expected accuracy greater than 0.3 but got ${validator.accuracy}` ) console.table(validator.confusionMatrix) - }).timeout(10_000) + }).timeout(15_000) - // TODO: fix titanic model (nan accuracy) - // it('works for titanic', async () => { - // const data: Data = await new NodeTabularLoader(titanic.task, ',') - // .loadAll(['../../example_training_data/titanic.csv'], { - // features: titanic.task.trainingInformation?.inputColumns, - // labels: titanic.task.trainingInformation?.outputColumns - // }) - // const validator = new Validator(titanic.task, new ConsoleLogger(), titanic.model()) - // await validator.assess(data) - - // assert( - // validator.visitedSamples() === data.size, - // `expected ${TITANIC_SAMPLES} visited samples but got ${validator.visitedSamples()}` - // ) - // assert( - // validator.accuracy() > 0.5, - // `expected accuracy greater than 0.5 but got ${validator.accuracy()}` - // ) - // }) + it('works for titanic', async () => { + const titanicTask = defaultTasks.titanic.getTask() + const files = ['../../example_training_data/titanic_train.csv'] + const data: data.Data = (await new node.data.NodeTabularLoader(titanicTask, ',').loadAll(files, { + features: titanicTask.trainingInformation.inputColumns, + labels: titanicTask.trainingInformation.outputColumns, + shuffle: false + })).train + const buffer = new aggregator.MeanAggregator(titanicTask) + const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, buffer) + buffer.setModel(await client.getLatestModel()) + const validator = new Validator(titanicTask, + new ConsoleLogger(), + new EmptyMemory(), + undefined, + client) + await validator.assess(data) + // data.size is undefined because tfjs handles dataset lazily + // instead we count the dataset size manually + let size = 0 + await data.dataset.forEachAsync(() => size+=1) + assert( + validator.visitedSamples === size, + `expected ${size} visited samples but got ${validator.visitedSamples}` + ) + assert( + validator.accuracy > 0.5, + `expected accuracy greater than 0.5 but got ${validator.accuracy}` + ) + }).timeout(15_000) }) diff --git a/docs/node_example/data.ts b/docs/node_example/data.ts index 7e4f6f171..27136720c 100644 --- a/docs/node_example/data.ts +++ b/docs/node_example/data.ts @@ -1,7 +1,7 @@ import fs from 'fs' import Rand from 'rand-seed' -import { node, data, Task } from '@epfml/discojs-node' +import { node, data, Task, defaultTasks } from '@epfml/discojs-node' const rand = new Rand('1234') @@ -45,3 +45,14 @@ export async function loadData (task: Task): Promise { return await new node.data.NodeImageLoader(task).loadAll(files, { labels: labels }) } + +export async function loadTitanicData (task:Task): Promise { + const files = ['../../example_training_data/titanic_train.csv'] + const titanicTask = defaultTasks.titanic.getTask() + return await new node.data.NodeTabularLoader(task, ',').loadAll(files, { + features: titanicTask.trainingInformation.inputColumns, + labels: titanicTask.trainingInformation.outputColumns, + shuffle: false + }) +} \ No newline at end of file diff --git a/docs/node_example/example.ts b/docs/node_example/example.ts index 22258a299..244cbfeb1 100644 --- a/docs/node_example/example.ts +++ b/docs/node_example/example.ts @@ -1,7 +1,5 @@ import { data, Disco, fetchTasks, Task } from '@epfml/discojs-node' - -import { startServer } from './start_server' -import { loadData } from './data' +import { loadTitanicData } from './data' /** * Example of discojs API, we load data, build the appropriate loggers, the disco object @@ -18,24 +16,22 @@ async function runUser (url: URL, task: Task, dataset: data.DataSplit): Promise< async function main (): Promise { - const [server, serverUrl] = await startServer() + // First have a server instance running before running this script + const serverUrl = new URL('http://localhost:8080/') + const tasks = await fetchTasks(serverUrl) // Choose your task to train - const task = tasks.get('simple_face') as Task + const task = tasks.get('titanic') as Task - const dataset = await loadData(task) + const dataset = await loadTitanicData(task) // Add more users to the list to simulate more clients await Promise.all([ runUser(serverUrl, task, dataset), - runUser(serverUrl, task, dataset) + runUser(serverUrl, task, dataset), + runUser(serverUrl, task, dataset), ]) - - await new Promise((resolve, reject) => { - server.once('close', resolve) - server.close(reject) - }) } main().catch(console.error) diff --git a/docs/node_example/start_server.ts b/docs/node_example/start_server.ts deleted file mode 100644 index f8bea0f63..000000000 --- a/docs/node_example/start_server.ts +++ /dev/null @@ -1,33 +0,0 @@ -import http from 'node:http' - -import { Disco } from '@epfml/disco-server' - -export async function startServer (): Promise<[http.Server, URL]> { - const disco = new Disco() - await disco.addDefaultTasks() - - const server = disco.serve(8000) - await new Promise((resolve, reject) => { - server.once('listening', resolve) - server.once('error', reject) - server.on('error', console.error) - }) - - let addr: string - const rawAddr = server.address() - if (rawAddr === null) { - throw new Error('unable to get server address') - } else if (typeof rawAddr === 'string') { - addr = rawAddr - } else if (typeof rawAddr === 'object') { - if (rawAddr.family === '4') { - addr = `${rawAddr.address}:${rawAddr.port}` - } else { - addr = `[${rawAddr.address}]:${rawAddr.port}` - } - } else { - throw new Error('unable to get address to server') - } - - return [server, new URL('', `http://${addr}`)] -} \ No newline at end of file diff --git a/docs/node_example/tsconfig.json b/docs/node_example/tsconfig.json index 1e8eb3ae8..74dbc77ec 100644 --- a/docs/node_example/tsconfig.json +++ b/docs/node_example/tsconfig.json @@ -14,7 +14,7 @@ "declaration": true, - "typeRoots": ["node_modules/@types", "discojs-core/types"] + "typeRoots": ["node_modules/@types", "../../discojs/discojs-core/types"] }, "include": ["*.ts"], "exclude": ["node_modules"] diff --git a/server/src/router/federated/server.ts b/server/src/router/federated/server.ts index b8a1fc77f..8237c0a90 100644 --- a/server/src/router/federated/server.ts +++ b/server/src/router/federated/server.ts @@ -89,17 +89,23 @@ export class Federated extends Server { } /** - * Loop storing aggregation results, every time an aggregation result promise resolves. - * This happens once per round. + * Loop creating an aggregation result promise at each round. + * Because clients contribute to the round asynchronously, a promise is used to let them wait + * until the server has aggregated the weights. This loop creates a promise whenever the previous + * one resolved and awaits until it resolves. The promise is used in createPromiseForWeights. * @param aggregator The aggregation handler */ private async storeAggregationResult (aggregator: aggregators.Aggregator): Promise { - // Renew the aggregation result promise. + // Create a promise on the future aggregated weights const result = aggregator.receiveResult() - // Store the result promise somewhere for the server to fetch from, so that it can await - // the result on client request. + // Store the promise such that it is accessible from other methods this.results = this.results.set(aggregator.task.taskID, result) + // The promise resolves once the server received enough contributions (through the handle method) + // and the aggregator aggregated the weights. await result + // Update the server round with the aggregator round + this.rounds = this.rounds.set(aggregator.task.taskID, aggregator.round) + // Create a new promise for the next round void this.storeAggregationResult(aggregator) } @@ -113,6 +119,85 @@ export class Federated extends Server { void this.storeAggregationResult(aggregator) } + /** + * This method is called when a client sends its contribution to the server. The server + * first adds the contribution to the aggregator and then replies with the aggregated weights + * + * @param msg the client message received of type SendPayload which contains the local client's weights + * @param task the task for which the client is contributing + * @param clientId the clientID of the contribution + * @param ws the websocket through which send the aggregated weights + */ + private async addContributionAndSendModel (msg: messages.SendPayload, task: Task, + clientId: client.NodeID, ws: WebSocket): Promise { + const { payload, round } = msg + const aggregator = this.aggregators.get(task.taskID) + + if (!(Array.isArray(payload) && + payload.every((e) => typeof e === 'number'))) { + throw new Error('received invalid weights format') + } + if (aggregator === undefined) { + throw new Error(`received weights for unknown task: ${task.taskID}`) + } + + // It is important to create a promise for the weights BEFORE adding the contribution + // Otherwise the server might go to the next round before sending the + // aggregated weights. Once the server has aggregated the weights it will + // send the new weights to the client. + // Use the void keyword to explicity avoid waiting for the promise to resolve + this.createPromiseForWeights(task, aggregator, ws) + .catch(console.error) + + const serialized = serialization.weights.decode(payload) + // Add the contribution to the aggregator, + // which returns False if the contribution is too old + if (!aggregator.add(clientId, serialized, round, 0)) { + console.info('Dropped contribution from client', clientId, 'for round', round) + } + } + + /** + * This method is called after received a local update. + * It puts the client on hold until the server has aggregated the weights + * by creating a Promise which will resolve once the server has received + * enough contributions. Relying on a promise is useful since clients may + * send their contributions at different times and a promise lets the server + * wait asynchronously for the results + * + * @param task the task to which the client is contributing + * @param aggregator the server aggregator, in order to access the current round + * @param ws the websocket through which send the aggregated weights + */ + private async createPromiseForWeights ( + task: Task, + aggregator: aggregators.Aggregator, + ws: WebSocket): Promise { + const promisedResult = this.results.get(task.taskID) + if (promisedResult === undefined) { + throw new Error(`result promise was not set for task ${task.taskID}`) + } + + // Wait for aggregation result to resolve with timeout, giving the network a time window + // to contribute to the model + void Promise.race([promisedResult, client.utils.timeout()]) + .then((result) => + // Reply with round - 1 because the round number should match the round at which the client sent its weights + // After the server aggregated the weights it also incremented the round so the server replies with round - 1 + [result, aggregator.round - 1] as [WeightsContainer, number]) + .then(async ([result, round]) => + [await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number]) + .then(([serialized, round]) => { + const msg: messages.ReceiveServerPayload = { + type: MessageTypes.ReceiveServerPayload, + round, + payload: serialized + } + ws.send(msgpack.encode(msg)) + }) + .catch(console.error) + } + protected handle ( task: Task, ws: WebSocket, @@ -133,6 +218,8 @@ export class Federated extends Server { const msg = msgpack.decode(data) if (msg.type === MessageTypes.ClientConnected) { + this.logsAppend(task.taskID, clientId, MessageTypes.ClientConnected, 0) + let aggregator = this.aggregators.get(task.taskID) if (aggregator === undefined) { aggregator = new aggregators.MeanAggregator(task) @@ -140,42 +227,19 @@ export class Federated extends Server { } console.info('client', clientId, 'joined', task.taskID) - this.logsAppend(task.taskID, clientId, MessageTypes.ClientConnected, 0) - const msg: AssignNodeID = { type: MessageTypes.AssignNodeID, id: clientId } ws.send(msgpack.encode(msg)) } else if (msg.type === MessageTypes.SendPayload) { - const { payload, round } = msg - - const aggregator = this.aggregators.get(task.taskID) - - this.logsAppend( - task.taskID, - clientId, - MessageTypes.SendPayload, - msg.round - ) - - if (!( - Array.isArray(payload) && - payload.every((e) => typeof e === 'number') - )) { - throw new Error('received invalid weights format') - } - - const serialized = serialization.weights.decode(payload) - - if (aggregator === undefined) { - throw new Error(`received weights for unknown task: ${task.taskID}`) - } + this.logsAppend(task.taskID, clientId, MessageTypes.SendPayload, msg.round) - // TODO @s314cy: add communication rounds to federated learning - if (!aggregator.add(clientId, serialized, round, 0)) { - console.info('Dropped contribution from client', clientId, 'for round', round) + if (model === undefined) { + throw new Error('aggregator model was not set') } + this.addContributionAndSendModel(msg, task, clientId, ws) + .catch(console.error) } else if (msg.type === MessageTypes.ReceiveServerStatistics) { const statistics = this.informants .get(task.taskID) @@ -188,37 +252,16 @@ export class Federated extends Server { ws.send(msgpack.encode(msg)) } else if (msg.type === MessageTypes.ReceiveServerPayload) { + this.logsAppend(task.taskID, clientId, MessageTypes.ReceiveServerPayload, 0) const aggregator = this.aggregators.get(task.taskID) if (aggregator === undefined) { throw new Error(`requesting round of unknown task: ${task.taskID}`) } - - this.logsAppend(task.taskID, clientId, MessageTypes.ReceiveServerPayload, 0) - if (model === undefined) { throw new Error('aggregator model was not set') } - const promisedResult = this.results.get(task.taskID) - if (promisedResult === undefined) { - throw new Error(`result promise was not set for task ${task.taskID}`) - } - - // Wait for aggregation result with timeout, giving the network a time window - // to contribute to the model sent to the requesting client. - void Promise.race([promisedResult, client.utils.timeout()]) - .then((result) => - [result, aggregator.round - 1] as [WeightsContainer, number]) - .then(async ([result, round]) => - [await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number]) - .then(([serialized, round]) => { - const msg: messages.ReceiveServerPayload = { - type: MessageTypes.ReceiveServerPayload, - round, - payload: serialized - } - ws.send(msgpack.encode(msg)) - }) + this.createPromiseForWeights(task, aggregator, ws) .catch(console.error) } else if (msg.type === MessageTypes.SendMetadata) { const { round, key, value } = msg diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 3cae23333..b4f50e258 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -11,7 +11,7 @@ import { getClient, startServer } from '../utils' const SCHEME = TrainingSchemes.FEDERATED describe('end-to-end federated', function () { - this.timeout(90_000) + this.timeout(120_000) let server: Server beforeEach(async () => {