Skip to content

Commit

Permalink
FastaSeq object to improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrg committed Nov 4, 2024
1 parent afaf833 commit ba9c137
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 39 deletions.
67 changes: 44 additions & 23 deletions src/tola/fasta/index.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/usr/bin/env python3

import io
import logging
import re
import sys
from functools import cached_property
from io import BytesIO
from pathlib import Path

from tola.assembly.assembly import Assembly
Expand All @@ -13,26 +13,13 @@
from tola.assembly.gap import Gap
from tola.assembly.parser import parse_agp
from tola.assembly.scaffold import Scaffold
from tola.fasta.simple import FastaSeq, revcomp_bytes_io


class IndexUsageError(Exception):
"""Unexpected usage of FastaIndex"""


IUPAC_COMPLEMENT = bytes.maketrans(
b"ACGTRYMKSWHBVDNacgtrymkswhbvdn",
b"TGCAYRKMSWDVBHNtgcayrkmswdvbhn",
)


def reverse_complement(seq: bytes):
return seq[::-1].translate(IUPAC_COMPLEMENT)


def revcomp_bytes_io(seq: io.BytesIO):
return io.BytesIO(reverse_complement(seq.getvalue()))


class FastaInfo:
__slots__ = (
"length",
Expand Down Expand Up @@ -137,7 +124,7 @@ def load_index(self):
residues_per_line,
max_line_length,
)
self.index = idx_dict
self.index = idx_dict

def write_index(self):
idx_dict = self.index
Expand Down Expand Up @@ -177,11 +164,33 @@ def run_indexing(self):
def fasta_fileandle(self):
return self.fasta_file.open("rb")

def get_sequence(self, frag: Fragment):
info = self.index.get(frag.name)
def get_info(self, name):
info = self.index.get(name)
if not info:
msg = f"No sequence in index named '{frag.name}'"
msg = f"No sequence in index named '{name}'"
raise ValueError(msg)
return info

