Skip to content

Commit

Permalink
Merge pull request #33 from marbl/develop
Browse files Browse the repository at this point in the history
v0.8.5 release
  • Loading branch information
alexsweeten authored Sep 4, 2024
2 parents c2bdc62 + 859ad1e commit 3d62edb
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 18 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "ModDotPlot"
version = "0.8.4"
version = "0.8.5"
requires-python = ">= 3.7"
dependencies = [
"pysam",
Expand All @@ -16,7 +16,9 @@ dependencies = [
"mmh3",
"tk",
"setproctitle",
"numpy"
"numpy",
"PIL",
"patchworklib"
]
authors = [
{name = "Alex Sweeten", email = "[email protected]"},
Expand Down
2 changes: 1 addition & 1 deletion src/moddotplot/const.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION = "0.8.4"
VERSION = "0.8.5"
COLS = [
"#query_name",
"query_start",
Expand Down
9 changes: 9 additions & 0 deletions src/moddotplot/moddotplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,12 @@ def get_parser():
help="Preserve diagonal when handling strings of ambiguous homopolymers (eg. long runs of N's).",
)

static_parser.add_argument(
"--grid",
action="store_true",
help="Plot comparative plots in an NxN grid like format.",
)

# TODO: Implement static mode logging options

return parser
Expand Down Expand Up @@ -932,6 +938,9 @@ def main():
seq_sparsity = 2 ** (int(math.log2(seq_sparsity - 1)) + 1)
expectation = round(win / seq_sparsity)

if args.grid:
grid_vals = []

for i in range(len(sequences)):
larger_seq = sequences[i][1]

Expand Down
4 changes: 2 additions & 2 deletions src/moddotplot/parse_fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def generateKmersFromFasta(seq: Sequence[str], k: int, quiet: bool) -> Iterable[
suffix="Completed",
length=40,
)

kmer = seq[i : i + k]
# Remove case sensitivity
kmer = seq[i : i + k].upper()
fh = mmh3.hash(kmer)

# Calculate reverse complement hash directly without the need for translation
Expand Down
202 changes: 189 additions & 13 deletions src/moddotplot/static_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
)
import pandas as pd
import numpy as np
from PIL import Image
import patchworklib as pw
import math
import os

from moddotplot.const import (
DIVERGING_PALETTES,
QUALITATIVE_PALETTES,
Expand Down Expand Up @@ -67,6 +71,43 @@ def make_scale(vals: list) -> list:
return make_m(scaled)


def overlap_axis(rotated_plot, filename, prefix):
scale_factor = math.sqrt(2) + 0.04
new_width = int(rotated_plot.width / scale_factor)
new_height = int(rotated_plot.height / scale_factor)
resized_rotated_plot = rotated_plot.resize((new_width, new_height), Image.LANCZOS)

# Step 3: Overlay the resized rotated heatmap onto the original axes

# Open the original heatmap with axes
image_with_axes = Image.open(filename)

# Create a blank image with the same size as the original
final_image = Image.new("RGBA", image_with_axes.size)

# Calculate the position to center the resized rotated image within the original plot area
x_offset = (final_image.width - resized_rotated_plot.width) // 2
y_offset = (final_image.height - resized_rotated_plot.height) // 2
y_offset += 2400
x_offset += 30

# Paste the original image with axes onto the final image
final_image.paste(image_with_axes, (0, 0))

# Paste the resized rotated plot onto the final image
final_image.paste(resized_rotated_plot, (x_offset, y_offset), resized_rotated_plot)
width, height = final_image.size
cropped_image = final_image.crop((0, height // 2.6, width, height))

# Save or show the final image
cropped_image.save(f"{prefix}_TRI.png")
cropped_image.save(f"{prefix}_TRI.pdf", "PDF", resolution=100.0)

# Remove temp files
if os.path.exists(filename):
os.remove(filename)


def get_colors(sdf, ncolors, is_freq, custom_breakpoints):
assert ncolors > 2 and ncolors < 12
bot = math.floor(min(sdf["perID_by_events"]))
Expand Down Expand Up @@ -316,7 +357,89 @@ def make_tri(sdf, title_name, palette, palette_orientation, colors, breaks, xlim
+ scale_x_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks)
+ scale_y_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks)
+ coord_fixed(ratio=1)
+ facet_grid("r ~ q")
+ labs(x="", y="", title=title_name)
)

# Adjust x-axis label size
p += theme(axis_title_x=element_text())

return p


def make_tri2(sdf, title_name, palette, palette_orientation, colors, breaks, xlim):
if not breaks:
breaks = True
else:
breaks = [float(number) for number in breaks]
if not xlim:
xlim = 0
hexcodes = []
new_hexcodes = []
if palette in DIVERGING_PALETTES:
function_name = getattr(diverging, palette)
hexcodes = function_name.hex_colors
if palette_orientation == "+":
palette_orientation = "-"
else:
palette_orientation = "+"
elif palette in QUALITATIVE_PALETTES:
function_name = getattr(qualitative, palette)
hexcodes = function_name.hex_colors
elif palette in SEQUENTIAL_PALETTES:
function_name = getattr(sequential, palette)
hexcodes = function_name.hex_colors
else:
function_name = getattr(sequential, "Spectral_11")
palette_orientation = "-"
hexcodes = function_name.hex_colors

if palette_orientation == "-":
new_hexcodes = hexcodes[::-1]
else:
new_hexcodes = hexcodes
if colors:
new_hexcodes = colors
max_val = max(sdf["q_en"].max(), sdf["r_en"].max(), xlim)
window = max(sdf["q_en"] - sdf["q_st"])
if max_val < 100000:
x_label = "Genomic Position (Kbp)"
elif max_val < 100000000:
x_label = "Genomic Position (Mbp)"
else:
x_label = "Genomic Position (Gbp)"
p = (
ggplot(sdf)
+ geom_tile(
aes(x="q_st", y="r_st", fill="discrete", height=window, width=window),
alpha=0,
)
+ scale_color_discrete(guide=False)
+ scale_fill_manual(
values=new_hexcodes,
guide=False,
)
+ theme(
legend_position="none",
panel_grid_major=element_blank(),
panel_grid_minor=element_blank(),
plot_background=element_blank(),
panel_background=element_blank(),
axis_line=element_line(color="black"), # Adjust axis line size
axis_text=element_text(
family=["DejaVu Sans"]
), # Change axis text font and size
axis_ticks_major=element_line(),
axis_line_x=element_line(), # Keep the x-axis line
axis_line_y=element_blank(), # Remove the y-axis line
axis_ticks_major_x=element_line(), # Keep x-axis ticks
axis_ticks_major_y=element_blank(), # Remove y-axis ticks
axis_text_x=element_line(), # Keep x-axis text
axis_text_y=element_blank(),
plot_title=element_blank(),
)
+ scale_x_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks)
+ scale_y_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks)
+ coord_fixed(ratio=1)
+ labs(x="", y="", title=title_name)
)

