Skip to content

Commit

Permalink
Aligned label and head predictions in predictWithScores. This is needed
Browse files Browse the repository at this point in the history
for ensemble parsing.
  • Loading branch information
MihaiSurdeanu committed Aug 11, 2023
1 parent 4cf79d9 commit 1014eed
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import breeze.linalg.`*`
import breeze.linalg.argmax
import breeze.linalg.DenseMatrix
import breeze.linalg.DenseVector
import scala.collection.mutable.ArrayBuffer

/** Implements one linear layer */
class LinearLayer(
Expand Down Expand Up @@ -56,7 +57,7 @@ class LinearLayer(

/** Predict all labels and their scores per token */
def predictWithScores(inputSentence: DenseMatrix[Float],
heads: Option[Array[Int]],
heads: Option[Array[Array[Int]]],
masks: Option[Array[Boolean]]): Array[Array[(String, Float)]] = {
val batchSentences = Array(inputSentence)
val batchHeads = heads.map(Array(_))
Expand All @@ -66,7 +67,7 @@ class LinearLayer(

/** Predict all labels and their scores per token in each sentence in the batch */
def predictWithScores(inputBatch: Array[DenseMatrix[Float]],
batchHeads: Option[Array[Array[Int]]],
batchHeads: Option[Array[Array[Array[Int]]]],
batchMasks: Option[Array[Array[Boolean]]]): Array[Array[Array[(String, Float)]]] = {
if (dual) predictDualWithScores(inputBatch, batchHeads, batchMasks)
else predictPrimalWithScores(inputBatch)
Expand All @@ -77,7 +78,7 @@ class LinearLayer(
headRelativePositions: Array[Int]): DenseMatrix[Float] = {

// this matrix concatenates the hidden states of modifier + corresponding head
// rows = number of tokens in the sentence; cols = hidden state size
// rows = number of tokens in the sentence; cols = hidden state size x 2
val concatMatrix = DenseMatrix.zeros[Float](rows = sentenceHiddenStates.rows, cols = 2 * sentenceHiddenStates.cols)

// traverse all modifiers
Expand All @@ -100,6 +101,37 @@ class LinearLayer(
concatMatrix
}

/**
* Generates a 1-row matrix containing a concatenation of the modifier and head embeddings
*
*/
def concatenateModifierAndHead(
sentenceHiddenStates: DenseMatrix[Float],
modifierAbsolutePosition: Int,
headRelativePosition: Int): DenseMatrix[Float] = {

// this matrix concatenates the hidden states of modifier + corresponding head
// rows = 1; cols = hidden state size x 2
val concatMatrix = DenseMatrix.zeros[Float](rows = 1, cols = 2 * sentenceHiddenStates.cols)

// embedding of the modifier
val modHiddenState = sentenceHiddenStates(modifierAbsolutePosition, ::)

// embedding of the head
val rawHeadAbsPos = modifierAbsolutePosition + headRelativePosition
val headAbsolutePosition =
if(rawHeadAbsPos >= 0 && rawHeadAbsPos < sentenceHiddenStates.rows) rawHeadAbsPos
else modifierAbsolutePosition // if the absolute position is invalid (e.g., root node or incorrect prediction) duplicate the mod embedding
val headHiddenState = sentenceHiddenStates(headAbsolutePosition, ::)

// concatenation of the modifier and head embeddings
// vector concatenation in Breeze operates over vertical vectors, hence the transposing here
val concatState = DenseVector.vertcat(modHiddenState.t, headHiddenState.t).t

concatMatrix(0, ::) :+= concatState
concatMatrix
}

/** Predict the top label for each combination of modifier token and corresponding head token */
def predictDual(inputBatch: Array[DenseMatrix[Float]],
batchHeads: Option[Array[Array[Int]]] = None,
Expand Down Expand Up @@ -134,41 +166,57 @@ class LinearLayer(
outputBatch
}

// predicts the top label for each of the candidate heads
// out dimensions: sentence in batch x token in sentence x label/score per token
// batchHeads dimensions: sentence in batch x token in sentence x heads per token
// labels are sorted in descending order of their scores
def predictDualWithScores(inputBatch: Array[DenseMatrix[Float]],
batchHeads: Option[Array[Array[Int]]] = None,
batchHeads: Option[Array[Array[Array[Int]]]] = None,
batchMasks: Option[Array[Array[Boolean]]] = None): Array[Array[Array[(String, Float)]]] = {
assert(batchHeads.isDefined)
assert(batchMasks.isDefined)
val indexToLabel = labelsOpt.getOrElse(throw new RuntimeException("ERROR: can't predict without labels!"))

// dimensions: sent in batch x token in sentence x label per candidate head
val outputBatch = new Array[Array[Array[(String, Float)]]](inputBatch.length)

// TODO: maybe the triple for loop below can be improved?

// we process one sentence at a time because the dual setting makes it harder to batch
for (i <- inputBatch.indices) {
val input = inputBatch(i)
val heads = batchHeads.get(i)
val headCandidatesPerSentence = batchHeads.get(i)

// generate a matrix that is twice as wide to concatenate the embeddings of the mod + head
val concatInput = concatenateModifiersAndHeads(input, heads)
// now process each token separately
val outputsPerSentence = new ArrayBuffer[Array[(String, Float)]]()
for (j <- headCandidatesPerSentence.indices) {
val modifierAbsolutePosition = j
val headCandidatesPerToken = headCandidatesPerSentence(j)

// get the logits for the current sentence produced by this linear layer
val logitsPerSentence = forward(Array(concatInput))(0)
// process each head candidate for this token
val outputsPerToken = new ArrayBuffer[(String, Float)]()
for(headRelativePosition <- headCandidatesPerToken) {
// generate a matrix that is twice as wide to concatenate the embeddings of the mod + head
val concatInput = concatenateModifierAndHead(input, modifierAbsolutePosition, headRelativePosition)

// one token per row; store scores for all labels for this token
val allLabels = Range(0, logitsPerSentence.rows).map { i =>
// picks line i from a 2D matrix and converts it to Array
val scores = logitsPerSentence(i, ::).t.toArray
// extract the label at each position in the row and its score
val labelsAndScores = indexToLabel.zip(scores)
// get the logits for the current pair of modifier and head
val logitsPerSentence = forward(Array(concatInput))(0)

// keep scores in descending order (largest first)
labelsAndScores.sortBy(-_._2)
}
val labelScores = logitsPerSentence(0, ::)
val bestIndex = argmax(labelScores.t)
val bestScore = labelScores(bestIndex)
val bestLabel = indexToLabel(bestIndex)

outputBatch(i) = allLabels.toArray
}
// println(s"Top prediction for mod $modifierAbsolutePosition and relative head $headRelativePosition is $bestLabel with score $bestScore")

outputsPerToken += Tuple2(bestLabel, bestScore)
} // end head candidates for this token

outputsPerSentence += outputsPerToken.toArray
} // end this token

outputBatch(i) = outputsPerSentence.toArray
} // end sentence batch

outputBatch
}
Expand Down Expand Up @@ -206,7 +254,7 @@ class LinearLayer(
val labelsAndScores = labels.zip(scores)

// keep scores in descending order (largest first)
labelsAndScores.sortBy(_._2)
labelsAndScores.sortBy(- _._2) // - score guarantees sorting in descending order of scores
}

allLabels.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@ class TokenClassifier(
val tokenization = LongTokenization(tokenizer.tokenize(words.toArray))
val inputIds = tokenization.tokenIds
val wordIds = tokenization.wordIds
//val tokens = tokenization.tokens

// run the sentence through the transformer encoder
val encOutput = encoder.forward(inputIds)

// outputs for all tasks stored here: task x tokens in sentence x scores per token
val allLabels = new Array[Array[Array[(String, Float)]]](tasks.length)
var heads: Option[Array[Int]] = None
// all heads predicted for every token
// dimensions: token x heads
var heads: Option[Array[Array[Int]]] = None

// now generate token label predictions for all primary tasks (not dual!)
for (i <- tasks.indices) {
Expand All @@ -61,17 +64,19 @@ class TokenClassifier(
allLabels(i) = wordLabels

// if this is the task that predicts head positions, then save them for the dual tasks
// here we save only the head predicted with the highest score (hence the .head)
// we save all the heads predicted for each token
if (tasks(i).name == headTaskName) {
heads = Some(tokenLabels.map(_.head._1.toInt))
heads = Some(tokenLabels.map(_.map(_._1.toInt)))
}
}
}

// generate outputs for the dual tasks, if heads were predicted by one of the primary tasks
// the dual task(s) must be aligned with the heads.
// that is, we predict the top label for each of the head candidates
if (heads.isDefined) {
//println("Tokens: " + tokens.mkString(", "))
//println("Heads: " + heads.get.mkString(", "))
//println("Heads:\n\t" + heads.get.map(_.slice(0, 3).mkString(", ")).mkString("\n\t"))
//println("Masks: " + TokenClassifier.mkTokenMask(wordIds).mkString(", "))
val masks = Some(TokenClassifier.mkTokenMask(wordIds))

Expand Down Expand Up @@ -102,6 +107,7 @@ class TokenClassifier(
val tokenization = LongTokenization(tokenizer.tokenize(words.toArray))
val inputIds = tokenization.tokenIds
val wordIds = tokenization.wordIds
//val tokens = tokenization.tokens

// run the sentence through the transformer encoder
val encOutput = encoder.forward(inputIds)
Expand Down

0 comments on commit 1014eed

Please sign in to comment.