Skip to content

Commit

Permalink
FastaStream class
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrg committed Oct 31, 2024
1 parent daf438e commit 80b14eb
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 16 deletions.
63 changes: 48 additions & 15 deletions src/tola/fasta/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import re
import sys
from functools import cached_property
from pathlib import Path

from tola.assembly.assembly import Assembly
Expand All @@ -18,9 +19,17 @@ class IndexUsageError(Exception):
"""Unexpected usage of FastaIndex"""


class FastaIndex:
__slots__ = "fasta_file", "fai_file", "agp_file", "index", "assembly"
IUPAC_COMPLEMENT = bytes.maketrans(
b"ACGTRYMKSWHBVDNacgtrymkswhbvdn",
b"TGCAYRKMSWDVBHNtgcayrkmswdvbhn",
)


def reverse_complement(seq: bytes):
return seq[::-1].translate(IUPAC_COMPLEMENT)


class FastaIndex:
def __init__(self, fasta_file: Path | str):
if not isinstance(fasta_file, Path):
fasta_file = Path(fasta_file)
Expand Down Expand Up @@ -110,6 +119,39 @@ def run_indexing(self):
self.write_index()
self.write_assembly()

@cached_property
def fasta_fileandle(self):
return self.fasta_file.open("rb")

def get_sequence(self, frag: Fragment):
fh = self.fasta_fileandle
info = self.index.get(frag.name)
if not info:
msg = f"No sequence in index named '{frag.name}'"
raise ValueError(msg)
rpl = info.residues_per_line
mll = info.max_line_length
line_end_bytes = mll - rpl
seq = io.BytesIO()

lines_to_seek = mll * ((frag.start - 1) // rpl)
line_offset = (frag.start - 1) % rpl
fh.seek(info.file_offset + lines_to_seek + line_offset)
head = 0
if line_offset:
### Wrong if frag.length < head
head = rpl - line_offset
seq.write(fh.read(head))
fh.seek(line_end_bytes, 1)
remainder = frag.length - head
whole_lines = remainder // rpl
for _ in range(whole_lines):
seq.write(fh.read(rpl))
fh.seek(line_end_bytes, 1)
if tail := remainder % rpl:
seq.write(fh.read(tail))
yield seq


class FastaInfo:
__slots__ = (
Expand Down Expand Up @@ -157,15 +199,8 @@ def fai_row(self, name):
)
return f"{name}\t{numbers}\n"

def regions(self):
s = io.StringIO()
for start, end in self.seq_regions:
s.write(f"{end - start + 1:14,d} {self.name}:{start}-{end}\n")

return s.getvalue()


def index_fasta_file(file: Path, buffer_size: int = 10_000_000):
def index_fasta_file(file: Path, buffer_size: int = 250_000):
name = None
seq_length = None
file_offset = None
Expand Down Expand Up @@ -235,9 +270,8 @@ def process_seq_buffer():

seq_length += len(seq_bytes)

# Opening the file in bytes mode means that Windows ("\r\n") or UNIX
# ("\n") line endings are preserved. It is also about 10% faster than
# decoding to UTF-8.
# Reading the file in bytes mode is about 10% faster than text mode, which
# has the overhead of decoding to UTF-8.
with file.open("rb") as fh:
for line in fh:
# ord(">") == 62
Expand Down Expand Up @@ -287,12 +321,11 @@ def process_seq_buffer():
return idx_dict, asm
else:
msg = f"No data in FASTA file '{file.absolute()}'"
raise ValueError(msg)


if __name__ == "__main__":
for file in sys.argv[1:]:
idx_dict, asm = index_fasta_file(Path(file))
for name, info in idx_dict.items():
# sys.stdout.write("\n")
sys.stdout.write(info.fai_row(name))
# sys.stdout.write(fst.regions())
55 changes: 55 additions & 0 deletions src/tola/fasta/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from io import BufferedIOBase, BytesIO

from tola.assembly.assembly import Assembly
from tola.assembly.gap import Gap
from tola.assembly.scaffold import Scaffold
from tola.fasta.index import FastaIndex


class FastaStream:
def __init__(
self,
out: BufferedIOBase,
index: FastaIndex,
line_length=60,
gap_character=b"N",
):
self.out = out
self.index = index
self.line_length = line_length
self.gap_character = gap_character

def write_assembly(self, assembly: Assembly):
for scffld in assembly.scaffolds:
self.write_scaffold(scffld)

def write_scaffold(self, scaffold: Scaffold):
out = self.out
fai = self.index
line_length = self.line_length
want = line_length

out.write(f">{scaffold.name}\n".encode().lower())
for row in scaffold.rows:
itr = (
self.gap_seq(row)
if isinstance(row, Gap)
else fai.get_sequence(row)
)
for chunk in itr:
chunk.seek(0)
while True:
if seq := chunk.read(want):
out.write(seq)
want -= len(seq)
if want == 0:
out.write(b"\n")
want = line_length
else:
break

if want != line_length:
out.write(b"\n")

def gap_seq(self, gap: Gap):
yield BytesIO(self.gap_character * gap.length)
37 changes: 36 additions & 1 deletion tests/fasta_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import io
import pathlib

import pytest

from tola.fasta.index import FastaIndex, index_fasta_file
from tola.assembly.fragment import Fragment
from tola.assembly.scaffold import Scaffold
from tola.fasta.index import FastaIndex, index_fasta_file, reverse_complement
from tola.fasta.stream import FastaStream


def list_fasta_files():
Expand All @@ -21,3 +25,34 @@ def test_fai(fasta_file):
idx.load_assembly()
asm.header = idx.assembly.header = []
assert str(asm) == str(idx.assembly)


def test_stream_fetch():
fasta_file = pathlib.Path(__file__).parent / "fasta/test.fa"
fai = FastaIndex(fasta_file)
fai.load_index()
fai.load_assembly()
out = io.BytesIO()
fst = FastaStream(out, fai, gap_character=b"N")
# for name, info in fai.index.items():
# frag = Fragment(name, 1, info.length, 1)
# fst.write_scaffold(Scaffold(name, rows=[frag]))
fst.write_assembly(fai.assembly)
fst = fai = None # Close filehandle
print(out.getvalue().decode())

return
fst_bytes = out.getvalue().replace(b"\n", b"")
ref_bytes = fasta_file.read_bytes().replace(b"\r", b"").replace(b"\n", b"")

assert len(ref_bytes) == len(fst_bytes)
assert ref_bytes == fst_bytes


def test_revcomp():
seq = b"ACGTRYMKSWHBVDNacgtrymkswhbvdn"
assert reverse_complement(seq) == b"nhbvdwsmkryacgtNHBVDWSMKRYACGT"


if __name__ == "__main__":
test_stream_fetch()

0 comments on commit 80b14eb

Please sign in to comment.