Skip to content

Commit

Permalink
✨ feat(layer_seq): RMSNormSeq (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-francoisreboud authored Jun 16, 2024
1 parent d97e520 commit 2d65e95
Show file tree
Hide file tree
Showing 19 changed files with 2,154 additions and 673 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.

## [unreleased]

**layer_seq:** RMSNormSeq ([123](https://github.com/owkin/GrAIdient/pull/123))\
**layer_seq:** EmbeddingSeq ([122](https://github.com/owkin/GrAIdient/pull/122))\
🚀 **perf:** use half in Metal kernels ([121](https://github.com/owkin/GrAIdient/pull/121))\
🔨 **refactor:** handle float16 along float on GPU ([#120](https://github.com/owkin/GrAIdient/pull/120))\
Expand Down
75 changes: 72 additions & 3 deletions Sources/GrAIdient/Core/Function/Normalization.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ class Normalization
let outsNew = vDSP.add(β, vDSP.multiply(Ɣ, xHat))
return outsNew
}

///
/// Forward Gradient Checking RMSNorm CPU.
///
/// - Parameters:
/// - outs: The data to normalize.
/// - Ɣ: The weights to scale the normalization result.
/// - Returns: The data normalized.
///
static func forwardΣGC(outs: [Double],
Ɣ: [Double]) -> [Double]
{
let σ2 = vDSP.meanSquare(outs)
let xHat = vDSP.divide(outs, sqrt(σ2 +))
let outsNew = vDSP.multiply(Ɣ, xHat)
return outsNew
}

///
/// Forward Training CPU.
Expand Down Expand Up @@ -118,6 +135,30 @@ class Normalization
μ: μ,
σ2: σ2)
}

///
/// Forward RMSNorm CPU.
///
/// - Parameters:
/// - outs: The data to normalize.
/// - Ɣ: The weights to scale the normalization result.
/// - Returns: (The data normalized,
/// The data normalized without taking into account the bias and the weight,
/// The deviation of the data).
///
static func forwardΣ(outs: [Double],
Ɣ: [Double]) -> (outsNew: [Double],
xHat: [Double],
σ2: Double)
{
let σ2 = vDSP.meanSquare(outs)
let xHat = vDSP.divide(outs, sqrt(σ2 +))
let outsNew = vDSP.multiply(Ɣ, xHat)

return (outsNew: outsNew,
xHat: xHat,
σ2: σ2)
}

///
/// Forward Inference CPU.
Expand Down Expand Up @@ -191,9 +232,7 @@ class Normalization
/// - xHat: The data normalized without taking into account the bias and the weight.
/// - σ2: The deviation of the data.
/// - Ɣ: The weights that scaled the normalization result.
/// - Returns: (The gradient taking into account the normalization,
/// The gradient of β,
/// The gradient of Ɣ).
/// - Returns: The gradient taking into account the normalization.
///
static func backward(delta: [Double],
xHat: [Double],
Expand All @@ -215,6 +254,36 @@ class Normalization

return deltaNew
}

///
/// Backward RMSNorm CPU.
///
/// - Parameters:
/// - delta: The gradients to back propagate.
/// - xHat: The data normalized without taking into account the bias and the weight.
/// - σ2: The deviation of the data.
/// - Ɣ: The weights that scaled the normalization result.
/// - Returns: The gradient taking into account the normalization.
///
static func backwardΣ(delta: [Double],
xHat: [Double],
σ2: Double,
Ɣ: [Double]) -> [Double]
{
let nbElems = delta.count
let factor = 1.0 / (Double(nbElems) * sqrt(σ2 +))

let Ɣdelta = vDSP.multiply(Ɣ, delta)
let sum2 = vDSP.sum(vDSP.multiply(Ɣdelta, xHat))

let tmp1 = vDSP.add(
multiplication: (Ɣdelta, Double(nbElems)),
multiplication: (xHat, -sum2))
let deltaNew = vDSP.add(
multiplication: (tmp1, factor), 0)

return deltaNew
}

///
/// Backward Inference CPU.
Expand Down
Loading

0 comments on commit 2d65e95

Please sign in to comment.