diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/LinearLayer.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/LinearLayer.scala index a1f1489..d0464e1 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/LinearLayer.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/LinearLayer.scala @@ -4,6 +4,7 @@ import breeze.linalg.`*` import breeze.linalg.argmax import breeze.linalg.DenseMatrix import breeze.linalg.DenseVector +import scala.collection.mutable.ArrayBuffer /** Implements one linear layer */ class LinearLayer( @@ -56,7 +57,7 @@ class LinearLayer( /** Predict all labels and their scores per token */ def predictWithScores(inputSentence: DenseMatrix[Float], - heads: Option[Array[Int]], + heads: Option[Array[Array[Int]]], masks: Option[Array[Boolean]]): Array[Array[(String, Float)]] = { val batchSentences = Array(inputSentence) val batchHeads = heads.map(Array(_)) @@ -66,7 +67,7 @@ class LinearLayer( /** Predict all labels and their scores per token in each sentence in the batch */ def predictWithScores(inputBatch: Array[DenseMatrix[Float]], - batchHeads: Option[Array[Array[Int]]], + batchHeads: Option[Array[Array[Array[Int]]]], batchMasks: Option[Array[Array[Boolean]]]): Array[Array[Array[(String, Float)]]] = { if (dual) predictDualWithScores(inputBatch, batchHeads, batchMasks) else predictPrimalWithScores(inputBatch) @@ -77,7 +78,7 @@ class LinearLayer( headRelativePositions: Array[Int]): DenseMatrix[Float] = { // this matrix concatenates the hidden states of modifier + corresponding head - // rows = number of tokens in the sentence; cols = hidden state size + // rows = number of tokens in the sentence; cols = hidden state size x 2 val concatMatrix = DenseMatrix.zeros[Float](rows = sentenceHiddenStates.rows, cols = 2 * sentenceHiddenStates.cols) // traverse all modifiers @@ -100,6 +101,37 @@ class LinearLayer( concatMatrix } + /** + * Generates a 1-row matrix containing a concatenation of the modifier and head embeddings + * + */ + def concatenateModifierAndHead( + sentenceHiddenStates: DenseMatrix[Float], + modifierAbsolutePosition: Int, + headRelativePosition: Int): DenseMatrix[Float] = { + + // this matrix concatenates the hidden states of modifier + corresponding head + // rows = 1; cols = hidden state size x 2 + val concatMatrix = DenseMatrix.zeros[Float](rows = 1, cols = 2 * sentenceHiddenStates.cols) + + // embedding of the modifier + val modHiddenState = sentenceHiddenStates(modifierAbsolutePosition, ::) + + // embedding of the head + val rawHeadAbsPos = modifierAbsolutePosition + headRelativePosition + val headAbsolutePosition = + if(rawHeadAbsPos >= 0 && rawHeadAbsPos < sentenceHiddenStates.rows) rawHeadAbsPos + else modifierAbsolutePosition // if the absolute position is invalid (e.g., root node or incorrect prediction) duplicate the mod embedding + val headHiddenState = sentenceHiddenStates(headAbsolutePosition, ::) + + // concatenation of the modifier and head embeddings + // vector concatenation in Breeze operates over vertical vectors, hence the transposing here + val concatState = DenseVector.vertcat(modHiddenState.t, headHiddenState.t).t + + concatMatrix(0, ::) :+= concatState + concatMatrix + } + /** Predict the top label for each combination of modifier token and corresponding head token */ def predictDual(inputBatch: Array[DenseMatrix[Float]], batchHeads: Option[Array[Array[Int]]] = None, @@ -134,41 +166,57 @@ class LinearLayer( outputBatch } + // predicts the top label for each of the candidate heads // out dimensions: sentence in batch x token in sentence x label/score per token + // batchHeads dimensions: sentence in batch x token in sentence x heads per token // labels are sorted in descending order of their scores def predictDualWithScores(inputBatch: Array[DenseMatrix[Float]], - batchHeads: Option[Array[Array[Int]]] = None, + batchHeads: Option[Array[Array[Array[Int]]]] = None, batchMasks: Option[Array[Array[Boolean]]] = None): Array[Array[Array[(String, Float)]]] = { assert(batchHeads.isDefined) assert(batchMasks.isDefined) val indexToLabel = labelsOpt.getOrElse(throw new RuntimeException("ERROR: can't predict without labels!")) + // dimensions: sent in batch x token in sentence x label per candidate head val outputBatch = new Array[Array[Array[(String, Float)]]](inputBatch.length) + // TODO: maybe the triple for loop below can be improved? + // we process one sentence at a time because the dual setting makes it harder to batch for (i <- inputBatch.indices) { val input = inputBatch(i) - val heads = batchHeads.get(i) + val headCandidatesPerSentence = batchHeads.get(i) - // generate a matrix that is twice as wide to concatenate the embeddings of the mod + head - val concatInput = concatenateModifiersAndHeads(input, heads) + // now process each token separately + val outputsPerSentence = new ArrayBuffer[Array[(String, Float)]]() + for (j <- headCandidatesPerSentence.indices) { + val modifierAbsolutePosition = j + val headCandidatesPerToken = headCandidatesPerSentence(j) - // get the logits for the current sentence produced by this linear layer - val logitsPerSentence = forward(Array(concatInput))(0) + // process each head candidate for this token + val outputsPerToken = new ArrayBuffer[(String, Float)]() + for(headRelativePosition <- headCandidatesPerToken) { + // generate a matrix that is twice as wide to concatenate the embeddings of the mod + head + val concatInput = concatenateModifierAndHead(input, modifierAbsolutePosition, headRelativePosition) - // one token per row; store scores for all labels for this token - val allLabels = Range(0, logitsPerSentence.rows).map { i => - // picks line i from a 2D matrix and converts it to Array - val scores = logitsPerSentence(i, ::).t.toArray - // extract the label at each position in the row and its score - val labelsAndScores = indexToLabel.zip(scores) + // get the logits for the current pair of modifier and head + val logitsPerSentence = forward(Array(concatInput))(0) - // keep scores in descending order (largest first) - labelsAndScores.sortBy(-_._2) - } + val labelScores = logitsPerSentence(0, ::) + val bestIndex = argmax(labelScores.t) + val bestScore = labelScores(bestIndex) + val bestLabel = indexToLabel(bestIndex) - outputBatch(i) = allLabels.toArray - } + // println(s"Top prediction for mod $modifierAbsolutePosition and relative head $headRelativePosition is $bestLabel with score $bestScore") + + outputsPerToken += Tuple2(bestLabel, bestScore) + } // end head candidates for this token + + outputsPerSentence += outputsPerToken.toArray + } // end this token + + outputBatch(i) = outputsPerSentence.toArray + } // end sentence batch outputBatch } @@ -206,7 +254,7 @@ class LinearLayer( val labelsAndScores = labels.zip(scores) // keep scores in descending order (largest first) - labelsAndScores.sortBy(_._2) + labelsAndScores.sortBy(- _._2) // - score guarantees sorting in descending order of scores } allLabels.toArray diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/TokenClassifier.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/TokenClassifier.scala index 8050d18..f57c46d 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/TokenClassifier.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/TokenClassifier.scala @@ -45,13 +45,16 @@ class TokenClassifier( val tokenization = LongTokenization(tokenizer.tokenize(words.toArray)) val inputIds = tokenization.tokenIds val wordIds = tokenization.wordIds + //val tokens = tokenization.tokens // run the sentence through the transformer encoder val encOutput = encoder.forward(inputIds) // outputs for all tasks stored here: task x tokens in sentence x scores per token val allLabels = new Array[Array[Array[(String, Float)]]](tasks.length) - var heads: Option[Array[Int]] = None + // all heads predicted for every token + // dimensions: token x heads + var heads: Option[Array[Array[Int]]] = None // now generate token label predictions for all primary tasks (not dual!) for (i <- tasks.indices) { @@ -61,17 +64,19 @@ class TokenClassifier( allLabels(i) = wordLabels // if this is the task that predicts head positions, then save them for the dual tasks - // here we save only the head predicted with the highest score (hence the .head) + // we save all the heads predicted for each token if (tasks(i).name == headTaskName) { - heads = Some(tokenLabels.map(_.head._1.toInt)) + heads = Some(tokenLabels.map(_.map(_._1.toInt))) } } } // generate outputs for the dual tasks, if heads were predicted by one of the primary tasks + // the dual task(s) must be aligned with the heads. + // that is, we predict the top label for each of the head candidates if (heads.isDefined) { //println("Tokens: " + tokens.mkString(", ")) - //println("Heads: " + heads.get.mkString(", ")) + //println("Heads:\n\t" + heads.get.map(_.slice(0, 3).mkString(", ")).mkString("\n\t")) //println("Masks: " + TokenClassifier.mkTokenMask(wordIds).mkString(", ")) val masks = Some(TokenClassifier.mkTokenMask(wordIds)) @@ -102,6 +107,7 @@ class TokenClassifier( val tokenization = LongTokenization(tokenizer.tokenize(words.toArray)) val inputIds = tokenization.tokenIds val wordIds = tokenization.wordIds + //val tokens = tokenization.tokens // run the sentence through the transformer encoder val encOutput = encoder.forward(inputIds)