diff --git a/rebl/ed/entity_disambiguation.py b/rebl/ed/entity_disambiguation.py index aaf9138..6fc69a7 100644 --- a/rebl/ed/entity_disambiguation.py +++ b/rebl/ed/entity_disambiguation.py @@ -73,7 +73,8 @@ def create_fields(self): def stream_doc_with_spans(self): data = next(self.stream_parquet_md_file) - for i, raw_data in enumerate(self.stream_raw_source_file): + docs = sorted([d for d in self.stream_raw_source_file], key=lambda a: int(a.strip().split('\t')[0])) + for i, raw_data in enumerate(docs): try: json_content = json.loads(raw_data) except json.decoder.JSONDecodeError: @@ -86,7 +87,8 @@ def stream_doc_with_spans(self): field = self.fields[field_key] current_text = json_content[field] spans, tags, scores = [], [], [] - while data[1]['field'] == field_key and \ + + while data[1]['field'] == field_key or data[1]['field'] == field and \ data[1]['identifier'] == json_content[self.arguments['identifier']]: spans.append((data[1]['start_pos'], data[1]['end_pos'] - data[1]['start_pos'])) tags.append(data[1]['tag'])