Skip to content

Commit

Permalink
Merge branch 'develop' into 611-ui-crashes-julien
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed Feb 8, 2024
2 parents 0cd3de8 + 27fb55e commit 998f6db
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 160 deletions.
4 changes: 2 additions & 2 deletions discojs/discojs-core/src/aggregator/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ export abstract class Base<T> {
*/
protected readonly roundCutoff = 0,
/**
* The number of communication rounds occuring during any given aggregation round.
* The number of communication rounds occurring during any given aggregation round.
*/
public readonly communicationRounds = 1
) {
Expand Down Expand Up @@ -272,7 +272,7 @@ export abstract class Base<T> {
}

/**
* The current commnication round.
* The current communication round.
*/
get communicationRound (): number {
return this._communicationRound
Expand Down
44 changes: 15 additions & 29 deletions discojs/discojs-core/src/client/federated/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@ export class Base extends Client {
* by this client class, the server is the only node which we are communicating with.
*/
public static readonly SERVER_NODE_ID = 'federated-server-node-id'
/**
* Most recent server-fetched round.
*/
private serverRound?: number
/**
* Most recent server-fetched aggregated result.
*/
private serverResult?: WeightsContainer
/**
* Statistics curated by the federated server.
*/
Expand Down Expand Up @@ -92,41 +84,37 @@ export class Base extends Client {

/**
* Send a message containing our local weight updates to the federated server.
* And waits for the server to reply with the most recent aggregated weights
* @param weights The weight updates to send
*/
async sendPayload (payload: WeightsContainer): Promise<void> {
private async sendPayloadAndReceiveResult (payload: WeightsContainer): Promise<WeightsContainer|undefined> {
const msg: messages.SendPayload = {
type: type.SendPayload,
payload: await serialization.weights.encode(payload),
round: this.aggregator.round
}
this.server.send(msg)
// It is important than the client immediately awaits the server result or it may miss it
return await this.receiveResult()
}

/**
* Fetches the server's result for its current (most recent) round and add it to our aggregator.
* Waits for the server's result for its current (most recent) round and add it to our aggregator.
* Updates the aggregator's round if it's behind the server's.
*/
async receiveResult (): Promise<void> {
this.serverRound = undefined
this.serverResult = undefined

const msg: messages.MessageBase = {
type: type.ReceiveServerPayload
}
this.server.send(msg)

private async receiveResult (): Promise<WeightsContainer|undefined> {
try {
const { payload, round } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload)
this.serverRound = round
const serverRound = round

// Store the server result only if it is not stale
if (this.aggregator.round <= round) {
this.serverResult = serialization.weights.decode(payload)
const serverResult = serialization.weights.decode(payload)
// Update the local round to match the server's
if (this.aggregator.round < this.serverRound) {
this.aggregator.setRound(this.serverRound)
if (this.aggregator.round < serverRound) {
this.aggregator.setRound(serverRound)
}
return serverResult
}
} catch (e) {
console.error(e)
Expand Down Expand Up @@ -226,13 +214,11 @@ export class Base extends Client {
throw new Error('local aggregation result was not set')
}

// Send our contribution to the server
await this.sendPayload(this.aggregator.makePayloads(weights).first())
// Fetch the server result
await this.receiveResult()
// Send our local contribution to the server
// and receive the most recent weights as an answer to our contribution
const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first())

// TODO @s314cy: add communication rounds to federated learning
if (this.serverResult !== undefined && this.aggregator.add(Base.SERVER_NODE_ID, this.serverResult, round, 0)) {
if (serverResult !== undefined && this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
// Regular case: the server sends us its aggregation result which will serve our
// own aggregation result.
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Task, tf } from '../../..'
import { List } from 'immutable'
import { PreprocessingFunction } from './base'

Expand All @@ -9,7 +10,25 @@ export enum TabularPreprocessing {
Normalize
}

interface TabularEntry extends tf.TensorContainerObject {
xs: number[]
ys: tf.Tensor1D | number | undefined
}

const sanitize: PreprocessingFunction = {
type: TabularPreprocessing.Sanitize,
apply: (entry: tf.TensorContainer, task: Task): tf.TensorContainer => {
const { xs, ys } = entry as TabularEntry
return {
xs: xs.map(i => i === undefined ? 0 : i),
ys: ys
}
}
}

/**
* Available tabular preprocessing functions.
*/
export const AVAILABLE_PREPROCESSING = List<PreprocessingFunction>()
export const AVAILABLE_PREPROCESSING = List([
sanitize]
).sortBy((e) => e.type)
3 changes: 2 additions & 1 deletion discojs/discojs-core/src/dataset/data/tabular_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ export class TabularData extends Data {
try {
await dataset.iterator()
} catch (e) {
throw new Error('Data input format is not compatible with the chosen task')
console.error('Data input format is not compatible with the chosen task.')
throw (e)
}

return new TabularData(dataset, task, size)
Expand Down
5 changes: 3 additions & 2 deletions discojs/discojs-core/src/default_tasks/titanic.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { tf, Task, TaskProvider } from '..'
import { tf, Task, TaskProvider, data } from '..'

export const titanic: TaskProvider = {
getTask (): Task {
Expand Down Expand Up @@ -49,7 +49,8 @@ export const titanic: TaskProvider = {
roundDuration: 10,
validationSplit: 0.2,
batchSize: 30,
preprocessingFunctions: [],
preprocessingFunctions: [data.TabularPreprocessing.Sanitize],
learningRate: 0.001,
modelCompileData: {
optimizer: 'sgd',
loss: 'binaryCrossentropy',
Expand Down
54 changes: 33 additions & 21 deletions discojs/discojs-core/src/validation/validator.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { assert } from 'chai'
import fs from 'fs'

import { Task, node, Validator, ConsoleLogger, EmptyMemory, client as clients, data, aggregator } from '@epfml/discojs-node'
import { Task, node, Validator, ConsoleLogger, EmptyMemory,
client as clients, data, aggregator, defaultTasks } from '@epfml/discojs-node'

const simplefaceMock = {
taskID: 'simple_face',
Expand Down Expand Up @@ -55,25 +56,36 @@ describe('validator', () => {
`expected accuracy greater than 0.3 but got ${validator.accuracy}`
)
console.table(validator.confusionMatrix)
}).timeout(10_000)
}).timeout(15_000)

// TODO: fix titanic model (nan accuracy)
// it('works for titanic', async () => {
// const data: Data = await new NodeTabularLoader(titanic.task, ',')
// .loadAll(['../../example_training_data/titanic.csv'], {
// features: titanic.task.trainingInformation?.inputColumns,
// labels: titanic.task.trainingInformation?.outputColumns
// })
// const validator = new Validator(titanic.task, new ConsoleLogger(), titanic.model())
// await validator.assess(data)

// assert(
// validator.visitedSamples() === data.size,
// `expected ${TITANIC_SAMPLES} visited samples but got ${validator.visitedSamples()}`
// )
// assert(
// validator.accuracy() > 0.5,
// `expected accuracy greater than 0.5 but got ${validator.accuracy()}`
// )
// })
it('works for titanic', async () => {
const titanicTask = defaultTasks.titanic.getTask()
const files = ['../../example_training_data/titanic_train.csv']
const data: data.Data = (await new node.data.NodeTabularLoader(titanicTask, ',').loadAll(files, {
features: titanicTask.trainingInformation.inputColumns,
labels: titanicTask.trainingInformation.outputColumns,
shuffle: false
})).train
const buffer = new aggregator.MeanAggregator(titanicTask)
const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, buffer)
buffer.setModel(await client.getLatestModel())
const validator = new Validator(titanicTask,
new ConsoleLogger(),
new EmptyMemory(),
undefined,
client)
await validator.assess(data)
// data.size is undefined because tfjs handles dataset lazily
// instead we count the dataset size manually
let size = 0
await data.dataset.forEachAsync(() => size+=1)
assert(
validator.visitedSamples === size,
`expected ${size} visited samples but got ${validator.visitedSamples}`
)
assert(
validator.accuracy > 0.5,
`expected accuracy greater than 0.5 but got ${validator.accuracy}`
)
}).timeout(15_000)
})
13 changes: 12 additions & 1 deletion docs/node_example/data.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import fs from 'fs'
import Rand from 'rand-seed'

import { node, data, Task } from '@epfml/discojs-node'
import { node, data, Task, defaultTasks } from '@epfml/discojs-node'

const rand = new Rand('1234')

Expand Down Expand Up @@ -45,3 +45,14 @@ export async function loadData (task: Task): Promise<data.DataSplit> {

return await new node.data.NodeImageLoader(task).loadAll(files, { labels: labels })
}

export async function loadTitanicData (task:Task): Promise<data.
DataSplit> {
const files = ['../../example_training_data/titanic_train.csv']
const titanicTask = defaultTasks.titanic.getTask()
return await new node.data.NodeTabularLoader(task, ',').loadAll(files, {
features: titanicTask.trainingInformation.inputColumns,
labels: titanicTask.trainingInformation.outputColumns,
shuffle: false
})
}
20 changes: 8 additions & 12 deletions docs/node_example/example.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import { data, Disco, fetchTasks, Task } from '@epfml/discojs-node'

import { startServer } from './start_server'
import { loadData } from './data'
import { loadTitanicData } from './data'

/**
* Example of discojs API, we load data, build the appropriate loggers, the disco object
Expand All @@ -18,24 +16,22 @@ async function runUser (url: URL, task: Task, dataset: data.DataSplit): Promise<

async function main (): Promise<void> {

const [server, serverUrl] = await startServer()
// First have a server instance running before running this script
const serverUrl = new URL('http://localhost:8080/')

const tasks = await fetchTasks(serverUrl)

// Choose your task to train
const task = tasks.get('simple_face') as Task
const task = tasks.get('titanic') as Task

const dataset = await loadData(task)
const dataset = await loadTitanicData(task)

// Add more users to the list to simulate more clients
await Promise.all([
runUser(serverUrl, task, dataset),
runUser(serverUrl, task, dataset)
runUser(serverUrl, task, dataset),
runUser(serverUrl, task, dataset),
])

await new Promise((resolve, reject) => {
server.once('close', resolve)
server.close(reject)
})
}

main().catch(console.error)
33 changes: 0 additions & 33 deletions docs/node_example/start_server.ts

This file was deleted.

2 changes: 1 addition & 1 deletion docs/node_example/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"declaration": true,

"typeRoots": ["node_modules/@types", "discojs-core/types"]
"typeRoots": ["node_modules/@types", "../../discojs/discojs-core/types"]
},
"include": ["*.ts"],
"exclude": ["node_modules"]
Expand Down
Loading

0 comments on commit 998f6db

Please sign in to comment.