Expand Down Expand Up @@ -376,6 +499,28 @@ def make_hist(sdf, palette, palette_orientation, custom_colors, custom_breakpoin
return p


def create_grid(
singles,
doubles,
directory,
name_x,
name_y,
palette,
palette_orientation,
no_hist,
width,
dpi,
is_freq,
xlim,
custom_colors,
custom_breakpoints,
from_file,
is_pairwise,
axes_label,
):
print(singles)


def create_plots(
sdf,
directory,
Expand All @@ -394,7 +539,6 @@ def create_plots(
is_pairwise,
axes_labels,
):
# TODO: Implement xlim
df = read_df(
sdf,
palette,
Expand Down Expand Up @@ -477,6 +621,15 @@ def create_plots(
axes_labels,
xlim,
)
tri_plot_axis_only = make_tri2(
sdf,
plot_filename,
palette,
palette_orientation,
custom_colors,
axes_labels,
xlim,
)
full_plot = make_dot(
check_st_en_equality(sdf),
plot_filename,
Expand All @@ -486,27 +639,51 @@ def create_plots(
axes_labels,
xlim,
)

print(f"Creating plots and saving to {plot_filename}...\n")
triplot_no_axis = tri_plot + theme(
axis_text_x=element_blank(),
axis_text_y=element_blank(),
axis_title_x=element_blank(),
axis_title_y=element_blank(),
axis_line_x=element_blank(),
axis_line_y=element_blank(),
axis_ticks_major=element_blank(),
axis_ticks_minor=element_blank(),
panel_background=element_blank(),
panel_grid_major=element_blank(),
panel_grid_minor=element_blank(),
plot_title=element_blank(),
)
ggsave(
tri_plot,
triplot_no_axis,
width=9,
height=9,
dpi=dpi,
format="pdf",
filename=f"{plot_filename}_TRI.pdf",
dpi=600,
format="png",
filename=f"{plot_filename}_TRI_NOAXIS.png",
verbose=False,
)
ggsave(
tri_plot,
tri_plot_axis_only,
width=9,
height=9,
dpi=dpi,
dpi=600,
format="png",
filename=f"{plot_filename}_TRI.png",
filename=f"{plot_filename}_AXIS.png",
verbose=False,
)

png_no_axes = Image.open(f"{plot_filename}_TRI_NOAXIS.png")
rotated_png = png_no_axes.rotate(315, expand=True)

rotated_png.save(f"{plot_filename}_ROTATED_TRI_NOAXIS.png")
overlap_axis(rotated_png, f"{plot_filename}_AXIS.png", plot_filename)

if os.path.exists(f"{plot_filename}_ROTATED_TRI_NOAXIS.png"):
os.remove(f"{plot_filename}_ROTATED_TRI_NOAXIS.png")
if os.path.exists(f"{plot_filename}_TRI_NOAXIS.png"):
os.remove(f"{plot_filename}_TRI_NOAXIS.png")

ggsave(
full_plot,
width=9,
Expand All @@ -525,10 +702,9 @@ def create_plots(
filename=f"{plot_filename}_FULL.png",
verbose=False,
)

if no_hist:
print(
f"{plot_filename}_TRI.png, {plot_filename}_TRI.pdf, {plot_filename}_FULL.png and {plot_filename}_FULL.png saved sucessfully. \n"
f"{plot_filename}_TRI.png, {plot_filename}_TRI.pdf, {plot_filename}_FULL.png and {plot_filename}_FULL.pdf saved sucessfully. \n"
)
else:
ggsave(
Expand All @@ -550,5 +726,5 @@ def create_plots(
verbose=False,
)
print(
f"{plot_filename}_TRI.png, {plot_filename}_TRI.pdf, {plot_filename}_FULL.png, {plot_filename}_FULL.png, {plot_filename}_HIST.png and {plot_filename}_HIST.pdf, saved sucessfully. \n"
f"{plot_filename}_TRI.png, {plot_filename}_TRI.pdf, {plot_filename}_FULL.png, {plot_filename}_FULL.pdf, {plot_filename}_HIST.png and {plot_filename}_HIST.pdf, saved sucessfully. \n"
)

0 comments on commit 3d62edb

Please sign in to comment.