Skip to content

Commit

Permalink
FastaIndex object for building and loading FAI and AGP indexes
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrg committed Oct 30, 2024
1 parent 8bdf90b commit daf438e
Show file tree
Hide file tree
Showing 3 changed files with 734 additions and 48 deletions.
212 changes: 173 additions & 39 deletions src/tola/fasta/index.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,150 @@
#!/usr/bin/env python3

import io
import logging
import re
import sys
from pathlib import Path

from tola.assembly.assembly import Assembly
from tola.assembly.format import format_agp
from tola.assembly.fragment import Fragment
from tola.assembly.gap import Gap
from tola.assembly.parser import parse_agp
from tola.assembly.scaffold import Scaffold


class IndexUsageError(Exception):
"""Unexpected usage of FastaIndex"""


class FastaIndex:
__slots__ = "fasta_file", "fai_file", "agp_file", "index", "assembly"

def __init__(self, fasta_file: Path | str):
if not isinstance(fasta_file, Path):
fasta_file = Path(fasta_file)
if not fasta_file.exists():
missing = str(fasta_file)
raise FileNotFoundError(missing)
self.fasta_file = fasta_file
self.fai_file = Path(str(fasta_file) + ".fai")
self.agp_file = Path(str(fasta_file) + ".agp")
self.index = None
self.assembly = None

def auto_load(self):
if self.check_for_index_files():
self.load_index()
self.load_assembly()
else:
self.run_indexing()

def check_for_index_files(self):
"""
Check that the .agp and fai files exist and are newer than the FASTA
sequence file.
"""
fasta_mtime = self.fasta_file.stat().st_mtime
for idx_file in self.fai_file, self.agp_file:
if not idx_file.exists():
return False
if not idx_file.stat().st_mtime > fasta_mtime:
logging.warning(
f"Index file '{idx_file}' is older than"
f" FASTA file '{self.fasta_file}'"
)
return False
return True

def load_index(self):
if self.index:
msg = "Index FAI already loaded"
raise IndexUsageError(msg)

idx_dict = {}
with self.fai_file.open() as idx:
for line in idx:
name, length, file_offset, residues_per_line, max_line_length = (
line.split()
)
idx_dict[name] = FastaInfo(
length,
file_offset,
residues_per_line,
max_line_length,
)
self.index = idx_dict

def write_index(self):
idx_dict = self.index
if not idx_dict:
msg = "No index data to write to FAI file"
raise IndexUsageError(msg)
if self.fai_file.exists():
logging.warning(f"Overwriting FAI index file '{self.fai_file}'")
with self.fai_file.open("w") as idx_fh:
for name, info in idx_dict.items():
idx_fh.write(info.fai_row(name))

def load_assembly(self):
if self.assembly:
msg = "Assembly AGP already loaded"
raise IndexUsageError(msg)
self.assembly = parse_agp(self.agp_file.open(), self.fasta_file.name)

def write_assembly(self):
asm = self.assembly
if not asm:
msg = "No assembly data to write to AGP file"
raise IndexUsageError(msg)
if self.agp_file.exists():
logging.warning(f"Overwriting AGP assembly file '{self.agp_file}'")
with self.agp_file.open("w") as agp_fh:
format_agp(asm, agp_fh)

def run_indexing(self):
idx_dict, assembly = index_fasta_file(self.fasta_file)
self.index = idx_dict
self.assembly = assembly
self.write_index()
self.write_assembly()


class FastaInfo:
__slots__ = (
"name",
"length",
"file_offset",
"residues_per_line",
"max_line_length",
"seq_regions",
)

def __init__(
self,
name,
length,
file_offset,
residues_per_line,
max_line_length,
seq_regions=None,
):
self.name = name
self.length = length
self.file_offset = file_offset
self.residues_per_line = residues_per_line
self.max_line_length = max_line_length
self.seq_regions = seq_regions

def fai_row(self):
self.length = int(length)
self.file_offset = int(file_offset)
self.residues_per_line = int(residues_per_line)
self.max_line_length = int(max_line_length)

