Skip to content

Commit

Permalink
Fix logic error in faa_aligner
Browse files Browse the repository at this point in the history
  • Loading branch information
bjascob committed Jan 19, 2021
1 parent d792038 commit bde002e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion amrlib/alignments/faa_aligner/faa_aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def align_sents(self, sents, gstrings):
gstrings = [to_graph_line(g) for g in gstrings]
eng_td_lines, amr_td_lines = preprocess_infer(self.working_dir, sents, gstrings)
fa_out_lines = self.aligner.align(eng_td_lines, amr_td_lines)
amr_surface_aligns, alignment_strings = postprocess(self.working_dir, fa_out_lines)
amr_surface_aligns, alignment_strings = postprocess(self.working_dir, fa_out_lines, sents, gstrings)
return amr_surface_aligns, alignment_strings


Expand Down
13 changes: 7 additions & 6 deletions amrlib/alignments/faa_aligner/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


# if model_out_lines is None, read from the file
def postprocess(wk_dir, model_out_lines=None, **kwargs):
def postprocess(wk_dir, model_out_lines=None, eng_lines=None, amr_lines=None, **kwargs):
# Input filenames
eng_fn = os.path.join(wk_dir, kwargs.get('eng_fn', 'sents.txt'))
amr_fn = os.path.join(wk_dir, kwargs.get('amr_fn', 'gstrings.txt'))
Expand All @@ -21,11 +21,12 @@ def postprocess(wk_dir, model_out_lines=None, **kwargs):
align_to_str_fn = os.path.join(wk_dir, kwargs.get('align_to_str_fn', 'align_to_str.err'))

# Read the input files and get the number of lines, which must be the same
with open(eng_fn) as f:
eng_lines = [l.strip() for l in f]
with open(amr_fn) as f:
amr_lines = [l.strip() for l in f]
assert len(eng_lines) == len(amr_lines)
if eng_lines is None or amr_lines is None:
with open(eng_fn) as f:
eng_lines = [l.strip() for l in f]
with open(amr_fn) as f:
amr_lines = [l.strip() for l in f]
assert len(eng_lines) == len(amr_lines)
lines_number = len(eng_lines)

# Read the output of the aligner or use the supplied input above
Expand Down

0 comments on commit bde002e

Please sign in to comment.