Skip to content

Commit

Permalink
code clean
Browse files Browse the repository at this point in the history
also version bump down
  • Loading branch information
ACEnglish committed Jun 3, 2024
1 parent 616d215 commit 8a4b216
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 56 deletions.
10 changes: 5 additions & 5 deletions repo_utils/tdb_ssshtests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ tdb_check() {
assert_exit_code 0
res_name=$1
ans_name=${2:-$1}
if [ "${STRIP,,}" == "true" ]; then
strip_option="--strip"
if [ "${STRIP}" == "true" ]; then
strip_opt="--strip"
fi
$tdb equal $strip_option $INDIR/tdb/$ans_name $OD/$res_name/
$tdb equal $strip_opt --join $INDIR/tdb/$ans_name $OD/$res_name/
assert_equal $? 0
}

Expand Down Expand Up @@ -82,13 +82,13 @@ fi

run test_merge $tdb merge -o $OD/merge1.tdb $INDIR/tdb/HG00438_chr14.tdb/ $INDIR/tdb/HG00741_chr14.tdb/ $INDIR/tdb/HG02630_chr14.tdb/
if [ $test_merge ]; then
tdb_check merge1.tdb
STRIP=true tdb_check merge1.tdb
fi

run test_merge_into $tdb create -o $OD/merge_into.tdb $INDIR/vcf/HG00741_chr14.vcf.gz
run test_merge_into $tdb merge --no-compress --mem 1 --into $OD/merge_into.tdb $INDIR/tdb/HG02630_chr14.tdb $INDIR/tdb/HG00438_chr14.tdb
if [ $test_merge_into ]; then
tdb_check merge_into.tdb
STRIP=true tdb_check merge_into.tdb
fi

run test_bad_merge $tdb merge --into $OD/mergex -o $OD/merge1 $INDIR/tdb/HG00438_chr14.tdb/ $INDIR/tdb/HG00438_chr14.tdb/ $INDIR/HG00741_chr14
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def read(rel_path):
with codecs.open(os.path.join(here, rel_path), 'r') as fp:
return fp.read()

VERSION = "1.0.0"
VERSION = "0.2.0"

