Skip to content

Commit

Permalink
Checking for max tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
MihaiSurdeanu committed Sep 16, 2023
1 parent c270b3e commit a8c707b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.clulab.scala_transformers.tokenizer.LongTokenization
*/
class TokenClassifier(
val encoder: Encoder,
val maxTokens: Int,
val tasks: Array[LinearLayer],
val tokenizer: Tokenizer
) {
Expand All @@ -45,7 +46,11 @@ class TokenClassifier(
val tokenization = LongTokenization(tokenizer.tokenize(words.toArray))
val inputIds = tokenization.tokenIds
val wordIds = tokenization.wordIds
//val tokens = tokenization.tokens
val tokens = tokenization.tokens

if(inputIds.length > maxTokens) {
throw new EncoderMaxTokensRuntimeException(s"Encoder error: the following text contains more tokens than the maximum number accepted by this encoder ($maxTokens): ${tokens.mkString(", ")}")
}

// run the sentence through the transformer encoder
val encOutput = encoder.forward(inputIds)
Expand Down Expand Up @@ -150,6 +155,8 @@ class TokenClassifier(
}
}

class EncoderMaxTokensRuntimeException(msg: String) extends RuntimeException(msg)

object TokenClassifier {

def fromFiles(modelDir: String): TokenClassifier = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ abstract class TokenClassifierFactory(val tokenClassifierLayout: TokenClassifier
def exists(place: String): Boolean

def name: String = sourceLine(newSource(tokenClassifierLayout.name))
def maxTokens: Int = sourceLine(newSource(tokenClassifierLayout.maxTokens)).toInt

def taskCount: Int = 0.until(Int.MaxValue)
.find { index =>
Expand All @@ -30,10 +31,16 @@ abstract class TokenClassifierFactory(val tokenClassifierLayout: TokenClassifier

def sourceBoolean(place: String): Boolean = sourceBoolean(newSource(place))

protected def newTokenClassifier(encoder: Encoder, tokenizerName: String, addPrefixSpace: Boolean, tasks: Array[LinearLayer]): TokenClassifier = {
protected def newTokenClassifier(
encoder: Encoder,
tokenizerName: String,
encoderMaxTokens: Int,
addPrefixSpace: Boolean,
tasks: Array[LinearLayer]): TokenClassifier = {

val tokenizer = ScalaJniTokenizer(tokenizerName, addPrefixSpace)

new TokenClassifier(encoder, tasks, tokenizer)
new TokenClassifier(encoder, encoderMaxTokens, tasks, tokenizer)
}

def newTokenClassifier: TokenClassifier = {
Expand All @@ -46,7 +53,7 @@ abstract class TokenClassifierFactory(val tokenClassifierLayout: TokenClassifier

linearLayerFactory.newLinearLayer
}
val tokenClassifier = newTokenClassifier(newEncoder, name, addPrefixSpace, linearLayers.toArray)
val tokenClassifier = newTokenClassifier(newEncoder, name, maxTokens, addPrefixSpace, linearLayers.toArray)

logger.info("Load complete.")
tokenClassifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class TokenClassifierLayout(val baseName: String) {

def name: String = s"$baseName/encoder.name"

def maxTokens: String = s"$baseName/encoder.maxtokens"

def tasks: String = s"$baseName/tasks"

def task(index: Int): String = s"${tasks}/$index"
Expand Down

0 comments on commit a8c707b

Please sign in to comment.