From a0be7af8b7d8597ac9c8c70887798008155a8d6c Mon Sep 17 00:00:00 2001 From: peacefulotter Date: Sat, 3 Feb 2024 14:09:03 +0100 Subject: [PATCH 1/4] Added wikitext task, added text-dataset for core, node and web, renamed task.taskId to task.id --- discojs/discojs-core/src/async_informant.ts | 103 +-- discojs/discojs-core/src/client/base.ts | 198 +++--- .../src/client/decentralized/base.ts | 453 +++++++------ .../discojs-core/src/client/federated/base.ts | 475 ++++++++------ .../src/dataset/data/image_data.spec.ts | 53 +- .../discojs-core/src/dataset/data/index.ts | 11 +- .../data/preprocessing/text_preprocessing.ts | 83 ++- .../src/dataset/data/tabular_data.spec.ts | 108 +-- .../src/dataset/data/text_data.ts | 27 +- .../src/dataset/data_loader/index.ts | 3 +- .../src/dataset/data_loader/text_loader.ts | 157 ++++- discojs/discojs-core/src/dataset/index.ts | 20 +- .../discojs-core/src/default_tasks/cifar10.ts | 118 ++-- .../discojs-core/src/default_tasks/geotags.ts | 124 ++-- .../discojs-core/src/default_tasks/index.ts | 1 + .../src/default_tasks/lus_covid.ts | 180 ++--- .../discojs-core/src/default_tasks/mnist.ts | 132 ++-- .../src/default_tasks/simple_face.ts | 90 +-- .../src/default_tasks/skin_mnist.ts | 172 ++--- .../discojs-core/src/default_tasks/titanic.ts | 172 ++--- .../src/default_tasks/wikitext.ts | 86 +++ discojs/discojs-core/src/task/index.ts | 13 +- .../src/task/model_compile_data.ts | 60 +- discojs/discojs-core/src/task/task.ts | 79 +-- discojs/discojs-core/src/task/task_handler.ts | 37 +- .../discojs-core/src/task/task_provider.ts | 24 +- .../src/task/training_information.ts | 387 ++++++----- discojs/discojs-core/src/training/disco.ts | 250 +++---- discojs/discojs-core/src/training/index.ts | 2 + .../src/training/models/gpt/bun.lockb | Bin 0 -> 8380 bytes .../src/training/models/gpt/config.ts | 77 +++ .../src/training/models/gpt/evaluate.ts | 47 ++ .../src/training/models/gpt/index.ts | 4 + .../src/training/models/gpt/model.ts | 616 ++++++++++++++++++ .../src/training/models/gpt/optimizers.ts | 120 ++++ .../src/training/models/gpt/train.ts | 150 +++++ .../discojs-core/src/training/models/index.ts | 2 + .../training/trainer/distributed_trainer.ts | 120 ++-- .../src/training/trainer/local_trainer.ts | 29 +- .../src/training/trainer/trainer.ts | 249 +++---- .../src/training/trainer/trainer_builder.ts | 150 +++-- .../src/validation/validator.spec.ts | 155 +++-- .../src/dataset/data_loader/index.ts | 1 + .../dataset/data_loader/text_loader.spec.ts | 174 +++++ .../src/dataset/data_loader/text_loader.ts | 188 +++++- .../src/dataset/data_loader/text-worker.ts | 95 +++ .../dataset/data_loader/text_loader.spec.ts | 210 ++++++ .../src/dataset/data_loader/text_loader.ts | 148 ++++- discojs/discojs-web/src/memory/memory.ts | 258 ++++---- docs/TASK.md | 121 ++-- server/src/get_server.ts | 99 +-- server/src/router/decentralized/server.ts | 261 ++++---- server/src/router/federated/server.ts | 563 +++++++++------- server/src/router/server.ts | 121 ++-- server/src/router/tasks.ts | 194 +++--- server/src/tasks.ts | 227 ++++--- web-client/cypress/e2e/tasks.cy.ts | 64 +- web-client/src/components/pages/TaskList.vue | 161 ++--- web-client/src/components/testing/Tester.vue | 376 ++++++----- web-client/src/components/testing/Testing.vue | 248 ++++--- .../src/components/training/ModelCaching.vue | 229 +++---- web-client/src/store/tasks.ts | 38 +- 62 files changed, 5803 insertions(+), 3310 deletions(-) create mode 100644 discojs/discojs-core/src/default_tasks/wikitext.ts create mode 100755 discojs/discojs-core/src/training/models/gpt/bun.lockb create mode 100644 discojs/discojs-core/src/training/models/gpt/config.ts create mode 100755 discojs/discojs-core/src/training/models/gpt/evaluate.ts create mode 100644 discojs/discojs-core/src/training/models/gpt/index.ts create mode 100644 discojs/discojs-core/src/training/models/gpt/model.ts create mode 100644 discojs/discojs-core/src/training/models/gpt/optimizers.ts create mode 100644 discojs/discojs-core/src/training/models/gpt/train.ts create mode 100644 discojs/discojs-core/src/training/models/index.ts create mode 100644 discojs/discojs-node/src/dataset/data_loader/text_loader.spec.ts create mode 100644 discojs/discojs-web/src/dataset/data_loader/text-worker.ts create mode 100644 discojs/discojs-web/src/dataset/data_loader/text_loader.spec.ts diff --git a/discojs/discojs-core/src/async_informant.ts b/discojs/discojs-core/src/async_informant.ts index b7508b9ca..2430d79ef 100644 --- a/discojs/discojs-core/src/async_informant.ts +++ b/discojs/discojs-core/src/async_informant.ts @@ -1,64 +1,67 @@ import { AggregatorBase } from './aggregator' export class AsyncInformant { - private _round = 0 - private _currentNumberOfParticipants = 0 - private _totalNumberOfParticipants = 0 - private _averageNumberOfParticipants = 0 + private _round = 0 + private _currentNumberOfParticipants = 0 + private _totalNumberOfParticipants = 0 + private _averageNumberOfParticipants = 0 - constructor ( - private readonly aggregator: AggregatorBase - ) {} + constructor(private readonly aggregator: AggregatorBase) {} - update (): void { - console.debug('before:') - this.printAllInfos() - if (this.round === 0 || this.round < this.aggregator.round) { - this._round = this.aggregator.round - this._currentNumberOfParticipants = this.aggregator.size - this._averageNumberOfParticipants = this.totalNumberOfParticipants / this.round - this._totalNumberOfParticipants += this.currentNumberOfParticipants - } else { - this._round = this.aggregator.round + update(): void { + console.debug('before:') + this.printAllInfos() + if (this.round === 0 || this.round < this.aggregator.round) { + this._round = this.aggregator.round + this._currentNumberOfParticipants = this.aggregator.size + this._averageNumberOfParticipants = + this.totalNumberOfParticipants / this.round + this._totalNumberOfParticipants += this.currentNumberOfParticipants + } else { + this._round = this.aggregator.round + } + console.debug('after:') + this.printAllInfos() } - console.debug('after:') - this.printAllInfos() - } - // Getter functions - get round (): number { - return this._round - } + // Getter functions + get round(): number { + return this._round + } - get currentNumberOfParticipants (): number { - return this._currentNumberOfParticipants - } + get currentNumberOfParticipants(): number { + return this._currentNumberOfParticipants + } - get totalNumberOfParticipants (): number { - return this._totalNumberOfParticipants - } + get totalNumberOfParticipants(): number { + return this._totalNumberOfParticipants + } - get averageNumberOfParticipants (): number { - return this._averageNumberOfParticipants - } + get averageNumberOfParticipants(): number { + return this._averageNumberOfParticipants + } - getAllStatistics (): Record< - 'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number - > { - return { - round: this.round, - currentNumberOfParticipants: this.currentNumberOfParticipants, - totalNumberOfParticipants: this.totalNumberOfParticipants, - averageNumberOfParticipants: this.averageNumberOfParticipants + getAllStatistics(): Record< + | 'round' + | 'currentNumberOfParticipants' + | 'totalNumberOfParticipants' + | 'averageNumberOfParticipants', + number + > { + return { + round: this.round, + currentNumberOfParticipants: this.currentNumberOfParticipants, + totalNumberOfParticipants: this.totalNumberOfParticipants, + averageNumberOfParticipants: this.averageNumberOfParticipants, + } } - } - // Debug - public printAllInfos (): void { - console.debug('task:', this.aggregator.task.taskID) - console.debug('round:', this.round) - console.debug('participants:', this.currentNumberOfParticipants) - console.debug('total:', this.totalNumberOfParticipants) - console.debug('average:', this.averageNumberOfParticipants) - } + // Debug + public printAllInfos(): void { + console.debug('task:', this.aggregator.task.id) + console.debug('round:', this.round) + console.debug('participants:', this.currentNumberOfParticipants) + console.debug('total:', this.totalNumberOfParticipants) + console.debug('average:', this.averageNumberOfParticipants) + } } diff --git a/discojs/discojs-core/src/client/base.ts b/discojs/discojs-core/src/client/base.ts index a0bed010d..6700b891a 100644 --- a/discojs/discojs-core/src/client/base.ts +++ b/discojs/discojs-core/src/client/base.ts @@ -1,7 +1,13 @@ import { Set } from 'immutable' import axios from 'axios' -import { tf, Task, TrainingInformant, serialization, WeightsContainer } from '..' +import { + tf, + Task, + TrainingInformant, + serialization, + WeightsContainer, +} from '..' import { NodeID } from './types' import { EventConnection } from './event_connection' import { Aggregator } from '../aggregator' @@ -11,119 +17,119 @@ import { Aggregator } from '../aggregator' * communication with other nodes, be it peers or a server. */ export abstract class Base { - /** - * Own ID provided by the network's server. - */ - protected _ownId?: NodeID - /** - * The network's server. - */ - protected _server?: EventConnection - /** - * The aggregator's result produced after aggregation. - */ - protected aggregationResult?: Promise - - constructor ( /** - * The network server's URL to connect to. + * Own ID provided by the network's server. */ - public readonly url: URL, + protected _ownId?: NodeID /** - * The client's corresponding task. + * The network's server. */ - public readonly task: Task, + protected _server?: EventConnection /** - * The client's aggregator. + * The aggregator's result produced after aggregation. */ - public readonly aggregator: Aggregator - ) {} + protected aggregationResult?: Promise - /** - * Handles the connection process from the client to any sort of network server. - */ - async connect (): Promise {} + constructor( + /** + * The network server's URL to connect to. + */ + public readonly url: URL, + /** + * The client's corresponding task. + */ + public readonly task: Task, + /** + * The client's aggregator. + */ + public readonly aggregator: Aggregator + ) {} - /** - * Handles the disconnection process of the client from any sort of network server. - */ - async disconnect (): Promise {} + /** + * Handles the connection process from the client to any sort of network server. + */ + async connect(): Promise {} - /** - * Fetches the latest model available on the network's server, for the adequate task. - * @returns The latest model - */ - async getLatestModel (): Promise { - const url = new URL('', this.url.href) - if (!url.pathname.endsWith('/')) { - url.pathname += '/' - } - url.pathname += `tasks/${this.task.taskID}/model.json` + /** + * Handles the disconnection process of the client from any sort of network server. + */ + async disconnect(): Promise {} - const response = await axios.get(url.href) + /** + * Fetches the latest model available on the network's server, for the adequate task. + * @returns The latest model + */ + async getLatestModel(): Promise { + const url = new URL('', this.url.href) + if (!url.pathname.endsWith('/')) { + url.pathname += '/' + } + url.pathname += `tasks/${this.task.id}/model.json` - return await serialization.model.decode(response.data) - } + const response = await axios.get(url.href) - /** - * Communication callback called once at the beginning of the training instance. - * @param weights The initial model weights - * @param trainingInformant The training informant - */ - async onTrainBeginCommunication ( - weights: WeightsContainer, - trainingInformant: TrainingInformant - ): Promise {} + return await serialization.model.decode(response.data) + } - /** - * Communication callback called once at the end of the training instance. - * @param weights The final model weights - * @param trainingInformant The training informant - */ - async onTrainEndCommunication ( - weights: WeightsContainer, - trainingInformant: TrainingInformant - ): Promise {} + /** + * Communication callback called once at the beginning of the training instance. + * @param weights The initial model weights + * @param trainingInformant The training informant + */ + async onTrainBeginCommunication( + weights: WeightsContainer, + trainingInformant: TrainingInformant + ): Promise {} - /** - * Communication callback called at the beginning of every training round. - * @param weights The most recent local weight updates - * @param round The current training round - * @param trainingInformant The training informant - */ - async onRoundBeginCommunication ( - weights: WeightsContainer, - round: number, - trainingInformant: TrainingInformant - ): Promise {} + /** + * Communication callback called once at the end of the training instance. + * @param weights The final model weights + * @param trainingInformant The training informant + */ + async onTrainEndCommunication( + weights: WeightsContainer, + trainingInformant: TrainingInformant + ): Promise {} - /** - * Communication callback called the end of every training round. - * @param weights The most recent local weight updates - * @param round The current training round - * @param trainingInformant The training informant - */ - async onRoundEndCommunication ( - weights: WeightsContainer, - round: number, - trainingInformant: TrainingInformant - ): Promise {} + /** + * Communication callback called at the beginning of every training round. + * @param weights The most recent local weight updates + * @param round The current training round + * @param trainingInformant The training informant + */ + async onRoundBeginCommunication( + weights: WeightsContainer, + round: number, + trainingInformant: TrainingInformant + ): Promise {} - get nodes (): Set { - return this.aggregator.nodes - } + /** + * Communication callback called the end of every training round. + * @param weights The most recent local weight updates + * @param round The current training round + * @param trainingInformant The training informant + */ + async onRoundEndCommunication( + weights: WeightsContainer, + round: number, + trainingInformant: TrainingInformant + ): Promise {} + + get nodes(): Set { + return this.aggregator.nodes + } - get ownId (): NodeID { - if (this._ownId === undefined) { - throw new Error('the node is not connected') + get ownId(): NodeID { + if (this._ownId === undefined) { + throw new Error('the node is not connected') + } + return this._ownId } - return this._ownId - } - get server (): EventConnection { - if (this._server === undefined) { - throw new Error('server undefined, not connected') + get server(): EventConnection { + if (this._server === undefined) { + throw new Error('server undefined, not connected') + } + return this._server } - return this._server - } } diff --git a/discojs/discojs-core/src/client/decentralized/base.ts b/discojs/discojs-core/src/client/decentralized/base.ts index fca6539a5..76956620d 100644 --- a/discojs/discojs-core/src/client/decentralized/base.ts +++ b/discojs/discojs-core/src/client/decentralized/base.ts @@ -5,7 +5,13 @@ import { TrainingInformant, WeightsContainer, serialization } from '../..' import { Client, NodeID } from '..' import { type, ClientConnected } from '../messages' import { timeout } from '../utils' -import { EventConnection, WebSocketServer, waitMessage, PeerConnection, waitMessageWithTimeout } from '../event_connection' +import { + EventConnection, + WebSocketServer, + waitMessage, + PeerConnection, + waitMessageWithTimeout, +} from '../event_connection' import { PeerPool } from './peer_pool' import * as messages from './messages' @@ -16,220 +22,269 @@ import * as messages from './messages' * WebRTC for Node.js. */ export class Base extends Client { - /** - * The pool of peers to communicate with during the current training round. - */ - private pool?: Promise - private connections?: Map - - /** - * Send message to server that this client is ready for the next training round. - */ - private async waitForPeers (round: number): Promise> { - console.info(`[${this.ownId}] is ready for round`, round) - - // Broadcast our readiness - const readyMessage: messages.PeerIsReady = { type: type.PeerIsReady } - - if (this.server === undefined) { - throw new Error('server undefined, could not connect peers') - } - this.server.send(readyMessage) - - // Wait for peers to be connected before sending any update information - try { - const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound) - if (this.nodes.size > 0) { - throw new Error('got new peer list from server but was already received for this round') - } - - const peers = Set(receivedMessage.peers) - console.info(`[${this.ownId}] received peers for round:`, peers.toJS()) - if (this.ownId !== undefined && peers.has(this.ownId)) { - throw new Error('received peer list contains our own id') - } - - this.aggregator.setNodes(peers.add(this.ownId)) - - if (this.pool === undefined) { - throw new Error('waiting for peers but peer pool is undefined') - } - - const pool = await this.pool - const connections = await pool.getPeers( - peers, - this.server, - // Init receipt of peers weights - (conn) => this.receivePayloads(conn, round) - ) - - console.info(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS()) - return connections - } catch (e) { - console.error(e) - this.aggregator.setNodes(Set(this.ownId)) - return Map() - } - } - - protected sendMessagetoPeer (peer: PeerConnection, msg: messages.PeerMessage): void { - console.info(`[${this.ownId}] send message to peer`, msg.peer, msg) - peer.send(msg) - } - - /** - * Creation of the WebSocket for the server, connection of client to that WebSocket, - * deals with message reception from the decentralized client's perspective (messages received by client). - */ - private async connectServer (url: URL): Promise { - const server: EventConnection = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer) - - server.on(type.SignalForPeer, (event) => { - console.info(`[${this.ownId}] received signal from`, event.peer) - - if (this.pool === undefined) { - throw new Error('received signal but peer pool is undefined') - } - void this.pool.then((pool) => pool.signal(event.peer, event.signal)) - }) - - return server - } - - async connect (): Promise { - const URL = typeof window !== 'undefined' ? window.URL : nodeUrl.URL - const serverURL = new URL('', this.url.href) - switch (this.url.protocol) { - case 'http:': - serverURL.protocol = 'ws:' - break - case 'https:': - serverURL.protocol = 'wss:' - break - default: - throw new Error(`unknown protocol: ${this.url.protocol}`) - } - serverURL.pathname += `deai/${this.task.taskID}` + /** + * The pool of peers to communicate with during the current training round. + */ + private pool?: Promise + private connections?: Map + + /** + * Send message to server that this client is ready for the next training round. + */ + private async waitForPeers( + round: number + ): Promise> { + console.info(`[${this.ownId}] is ready for round`, round) + + // Broadcast our readiness + const readyMessage: messages.PeerIsReady = { type: type.PeerIsReady } + + if (this.server === undefined) { + throw new Error('server undefined, could not connect peers') + } + this.server.send(readyMessage) + + // Wait for peers to be connected before sending any update information + try { + const receivedMessage = await waitMessageWithTimeout( + this.server, + type.PeersForRound + ) + if (this.nodes.size > 0) { + throw new Error( + 'got new peer list from server but was already received for this round' + ) + } - this._server = await this.connectServer(serverURL) + const peers = Set(receivedMessage.peers) + console.info( + `[${this.ownId}] received peers for round:`, + peers.toJS() + ) + if (this.ownId !== undefined && peers.has(this.ownId)) { + throw new Error('received peer list contains our own id') + } - const msg: ClientConnected = { - type: type.ClientConnected - } - this.server.send(msg) + this.aggregator.setNodes(peers.add(this.ownId)) - const peerIdMsg = await waitMessage(this.server, type.AssignNodeID) - console.info(`[${peerIdMsg.id}] assigned id generated by server`) + if (this.pool === undefined) { + throw new Error('waiting for peers but peer pool is undefined') + } - if (this._ownId !== undefined) { - throw new Error('received id from server but was already received') + const pool = await this.pool + const connections = await pool.getPeers( + peers, + this.server, + // Init receipt of peers weights + (conn) => this.receivePayloads(conn, round) + ) + + console.info( + `[${this.ownId}] received peers for round ${round}:`, + connections.keySeq().toJS() + ) + return connections + } catch (e) { + console.error(e) + this.aggregator.setNodes(Set(this.ownId)) + return Map() + } } - this._ownId = peerIdMsg.id - this.pool = PeerPool.init(peerIdMsg.id) - } - - async disconnect (): Promise { - // Disconnect from peers - const pool = await this.pool - pool?.shutdown() - this.pool = undefined - - if (this.connections !== undefined) { - const peers = this.connections.keySeq().toSet() - this.aggregator.setNodes(this.aggregator.nodes.subtract(peers)) + + protected sendMessagetoPeer( + peer: PeerConnection, + msg: messages.PeerMessage + ): void { + console.info(`[${this.ownId}] send message to peer`, msg.peer, msg) + peer.send(msg) } - // Disconnect from server - this.server?.disconnect() - this._server = undefined - this._ownId = undefined - } - - async onRoundBeginCommunication ( - weights: WeightsContainer, - round: number, - trainingInformant: TrainingInformant - ): Promise { - // Reset peers list at each round of training to make sure client works with an updated peers - // list, maintained by the server. Adds any received weights to the aggregator. - this.connections = await this.waitForPeers(round) - // Store the promise for the current round's aggregation result. - this.aggregationResult = this.aggregator.receiveResult() - } - - async onRoundEndCommunication ( - weights: WeightsContainer, - round: number, - trainingInformant: TrainingInformant - ): Promise { - let result = weights - - // Perform the required communication rounds. Each communication round consists in sending our local payload, - // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator. - // A communication round's payload is the aggregation result of the previous communication round. The first - // communication round simply sends our training result, i.e. model weights updates. This scheme allows for - // the aggregator to define any complex multi-round aggregation mechanism. - for (let r = 0; r < this.aggregator.communicationRounds; r++) { - // Generate our payloads for this communication round and send them to all ready connected peers - if (this.connections !== undefined) { - const payloads = this.aggregator.makePayloads(result) - try { - await Promise.all(payloads.map(async (payload, id) => { - if (id === this.ownId) { - this.aggregator.add(this.ownId, payload, round, r) - } else { - const connection = this.connections?.get(id) - if (connection !== undefined) { - const encoded = await serialization.weights.encode(payload) - this.sendMessagetoPeer(connection, { - type: type.Payload, - peer: id, - round: r, - payload: encoded - }) - } + /** + * Creation of the WebSocket for the server, connection of client to that WebSocket, + * deals with message reception from the decentralized client's perspective (messages received by client). + */ + private async connectServer(url: URL): Promise { + const server: EventConnection = await WebSocketServer.connect( + url, + messages.isMessageFromServer, + messages.isMessageToServer + ) + + server.on(type.SignalForPeer, (event) => { + console.info(`[${this.ownId}] received signal from`, event.peer) + + if (this.pool === undefined) { + throw new Error('received signal but peer pool is undefined') } - })) - } catch { - throw new Error('error while sending weights') + void this.pool.then((pool) => pool.signal(event.peer, event.signal)) + }) + + return server + } + + async connect(): Promise { + const URL = typeof window !== 'undefined' ? window.URL : nodeUrl.URL + const serverURL = new URL('', this.url.href) + switch (this.url.protocol) { + case 'http:': + serverURL.protocol = 'ws:' + break + case 'https:': + serverURL.protocol = 'wss:' + break + default: + throw new Error(`unknown protocol: ${this.url.protocol}`) } - } + serverURL.pathname += `deai/${this.task.id}` - if (this.aggregationResult === undefined) { - throw new TypeError('aggregation result promise is undefined') - } + this._server = await this.connectServer(serverURL) + + const msg: ClientConnected = { + type: type.ClientConnected, + } + this.server.send(msg) - // Wait for aggregation before proceeding to the next communication round. - // The current result will be used as payload for the eventual next communication round. - result = await Promise.race([this.aggregationResult, timeout()]) + const peerIdMsg = await waitMessage(this.server, type.AssignNodeID) + console.info(`[${peerIdMsg.id}] assigned id generated by server`) - // There is at least one communication round remaining - if (r < this.aggregator.communicationRounds - 1) { - // Reuse the aggregation result + if (this._ownId !== undefined) { + throw new Error('received id from server but was already received') + } + this._ownId = peerIdMsg.id + this.pool = PeerPool.init(peerIdMsg.id) + } + + async disconnect(): Promise { + // Disconnect from peers + const pool = await this.pool + pool?.shutdown() + this.pool = undefined + + if (this.connections !== undefined) { + const peers = this.connections.keySeq().toSet() + this.aggregator.setNodes(this.aggregator.nodes.subtract(peers)) + } + + // Disconnect from server + this.server?.disconnect() + this._server = undefined + this._ownId = undefined + } + + async onRoundBeginCommunication( + weights: WeightsContainer, + round: number, + trainingInformant: TrainingInformant + ): Promise { + // Reset peers list at each round of training to make sure client works with an updated peers + // list, maintained by the server. Adds any received weights to the aggregator. + this.connections = await this.waitForPeers(round) + // Store the promise for the current round's aggregation result. this.aggregationResult = this.aggregator.receiveResult() - } } - // Reset the peers list for the next round - this.aggregator.resetNodes() - } + async onRoundEndCommunication( + weights: WeightsContainer, + round: number, + trainingInformant: TrainingInformant + ): Promise { + let result = weights + + // Perform the required communication rounds. Each communication round consists in sending our local payload, + // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator. + // A communication round's payload is the aggregation result of the previous communication round. The first + // communication round simply sends our training result, i.e. model weights updates. This scheme allows for + // the aggregator to define any complex multi-round aggregation mechanism. + for (let r = 0; r < this.aggregator.communicationRounds; r++) { + // Generate our payloads for this communication round and send them to all ready connected peers + if (this.connections !== undefined) { + const payloads = this.aggregator.makePayloads(result) + try { + await Promise.all( + payloads.map(async (payload, id) => { + if (id === this.ownId) { + this.aggregator.add( + this.ownId, + payload, + round, + r + ) + } else { + const connection = this.connections?.get(id) + if (connection !== undefined) { + const encoded = + await serialization.weights.encode( + payload + ) + this.sendMessagetoPeer(connection, { + type: type.Payload, + peer: id, + round: r, + payload: encoded, + }) + } + } + }) + ) + } catch { + throw new Error('error while sending weights') + } + } + + if (this.aggregationResult === undefined) { + throw new TypeError('aggregation result promise is undefined') + } - private receivePayloads (connections: Map, round: number): void { - console.info(`[${this.ownId}] Accepting new contributions for round ${round}`) - connections.forEach(async (connection, peerId) => { - let receivedPayloads = 0 - do { - try { - const message = await waitMessageWithTimeout(connection, type.Payload) - const decoded = serialization.weights.decode(message.payload) + // Wait for aggregation before proceeding to the next communication round. + // The current result will be used as payload for the eventual next communication round. + result = await Promise.race([this.aggregationResult, timeout()]) - if (!this.aggregator.add(peerId, decoded, round, message.round)) { - console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`) - } - } catch (e) { - console.warn(e instanceof Error ? e.message : e) + // There is at least one communication round remaining + if (r < this.aggregator.communicationRounds - 1) { + // Reuse the aggregation result + this.aggregationResult = this.aggregator.receiveResult() + } } - } while (++receivedPayloads < this.aggregator.communicationRounds) - }) - } + + // Reset the peers list for the next round + this.aggregator.resetNodes() + } + + private receivePayloads( + connections: Map, + round: number + ): void { + console.info( + `[${this.ownId}] Accepting new contributions for round ${round}` + ) + connections.forEach(async (connection, peerId) => { + let receivedPayloads = 0 + do { + try { + const message = await waitMessageWithTimeout( + connection, + type.Payload + ) + const decoded = serialization.weights.decode( + message.payload + ) + + if ( + !this.aggregator.add( + peerId, + decoded, + round, + message.round + ) + ) { + console.warn( + `[${this.ownId}] Failed to add contribution from peer ${peerId}` + ) + } + } catch (e) { + console.warn(e instanceof Error ? e.message : e) + } + } while (++receivedPayloads < this.aggregator.communicationRounds) + }) + } } diff --git a/discojs/discojs-core/src/client/federated/base.ts b/discojs/discojs-core/src/client/federated/base.ts index 43fd36acb..7d353de02 100644 --- a/discojs/discojs-core/src/client/federated/base.ts +++ b/discojs/discojs-core/src/client/federated/base.ts @@ -1,11 +1,22 @@ import * as nodeUrl from 'url' import { Map } from 'immutable' -import { serialization, informant, MetadataKey, MetadataValue, WeightsContainer, TrainingInformant } from '../..' +import { + serialization, + informant, + MetadataKey, + MetadataValue, + WeightsContainer, + TrainingInformant, +} from '../..' import { NodeID } from '../types' import { Base as Client } from '../base' import { type, ClientConnected } from '../messages' -import { EventConnection, waitMessageWithTimeout, WebSocketServer } from '../event_connection' +import { + EventConnection, + waitMessageWithTimeout, + WebSocketServer, +} from '../event_connection' import * as messages from './messages' /** @@ -13,238 +24,270 @@ import * as messages from './messages' * a specific task in the federated setting. */ export class Base extends Client { - /** - * Arbitrary node id assigned to the federated server which we are communicating with. - * Indeed, the server acts as a node within the network. In the federated setting described - * by this client class, the server is the only node which we are communicating with. - */ - public static readonly SERVER_NODE_ID = 'federated-server-node-id' - /** - * Most recent server-fetched round. - */ - private serverRound?: number - /** - * Most recent server-fetched aggregated result. - */ - private serverResult?: WeightsContainer - /** - * Statistics curated by the federated server. - */ - private receivedStatistics?: Record - /** - * Map of metadata values for each node id. - */ - private metadataMap?: Map - - /** - * Opens a new WebSocket connection with the server and listens to new messages over the channel - */ - private async connectServer (url: URL): Promise { - const server: EventConnection = await WebSocketServer.connect(url, messages.isMessageFederated, messages.isMessageFederated) - - return server - } - - /** - * Initializes the connection to the server and get our own node id. - * TODO: In the federated setting, should return the current server-side round - * for the task. - */ - async connect (): Promise { - const URL = typeof window !== 'undefined' ? window.URL : nodeUrl.URL - const serverURL = new URL('', this.url.href) - switch (this.url.protocol) { - case 'http:': - serverURL.protocol = 'ws:' - break - case 'https:': - serverURL.protocol = 'wss:' - break - default: - throw new Error(`unknown protocol: ${this.url.protocol}`) + /** + * Arbitrary node id assigned to the federated server which we are communicating with. + * Indeed, the server acts as a node within the network. In the federated setting described + * by this client class, the server is the only node which we are communicating with. + */ + public static readonly SERVER_NODE_ID = 'federated-server-node-id' + /** + * Most recent server-fetched round. + */ + private serverRound?: number + /** + * Most recent server-fetched aggregated result. + */ + private serverResult?: WeightsContainer + /** + * Statistics curated by the federated server. + */ + private receivedStatistics?: Record + /** + * Map of metadata values for each node id. + */ + private metadataMap?: Map + + /** + * Opens a new WebSocket connection with the server and listens to new messages over the channel + */ + private async connectServer(url: URL): Promise { + const server: EventConnection = await WebSocketServer.connect( + url, + messages.isMessageFederated, + messages.isMessageFederated + ) + + return server } - serverURL.pathname += `feai/${this.task.taskID}` + /** + * Initializes the connection to the server and get our own node id. + * TODO: In the federated setting, should return the current server-side round + * for the task. + */ + async connect(): Promise { + const URL = typeof window !== 'undefined' ? window.URL : nodeUrl.URL + const serverURL = new URL('', this.url.href) + switch (this.url.protocol) { + case 'http:': + serverURL.protocol = 'ws:' + break + case 'https:': + serverURL.protocol = 'wss:' + break + default: + throw new Error(`unknown protocol: ${this.url.protocol}`) + } - this._server = await this.connectServer(serverURL) - this.aggregator.registerNode(Base.SERVER_NODE_ID) + serverURL.pathname += `feai/${this.task.id}` - const msg: ClientConnected = { - type: type.ClientConnected - } - this.server.send(msg) - - const received = await waitMessageWithTimeout(this.server, type.AssignNodeID) - console.info(`[${received.id}] assign id generated by the server`) - this._ownId = received.id - } - - /** - * Disconnection process when user quits the task. - */ - async disconnect (): Promise { - this.server.disconnect() - this._server = undefined - this._ownId = undefined - - this.aggregator.setNodes(this.aggregator.nodes.delete(Base.SERVER_NODE_ID)) - } - - /** - * Send a message containing our local weight updates to the federated server. - * @param weights The weight updates to send - */ - async sendPayload (payload: WeightsContainer): Promise { - const msg: messages.SendPayload = { - type: type.SendPayload, - payload: await serialization.weights.encode(payload), - round: this.aggregator.round - } - this.server.send(msg) - } - - /** - * Fetches the server's result for its current (most recent) round and add it to our aggregator. - * Updates the aggregator's round if it's behind the server's. - */ - async receiveResult (): Promise { - this.serverRound = undefined - this.serverResult = undefined - - const msg: messages.MessageBase = { - type: type.ReceiveServerPayload - } - this.server.send(msg) - - try { - const { payload, round } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload) - this.serverRound = round - - // Store the server result only if it is not stale - if (this.aggregator.round <= round) { - this.serverResult = serialization.weights.decode(payload) - // Update the local round to match the server's - if (this.aggregator.round < this.serverRound) { - this.aggregator.setRound(this.serverRound) + this._server = await this.connectServer(serverURL) + this.aggregator.registerNode(Base.SERVER_NODE_ID) + + const msg: ClientConnected = { + type: type.ClientConnected, } - } - } catch (e) { - console.error(e) + this.server.send(msg) + + const received = await waitMessageWithTimeout( + this.server, + type.AssignNodeID + ) + console.info(`[${received.id}] assign id generated by the server`) + this._ownId = received.id } - } - - /** - * Pulls statistics curated by the federated server, which orchestrates the network - * and produces the aggregation result, then display the relevant statistics via the - * given training informant. - * @param trainingInformant The training informant - */ - async receiveStatistics ( - trainingInformant: informant.FederatedInformant - ): Promise { - this.receivedStatistics = undefined - - const msg: messages.MessageBase = { - type: type.ReceiveServerStatistics + + /** + * Disconnection process when user quits the task. + */ + async disconnect(): Promise { + this.server.disconnect() + this._server = undefined + this._ownId = undefined + + this.aggregator.setNodes( + this.aggregator.nodes.delete(Base.SERVER_NODE_ID) + ) } - this.server.send(msg) - - try { - const received = await waitMessageWithTimeout(this.server, type.ReceiveServerStatistics) - this.receivedStatistics = received.statistics - trainingInformant.update(this.receivedStatistics) - } catch (e) { - console.error(e) + + /** + * Send a message containing our local weight updates to the federated server. + * @param weights The weight updates to send + */ + async sendPayload(payload: WeightsContainer): Promise { + const msg: messages.SendPayload = { + type: type.SendPayload, + payload: await serialization.weights.encode(payload), + round: this.aggregator.round, + } + this.server.send(msg) } - } - - /** - * Sends metadata to the federated server. Metadata is gathered server-side according - * to the key given by clients. - * @param key The metadata key - * @param value The metadata value - */ - async sendMetadata (key: MetadataKey, value: MetadataValue): Promise { - const msg: messages.SendMetadata = { - type: type.SendMetadata, - taskId: this.task.taskID, - nodeId: this.ownId, - round: this.aggregator.round, - key, - value + + /** + * Fetches the server's result for its current (most recent) round and add it to our aggregator. + * Updates the aggregator's round if it's behind the server's. + */ + async receiveResult(): Promise { + this.serverRound = undefined + this.serverResult = undefined + + const msg: messages.MessageBase = { + type: type.ReceiveServerPayload, + } + this.server.send(msg) + + try { + const { payload, round } = await waitMessageWithTimeout( + this.server, + type.ReceiveServerPayload + ) + this.serverRound = round + + // Store the server result only if it is not stale + if (this.aggregator.round <= round) { + this.serverResult = serialization.weights.decode(payload) + // Update the local round to match the server's + if (this.aggregator.round < this.serverRound) { + this.aggregator.setRound(this.serverRound) + } + } + } catch (e) { + console.error(e) + } } - this.server.send(msg) - } - - /** - * Fetch the metadata values maintained by the federated server, for a given metadata key. - * The values are indexed by node id. - * @param key The metadata key - * @returns The map of node id to metadata value - */ - async receiveMetadataMap (key: MetadataKey): Promise | undefined> { - this.metadataMap = undefined - - const msg: messages.ReceiveServerMetadata = { - type: type.ReceiveServerMetadata, - taskId: this.task.taskID, - nodeId: this.ownId, - round: this.aggregator.round, - key + /** + * Pulls statistics curated by the federated server, which orchestrates the network + * and produces the aggregation result, then display the relevant statistics via the + * given training informant. + * @param trainingInformant The training informant + */ + async receiveStatistics( + trainingInformant: informant.FederatedInformant + ): Promise { + this.receivedStatistics = undefined + + const msg: messages.MessageBase = { + type: type.ReceiveServerStatistics, + } + this.server.send(msg) + + try { + const received = await waitMessageWithTimeout( + this.server, + type.ReceiveServerStatistics + ) + this.receivedStatistics = received.statistics + trainingInformant.update(this.receivedStatistics) + } catch (e) { + console.error(e) + } } - this.server.send(msg) + /** + * Sends metadata to the federated server. Metadata is gathered server-side according + * to the key given by clients. + * @param key The metadata key + * @param value The metadata value + */ + async sendMetadata(key: MetadataKey, value: MetadataValue): Promise { + const msg: messages.SendMetadata = { + type: type.SendMetadata, + taskId: this.task.id, + nodeId: this.ownId, + round: this.aggregator.round, + key, + value, + } - const received = await waitMessageWithTimeout(this.server, type.ReceiveServerMetadata) - if (received.metadataMap !== undefined) { - this.metadataMap = Map( - received.metadataMap.filter(([k, v]) => v !== undefined) as Array<[NodeID, MetadataValue]> - ) + this.server.send(msg) } - return this.metadataMap - } - - async onRoundBeginCommunication ( - weights: WeightsContainer, - round: number, informant: - TrainingInformant - ): Promise { - // Prepare the result promise for the incoming round - this.aggregationResult = this.aggregator.receiveResult() - } - - async onRoundEndCommunication ( - weights: WeightsContainer, - round: number, - trainingInformant: informant.FederatedInformant - ): Promise { - // NB: For now, we suppose a fully-federated setting. - - if (this.aggregationResult === undefined) { - throw new Error('local aggregation result was not set') + /** + * Fetch the metadata values maintained by the federated server, for a given metadata key. + * The values are indexed by node id. + * @param key The metadata key + * @returns The map of node id to metadata value + */ + async receiveMetadataMap( + key: MetadataKey + ): Promise | undefined> { + this.metadataMap = undefined + + const msg: messages.ReceiveServerMetadata = { + type: type.ReceiveServerMetadata, + taskId: this.task.id, + nodeId: this.ownId, + round: this.aggregator.round, + key, + } + + this.server.send(msg) + + const received = await waitMessageWithTimeout( + this.server, + type.ReceiveServerMetadata + ) + if (received.metadataMap !== undefined) { + this.metadataMap = Map( + received.metadataMap.filter( + ([k, v]) => v !== undefined + ) as Array<[NodeID, MetadataValue]> + ) + } + + return this.metadataMap } - // Send our contribution to the server - await this.sendPayload(this.aggregator.makePayloads(weights).first()) - // Fetch the server result - await this.receiveResult() - - // TODO @s314cy: add communication rounds to federated learning - if (this.serverResult !== undefined && this.aggregator.add(Base.SERVER_NODE_ID, this.serverResult, round, 0)) { - // Regular case: the server sends us its aggregation result which will serve our - // own aggregation result. - } else { - // Unexpected case: for some reason, the server result is stale. - // We proceed to the next round without its result. - console.info(`[${this.ownId}] Server result is either stale or not received`) - this.aggregator.nextRound() + async onRoundBeginCommunication( + weights: WeightsContainer, + round: number, + informant: TrainingInformant + ): Promise { + // Prepare the result promise for the incoming round + this.aggregationResult = this.aggregator.receiveResult() } - // Pull statistics about the contributors - // await this.receiveStatistics(trainingInformant) - } + async onRoundEndCommunication( + weights: WeightsContainer, + round: number, + trainingInformant: informant.FederatedInformant + ): Promise { + // NB: For now, we suppose a fully-federated setting. + + if (this.aggregationResult === undefined) { + throw new Error('local aggregation result was not set') + } + + // Send our contribution to the server + await this.sendPayload(this.aggregator.makePayloads(weights).first()) + // Fetch the server result + await this.receiveResult() + + // TODO @s314cy: add communication rounds to federated learning + if ( + this.serverResult !== undefined && + this.aggregator.add( + Base.SERVER_NODE_ID, + this.serverResult, + round, + 0 + ) + ) { + // Regular case: the server sends us its aggregation result which will serve our + // own aggregation result. + } else { + // Unexpected case: for some reason, the server result is stale. + // We proceed to the next round without its result. + console.info( + `[${this.ownId}] Server result is either stale or not received` + ) + this.aggregator.nextRound() + } + + // Pull statistics about the contributors + // await this.receiveStatistics(trainingInformant) + } - async onTrainEndCommunication (): Promise {} + async onTrainEndCommunication(): Promise {} } diff --git a/discojs/discojs-core/src/dataset/data/image_data.spec.ts b/discojs/discojs-core/src/dataset/data/image_data.spec.ts index a8b247b8b..6254cfe7f 100644 --- a/discojs/discojs-core/src/dataset/data/image_data.spec.ts +++ b/discojs/discojs-core/src/dataset/data/image_data.spec.ts @@ -4,27 +4,38 @@ import { ImageData } from './image_data' import { tf, Task } from '../..' describe('image data checks', () => { - const simplefaceMock: Task = { - taskID: 'simpleface', - displayInformation: {}, - trainingInformation: { - IMAGE_H: 200, - IMAGE_W: 200 - } - } as unknown as Task + const simplefaceMock: Task = { + id: 'simpleface', + displayInformation: {}, + trainingInformation: { + IMAGE_H: 200, + IMAGE_W: 200, + }, + } as unknown as Task - it('throw an error on incorrectly formatted data', async () => { - try { - await ImageData.init(tf.data.array([tf.zeros([150, 150, 3]), tf.zeros([150, 150, 3])]), simplefaceMock, 3) - } catch (e) { - expect(e).to.be.an.instanceOf(Error) - return - } - // no error means we failed - assert(false) - }) + it('throw an error on incorrectly formatted data', async () => { + try { + await ImageData.init( + tf.data.array([ + tf.zeros([150, 150, 3]), + tf.zeros([150, 150, 3]), + ]), + simplefaceMock, + 3 + ) + } catch (e) { + expect(e).to.be.an.instanceOf(Error) + return + } + // no error means we failed + assert(false) + }) - it('do nothing on correctly formatted data', async () => { - await ImageData.init(tf.data.array([tf.zeros([200, 200, 3]), tf.zeros([200, 200, 3])]), simplefaceMock, 3) - }) + it('do nothing on correctly formatted data', async () => { + await ImageData.init( + tf.data.array([tf.zeros([200, 200, 3]), tf.zeros([200, 200, 3])]), + simplefaceMock, + 3 + ) + }) }) diff --git a/discojs/discojs-core/src/dataset/data/index.ts b/discojs/discojs-core/src/dataset/data/index.ts index 2035f14bd..5c7b8db15 100644 --- a/discojs/discojs-core/src/dataset/data/index.ts +++ b/discojs/discojs-core/src/dataset/data/index.ts @@ -1,9 +1,14 @@ -export * as tuple from './data_split' +export * as data_split from './data_split' +export type { DataSplit } from './data_split' export { Data } from './data' export { ImageData } from './image_data' export { TabularData } from './tabular_data' export { TextData } from './text_data' export { - ImagePreprocessing, TabularPreprocessing, TextPreprocessing, - IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING + ImagePreprocessing, + TabularPreprocessing, + TextPreprocessing, + IMAGE_PREPROCESSING, + TABULAR_PREPROCESSING, + TEXT_PREPROCESSING, } from './preprocessing' diff --git a/discojs/discojs-core/src/dataset/data/preprocessing/text_preprocessing.ts b/discojs/discojs-core/src/dataset/data/preprocessing/text_preprocessing.ts index 96e53dbed..bc2e5c31b 100644 --- a/discojs/discojs-core/src/dataset/data/preprocessing/text_preprocessing.ts +++ b/discojs/discojs-core/src/dataset/data/preprocessing/text_preprocessing.ts @@ -1,73 +1,64 @@ import { Task, tf } from '../../..' import { PreprocessingFunction } from './base' -import { GPTLMHeadModel } from 'gpt-tfjs' +import defaultTokenizer from 'gpt-tokenizer/model/text-davinci-003' import { List } from 'immutable' /** * Available text preprocessing types. */ export enum TextPreprocessing { - Tokenize, - Padding + Tokenize, + Padding, } interface TextEntry extends tf.TensorContainerObject { - xs: string[] - ys: number[] + xs: string[] + ys: number[] } interface TokenizedEntry extends tf.TensorContainerObject { - xs: tf.Tensor1D - ys: tf.Tensor1D + xs: tf.Tensor1D + ys: tf.Tensor1D } -const minGptTokenizer = GPTLMHeadModel.tokenizer - const padding: PreprocessingFunction = { - type: TextPreprocessing.Padding, - apply: (x: tf.TensorContainer, task: Task) => { - const { xs, ys } = x as TokenizedEntry - // TODO: add to task definition - const maxLength = 64 - if (maxLength === undefined) { - return { xs, ys } - } - return { - xs: xs - .pad([[0, Math.max(0, maxLength - xs.size)]]) - .slice([0], [maxLength]), - ys - } - } + type: TextPreprocessing.Padding, + apply: (x: tf.TensorContainer, task: Task) => { + const { xs, ys } = x as TokenizedEntry + const params = task.trainingInformation + const maxLength = params.blockSize || 64 + // FIXME: Not sure you would want an undefined maxLength + // if (maxLength === undefined) { + // return { xs, ys } + // } + return { + xs: xs + .pad([[0, Math.max(0, maxLength - xs.size)]]) + .slice([0], [maxLength]), + ys, + } + }, } const tokenize: PreprocessingFunction = { - type: TextPreprocessing.Tokenize, - apply: (x: tf.TensorContainer, task: Task) => { - const { xs, ys } = x as TextEntry - const params = task.trainingInformation - // TODO: add to task definition - const tokenizer = (params as unknown as any).tokenizer - - let tokenized: number[] - if (tokenizer === undefined) { - tokenized = minGptTokenizer.encode(xs[0]).bpe - } else { - throw new Error('tokenizer not implemented') - } + type: TextPreprocessing.Tokenize, + apply: (x: tf.TensorContainer, task: Task) => { + const { xs, ys } = x as TextEntry + const params = task.trainingInformation + const tokenizer = params.tokenizer || defaultTokenizer + const tokenized = tokenizer.encode(xs[0]) - return { - xs: tf.tensor(tokenized), - ys: tf.tensor(ys) - } - } + return { + xs: tf.tensor(tokenized), + ys: tf.tensor(ys), + } + }, } /** * Available text preprocessing functions. */ -export const AVAILABLE_PREPROCESSING = List.of( - tokenize, - padding -).sortBy((e) => e.type) +export const AVAILABLE_PREPROCESSING = List.of(tokenize, padding).sortBy( + (e) => e.type +) diff --git a/discojs/discojs-core/src/dataset/data/tabular_data.spec.ts b/discojs/discojs-core/src/dataset/data/tabular_data.spec.ts index 26464f025..f1ae82f42 100644 --- a/discojs/discojs-core/src/dataset/data/tabular_data.spec.ts +++ b/discojs/discojs-core/src/dataset/data/tabular_data.spec.ts @@ -5,54 +5,72 @@ import { TabularData } from './tabular_data' import { tf, Task } from '../..' describe('tabular data checks', () => { - const titanicMock: Task = { - taskID: 'titanic', - displayInformation: {}, - trainingInformation: { - inputColumns: [ - 'PassengerId', - 'Age', - 'SibSp', - 'Parch', - 'Fare', - 'Pclass' - ], - outputColumns: [ - 'Survived' - ] - } - } as unknown as Task - - const dataConfig = { - features: titanicMock.trainingInformation.inputColumns, - labels: titanicMock.trainingInformation.outputColumns - } + const titanicMock: Task = { + id: 'titanic', + displayInformation: {}, + trainingInformation: { + inputColumns: [ + 'PassengerId', + 'Age', + 'SibSp', + 'Parch', + 'Fare', + 'Pclass', + ], + outputColumns: ['Survived'], + }, + } as unknown as Task - const columnConfigs = Map( - Set(dataConfig.features).map((feature) => [feature, { required: false, isLabel: false }]) - ).merge( - Set(dataConfig.labels).map((label) => [label, { required: true, isLabel: true }]) - ) + const dataConfig = { + features: titanicMock.trainingInformation.inputColumns, + labels: titanicMock.trainingInformation.outputColumns, + } - const csvConfig = { - hasHeader: true, - columnConfigs: columnConfigs.toObject(), - configuredColumnsOnly: true, - delimiter: ',' - } + const columnConfigs = Map( + Set(dataConfig.features).map((feature) => [ + feature, + { required: false, isLabel: false }, + ]) + ).merge( + Set(dataConfig.labels).map((label) => [ + label, + { required: true, isLabel: true }, + ]) + ) - it('throw an error on incorrectly formatted data', async () => { - try { - await TabularData.init(tf.data.csv('file://../../example_training_data/cifar10-labels.csv', csvConfig), titanicMock, 3) - } catch (e) { - expect(e).to.be.an.instanceOf(Error) - return + const csvConfig = { + hasHeader: true, + columnConfigs: columnConfigs.toObject(), + configuredColumnsOnly: true, + delimiter: ',', } - // no error means we failed - assert(false) - }) - it('do nothing on correctly formatted data', async () => { - await TabularData.init(tf.data.csv('file://../../example_training_data/titanic_train.csv', csvConfig), titanicMock, 3) - }) + it('throw an error on incorrectly formatted data', async () => { + try { + await TabularData.init( + tf.data.csv( + 'file://../../example_training_data/cifar10-labels.csv', + csvConfig + ), + titanicMock, + 3 + ) + } catch (e) { + expect(e).to.be.an.instanceOf(Error) + return + } + // no error means we failed + assert(false) + }) + + it('do nothing on correctly formatted data', async () => { + await TabularData.init( + tf.data.csv( + 'file://../../example_training_data/titanic_train.csv', + csvConfig + ), + titanicMock, + 3 + ) + }) }) diff --git a/discojs/discojs-core/src/dataset/data/text_data.ts b/discojs/discojs-core/src/dataset/data/text_data.ts index 8a8d80560..1641fd501 100644 --- a/discojs/discojs-core/src/dataset/data/text_data.ts +++ b/discojs/discojs-core/src/dataset/data/text_data.ts @@ -1,23 +1,24 @@ -import { Task } from '../..' +import { List } from 'immutable' +import { Task } from '../../' import { Dataset } from '../dataset' import { Data } from './data' -import { TEXT_PREPROCESSING } from './preprocessing' +import { PreprocessingFunction } from './preprocessing' /** * Disco data made of textual samples. */ export class TextData extends Data { - public readonly availablePreprocessing = TEXT_PREPROCESSING + public readonly availablePreprocessing: List = List() - static async init ( - dataset: Dataset, - task: Task, - size?: number - ): Promise { - return new TextData(dataset, task, size) - } + static async init( + dataset: Dataset, + task: Task, + size?: number + ): Promise { + return new TextData(dataset, task, size) + } - protected create (dataset: Dataset, task: Task, size?: number): TextData { - return new TextData(dataset, task, size) - } + protected create(dataset: Dataset, task: Task, size?: number): TextData { + return new TextData(dataset, task, size) + } } diff --git a/discojs/discojs-core/src/dataset/data_loader/index.ts b/discojs/discojs-core/src/dataset/data_loader/index.ts index 75677e4fd..2d7fbe38d 100644 --- a/discojs/discojs-core/src/dataset/data_loader/index.ts +++ b/discojs/discojs-core/src/dataset/data_loader/index.ts @@ -1,4 +1,5 @@ -export { DataConfig, DataLoader } from './data_loader' +export { type DataConfig, DataLoader } from './data_loader' export { ImageLoader } from './image_loader' export { TabularLoader } from './tabular_loader' export { TextLoader } from './text_loader' +export type * from './text_loader' diff --git a/discojs/discojs-core/src/dataset/data_loader/text_loader.ts b/discojs/discojs-core/src/dataset/data_loader/text_loader.ts index 320197141..1a3fd7e9f 100644 --- a/discojs/discojs-core/src/dataset/data_loader/text_loader.ts +++ b/discojs/discojs-core/src/dataset/data_loader/text_loader.ts @@ -1,16 +1,155 @@ -import { TabularLoader } from './tabular_loader' +import { tf } from '../../' import { Dataset } from '../dataset' -import { TextData, Data } from '../data' +import { TextData, Data, DataSplit } from '../data' +import { DataConfig, DataLoader } from '.' + +export interface TextConfig extends DataConfig { + blockSize: number + vocabSize: number + batchSize?: number +} + +export type BatchedTokenizedTensorSample = { + xs: tf.Tensor2D // tokens of size (B, blockSize) + ys: tf.Tensor3D // one hot encoded vector of size (B, blockSize, vocabSize) +} + +export type TokenizedDataset = Dataset + +export type TokenizedIterResult = IteratorResult< + BatchedTokenizedTensorSample, + BatchedTokenizedTensorSample +> + +export type TextSource = { + train: string[] + validation?: string[] +} + +export type ParsedWSSearchParams = { + id: string + config: TextConfig + file: string +} +export type WSSearchParams = Record + +// type AsyncTokenizedGenerator = AsyncGenerator +type AsyncTokenizedGenerator = AsyncGenerator< + BatchedTokenizedTensorSample, + void, + unknown +> + +type CoreElement = number[] | Buffer | Uint8Array +type CoreIterator = AsyncIterator /** * Text data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely, - * @epfml/discojs-web and @epfml/discojs-node. Loads data from files whose entries are line-separated and each consist of - * a sentence-like sample associated to an optional label. + * @epfml/discojs-web and @epfml/discojs-node. */ -export abstract class TextLoader extends TabularLoader { - abstract loadDatasetFrom (source: Source, config: Record): Promise +// TODO: does shuffle work for the text loader? -> add tests +export abstract class TextLoader extends DataLoader< + string, + TextSource, + TextConfig +> { + // Default config required to define TextConfig but leave DataConfig optional + static DEFAULT_CONFIG: Required> & + DataConfig = { + blockSize: 128, + vocabSize: 50258, + batchSize: 4, + } + + // TODO: remove this when refactor TASK is done + // and requires batchSize, blockSize and vocabSize + // to be required for any text task! + getBatchSize(config: TextConfig): number { + return ( + config.batchSize || + this.task.trainingInformation.datasetBatchSize || + TextLoader.DEFAULT_CONFIG.batchSize + ) + } + + // TODO: remove this when refactor TASK is done + // and finally requires batchSize, blockSize and vocabSize + // to be required for any text task! + resolveConfig(config?: Partial): TextConfig { + return Object.assign({}, TextLoader.DEFAULT_CONFIG, config) + } + + /** + * Core dataset, shared between node and web versions + * Takes an iterator that yields arrays of numbers and turns + * them into structured batch of tuples x, y + * @param config + * @param requestNext + * @returns A TokenizedDataset = tfjs dataset containing xs and ys tensors + */ + async getCoreDataset( + config: TextConfig, + iterator: CoreIterator + ): Promise { + const toUInt16 = (low: number, high: number) => { + low &= 0xff + high &= 0xff + return (high << 8) | low + } + + const { vocabSize, blockSize } = config + const batchSize = this.getBatchSize(config) + const sampleSize = blockSize + 1 + + async function* generator(): AsyncTokenizedGenerator { + let next = iterator.next() + while (true) { + const { value: chunk } = await next + if (!chunk) break + + // pre-fetch the next chunk even before actually requesting it + next = iterator.next() + + const xs = tf.buffer([batchSize, blockSize], 'int32') + const ys = tf.buffer([batchSize, blockSize, vocabSize], 'int32') + + for (let i = 0; i < batchSize; i++) { + for (let j = 0; j < sampleSize; j++) { + const idx = (i * sampleSize + j) * 2 + const low = chunk[idx] + const high = chunk[idx + 1] + const token = toUInt16(low, high) + if (j < sampleSize - 1) xs.set(token, i, j) + if (j > 0) ys.set(1, i, j - 1, token) + } + } + + const x = xs.toTensor() + const y = ys.toTensor() + yield { + xs: x as tf.Tensor2D, + ys: y as tf.Tensor3D, + } + tf.dispose([x, y]) + } + } + + // cast as any because tf.data.generator does not take a type AsyncGenerator (but it works) + return tf.data.generator(generator as any) + } + + abstract load(source: string, config: TextConfig): Promise + + // TODO: not a fan of the TextConfig, it becomes tricky to know what parameters are set where + // because of they are set in the task AND/OR in the config + // when Task objects are refactor => try to remove TextConfig entirely + // or at least don't overlap keys with the task trainingInfo keys + abstract loadAll( + source: TextSource, + config?: Partial + ): Promise - async createData (dataset: Dataset): Promise { - return await TextData.init(dataset, this.task) - } + async createData(dataset: Dataset): Promise { + return await TextData.init(dataset, this.task) + } } diff --git a/discojs/discojs-core/src/dataset/index.ts b/discojs/discojs-core/src/dataset/index.ts index f244281f2..c7496338c 100644 --- a/discojs/discojs-core/src/dataset/index.ts +++ b/discojs/discojs-core/src/dataset/index.ts @@ -1,8 +1,18 @@ -export { Dataset } from './dataset' +export * as data from './data' +export * as loader from './data_loader' +export type * from './data_loader' +export type { Dataset } from './dataset' export { DatasetBuilder } from './dataset_builder' -export { ImageLoader, TabularLoader, DataLoader } from './data_loader' +export type { DataSplit } from './data' export { - tuple, Data, TabularData, ImageData, TextData, - ImagePreprocessing, TabularPreprocessing, TextPreprocessing, - IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING + Data, + ImageData, + TextData, + TabularData, + TabularPreprocessing, + TextPreprocessing, + ImagePreprocessing, + IMAGE_PREPROCESSING, + TABULAR_PREPROCESSING, + TEXT_PREPROCESSING, } from './data' diff --git a/discojs/discojs-core/src/default_tasks/cifar10.ts b/discojs/discojs-core/src/default_tasks/cifar10.ts index 03f77e197..1cb01a60f 100644 --- a/discojs/discojs-core/src/default_tasks/cifar10.ts +++ b/discojs/discojs-core/src/default_tasks/cifar10.ts @@ -1,60 +1,70 @@ import { tf, Task, data, TaskProvider, training } from '..' export const cifar10: TaskProvider = { - getTask (): Task { - return { - taskID: 'cifar10', - displayInformation: { - taskTitle: 'CIFAR10', - summary: { - preview: 'In this challenge, we ask you to classify images into categories based on the objects shown on the image.', - overview: 'The CIFAR-10 dataset is a collection of images that are commonly used to train machine learning and computer vision algorithms. It is one of the most widely used datasets for machine learning research.' - }, - limitations: 'The training data is limited to small images of size 32x32.', - tradeoffs: 'Training success strongly depends on label distribution', - dataFormatInformation: 'Images should be of .png format and of size 32x32.
The label file should be .csv, where each row contains a file_name, class.

e.g. if you have images: 0.png (of a frog) and 1.png (of a car)
labels.csv contains: (Note that no header is needed)
0.png, frog
1.png, car', - dataExampleText: 'Below you can find 10 random examples from each of the 10 classes in the dataset.', - dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png' - }, - trainingInformation: { - modelID: 'cifar10-model', - epochs: 10, - roundDuration: 10, - validationSplit: 0.2, - batchSize: 10, - modelCompileData: { - optimizer: 'sgd', - loss: 'categoricalCrossentropy', - metrics: ['accuracy'] - }, - dataType: 'image', - preprocessingFunctions: [data.ImagePreprocessing.Resize], - IMAGE_H: 224, - IMAGE_W: 224, - LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], - scheme: 'Decentralized', - noiseScale: undefined, - clippingRadius: 20, - decentralizedSecure: true, - minimumReadyPeers: 3, - maxShareValue: 100 - } - } - }, + getTask(): Task { + return { + id: 'cifar10', + displayInformation: { + taskTitle: 'CIFAR10', + summary: { + preview: + 'In this challenge, we ask you to classify images into categories based on the objects shown on the image.', + overview: + 'The CIFAR-10 dataset is a collection of images that are commonly used to train machine learning and computer vision algorithms. It is one of the most widely used datasets for machine learning research.', + }, + limitations: + 'The training data is limited to small images of size 32x32.', + tradeoffs: + 'Training success strongly depends on label distribution', + dataFormatInformation: + 'Images should be of .png format and of size 32x32.
The label file should be .csv, where each row contains a file_name, class.

e.g. if you have images: 0.png (of a frog) and 1.png (of a car)
labels.csv contains: (Note that no header is needed)
0.png, frog
1.png, car', + dataExampleText: + 'Below you can find 10 random examples from each of the 10 classes in the dataset.', + dataExampleImage: + 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png', + }, + trainingInformation: { + modelID: 'cifar10-model', + epochs: 10, + roundDuration: 10, + validationSplit: 0.2, + batchSize: 10, + modelCompileData: { + optimizer: 'sgd', + loss: 'categoricalCrossentropy', + metrics: ['accuracy'], + }, + dataType: 'image', + preprocessingFunctions: [data.ImagePreprocessing.Resize], + IMAGE_H: 224, + IMAGE_W: 224, + LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], + scheme: 'Decentralized', + noiseScale: undefined, + clippingRadius: 20, + decentralizedSecure: true, + minimumReadyPeers: 3, + maxShareValue: 100, + }, + } + }, - async getModel (): Promise { - const mobilenet = await tf.loadLayersModel( - 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json' - ) - const x = mobilenet.getLayer('global_average_pooling2d_1') - const predictions = tf.layers - .dense({ units: 10, activation: 'softmax', name: 'denseModified' }) - .apply(x.output) as tf.SymbolicTensor + async getModel(): Promise { + const mobilenet = await tf.loadLayersModel( + 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json' + ) + const x = mobilenet.getLayer('global_average_pooling2d_1') + const predictions = tf.layers + .dense({ units: 10, activation: 'softmax', name: 'denseModified' }) + .apply(x.output) as tf.SymbolicTensor - return new training.model.TFJSModel(this.getTask(), tf.model({ - inputs: mobilenet.input, - outputs: predictions, - name: 'modelModified' - })) - } + return new training.model.TFJSModel( + this.getTask(), + tf.model({ + inputs: mobilenet.input, + outputs: predictions, + name: 'modelModified', + }) + ) + }, } diff --git a/discojs/discojs-core/src/default_tasks/geotags.ts b/discojs/discojs-core/src/default_tasks/geotags.ts index 10ed58164..aa78aa260 100644 --- a/discojs/discojs-core/src/default_tasks/geotags.ts +++ b/discojs/discojs-core/src/default_tasks/geotags.ts @@ -3,67 +3,77 @@ import { Range } from 'immutable' import { LabelTypeEnum } from '../task/label_type' export const geotags: TaskProvider = { - getTask (): Task { - return { - taskID: 'geotags', - displayInformation: { - taskTitle: 'GeoTags', - summary: { - preview: 'In this challenge, we predict the geo-location of a photo given its pixels in terms of a cell number of a grid built on top of Switzerland', - overview: 'The geotags dataset is a collection of images with geo-location information used to train a machine learning algorithm to predict the location of a photo given its pixels.' - }, - limitations: 'The training data is limited to images of size 224x224.', - tradeoffs: 'Training success strongly depends on label distribution', - dataFormatInformation: 'Images should be of .png format and of size 224x224.
The label file should be .csv, where each row contains a file_name, class. The class is the cell number of a the given grid of Switzerland. ', - labelDisplay: { - labelType: LabelTypeEnum.POLYGON_MAP, - mapBaseUrl: 'https://disco-polygon.web.app/' + getTask(): Task { + return { + id: 'geotags', + displayInformation: { + taskTitle: 'GeoTags', + summary: { + preview: + 'In this challenge, we predict the geo-location of a photo given its pixels in terms of a cell number of a grid built on top of Switzerland', + overview: + 'The geotags dataset is a collection of images with geo-location information used to train a machine learning algorithm to predict the location of a photo given its pixels.', + }, + limitations: + 'The training data is limited to images of size 224x224.', + tradeoffs: + 'Training success strongly depends on label distribution', + dataFormatInformation: + 'Images should be of .png format and of size 224x224.
The label file should be .csv, where each row contains a file_name, class. The class is the cell number of a the given grid of Switzerland. ', + labelDisplay: { + labelType: LabelTypeEnum.POLYGON_MAP, + mapBaseUrl: 'https://disco-polygon.web.app/', + }, + }, + trainingInformation: { + modelID: 'geotags-model', + epochs: 10, + roundDuration: 10, + validationSplit: 0.2, + batchSize: 10, + modelCompileData: { + optimizer: 'adam', + loss: 'categoricalCrossentropy', + metrics: ['accuracy'], + }, + dataType: 'image', + IMAGE_H: 224, + IMAGE_W: 224, + preprocessingFunctions: [data.ImagePreprocessing.Resize], + LABEL_LIST: Range(0, 127).map(String).toArray(), + scheme: 'Federated', + noiseScale: undefined, + clippingRadius: 20, + decentralizedSecure: true, + minimumReadyPeers: 3, + maxShareValue: 100, + }, } - }, - trainingInformation: { - modelID: 'geotags-model', - epochs: 10, - roundDuration: 10, - validationSplit: 0.2, - batchSize: 10, - modelCompileData: { - optimizer: 'adam', - loss: 'categoricalCrossentropy', - metrics: ['accuracy'] - }, - dataType: 'image', - IMAGE_H: 224, - IMAGE_W: 224, - preprocessingFunctions: [data.ImagePreprocessing.Resize], - LABEL_LIST: Range(0, 127).map(String).toArray(), - scheme: 'Federated', - noiseScale: undefined, - clippingRadius: 20, - decentralizedSecure: true, - minimumReadyPeers: 3, - maxShareValue: 100 - } - } - }, + }, - async getModel (): Promise { - const pretrainedModel = await tf.loadLayersModel( - 'https://storage.googleapis.com/deai-313515.appspot.com/models/geotags/model.json' - ) + async getModel(): Promise { + const pretrainedModel = await tf.loadLayersModel( + 'https://storage.googleapis.com/deai-313515.appspot.com/models/geotags/model.json' + ) - const numLayers = pretrainedModel.layers.length + const numLayers = pretrainedModel.layers.length - pretrainedModel.layers.forEach(layer => { layer.trainable = false }) - pretrainedModel.layers[numLayers - 1].trainable = true + pretrainedModel.layers.forEach((layer) => { + layer.trainable = false + }) + pretrainedModel.layers[numLayers - 1].trainable = true - const model = new training.model.TFJSModel(this.getTask(), tf.sequential({ - layers: [ - tf.layers.inputLayer({ inputShape: [224, 224, 3] }), - tf.layers.rescaling({ scale: 1 / 127.5, offset: -1 }), // Rescaling input between -1 and 1 - pretrainedModel - ] - })) + const model = new training.model.TFJSModel( + this.getTask(), + tf.sequential({ + layers: [ + tf.layers.inputLayer({ inputShape: [224, 224, 3] }), + tf.layers.rescaling({ scale: 1 / 127.5, offset: -1 }), // Rescaling input between -1 and 1 + pretrainedModel, + ], + }) + ) - return model - } + return model + }, } diff --git a/discojs/discojs-core/src/default_tasks/index.ts b/discojs/discojs-core/src/default_tasks/index.ts index c9f113752..b8a80bca2 100644 --- a/discojs/discojs-core/src/default_tasks/index.ts +++ b/discojs/discojs-core/src/default_tasks/index.ts @@ -5,3 +5,4 @@ export { titanic } from './titanic' export { simpleFace } from './simple_face' export { geotags } from './geotags' export { skinMnist } from './skin_mnist' +export { wikitext } from './wikitext' diff --git a/discojs/discojs-core/src/default_tasks/lus_covid.ts b/discojs/discojs-core/src/default_tasks/lus_covid.ts index b979db742..eb435e66a 100644 --- a/discojs/discojs-core/src/default_tasks/lus_covid.ts +++ b/discojs/discojs-core/src/default_tasks/lus_covid.ts @@ -1,95 +1,107 @@ import { tf, data, training, Task, TaskProvider } from '..' export const lusCovid: TaskProvider = { - getTask (): Task { - return { - taskID: 'lus_covid', - displayInformation: { - taskTitle: 'COVID Lung Ultrasound', - summary: { - preview: 'Do you have a data of lung ultrasound images on patients suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic?
Learn how to discriminate between COVID positive and negative patients by joining this task.', - overview: "Don’t have a dataset of your own? Download a sample of a few cases here." - }, - model: "We use a simplified* version of the DeepChest model: A deep learning model developed in our lab (intelligent Global Health.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task.

*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below
- Removed: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient
- Replaced: ResNet18 by Mobilenet", - tradeoffs: 'We are using a simpler version of DeepChest in order to be able to run it on the browser.', - dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"', - dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png', - dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png' - }, - trainingInformation: { - modelID: 'lus-covid-model', - epochs: 15, - roundDuration: 10, - validationSplit: 0.2, - batchSize: 2, - modelCompileData: { - optimizer: 'sgd', - loss: 'binaryCrossentropy', - metrics: ['accuracy'] - }, - learningRate: 0.001, - IMAGE_H: 100, - IMAGE_W: 100, - preprocessingFunctions: [data.ImagePreprocessing.Resize], - LABEL_LIST: ['COVID-Positive', 'COVID-Negative'], - dataType: 'image', - scheme: 'Decentralized', - noiseScale: undefined, - clippingRadius: 20, - decentralizedSecure: true, - minimumReadyPeers: 3, - maxShareValue: 100 - } - } - }, + getTask(): Task { + return { + id: 'lus_covid', + displayInformation: { + taskTitle: 'COVID Lung Ultrasound', + summary: { + preview: + 'Do you have a data of lung ultrasound images on patients suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic?
Learn how to discriminate between COVID positive and negative patients by joining this task.', + overview: + "Don’t have a dataset of your own? Download a sample of a few cases here.", + }, + model: "We use a simplified* version of the DeepChest model: A deep learning model developed in our lab (intelligent Global Health.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task.

*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below
- Removed: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient
- Replaced: ResNet18 by Mobilenet", + tradeoffs: + 'We are using a simpler version of DeepChest in order to be able to run it on the browser.', + dataFormatInformation: + 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"', + dataExampleText: + 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png', + dataExampleImage: + 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png', + }, + trainingInformation: { + modelID: 'lus-covid-model', + epochs: 15, + roundDuration: 10, + validationSplit: 0.2, + batchSize: 2, + modelCompileData: { + optimizer: 'sgd', + loss: 'binaryCrossentropy', + metrics: ['accuracy'], + }, + learningRate: 0.001, + IMAGE_H: 100, + IMAGE_W: 100, + preprocessingFunctions: [data.ImagePreprocessing.Resize], + LABEL_LIST: ['COVID-Positive', 'COVID-Negative'], + dataType: 'image', + scheme: 'Decentralized', + noiseScale: undefined, + clippingRadius: 20, + decentralizedSecure: true, + minimumReadyPeers: 3, + maxShareValue: 100, + }, + } + }, - async getModel (): Promise { - const imageHeight = 100 - const imageWidth = 100 - const imageChannels = 3 - const numOutputClasses = 2 - const model = tf.sequential() + async getModel(): Promise { + const imageHeight = 100 + const imageWidth = 100 + const imageChannels = 3 + const numOutputClasses = 2 + const model = tf.sequential() - // In the first layer of our convolutional neural network we have - // to specify the input shape. Then we specify some parameters for - // the convolution operation that takes place in this layer. - model.add(tf.layers.conv2d({ - inputShape: [imageHeight, imageWidth, imageChannels], - kernelSize: 5, - filters: 8, - strides: 1, - activation: 'relu', - kernelInitializer: 'varianceScaling' - })) + // In the first layer of our convolutional neural network we have + // to specify the input shape. Then we specify some parameters for + // the convolution operation that takes place in this layer. + model.add( + tf.layers.conv2d({ + inputShape: [imageHeight, imageWidth, imageChannels], + kernelSize: 5, + filters: 8, + strides: 1, + activation: 'relu', + kernelInitializer: 'varianceScaling', + }) + ) - // The MaxPooling layer acts as a sort of downsampling using max values - // in a region instead of averaging. - model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })) + // The MaxPooling layer acts as a sort of downsampling using max values + // in a region instead of averaging. + model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })) - // Repeat another conv2d + maxPooling stack. - // Note that we have more filters in the convolution. - model.add(tf.layers.conv2d({ - kernelSize: 5, - filters: 16, - strides: 1, - activation: 'relu', - kernelInitializer: 'varianceScaling' - })) - model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })) + // Repeat another conv2d + maxPooling stack. + // Note that we have more filters in the convolution. + model.add( + tf.layers.conv2d({ + kernelSize: 5, + filters: 16, + strides: 1, + activation: 'relu', + kernelInitializer: 'varianceScaling', + }) + ) + model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })) - // Now we flatten the output from the 2D filters into a 1D vector to prepare - // it for input into our last layer. This is common practice when feeding - // higher dimensional data to a final classification output layer. - model.add(tf.layers.flatten()) + // Now we flatten the output from the 2D filters into a 1D vector to prepare + // it for input into our last layer. This is common practice when feeding + // higher dimensional data to a final classification output layer. + model.add(tf.layers.flatten()) - // Our last layer is a dense layer which has 2 output units, one for each - // output class. - model.add(tf.layers.dense({ - units: numOutputClasses, - kernelInitializer: 'varianceScaling', - activation: 'softmax' - })) + // Our last layer is a dense layer which has 2 output units, one for each + // output class. + model.add( + tf.layers.dense({ + units: numOutputClasses, + kernelInitializer: 'varianceScaling', + activation: 'softmax', + }) + ) - return new training.model.TFJSModel(this.getTask(), model) - } + return new training.model.TFJSModel(this.getTask(), model) + }, } diff --git a/discojs/discojs-core/src/default_tasks/mnist.ts b/discojs/discojs-core/src/default_tasks/mnist.ts index 15d4c4192..e11b5feee 100644 --- a/discojs/discojs-core/src/default_tasks/mnist.ts +++ b/discojs/discojs-core/src/default_tasks/mnist.ts @@ -1,70 +1,76 @@ import { tf, training, Task, TaskProvider } from '..' export const mnist: TaskProvider = { - getTask (): Task { - return { - taskID: 'mnist', - displayInformation: { - taskTitle: 'MNIST', - summary: { - preview: "Test our platform by using a publicly available image dataset.

Download the classic MNIST imagebank of hand-written numbers here.
This model learns to identify hand written numbers.", - overview: 'The MNIST handwritten digit classification problem is a standard dataset used in computer vision and deep learning. Although the dataset is effectively solved, we use it to test our Decentralised Learning algorithms and platform.' - }, - model: 'The current model is a very simple CNN and its main goal is to test the app and the Decentralizsed Learning functionality.', - tradeoffs: 'We are using a simple model, first a 2d convolutional layer > max pooling > 2d convolutional layer > max pooling > convolutional layer > 2 dense layers.', - dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can upload each digit image of your dataset in the box corresponding to its label. The model taskes images of size 28x28 as input.', - dataExampleText: 'Below you can find an example of an expected image representing the digit 9.', - dataExampleImage: 'http://storage.googleapis.com/deai-313515.appspot.com/example_training_data/9-mnist-example.png' - }, - trainingInformation: { - modelID: 'mnist-model', - epochs: 10, - roundDuration: 10, - validationSplit: 0.2, - batchSize: 30, - modelCompileData: { - optimizer: 'rmsprop', - loss: 'categoricalCrossentropy', - metrics: ['accuracy'] - }, - dataType: 'image', - IMAGE_H: 28, - IMAGE_W: 28, - preprocessingFunctions: [], - LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], - scheme: 'Decentralized', - noiseScale: undefined, - clippingRadius: 20, - decentralizedSecure: true, - minimumReadyPeers: 3, - maxShareValue: 100 - } - } - }, + getTask(): Task { + return { + id: 'mnist', + displayInformation: { + taskTitle: 'MNIST', + summary: { + preview: + "Test our platform by using a publicly available image dataset.

Download the classic MNIST imagebank of hand-written numbers here.
This model learns to identify hand written numbers.", + overview: + 'The MNIST handwritten digit classification problem is a standard dataset used in computer vision and deep learning. Although the dataset is effectively solved, we use it to test our Decentralised Learning algorithms and platform.', + }, + model: 'The current model is a very simple CNN and its main goal is to test the app and the Decentralizsed Learning functionality.', + tradeoffs: + 'We are using a simple model, first a 2d convolutional layer > max pooling > 2d convolutional layer > max pooling > convolutional layer > 2 dense layers.', + dataFormatInformation: + 'This model is trained on images corresponding to digits 0 to 9. You can upload each digit image of your dataset in the box corresponding to its label. The model taskes images of size 28x28 as input.', + dataExampleText: + 'Below you can find an example of an expected image representing the digit 9.', + dataExampleImage: + 'http://storage.googleapis.com/deai-313515.appspot.com/example_training_data/9-mnist-example.png', + }, + trainingInformation: { + modelID: 'mnist-model', + epochs: 10, + roundDuration: 10, + validationSplit: 0.2, + batchSize: 30, + modelCompileData: { + optimizer: 'rmsprop', + loss: 'categoricalCrossentropy', + metrics: ['accuracy'], + }, + dataType: 'image', + IMAGE_H: 28, + IMAGE_W: 28, + preprocessingFunctions: [], + LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], + scheme: 'Decentralized', + noiseScale: undefined, + clippingRadius: 20, + decentralizedSecure: true, + minimumReadyPeers: 3, + maxShareValue: 100, + }, + } + }, - async getModel (): Promise { - const model = tf.sequential() + async getModel(): Promise { + const model = tf.sequential() - model.add( - tf.layers.conv2d({ - inputShape: [28, 28, 3], - kernelSize: 3, - filters: 16, - activation: 'relu' - }) - ) - model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })) - model.add( - tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }) - ) - model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })) - model.add( - tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }) - ) - model.add(tf.layers.flatten({})) - model.add(tf.layers.dense({ units: 64, activation: 'relu' })) - model.add(tf.layers.dense({ units: 10, activation: 'softmax' })) + model.add( + tf.layers.conv2d({ + inputShape: [28, 28, 3], + kernelSize: 3, + filters: 16, + activation: 'relu', + }) + ) + model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })) + model.add( + tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }) + ) + model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })) + model.add( + tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }) + ) + model.add(tf.layers.flatten({})) + model.add(tf.layers.dense({ units: 64, activation: 'relu' })) + model.add(tf.layers.dense({ units: 10, activation: 'softmax' })) - return new training.model.TFJSModel(this.getTask(), model) - } + return new training.model.TFJSModel(this.getTask(), model) + }, } diff --git a/discojs/discojs-core/src/default_tasks/simple_face.ts b/discojs/discojs-core/src/default_tasks/simple_face.ts index ec6b6b9ce..2bd926b91 100644 --- a/discojs/discojs-core/src/default_tasks/simple_face.ts +++ b/discojs/discojs-core/src/default_tasks/simple_face.ts @@ -1,47 +1,53 @@ import { data, training, Task, TaskProvider } from '..' export const simpleFace: TaskProvider = { - getTask (): Task { - return { - taskID: 'simple_face', - displayInformation: { - taskTitle: 'Simple Face', - summary: { - preview: 'Can you detect if the person in a picture is a child or an adult?', - overview: 'Simple face is a small subset of face_task from Kaggle' - }, - limitations: 'The training data is limited to small images of size 200x200.', - tradeoffs: 'Training success strongly depends on label distribution', - dataFormatInformation: '', - dataExampleText: 'Below you find an example', - dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png' - }, - trainingInformation: { - modelID: 'simple_face-model', - epochs: 50, - modelURL: 'https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json', - roundDuration: 1, - validationSplit: 0.2, - batchSize: 10, - preprocessingFunctions: [data.ImagePreprocessing.Normalize], - learningRate: 0.001, - modelCompileData: { - optimizer: 'sgd', - loss: 'categoricalCrossentropy', - metrics: ['accuracy'] - }, - dataType: 'image', - IMAGE_H: 200, - IMAGE_W: 200, - LABEL_LIST: ['child', 'adult'], - scheme: 'Federated', // secure aggregation not yet implemented for federated - noiseScale: undefined, - clippingRadius: undefined - } - } - }, + getTask(): Task { + return { + id: 'simple_face', + displayInformation: { + taskTitle: 'Simple Face', + summary: { + preview: + 'Can you detect if the person in a picture is a child or an adult?', + overview: + 'Simple face is a small subset of face_task from Kaggle', + }, + limitations: + 'The training data is limited to small images of size 200x200.', + tradeoffs: + 'Training success strongly depends on label distribution', + dataFormatInformation: '', + dataExampleText: 'Below you find an example', + dataExampleImage: + 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png', + }, + trainingInformation: { + modelID: 'simple_face-model', + epochs: 50, + modelURL: + 'https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json', + roundDuration: 1, + validationSplit: 0.2, + batchSize: 10, + preprocessingFunctions: [data.ImagePreprocessing.Normalize], + learningRate: 0.001, + modelCompileData: { + optimizer: 'sgd', + loss: 'categoricalCrossentropy', + metrics: ['accuracy'], + }, + dataType: 'image', + IMAGE_H: 200, + IMAGE_W: 200, + LABEL_LIST: ['child', 'adult'], + scheme: 'Federated', // secure aggregation not yet implemented for federated + noiseScale: undefined, + clippingRadius: undefined, + }, + } + }, - async getModel (): Promise { - throw new Error('Not implemented') - } + async getModel(): Promise { + throw new Error('Not implemented') + }, } diff --git a/discojs/discojs-core/src/default_tasks/skin_mnist.ts b/discojs/discojs-core/src/default_tasks/skin_mnist.ts index cecc3d3fc..3e9b8c9b0 100644 --- a/discojs/discojs-core/src/default_tasks/skin_mnist.ts +++ b/discojs/discojs-core/src/default_tasks/skin_mnist.ts @@ -1,100 +1,102 @@ import { tf, data, training, Task, TaskProvider } from '..' export const skinMnist: TaskProvider = { - getTask (): Task { - return { - taskID: 'skin_mnist', - displayInformation: { - taskTitle: 'Skin disease classification', - summary: { - preview: 'Can you determine the skin disease from the dermatoscopic images?', - overview: - 'HAM10000 "Human Against Machine with 10000 training images" dataset is a large collection of multi-source dermatoscopic images of pigmented lesions from Kaggle' - }, - limitations: - 'The training data is limited to small images of size 28x28, similarly to the MNIST dataset.', - tradeoffs: 'Training success strongly depends on label distribution', - dataFormatInformation: '', - dataExampleText: 'Below you find an example', - dataExampleImage: 'http://walidbn.com/ISIC_0024306.jpg' - }, - trainingInformation: { - modelID: 'skin_mnist-model', - epochs: 50, - roundDuration: 1, - validationSplit: 0.1, - batchSize: 32, - preprocessingFunctions: [data.ImagePreprocessing.Normalize], - learningRate: 0.001, - modelCompileData: { - optimizer: 'adam', - loss: 'categoricalCrossentropy', - metrics: ['accuracy'] - }, - dataType: 'image', - IMAGE_H: 28, - IMAGE_W: 28, - LABEL_LIST: [ - 'Melanocytic nevi', - 'Melanoma', - 'Benign keratosis-like lesions', - 'Basal cell carcinoma', - 'Actinic keratoses', - 'Vascular lesions', - 'Dermatofibroma' - ], - scheme: 'Federated', - noiseScale: undefined, - clippingRadius: undefined - } - } - }, + getTask(): Task { + return { + id: 'skin_mnist', + displayInformation: { + taskTitle: 'Skin disease classification', + summary: { + preview: + 'Can you determine the skin disease from the dermatoscopic images?', + overview: + 'HAM10000 "Human Against Machine with 10000 training images" dataset is a large collection of multi-source dermatoscopic images of pigmented lesions from Kaggle', + }, + limitations: + 'The training data is limited to small images of size 28x28, similarly to the MNIST dataset.', + tradeoffs: + 'Training success strongly depends on label distribution', + dataFormatInformation: '', + dataExampleText: 'Below you find an example', + dataExampleImage: 'http://walidbn.com/ISIC_0024306.jpg', + }, + trainingInformation: { + modelID: 'skin_mnist-model', + epochs: 50, + roundDuration: 1, + validationSplit: 0.1, + batchSize: 32, + preprocessingFunctions: [data.ImagePreprocessing.Normalize], + learningRate: 0.001, + modelCompileData: { + optimizer: 'adam', + loss: 'categoricalCrossentropy', + metrics: ['accuracy'], + }, + dataType: 'image', + IMAGE_H: 28, + IMAGE_W: 28, + LABEL_LIST: [ + 'Melanocytic nevi', + 'Melanoma', + 'Benign keratosis-like lesions', + 'Basal cell carcinoma', + 'Actinic keratoses', + 'Vascular lesions', + 'Dermatofibroma', + ], + scheme: 'Federated', + noiseScale: undefined, + clippingRadius: undefined, + }, + } + }, - async getModel (): Promise { - const numClasses = 7 - const size = 28 + async getModel(): Promise { + const numClasses = 7 + const size = 28 - const model = tf.sequential() + const model = tf.sequential() - model.add( - tf.layers.conv2d({ - inputShape: [size, size, 3], - filters: 256, - kernelSize: 3, - activation: 'relu' - }) - ) + model.add( + tf.layers.conv2d({ + inputShape: [size, size, 3], + filters: 256, + kernelSize: 3, + activation: 'relu', + }) + ) - model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] })) - model.add(tf.layers.dropout({ rate: 0.3 })) + model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] })) + model.add(tf.layers.dropout({ rate: 0.3 })) - model.add( - tf.layers.conv2d({ - filters: 128, - kernelSize: 3, - activation: 'relu' - }) - ) + model.add( + tf.layers.conv2d({ + filters: 128, + kernelSize: 3, + activation: 'relu', + }) + ) - model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] })) - model.add(tf.layers.dropout({ rate: 0.3 })) + model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] })) + model.add(tf.layers.dropout({ rate: 0.3 })) - model.add( - tf.layers.conv2d({ - filters: 64, - kernelSize: 3, - activation: 'relu' - }) - ) + model.add( + tf.layers.conv2d({ + filters: 64, + kernelSize: 3, + activation: 'relu', + }) + ) - model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] })) - model.add(tf.layers.dropout({ rate: 0.3 })) + model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] })) + model.add(tf.layers.dropout({ rate: 0.3 })) - model.add(tf.layers.flatten()) + model.add(tf.layers.flatten()) - model.add(tf.layers.dense({ units: 32 })) - model.add(tf.layers.dense({ units: numClasses, activation: 'softmax' })) + model.add(tf.layers.dense({ units: 32 })) + model.add(tf.layers.dense({ units: numClasses, activation: 'softmax' })) - return new training.model.TFJSModel(this.getTask(), model) - } + return new training.model.TFJSModel(this.getTask(), model) + }, } diff --git a/discojs/discojs-core/src/default_tasks/titanic.ts b/discojs/discojs-core/src/default_tasks/titanic.ts index f7ab6bfe2..6a6f6c882 100644 --- a/discojs/discojs-core/src/default_tasks/titanic.ts +++ b/discojs/discojs-core/src/default_tasks/titanic.ts @@ -1,93 +1,93 @@ import { tf, training, Task, TaskProvider } from '..' export const titanic: TaskProvider = { - getTask (): Task { - return { - taskID: 'titanic', - displayInformation: { - taskTitle: 'Titanic', - summary: { - preview: "Test our platform by using a publicly available tabular dataset.

Download the passenger list from the Titanic shipwreck here: titanic_train.csv (more info here).
This model predicts the type of person most likely to survive/die in the historic ship accident, based on their characteristics (sex, age, class etc.).", - overview: 'We all know the unfortunate story of the Titanic: this flamboyant new transatlantic boat that sunk in 1912 in the North Atlantic Ocean. Today, we revist this tragedy by trying to predict the survival odds of the passenger given some basic features.' - }, - model: 'The current model does not normalize the given data and applies only a very simple pre-processing of the data.', - tradeoffs: 'We are using a small model for this task: 4 fully connected layers with few neurons. This allows fast training but can yield to reduced accuracy.', - dataFormatInformation: 'This model takes as input a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.

pclass: A proxy for socio-economic status (SES)
1st = Upper
2nd = Middle
3rd = Lower

age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5

sibsp: The dataset defines family relations in this way:
Sibling = brother, sister, stepbrother, stepsister
Spouse = husband, wife (mistresses and fiancés were ignored)

parch: The dataset defines family relations in this way:
Parent = mother, father
Child = daughter, son, stepdaughter, stepson
Some children travelled only with a nanny, therefore parch=0 for them.

The first line of the CSV contains the header:
PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked

Each susequent row contains the corresponding data.', - dataExampleText: 'Below one can find an example of a datapoint taken as input by our model. In this datapoint, the person is young man named Owen Harris that unfortunnalty perished with the Titanic. He boarded the boat in South Hamptons and was a 3rd class passenger. On the testing & validation page, the data should not contain the label column (Survived).', - dataExample: [ - { columnName: 'PassengerId', columnData: '1' }, - { columnName: 'Survived', columnData: '0' }, - { columnName: 'Name', columnData: 'Braund, Mr. Owen Harris' }, - { columnName: 'Sex', columnData: 'male' }, - { columnName: 'Age', columnData: '22' }, - { columnName: 'SibSp', columnData: '1' }, - { columnName: 'Parch', columnData: '0' }, - { columnName: 'Ticket', columnData: '1/5 21171' }, - { columnName: 'Fare', columnData: '7.25' }, - { columnName: 'Cabin', columnData: 'E46' }, - { columnName: 'Embarked', columnData: 'S' }, - { columnName: 'Pclass', columnData: '3' } - ], - headers: [ - 'PassengerId', - 'Survived', - 'Name', - 'Sex', - 'Age', - 'SibSp', - 'Parch', - 'Ticket', - 'Fare', - 'Cabin', - 'Embarked', - 'Pclass' - ] - }, - trainingInformation: { - modelID: 'titanic-model', - epochs: 20, - roundDuration: 10, - validationSplit: 0.2, - batchSize: 30, - preprocessingFunctions: [], - modelCompileData: { - optimizer: 'sgd', - loss: 'binaryCrossentropy', - metrics: ['accuracy'] - }, - dataType: 'tabular', - inputColumns: [ - 'Age', - 'SibSp', - 'Parch', - 'Fare', - 'Pclass' - ], - outputColumns: [ - 'Survived' - ], - scheme: 'Federated', // secure aggregation not yet implemented for FeAI - noiseScale: undefined, - clippingRadius: undefined - } - } - }, + getTask(): Task { + return { + id: 'titanic', + displayInformation: { + taskTitle: 'Titanic', + summary: { + preview: + "Test our platform by using a publicly available tabular dataset.

Download the passenger list from the Titanic shipwreck here: titanic_train.csv (more info here).
This model predicts the type of person most likely to survive/die in the historic ship accident, based on their characteristics (sex, age, class etc.).", + overview: + 'We all know the unfortunate story of the Titanic: this flamboyant new transatlantic boat that sunk in 1912 in the North Atlantic Ocean. Today, we revist this tragedy by trying to predict the survival odds of the passenger given some basic features.', + }, + model: 'The current model does not normalize the given data and applies only a very simple pre-processing of the data.', + tradeoffs: + 'We are using a small model for this task: 4 fully connected layers with few neurons. This allows fast training but can yield to reduced accuracy.', + dataFormatInformation: + 'This model takes as input a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.

pclass: A proxy for socio-economic status (SES)
1st = Upper
2nd = Middle
3rd = Lower

age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5

sibsp: The dataset defines family relations in this way:
Sibling = brother, sister, stepbrother, stepsister
Spouse = husband, wife (mistresses and fiancés were ignored)

parch: The dataset defines family relations in this way:
Parent = mother, father
Child = daughter, son, stepdaughter, stepson
Some children travelled only with a nanny, therefore parch=0 for them.

The first line of the CSV contains the header:
PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked

Each susequent row contains the corresponding data.', + dataExampleText: + 'Below one can find an example of a datapoint taken as input by our model. In this datapoint, the person is young man named Owen Harris that unfortunnalty perished with the Titanic. He boarded the boat in South Hamptons and was a 3rd class passenger. On the testing & validation page, the data should not contain the label column (Survived).', + dataExample: [ + { columnName: 'PassengerId', columnData: '1' }, + { columnName: 'Survived', columnData: '0' }, + { + columnName: 'Name', + columnData: 'Braund, Mr. Owen Harris', + }, + { columnName: 'Sex', columnData: 'male' }, + { columnName: 'Age', columnData: '22' }, + { columnName: 'SibSp', columnData: '1' }, + { columnName: 'Parch', columnData: '0' }, + { columnName: 'Ticket', columnData: '1/5 21171' }, + { columnName: 'Fare', columnData: '7.25' }, + { columnName: 'Cabin', columnData: 'E46' }, + { columnName: 'Embarked', columnData: 'S' }, + { columnName: 'Pclass', columnData: '3' }, + ], + headers: [ + 'PassengerId', + 'Survived', + 'Name', + 'Sex', + 'Age', + 'SibSp', + 'Parch', + 'Ticket', + 'Fare', + 'Cabin', + 'Embarked', + 'Pclass', + ], + }, + trainingInformation: { + modelID: 'titanic-model', + epochs: 20, + roundDuration: 10, + validationSplit: 0.2, + batchSize: 30, + preprocessingFunctions: [], + modelCompileData: { + optimizer: 'sgd', + loss: 'binaryCrossentropy', + metrics: ['accuracy'], + }, + dataType: 'tabular', + inputColumns: ['Age', 'SibSp', 'Parch', 'Fare', 'Pclass'], + outputColumns: ['Survived'], + scheme: 'Federated', // secure aggregation not yet implemented for FeAI + noiseScale: undefined, + clippingRadius: undefined, + }, + } + }, - async getModel (): Promise { - const model = tf.sequential() + async getModel(): Promise { + const model = tf.sequential() - model.add( - tf.layers.dense({ - inputShape: [5], - units: 124, - activation: 'relu', - kernelInitializer: 'leCunNormal' - }) - ) - model.add(tf.layers.dense({ units: 64, activation: 'relu' })) - model.add(tf.layers.dense({ units: 32, activation: 'relu' })) - model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' })) + model.add( + tf.layers.dense({ + inputShape: [5], + units: 124, + activation: 'relu', + kernelInitializer: 'leCunNormal', + }) + ) + model.add(tf.layers.dense({ units: 64, activation: 'relu' })) + model.add(tf.layers.dense({ units: 32, activation: 'relu' })) + model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' })) - return new training.model.TFJSModel(this.getTask(), model) - } + return new training.model.TFJSModel(this.getTask(), model) + }, } diff --git a/discojs/discojs-core/src/default_tasks/wikitext.ts b/discojs/discojs-core/src/default_tasks/wikitext.ts new file mode 100644 index 000000000..854adfa7b --- /dev/null +++ b/discojs/discojs-core/src/default_tasks/wikitext.ts @@ -0,0 +1,86 @@ +import { tf, Task, TaskProvider, TrainingSchemes } from '..' +import * as gpt from '../training/models/gpt' +import { TFJSModel, Model } from '../training/model' + +const modelConfig: gpt.GPTConfig = { + modelType: 'gpt-nano', + epochs: 10, + maxIter: 10_000, + batchSize: 4, + blockSize: 128, + lr: 0.001, + vocabSize: 50258, // TODO: think it should be 50257 but somehow the tokenizer sometimes returns 50258, need to test (it appears in tiny-shakespeare) + evaluate: true, + maxEvalBatches: 12, + evaluateEvery: 100, +} as const + +export const wikitext: TaskProvider = { + getTask(): Task { + return { + id: 'wikitext-103', + displayInformation: { + taskTitle: 'Wikitext 103 Raw', + summary: { + preview: + 'In this challenge, we ask you to do next word prediction on a dataset of Wikipedia articles.', + overview: + 'Wikitext-103-raw is a dataset comprising unprocessed text excerpts from Wikipedia articles, designed for tasks related to natural language processing and language modeling.', + }, + limitations: + 'The dataset may contain noise, inconsistencies, and unstructured content due to its raw nature, potentially posing challenges for certain NLP tasks.', + tradeoffs: + 'The raw format may lack structured annotations and may require additional preprocessing for specific applications.', + dataFormatInformation: + 'The dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.', + dataExampleText: + 'An example excerpt from the dataset could be: "The history of artificial intelligence dates back to ancient times, with philosophical discussions on the nature of thought and reasoning."', + }, + trainingInformation: { + dataType: 'text', + modelID: 'wikitext-103-raw-model', + validationSplit: 0.2, // TODO: is this used somewhere? because train, eval and test are already split in dataset + maxIterations: modelConfig.maxIter, + epochs: modelConfig.epochs ?? 1, + // constructing a batch is taken care automatically in the dataset to make things faster + // so we fake a batch size of 1 + batchSize: 0, + // this is the real batch size used by the core text loader + datasetBatchSize: modelConfig.batchSize, + learningRate: modelConfig.lr, + modelCompileData: { + optimizer: 'adam', + loss: 'categoricalCrossentropy', + metrics: [], // 'precision', 'mse' , 'perplexity' doesnt exist + }, + modelConfig, + /** + * preprocessing is done prior to training so it is not needed in my case + * but otherwise, one can use the following template to use a custom tokenizer + * and the predefined preprocessing functions + */ + // import tokenizer from 'gpt-tokenizer/model/text-davinci-003' + // ... + // tokenizer, + // preprocessingFunctions: [ + // data.TextPreprocessing.Tokenize, + // data.TextPreprocessing.Padding, + // ], + // vocabSize: 50258 + // blockSize: 64 + scheme: TrainingSchemes.DECENTRALIZED, // FIXME: FEDERATED / DECENTRALIZED is broken because of Bun I think + noiseScale: undefined, + decentralizedSecure: true, + minimumReadyPeers: 3, + maxShareValue: 100, + roundDuration: 10, + }, + } + }, + + async getModel(): Promise { + console.log('[wikitext-103 task] GPT Config:', modelConfig) + const model = gpt.GPT(modelConfig) + return new TFJSModel(this.getTask(), model as any as tf.LayersModel) + }, +} diff --git a/discojs/discojs-core/src/task/index.ts b/discojs/discojs-core/src/task/index.ts index 42f55056a..a278ba26d 100644 --- a/discojs/discojs-core/src/task/index.ts +++ b/discojs/discojs-core/src/task/index.ts @@ -1,7 +1,10 @@ -export { isTask, Task, isTaskID, TaskID } from './task' -export { TaskProvider, isTaskProvider } from './task_provider' -export { isDigest, Digest } from './digest' -export { isDisplayInformation, DisplayInformation } from './display_information' -export { TrainingInformation } from './training_information' +export { isTask, isTaskID, type Task, type TaskID } from './task' +export { isTaskProvider, type TaskProvider } from './task_provider' +export { isDigest, type Digest } from './digest' +export { + isDisplayInformation, + type DisplayInformation, +} from './display_information' +export { type TrainingInformation } from './training_information' export { pushTask, fetchTasks } from './task_handler' export { LabelTypeEnum } from './label_type' diff --git a/discojs/discojs-core/src/task/model_compile_data.ts b/discojs/discojs-core/src/task/model_compile_data.ts index c6eed63e3..ad4fe4b69 100644 --- a/discojs/discojs-core/src/task/model_compile_data.ts +++ b/discojs/discojs-core/src/task/model_compile_data.ts @@ -1,36 +1,40 @@ -export function isModelCompileData (raw: unknown): raw is ModelCompileData { - if (typeof raw !== 'object') { - return false - } - if (raw === null) { - return false - } +export function isModelCompileData(raw: unknown): raw is ModelCompileData { + if (typeof raw !== 'object') { + return false + } + if (raw === null) { + return false + } - const { - optimizer, - loss, - metrics - } = raw as Record<'optimizer' | 'loss' | 'metrics', unknown | undefined> + const { optimizer, loss, metrics } = raw as Record< + 'optimizer' | 'loss' | 'metrics', + unknown | undefined + > - if ( - typeof optimizer !== 'string' || - typeof loss !== 'string' - ) { - return false - } + if (typeof optimizer !== 'string' || typeof loss !== 'string') { + return false + } - if (!( - Array.isArray(metrics) && - metrics.every((e) => typeof e === 'string') - )) { - return false - } + if ( + !(Array.isArray(metrics) && metrics.every((e) => typeof e === 'string')) + ) { + return false + } - return true + return true } +// TODO: Use this instead of the original ModelCompileData ? +// Since tf.ModelCompileArgs is unavailable, the following is a way to retrieve it +// as the first parameter of tf.LayersModel.compile() +// type TypeOfClassMethod = T[M] extends Function +// ? T[M] +// : never +// type CompileMethod = TypeOfClassMethod +// export type ModelCompileData = Parameters[0] + export interface ModelCompileData { - optimizer: string - loss: string - metrics: string[] + optimizer: string + loss: string + metrics: string[] } diff --git a/discojs/discojs-core/src/task/task.ts b/discojs/discojs-core/src/task/task.ts index 614f02652..ee4735f2f 100644 --- a/discojs/discojs-core/src/task/task.ts +++ b/discojs/discojs-core/src/task/task.ts @@ -1,47 +1,52 @@ import { isDisplayInformation, DisplayInformation } from './display_information' -import { isTrainingInformation, TrainingInformation } from './training_information' +import { + isTrainingInformation, + TrainingInformation, +} from './training_information' import { isDigest, Digest } from './digest' export type TaskID = string -export function isTaskID (obj: unknown): obj is TaskID { - return typeof obj === 'string' +export function isTaskID(obj: unknown): obj is TaskID { + return typeof obj === 'string' } -export function isTask (raw: unknown): raw is Task { - if (typeof raw !== 'object') { - return false - } - if (raw === null) { - return false - } - - const { taskID, digest, displayInformation, trainingInformation } = raw as - Record<'taskID' | 'digest' | 'displayInformation' | 'trainingInformation', unknown | undefined> - - if (typeof taskID !== 'string') { - return false - } - if (digest !== undefined && !isDigest(digest)) { - return false - } - if (!isDisplayInformation(displayInformation)) { - return false - } - if (!isTrainingInformation(trainingInformation)) { - return false - } - - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const _: Task = { taskID, displayInformation, trainingInformation } - - return true +export function isTask(raw: unknown): raw is Task { + if (typeof raw !== 'object') { + return false + } + if (raw === null) { + return false + } + + const { id, digest, displayInformation, trainingInformation } = + raw as Record< + 'id' | 'digest' | 'displayInformation' | 'trainingInformation', + unknown | undefined + > + + if (!isTaskID(id)) { + return false + } + if (digest !== undefined && !isDigest(digest)) { + return false + } + if (!isDisplayInformation(displayInformation)) { + return false + } + if (!isTrainingInformation(trainingInformation)) { + return false + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const _: Task = { id, displayInformation, trainingInformation } + + return true } -export interface Task { - // TODO rename to ID - taskID: TaskID - digest?: Digest - displayInformation: DisplayInformation - trainingInformation: TrainingInformation +export type Task = { + id: TaskID + digest?: Digest + displayInformation: DisplayInformation + trainingInformation: TrainingInformation } diff --git a/discojs/discojs-core/src/task/task_handler.ts b/discojs/discojs-core/src/task/task_handler.ts index 8a4dfc043..7fcdafe33 100644 --- a/discojs/discojs-core/src/task/task_handler.ts +++ b/discojs/discojs-core/src/task/task_handler.ts @@ -6,28 +6,27 @@ import { isTask, Task, TaskID } from './task' const TASK_ENDPOINT = 'tasks' -export async function pushTask ( - url: URL, - task: Task, - model: tf.LayersModel +export async function pushTask( + url: URL, + task: Task, + model: tf.LayersModel ): Promise { - await axios.post( - url.href + TASK_ENDPOINT, - { - task, - model: await serialization.model.encode(model), - weights: await serialization.weights.encode(WeightsContainer.from(model)) - } - ) + await axios.post(url.href + TASK_ENDPOINT, { + task, + model: await serialization.model.encode(model), + weights: await serialization.weights.encode( + WeightsContainer.from(model) + ), + }) } -export async function fetchTasks (url: URL): Promise> { - const response = await axios.get(new URL(TASK_ENDPOINT, url).href) - const tasks: unknown = response.data +export async function fetchTasks(url: URL): Promise> { + const response = await axios.get(new URL(TASK_ENDPOINT, url).href) + const tasks: unknown = response.data - if (!(Array.isArray(tasks) && tasks.every(isTask))) { - throw new Error('invalid tasks response') - } + if (!(Array.isArray(tasks) && tasks.every(isTask))) { + throw new Error('invalid tasks response') + } - return Map(tasks.map((t) => [t.taskID, t])) + return Map(tasks.map((t) => [t.id, t])) } diff --git a/discojs/discojs-core/src/task/task_provider.ts b/discojs/discojs-core/src/task/task_provider.ts index 3d8e8b8e3..1997a5ee2 100644 --- a/discojs/discojs-core/src/task/task_provider.ts +++ b/discojs/discojs-core/src/task/task_provider.ts @@ -1,15 +1,19 @@ import { Task, training } from '..' -export interface TaskProvider { - getTask: () => Task - getModel: () => Promise +export type TaskProvider = { + getTask: () => Task + getModel: () => Promise } -export function isTaskProvider (obj: any): obj is TaskProvider { - if ('getModel' in obj && typeof obj.getModel === 'function' && - 'getTask' in obj && typeof obj.getTask === 'function') { - return true - } else { - return false - } +export function isTaskProvider(obj: any): obj is TaskProvider { + if ( + 'getModel' in obj && + typeof obj.getModel === 'function' && + 'getTask' in obj && + typeof obj.getTask === 'function' + ) { + return true + } else { + return false + } } diff --git a/discojs/discojs-core/src/task/training_information.ts b/discojs/discojs-core/src/task/training_information.ts index 18dd4bc50..f4a085be1 100644 --- a/discojs/discojs-core/src/task/training_information.ts +++ b/discojs/discojs-core/src/task/training_information.ts @@ -1,193 +1,228 @@ +import { TrainingSchemes } from '..' import { AggregatorChoice } from '../aggregator/get' import { Preprocessing } from '../dataset/data/preprocessing' +import { Tokenizer } from '../training/model' import { isModelCompileData, ModelCompileData } from './model_compile_data' -export function isTrainingInformation (raw: unknown): raw is TrainingInformation { - if (typeof raw !== 'object') { - return false - } - if (raw === null) { - return false - } +export function isTrainingInformation( + raw: unknown +): raw is TrainingInformation { + if (typeof raw !== 'object') { + return false + } + if (raw === null) { + return false + } - type Fields = - 'dataType' | - 'scheme' | - 'epochs' | - 'roundDuration' | - 'validationSplit' | - 'batchSize' | - 'modelCompileData' | - 'modelID' | - 'preprocessingFunctions' | - 'inputColumns' | - 'outputColumns' | - 'IMAGE_H' | - 'IMAGE_W' | - 'modelURL' | - 'learningRate' | - 'decentralizedSecure' | - 'maxShareValue' | - 'minimumReadyPeers' | - 'LABEL_LIST' | - 'noiseScale' | - 'clippingRadius'| - 'aggregator' | - 'vocabSize' + type Fields = + | 'dataType' + | 'scheme' + | 'epochs' + | 'roundDuration' + | 'validationSplit' + | 'batchSize' + | 'modelCompileData' + | 'modelID' + | 'preprocessingFunctions' + | 'inputColumns' + | 'outputColumns' + | 'IMAGE_H' + | 'IMAGE_W' + | 'modelURL' + | 'learningRate' + | 'decentralizedSecure' + | 'maxShareValue' + | 'minimumReadyPeers' + | 'LABEL_LIST' + | 'noiseScale' + | 'clippingRadius' + | 'aggregator' + | 'vocabSize' - const { - dataType, - scheme, - epochs, - // roundDuration, - validationSplit, - batchSize, - modelCompileData, - modelID, - preprocessingFunctions, - inputColumns, - outputColumns, - IMAGE_H, - IMAGE_W, - roundDuration, - modelURL, - learningRate, - decentralizedSecure, - maxShareValue, - minimumReadyPeers, - LABEL_LIST, - noiseScale, - clippingRadius, - aggregator, - vocabSize - } = raw as Record + const { + dataType, + scheme, + epochs, + // roundDuration, + validationSplit, + batchSize, + modelCompileData, + modelID, + preprocessingFunctions, + inputColumns, + outputColumns, + IMAGE_H, + IMAGE_W, + roundDuration, + modelURL, + learningRate, + decentralizedSecure, + maxShareValue, + minimumReadyPeers, + LABEL_LIST, + noiseScale, + clippingRadius, + aggregator, + vocabSize, + } = raw as Record - if ( - typeof dataType !== 'string' || - typeof modelID !== 'string' || - typeof epochs !== 'number' || - typeof batchSize !== 'number' || - typeof roundDuration !== 'number' || - typeof validationSplit !== 'number' || - (modelURL !== undefined && typeof modelURL !== 'string') || - (noiseScale !== undefined && typeof noiseScale !== 'number') || - (clippingRadius !== undefined && typeof clippingRadius !== 'number') || - (learningRate !== undefined && typeof learningRate !== 'number') || - (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') || - (maxShareValue !== undefined && typeof maxShareValue !== 'number') || - (minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') || - (aggregator !== undefined && typeof aggregator !== 'number') || - (vocabSize !== undefined && typeof vocabSize !== 'string') - ) { - return false - } + if ( + typeof dataType !== 'string' || + typeof modelID !== 'string' || + typeof epochs !== 'number' || + typeof batchSize !== 'number' || + typeof roundDuration !== 'number' || + typeof validationSplit !== 'number' || + (modelURL !== undefined && typeof modelURL !== 'string') || + (noiseScale !== undefined && typeof noiseScale !== 'number') || + (clippingRadius !== undefined && typeof clippingRadius !== 'number') || + (learningRate !== undefined && typeof learningRate !== 'number') || + (decentralizedSecure !== undefined && + typeof decentralizedSecure !== 'boolean') || + (maxShareValue !== undefined && typeof maxShareValue !== 'number') || + (minimumReadyPeers !== undefined && + typeof minimumReadyPeers !== 'number') || + (aggregator !== undefined && typeof aggregator !== 'number') || + (vocabSize !== undefined && typeof vocabSize !== 'string') + ) { + return false + } - // interdepences on data type - if (dataType === 'image') { - if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') { - return false + // interdepences on data type + if (dataType === 'image') { + if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') { + return false + } + } else if (dataType in ['text', 'tabular']) { + if ( + !( + Array.isArray(inputColumns) && + inputColumns.every((e) => typeof e === 'string') + ) + ) { + return false + } + if ( + !( + Array.isArray(outputColumns) && + outputColumns.every((e) => typeof e === 'string') + ) + ) { + return false + } } - } else if (dataType in ['text', 'tabular']) { - if (!(Array.isArray(inputColumns) && inputColumns.every((e) => typeof e === 'string'))) { - return false + + // interdepences on scheme + switch (scheme) { + case 'decentralized': + break + case 'federated': + break + case 'local': + break } - if (!(Array.isArray(outputColumns) && outputColumns.every((e) => typeof e === 'string'))) { - return false + + if (!isModelCompileData(modelCompileData)) { + return false } - } - // interdepences on scheme - switch (scheme) { - case 'decentralized': - break - case 'federated': - break - case 'local': - break - } + if ( + LABEL_LIST !== undefined && + !( + Array.isArray(LABEL_LIST) && + LABEL_LIST.every((e) => typeof e === 'string') + ) + ) { + return false + } - if (!isModelCompileData(modelCompileData)) { - return false - } + if ( + preprocessingFunctions !== undefined && + !Array.isArray(preprocessingFunctions) + ) { + return false + } - if ( - LABEL_LIST !== undefined && !( - Array.isArray(LABEL_LIST) && LABEL_LIST.every((e) => typeof e === 'string') - ) - ) { - return false - } + return true +} - if ( - preprocessingFunctions !== undefined && !(Array.isArray(preprocessingFunctions)) - ) { - return false - } +type ModelConfigTask = Config extends {} + ? { modelConfig: Config } + : { modelConfig?: Config } // TODO: ideally modelConfig shouldn't be available at all if Config is null or undefined i.e. if Config does not extends {} - return true -} +export type TrainingInformation = { + // modelID: unique ID for the model + modelID: string + // maxIterations: number of iterations to run training (if epoch is specified, whatever comes first stops training) + maxIterations?: number + // epochs: number of epochs to run training for + epochs: number + // roundDuration: number of batches between each weight sharing round, e.g. if 3 then after every + // 3 batches we share weights (in the distributed setting). + roundDuration: number + // validationSplit: fraction of data to keep for validation, note this only works for image data + validationSplit: number + // batchSize: batch size of training data + batchSize: number + // preprocessingFunctions: preprocessing functions such as resize and normalize + preprocessingFunctions?: Preprocessing[] + // modelCompileData: interface of additional training information (optimizer, loss and metrics) + modelCompileData: ModelCompileData + // dataType, e.g. image or tabular + dataType: string + // inputColumns: for tabular data, the columns to be chosen as input data for the model + inputColumns?: string[] + // outputColumns: for tabular data, the columns to be predicted by the model + outputColumns?: string[] + // IMAGE_H height of image (or RESIZED_IMAGE_H if ImagePreprocessing.Resize in preprocessingFunctions) + IMAGE_H?: number + // IMAGE_W width of image (or RESIZED_IMAGE_W if ImagePreprocessing.Resize in preprocessingFunctions) + IMAGE_W?: number + // Model URL to download the base task model from. Useful for pretrained or/and hosted models. + modelURL?: string + // LABEL_LIST of classes, e.g. if two class of images, one with dogs and one with cats, then we would + // define ['dogs', 'cats']. + LABEL_LIST?: string[] + // learningRate: learning rate for the optimizer + learningRate?: number + // scheme: Distributed training scheme, i.e. Federated and Decentralized + scheme: TrainingSchemes + // noiseScale: Differential Privacy (DP): Affects the variance of the Gaussian noise added to the models / model updates. + // Number or undefined. If undefined, then no noise will be added. + noiseScale?: number + // clippingRadius: Privacy (DP and Secure Aggregation): + // Number or undefined. If undefined, then no model updates will be clipped. + // If number, then model updates will be scaled down if their norm exceeds clippingRadius. + clippingRadius?: number + // decentralizedSecure: Secure Aggregation on/off: + // Boolean. true for secure aggregation to be used, if the training scheme is decentralized, false otherwise + decentralizedSecure?: boolean + // byzantineRobustAggregator: Byzantine robust aggregator on/off: + // Boolean. true to use byzantine robust aggregation, if the training scheme is federated, false otherwise + byzantineRobustAggregator?: boolean + // tauPercentile: it indicates the percentile to take when choosing the tau for byzantine robust aggregator: + // Number (>0 && <1). It must be a number between 0 and 1 and it is used only if byzantineRobustAggregator is true. + tauPercentile?: number + // maxShareValue: Secure Aggregation: maximum absolute value of a number in a randomly generated share + // default is 100, must be a positive number, check the ~/disco/information/PRIVACY.md file for more information on significance of maxShareValue selection + // only relevant if secure aggregation is true (for either federated or decentralized learning) + maxShareValue?: number + // minimumReadyPeers: Decentralized Learning: minimum number of peers who must be ready to participate in aggregation before model updates are shared between clients + // default is 3, range is [3, totalNumberOfPeersParticipating] + minimumReadyPeers?: number + // aggregator: aggregator to be used by the server for federated learning, or by the peers for decentralized learning + // default is 'average', other options include for instance 'bandit' + aggregator?: AggregatorChoice -export interface TrainingInformation { - // modelID: unique ID for the model - modelID: string - // epochs: number of epochs to run training for - epochs: number - // roundDuration: number of batches between each weight sharing round, e.g. if 3 then after every - // 3 batches we share weights (in the distributed setting). - roundDuration: number - // validationSplit: fraction of data to keep for validation, note this only works for image data - validationSplit: number - // batchSize: batch size of training data - batchSize: number - // preprocessingFunctions: preprocessing functions such as resize and normalize - preprocessingFunctions?: Preprocessing[] - // modelCompileData: interface of additional training information (optimizer, loss and metrics) - modelCompileData: ModelCompileData - // dataType, e.g. image or tabular - dataType: string - // inputColumns: for tabular data, the columns to be chosen as input data for the model - inputColumns?: string[] - // outputColumns: for tabular data, the columns to be predicted by the model - outputColumns?: string[] - // IMAGE_H height of image (or RESIZED_IMAGE_H if ImagePreprocessing.Resize in preprocessingFunctions) - IMAGE_H?: number - // IMAGE_W width of image (or RESIZED_IMAGE_W if ImagePreprocessing.Resize in preprocessingFunctions) - IMAGE_W?: number - // Model URL to download the base task model from. Useful for pretrained or/and hosted models. - modelURL?: string - // LABEL_LIST of classes, e.g. if two class of images, one with dogs and one with cats, then we would - // define ['dogs', 'cats']. - LABEL_LIST?: string[] - // learningRate: learning rate for the optimizer - learningRate?: number - // scheme: Distributed training scheme, i.e. Federated and Decentralized - scheme: string - // noiseScale: Differential Privacy (DP): Affects the variance of the Gaussian noise added to the models / model updates. - // Number or undefined. If undefined, then no noise will be added. - noiseScale?: number - // clippingRadius: Privacy (DP and Secure Aggregation): - // Number or undefined. If undefined, then no model updates will be clipped. - // If number, then model updates will be scaled down if their norm exceeds clippingRadius. - clippingRadius?: number - // decentralizedSecure: Secure Aggregation on/off: - // Boolean. true for secure aggregation to be used, if the training scheme is decentralized, false otherwise - decentralizedSecure?: boolean - // byzantineRobustAggregator: Byzantine robust aggregator on/off: - // Boolean. true to use byzantine robust aggregation, if the training scheme is federated, false otherwise - byzantineRobustAggregator?: boolean - // tauPercentile: it indicates the percentile to take when choosing the tau for byzantine robust aggregator: - // Number (>0 && <1). It must be a number between 0 and 1 and it is used only if byzantineRobustAggregator is true. - tauPercentile?: number - // maxShareValue: Secure Aggregation: maximum absolute value of a number in a randomly generated share - // default is 100, must be a positive number, check the ~/disco/information/PRIVACY.md file for more information on significance of maxShareValue selection - // only relevant if secure aggregation is true (for either federated or decentralized learning) - maxShareValue?: number - // minimumReadyPeers: Decentralized Learning: minimum number of peers who must be ready to participate in aggregation before model updates are shared between clients - // default is 3, range is [3, totalNumberOfPeersParticipating] - minimumReadyPeers?: number - // aggregator: aggregator to be used by the server for federated learning, or by the peers for decentralized learning - // default is 'average', other options include for instance 'bandit' - aggregator?: AggregatorChoice - // TODO: for LLMs - vocabSize?: number -} + /** + * ==== FOR LLMs ==== + */ + // datasetBatchSize: actual batch size used by the core text loader to construct a batch + datasetBatchSize?: number + // vocabSize: vocabulary size of the tokenizer + vocabSize?: number + // tokenizer: tokenizer to be used for tokenizing the text + tokenizer?: Tokenizer + // blockSize: sequence length to be tokenized + blockSize?: number +} & ModelConfigTask diff --git a/discojs/discojs-core/src/training/disco.ts b/discojs/discojs-core/src/training/disco.ts index 4d273de22..582f9f858 100644 --- a/discojs/discojs-core/src/training/disco.ts +++ b/discojs/discojs-core/src/training/disco.ts @@ -1,12 +1,14 @@ import { - client as clients, - data, - Logger, - Task, - TrainingInformant, informant as informants, - TrainingSchemes, - Memory, EmptyMemory, - ConsoleLogger + client as clients, + data, + Logger, + Task, + TrainingInformant, + informant as informants, + TrainingSchemes, + Memory, + EmptyMemory, + ConsoleLogger, } from '..' import { Trainer } from './trainer/trainer' import { TrainerBuilder } from './trainer/trainer_builder' @@ -15,13 +17,13 @@ import { Aggregator } from '../aggregator' import { MeanAggregator } from '../aggregator/mean' export interface DiscoOptions { - client?: clients.Client - aggregator?: Aggregator - url?: string | URL - scheme?: TrainingSchemes - informant?: TrainingInformant - logger?: Logger - memory?: Memory + client?: clients.Client + aggregator?: Aggregator + url?: string | URL + scheme?: TrainingSchemes + informant?: TrainingInformant + logger?: Logger + memory?: Memory } /** @@ -30,111 +32,135 @@ export interface DiscoOptions { * communication with nodes, logs and model memory. */ export class Disco { - public readonly task: Task - public readonly logger: Logger - public readonly memory: Memory - private readonly client: clients.Client - private readonly aggregator: Aggregator - private readonly trainer: Promise + public readonly task: Task + public readonly logger: Logger + public readonly memory: Memory + private readonly client: clients.Client + private readonly aggregator: Aggregator + private readonly trainer: Promise - constructor ( - task: Task, - options: DiscoOptions - ) { - if (options.scheme === undefined) { - options.scheme = TrainingSchemes[task.trainingInformation.scheme as keyof typeof TrainingSchemes] - } - if (options.aggregator === undefined) { - options.aggregator = new MeanAggregator(task) - } - if (options.client === undefined) { - if (options.url === undefined) { - throw new Error('could not determine client from given parameters') - } + constructor(task: Task, options: DiscoOptions) { + if (options.scheme === undefined) { + options.scheme = task.trainingInformation.scheme + } + if (options.aggregator === undefined) { + options.aggregator = new MeanAggregator(task) + } + if (options.client === undefined) { + if (options.url === undefined) { + throw new Error( + 'could not determine client from given parameters' + ) + } - if (typeof options.url === 'string') { - options.url = new URL(options.url) - } - switch (options.scheme) { - case TrainingSchemes.FEDERATED: - options.client = new clients.federated.FederatedClient(options.url, task, options.aggregator) - break - case TrainingSchemes.DECENTRALIZED: - options.client = new clients.decentralized.DecentralizedClient(options.url, task, options.aggregator) - break - default: - options.client = new clients.Local(options.url, task, options.aggregator) - break - } - } - if (options.informant === undefined) { - switch (options.scheme) { - case TrainingSchemes.FEDERATED: - options.informant = new informants.FederatedInformant(task) - break - case TrainingSchemes.DECENTRALIZED: - options.informant = new informants.DecentralizedInformant(task) - break - default: - options.informant = new informants.LocalInformant(task) - break - } - } - if (options.logger === undefined) { - options.logger = new ConsoleLogger() - } - if (options.memory === undefined) { - options.memory = new EmptyMemory() - } - if (options.client.task !== task) { - throw new Error('client not setup for given task') - } - if (options.informant.task.taskID !== task.taskID) { - throw new Error('informant not setup for given task') - } + if (typeof options.url === 'string') { + options.url = new URL(options.url) + } + switch (options.scheme) { + case TrainingSchemes.FEDERATED: + options.client = new clients.federated.FederatedClient( + options.url, + task, + options.aggregator + ) + break + case TrainingSchemes.DECENTRALIZED: + options.client = + new clients.decentralized.DecentralizedClient( + options.url, + task, + options.aggregator + ) + break + default: + options.client = new clients.Local( + options.url, + task, + options.aggregator + ) + break + } + } + if (options.informant === undefined) { + switch (options.scheme) { + case TrainingSchemes.FEDERATED: + options.informant = new informants.FederatedInformant(task) + break + case TrainingSchemes.DECENTRALIZED: + options.informant = new informants.DecentralizedInformant( + task + ) + break + default: + options.informant = new informants.LocalInformant(task) + break + } + } + if (options.logger === undefined) { + options.logger = new ConsoleLogger() + } + if (options.memory === undefined) { + options.memory = new EmptyMemory() + } + if (options.client.task !== task) { + throw new Error('client not setup for given task') + } + if (options.informant.task.id !== task.id) { + throw new Error('informant not setup for given task') + } - this.task = task - this.client = options.client - this.aggregator = options.aggregator - this.memory = options.memory - this.logger = options.logger + this.task = task + this.client = options.client + this.aggregator = options.aggregator + this.memory = options.memory + this.logger = options.logger - const trainerBuilder = new TrainerBuilder(this.memory, this.task, options.informant) - this.trainer = trainerBuilder.build(this.aggregator, this.client, options.scheme !== TrainingSchemes.LOCAL) - } + const trainerBuilder = new TrainerBuilder( + this.memory, + this.task, + options.informant + ) + this.trainer = trainerBuilder.build( + this.aggregator, + this.client, + options.scheme !== TrainingSchemes.LOCAL + ) + } - /** - * Starts a training instance for the Disco object's task on the provided data tuple. - * @param data The data tuple - */ - async fit (data: data.tuple.DataSplit): Promise { - this.logger.success('Thank you for your contribution. Data preprocessing has started') + /** + * Starts a training instance for the Disco object's task on the provided data tuple. + * @param data The data tuple + */ + async fit(data: data.tuple.DataSplit): Promise { + this.logger.success( + 'Thank you for your contribution. Data preprocessing has started' + ) - await this.client.connect() - const trainer = await this.trainer - await trainer.fitModel(data) - } + await this.client.connect() + const trainer = await this.trainer + await trainer.fitModel(data) + } - /** - * Stops the ongoing training instance without disconnecting the client. - */ - async pause (): Promise { - const trainer = await this.trainer - await trainer.stopTraining() + /** + * Stops the ongoing training instance without disconnecting the client. + */ + async pause(): Promise { + const trainer = await this.trainer + await trainer.stopTraining() - this.logger.success('Training was successfully interrupted.') - } + this.logger.success('Training was successfully interrupted.') + } - /** - * Completely stops the ongoing training instance. - */ - async close (): Promise { - await this.pause() - await this.client.disconnect() - } + /** + * Completely stops the ongoing training instance. + */ + async close(): Promise { + await this.pause() + await this.client.disconnect() + } - async logs (): Promise { - const trainer = await this.trainer - return trainer.getTrainerLog() - } + async logs(): Promise { + const trainer = await this.trainer + return trainer.getTrainerLog() + } } diff --git a/discojs/discojs-core/src/training/index.ts b/discojs/discojs-core/src/training/index.ts index 89c15fe21..97881c9b4 100644 --- a/discojs/discojs-core/src/training/index.ts +++ b/discojs/discojs-core/src/training/index.ts @@ -1,3 +1,5 @@ export { Disco } from './disco' export { TrainingSchemes } from './training_schemes' +export type { TrainingCallbacks } from './trainer/trainer' export * as model from './model' +export * as trainer from './trainer' diff --git a/discojs/discojs-core/src/training/models/gpt/bun.lockb b/discojs/discojs-core/src/training/models/gpt/bun.lockb new file mode 100755 index 0000000000000000000000000000000000000000..854ef7fec5d747a3e557c427fc0bed6001f31e53 GIT binary patch literal 8380 zcmeHM2~<-@7mgTYQ>Y54MKp+7R9*;(vPcmHwW5LwD&PVkJRq1nViE)_DAbkK4Mb5? z+zQAd)B-8=K{EborF znN@^LY85D?tVn7;CoqE2g&!db7t9S7ibx_jTq@?viTONT4Hk>lVDjWZq|eOaxZgyF zt+y_6^z%tNzN)d|jZC_<_wEzVPgruG5hP=UU}?9KamunYZa|j}Sgg37kV8@^BO+xY z7HbETb)ZZ}kpe#CgF3KS-J$FYQV(P}MaqR@k|h?3g#sd&ngq`~!*dB=9!kt5qem&r z_8<|@5RhF!nt{{>X$%taM1~3lp};ScfhLfB*$O*efJE*2Qo0)1M)O!|O?BX%&(d$YC11ZA^Yq{I;DQhr?5i( zKW;w=6l4EEs87`|0Q~5`F@EWQcLIDTrOBVF{|?}h|6sp?V!IOV|5xzhkpJT{mg-1w zy9KTA>TtLo<8uLz)(>39_39*U{~qw80pH$!V_u9O1rwjA)Q_LxHf>4Vb{^nI03N?L zv>&uLArACIw&V2&`vH=$KhgUOhz}&v^%;5y2}xA$iVBe0BauEs_#vDT{7^y?mGK7L zTB0`pE_<5|no)=T$EgEL!ONs(H@O!X)=%$z&3E~tA$xUPc6m(QWqfD<+iJfDdXBdj z=ihTp-&Xj;>t)|{%rFvNY0UGVvf|lycXVPTQcfCAXnk`!*ob+tUm>&EF}X#1_6+H{ z-cYMUeB5=d;QJ5Av&|k?PN!FIPj%JdowC0Ab4UL6q*9wvUkxnW66^Xo^{6s7w9^8! z9o{EGE6YwYc$xEq-Qz(aJ4Z8j;h}_dP1!^XgEGCcGkpptb}zmhzq--kfn}7ghHKRK zeX^RGiaPZhe_Y2g_3Dy&Z!(XSQ^6X?8o09=yfFP*68pvdEHbrshvZSL7i`WY!L37I z_BGpC=sWBD`NQ4%_yn%F$MgE)rP;a8r;>Y5O2~R!aDU{pgXgrX6O+dcczpE_vr+~x znh!X!I~@v-`GtME!Ybz=S!G0Rwl7_k>`<^u)OED(-Uy4N^F5}{ad5v~vOZ^8R_(it zgZ8UFd%2X_5NxoY@{BCYp0E;#@c82Sg}B+pl2<{V>2J)7n;&+&)KJ;|QbXfJZfNPZ z>zm%ON-V01MP~})ea&W>noh~%77{7LZoH_!+UbwgbMBtd4k%u0(u={1*I$}9@sQtm zvxG?@E`p>r)9&dHlIJ+zXwoY>T)cd|xXiTH=VDqwbFa-u1Y<(Iw*OX>X61KOd-1zT zgKH88HQ%er*yzsS#V2K&*D>S5x>WY4>R!*ut7orMze@$>SJul%TZIWU=0)dajWzce zf4aFe*YEY@%GW##?`nTfKR1~C5FUHYpxV7DZjR-Y!lX;wUydv< zo(V*FeBlh;lGr;P((K%~HWKFf*2Iq0PJ3M|b5bO{@!HAed3E~!_k8`6&hD*Agxj9dN5-@WPqAC9#Ev-|w#Fe)-03O7|)QhfQ8nOH&^? zC5|lE6907YkK2up^jny)X#UbIa{XS}C2OC#&Na+3tH2}JJz@z0SmM^Uyma>RCcKPVmK#NSj$b&c5%p13; zWM71LzW&{X7tFK29&@l-pTP^?2rY@N6+b2<`?zQ5on;d^iTjLiNy2@WpPYF?l;*~n zt0#ZVJzJxBdENG)Vg3bYv^M2lnEZW_X`o;4&l3dOCg{8i>feLGi*(?`F1)beyT`jK zZOyZ}zn@$?z^}kr;0ab(d z3-bEC$sKfJL2gfD<0SW$h4U-E$?w=`(U@mckaVZxX&XQOyj@+X(D-b3+Y`FBXYRu1 zUyHqXpQCx*VkUC8kz010GVN8}(1Gt}ZW5qnl(F#Wk3q|%d#b8(1@Y6Sd-Qyrz4PY8 zgJF+{m_O^g&8&;t0kcl#LihaseasoW-4LNNvDe1Ty*XSMSP`6iJ51lP-v3p^9@8L` zqUUdQ1`U~VAUb;QoO1&ullYSYrt>G(?8)F&G;PYv;C5Nv7DJznr3z;}%mt_t>MwJ4|-mvB_3@%RJM%f}GjcmXz`Kns{8lx3s6c zV2AI-l>O^QmiBI1FwW#x1~1yjabi2Xzj#?U>t!UT&&n#ZZ*H>BR<1ej{cgZm`$u_N z5z~y{O>!$Q|H@0ZLKHh_W}v_(EH%Qx*EBnGY;^AY#(n%u#Th`czu|KM;%3hoW##iz zME@1DyubCocKDD{^mQ*o??lbHbQk?3CxLu0^&%@)?zID3#Vuj2lR-C>@K z{)zVeXVZTp@ZSjhzegZW`T3Zvl)6qbsemh_WOBYp#Kl2muD6sVojEoF7+2 z&c8?rI#;578lBhB*$RDkptA)!2cSJ0or{oew9mRI^=@p3e#Ks@z#-je{YKwlh!5#O z`jB3vAK8NJLiQp3$UbBn`XM{e84B5i>^Y;<&Hfw;kaV49+6+TWRT#W;<48DN!$WL1){bM?P-6$S4P~)#G#*E}p=LPL00zhFaU7gxgC0Q?5neay zDUOt@^x&93j-{(=@a+JOx5xW=7Pssf5c&3 zNJ+3*WECQj6Y}6N8MoiyDJP)EBwrjFA?8zr4Ax7&NCsvpo{C9n2&v4&pb%0R zLF7t`b_5wEltV+Iga{S_j-bd$k&q(W=*&l2gMlA;3LS(gs-}YsGY|`HJZ%(#CXt;Q zv~pGgQkkQN$VLqc&pv?REZB;?plu&u4j*8dMS@-?)Zp<91}x4(FJZYk=nu;68Z(u90$f!f(DJQLJHyKX%qS7`wtz-W7D@*4sSrXC zE(YfhqUSoj=yB<*YE`)nho;+u0$S3;qRxTnOJG3IvR$vBm7RcX^+d&jD_|)jquMb0 sgLVu;8b{9_b^eU@bpRp(>d|a%+y7e@UjVh_>DAME7=kL+@_*j{pNuwyJ^%m! literal 0 HcmV?d00001 diff --git a/discojs/discojs-core/src/training/models/gpt/config.ts b/discojs/discojs-core/src/training/models/gpt/config.ts new file mode 100644 index 000000000..46ce3f6a1 --- /dev/null +++ b/discojs/discojs-core/src/training/models/gpt/config.ts @@ -0,0 +1,77 @@ +type ModelType = + | 'gpt2' + | 'gpt2-medium' + | 'gpt2-large' + | 'gpt2-xl' + | 'gpt-mini' + | 'gpt-micro' + | 'gpt-nano' + +type ModelSize = { + nLayer?: number + nHead?: number + nEmbd?: number +} + +export type GPTConfig = { + lr: number + batchSize: number + blockSize: number + vocabSize: number + evaluate?: boolean + maxEvalBatches?: number + evaluateEvery?: number + epochs?: number + maxIter?: number + weightDecay?: number + verbose?: boolean + bias?: boolean + debug?: boolean + dropout?: number + residDrop?: number + embdDrop?: number + tokEmb?: boolean + lmHead?: boolean + modelType: ModelType +} + +export const DEFAULT_CONFIG: Required = { + lr: 0.001, + weightDecay: 0, + batchSize: 2, + epochs: 9999, + maxIter: 10_000, + verbose: false, + modelType: 'gpt-nano', + evaluate: true, + maxEvalBatches: 12, + evaluateEvery: 100, + blockSize: 128, + vocabSize: 50258, + bias: true, + debug: false, + dropout: 0.2, + residDrop: 0.2, + embdDrop: 0.2, + tokEmb: true, + lmHead: true, +} + +export const getModelSizes = (modelType: ModelType): Required => { + switch (modelType) { + case 'gpt2': + return { nLayer: 12, nHead: 12, nEmbd: 768 } + case 'gpt2-medium': + return { nLayer: 24, nHead: 16, nEmbd: 1024 } + case 'gpt2-large': + return { nLayer: 36, nHead: 20, nEmbd: 1280 } + case 'gpt2-xl': + return { nLayer: 48, nHead: 25, nEmbd: 1600 } + case 'gpt-mini': + return { nLayer: 6, nHead: 6, nEmbd: 192 } + case 'gpt-micro': + return { nLayer: 4, nHead: 4, nEmbd: 128 } + case 'gpt-nano': + return { nLayer: 3, nHead: 3, nEmbd: 48 } + } +} diff --git a/discojs/discojs-core/src/training/models/gpt/evaluate.ts b/discojs/discojs-core/src/training/models/gpt/evaluate.ts new file mode 100755 index 000000000..092a939fe --- /dev/null +++ b/discojs/discojs-core/src/training/models/gpt/evaluate.ts @@ -0,0 +1,47 @@ +import { tf, dataset } from '../../..' +import { GPTConfig } from '.' + +export default async function evaluate( + model: any, + dataset: dataset.Dataset, + config: Required +) { + console.log('Evaluating..') + + const iter = await dataset.iterator() + + let total_loss = 0 + const acc: [number, number] = [0, 0] + + let iteration = 0 + while (iteration < config.maxEvalBatches) { + const next = await iter.next() + if (!next) break + const { xs, ys } = next.value + const logits = model.apply(xs) + + // Loss + const loss = tf.losses.softmaxCrossEntropy(ys, logits) + const lossVal = await loss.array() + total_loss += lossVal as number + + // Accuracy + const acc_tensor = tf.metrics.categoricalAccuracy(ys, logits) + const acc_sum = acc_tensor.sum() + acc[0] += (await acc_sum.array()) as number + acc[1] += acc_tensor.shape[0] * (acc_tensor.shape[1] as number) + + tf.dispose([acc_tensor, acc_sum, loss, logits, xs, ys]) + + iteration++ + } + + const loss = total_loss / iteration + const pp = 2.71828 ** loss + + return { + 'val/loss': loss, + 'val/perplexity': pp, + 'val/acc': acc[0] / acc[1], + } +} diff --git a/discojs/discojs-core/src/training/models/gpt/index.ts b/discojs/discojs-core/src/training/models/gpt/index.ts new file mode 100644 index 000000000..532e9451c --- /dev/null +++ b/discojs/discojs-core/src/training/models/gpt/index.ts @@ -0,0 +1,4 @@ +export * from './train' +export * from './optimizers' +export * from './model' +export * from './config' diff --git a/discojs/discojs-core/src/training/models/gpt/model.ts b/discojs/discojs-core/src/training/models/gpt/model.ts new file mode 100644 index 000000000..671db7e41 --- /dev/null +++ b/discojs/discojs-core/src/training/models/gpt/model.ts @@ -0,0 +1,616 @@ +import { GPTConfig, getModelSizes, DEFAULT_CONFIG } from '.' +import { dataset, tf, training } from '../../..' +import { train } from './train' + +const Range = (config: any) => new Range_(config) +class Range_ extends tf.layers.Layer { + computeOutputShape(inputShape: any) { + return inputShape + } + + call(input: any, kwargs: any) { + return tf.tidy(() => { + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + const [B, T] = input.shape + const range = tf.reshape(tf.range(0, T, 1, 'int32'), [1, T]) // .tile([B, 1]) + return range + }) + } + + static get className() { + return 'Range' + } +} +tf.serialization.registerClass(Range_) + +const LogLayer = (config: any) => new LogLayer_(config) +class LogLayer_ extends tf.layers.Layer { + config: any + constructor(config: any) { + super(config) + this.config = config + } + + computeOutputShape(inputShape: any) { + return inputShape + } + + call(input: any, kwargs: any) { + return tf.tidy(() => { + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + const x = tf.util.flatten(input.arraySync()) + console.log( + this.config.name + '>', + input.shape, + x[0], + x[x.length - 1] + ) + return input + }) + } + + static get className() { + return 'LogLayer' + } +} +tf.serialization.registerClass(LogLayer_) + +const CausalSelfAttentionBase = (config: any) => + new CausalSelfAttentionBase_(config) +class CausalSelfAttentionBase_ extends tf.layers.Layer { + config: any + blockSize: any + nHead: any + nEmbd: any + dropout: any + mask: any + + constructor(config: any) { + super(config) + this.config = config + this.blockSize = config.blockSize + this.nEmbd = config.nEmbd + this.nHead = config.nHead + this.dropout = config.dropout + this.mask = tf.linalg.bandPart( + tf.ones([config.blockSize, config.blockSize]), + -1, + 0 + ) + } + + computeOutputShape(inputShape: any) { + return [null, this.blockSize, this.nEmbd] + } + + getConfig() { + const config = super.getConfig() + return Object.assign({}, config, this.config) + } + + call(input: any, kwargs: any) { + return tf.tidy(() => { + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + + let [q, k, v] = tf.split(input, 3, -1) + const [B, T, C] = k.shape + const splitHeads = (x: any) => + tf.transpose( + tf.reshape(x, [B, T, this.nHead, C / this.nHead]), + [0, 2, 1, 3] + ) + q = splitHeads(q) + k = splitHeads(k) + v = splitHeads(v) + + let att = tf.mul( + tf.matMul(q, k, false, true), + tf.div( + 1, + tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32')) + ) + ) + att = tf.add(att, tf.mul(tf.sub(1, this.mask), -1e9)) + att = tf.softmax(att, -1) + att = kwargs['training'] ? tf.dropout(att, this.dropout) : att + + let y = tf.matMul(att, v) + y = tf.transpose(y, [0, 2, 1, 3]) + y = tf.reshape(y, [B, T, C]) + + return y + }) + } + + static get className() { + return 'CausalSelfAttentionBase' + } +} +tf.serialization.registerClass(CausalSelfAttentionBase_) + +function CausalSelfAttentionMixed(conf: any) { + const config = Object.assign({ name: 'attn' }, conf) + const csa = CausalSelfAttentionBase(config) + const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] }) + let att + att = tf.layers + .dense({ + name: config.name + '/c_attn', + units: 3 * config.nEmbd, + inputDim: config.nEmbd, + inputShape: [config.blockSize, config.nEmbd], + useBias: config.bias, + }) + .apply(inputs) + att = csa.apply(att) + att = tf.layers + .dense({ + name: config.name + '/proj', + units: config.nEmbd, + inputDim: config.nEmbd, + inputShape: [config.blockSize, config.nEmbd], + useBias: config.bias, + }) + .apply(att) + att = tf.layers + .dropout({ + name: config.name + '/drop', + rate: config.dropout, + }) + .apply(att) + return tf.model({ inputs: inputs, outputs: att as any }) +} + +const CausalSelfAttention = (config: any) => new CausalSelfAttention_(config) +class CausalSelfAttention_ extends tf.layers.Layer { + config: any + blockSize: any + nHead: any + nEmbd: any + dropout: any + bias: any + mask: any + + cAttnKernel: any + cAttnBias: any + cProjKernel: any + cProjBias: any + + constructor(config: any) { + super(config) + this.config = Object.assign({ name: 'attn' }, config) + this.blockSize = config.blockSize + this.nEmbd = config.nEmbd + this.nHead = config.nHead + this.dropout = config.dropout + this.bias = config.bias + this.mask = tf.linalg.bandPart( + tf.ones([config.blockSize, config.blockSize]), + -1, + 0 + ) + } + + build(inputShape: any) { + this.cAttnKernel = this.addWeight( + 'c_attn/kernel', + [this.nEmbd, 3 * this.nEmbd], + 'float32', + tf.initializers.glorotNormal({}) + ) + this.cAttnBias = this.addWeight( + 'c_attn/bias', + [3 * this.nEmbd], + 'float32', + tf.initializers.zeros() + ) + this.cProjKernel = this.addWeight( + 'c_proj/kernel', + [this.nEmbd, this.nEmbd], + 'float32', + tf.initializers.glorotNormal({}) + ) + this.cProjBias = this.addWeight( + 'c_proj/bias', + [this.nEmbd], + 'float32', + tf.initializers.zeros() + ) + } + + computeOutputShape(inputShape: any) { + return inputShape + } + + getConfig() { + const config = super.getConfig() + return Object.assign({}, config, this.config) + } + + call(input: any, kwargs: any) { + return tf.tidy(() => { + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + + const dense = (x: any, kernel: any, bias: any) => { + const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]) + const m = tf.matMul(x, k) + if (this.bias) { + return tf.add(m, bias.read()) + } else { + return m + } + } + + const cAttn = dense(input, this.cAttnKernel, this.cAttnBias) + + let [q, k, v] = tf.split(cAttn, 3, -1) + const [B, T, C] = k.shape + + const splitHeads = (x: any) => + tf.transpose( + tf.reshape(x, [B, T, this.nHead, C / this.nHead]), + [0, 2, 1, 3] + ) + + q = splitHeads(q) + k = splitHeads(k) + v = splitHeads(v) + + let att = tf.mul( + tf.matMul(q, k, false, true), + tf.div( + 1, + tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32')) + ) + ) + + const mask = this.mask.slice([0, 0], [T, T]) + att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9)) + att = tf.softmax(att, -1) + att = kwargs['training'] ? tf.dropout(att, this.dropout) : att + + let y = tf.matMul(att, v) + y = tf.transpose(y, [0, 2, 1, 3]) + y = tf.reshape(y, [B, T, C]) + y = dense(y, this.cProjKernel, this.cProjBias) + y = kwargs['training'] ? tf.dropout(y, this.dropout) : y + + return y + }) + } + + static get className() { + return 'CausalSelfAttention' + } +} +tf.serialization.registerClass(CausalSelfAttention_) + +const GELU = () => new GELU_() +class GELU_ extends tf.layers.Layer { + constructor() { + super({}) + } + + computeOutputShape(inputShape: any) { + return inputShape + } + + call(input: any, kwargs: any) { + return tf.tidy(() => { + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + const cdf = tf.mul( + 0.5, + tf.add( + 1, + tf.tanh( + tf.mul( + tf.sqrt(tf.div(2, Math.PI)), + tf.add(input, tf.mul(0.044715, tf.pow(input, 3))) + ) + ) + ) + ) + return tf.mul(input, cdf) + }) + } + + static get className() { + return 'GELU' + } +} +tf.serialization.registerClass(GELU_) + +function MLP(conf: any) { + const config = Object.assign({ name: 'mlp' }, conf) + const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] }) + let x + x = tf.layers + .dense({ + name: config.name + '/c_fc', + units: 4 * config.nEmbd, + inputDim: config.nEmbd, + inputShape: [config.blockSize, config.nEmbd], + }) + .apply(inputs) + x = GELU().apply(x) + x = tf.layers + .dense({ + name: config.name + '/c_proj', + units: config.nEmbd, + inputDim: 4 * config.nEmbd, + inputShape: [config.blockSize, 4 * config.nEmbd], + }) + .apply(x) + x = tf.layers + .dropout({ + name: config.name + '/drop', + rate: config.residDrop, + }) + .apply(x) + return tf.model({ inputs: inputs, outputs: x as any }) +} + +function Block(conf: any) { + const config = Object.assign({ name: 'h' }, conf) + const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] }) + let x1, x2 + x1 = tf.layers + .layerNormalization({ name: config.name + '/ln_1', epsilon: 1e-5 }) + .apply(inputs) + if (config.debug) { + x1 = LogLayer({ name: config.name + '/ln_1_log' }).apply(x1) + } + x1 = CausalSelfAttention( + Object.assign({}, config, { name: config.name + '/attn' }) + ).apply(x1) + x1 = tf.layers.add().apply([inputs, x1 as any]) + x2 = tf.layers + .layerNormalization({ name: config.name + '/ln_2', epsilon: 1e-5 }) + .apply(x1) + x2 = MLP(Object.assign({}, config, { name: config.name + '/mlp' })).apply( + x2 + ) + x2 = tf.layers.add().apply([x1 as any, x2 as any]) + return tf.model({ name: config.name, inputs: inputs, outputs: x2 as any }) +} + +function GPT(conf: GPTConfig) { + const configDefaults = { + name: 'transformer', + ...DEFAULT_CONFIG, + } + + const modelSizes = getModelSizes(conf.modelType) + const config = Object.assign({}, configDefaults, conf, modelSizes) + + console.log('IN MODEL CONFIG', config) + + const inputs = tf.input({ shape: [null] }) + + const tokEmb = config.tokEmb + ? tf.layers + .embedding({ + name: config.name + '/wte', + inputDim: config.vocabSize, + outputDim: config.nEmbd, + embeddingsInitializer: 'zeros', + embeddingsRegularizer: undefined, + activityRegularizer: undefined, + }) + .apply(inputs) + : inputs + + const range = Range({}).apply(inputs) + let posEmb = tf.layers + .embedding({ + name: config.name + '/wpe', + inputDim: config.blockSize, + outputDim: config.nEmbd, + embeddingsInitializer: 'zeros', + }) + .apply(range) + if (config.debug) { + posEmb = LogLayer({ name: 'posEmb' }).apply(posEmb) + } + + let x + x = tf.layers.add().apply([tokEmb as any, posEmb as any]) + x = tf.layers + .dropout({ + name: 'drop', + rate: config.embdDrop, + }) + .apply(x) + if (config.debug) { + x = LogLayer({ name: 'dropadd' }).apply(x) + } + + for (let i = 0; i < config.nLayer; i++) { + x = Block( + Object.assign({}, config, { name: config.name + '/h/' + i }) + ).apply(x) + } + x = tf.layers + .layerNormalization({ name: config.name + '/ln_f', epsilon: 1e-5 }) + .apply(x) + if (config.debug) { + x = LogLayer({ name: 'fin/ln' }).apply(x) + } + + if (config.lmHead) { + x = tf.layers + .dense({ + name: 'lm_head', + units: config.vocabSize, + inputDim: config.nEmbd, + inputShape: [config.blockSize, config.nEmbd], + useBias: false, + }) + .apply(x) + } + return tf.model({ inputs: inputs, outputs: x as any }) +} + +const defaultGenerateConfig = { + maxNewTokens: 20, + temperature: 1.0, + doSample: false, + topK: null, +} + +function prepareIdx(idx: any) { + tf.tidy(() => { + if (idx instanceof tf.Tensor) { + idx = idx.clone() + } else { + idx = tf.tensor(idx) + } + if (idx.dtype !== 'int32') { + idx = idx.toInt() + } + if (idx.shape.length === 1) { + idx = idx.expandDims(0) + } + tf.keep(idx) + }) + return idx +} + +function generateOnce(model: any, idx: any, config: any) { + let idxNext + let timePerToken = performance.now() + tf.tidy(() => { + const block_size = model.inputs[0].shape[1] + const idxCond = + idx.shape[1] <= block_size + ? idx + : idx.slice([0, -block_size], [-1, -1]) + const logits = model.predict(idxCond) + timePerToken = performance.now() - timePerToken + const logitsScaled = logits + .slice([0, idx.shape[1] - 1, 0]) + .reshape([logits.shape[0], logits.shape[2]]) + .div(tf.scalar(config.temperature)) + const probs = logitsScaled.softmax(-1) + if (config.doSample) { + idxNext = tf.multinomial(probs, 1) + } else { + idxNext = probs.argMax(-1) + idxNext = idxNext.expandDims(1) + } + tf.keep(idxNext) + }) + return { + idxNext, + timePerToken, + } +} + +function generateSync(model: any, idx: any, conf: any, callback: any) { + const config = Object.assign({}, defaultGenerateConfig, conf) + idx = prepareIdx(idx) + for (let step = 0; step < config.maxNewTokens; step++) { + const { idxNext, timePerToken } = generateOnce(model, idx, config) + const idxNew = idx.concat(idxNext, 1) + tf.dispose(idx) + idx = idxNew + const idxNextArr = (idxNext as any).arraySync() + tf.dispose(idxNext) + if (callback) { + callback({ idxNext: idxNextArr, timePerToken: timePerToken }) + } + } + const idxArr = idx.arraySync() + tf.dispose(idx) + return idxArr +} + +async function generate(model: any, idx: any, conf: any, callback: any) { + const config = Object.assign({}, defaultGenerateConfig, conf) + idx = await prepareIdx(idx) + for (let step = 0; step < config.maxNewTokens; step++) { + const { idxNext, timePerToken } = generateOnce(model, idx, config) + const idxNew = idx.concat(idxNext, 1) + tf.dispose(idx) + idx = idxNew + const idxNextArr = await (idxNext as any).array() + tf.dispose(idxNext) + if (callback) { + await callback({ idxNext: idxNextArr, timePerToken: timePerToken }) + } + } + const idxArr = await idx.array() + tf.dispose(idx) + return idxArr +} + +/** + * tfjs does not export LazyIterator and Dataset... + */ +declare abstract class LazyIterator { + abstract next(): Promise> +} + +declare abstract class Dataset { + abstract iterator(): Promise> + size: number +} + +class GPTModel extends tf.LayersModel { + constructor(protected readonly config: any) { + const gpt = GPT(config) + const { inputs, outputs, name } = gpt + super({ inputs, outputs, name }) + Object.assign(this, gpt) + } + + async fitDataset( + dataset: Dataset, + args: tf.ModelFitDatasetArgs + ): Promise { + console.log('=== GPTModel custom train function ===') + const config = { ...this.config, ...args } + await train( + this, + dataset as dataset.Dataset, + config, + args.callbacks as training.TrainingCallbacks + ) + return {} as tf.History + } + + async load(modelPath: any) { + this.loadWeights(modelPath) + } +} + +class GPTLMHeadModel extends GPTModel { + constructor(config: any) { + super(config) + } + + async generate(idx: any, conf: any, callback: any) { + return await generate(this, idx, conf, callback) + } + + generateSync(idx: any, conf: any, callback: any) { + return generateSync(this, idx, conf, callback) + } +} + +export { GPT, GPTModel, GPTLMHeadModel, generate, generateSync } diff --git a/discojs/discojs-core/src/training/models/gpt/optimizers.ts b/discojs/discojs-core/src/training/models/gpt/optimizers.ts new file mode 100644 index 000000000..df357f378 --- /dev/null +++ b/discojs/discojs-core/src/training/models/gpt/optimizers.ts @@ -0,0 +1,120 @@ +import { tf } from '../../..' + +type Tensor = tf.Tensor + +const ENGINE = tf.engine() + +function l2Loss(tensor: Tensor): Tensor { + return tf.div(tf.sum(tf.square(tensor)), 2) +} + +function globalNorm(tensors: Tensor[]): Tensor { + const halfSquaredNorms: Tensor[] = [] + tensors.forEach((tensor: Tensor, ti: number) => { + halfSquaredNorms.push(l2Loss(tensor)) + }) + const halfSquaredNorm: Tensor = tf.sum(tf.stack(halfSquaredNorms)) + const norm: Tensor = tf.sqrt( + tf.mul(halfSquaredNorm, tf.scalar(2.0, halfSquaredNorm.dtype)) + ) + return norm +} + +function clipByGlobalNorm( + tensors: Tensor[], + clipNorm: number, + useNorm?: Tensor +): Tensor[] { + useNorm = useNorm || globalNorm(tensors) + const scale: Tensor = tf.mul( + clipNorm, + tf.minimum( + tf.div(tf.scalar(1.0), useNorm), + tf.div(tf.scalar(1.0, useNorm.dtype), clipNorm) + ) + ) + const tensorsClipped: Tensor[] = [] + tensors.forEach((tensor: Tensor, ti: number) => { + tensorsClipped.push(tf.clone(tf.mul(tensor, scale))) + }) + return tensorsClipped +} + +function clipByGlobalNormObj( + tensorsObj: { [key: string]: Tensor }, + clipNorm: number, + useNorm?: Tensor +): { [key: string]: Tensor } { + const varNames: string[] = Object.keys(tensorsObj) + const tensorsArr: Tensor[] = varNames.map((n: string) => tensorsObj[n]) + const tensorsArrClipped: Tensor[] = clipByGlobalNorm( + tensorsArr, + clipNorm, + useNorm + ) + const tensorsObjClipped: { [key: string]: Tensor } = {} + tensorsArrClipped.forEach((t: Tensor, ti: number) => { + tensorsObjClipped[varNames[ti]] = t + }) + return tensorsObjClipped +} + +class AdamW extends tf.AdamOptimizer { + weightDecayRate: number + includeInWeightDecay: string[] + excludeFromWeightDecay: string[] + gradientClipNorm: number + + constructor(params: { + learningRate?: number + beta1?: number + beta2?: number + epsilon?: number + weightDecayRate?: number + includeInWeightDecay?: string[] + excludeFromWeightDecay?: string[] + gradientClipNorm?: number + }) { + console.log('Using custom AdamW optimizer') + const defaultParams = { + learningRate: 0.1, + beta1: 0.9, + beta2: 0.999, + epsilon: 1e-7, + weightDecayRate: 0, + includeInWeightDecay: [], + excludeFromWeightDecay: [], + gradientClipNorm: 1.0, + } + const p = Object.assign({}, defaultParams, params) + super(p.learningRate, p.beta1, p.beta2, p.epsilon) + this.weightDecayRate = p.weightDecayRate + this.includeInWeightDecay = p.includeInWeightDecay + this.excludeFromWeightDecay = p.excludeFromWeightDecay + this.gradientClipNorm = p.gradientClipNorm + } + + applyGradients(variableGradients: any): void { + const varNames: string[] = Array.isArray(variableGradients) + ? variableGradients.map((v: Tensor) => (v as any).name) + : Object.keys(variableGradients) + + varNames.forEach((name: string, i: number) => { + if (this.includeInWeightDecay.includes(name)) { + const value: any = ENGINE.registeredVariables[name] + const newValue: Tensor = tf.sub( + value, + tf.mul( + this.learningRate, + tf.mul(value, this.weightDecayRate) + ) + ) + value.assign(newValue) + } + }) + + super.applyGradients(variableGradients as any) + } +} + +export { AdamW, clipByGlobalNorm, clipByGlobalNormObj } diff --git a/discojs/discojs-core/src/training/models/gpt/train.ts b/discojs/discojs-core/src/training/models/gpt/train.ts new file mode 100644 index 000000000..ce87e2c74 --- /dev/null +++ b/discojs/discojs-core/src/training/models/gpt/train.ts @@ -0,0 +1,150 @@ +import { dataset, tf, training } from '../../..' +import { AdamW, clipByGlobalNormObj } from './optimizers' +import { GPTConfig, DEFAULT_CONFIG } from './config' +import evaluate from './evaluate' + +export type GPTConfigWithWandb = Required + +export const getConfig = (config: GPTConfig): GPTConfigWithWandb => ({ + ...DEFAULT_CONFIG, + ...config, +}) + +const getCustomAdam = (model: any, c: Required): tf.Optimizer => { + const includeInWeightDecay: string[] = [] + const excludeFromWeightDecay: string[] = [] + + model.getNamedWeights().forEach((v: any) => { + if ( + v.name.includes('bias') || + v.name.includes('normalization') || + v.name.includes('emb') + ) { + excludeFromWeightDecay.push(v.name) + } else { + includeInWeightDecay.push(v.name) + } + }) + return new AdamW({ + learningRate: c.lr, + weightDecayRate: c.weightDecay, + includeInWeightDecay, + excludeFromWeightDecay, + }) +} + +export async function train( + model: tf.LayersModel, + ds: dataset.Dataset, + config: GPTConfig, + callbacks: training.TrainingCallbacks, + evalDs?: dataset.Dataset +): Promise { + const c = getConfig(config) + console.log(c) + + const opt = c.weightDecay ? getCustomAdam(model, c) : tf.train.adam(c.lr) + + await callbacks.onTrainBegin() + + let epoch = 1 + let iteration = 1 + let iterator = await ds.iterator() + + const start = Date.now() + let time = start + + console.warn('=== Starting training ===') + await callbacks.onEpochBegin(epoch) + + while (true) { + await callbacks.onBatchBegin(iteration) + + // Get new batch of x and y + let datasetTime = Date.now() + let next = await iterator.next() + if (next.done) { + await callbacks.onEpochEnd(epoch) + epoch++ + if (c.epochs && epoch > c.epochs) { + break + } + await callbacks.onEpochBegin(epoch) + iterator = await ds.iterator() + next = await iterator.next() + } + const { xs, ys } = next.value + + datasetTime = Date.now() - datasetTime + + let iterationTime = Date.now() + + // Calculates loss, computes gradients and applies them + const loss = tf.tidy(() => { + let { grads, value: loss } = opt.computeGradients(() => { + const logits = model.apply(xs) + const loss = tf.losses.softmaxCrossEntropy(ys, logits) + return loss as tf.Scalar + }) + let gradsClipped = clipByGlobalNormObj(grads, 1) + opt.applyGradients(gradsClipped) + return loss + }) + + const lossVal = await loss.array() + + await callbacks.onBatchEnd(iteration) + + // Create a WandB log payload, evaluate every + const memory = tf.memory().numBytes * 0.000001 + const payload = { + 'train/perplexity': Math.exp(lossVal), + 'train/loss': lossVal, + iter: iteration, + 'tf-mem': memory, // MB + dt_ms: Date.now() - time, + time_s: (Date.now() - start) / 1000, + } + + if (c.evaluate && iteration % c.evaluateEvery === 0) { + if (!evalDs) { + throw new Error( + 'No evaluation dataset provided but config.evaluate is set' + ) + } + const evalPayload = await evaluate(model, evalDs, c) + Object.assign(payload, evalPayload) + } + + console.log(payload) + time = Date.now() + + tf.dispose([loss, xs, ys]) + + iterationTime = Date.now() - iterationTime + console.log( + `Epoch: ${epoch},\tStep: ${iteration} / ${ + c.maxIter + },\tLoss: ${lossVal.toFixed( + 3 + )},\tIteration time: ${iterationTime} ms, \tDataset time: ${datasetTime} ms,\tMemory: ${memory.toFixed( + 2 + )} MB` + ) + + // Check if we should stop + iteration++ + if (c.maxIter && iteration > c.maxIter) { + break + } + + if (c.verbose) { + console.log('Mem:', tf.memory()) + console.log(`Epoch: ${epoch}, Step: ${iteration}, Loss: ${lossVal}`) + } + + await new Promise((resolve) => setTimeout(resolve, 1)) + } + + await callbacks.onTrainEnd() +} diff --git a/discojs/discojs-core/src/training/models/index.ts b/discojs/discojs-core/src/training/models/index.ts new file mode 100644 index 000000000..69eab1f1f --- /dev/null +++ b/discojs/discojs-core/src/training/models/index.ts @@ -0,0 +1,2 @@ +export * as gpt from './gpt' +export type { GPTConfig, GPTConfigWithWandb } from './gpt' diff --git a/discojs/discojs-core/src/training/trainer/distributed_trainer.ts b/discojs/discojs-core/src/training/trainer/distributed_trainer.ts index a27783a50..4a0ef636c 100644 --- a/discojs/discojs-core/src/training/trainer/distributed_trainer.ts +++ b/discojs/discojs-core/src/training/trainer/distributed_trainer.ts @@ -1,4 +1,12 @@ -import { tf, training, Memory, Task, TrainingInformant, WeightsContainer, client as clients } from '../..' +import { + tf, + training, + Memory, + Task, + TrainingInformant, + WeightsContainer, + client as clients, +} from '../..' import { Aggregator } from '../../aggregator' import { Trainer } from './trainer' @@ -6,53 +14,69 @@ import { Trainer } from './trainer' * Class whose role is to train a model in a distributed way with a given dataset. */ export class DistributedTrainer extends Trainer { - private readonly aggregator: Aggregator - - /** - * DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights. - */ - constructor ( - task: Task, - trainingInformant: TrainingInformant, - memory: Memory, - model: training.model.Model, - private readonly client: clients.Client - ) { - super(task, trainingInformant, memory, model) - this.aggregator = this.client.aggregator - this.aggregator.setModel(model) - } - - async onTrainBegin (logs?: tf.Logs): Promise { - await super.onTrainBegin(logs) - - const weights = WeightsContainer.from(this.model) - - await this.client.onTrainBeginCommunication(weights, this.trainingInformant) - } - - async onRoundBegin (accuracy: number): Promise { - const weights = WeightsContainer.from(this.model) - - await this.client.onRoundBeginCommunication(weights, this.roundTracker.round, this.trainingInformant) - } - - /** - * Callback called every time a round is over - */ - async onRoundEnd (accuracy: number): Promise { - const weights = WeightsContainer.from(this.model) - - await this.client.onRoundEndCommunication(weights, this.roundTracker.round, this.trainingInformant) - if (this.aggregator.model !== undefined) { - // The aggregator's own aggregation is async. The trainer updates its model to match the aggregator's - // after it has completed a round of training. - this.model.toTfjs().setWeights(this.aggregator.model.toTfjs().getWeights()) + private readonly aggregator: Aggregator + + /** + * DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights. + */ + constructor( + task: Task, + trainingInformant: TrainingInformant, + memory: Memory, + model: training.model.Model, + private readonly client: clients.Client + ) { + super(task, trainingInformant, memory, model) + this.aggregator = this.client.aggregator + this.aggregator.setModel(model) + } + + async onTrainBegin(logs?: tf.Logs): Promise { + await super.onTrainBegin(logs) + + const weights = WeightsContainer.from(this.model) + + await this.client.onTrainBeginCommunication( + weights, + this.trainingInformant + ) } - await this.memory.updateWorkingModel( - { taskID: this.task.taskID, name: this.task.trainingInformation.modelID }, - this.model.toTfjs() - ) - } + async onRoundBegin(accuracy: number): Promise { + const weights = WeightsContainer.from(this.model) + + await this.client.onRoundBeginCommunication( + weights, + this.roundTracker.round, + this.trainingInformant + ) + } + + /** + * Callback called every time a round is over + */ + async onRoundEnd(accuracy: number): Promise { + const weights = WeightsContainer.from(this.model) + + await this.client.onRoundEndCommunication( + weights, + this.roundTracker.round, + this.trainingInformant + ) + if (this.aggregator.model !== undefined) { + // The aggregator's own aggregation is async. The trainer updates its model to match the aggregator's + // after it has completed a round of training. + this.model + .toTfjs() + .setWeights(this.aggregator.model.toTfjs().getWeights()) + } + + await this.memory.updateWorkingModel( + { + taskID: this.task.id, + name: this.task.trainingInformation.modelID, + }, + this.model.toTfjs() + ) + } } diff --git a/discojs/discojs-core/src/training/trainer/local_trainer.ts b/discojs/discojs-core/src/training/trainer/local_trainer.ts index 358e4887b..65a6a451f 100644 --- a/discojs/discojs-core/src/training/trainer/local_trainer.ts +++ b/discojs/discojs-core/src/training/trainer/local_trainer.ts @@ -4,19 +4,22 @@ import { Trainer } from './trainer' /** Class whose role is to locally (alone) train a model on a given dataset, without any collaborators. */ export class LocalTrainer extends Trainer { - async onRoundBegin (accuracy: number): Promise {} + async onRoundBegin(accuracy: number): Promise {} - async onRoundEnd (accuracy: number): Promise { - console.log('on round end') - await this.memory.updateWorkingModel( - { taskID: this.task.taskID, name: this.task.trainingInformation.modelID }, - this.model.toTfjs() - ) - } + async onRoundEnd(accuracy: number): Promise { + console.log('on round end') + await this.memory.updateWorkingModel( + { + taskID: this.task.id, + name: this.task.trainingInformation.modelID, + }, + this.model.toTfjs() + ) + } - onEpochEnd (epoch: number, logs?: tf.Logs): void { - super.onEpochEnd(epoch, logs) - console.log('on epoch end') - this.trainingInformant.update({ currentRound: epoch }) - } + onEpochEnd(epoch: number, logs?: tf.Logs): void { + super.onEpochEnd(epoch, logs) + console.log('on epoch end') + this.trainingInformant.update({ currentRound: epoch }) + } } diff --git a/discojs/discojs-core/src/training/trainer/trainer.ts b/discojs/discojs-core/src/training/trainer/trainer.ts index 5591d0f9b..3540ff58e 100644 --- a/discojs/discojs-core/src/training/trainer/trainer.ts +++ b/discojs/discojs-core/src/training/trainer/trainer.ts @@ -3,6 +3,16 @@ import { RoundTracker } from './round_tracker' import { TrainerLogger, TrainerLog } from '../../logging/trainer_logger' import { Model } from '../model' +// From tfjs base_callbacks.d.ts +export interface TrainingCallbacks { + onEpochBegin: (epoch: number, logs?: tf.Logs) => Promise + onEpochEnd: (epoch: number, logs?: tf.Logs) => Promise + onBatchBegin: (batch: number, logs?: tf.Logs) => Promise + onBatchEnd: (batch: number, logs?: tf.Logs) => Promise + onTrainBegin: (logs?: tf.Logs) => Promise + onTrainEnd: (logs?: tf.Logs) => Promise +} + /** Abstract class whose role is to train a model with a given dataset. This can be either done * locally (alone) or in a distributed way with collaborators. The Trainer works as follows: * @@ -13,134 +23,141 @@ import { Model } from '../model' * a round has ended we use the roundTracker object. */ export abstract class Trainer { - public readonly roundTracker: RoundTracker - - private stopTrainingRequested = false - private readonly trainerLogger: TrainerLogger - - /** - * Constructs the training manager. - * @param task the trained task - * @param trainingInformant the training informant - */ - constructor ( - public readonly task: Task, - public readonly trainingInformant: TrainingInformant, - public readonly memory: Memory, - public readonly model: Model - ) { - this.trainerLogger = new TrainerLogger() - this.roundTracker = new RoundTracker(task.trainingInformation.roundDuration) - } - - protected abstract onRoundBegin (accuracy: number): Promise - - /** - * Every time a round ends this function will be called - */ - protected abstract onRoundEnd (accuracy: number): Promise - - /** - * Callback executed on every batch end. When a round ends, onRoundEnd is called - */ - public async onBatchEnd (_: number, logs?: tf.Logs): Promise { - if (logs === undefined) { - return + public readonly roundTracker: RoundTracker + + private stopTrainingRequested = false + private readonly trainerLogger: TrainerLogger + + /** + * Constructs the training manager. + * @param task the trained task + * @param trainingInformant the training informant + */ + constructor( + public readonly task: Task, + public readonly trainingInformant: TrainingInformant, + public readonly memory: Memory, + public readonly model: Model + ) { + this.trainerLogger = new TrainerLogger() + this.roundTracker = new RoundTracker( + task.trainingInformation.roundDuration + ) + } + + protected abstract onRoundBegin(accuracy: number): Promise + + /** + * Every time a round ends this function will be called + */ + protected abstract onRoundEnd(accuracy: number): Promise + + /** + * Callback executed on every batch end. When a round ends, onRoundEnd is called + */ + public async onBatchEnd(_: number, logs?: tf.Logs): Promise { + if (logs === undefined) { + return + } + + this.roundTracker.updateBatch() + this.stopTrainModelIfRequested() + + if (this.roundTracker.roundHasEnded()) { + await this.onRoundEnd(logs.acc) + } } - this.roundTracker.updateBatch() - this.stopTrainModelIfRequested() + async onBatchBegin(_: number, logs?: tf.Logs): Promise { + if (logs === undefined) { + return + } - if (this.roundTracker.roundHasEnded()) { - await this.onRoundEnd(logs.acc) + if (this.roundTracker.roundHasBegun()) { + await this.onRoundBegin(logs.acc) + } } - } - async onBatchBegin (_: number, logs?: tf.Logs): Promise { - if (logs === undefined) { - return + onEpochBegin(epoch: number, logs?: tf.Logs): void {} + + /** + * We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd. + */ + onEpochEnd(epoch: number, logs?: tf.Logs): void { + this.trainerLogger.onEpochEnd(epoch, logs) + + if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) { + this.trainingInformant.updateTrainingGraph( + this.roundDecimals(logs.acc) + ) + this.trainingInformant.updateValidationGraph( + this.roundDecimals(logs.val_acc) + ) + } else { + this.trainerLogger.error('onEpochEnd: NaN value') + } } - if (this.roundTracker.roundHasBegun()) { - await this.onRoundBegin(logs.acc) + async onTrainBegin(logs?: tf.Logs): Promise { + this.trainingInformant.addMessage('Training started.') } - } - onEpochBegin (epoch: number, logs?: tf.Logs): void {} + /** + * When the training ends this function will be call + */ + async onTrainEnd(logs?: tf.Logs): Promise { + this.trainingInformant.addMessage('Training finished.') + } + + /** + * Request stop training to be used from the Disco instance or any class that is taking care of the trainer. + */ + async stopTraining(): Promise { + this.stopTrainingRequested = true + } + + /** + * Starts training the model with the given dataset. The exact behavior for model weights updates + * is model-dependent and is thus left to the model. The trainer instance is given for the fit function + * to be able to access the regular TF.js training hooks, which may include communication in the case of + * decentralized & federated learning. + * @param dataset + */ + async fitModel(data: data.tuple.DataSplit): Promise { + this.resetStopTrainerState() + + await this.model.fit(this, data) + } - /** - * We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd. - */ - onEpochEnd (epoch: number, logs?: tf.Logs): void { - this.trainerLogger.onEpochEnd(epoch, logs) + /** + * Format accuracy + */ + protected roundDecimals( + accuracy: number, + decimalsToRound: number = 2 + ): number { + return +(accuracy * 100).toFixed(decimalsToRound) + } - if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) { - this.trainingInformant.updateTrainingGraph(this.roundDecimals(logs.acc)) - this.trainingInformant.updateValidationGraph(this.roundDecimals(logs.val_acc)) - } else { - this.trainerLogger.error('onEpochEnd: NaN value') + /** + * reset stop training state + */ + protected resetStopTrainerState(): void { + this.model.toTfjs().stopTraining = false + this.stopTrainingRequested = false } - } - - async onTrainBegin (logs?: tf.Logs): Promise { - this.trainingInformant.addMessage('Training started.') - } - - /** - * When the training ends this function will be call - */ - async onTrainEnd (logs?: tf.Logs): Promise { - this.trainingInformant.addMessage('Training finished.') - } - - /** - * Request stop training to be used from the Disco instance or any class that is taking care of the trainer. - */ - async stopTraining (): Promise { - this.stopTrainingRequested = true - } - - /** - * Starts training the model with the given dataset. The exact behavior for model weights updates - * is model-dependent and is thus left to the model. The trainer instance is given for the fit function - * to be able to access the regular TF.js training hooks, which may include communication in the case of - * decentralized & federated learning. - * @param dataset - */ - async fitModel ( - data: data.tuple.DataSplit - ): Promise { - this.resetStopTrainerState() - - await this.model.fit(this, data) - } - - /** - * Format accuracy - */ - protected roundDecimals (accuracy: number, decimalsToRound: number = 2): number { - return +(accuracy * 100).toFixed(decimalsToRound) - } - - /** - * reset stop training state - */ - protected resetStopTrainerState (): void { - this.model.toTfjs().stopTraining = false - this.stopTrainingRequested = false - } - - /** - * If stop training is requested, do so - */ - protected stopTrainModelIfRequested (): void { - if (this.stopTrainingRequested) { - this.model.toTfjs().stopTraining = true - this.stopTrainingRequested = false + + /** + * If stop training is requested, do so + */ + protected stopTrainModelIfRequested(): void { + if (this.stopTrainingRequested) { + this.model.toTfjs().stopTraining = true + this.stopTrainingRequested = false + } } - } - getTrainerLog (): TrainerLog { - return this.trainerLogger.log - } + getTrainerLog(): TrainerLog { + return this.trainerLogger.log + } } diff --git a/discojs/discojs-core/src/training/trainer/trainer_builder.ts b/discojs/discojs-core/src/training/trainer/trainer_builder.ts index 51dead500..b92388bd0 100644 --- a/discojs/discojs-core/src/training/trainer/trainer_builder.ts +++ b/discojs/discojs-core/src/training/trainer/trainer_builder.ts @@ -1,4 +1,12 @@ -import { client as clients, Task, TrainingInformant, Memory, ModelType, ModelInfo, training } from '../..' +import { + client as clients, + Task, + TrainingInformant, + Memory, + ModelType, + ModelInfo, + training, +} from '../..' import { Aggregator } from '../../aggregator' import { DistributedTrainer } from './distributed_trainer' @@ -9,79 +17,91 @@ import { Trainer } from './trainer' * A class that helps build the Trainer and auxiliary classes. */ export class TrainerBuilder { - constructor ( - private readonly memory: Memory, - private readonly task: Task, - private readonly trainingInformant: TrainingInformant - ) {} + constructor( + private readonly memory: Memory, + private readonly task: Task, + private readonly trainingInformant: TrainingInformant + ) {} - /** - * Builds a trainer object. - * - * @param client client to share weights with (either distributed or federated) - * @param distributed whether to build a distributed or local trainer - * @returns - */ - async build (aggregator: Aggregator, client: clients.Client, distributed: boolean = false): Promise { - const model = await this.getModel(client) - if (distributed) { - return new DistributedTrainer( - this.task, - this.trainingInformant, - this.memory, - model, - client - ) - } else { - return new LocalTrainer( - this.task, - this.trainingInformant, - this.memory, - model - ) + /** + * Builds a trainer object. + * + * @param client client to share weights with (either distributed or federated) + * @param distributed whether to build a distributed or local trainer + * @returns + */ + async build( + aggregator: Aggregator, + client: clients.Client, + distributed: boolean = false + ): Promise { + const model = await this.getModel(client) + if (distributed) { + return new DistributedTrainer( + this.task, + this.trainingInformant, + this.memory, + model, + client + ) + } else { + return new LocalTrainer( + this.task, + this.trainingInformant, + this.memory, + model + ) + } } - } - /** - * If a model exists in memory, laod it, otherwise load model from server - * @returns - */ - private async getModel (client: clients.Client): Promise { - const modelID = this.task.trainingInformation?.modelID - if (modelID === undefined) { - throw new TypeError('model ID is undefined') - } - - const info: ModelInfo = { type: ModelType.WORKING, taskID: this.task.taskID, name: modelID } + /** + * If a model exists in memory, laod it, otherwise load model from server + * @returns + */ + private async getModel( + client: clients.Client + ): Promise { + const modelID = this.task.trainingInformation?.modelID + if (modelID === undefined) { + throw new TypeError('model ID is undefined') + } - const model = await ( - await this.memory.contains(info) ? this.memory.getModel(info) : client.getLatestModel() - ) + const info: ModelInfo = { + type: ModelType.WORKING, + taskID: this.task.id, + name: modelID, + } - return await this.updateModelInformation(model) - } + const model = await ((await this.memory.contains(info)) + ? this.memory.getModel(info) + : client.getLatestModel()) - private async updateModelInformation (model: training.model.Model): Promise { - const m = model.toTfjs() - // Continue local training from previous epoch checkpoint - if (m.getUserDefinedMetadata() === undefined) { - m.setUserDefinedMetadata({ epoch: 0 }) + return await this.updateModelInformation(model) } - const info = this.task.trainingInformation - if (info === undefined) { - throw new TypeError('training information is undefined') - } + private async updateModelInformation( + model: training.model.Model + ): Promise { + const m = model.toTfjs() + // Continue local training from previous epoch checkpoint + if (m.getUserDefinedMetadata() === undefined) { + m.setUserDefinedMetadata({ epoch: 0 }) + } - m.compile(info.modelCompileData) + const info = this.task.trainingInformation + if (info === undefined) { + throw new TypeError('training information is undefined') + } - if (info.learningRate !== undefined) { - // TODO: Not the right way to change learningRate and hence we cast to any - // the right way is to construct the optimiser and pass learningRate via - // argument. - m.optimizer.learningRate = info.learningRate - } + m.compile(info.modelCompileData) + + if (info.learningRate !== undefined) { + // TODO: Not the right way to change learningRate and hence we cast to any + // the right way is to construct the optimiser and pass learningRate via + // argument. + m.optimizer.learningRate = info.learningRate + } - return model - } + return model + } } diff --git a/discojs/discojs-core/src/validation/validator.spec.ts b/discojs/discojs-core/src/validation/validator.spec.ts index 0e060a97c..6fa897eec 100644 --- a/discojs/discojs-core/src/validation/validator.spec.ts +++ b/discojs/discojs-core/src/validation/validator.spec.ts @@ -1,79 +1,100 @@ import { assert } from 'chai' import fs from 'fs' -import { Task, node, Validator, ConsoleLogger, EmptyMemory, client as clients, data, aggregator } from '@epfml/discojs-node' +import { + Task, + node, + Validator, + ConsoleLogger, + EmptyMemory, + client as clients, + data, + aggregator, +} from '@epfml/discojs-node' const simplefaceMock = { - taskID: 'simple_face', - displayInformation: {}, - trainingInformation: { - modelID: 'simple_face-model', - batchSize: 4, - dataType: 'image', - IMAGE_H: 200, - IMAGE_W: 200, - LABEL_LIST: ['child', 'adult'], - modelCompileData: { - optimizer: 'sgd', - loss: 'categoricalCrossentropy', - metrics: ['accuracy'] - } - } + id: 'simple_face', + displayInformation: {}, + trainingInformation: { + modelID: 'simple_face-model', + batchSize: 4, + dataType: 'image', + IMAGE_H: 200, + IMAGE_W: 200, + LABEL_LIST: ['child', 'adult'], + modelCompileData: { + optimizer: 'sgd', + loss: 'categoricalCrossentropy', + metrics: ['accuracy'], + }, + }, } as unknown as Task describe('validator', () => { - it('works for simple_face', async () => { - const dir = '../../example_training_data/simple_face/' - const files: string[][] = ['child/', 'adult/'] - .map((subdir: string) => fs.readdirSync(dir + subdir) - .map((file: string) => dir + subdir + file)) - const labels = files.flatMap((files, index) => Array(files.length).fill(index)) + it('works for simple_face', async () => { + const dir = '../../example_training_data/simple_face/' + const files: string[][] = ['child/', 'adult/'].map((subdir: string) => + fs + .readdirSync(dir + subdir) + .map((file: string) => dir + subdir + file) + ) + const labels = files.flatMap((files, index) => + Array(files.length).fill(index) + ) - const data: data.Data = (await new node.data.NodeImageLoader(simplefaceMock) - .loadAll(files.flat(), { labels })).train - const buffer = new aggregator.MeanAggregator(simplefaceMock) - const client = new clients.Local(new URL('http://localhost:8080'), simplefaceMock, buffer) - buffer.setModel(await client.getLatestModel()) - const validator = new Validator( - simplefaceMock, - new ConsoleLogger(), - new EmptyMemory(), - undefined, - client - ) - await validator.assess(data) - const size = data.size !== undefined ? data.size : -1 - if (size === -1) { - console.log('data.size was undefined') - } - assert( - validator.visitedSamples === data.size, - `expected ${size} visited samples but got ${validator.visitedSamples}` - ) - assert( - validator.accuracy > 0.3, - `expected accuracy greater than 0.3 but got ${validator.accuracy}` - ) - console.table(validator.confusionMatrix) - }).timeout(10_000) + const data: data.Data = ( + await new node.data.NodeImageLoader(simplefaceMock).loadAll( + files.flat(), + { labels } + ) + ).train + const buffer = new aggregator.MeanAggregator(simplefaceMock) + const client = new clients.Local( + new URL('http://localhost:8080'), + simplefaceMock, + buffer + ) + buffer.setModel(await client.getLatestModel()) + const validator = new Validator( + simplefaceMock, + new ConsoleLogger(), + new EmptyMemory(), + undefined, + client + ) + await validator.assess(data) + const size = data.size !== undefined ? data.size : -1 + if (size === -1) { + console.log('data.size was undefined') + } + assert( + validator.visitedSamples === data.size, + `expected ${size} visited samples but got ${validator.visitedSamples}` + ) + assert( + validator.accuracy > 0.3, + `expected accuracy greater than 0.3 but got ${validator.accuracy}` + ) + console.table(validator.confusionMatrix) + }).timeout(10_000) - // TODO: fix titanic model (nan accuracy) - // it('works for titanic', async () => { - // const data: Data = await new NodeTabularLoader(titanic.task, ',') - // .loadAll(['../../example_training_data/titanic.csv'], { - // features: titanic.task.trainingInformation?.inputColumns, - // labels: titanic.task.trainingInformation?.outputColumns - // }) - // const validator = new Validator(titanic.task, new ConsoleLogger(), titanic.model()) - // await validator.assess(data) + // TODO: fix titanic model (nan accuracy) + // it('works for titanic', async () => { + // const data: Data = await new NodeTabularLoader(titanic.task, ',') + // .loadAll(['../../example_training_data/titanic.csv'], { + // features: titanic.task.trainingInformation?.inputColumns, + // labels: titanic.task.trainingInformation?.outputColumns + // }) + // const validator = new Validator(titanic.task, new ConsoleLogger(), titanic.model()) + // await validator.assess(data) - // assert( - // validator.visitedSamples() === data.size, - // `expected ${TITANIC_SAMPLES} visited samples but got ${validator.visitedSamples()}` - // ) - // assert( - // validator.accuracy() > 0.5, - // `expected accuracy greater than 0.5 but got ${validator.accuracy()}` - // ) - // }) + // assert( + // validator.visitedSamples() === data.size, + // `expected ${TITANIC_SAMPLES} visited samples but got ${validator.visitedSamples()}` + // ) + // assert( + // validator.accuracy() > 0.5, + // `expected accuracy greater than 0.5 but got ${validator.accuracy()}` + // ) + // }) }) diff --git a/discojs/discojs-node/src/dataset/data_loader/index.ts b/discojs/discojs-node/src/dataset/data_loader/index.ts index 0ea714444..bb9089d34 100644 --- a/discojs/discojs-node/src/dataset/data_loader/index.ts +++ b/discojs/discojs-node/src/dataset/data_loader/index.ts @@ -1,2 +1,3 @@ export { NodeImageLoader } from './image_loader' export { NodeTabularLoader } from './tabular_loader' +export { NodeTextLoader } from './text_loader' diff --git a/discojs/discojs-node/src/dataset/data_loader/text_loader.spec.ts b/discojs/discojs-node/src/dataset/data_loader/text_loader.spec.ts new file mode 100644 index 000000000..6d98f130b --- /dev/null +++ b/discojs/discojs-node/src/dataset/data_loader/text_loader.spec.ts @@ -0,0 +1,174 @@ +import fs from 'fs' +import path from 'path' +import { describe, test, expect } from 'bun:test' +import { encode, decode } from 'gpt-tokenizer/model/text-davinci-003' + +import { tf, defaultTasks, dataset } from '../..' +import { NodeTextLoader } from '.' + +/** + * ================================================ + * Assumes you have followed the installation steps + * in disco/experiment (see README.md) + * ================================================ + */ + +const datasetsFolder = path.join( + /* @ts-ignore */ + import.meta.dir, + '../../../../../experiment', + 'datasets', + 'wikitext-103' +) +const trainFile = 'test' + +const source: dataset.TextSource = { + train: [path.join(datasetsFolder, `${trainFile}.tokens`)], + // validation: [path.join(datasetsFolder, 'validation.tokens')], +} + +const task = defaultTasks.wikitext.getTask() +const config = { + ...task.trainingInformation.modelConfig, + blockSize: 16, + batchSize: 4, + vocabSize: 50257, +} + +const BENCHMARK_ITERATIONS = 1_000 +const BENCHMARK_BATCH_SIZES = [4, 16, 32] +const BENCHMARK_BLOCK_SIZES = [64, 128, 256, 512] + +// config: gpt.GPTConfig +const getDataset = async (config: Partial) => { + const loaded = await new NodeTextLoader(task).loadAll(source, config) + const ds = loaded.train.dataset as dataset.TokenizedDataset + return ds +} + +// config: gpt.GPTConfig +const getIterator = async (config: Partial) => { + const ds = await getDataset(config) + const iter = await ds.iterator() + return { + next: async () => { + const { value } = (await iter.next()) as dataset.TokenizedIterResult + return { xs: value.xs, ys: value.ys } + }, + } +} + +const getIteratorArray = async (config: any) => { + const iter = await getIterator(config) + return { + next: async () => { + const { xs, ys } = await iter.next() + const x = await xs.array() + const y = await (ys.argMax(2) as tf.Tensor2D).array() // get indices of max values along last axis + return { x, y } + }, + } +} + +/** + * Reads the RAW dataset (not preprocessed) and tokenizes the equivalent of the first batch. + */ +const getRawTokenizedSample = async ( + sampleSize: number, + tokensLength: number +) => { + const wikiRaw = fs.createReadStream(path.join(datasetsFolder, trainFile), { + encoding: 'utf8', + start: 0, + end: sampleSize * 1.5, // * 1.5 to make sure we have enough tokens + }) + const iter = wikiRaw.iterator() + const { value: chunk } = await iter.next() + const tokens = encode(chunk).slice(0, tokensLength) + return tokens +} + +const correctShapeTest = ( + xs: tf.Tensor2D, + ys: tf.Tensor3D, + config: Required> +) => { + expect(xs.shape).toEqual([config.batchSize, config.blockSize]) + expect(ys.shape).toEqual([ + config.batchSize, + config.blockSize, + config.vocabSize, + ]) +} + +describe('node text loader', () => { + test('loads a batched sample with correct x and y shapes', async () => { + const iter = await getIterator(config) + const { xs, ys } = await iter.next() + + correctShapeTest(xs, ys, config) + + tf.dispose([xs, ys]) + }) + + test('x without [0] equals y without [-1]', async () => { + const TEST_SIZE = 10 + const iter = await getIteratorArray(config) + for (let i = 0; i < TEST_SIZE; i++) { + const { x, y } = await iter.next() + for (let i = 0; i < config.batchSize; i++) { + // console.log('x=', decode(x_arr[i]).trim()) + // console.log('y=', decode(y_arr[i]).trim()) + expect(x[i].slice(1)).toEqual(y[i].slice(0, -1)) + } + } + }) + + test('dataset is tokenized properly', async () => { + const iter = await getIteratorArray(config) + const { x, y } = await iter.next() + + /** + * Flatten the batch by taking the first token in x and the rest in y, since y is x shifted by 1 + 1 token + * e.g. [a, b, c, d, e, f] -> x = [a, b, c, d, e] and y = [b, c, d, e, f] + * thus x[0] + y = [a, b, c, d, e, f] + **/ + const sample: number[] = [] + for (let i = 0; i < config.batchSize; i++) { + sample.push(x[i][0], ...y[i]) + } + const textLength = decode(sample).length + const tokens = await getRawTokenizedSample(textLength, sample.length) + + expect(sample.length).toBe(tokens.length) + expect(sample).toEqual(tokens) + }) + + test(`benchmark ${BENCHMARK_ITERATIONS} iterations for batch sizes: ${BENCHMARK_BATCH_SIZES} and block sizes: ${BENCHMARK_BLOCK_SIZES}`, async () => { + for (const batchSize of BENCHMARK_BATCH_SIZES) { + for (const blockSize of BENCHMARK_BLOCK_SIZES) { + const c = { + ...config, + batchSize, + blockSize, + } + const iter = await getIterator(c) + const benchmarkStart = Date.now() + for (let i = 0; i < BENCHMARK_ITERATIONS; i++) { + const { xs, ys } = await iter.next() + if (i === 0) correctShapeTest(xs, ys, c) + tf.dispose([xs, ys]) + } + const benchmarkEnd = Date.now() + const ms = benchmarkEnd - benchmarkStart + console.log( + `[batchSize=${c.batchSize}, blockSize=${ + c.blockSize + }] Time per iteration: ${( + ms / BENCHMARK_ITERATIONS + ).toFixed(3)}ms` + ) + } + } + }, 256_000) +}) diff --git a/discojs/discojs-node/src/dataset/data_loader/text_loader.ts b/discojs/discojs-node/src/dataset/data_loader/text_loader.ts index 0044e07fa..f392d09b3 100644 --- a/discojs/discojs-node/src/dataset/data_loader/text_loader.ts +++ b/discojs/discojs-node/src/dataset/data_loader/text_loader.ts @@ -1,30 +1,158 @@ -// import fs from 'node:fs' - -// import split2 from 'split2' - -// import { tf } from '../..' -// import { TextLoader } from '../../core/dataset/data_loader/text_loader' -// import { Dataset } from '../../core/dataset' -// import { DataConfig } from '../../core/dataset/data_loader' - -// export class NodeTextLoader extends TextLoader { -// async loadDatasetFrom (source: string, config?: DataConfig): Promise { -// const prefix = 'file://' -// if (source.slice(0, 7) !== prefix) { -// source = prefix + source -// } -// // create stream being read by generator -// const stream = fs.createReadStream(source, { encoding: 'utf-8' }) -// // eslint-disable-next-line @typescript-eslint/no-this-alias -// const self = this - -// async function * dataGenerator (): AsyncGenerator { -// // TODO @s314cy -// const withLabels = config?.labels !== undefined -// stream.pipe(split2()) -// stream.on('data', (data) => yield self.tokenize(data)) -// } - -// return tf.data.generator(dataGenerator) -// } -// } +import { dataset } from '../..' +import fs from 'fs' +import { List } from 'immutable' + +/* + +TODO: Bun.file().stream() is kind of broken for now +See: https://github.com/oven-sh/bun/pull/7506 +and: https://github.com/oven-sh/bun/issues/7057 +When the PR is merged, we can probably use the following code: + +const stream = Bun.file(source).stream() +async function* generator() { + for await (const chunk of stream) { + console.log('GENERATOR', chunk.length) + yield chunk + } +} +return { stream, iter: generator() } + + +async getInfiniteBufferIteratorFromFile( + source: string, + config: dataset.TextConfig + ): Promise> { + const getStream = async () => { + return await this.getFileStream(source, config) + } + let { stream, iter } = await getStream() + return { + next: async () => { + let sample = await iter.next() + if (!sample || !sample.value || sample.done) { + await stream.cancel() + const newStream = await getStream() + stream = newStream.stream + iter = newStream.iter + sample = await iter.next() + if (!sample || !sample.value || sample.done) { + throw new Error( + 'Getting a sample from the file stream still fails after retrying, most likely the file at ' + + source + + ' is empty..' + ) + } + } + return sample as IteratorResult + }, + } + } + +*/ + +export class NodeTextLoader extends dataset.loader.TextLoader { + /** + * Creates a file stream from a dataset filename. + * This stream will contain a specific number of bytes + * defined by the highWaterMark parameter which depends on the + * block size and batch size. This ensures that reading the stream + * always return a chunk of data of the same, required, size. + * @param source: dataset filename to stream from + * @param config: TextConfig + * @returns a file stream + */ + getFileStream(source: string, chunkSize: number) { + return new Promise((resolve) => { + const stream = fs.createReadStream(source, { + fd: undefined, + highWaterMark: chunkSize, + }) + stream.on('readable', () => resolve(stream)) + }) + } + + getChunkSize(config: dataset.TextConfig) { + const batchSize = this.getBatchSize(config) + // blockSize + 1 = input size (size of x = blockSize, size of y = blockSize shifted right by 1, thus the + 1) + // * batchSize to retrieve a batch at once + // * 2 because tokens are stored as uint16 and thus require 2 bytes + return (config.blockSize + 1) * batchSize * 2 + } + + /** + * Creates an infinite iterator from a file stream + * meaning when the stream reaches the end of the file + * it will start again from the beginning + * @param source: dataset filename to stream from + * @param config: TextConfig + * @returns an infinite iterator over a file stream + */ + async getInfiniteBufferIteratorFromFile( + source: string, + config: dataset.TextConfig + ): Promise> { + const chunkSize = this.getChunkSize(config) + + if (isNaN(chunkSize)) + throw new Error( + 'chunk size, is NaN but is supposed to be of type number' + ) + + const getStream = async () => + await this.getFileStream(source, chunkSize) + + let stream = await getStream() + return { + next: async () => { + let buffer = (await stream.read(chunkSize)) as + | Buffer + | undefined + if (!buffer) { + stream.close() + stream = await getStream() + buffer = await stream.read(chunkSize) + if (!buffer) { + throw new Error( + 'Getting a sample from the file stream still fails after retrying, most likely the file at ' + + source + + ' is empty..' + ) + } + } + return { value: buffer, done: false } + }, + } + } + + async load( + source: string, + config: dataset.TextConfig + ): Promise { + const requestNext = await this.getInfiniteBufferIteratorFromFile( + source, + config + ) + const dataset = await this.getCoreDataset(config, requestNext) + return dataset + } + + async loadAll( + source: dataset.TextSource, + config?: Partial + ): Promise { + const _config = this.resolveConfig(config) + const split: Partial = {} + for await (const [k, files] of Object.entries(source)) { + const datasets = await Promise.all( + files.map(async (src) => await this.load(src, _config)) + ) + const dataset = List(datasets).reduce( + (acc: dataset.Dataset, dataset) => acc.concatenate(dataset) + ) + const data = await this.createData(dataset) + ;(split as dataset.DataSplit)[k as keyof typeof split] = data + } + return split as dataset.DataSplit + } +} diff --git a/discojs/discojs-web/src/dataset/data_loader/text-worker.ts b/discojs/discojs-web/src/dataset/data_loader/text-worker.ts new file mode 100644 index 000000000..b3cbe7026 --- /dev/null +++ b/discojs/discojs-web/src/dataset/data_loader/text-worker.ts @@ -0,0 +1,95 @@ +// prevents TS errors +declare var self: Worker + +import { v4 as randomUUID } from 'uuid' +import { dataset } from '../..' +import { Cache } from './cache' + +type MessageData = { + value: { + type: 'Buffer' + data: number[] + } + done: boolean + pos: number +} + +export type CacheData = { + value: number[] + done: boolean + pos: number +} + +// TODO: make brokerURL configurable and at least stored in .env +// or automatically retrieved and compatible with websocket server somehow +const BROKER_URL = 'ws://localhost:3001/ws' + +const { FILE, CONFIG, CACHE_SIZE } = process.env as { + ID: string + FILE: string + CONFIG: string + CACHE_SIZE: string +} + +/** + * Creates a url and connect to the websocket server + * @param file: filename corresponding to the file the websocket server will stream + * @param config entries: all the config key, value pairs. The config object will be reconstructed in the websocket server side + */ +const getWebsocket = () => { + const url = new URL(BROKER_URL) + + const id = randomUUID() + const searchParams: dataset.WSSearchParams = { + id, + config: CONFIG, + file: FILE, + } + for (const [k, v] of Object.entries(searchParams)) + url.searchParams.append(k, v) + + const ws = new WebSocket(url) + + ws.onerror = (err) => { + console.error(err) + } + + return { ws, id } +} + +const { ws, id } = getWebsocket() + +const proceed = async () => { + console.log('worker', id, 'connected') + + const request = (pos: number) => { + // console.log(Date.now(), 'WORKER requesting next value', pos) + ws.send(JSON.stringify({ pos, id })) + } + + const cache = await Cache.init( + parseInt(CACHE_SIZE), + request, + (c) => { + ws.onmessage = (payload: MessageEvent) => { + const { value, done, pos } = JSON.parse( + payload.data as string + ) as MessageData + // console.log(Date.now(), 'WORKER received from ws', pos) + c.put(pos, { value: value.data, done, pos }) + } + } + ) + + self.onmessage = async (event: MessageEvent) => { + // console.log(Date.now(), 'WORKER onmessage') + // console.time('onmessage') + const sample = await cache.next() + // console.timeEnd('onmessage') + postMessage(JSON.stringify(sample)) + } + + self.postMessage('connected') +} + +ws.onopen = proceed diff --git a/discojs/discojs-web/src/dataset/data_loader/text_loader.spec.ts b/discojs/discojs-web/src/dataset/data_loader/text_loader.spec.ts new file mode 100644 index 000000000..2576d854e --- /dev/null +++ b/discojs/discojs-web/src/dataset/data_loader/text_loader.spec.ts @@ -0,0 +1,210 @@ +/// +import { GlobalRegistrator } from '@happy-dom/global-registrator' +const oldConsole = console +GlobalRegistrator.register() +window.console = oldConsole + +import fs from 'fs' +import path from 'path' +import { describe, test, expect } from 'bun:test' +import { encode, decode } from 'gpt-tokenizer/esm/model/text-davinci-003' +import { tf, dataset, defaultTasks, Task, Disco } from '../..' +import { WebTextLoader } from '.' + +/** + * ================================================ + * Assumes you have followed the installation steps + * in disco/experiment (see README.md) + * ================================================ + */ + +const datasetsFolder = path.join( + '..', + '..', + 'experiment', + 'datasets', + 'wikitext-103' +) + +const trainFile = 'test' + +const source: dataset.TextSource = { + train: [path.join(datasetsFolder, `${trainFile}.tokens`)], + // validation: [path.join(datasetsFolder, 'validation.tokens')], +} + +const task = defaultTasks.wikitext.getTask() +const config: Required> = { + ...task.trainingInformation.modelConfig, + blockSize: 16, + batchSize: 4, + vocabSize: 50258, +} + +const BENCHMARK_ITERATIONS = 1000 +const BENCHMARK_BATCH_SIZES = [4, 16, 32] +const BENCHMARK_BLOCK_SIZES = [64, 128, 256, 512] + +const getDataset = async (config: Partial) => { + const loaded = await new WebTextLoader(task).loadAll(source, config) + const ds = loaded.train.dataset as dataset.TokenizedDataset + return ds +} + +// config: gpt.GPTConfig +const getIterator = async (config: Partial) => { + const ds = await getDataset(config) + const iter = await ds.iterator() + return { + next: async () => { + const { value } = (await iter.next()) as dataset.TokenizedIterResult + return { xs: value.xs, ys: value.ys } + }, + } +} + +const getIteratorArray = async (config: any) => { + const iter = await getIterator(config) + return { + next: async () => { + const { xs, ys } = await iter.next() + const x = await xs.array() + const y = await (ys.argMax(2) as tf.Tensor2D).array() // get indices of max values along last axis + return { x, y } + }, + } +} + +/** + * Reads the RAW dataset (not preprocessed) and tokenizes the equivalent of the first batch. + */ +const getRawTokenizedSample = async ( + sampleSize: number, + tokensLength: number +) => { + const wikiRaw = fs.createReadStream( + path.join( + /* @ts-ignore */ + import.meta.dir, + '..', + '..', + '..', + datasetsFolder, + trainFile + ), + { + encoding: 'utf8', + start: 0, + end: sampleSize * 1.5, // * 1.5 to make sure we have enough tokens + } + ) + const iter = wikiRaw.iterator() + const { value: chunk } = await iter.next() + const tokens = encode(chunk).slice(0, tokensLength) + return tokens +} + +const correctShapeTest = ( + xs: tf.Tensor2D, + ys: tf.Tensor3D, + config: Required> +) => { + expect(xs.shape).toEqual([config.batchSize, config.blockSize]) + expect(ys.shape).toEqual([ + config.batchSize, + config.blockSize, + config.vocabSize, + ]) +} + +describe('web text loader', () => { + test('loads a batched sample', async () => { + const iter = await getIterator(config) + const { xs, ys } = await iter.next() + + correctShapeTest(xs, ys, config) + + tf.dispose([xs, ys]) + }) + + test('x without [0] equals y without [-1]', async () => { + const TEST_SIZE = 10 + const iter = await getIteratorArray(config) + for (let i = 0; i < TEST_SIZE; i++) { + const { x, y } = await iter.next() + for (let j = 0; j < config.batchSize; j++) { + // console.log('x=', decode(x_arr[i]).trim()) + // console.log('y=', decode(y_arr[i]).trim()) + expect(x[j].slice(1)).toEqual(y[j].slice(0, -1)) + } + } + }) + + test('dataset is tokenized properly', async () => { + const iter = await getIteratorArray(config) + const { x, y } = await iter.next() + + /** + * Flatten the batch by taking the first token in x and the rest in y, since y is x shifted by 1 + 1 token + * e.g. [a, b, c, d, e, f] -> x = [a, b, c, d, e] and y = [b, c, d, e, f] + * thus x[0] + y = [a, b, c, d, e, f] + **/ + + const sample: number[] = [] + + for (let i = 0; i < config.batchSize; i++) { + sample.push(x[i][0], ...y[i]) + } + const textLength = decode(sample).length + const tokens = await getRawTokenizedSample(textLength, sample.length) + + expect(sample.length).toBe(tokens.length) + expect(sample).toEqual(tokens) + }) + + test(`benchmark ${BENCHMARK_ITERATIONS} iterations for batch sizes: ${BENCHMARK_BATCH_SIZES} and block sizes: ${BENCHMARK_BLOCK_SIZES}`, async () => { + for (const batchSize of BENCHMARK_BATCH_SIZES) { + for (const blockSize of BENCHMARK_BLOCK_SIZES) { + const c = { + ...config, + batchSize, + blockSize, + } + const iter = await getIterator(c) + const benchmarkStart = Date.now() + for (let i = 0; i < BENCHMARK_ITERATIONS; i++) { + const { xs, ys } = await iter.next() + if (i === 0) correctShapeTest(xs, ys, c) + } + const benchmarkEnd = Date.now() + const ms = benchmarkEnd - benchmarkStart + console.log( + `[batchSize=${c.batchSize}, blockSize=${ + c.blockSize + }] Time per iteration: ${( + ms / BENCHMARK_ITERATIONS + ).toFixed(3)}ms` + ) + } + } + }, 256_000) + + test.skip(`one iteration on gpt with block size ${config.blockSize} and batch size ${config.batchSize}`, async () => { + const t: Task = { + ...task, + trainingInformation: { + ...task.trainingInformation, + maxIterations: 10, + }, + } + + const ds = await getDataset(config) + const data = await dataset.TextData.init(ds, t) + const url = new URL('', 'http://localhost:8000') + const d = new Disco(t, { url }) + // Stop training and disconnect from the remote server + const trainer = await (d as any).trainer + trainer.fit({ train: data }) + await d.close() + }) +}) diff --git a/discojs/discojs-web/src/dataset/data_loader/text_loader.ts b/discojs/discojs-web/src/dataset/data_loader/text_loader.ts index 09b813850..821630548 100644 --- a/discojs/discojs-web/src/dataset/data_loader/text_loader.ts +++ b/discojs/discojs-web/src/dataset/data_loader/text_loader.ts @@ -1,14 +1,136 @@ -import { tf } from '../..' -import { Dataset } from '../../core/dataset' -import { TextLoader } from '../../core/dataset/data_loader/text_loader' - -export class WebTextLoader extends TextLoader { - async loadDatasetFrom (source: File, config?: Record): Promise { - const file = new tf.data.FileDataSource(source) - if (config !== undefined) { - return new tf.data.CSVDataset(file, config) - } else { - return new tf.data.TextLineDataset(file) - } - } +import { v4 as randomUUID } from 'uuid' +import { dataset } from '../..' +import { Deferred } from './cache' + +type MessageData = { + value: { + type: 'Buffer' + data: number[] + } + done: boolean +} + +export class WebTextLoader extends dataset.loader.TextLoader { + static readonly CACHE_SIZE: number = 10 + + // ========================= WORKER + CACHE ========================= + // TODO:? make this faster than just having a ws instance in the loader + // getWorker = (file: string, config: dataset.TextConfig) => { + // const workerURL = new URL('worker.ts', import.meta.url).href + // const worker = new Worker(workerURL, { + // env: { + // FILE: file, + // CONFIG: JSON.stringify(config), + // CACHE_SIZE: WebTextLoader.CACHE_SIZE, + // }, + // } as WorkerOptions) + + // return new Promise((resolve) => { + // // waiting for a message from the worker to inform the loader + // // that the websocket connection is opened + // worker.onmessage = () => { + // resolve(worker) + // } + // }) + // } + + getWebsocket(file: string, config: dataset.TextConfig) { + const BROKER_URL = 'ws://localhost:3001/ws' + const url = new URL(BROKER_URL) + + const id = randomUUID() + const searchParams: dataset.WSSearchParams = { + id, + config: JSON.stringify(config), + file, + } + for (const [k, v] of Object.entries(searchParams)) + url.searchParams.append(k, v) + + const ws = new WebSocket(url) + + ws.onerror = (err) => { + console.error(err) + } + + return new Promise<{ ws: WebSocket; id: string }>((resolve) => { + ws.onopen = () => { + resolve({ ws, id }) + } + }) + } + + async load( + file: string, + config: dataset.TextConfig + ): Promise { + // TODO: /!\ implement a way to close websocket at the end of training + // onTrainEnd = () => ws.close() + + // ========================= WORKER + CACHE ========================= + // TODO:? make this faster than just having a ws instance in the loader + // const worker = await this.getWorker(file, config) + // const cache = await Cache.init( + // WebTextLoader.CACHE_SIZE, + // (pos, init) => { + // worker.postMessage(JSON.stringify({ pos, init })) + // }, + // (c) => { + // worker.onmessage = (payload: globalThis.MessageEvent) => { + // const sample = JSON.parse( + // payload.data as string + // ) as CacheData + // c.put(sample.pos, sample) + // } + // } + // ) + const { ws, id } = await this.getWebsocket(file, config) + const cache = new Deferred<{ value: number[]; done: boolean }>() + + ws.onmessage = (payload: globalThis.MessageEvent) => { + const sample = JSON.parse(payload.data as string) as MessageData + cache.resolve({ value: sample.value.data, done: sample.done }) + } + + const iterator = { + next: async () => { + ws.send(JSON.stringify({ id })) + const sample = await cache.promise + cache.reset() + return sample + }, + } + + const dataset = await this.getCoreDataset(config, iterator) + return dataset + } + + async loadAll( + source: dataset.TextSource, + config?: Partial | undefined + ): Promise { + const _config = this.resolveConfig(config) + + const loadFromSources = async (files: string[]) => { + const datasets = await Promise.all( + files.map((f) => this.load(f, _config)) ?? [] + ) + const ds = + datasets.length > 1 + ? datasets + .slice(1) + .reduce( + (acc, cur) => acc.concatenate(cur), + datasets[0] + ) + : datasets[0] + return await dataset.TextData.init(ds, this.task) + } + + return { + train: await loadFromSources(source.train), + validation: + source.validation && (await loadFromSources(source.validation)), + } + } } diff --git a/discojs/discojs-web/src/memory/memory.ts b/discojs/discojs-web/src/memory/memory.ts index cc146a843..e66f432a2 100644 --- a/discojs/discojs-web/src/memory/memory.ts +++ b/discojs/discojs-web/src/memory/memory.ts @@ -12,148 +12,168 @@ import path from 'path' import { tf, Memory, ModelType, Path, ModelInfo, ModelSource } from '..' export class IndexedDB extends Memory { - pathFor (source: ModelSource): Path { - if (typeof source === 'string') { - return source + pathFor(source: ModelSource): Path { + if (typeof source === 'string') { + return source + } + + if ( + source.type === undefined || + source.taskID === undefined || + source.name === undefined + ) { + throw new TypeError('source incomplete') + } + + const version = source.version ?? 0 + + return `indexeddb://${path.join( + source.type, + source.taskID, + source.name + )}@${version}` } - if (source.type === undefined || source.taskID === undefined || source.name === undefined) { - throw new TypeError('source incomplete') - } - - const version = source.version ?? 0 + infoFor(source: ModelSource): ModelInfo { + if (typeof source !== 'string') { + return source + } + const [stringType, taskID, fullName] = source.split('/').splice(2) - return `indexeddb://${path.join(source.type, source.taskID, source.name)}@${version}` - } + const type = + stringType === 'working' ? ModelType.WORKING : ModelType.SAVED - infoFor (source: ModelSource): ModelInfo { - if (typeof source !== 'string') { - return source + const [name, versionSuffix] = fullName.split('@') + const version = versionSuffix === undefined ? 0 : Number(versionSuffix) + return { type, taskID, name, version } } - const [stringType, taskID, fullName] = source.split('/').splice(2) - - const type = stringType === 'working' ? ModelType.WORKING : ModelType.SAVED - const [name, versionSuffix] = fullName.split('@') - const version = versionSuffix === undefined ? 0 : Number(versionSuffix) - return { type, taskID, name, version } - } - - async getModelMetadata (source: ModelSource): Promise { - const models = await tf.io.listModels() - return models[this.pathFor(source)] - } + async getModelMetadata( + source: ModelSource + ): Promise { + const models = await tf.io.listModels() + return models[this.pathFor(source)] + } - async contains (source: ModelSource): Promise { - return await this.getModelMetadata(source) !== undefined - } + async contains(source: ModelSource): Promise { + return (await this.getModelMetadata(source)) !== undefined + } - async getModel (source: ModelSource): Promise { - return await tf.loadLayersModel(this.pathFor(source)) - } + async getModel(source: ModelSource): Promise { + return await tf.loadLayersModel(this.pathFor(source)) + } - async deleteModel (source: ModelSource): Promise { - await tf.io.removeModel(this.pathFor(source)) - } + async deleteModel(source: ModelSource): Promise { + await tf.io.removeModel(this.pathFor(source)) + } - async loadModel (source: ModelSource): Promise { - const src = this.infoFor(source) - if (src.type === ModelType.WORKING) { - // Model is already loaded - return + async loadModel(source: ModelSource): Promise { + const src = this.infoFor(source) + if (src.type === ModelType.WORKING) { + // Model is already loaded + return + } + await tf.io.copyModel( + this.pathFor(src), + this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }) + ) } - await tf.io.copyModel( - this.pathFor(src), - this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }) - ) - } - - /** - * Saves the working model to the source. - * @param source the destination - * @param model the model - */ - async updateWorkingModel (source: ModelSource, model: tf.LayersModel): Promise { - const src: ModelInfo = this.infoFor(source) - if (src.type !== undefined && src.type !== ModelType.WORKING) { - throw new Error('expected working model') + + /** + * Saves the working model to the source. + * @param source the destination + * @param model the model + */ + async updateWorkingModel( + source: ModelSource, + model: tf.LayersModel + ): Promise { + const src: ModelInfo = this.infoFor(source) + if (src.type !== undefined && src.type !== ModelType.WORKING) { + throw new Error('expected working model') + } + // Enforce version 0 to always keep a single working model at a time + await model.save( + this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }) + ) } - // Enforce version 0 to always keep a single working model at a time - await model.save(this.pathFor({ ...src, type: ModelType.WORKING, version: 0 })) - } - /** - * Creates a saved copy of the working model corresponding to the source. - * @param source the source - */ - async saveWorkingModel (source: ModelSource): Promise { - const src: ModelInfo = this.infoFor(source) - if (src.type !== undefined && src.type !== ModelType.WORKING) { - throw new Error('expected working model') + /** + * Creates a saved copy of the working model corresponding to the source. + * @param source the source + */ + async saveWorkingModel(source: ModelSource): Promise { + const src: ModelInfo = this.infoFor(source) + if (src.type !== undefined && src.type !== ModelType.WORKING) { + throw new Error('expected working model') + } + const dst = this.pathFor( + await this.duplicateSource({ ...src, type: ModelType.SAVED }) + ) + await tf.io.copyModel( + this.pathFor({ ...src, type: ModelType.WORKING }), + dst + ) + return dst } - const dst = this.pathFor(await this.duplicateSource({ ...src, type: ModelType.SAVED })) - await tf.io.copyModel( - this.pathFor({ ...src, type: ModelType.WORKING }), - dst - ) - return dst - } - - async saveModel (source: ModelSource, model: tf.LayersModel): Promise { - const src: ModelInfo = this.infoFor(source) - if (src.type !== undefined && src.type !== ModelType.SAVED) { - throw new Error('expected saved model') + + async saveModel(source: ModelSource, model: tf.LayersModel): Promise { + const src: ModelInfo = this.infoFor(source) + if (src.type !== undefined && src.type !== ModelType.SAVED) { + throw new Error('expected saved model') + } + const dst = this.pathFor( + await this.duplicateSource({ ...src, type: ModelType.SAVED }) + ) + await model.save(dst) + return dst } - const dst = this.pathFor(await this.duplicateSource({ ...src, type: ModelType.SAVED })) - await model.save(dst) - return dst - } - - /** - * Downloads the model corresponding to the source. - * @param source the source - */ - async downloadModel (source: ModelSource): Promise { - const src: ModelInfo = this.infoFor(source) - await tf.io.copyModel( - this.pathFor(source), - `downloads://${src.taskID}_${src.name}` - ) - } - - async latestDuplicate (source: ModelSource): Promise { - if (typeof source !== 'string') { - source = this.pathFor({ ...source, version: 0 }) + + /** + * Downloads the model corresponding to the source. + * @param source the source + */ + async downloadModel(source: ModelSource): Promise { + const src: ModelInfo = this.infoFor(source) + await tf.io.copyModel( + this.pathFor(source), + `downloads://${src.taskID}_${src.name}` + ) } - // perform a single memory read - const paths = Map(await tf.io.listModels()) + async latestDuplicate(source: ModelSource): Promise { + if (typeof source !== 'string') { + source = this.pathFor({ ...source, version: 0 }) + } - if (!paths.has(source)) { - return undefined - } + // perform a single memory read + const paths = Map(await tf.io.listModels()) + + if (!paths.has(source)) { + return undefined + } - const latest = Map(paths) - .keySeq() - .toList() - .map((p) => this.infoFor(p).version) - .max() + const latest = Map(paths) + .keySeq() + .toList() + .map((p) => this.infoFor(p).version) + .max() - if (latest === undefined) { - return 0 + if (latest === undefined) { + return 0 + } + + return latest } - return latest - } + async duplicateSource(source: ModelSource): Promise { + const latestDuplicate = await this.latestDuplicate(source) + source = this.infoFor(source) - async duplicateSource (source: ModelSource): Promise { - const latestDuplicate = await this.latestDuplicate(source) - source = this.infoFor(source) + if (latestDuplicate === undefined) { + return source + } - if (latestDuplicate === undefined) { - return source + return { ...source, version: latestDuplicate + 1 } } - - return { ...source, version: latestDuplicate + 1 } - } } diff --git a/docs/TASK.md b/docs/TASK.md index 16ed7ee36..4e5af2917 100644 --- a/docs/TASK.md +++ b/docs/TASK.md @@ -6,22 +6,22 @@ Disco.js currently allows learning of arbitrary machine learning tasks, where ta 2. New tasks defined via the [**task creation form**](https://epfml.github.io/disco/#/create), via the Disco web UI, without programming knowledge needed 3. New **custom tasks** - ## Bringing your ML model to Disco To use an existing model in Disco, we first need to convert the model to TensorFlowJS format, consisting of a TensorFlowJS model file in a JSON format for the neural network architecture, and an optional weight file in .bin format if you want to start from a particular initialization or a pretrained model. If your model comes from another framework than TensorflowJS, like Pytorch or Tensorflow/Keras, but you still want to bring it to DisCo, we indicate the appropriate procedure as follows. - ### Importing models or weights from PyTorch to TensorflowJS The simplest way to obtain a TensorflowJS model is to first obtain a Python Tensorflow/Keras model, stored as a .h5 file, and then convert it using TensorflowJS's converter tool, which transforms any Tensorflow/Keras model to TensorflowJS. One recommended way to obtain a Python Tensorflow/Keras model it to directly develop the model in Keras: most of PyTorch components have their equivalent counterpart in Tensorflow/Keras, and translating model architectures between these two frameworks can be done in a straightforward way. One caveat is that for more complex models, pretrained weights can currently not automatically be converted from the Python `.pth` format to the Keras `.h5` format. If you plan to retrain the model from scratch in Disco, this is no problem. On the other hand if you want to import pretrained Python model weights you currently have to first obtain corresponding Keras weights, from which you can then TF.js weights. Given your keras model file, to convert it to a TensorFlowJS model: + ```bash $ tensorflowjs_converter --input_format=keras my_model_name.h5 /tfjs_model ``` Side Note: If you already have a TensorFlow (Python) saved model ([LayersModel](https://www.tensorflow.org/js/guide/models_and_layers)), then the conversion to TensorFlowJS is straightforward with the following command: + ```bash $ tensorflowjs_converter --input_format=tf_saved_model my_tensorflow_saved_model /tmp/tfjs_model ``` @@ -31,24 +31,23 @@ Make sure to convert to TF.js [LayersModel](https://www.tensorflow.org/js/guide/ Following the `tensorflowjs_converter` command, you will recover two files : a .json describing your model architecture, and a collection of .bin files describing your model weights, which are ready to be uploaded on DisCo. We describe this procedure in the paragraphs below. Note that the following conversion is only possible in cases of models for which TensorFlowJS possesses the [corresponding modules](https://js.tensorflow.org/api/latest/). -*Side Note : There exist several libraries that try to perform automatic conversion between frameworks, which we do not recommend as most of the tools have compatibility issues for models containing components which differ strongly in implementation between the two frameworks.* - - - - +_Side Note : There exist several libraries that try to perform automatic conversion between frameworks, which we do not recommend as most of the tools have compatibility issues for models containing components which differ strongly in implementation between the two frameworks._ ## 1) Simple use case: Using the user interface directly for creating a new task + I am a user who wants to define my custom task and bring my model to Disco, without doing any programming. In this case, you use our existing supported data modalities and preprocessing (such as tabular, images, text etc). For this use case, an initial `.bin` weight file of your TF.js model is mandatory. - - Through the Disco user interface, click on the *create* button on "Add your own model to be trained in a DISCOllaborative" - - Fill in all the relevant information for your task and model - - Upload the .json + .bin model in the *Model Files* box. - Your task has been successfully instantiated. +- Through the Disco user interface, click on the _create_ button on "Add your own model to be trained in a DISCOllaborative" +- Fill in all the relevant information for your task and model +- Upload the .json + .bin model in the _Model Files_ box. + Your task has been successfully instantiated. ## 2) Procedure for adding a custom task + In order to add a completely new custom task to Disco.js using our own code (such as for data loading, preprocessing etc), we need to defined a `TaskProvider` which need to implement two methods: - * `getTask` which returns a `Task` as defined [here](../discojs/discojs-core/src/task/task.ts), the `Task` contains all the crucial information from training to the mode - * `getModel` which returns a `Promise` specifying a model architecture for the task + +- `getTask` which returns a `Task` as defined [here](../discojs/discojs-core/src/task/task.ts), the `Task` contains all the crucial information from training to the mode +- `getModel` which returns a `Promise` specifying a model architecture for the task You can find examples of `TaskProvider` currently used in our Disco server in `discojs/discojs-core/src/default_tasks/`. These tasks are all loaded by our server by default. @@ -59,8 +58,9 @@ For the task creation of new custom tasks, if you can not go through the user in **I am a developper who wants to define my own custom task** If you want to add a new task to our production DISCO server you have two possibilities: - * using the user interface as described above (no coding required) - * exporting your own `TaskProvider` from `discojs/discojs-core/src/default_tasks/` and adding a new default task by contributing to the code. (describing the task in Typescript code) + +- using the user interface as described above (no coding required) +- exporting your own `TaskProvider` from `discojs/discojs-core/src/default_tasks/` and adding a new default task by contributing to the code. (describing the task in Typescript code) To export a new task in the code, make sure to export the `TaskProvider` in the `discojs/discojs-core/src/default_tasks/index.ts` file as follows: @@ -82,24 +82,24 @@ import { Disco, tf } from '@epfml/disco-server' // Define your own task provider (task definition + model) const customTask: TaskProvider = { getTask(): Task { - return { - // Your task definition - } + return { + // Your task definition + } }, - + async getModel(): Promise { - const model = tf.sequential() - // Configure your model architechture - return model - } - } + const model = tf.sequential() + // Configure your model architechture + return model + }, +} async function runServer() { - const disco = new Disco() - // Add your own custom task - await disco.addTask(customTask) - // Start the server - disco.serve() + const disco = new Disco() + // Add your own custom task + await disco.addTask(customTask) + // Start the server + disco.serve() } runServer() @@ -111,11 +111,9 @@ For your custom model, the JSON model architecture is necessary, but the .bin we For more detail about how to define a `Task` and a `tf.LayersModel` for your own `TaskProvider`, continue reading. - - ### Model -The interface let you load your model however you want, as long as you return a `tf.LayersModel` at the end. If you use a +The interface let you load your model however you want, as long as you return a `tf.LayersModel` at the end. If you use a pre-trained model, you can simply load and return said model in the function via `tf.loadLayersModel(modelPath)`. ```js @@ -125,56 +123,55 @@ async function getModel (_: string): Promise { // Add layers model.add(...) - + return model ``` -Alternatively we can also load a pre-existing model; if we only provide a `model.json` file, then only the architecture of the model will be +Alternatively we can also load a pre-existing model; if we only provide a `model.json` file, then only the architecture of the model will be loaded. If however in the same path we also include `weights.bin`, then pre-trained weights stored in these files will also be loaded to the model. ```js -async function getModel (modelPath: string): Promise { - return await tf.loadLayersModel(`file://${modelPath}`) +async function getModel(modelPath: string): Promise { + return await tf.loadLayersModel(`file://${modelPath}`) } ``` > Reminder that the tasks and models definition are used by the server. The server then exposes the initial models to the clients that want to train them locally. So the server need to be able to retrieve the model if it's stored in a remote location. > When the training begin, the client retrieves the **initial** model stored on the server. Then depending on the scheme the model **updates** (without training data) are: -> -> * Sent to the server for aggregation (**federated scheme**) -> * At some point the server will update its stored model to benefit future client trainings -> * Shared between peers for aggregation (no interaction with server) (**decentralized scheme**) -> * In this case, the server never have the opportunity to update the initial model as it's kept between peers. +> +> - Sent to the server for aggregation (**federated scheme**) +> - At some point the server will update its stored model to benefit future client trainings +> - Shared between peers for aggregation (no interaction with server) (**decentralized scheme**) +> - In this case, the server never have the opportunity to update the initial model as it's kept between peers. In summary here are the most common ways of loading a model: -* Loading the model from the web (example in [cifar10](../discojs/discojs-core/src/default_tasks/cifar10.ts)) -* Loading the model from the local filesystem (similar to the web with a file path from the server filesystem) -* Defining the architecture directly in the `TaskProvider` (example in [luscovid](../discojs/discojs-core/src/default_tasks/lus_covid.ts)) +- Loading the model from the web (example in [cifar10](../discojs/discojs-core/src/default_tasks/cifar10.ts)) +- Loading the model from the local filesystem (similar to the web with a file path from the server filesystem) +- Defining the architecture directly in the `TaskProvider` (example in [luscovid](../discojs/discojs-core/src/default_tasks/lus_covid.ts)) At runtime, the models are stored in `disco/server/models/`, and it is also in the server side that we let disco know where exactly they are saved. > If you are using a pre-existing model, and the data shape does not match the input of the model, then it is possible -to use preprocessing functions to resize the data (we also describe how to add custom preprocessing). +> to use preprocessing functions to resize the data (we also describe how to add custom preprocessing). ### Task -The `Task` class contains all the crucial information for training the model (batchSize, learningRate, ...) and also the +The `Task` class contains all the crucial information for training the model (batchSize, learningRate, ...) and also the scheme of distributed learning (federated or decentralized), along with other meta data about the model and data. -> In the appendix (end of this document) you find all possible [`TrainingInformation`](../discojs/discojs-core/src/task/training_information.ts) parameters with a short description. +> In the appendix (end of this document) you find all possible [`TrainingInformation`](../discojs/discojs-core/src/task/training_information.ts) parameters with a short description. As an example, the task class for `simple-face` can be found [here](../discojs/discojs-core/src/default_tasks/simple_face.ts), suppose our own task is a binary classification for age detection (similar to simple face), then we could write: - ```js import { ImagePreprocessing } from '../dataset/preprocessing' export const customTask: TaskProvider = { getTask (): Task { return { - taskID: 'my_new_task', + id: 'my_new_task', displayInformation: { taskTitle: 'My new task', summary: 'Can you detect if the person in a picture is a child or an adult?', @@ -205,14 +202,14 @@ export const customTask: TaskProvider = { } ``` -The `Task` interface has three fields: a mandatory `taskID` (of `string` type), an optional `displayInformation`, and an optional `trainingInformation`. The interfaces for the optional fields are [`DisplayInformation`](../discojs/discojs-core/src/task/display_information.ts) and [`TrainingInformation`](../discojs/discojs-core/src/task/training_information.ts). +The `Task` interface has three fields: a mandatory `id` (of `string` type), an optional `displayInformation`, and an optional `trainingInformation`. The interfaces for the optional fields are [`DisplayInformation`](../discojs/discojs-core/src/task/display_information.ts) and [`TrainingInformation`](../discojs/discojs-core/src/task/training_information.ts). ### Preprocessing In the Task object we can optionally choose to add preprocessing functions. Preprocessing is defined [here](../discojs/discojs-core/src/dataset/data/preprocessing.ts), and is currently only implemented for images (e.g. resize, normalize, ...). -Suppose we want our custom preprocessing that divides each pixel value by 2. In the [preprocessing](../discojs/discojs-core/src/dataset/data/preprocessing.ts) file, +Suppose we want our custom preprocessing that divides each pixel value by 2. In the [preprocessing](../discojs/discojs-core/src/dataset/data/preprocessing.ts) file, first we add the enum of our custom function: ```js @@ -225,14 +222,13 @@ export enum ImagePreprocessing { If your task requires a preprocessing function to be applied to the data before training, you can specifiy it in the `preprocessingFunctions` field of the `trainingInformation` parameter in the task object. In order to add custom preprocessing function, either extend the `Preprocessing` type and define your preprocessing functions in the [preprocessing](../discojs/discojs-core/src/dataset/data/preprocessing.ts) file. If the preprocessing function is challenging to implement in JS (e.g requires complex audio preprocessing for JS), we recommend implementing in some other language which supports the desired preprocessing (e.g. Python) and feed the preprocessed data to the task. - #### Rebuild Then we define our custom function ```js -function custom (image: tf.Tensor3D): tf.Tensor3D { -return image.div(tf.scalar(2)) +function custom(image: tf.Tensor3D): tf.Tensor3D { + return image.div(tf.scalar(2)) } ``` @@ -274,21 +270,20 @@ export const task: Task = { > Note that you need to rebuild discojs every time you make changes to it (`cd discojs; rm -rf dist/; npm run build`). -## Summary +## Summary -- In ```disco/discojs/discojs-core/src/default_tasks/``` define your new custom task by implementing the `TaskProvider` interface. You will need to have your model in the .json + .bin format. - - In ```disco/discojs/discojs-core/src/default_tasks/index.ts``` export your newly defined task - - Run the ```./build.sh``` script from ```discojs/discojs-core``` - - Reinstall cleanly the server by running ```npm ci``` from ```disco/server``` - - Reinstall cleanly the client by running ```npm ci``` from ```disco/web-client``` - - Instantiate a Disco server by running ```npm run dev``` from ```disco/server``` - - Instanciate a Disco client by running ```npm run dev``` from ```disco/web-client``` +- In `disco/discojs/discojs-core/src/default_tasks/` define your new custom task by implementing the `TaskProvider` interface. You will need to have your model in the .json + .bin format. +- In `disco/discojs/discojs-core/src/default_tasks/index.ts` export your newly defined task +- Run the `./build.sh` script from `discojs/discojs-core` +- Reinstall cleanly the server by running `npm ci` from `disco/server` +- Reinstall cleanly the client by running `npm ci` from `disco/web-client` +- Instantiate a Disco server by running `npm run dev` from `disco/server` +- Instanciate a Disco client by running `npm run dev` from `disco/web-client` Your task has been successfully uploaded. **Or** just use the NPM `disco-server` package and add your own custom `TaskProvider` directly to the server. - ## Appendix The [`TrainingInformation`](../discojs/src/task/training_information.ts) of a task contains the following customizable parameters diff --git a/server/src/get_server.ts b/server/src/get_server.ts index 0e9ce9feb..19ee3a0f9 100644 --- a/server/src/get_server.ts +++ b/server/src/get_server.ts @@ -9,63 +9,68 @@ import { tf, Task, TaskProvider } from '@epfml/discojs-node' import * as http from 'http' export class Disco { - private readonly _app: express.Application - private readonly tasksAndModels: TasksAndModels + private readonly _app: express.Application + private readonly tasksAndModels: TasksAndModels - constructor () { - this._app = express() - this.tasksAndModels = new TasksAndModels() - } + constructor() { + this._app = express() + this.tasksAndModels = new TasksAndModels() + } - public get server (): express.Application { - return this._app - } + public get server(): express.Application { + return this._app + } - // Load tasks provided by default with disco server - async addDefaultTasks (): Promise { - await this.tasksAndModels.loadDefaultTasks() - } + // Load tasks provided by default with disco server + async addDefaultTasks(): Promise { + await this.tasksAndModels.loadDefaultTasks() + } - // If a model is not provided, its url must be provided in the task object - async addTask (task: Task | TaskProvider, model?: tf.LayersModel | URL): Promise { - await this.tasksAndModels.addTaskAndModel(task, model) - } + // If a model is not provided, its url must be provided in the task object + async addTask( + task: Task | TaskProvider, + model?: tf.LayersModel | URL + ): Promise { + await this.tasksAndModels.addTaskAndModel(task, model) + } - serve (port?: number): http.Server { - const wsApplier = expressWS(this.server, undefined, { leaveRouterUntouched: true }) - const app = wsApplier.app + serve(port?: number): http.Server { + const wsApplier = expressWS(this.server, undefined, { + leaveRouterUntouched: true, + }) + const app = wsApplier.app - app.enable('trust proxy') - app.use(cors()) - app.use(express.json({ limit: '50mb' })) - app.use(express.urlencoded({ limit: '50mb', extended: false })) + app.enable('trust proxy') + app.use(cors()) + app.use(express.json({ limit: '50mb' })) + app.use(express.urlencoded({ limit: '50mb', extended: false })) - const baseRouter = new Router(wsApplier, this.tasksAndModels, CONFIG) - app.use('/', baseRouter.router) + const baseRouter = new Router(wsApplier, this.tasksAndModels, CONFIG) + app.use('/', baseRouter.router) - const server = app.listen(port ?? CONFIG.serverPort, () => { - console.log(`Disco Server listening on ${CONFIG.serverUrl.href}`) - }) + const server = app.listen(port ?? CONFIG.serverPort, () => { + console.log(`Disco Server listening on ${CONFIG.serverUrl.href}`) + }) - console.info('Disco Server initially loaded the tasks below\n') - console.table( - Array.from(this.tasksAndModels.tasksAndModels).map(t => { - return { - ID: t[0].taskID, - Title: t[0].displayInformation.taskTitle, - 'Data Type': t[0].trainingInformation.dataType, - Scheme: t[0].trainingInformation.scheme - } - }) - ) - console.log() + console.info('Disco Server initially loaded the tasks below\n') + console.table( + Array.from(this.tasksAndModels.tasksAndModels).map((t) => { + return { + ID: t[0].id, + Title: t[0].displayInformation.taskTitle, + 'Data Type': t[0].trainingInformation.dataType, + Scheme: t[0].trainingInformation.scheme, + } + }) + ) + console.log() - return server - } + return server + } } -export async function runDefaultServer (port?: number): Promise { - const disco = new Disco() - await disco.addDefaultTasks() - return disco.serve(port) +export async function runDefaultServer(port?: number): Promise { + const disco = new Disco() + await disco.addDefaultTasks() + return disco.serve(port) } diff --git a/server/src/router/decentralized/server.ts b/server/src/router/decentralized/server.ts index d90a2b4d7..1fedac63b 100644 --- a/server/src/router/decentralized/server.ts +++ b/server/src/router/decentralized/server.ts @@ -15,131 +15,152 @@ import AssignNodeID = client.messages.AssignNodeID import MessageTypes = client.messages.type export class Decentralized extends Server { - /** - * Map associating task ids to their sets of nodes who have contributed. - */ - private readyNodes: Map> = Map() - /** - * Map associating node ids to their open WebSocket connections. - */ - private connections: Map = Map() - - protected get description (): string { - return 'Disco Decentralized Server' - } - - protected buildRoute (task: Task): string { - return `/${task.taskID}` - } - - public isValidUrl (url: string | undefined): boolean { - const splittedUrl = url?.split('/') - - return ( - splittedUrl !== undefined && - splittedUrl.length === 3 && - splittedUrl[0] === '' && - this.isValidTask(splittedUrl[1]) && - this.isValidWebSocket(splittedUrl[2]) - ) - } - - protected initTask (task: Task, model: tf.LayersModel): void {} - - protected handle ( - task: Task, - ws: import('ws'), - model: tf.LayersModel, - req: express.Request< - ParamsDictionary, - any, - any, - ParsedQs, - Record - > - ): void { - // TODO @s314cy: add to task definition, to be used as threshold in aggregator - const minimumReadyPeers = task.trainingInformation?.minimumReadyPeers ?? 3 - - // Peer id of the message sender - let peerId = randomUUID() - while (this.connections.has(peerId)) { - peerId = randomUUID() + /** + * Map associating task ids to their sets of nodes who have contributed. + */ + private readyNodes: Map> = Map() + /** + * Map associating node ids to their open WebSocket connections. + */ + private connections: Map = Map() + + protected get description(): string { + return 'Disco Decentralized Server' } - // How the server responds to messages - ws.on('message', (data: Buffer) => { - try { - const msg: unknown = msgpack.decode(data) - if (!messages.isMessageToServer(msg)) { - console.warn('invalid message received:', msg) - return - } + protected buildRoute(task: Task): string { + return `/${task.id}` + } - switch (msg.type) { - // A new peer joins the network - case MessageTypes.ClientConnected: { - this.connections = this.connections.set(peerId, ws) - const msg: AssignNodeID = { - type: MessageTypes.AssignNodeID, - id: peerId - } - console.info('Peer', peerId, 'joined', task.taskID) + public isValidUrl(url: string | undefined): boolean { + const splittedUrl = url?.split('/') - // Add the new task and its set of nodes - if (!this.readyNodes.has(task.taskID)) { - this.readyNodes = this.readyNodes.set(task.taskID, Set()) - } + return ( + splittedUrl !== undefined && + splittedUrl.length === 3 && + splittedUrl[0] === '' && + this.isValidTask(splittedUrl[1]) && + this.isValidWebSocket(splittedUrl[2]) + ) + } - ws.send(msgpack.encode(msg), { binary: true }) - break - } + protected initTask(task: Task, model: tf.LayersModel): void {} + + protected handle( + task: Task, + ws: import('ws'), + model: tf.LayersModel, + req: express.Request< + ParamsDictionary, + any, + any, + ParsedQs, + Record + > + ): void { + // TODO @s314cy: add to task definition, to be used as threshold in aggregator + const minimumReadyPeers = + task.trainingInformation?.minimumReadyPeers ?? 3 + + // Peer id of the message sender + let peerId = randomUUID() + while (this.connections.has(peerId)) { + peerId = randomUUID() + } - // Forwards a peer's message to another destination peer - case MessageTypes.SignalForPeer: { - const forward: messages.SignalForPeer = { - type: MessageTypes.SignalForPeer, - peer: peerId, - signal: msg.signal - } - this.connections.get(msg.peer)?.send(msgpack.encode(forward)) - break - } - case MessageTypes.PeerIsReady: { - const peers = this.readyNodes.get(task.taskID)?.add(peerId) - if (peers === undefined) { - throw new Error(`task ${task.taskID} doesn't exist in ready buffer`) - } - this.readyNodes = this.readyNodes.set(task.taskID, peers) - - if (peers.size >= minimumReadyPeers) { - this.readyNodes = this.readyNodes.set(task.taskID, Set()) - - peers - .map((id) => { - const readyPeerIDs: messages.PeersForRound = { - type: MessageTypes.PeersForRound, - peers: peers.delete(id).toArray() - } - const encoded = msgpack.encode(readyPeerIDs) - return [id, encoded] as [client.NodeID, Buffer] - }) - .map(([id, encoded]) => { - const conn = this.connections.get(id) - if (conn === undefined) { - throw new Error(`peer ${id} marked as ready but not connection to it`) - } - return [conn, encoded] as [WebSocket, Buffer] - }).forEach(([conn, encoded]) => - conn.send(encoded) - ) + // How the server responds to messages + ws.on('message', (data: Buffer) => { + try { + const msg: unknown = msgpack.decode(data) + if (!messages.isMessageToServer(msg)) { + console.warn('invalid message received:', msg) + return + } + + switch (msg.type) { + // A new peer joins the network + case MessageTypes.ClientConnected: { + this.connections = this.connections.set(peerId, ws) + const msg: AssignNodeID = { + type: MessageTypes.AssignNodeID, + id: peerId, + } + console.info('Peer', peerId, 'joined', task.id) + + // Add the new task and its set of nodes + if (!this.readyNodes.has(task.id)) { + this.readyNodes = this.readyNodes.set( + task.id, + Set() + ) + } + + ws.send(msgpack.encode(msg), { binary: true }) + break + } + + // Forwards a peer's message to another destination peer + case MessageTypes.SignalForPeer: { + const forward: messages.SignalForPeer = { + type: MessageTypes.SignalForPeer, + peer: peerId, + signal: msg.signal, + } + this.connections + .get(msg.peer) + ?.send(msgpack.encode(forward)) + break + } + case MessageTypes.PeerIsReady: { + const peers = this.readyNodes.get(task.id)?.add(peerId) + if (peers === undefined) { + throw new Error( + `task ${task.id} doesn't exist in ready buffer` + ) + } + this.readyNodes = this.readyNodes.set(task.id, peers) + + if (peers.size >= minimumReadyPeers) { + this.readyNodes = this.readyNodes.set( + task.id, + Set() + ) + + peers + .map((id) => { + const readyPeerIDs: messages.PeersForRound = + { + type: MessageTypes.PeersForRound, + peers: peers.delete(id).toArray(), + } + const encoded = msgpack.encode(readyPeerIDs) + return [id, encoded] as [ + client.NodeID, + Buffer + ] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error( + `peer ${id} marked as ready but not connection to it` + ) + } + return [conn, encoded] as [ + WebSocket, + Buffer + ] + }) + .forEach(([conn, encoded]) => + conn.send(encoded) + ) + } + break + } + } + } catch (e) { + console.error('when processing WebSocket message:', e) } - break - } - } - } catch (e) { - console.error('when processing WebSocket message:', e) - } - }) - } + }) + } } diff --git a/server/src/router/federated/server.ts b/server/src/router/federated/server.ts index b8a1fc77f..4cca57ab6 100644 --- a/server/src/router/federated/server.ts +++ b/server/src/router/federated/server.ts @@ -6,16 +6,16 @@ import { List, Map } from 'immutable' import msgpack from 'msgpack-lite' import { - client, - tf, - serialization, - AsyncInformant, - Task, - TaskID, - aggregator as aggregators, - WeightsContainer, - MetadataKey, - MetadataValue + client, + tf, + serialization, + AsyncInformant, + Task, + TaskID, + aggregator as aggregators, + WeightsContainer, + MetadataKey, + MetadataValue, } from '@epfml/discojs-node' import { Server } from '../server' @@ -34,265 +34,326 @@ import MessageTypes = client.messages.type * - the timestamp at which the request was made */ interface Log { - timestamp: Date - task: TaskID - round: number - nodeId: client.NodeID - type: MessageTypes + timestamp: Date + task: TaskID + round: number + nodeId: client.NodeID + type: MessageTypes } export class Federated extends Server { - /** - * Aggregators for each hosted task. - */ - private aggregators = Map() - /** - * Promises containing the current round's results. To be awaited on when providing clients - * with the most recent result. - */ - private results = Map>() - /** - * Training informants for each hosted task. - */ - private informants = Map>() - /** - * Contains metadata used for training by clients for a given task and round. - * Stored by task id, round number, node id and metadata key. - */ - private metadataMap = Map>>>() - // TODO use real log system - /** - * Logs of successful requests made to the server. - */ - private logs = List() - - private rounds = Map() - - protected get description (): string { - return 'Disco Federated Server' - } - - protected buildRoute (task: Task): string { - return `/${task.taskID}` - } - - public isValidUrl (url: string | undefined): boolean { - const splittedUrl = url?.split('/') - - return ( - splittedUrl !== undefined && - splittedUrl.length === 3 && - splittedUrl[0] === '' && - this.isValidTask(splittedUrl[1]) && - this.isValidWebSocket(splittedUrl[2]) - ) - } - - /** - * Loop storing aggregation results, every time an aggregation result promise resolves. - * This happens once per round. - * @param aggregator The aggregation handler - */ - private async storeAggregationResult (aggregator: aggregators.Aggregator): Promise { - // Renew the aggregation result promise. - const result = aggregator.receiveResult() - // Store the result promise somewhere for the server to fetch from, so that it can await - // the result on client request. - this.results = this.results.set(aggregator.task.taskID, result) - await result - void this.storeAggregationResult(aggregator) - } - - protected initTask (task: Task, model: tf.LayersModel): void { - const aggregator = new aggregators.MeanAggregator(task, model) - - this.aggregators = this.aggregators.set(task.taskID, aggregator) - this.informants = this.informants.set(task.taskID, new AsyncInformant(aggregator)) - this.rounds = this.rounds.set(task.taskID, 0) - - void this.storeAggregationResult(aggregator) - } - - protected handle ( - task: Task, - ws: WebSocket, - model: tf.LayersModel, - req: express.Request - ): void { - const taskAggregator = this.aggregators.get(task.taskID) - if (taskAggregator === undefined) { - throw new Error('connecting to a non-existing task') + /** + * Aggregators for each hosted task. + */ + private aggregators = Map() + /** + * Promises containing the current round's results. To be awaited on when providing clients + * with the most recent result. + */ + private results = Map>() + /** + * Training informants for each hosted task. + */ + private informants = Map>() + /** + * Contains metadata used for training by clients for a given task and round. + * Stored by task id, round number, node id and metadata key. + */ + private metadataMap = Map< + TaskID, + Map>> + >() + // TODO use real log system + /** + * Logs of successful requests made to the server. + */ + private logs = List() + + private rounds = Map() + + protected get description(): string { + return 'Disco Federated Server' } - // Client id of the message sender - let clientId = randomUUID() - while (!taskAggregator.registerNode(clientId)) { - clientId = randomUUID() - } - - ws.on('message', (data: Buffer) => { - const msg = msgpack.decode(data) - - if (msg.type === MessageTypes.ClientConnected) { - let aggregator = this.aggregators.get(task.taskID) - if (aggregator === undefined) { - aggregator = new aggregators.MeanAggregator(task) - this.aggregators = this.aggregators.set(task.taskID, aggregator) - } - console.info('client', clientId, 'joined', task.taskID) - this.logsAppend(task.taskID, clientId, MessageTypes.ClientConnected, 0) - - const msg: AssignNodeID = { - type: MessageTypes.AssignNodeID, - id: clientId - } - ws.send(msgpack.encode(msg)) - } else if (msg.type === MessageTypes.SendPayload) { - const { payload, round } = msg + protected buildRoute(task: Task): string { + return `/${task.id}` + } - const aggregator = this.aggregators.get(task.taskID) + public isValidUrl(url: string | undefined): boolean { + const splittedUrl = url?.split('/') - this.logsAppend( - task.taskID, - clientId, - MessageTypes.SendPayload, - msg.round + return ( + splittedUrl !== undefined && + splittedUrl.length === 3 && + splittedUrl[0] === '' && + this.isValidTask(splittedUrl[1]) && + this.isValidWebSocket(splittedUrl[2]) ) + } - if (!( - Array.isArray(payload) && - payload.every((e) => typeof e === 'number') - )) { - throw new Error('received invalid weights format') - } - - const serialized = serialization.weights.decode(payload) + /** + * Loop storing aggregation results, every time an aggregation result promise resolves. + * This happens once per round. + * @param aggregator The aggregation handler + */ + private async storeAggregationResult( + aggregator: aggregators.Aggregator + ): Promise { + // Renew the aggregation result promise. + const result = aggregator.receiveResult() + // Store the result promise somewhere for the server to fetch from, so that it can await + // the result on client request. + this.results = this.results.set(aggregator.task.id, result) + await result + void this.storeAggregationResult(aggregator) + } - if (aggregator === undefined) { - throw new Error(`received weights for unknown task: ${task.taskID}`) - } + protected initTask(task: Task, model: tf.LayersModel): void { + const aggregator = new aggregators.MeanAggregator(task, model) - // TODO @s314cy: add communication rounds to federated learning - if (!aggregator.add(clientId, serialized, round, 0)) { - console.info('Dropped contribution from client', clientId, 'for round', round) - } - } else if (msg.type === MessageTypes.ReceiveServerStatistics) { - const statistics = this.informants - .get(task.taskID) - ?.getAllStatistics() - - const msg: messages.ReceiveServerStatistics = { - type: MessageTypes.ReceiveServerStatistics, - statistics: statistics ?? {} - } - - ws.send(msgpack.encode(msg)) - } else if (msg.type === MessageTypes.ReceiveServerPayload) { - const aggregator = this.aggregators.get(task.taskID) - if (aggregator === undefined) { - throw new Error(`requesting round of unknown task: ${task.taskID}`) - } + this.aggregators = this.aggregators.set(task.id, aggregator) + this.informants = this.informants.set( + task.id, + new AsyncInformant(aggregator) + ) + this.rounds = this.rounds.set(task.id, 0) - this.logsAppend(task.taskID, clientId, MessageTypes.ReceiveServerPayload, 0) + void this.storeAggregationResult(aggregator) + } - if (model === undefined) { - throw new Error('aggregator model was not set') + protected handle( + task: Task, + ws: WebSocket, + model: tf.LayersModel, + req: express.Request + ): void { + const taskAggregator = this.aggregators.get(task.id) + if (taskAggregator === undefined) { + throw new Error('connecting to a non-existing task') } - - const promisedResult = this.results.get(task.taskID) - if (promisedResult === undefined) { - throw new Error(`result promise was not set for task ${task.taskID}`) + // Client id of the message sender + let clientId = randomUUID() + while (!taskAggregator.registerNode(clientId)) { + clientId = randomUUID() } - // Wait for aggregation result with timeout, giving the network a time window - // to contribute to the model sent to the requesting client. - void Promise.race([promisedResult, client.utils.timeout()]) - .then((result) => - [result, aggregator.round - 1] as [WeightsContainer, number]) - .then(async ([result, round]) => - [await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number]) - .then(([serialized, round]) => { - const msg: messages.ReceiveServerPayload = { - type: MessageTypes.ReceiveServerPayload, - round, - payload: serialized + ws.on('message', (data: Buffer) => { + const msg = msgpack.decode(data) + + if (msg.type === MessageTypes.ClientConnected) { + let aggregator = this.aggregators.get(task.id) + if (aggregator === undefined) { + aggregator = new aggregators.MeanAggregator(task) + this.aggregators = this.aggregators.set(task.id, aggregator) + } + console.info('client', clientId, 'joined', task.id) + + this.logsAppend( + task.id, + clientId, + MessageTypes.ClientConnected, + 0 + ) + + const msg: AssignNodeID = { + type: MessageTypes.AssignNodeID, + id: clientId, + } + ws.send(msgpack.encode(msg)) + } else if (msg.type === MessageTypes.SendPayload) { + const { payload, round } = msg + + const aggregator = this.aggregators.get(task.id) + + this.logsAppend( + task.id, + clientId, + MessageTypes.SendPayload, + msg.round + ) + + if ( + !( + Array.isArray(payload) && + payload.every((e) => typeof e === 'number') + ) + ) { + throw new Error('received invalid weights format') + } + + const serialized = serialization.weights.decode(payload) + + if (aggregator === undefined) { + throw new Error( + `received weights for unknown task: ${task.id}` + ) + } + + // TODO @s314cy: add communication rounds to federated learning + if (!aggregator.add(clientId, serialized, round, 0)) { + console.info( + 'Dropped contribution from client', + clientId, + 'for round', + round + ) + } + } else if (msg.type === MessageTypes.ReceiveServerStatistics) { + const statistics = this.informants + .get(task.id) + ?.getAllStatistics() + + const msg: messages.ReceiveServerStatistics = { + type: MessageTypes.ReceiveServerStatistics, + statistics: statistics ?? {}, + } + + ws.send(msgpack.encode(msg)) + } else if (msg.type === MessageTypes.ReceiveServerPayload) { + const aggregator = this.aggregators.get(task.id) + if (aggregator === undefined) { + throw new Error( + `requesting round of unknown task: ${task.id}` + ) + } + + this.logsAppend( + task.id, + clientId, + MessageTypes.ReceiveServerPayload, + 0 + ) + + if (model === undefined) { + throw new Error('aggregator model was not set') + } + + const promisedResult = this.results.get(task.id) + if (promisedResult === undefined) { + throw new Error( + `result promise was not set for task ${task.id}` + ) + } + + // Wait for aggregation result with timeout, giving the network a time window + // to contribute to the model sent to the requesting client. + void Promise.race([promisedResult, client.utils.timeout()]) + .then( + (result) => + [result, aggregator.round - 1] as [ + WeightsContainer, + number + ] + ) + .then( + async ([result, round]) => + [ + await serialization.weights.encode(result), + round, + ] as [serialization.weights.Encoded, number] + ) + .then(([serialized, round]) => { + const msg: messages.ReceiveServerPayload = { + type: MessageTypes.ReceiveServerPayload, + round, + payload: serialized, + } + ws.send(msgpack.encode(msg)) + }) + .catch(console.error) + } else if (msg.type === MessageTypes.SendMetadata) { + const { round, key, value } = msg + + this.logsAppend( + task.id, + clientId, + MessageTypes.SendMetadata, + round + ) + + if (this.metadataMap.hasIn([task.id, round, clientId, key])) { + throw new Error('metadata already set') + } + this.metadataMap = this.metadataMap.setIn( + [task, round, clientId, key], + value + ) + } else if (msg.type === MessageTypes.ReceiveServerMetadata) { + const key = msg.metadataId + const round = Number.parseInt(msg.round, 0) + + const taskMetadata = this.metadataMap.get(task.id) + + if ( + !Number.isNaN(round) && + round >= 0 && + taskMetadata !== undefined + ) { + // Find the most recent entry round-wise for the given task (upper bounded + // by the given round). Allows for sporadic entries in the metadata map. + const latestRound = taskMetadata.keySeq().max() ?? round + + // Fetch the required metadata from the general metadata structure stored + // server-side and construct the queried metadata's map accordingly. This + // essentially creates a "ID -> metadata" single-layer map. + const queriedMetadataMap = Map( + taskMetadata + .get( + latestRound, + Map>() + ) + .filter((entries) => entries.has(key)) + .mapEntries(([id, entries]) => [ + id, + entries.get(key), + ]) + ) + + this.logsAppend( + task.id, + clientId, + MessageTypes.ReceiveServerMetadata, + round + ) + + const msg: messages.ReceiveServerMetadata = { + type: MessageTypes.ReceiveServerMetadata, + taskId: task.id, + nodeId: clientId, + key, + round: round, + metadataMap: Array.from(queriedMetadataMap), + } + + ws.send(msgpack.encode(msg)) + } } - ws.send(msgpack.encode(msg)) - }) - .catch(console.error) - } else if (msg.type === MessageTypes.SendMetadata) { - const { round, key, value } = msg - - this.logsAppend(task.taskID, clientId, MessageTypes.SendMetadata, round) + }) + } - if (this.metadataMap.hasIn([task.taskID, round, clientId, key])) { - throw new Error('metadata already set') - } - this.metadataMap = this.metadataMap.setIn( - [task, round, clientId, key], - value - ) - } else if (msg.type === MessageTypes.ReceiveServerMetadata) { - const key = msg.metadataId - const round = Number.parseInt(msg.round, 0) - - const taskMetadata = this.metadataMap.get(task.taskID) - - if (!Number.isNaN(round) && round >= 0 && taskMetadata !== undefined) { - // Find the most recent entry round-wise for the given task (upper bounded - // by the given round). Allows for sporadic entries in the metadata map. - const latestRound = taskMetadata.keySeq().max() ?? round - - // Fetch the required metadata from the general metadata structure stored - // server-side and construct the queried metadata's map accordingly. This - // essentially creates a "ID -> metadata" single-layer map. - const queriedMetadataMap = Map( - taskMetadata - .get(latestRound, Map>()) - .filter((entries) => entries.has(key)) - .mapEntries(([id, entries]) => [id, entries.get(key)]) - ) - - this.logsAppend(task.taskID, clientId, MessageTypes.ReceiveServerMetadata, round) - - const msg: messages.ReceiveServerMetadata = { - type: MessageTypes.ReceiveServerMetadata, - taskId: task.taskID, - nodeId: clientId, - key, - round: round, - metadataMap: Array.from(queriedMetadataMap) - } - - ws.send(msgpack.encode(msg)) + /** + * Appends a request to the logs. + * @param taskId The task id for which the request was made + * @param nodeId The node id who made the request + * @param type The request type + * @param round The round for which the request was made + */ + private logsAppend( + taskId: TaskID, + nodeId: client.NodeID, + type: MessageTypes, + round: number | undefined = undefined + ): void { + if (round === undefined) { + return } - } - }) - } - - /** - * Appends a request to the logs. - * @param taskId The task id for which the request was made - * @param nodeId The node id who made the request - * @param type The request type - * @param round The round for which the request was made - */ - private logsAppend ( - taskId: TaskID, - nodeId: client.NodeID, - type: MessageTypes, - round: number | undefined = undefined - ): void { - if (round === undefined) { - return - } - this.logs = this.logs.push({ - timestamp: new Date(), - task: taskId, - round, - nodeId, - type - }) - } + this.logs = this.logs.push({ + timestamp: new Date(), + task: taskId, + round, + nodeId, + type, + }) + } } diff --git a/server/src/router/server.ts b/server/src/router/server.ts index 0fae93aa9..ab5ee0aac 100644 --- a/server/src/router/server.ts +++ b/server/src/router/server.ts @@ -7,72 +7,73 @@ import { tf, Task } from '@epfml/discojs-node' import { TasksAndModels } from '../tasks' export abstract class Server { - private readonly ownRouter: expressWS.Router + private readonly ownRouter: expressWS.Router - private readonly tasks: string[] = new Array() - private readonly UUIDRegexExp = /^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}$/gi + private readonly tasks: string[] = new Array() + private readonly UUIDRegexExp = + /^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}$/gi - constructor (wsApplier: expressWS.Instance, tasksAndModels: TasksAndModels) { - this.ownRouter = express.Router() - wsApplier.applyTo(this.ownRouter) + constructor(wsApplier: expressWS.Instance, tasksAndModels: TasksAndModels) { + this.ownRouter = express.Router() + wsApplier.applyTo(this.ownRouter) - this.ownRouter.get('/', (_, res) => res.send(this.description + '\n')) + this.ownRouter.get('/', (_, res) => res.send(this.description + '\n')) - // delay listener because this (object) isn't fully constructed yet. The lambda function inside process.nextTick is executed after the current operation on the JS stack runs to completion and before the event loop is allowed to continue. - /* this.onNewTask is registered as a listener to tasksAndModels, which has 2 consequences: + // delay listener because this (object) isn't fully constructed yet. The lambda function inside process.nextTick is executed after the current operation on the JS stack runs to completion and before the event loop is allowed to continue. + /* this.onNewTask is registered as a listener to tasksAndModels, which has 2 consequences: - this.onNewTask is executed on all the default tasks (which are already loaded in tasksAndModels) - Every time a new task and model are added to tasksAndModels, this.onNewTask is executed on them. For every task and model, this.onNewTask creates a path /taskID and routes it to this.handle. */ - process.nextTick(() => - tasksAndModels.addListener('taskAndModel', (t, m) => - this.onNewTask(t, m) - ) - ) - } - - public get router (): express.Router { - return this.ownRouter - } - - private onNewTask (task: Task, model: tf.LayersModel): void { - this.tasks.push(task.taskID) - this.initTask(task, model) - - this.ownRouter.ws(this.buildRoute(task), (ws, req) => { - if (this.isValidUrl(req.url)) { - this.handle(task, ws, model, req) - } else { - ws.terminate() - ws.close() - } - }) - } - - protected isValidTask (taskId: string): boolean { - return this.tasks.filter(e => e === taskId).length === 1 - } - - protected isValidClientId (clientId: string): boolean { - return new RegExp(this.UUIDRegexExp).test(clientId) - } - - protected isValidWebSocket (urlEnd: string): boolean { - return urlEnd === '.websocket' - } - - public abstract isValidUrl (url?: string): boolean - - protected abstract get description (): string - - protected abstract buildRoute (task: Task): string - - protected abstract initTask (task: Task, model: tf.LayersModel): void - - protected abstract handle ( - task: Task, - ws: WebSocket, - model: tf.LayersModel, - req: express.Request, - ): void + process.nextTick(() => + tasksAndModels.addListener('taskAndModel', (t, m) => + this.onNewTask(t, m) + ) + ) + } + + public get router(): express.Router { + return this.ownRouter + } + + private onNewTask(task: Task, model: tf.LayersModel): void { + this.tasks.push(task.id) + this.initTask(task, model) + + this.ownRouter.ws(this.buildRoute(task), (ws, req) => { + if (this.isValidUrl(req.url)) { + this.handle(task, ws, model, req) + } else { + ws.terminate() + ws.close() + } + }) + } + + protected isValidTask(taskId: string): boolean { + return this.tasks.filter((e) => e === taskId).length === 1 + } + + protected isValidClientId(clientId: string): boolean { + return new RegExp(this.UUIDRegexExp).test(clientId) + } + + protected isValidWebSocket(urlEnd: string): boolean { + return urlEnd === '.websocket' + } + + public abstract isValidUrl(url?: string): boolean + + protected abstract get description(): string + + protected abstract buildRoute(task: Task): string + + protected abstract initTask(task: Task, model: tf.LayersModel): void + + protected abstract handle( + task: Task, + ws: WebSocket, + model: tf.LayersModel, + req: express.Request + ): void } diff --git a/server/src/router/tasks.ts b/server/src/router/tasks.ts index 51b11b8c5..e8c9eb0c8 100644 --- a/server/src/router/tasks.ts +++ b/server/src/router/tasks.ts @@ -7,99 +7,113 @@ import { Config } from '../config' import { TasksAndModels } from '../tasks' export class Tasks { - private readonly ownRouter: express.Router - - private tasksAndModels = Set<[Task, tf.LayersModel]>() - - constructor ( - private readonly config: Config, - tasksAndModels: TasksAndModels - ) { - this.ownRouter = express.Router() - - this.ownRouter.get('/', (req, res, next) => { - this.getTasksMetadata(req, res).catch(next) - }) - - this.ownRouter.post('/', (req, res) => { - const model = req.body.model - const newTask = req.body.task - - if (!( - model !== undefined && - newTask !== undefined && - isTask(newTask) - )) { - res.status(400) - return - } - - serialization.model.decode(model) - .then(async (model) => { - await tasksAndModels.addTaskAndModel(newTask, model) + private readonly ownRouter: express.Router + + private tasksAndModels = Set<[Task, tf.LayersModel]>() + + constructor( + private readonly config: Config, + tasksAndModels: TasksAndModels + ) { + this.ownRouter = express.Router() + + this.ownRouter.get('/', (req, res, next) => { + this.getTasksMetadata(req, res).catch(next) }) - .then(() => res.status(200).end('Successful task upload')) - .catch(console.error) - }) - - // delay listening - process.nextTick(() => - tasksAndModels.addListener('taskAndModel', (t, m) => - this.onNewTask(t, m))) - } - - public get router (): express.Router { - return this.ownRouter - } - - onNewTask (task: Task, model: tf.LayersModel): void { - this.ownRouter.get(`/${task.taskID}/:file`, (req, res, next) => { - this.getLatestModel(task.taskID, req, res).catch(next) - }) - - this.tasksAndModels = this.tasksAndModels.add([task, model]) - } - - /** - * Request handler called when a client sends a GET request asking for all the - * tasks metadata stored in the server's tasks.json file. This is used for - * generating the client's list of tasks. It requires no prior connection to the - * server and is thus publicly available data. - * @param request received from client - * @param response sent to client - */ - private async getTasksMetadata (request: Request, response: Response): Promise { - response - .status(200) - .send(this.tasksAndModels.map(([t, _]) => t).toArray()) - } - - /** - * Request handler called when a client sends a GET request asking for the - * TFJS model files of a given task. The files consist of the model's - * architecture file model.json and its layer weights file weights.bin. - * It requires no prior connection to the server and is thus publicly available - * data. - * @param request received from client - * @param response sent to client - */ - private async getLatestModel (taskID: TaskID, request: Request, response: Response): Promise { - const validModelFiles = Set.of('model.json', 'weights.bin') - - const file = request.params.file - if (!validModelFiles.has(file)) { - response.status(404) - return + + this.ownRouter.post('/', (req, res) => { + const model = req.body.model + const newTask = req.body.task + + if ( + !( + model !== undefined && + newTask !== undefined && + isTask(newTask) + ) + ) { + res.status(400) + return + } + + serialization.model + .decode(model) + .then(async (model) => { + await tasksAndModels.addTaskAndModel(newTask, model) + }) + .then(() => res.status(200).end('Successful task upload')) + .catch(console.error) + }) + + // delay listening + process.nextTick(() => + tasksAndModels.addListener('taskAndModel', (t, m) => + this.onNewTask(t, m) + ) + ) } - const taskAndModel = this.tasksAndModels.find(([t, _]) => t.taskID === taskID) - if (taskAndModel === undefined) { - response.status(404) - return + + public get router(): express.Router { + return this.ownRouter } - const encoded = await serialization.model.encode(taskAndModel[1]) + onNewTask(task: Task, model: tf.LayersModel): void { + this.ownRouter.get(`/${task.id}/:file`, (req, res, next) => { + this.getLatestModel(task.id, req, res).catch(next) + }) + + this.tasksAndModels = this.tasksAndModels.add([task, model]) + } - response.status(200).send(encoded) - console.log(`${file} download for task ${taskID} succeeded`) - } + /** + * Request handler called when a client sends a GET request asking for all the + * tasks metadata stored in the server's tasks.json file. This is used for + * generating the client's list of tasks. It requires no prior connection to the + * server and is thus publicly available data. + * @param request received from client + * @param response sent to client + */ + private async getTasksMetadata( + request: Request, + response: Response + ): Promise { + response + .status(200) + .send(this.tasksAndModels.map(([t, _]) => t).toArray()) + } + + /** + * Request handler called when a client sends a GET request asking for the + * TFJS model files of a given task. The files consist of the model's + * architecture file model.json and its layer weights file weights.bin. + * It requires no prior connection to the server and is thus publicly available + * data. + * @param request received from client + * @param response sent to client + */ + private async getLatestModel( + taskID: TaskID, + request: Request, + response: Response + ): Promise { + const validModelFiles = Set.of('model.json', 'weights.bin') + + const file = request.params.file + if (!validModelFiles.has(file)) { + response.status(404) + return + } + const taskAndModel = this.tasksAndModels.find( + ([t, _]) => t.id === taskID + ) + if (taskAndModel === undefined) { + response.status(404) + return + } + + const encoded = await serialization.model.encode(taskAndModel[1]) + + response.status(200).send(encoded) + console.log(`${file} download for task ${taskID} succeeded`) + } } diff --git a/server/src/tasks.ts b/server/src/tasks.ts index b1c3069d4..1907d96ea 100644 --- a/server/src/tasks.ts +++ b/server/src/tasks.ts @@ -3,123 +3,142 @@ import { EventEmitter } from 'node:events' import { createHash } from 'node:crypto' import fs from 'node:fs' -import { tf, Task, Path, Digest, isTaskProvider, TaskProvider, defaultTasks } from '@epfml/discojs-node' +import { + tf, + Task, + Path, + Digest, + isTaskProvider, + TaskProvider, + defaultTasks, +} from '@epfml/discojs-node' // default tasks and added ones // register 'taskAndModel' event to get tasks // TODO save and load from disk export class TasksAndModels extends EventEmitter { - tasksAndModels = Set<[Task, tf.LayersModel]>() - - constructor () { - super({ captureRejections: true }) - - this.on('newListener', (event, listener) => { - if (event !== 'taskAndModel') { - throw new Error('unknown event') - } - this.tasksAndModels.forEach(([t, m]) => listener(t, m)) - }) - } - - async loadDefaultTasks (): Promise { - const tasks = Object.values(defaultTasks) - await Promise.all(tasks.map(async (t: TaskProvider) => { - await this.addTaskAndModel(t) - })) - } - - // Returns already saved model in priority, then the model from the task definition - private async loadModelFromTask (task: Task | TaskProvider): Promise { - const discoTask = isTaskProvider(task) ? task.getTask() : task - let model: tf.LayersModel | undefined - - const modelPath = `./models/${discoTask.taskID}/` - if (fs.existsSync(modelPath)) { - // either a model has already been trained, or the pretrained - // model has already been downloaded - return await tf.loadLayersModel(`file://${modelPath}/model.json`) - } else { - const modelURL = discoTask.trainingInformation.modelURL - if (modelURL !== undefined) { - model = await tf.loadLayersModel(modelURL) - } else if (isTaskProvider(task)) { - model = await task.getModel() - } else { - throw new Error('model not provided in task definition') - } - } - - fs.mkdirSync(modelPath, { recursive: true }) - await model.save(`file://${modelPath}`) - - // Check digest if provided - if (discoTask.digest !== undefined) { - try { - this.checkDigest(discoTask.digest, modelPath) - } catch (e: any) { - TasksAndModels.removeModelFiles(modelPath) - throw new Error(e) - } - } - - return model - } + tasksAndModels = Set<[Task, tf.LayersModel]>() - private checkDigest (digest: Digest, modelPath: Path): void { - const hash = createHash(digest.algorithm) - const modelConfigRaw = fs.readFileSync(`${modelPath}/model.json`) + constructor() { + super({ captureRejections: true }) - const modelConfig = JSON.parse(modelConfigRaw.toString()) - const weightsFiles = modelConfig.weightsManifest[0].paths - if (!( - Array.isArray(weightsFiles) && - typeof weightsFiles[0] === 'string' - )) { - throw new Error() + this.on('newListener', (event, listener) => { + if (event !== 'taskAndModel') { + throw new Error('unknown event') + } + this.tasksAndModels.forEach(([t, m]) => listener(t, m)) + }) } - weightsFiles.forEach((file: string) => { - const data = fs.readFileSync(`${modelPath}/${file}`) - hash.update(data) - }) - - const computedDigest = hash.digest('base64') - if (computedDigest !== digest.value) { - console.warn(`digest was\n ${computedDigest}\nbut expected\n${digest.value}`) - throw new Error('digest mismatch') - } else { - console.info('digest verified') - } - } - async addTaskAndModel (task: Task | TaskProvider, model?: tf.LayersModel | URL): Promise { - let tfModel: tf.LayersModel - let discoTask: Task + async loadDefaultTasks(): Promise { + const tasks = Object.values(defaultTasks) + await Promise.all( + tasks.map(async (t: TaskProvider) => { + await this.addTaskAndModel(t) + }) + ) + } - if (isTaskProvider(task)) { - discoTask = task.getTask() - } else { - discoTask = task + // Returns already saved model in priority, then the model from the task definition + private async loadModelFromTask( + task: Task | TaskProvider + ): Promise { + const discoTask = isTaskProvider(task) ? task.getTask() : task + let model: tf.LayersModel | undefined + + const modelPath = `./models/${discoTask.id}/` + if (fs.existsSync(modelPath)) { + // either a model has already been trained, or the pretrained + // model has already been downloaded + return await tf.loadLayersModel(`file://${modelPath}/model.json`) + } else { + const modelURL = discoTask.trainingInformation.modelURL + if (modelURL !== undefined) { + model = await tf.loadLayersModel(modelURL) + } else if (isTaskProvider(task)) { + model = await task.getModel() + } else { + throw new Error('model not provided in task definition') + } + } + + fs.mkdirSync(modelPath, { recursive: true }) + await model.save(`file://${modelPath}`) + + // Check digest if provided + if (discoTask.digest !== undefined) { + try { + this.checkDigest(discoTask.digest, modelPath) + } catch (e: any) { + TasksAndModels.removeModelFiles(modelPath) + throw new Error(e) + } + } + + return model } - if (model === undefined) { - tfModel = await this.loadModelFromTask(task) - } else if (model instanceof tf.LayersModel) { - tfModel = model - } else if (model instanceof URL) { - tfModel = await tf.loadLayersModel(model.href) - } else { - throw new Error('invalid model') + private checkDigest(digest: Digest, modelPath: Path): void { + const hash = createHash(digest.algorithm) + const modelConfigRaw = fs.readFileSync(`${modelPath}/model.json`) + + const modelConfig = JSON.parse(modelConfigRaw.toString()) + const weightsFiles = modelConfig.weightsManifest[0].paths + if ( + !( + Array.isArray(weightsFiles) && + typeof weightsFiles[0] === 'string' + ) + ) { + throw new Error() + } + weightsFiles.forEach((file: string) => { + const data = fs.readFileSync(`${modelPath}/${file}`) + hash.update(data) + }) + + const computedDigest = hash.digest('base64') + if (computedDigest !== digest.value) { + console.warn( + `digest was\n ${computedDigest}\nbut expected\n${digest.value}` + ) + throw new Error('digest mismatch') + } else { + console.info('digest verified') + } } - this.tasksAndModels = this.tasksAndModels.add([discoTask, tfModel]) - this.emit('taskAndModel', task, model) - } + async addTaskAndModel( + task: Task | TaskProvider, + model?: tf.LayersModel | URL + ): Promise { + let tfModel: tf.LayersModel + let discoTask: Task + + if (isTaskProvider(task)) { + discoTask = task.getTask() + } else { + discoTask = task + } + + if (model === undefined) { + tfModel = await this.loadModelFromTask(task) + } else if (model instanceof tf.LayersModel) { + tfModel = model + } else if (model instanceof URL) { + tfModel = await tf.loadLayersModel(model.href) + } else { + throw new Error('invalid model') + } + + this.tasksAndModels = this.tasksAndModels.add([discoTask, tfModel]) + this.emit('taskAndModel', task, model) + } - static removeModelFiles (path: Path): void { - console.warn('removing nodel files at', path) - fs.rm(path, { recursive: true, force: true }, (err) => { - if (err !== null) console.error(err) - }) - } + static removeModelFiles(path: Path): void { + console.warn('removing nodel files at', path) + fs.rm(path, { recursive: true, force: true }, (err) => { + if (err !== null) console.error(err) + }) + } } diff --git a/web-client/cypress/e2e/tasks.cy.ts b/web-client/cypress/e2e/tasks.cy.ts index 37d961342..60fbbad34 100644 --- a/web-client/cypress/e2e/tasks.cy.ts +++ b/web-client/cypress/e2e/tasks.cy.ts @@ -1,39 +1,41 @@ /* eslint-disable no-undef */ -import { defaultTasks } from '@epfml/discojs' +import { defaultTasks } from "@epfml/discojs"; // most basic disco tasks export const TASK_LIST = [ defaultTasks.titanic.getTask(), defaultTasks.mnist.getTask(), - defaultTasks.cifar10.getTask() -] + defaultTasks.cifar10.getTask(), +]; -describe('tasks page', () => { - it('displays tasks', () => { - cy.intercept('tasks', TASK_LIST).as('tasks') - cy.visit('list') - cy.wait('@tasks').then((interception) => { - assert.lengthOf(interception.response.body, TASK_LIST.length) - }) - cy.get('div[id="tasks"]').children().should('have.length', TASK_LIST.length) - }) - it('redirects to training', () => { - cy.intercept('tasks', TASK_LIST).as('tasks') - cy.visit('list') - cy.wait('@tasks') +describe("tasks page", () => { + it("displays tasks", () => { + cy.intercept("tasks", TASK_LIST).as("tasks"); + cy.visit("list"); + cy.wait("@tasks").then((interception) => { + assert.lengthOf(interception.response.body, TASK_LIST.length); + }); + cy.get('div[id="tasks"]') + .children() + .should("have.length", TASK_LIST.length); + }); + it("redirects to training", () => { + cy.intercept("tasks", TASK_LIST).as("tasks"); + cy.visit("list"); + cy.wait("@tasks"); TASK_LIST.forEach((task) => { - cy.get(`div[id="${task.taskID}"]`).find('button').click() - cy.url().should('eq', Cypress.config().baseUrl + task.taskID) - cy.get('button').contains('previous', { matchCase: false }).click() - cy.url().should('eq', Cypress.config().baseUrl + 'list') - }) - }) - it('displays error message', () => { - cy.intercept('tasks', (req) => { - req.reply({ statusCode: 404 }) - }).as('tasks') - cy.visit('list') - cy.wait('@tasks') - cy.get('button').contains('reload page', { matchCase: false }) - }) -}) + cy.get(`div[id="${task.id}"]`).find("button").click(); + cy.url().should("eq", Cypress.config().baseUrl + task.id); + cy.get("button").contains("previous", { matchCase: false }).click(); + cy.url().should("eq", Cypress.config().baseUrl + "list"); + }); + }); + it("displays error message", () => { + cy.intercept("tasks", (req) => { + req.reply({ statusCode: 404 }); + }).as("tasks"); + cy.visit("list"); + cy.wait("@tasks"); + cy.get("button").contains("reload page", { matchCase: false }); + }); +}); diff --git a/web-client/src/components/pages/TaskList.vue b/web-client/src/components/pages/TaskList.vue index cd4fa53e9..79cae9004 100644 --- a/web-client/src/components/pages/TaskList.vue +++ b/web-client/src/components/pages/TaskList.vue @@ -1,17 +1,10 @@