def __eq__(self, othr):
for attr in self.__slots__:
if getattr(self, attr) != getattr(othr, attr):
return False
return True

def __repr__(self):
return (
"FastaInfo("
+ (", ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__))
+ ")"
)

def fai_row(self, name):
"""Returns a row for a Fasta Index (.fai) file."""
numbers = "\t".join(
str(x)
Expand All @@ -43,7 +155,7 @@ def fai_row(self):
self.max_line_length,
)
)
return f"{self.name}\t{numbers}\n"
return f"{name}\t{numbers}\n"

def regions(self):
s = io.StringIO()
Expand All @@ -64,27 +176,47 @@ def index_fasta_file(file: Path, buffer_size: int = 10_000_000):
line_end_bytes = None
seq_buffer = io.BytesIO()

info = []
idx_dict = {}
asm = Assembly(
file.name,
header=[f"Built from FASTA file '{file.absolute()}'"],
)

def load_info():
def store_info():
process_seq_buffer()
if region_end:
seq_regions.append((region_start + 1, region_end))
info.append(
FastaInfo(
name,
seq_length,
file_offset,
residues_per_line,
residues_per_line + line_end_bytes,
seq_regions,
)
seq_regions.append((region_start, region_end))

if idx_dict.get(name):
msg = f"More than one sequence named '{name}' in FASTA file '{file}'"
raise ValueError(msg)
idx_dict[name] = FastaInfo(
seq_length,
file_offset,
residues_per_line,
residues_per_line + line_end_bytes,
)

scffld = Scaffold(name)
prev = (0, 0)
for region in seq_regions:
start, end = region
if start != prev[1]:
gap_length = start - prev[1]
scffld.add_row(Gap(gap_length, "scaffold"))
scffld.add_row(Fragment(name, start + 1, end, 1))
prev = region
if rem := seq_length - prev[1]:
scffld.add_row(Gap(rem, "scaffold"))

asm.add_scaffold(scffld)

def process_seq_buffer():
nonlocal seq_length
nonlocal region_start
nonlocal region_end
# Outer scope variables which we "rebind" in this function.
# See https://peps.python.org/pep-3104/ for explanation.
nonlocal seq_length, region_start, region_end

# Take the value from the sequence buffer and empty it
seq_bytes = seq_buffer.getvalue()
seq_buffer.seek(0)
seq_buffer.truncate(0)
Expand All @@ -97,13 +229,12 @@ def process_seq_buffer():
region_end = end
else:
if region_end:
seq_regions.append((region_start + 1, region_end))
seq_regions.append((region_start, region_end))
region_start = start
region_end = end

seq_length += len(seq_bytes)


# Opening the file in bytes mode means that Windows ("\r\n") or UNIX
# ("\n") line endings are preserved. It is also about 10% faster than
# decoding to UTF-8.
Expand All @@ -114,12 +245,12 @@ def process_seq_buffer():
# If this isn't the first sequence in the file, store the
# accumulated data from the previous sequence.
if name:
load_info()
store_info()

# Get new name by splitting on whitespace beyond the first
# character and taking the first element of the array. This
# also allows space characters following the ">" character of
# the header.
# character and taking the first element of the array.
# (This also allows space characters following the ">"
# character of the header.)
name = line[1:].split()[0].decode("utf8")
if not name:
msg = f"Failed to parse sequence name from line:\n{line}"
Expand Down Expand Up @@ -150,15 +281,18 @@ def process_seq_buffer():

# Store info for the last sequence in the file
if name:
load_info()
store_info()

return info
if idx_dict:
return idx_dict, asm
else:
msg = f"No data in FASTA file '{file.absolute()}'"


if __name__ == "__main__":
for file in sys.argv[1:]:
info = index_fasta_file(Path(file))
for fst in info:
idx_dict, asm = index_fasta_file(Path(file))
for name, info in idx_dict.items():
# sys.stdout.write("\n")
sys.stdout.write(fst.fai_row())
sys.stdout.write(info.fai_row(name))
# sys.stdout.write(fst.regions())
Loading

0 comments on commit daf438e

Please sign in to comment.