Skip to content

Commit

Permalink
generate: add option --variants (nr of nbest seqs)
Browse files Browse the repository at this point in the history
  • Loading branch information
bertsky committed Mar 11, 2024
1 parent 6952e54 commit 1f7b5e5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
9 changes: 5 additions & 4 deletions ocrd_keraslm/lib/rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def predict(self, candidates, initial_states, context=None):
return preds, final_states

# todo: also allow specifying suffix
def generate(self, prefix, number, context=None):
def generate(self, prefix, length, context=None, variants=1):
'''Generate a number of characters after a prefix.
Calculate the hidden layer state after reading the string `prefix`
Expand Down Expand Up @@ -674,7 +674,7 @@ def generate(self, prefix, number, context=None):
cost=0.0)]

# beam search
for _ in range(number): # iterate over number of characters to be generated
for _ in range(length): # iterate over number of characters to be generated
fringe = next_fringe
preds, states = self.predict([n.value for n in fringe],
[n.state for n in fringe],
Expand All @@ -692,8 +692,9 @@ def generate(self, prefix, number, context=None):
n_new = Node(parent=n, state=state, value=self.mapping[1][best], cost=cost)
insort_left(next_fringe, n_new) # add alternative to tree
next_fringe = next_fringe[:256] # keep 256-best paths (equals batch size)
best = next_fringe[0] # best-scoring
result = ''.join([n.value for n in best.to_sequence()])
best = next_fringe[0:variants] # best-scoring
result = [''.join([n.value for n in res.to_sequence()])
for res in best]

return result # without prefix

Expand Down
10 changes: 6 additions & 4 deletions ocrd_keraslm/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,11 @@ def test(model, data):
@cli.command(short_help='sample characters from language model')
@click.option('-m', '--model', required=True, help='model file', type=click.Path(dir_okay=False, exists=True))
@click.option('-n', '--number', default=1, help='number of characters to sample', type=click.IntRange(min=1, max=10000))
@click.option('-v', '--variants', default=1, help='number of character sequences to sample', type=click.IntRange(min=1, max=10000))
@click.option('-c', '--context', default=None, help='constant meta-data input')
@click.argument('prefix', type=click.STRING)
# todo: also allow specifying suffix
def generate(model, number, prefix, context):
def generate(model, number, variants, context, prefix):
"""Apply a language model, generating the most probable characters (starting with PREFIX string)."""

# load model
Expand All @@ -148,12 +149,13 @@ def generate(model, number, prefix, context):
rater.load_weights(model)

if context:
context = list(map(lambda x: ceil(int(x)/10), context.split(' ')))
context = [ceil(int(x)/10) for x in context.split(' ')]
else:
context = rater.underspecify_contexts()

result = rater.generate(prefix, number, context)
click.echo(prefix[:-1] + result)
result = rater.generate(prefix, number, context, variants)
for res in result:
click.echo(prefix[:-1] + res)

@cli.command(short_help='Print the mapped characters')
@click.option('-m', '--model', required=True, help='model file', type=click.Path(dir_okay=False, exists=True))
Expand Down

0 comments on commit 1f7b5e5

Please sign in to comment.