Skip to content

Commit

Permalink
Improved 3D visualizations - annotations and legend
Browse files Browse the repository at this point in the history
  • Loading branch information
mcgalcode committed Jul 30, 2024
1 parent 3f2ccad commit aa6fd08
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/pylattica/core/runner/asynchronous_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ def _add_sites_to_queue():
break

result.set_output(live_state)
return result
return result
16 changes: 10 additions & 6 deletions src/pylattica/core/simulation_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ def from_file(cls, fpath):
def from_dict(cls, res_dict):
diffs = res_dict["diffs"]
compress_freq = res_dict.get("compress_freq", 1)
res = cls(SimulationState.from_dict(res_dict["initial_state"]), compress_freq=compress_freq)
res = cls(
SimulationState.from_dict(res_dict["initial_state"]),
compress_freq=compress_freq,
)
for diff in diffs:
if SITES in diff:
diff[SITES] = { int(k): v for k, v in diff[SITES].items() }
diff[SITES] = {int(k): v for k, v in diff[SITES].items()}
if GENERAL not in diff and SITES not in diff:
diff = { int(k): v for k, v in diff.items() }
diff = {int(k): v for k, v in diff.items()}
res.add_step(diff)

return res
Expand Down Expand Up @@ -176,8 +179,10 @@ def compress_result(result: SimulationResult, num_steps: int):
# total steps is the actual number of diffs stored, not the number of original simulation steps taken
total_steps = len(result)
if num_steps >= total_steps:
raise ValueError(f"Cannot upsample SimulationResult of length {total_steps} to size {num_steps}.")

raise ValueError(
f"Cannot upsample SimulationResult of length {total_steps} to size {num_steps}."
)

exact_sample_freq = total_steps / (num_steps)
# print(total_steps, current_sample_freq)
total_compress_freq = exact_sample_freq * result.compress_freq
Expand All @@ -196,4 +201,3 @@ def compress_result(result: SimulationResult, num_steps: int):
compressed_result.add_step(live_state.as_state_update())
next_sample_step += exact_sample_freq
return compressed_result

4 changes: 2 additions & 2 deletions src/pylattica/core/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_site_state(self, site_id: int) -> Dict:
"""
return self._state[SITES].get(site_id)

def get_general_state(self, key: str = None, default = None) -> Dict:
def get_general_state(self, key: str = None, default=None) -> Dict:
"""Returns the general state.
Returns
Expand Down Expand Up @@ -177,7 +177,7 @@ def copy(self) -> SimulationState:
The copy of this SimulationState
"""
return SimulationState(self._state)

def as_state_update(self) -> Dict:
return copy.deepcopy(self._state)

Expand Down
29 changes: 19 additions & 10 deletions src/pylattica/visualization/result_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,24 @@

from PIL import Image

from typing import Callable

_dsr_globals = {}


def default_annotation_builder(step, step_no):
return f"Step {step_no}"


class ResultArtist:
"""A class for rendering simulation results as animated GIFs."""

def __init__(self, step_artist: StructureArtist, result: SimulationResult):
def __init__(
self,
step_artist: StructureArtist,
result: SimulationResult,
annotation_builder: Callable = default_annotation_builder,
):
"""Instantiates the ResultArtist class.
Parameters
Expand All @@ -26,6 +37,7 @@ def __init__(self, step_artist: StructureArtist, result: SimulationResult):
"""
self._step_artist = step_artist
self.result = result
self.annotation_builder = annotation_builder

def _get_images(self, **kwargs):
draw_freq = kwargs.get("draw_freq", 1)
Expand All @@ -47,21 +59,18 @@ def _get_images(self, **kwargs):
with mp.get_context("fork").Pool(PROCESSES) as pool:
params = []
for idx in indices:
label = f"Step {idx}"
step_kwargs = {**kwargs, "label": label}
step = self.result.get_step(idx)
label = self.annotation_builder(step, idx)
step_kwargs = {**kwargs, "label": label}

params.append([step, step_kwargs])

for img in pool.starmap(_get_img_parallel, params):
imgs.append(img)

return imgs

def jupyter_show_step(
self,
step_no: int,
cell_size=20,
) -> None:
def jupyter_show_step(self, step_no: int, cell_size=20, **kwargs) -> None:
"""In a jupyter notebook environment, visualizes the step as a color coded phase grid.
Parameters
Expand All @@ -71,10 +80,10 @@ def jupyter_show_step(
cell_size : int, optional
The size of each simulation cell, in pixels, by default 20
"""
label = f"Step {step_no}" # pragma: no cover
step = self.result.get_step(step_no) # pragma: no cover
label = self.annotation_builder(step, step_no)
self._step_artist.jupyter_show(
step, label=label, cell_size=cell_size
step, label=label, cell_size=cell_size, **kwargs
) # pragma: no cover

def jupyter_play(self, cell_size: int = 20, wait: int = 1, **kwargs):
Expand Down
44 changes: 42 additions & 2 deletions src/pylattica/visualization/square_grid_artist_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.lines import Line2D


class SquareGridArtist3D(StructureArtist):
Expand Down Expand Up @@ -46,11 +47,50 @@ def _draw_image(self, state: SimulationState, **kwargs):
colors = [0.8, 0.8, 0.8, 0.2]
ax.voxels(data, facecolors=colors, edgecolor="k", linewidth=0)
else:
colors = np.array(color_cache[color]) / 255
colors = list(np.array(color_cache[color]) / 255)
ax.voxels(data, facecolors=colors, edgecolor="k", linewidth=0.25)

ax.legend()
if kwargs.get("show_legend") == True:
legend = self.cell_artist.get_legend(state)
legend_handles = []
for phase, color in legend.items():
legend_handles.append(
Line2D(
[0],
[0],
marker="s",
color="w",
markerfacecolor=list(np.array(color) / 255),
markersize=10,
label=phase,
)
)

# Add custom legend to the plot
legend_font_props = {"family": "Lato", "size": 14}

plt.legend(
handles=legend_handles,
loc="lower center",
prop=legend_font_props,
ncols=5,
frameon=False,
)
plt.axis("off")
if kwargs.get("label") is not None:
x_text, y_text, z_text = 18, -5, 30

# Add the text
annotation_font = {
"size": 16,
"family": "Lato",
"color": np.array([194, 29, 63]) / 255,
"weight": "bold",
}
ax.text(
x_text, y_text, z_text, kwargs.get("label"), fontdict=annotation_font
)

fig = ax.get_figure()
buf = io.BytesIO()
fig.savefig(buf)
Expand Down

0 comments on commit aa6fd08

Please sign in to comment.