setup(
name="tdb",
Expand Down
93 changes: 43 additions & 50 deletions tdb/create.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""
Faster creation of a tdb
parse each line and write the outputs.
No need to worry about getting crazy with the indexes or joins
Turn a vcf into a tdb
"""
import gc
import os
Expand All @@ -18,7 +15,7 @@

import tdb

#pylint: disable=global-statement
# pylint: disable=global-statement

DTYPES = {"LocusID": (pa.uint32(), np.uint32),
"chrom": (pa.string(), str),
Expand All @@ -37,10 +34,11 @@
S_COLUMNS = ["LocusID", "allele_number", "spanning_reads", "length_range_lower",
"length_range_upper", "average_methylation"]

AVAILMEM = 1e11 # 100GB default
AVAILMEM = 1e11 # 100GB default
# Give 20% overhead since our memory tracking probably underestimates
USEDMEM = int(AVAILMEM * 0.20)


def check_args(args):
"""
Preflight checks on arguments. Returns True if there is a problem
Expand All @@ -57,50 +55,37 @@ def check_args(args):
logging.error(f"Input {args.input} does not exist")
check_fail = True
if not args.input.rstrip('/').endswith((".vcf", ".vcf.gz")):
logging.error(f"Unrecognized file extension on {args.input}. Expected .vcf .vcf.gz")
logging.error(
f"Unrecognized file extension on {args.input}. Expected .vcf .vcf.gz")
check_fail = True
return check_fail

def make_locus_writer(output_file, comp):
"""
Parquet writer for the locus table
"""
schema = pa.schema([
pa.field(key, DTYPES[key][0]) for key in L_COLUMNS
])
return pq.ParquetWriter(output_file, schema, compression=comp)

def make_allele_writer(output_file, comp):
"""
Parquet writer for the allele table
"""
schema = pa.schema([
pa.field(key, DTYPES[key][0]) for key in A_COLUMNS
])
return pq.ParquetWriter(output_file, schema, compression=comp)

def make_sample_writer(output_file, comp):
"""
Parquet writer for the sample table
"""
schema = pa.schema([
pa.field(key, DTYPES[key][0]) for key in S_COLUMNS
])
return pq.ParquetWriter(output_file, schema, compression=comp)

def make_parquets(samples, out_dir, compression):
"""
Make the parquet file handlers for the output database
Parquet writer for the tables
"""
comp = "GZIP" if compression else None
ret = {}
ret['locus'] = make_locus_writer(os.path.join(out_dir, 'locus.pq'), comp)
ret['allele'] = make_allele_writer(os.path.join(out_dir, 'allele.pq'), comp)

comp = "GZIP" if compression else None

fn = os.path.join(out_dir, 'locus.pq')
schema = pa.schema([pa.field(key, DTYPES[key][0]) for key in L_COLUMNS])
ret['locus'] = pq.ParquetWriter(fn, schema, compression=comp)

fn = os.path.join(out_dir, 'allele.pq')
schema = pa.schema([pa.field(key, DTYPES[key][0]) for key in A_COLUMNS])
ret['allele'] = pq.ParquetWriter(fn, schema, compression=comp)

ret['sample'] = {}
schema = pa.schema([pa.field(key, DTYPES[key][0]) for key in S_COLUMNS])
for name in samples:
ret['sample'][name] = make_sample_writer(os.path.join(out_dir, f"sample.{name}.pq"), comp)
fn = os.path.join(out_dir, f"sample.{name}.pq")
ret['sample'][name] = pq.ParquetWriter(fn, schema, compression=comp)

return ret


def sample_extract(locus_id, fmt_fields):
"""
Given a dict from a vcf record sample, turn them into sample rows
Expand All @@ -111,7 +96,6 @@ def sample_extract(locus_id, fmt_fields):
fmt_fields['ALLR'],
fmt_fields['AM'])
for an, sd, allr, am in view:
# None isn't imported
if an is None:
continue
lrl, lru = allr.split('-')
Expand All @@ -120,6 +104,7 @@ def sample_extract(locus_id, fmt_fields):
ret.append([locus_id, an, sd, lrl, lru, am])
return ret


def translate_entry(entry, locus_id):
"""
return three things,
Expand All @@ -130,25 +115,30 @@ def translate_entry(entry, locus_id):
global USEDMEM
locus = [locus_id, entry.chrom, entry.start, entry.stop]
USEDMEM += sys.getsizeof(locus)
alleles = [(locus_id, allele_number, len(sequence), sequence.encode("utf8"))

alleles = [(locus_id, allele_number, len(sequence),
b'' if sequence is None else sequence.encode("utf8"))
for allele_number, sequence in enumerate(entry.alleles)]
USEDMEM += sys.getsizeof(alleles)

samples = {}
for sample, m_d in entry.samples.items():
samples[sample] = sample_extract(locus_id, m_d)
# Approximate usage of each row of a sample table
USEDMEM += 400 * len(samples)

return locus, alleles, samples


def convert_buffer(vcf, samples, stats):
"""
Converts a number of vcf entries.
Tries to monitor memory to not buffer too many
"""
m_buffer = {'locus':[],
'allele':[],
'sample': {_:[] for _ in samples}
}
m_buffer = {'locus': [],
'allele': [],
'sample': {_: [] for _ in samples}
}
# Flag for telling main loop when we're finished
cvt_any = False
while AVAILMEM > USEDMEM:
Expand All @@ -158,7 +148,8 @@ def convert_buffer(vcf, samples, stats):
break

cvt_any = True
cur_locus, cur_allele, cur_sample = translate_entry(entry, stats['locus'])
cur_locus, cur_allele, cur_sample = translate_entry(
entry, stats['locus'])

m_buffer['locus'].append(cur_locus)
m_buffer['allele'].extend(cur_allele)
Expand All @@ -173,10 +164,10 @@ def convert_buffer(vcf, samples, stats):

return m_buffer, cvt_any


def write_tables(cur_tables, tables):
"""
Write the cur_tables entries to the output tables
So since cur_tables will have a list of vales, I should probably do pa.Table from pandas
"""
global USEDMEM
schema = pa.schema({key: DTYPES[key][0] for key in L_COLUMNS})
Expand All @@ -191,23 +182,25 @@ def write_tables(cur_tables, tables):

schema = pa.schema({key: DTYPES[key][0] for key in S_COLUMNS})
for name, out_samp in tables["sample"].items():
sdf = pd.DataFrame(cur_tables["sample"][name], columns=S_COLUMNS, copy=False)
sdf = pd.DataFrame(cur_tables["sample"][name],
columns=S_COLUMNS, copy=False)
sample = pa.Table.from_pandas(sdf, schema=schema, preserve_index=False)
out_samp.write(sample)
# Reset memory
USEDMEM = int(AVAILMEM * 0.20)


def create_main(args):
"""
Create a new tdb from multiple input calls
"""
global AVAILMEM
global USEDMEM
parser = argparse.ArgumentParser(prog="tdb create", description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("-o", "--output", metavar="OUT", required=True,
help="Output tdb directory")
parser.add_argument("--mem", metavar="MEM", type=int, default=100,
parser.add_argument("--mem", metavar="MEM", type=int, default=4,
help="Memory in GB available to buffer reading (%(default)s)")
parser.add_argument("--no-compress", action="store_false",
help="Don't compress the database")
Expand All @@ -231,7 +224,7 @@ def create_main(args):

vcf = pysam.VariantFile(args.input)
samples = list(vcf.header.samples)
stats = {"locus":0, "allele":0, "sample":0}
stats = {"locus": 0, "allele": 0, "sample": 0}

tables = make_parquets(samples, args.output, args.no_compress)
logging.info("Converting VCF with %d samples", len(samples))
Expand Down

0 comments on commit 8a4b216

Please sign in to comment.