def get_gap_iter(self, gap: Gap, gap_character=b"N"):
"""
Returns an iterator of `BytesIO` objects for gap characters for the Gap.
Keeps memory usage below `buffer_size` for large gaps.
"""
max_length = self.buffer_size
length = gap.length
chunk_count = 1 + (length // max_length)
for i in range(chunk_count):
chunk_start = i * max_length
chunk_end = min(length, chunk_start + max_length)
yield BytesIO(gap_character * (chunk_end - chunk_start))

def get_sequence_iter(self, frag: Fragment):
"""
Returns an iterator of `BytesIO` objects for sequence characters of
the `Fragment`, keeping memory usage by the sequence data below
`buffer_size`.
"""
info = self.get_info(frag.name)

if frag.strand == -1:
return self.rev_chunks(info, frag.start, frag.end)
Expand All @@ -200,13 +209,25 @@ def fwd_chunks(self, info: FastaInfo, start, end):
def rev_chunks(self, info: FastaInfo, start, end):
max_length = self.buffer_size
chunk_count = (end - start) // max_length

# Loop backwards from last chunk to the first, yeilding the
# reverse-complement of each chunk.
for i in range(chunk_count, -1, -1):
offset = i * max_length
chunk_start = start + offset
chunk_end = min(end, chunk_start + max_length - 1)
yield revcomp_bytes_io(self.sequence_bytes(info, chunk_start, chunk_end))

def sequence_bytes(self, info: FastaInfo, start, end):
def all_fasta_seq(self):
for name in self.index:
yield self.get_fasta_seq(name)

def get_fasta_seq(self, name) -> FastaSeq:
info = self.get_info(name)
seq_bytes = self.sequence_bytes(info, 1, info.length).getvalue()
return FastaSeq(name, seq_bytes)

def sequence_bytes(self, info: FastaInfo, start, end) -> BytesIO:
start -= 1 # Switch to Python coordinates
rpl = info.residues_per_line
mll = info.max_line_length
Expand All @@ -222,7 +243,7 @@ def sequence_bytes(self, info: FastaInfo, start, end):
fh = self.fasta_fileandle
fh.seek(info.file_offset + frst_offset + mll * frst_line)

seq = io.BytesIO()
seq = BytesIO()
if frst_line == last_line:
# Sequence fragment is all on one line of the FASTA file
seq.write(fh.read(end - start))
Expand Down Expand Up @@ -253,7 +274,7 @@ def index_fasta_file(file: Path, buffer_size: int = 250_000):
region_end = None
seq_regions = None
line_end_bytes = None
seq_buffer = io.BytesIO()
seq_buffer = BytesIO()

idx_dict = {}
asm = Assembly(
Expand Down Expand Up @@ -329,7 +350,7 @@ def process_seq_buffer():
# character and taking the first element of the array.
# (This also allows space characters following the ">"
# character of the header.)
name = line[1:].split()[0].decode("utf8")
name = line[1:].split()[0].decode()
if not name:
msg = f"Failed to parse sequence name from line:\n{line}"
raise ValueError(msg)
Expand Down
66 changes: 66 additions & 0 deletions src/tola/fasta/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import io

IUPAC_COMPLEMENT = bytes.maketrans(
b"ACGTRYMKSWHBVDNacgtrymkswhbvdn",
b"TGCAYRKMSWDVBHNtgcayrkmswdvbhn",
)


def reverse_complement(seq: bytes):
return seq[::-1].translate(IUPAC_COMPLEMENT)


def revcomp_bytes_io(seq: io.BytesIO):
return io.BytesIO(reverse_complement(seq.getvalue()))


class FastaSeq:
__slots__ = "name", "description", "sequence"

def __init__(self, name: str, sequence: bytes, description: str = None):
self.name = name
self.sequence = sequence
self.description = description

@property
def length(self):
return len(self.sequence)

def __str__(self, line_length=60):
out = io.StringIO()
out.write(f">{self.name}")
if desc := self.description:
out.write(f" {desc}")
out.write("\n")
seq_length = self.length
seq = self.sequence.decode()
line_count = 1 + ((seq_length - 1) // line_length)
for i in range(line_count):
x = i * line_length
y = min(seq_length, x + line_length)
out.write(seq[x:y] + "\n")

return out.getvalue()

def fasta_bytes(self, line_length=60):
out = io.BytesIO()
out.write(b">" + self.name.encode())
if desc := self.description:
out.write(b" " + desc.encode())
out.write(b"\n")
seq_length = self.length
seq = self.sequence
line_count = 1 + ((seq_length - 1) // line_length)
for i in range(line_count):
x = i * line_length
y = min(seq_length, x + line_length)
out.write(seq[x:y] + b"\n")

return out.getvalue()

def rev_comp(self):
return FastaSeq(
self.name,
reverse_complement(self.sequence),
self.description,
)
9 changes: 3 additions & 6 deletions src/tola/fasta/stream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from io import BufferedIOBase, BytesIO
from io import BufferedIOBase

from tola.assembly.assembly import Assembly
from tola.assembly.gap import Gap
Expand Down Expand Up @@ -32,9 +32,9 @@ def write_scaffold(self, scaffold: Scaffold):
out.write(f">{scaffold.name}\n".encode())
for row in scaffold.rows:
itr = (
self.gap_seq(row)
fai.get_gap_iter(row, self.gap_character)
if isinstance(row, Gap)
else fai.get_sequence(row)
else fai.get_sequence_iter(row)
)
for chunk in itr:
chunk.seek(0)
Expand All @@ -50,6 +50,3 @@ def write_scaffold(self, scaffold: Scaffold):

if want != line_length:
out.write(b"\n")

def gap_seq(self, gap: Gap):
yield BytesIO(self.gap_character * gap.length)
42 changes: 32 additions & 10 deletions tests/fasta_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

# from tola.assembly.fragment import Fragment
# from tola.assembly.scaffold import Scaffold
from tola.fasta.index import FastaIndex, index_fasta_file, reverse_complement
from tola.fasta.index import FastaIndex, index_fasta_file
from tola.fasta.simple import FastaSeq, reverse_complement
from tola.fasta.stream import FastaStream


Expand All @@ -27,22 +28,43 @@ def test_fai(fasta_file):
assert str(asm) == str(idx.assembly)


def test_stream_fetch():
def test_simple_fasta_bytes():
name = "test"
desc = "A test sequence"
seq = b"n" * 60
ref_str = f">{name} {desc}\n{seq.decode()}\n"

fst = FastaSeq(name, seq, desc)
assert str(fst) == ref_str
assert fst.fasta_bytes() == ref_str.encode()


@pytest.mark.parametrize("buf_size", [5, 7, 100, 200])
def test_stream_fetch(buf_size):
fasta_file = pathlib.Path(__file__).parent / "fasta/test.fa"
fai = FastaIndex(fasta_file, buffer_size=7)
ref_fai = FastaIndex(fasta_file)
ref_fai.load_index()

# Check we have the first and last sequence
assert ref_fai.index.get('RAND-001')
assert ref_fai.index.get('RAND-100')

ref_io = io.BytesIO()
for seq in ref_fai.all_fasta_seq():
ref_io.write(seq.fasta_bytes())
ref_bytes = ref_io.getvalue()

fai = FastaIndex(fasta_file, buffer_size=buf_size)
fai.load_index()
fai.load_assembly()

out = io.BytesIO()
fst = FastaStream(out, fai, gap_character=b"n")
fst.write_assembly(fai.assembly)
fst_bytes = out.getvalue()

print(out.getvalue().decode(), end="")

ref_bytes = fasta_file.read_bytes().replace(b"\r", b"").replace(b"\n", b"")
fst_bytes = out.getvalue().replace(b"\n", b"")

assert len(ref_bytes) == len(fst_bytes)
assert ref_bytes == fst_bytes
# Decode bytes to string so that pytest diff works
assert ref_bytes.decode() == fst_bytes.decode()


def test_revcomp():
Expand Down

0 comments on commit ba9c137

Please sign in to comment.