Skip to content

Commit

Permalink
Merge pull request #14 from jobar8:jobar8/issue12
Browse files Browse the repository at this point in the history
Add Load button
  • Loading branch information
jobar8 authored Oct 17, 2024
2 parents 197286e + e32f0cd commit 91c2687
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 44 deletions.
65 changes: 30 additions & 35 deletions src/attractor_explorer/attractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import inspect
from collections.abc import Callable
from pathlib import Path
from typing import Any

Expand All @@ -14,8 +13,6 @@
from numba import jit
from numpy import cos, fabs, sin, sqrt
from param import concrete_descendents
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema

from attractor_explorer.maths import trajectory

Expand All @@ -27,18 +24,13 @@ class Attractor(param.Parameterized):

x = param.Number(0, softbounds=(-2, 2), step=0.01, doc='Starting x value', precedence=-1)
y = param.Number(0, softbounds=(-2, 2), step=0.01, doc='Starting y value', precedence=-1)
a = param.Number(1.7, bounds=(-3, 3), step=0.05, doc='Attractor parameter a', precedence=0.2)
b = param.Number(1.7, bounds=(-3, 3), step=0.05, doc='Attractor parameter b', precedence=0.2)
a = param.Number(1.7, softbounds=(-3, 3), step=0.05, doc='Attractor parameter a', precedence=0.2)
b = param.Number(1.7, softbounds=(-3, 3), step=0.05, doc='Attractor parameter b', precedence=0.2)

colormap: str = 'kgy'
equations: tuple[str, ...] = ()
__abstract = True

# This allows pydantic to support this class
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
return core_schema.no_info_after_validator_function(cls, handler(str)) # type: ignore

@staticmethod
@jit(cache=True)
def fn(x, y, a, b, *o):
Expand All @@ -58,8 +50,7 @@ def update(self, args: dict[str, Any]):
self.__setattr__(key, value)

def vals(self):
# return [self.__class__.name] + [self.colormap] + [getattr(self, p) for p in self.signature()]
return [self.__class__.__name__] + [self.colormap] + [getattr(self, p) for p in self.signature()]
return [self.__class__.name] + [self.colormap] + [getattr(self, p) for p in self.signature()]

def signature(self) -> list[str]:
"""Returns the calling signature expected by this attractor function"""
Expand Down Expand Up @@ -210,39 +201,44 @@ def fn(x, y, a, b, gamma, omega, lambda_, degree, *o): # noqa: ARG004
return p * x + gamma * zreal - omega * y, p * y - gamma * zimag + omega * x


class ParameterSets(BaseModel):
class ParameterSets(param.Parameterized):
"""
Allows selection from sets of pre-defined parameters saved in YAML.
"""

data_folder: Path = Path(__file__).parent / 'data'
input_examples_filename: str = 'attractors.yml'
output_examples_filename: str = 'saved_attractors.yml'
current: Callable = lambda: None
input_examples_filename = param.Filename('attractors.yml', search_paths=[data_folder.as_posix()])
output_examples_filename = param.Filename(
'saved_attractors.yml', check_exists=False, search_paths=[data_folder.as_posix()]
)
current = param.Callable(lambda: None, precedence=-1)
attractors: dict[str, Attractor]

example: list[str] = []
examples: list[list[str]] = []
attractors: dict[str, Attractor] = {}
load = param.Action(lambda x: x._load())
randomize = param.Action(lambda x: x._randomize())
sort = param.Action(lambda x: x._sort())
remember_this_one = param.Action(lambda x: x._remember())
# save = param.Action(lambda x: x._save(), precedence=0.8)
example = param.Selector(objects=[[]], precedence=1, instantiate=False)

def __init__(self, **params):
super().__init__(**params)

self._load()

self.attractors = {k: v(name=f'{k} parameters') for k, v in sorted(concrete_descendents(Attractor).items())}
# check parameters for each kind of attractors
for k in self.attractors:
# update attractor instances with the first example of each type
for attractor in self.attractors:
try:
self.get_attractor(k, *self.args(k)[0])
self.get_attractor(attractor, *self.args(attractor)[0])
except IndexError:
pass

def _load(self):
with Path(self.data_folder / self.input_examples_filename).open('r') as f:
with Path(self.input_examples_filename).open('r') as f: # type: ignore
vals = yaml.safe_load(f)
if len(vals) > 0:
# self.param.example.objects[:] = vals
self.examples = vals
self.param.example.objects[:] = vals
self.example = vals[0]

# def _save(self):
Expand All @@ -256,24 +252,23 @@ def __call__(self):
return self.example

def _randomize(self):
# RNG.shuffle(self.param.example.objects)
RNG.shuffle(self.examples)
RNG.shuffle(self.param.example.objects)
self.example = self.param.example.objects[0]

def _sort(self):
# self.param.example.objects[:] = sorted(self.param.example.objects)
self.examples = sorted(self.examples)
self.param.example.objects[:] = sorted(self.param.example.objects)
self.example = self.param.example.objects[0]

def _add_item(self, item):
self.examples += [item]
self.param.example.objects += [item]
self.example = item

# def _remember(self):
# vals = self.current().vals()
# self._add_item(vals)
def _remember(self):
vals = self.current().vals() # type: ignore
self._add_item(vals)

def args(self, name):
# return [v[1:] for v in self.param.example.objects if v[0] == name]
return [v[1:] for v in self.examples if v[0] == name]
return [v[1:] for v in self.param.example.objects if v[0] == name]

def get_attractor(self, name: str, *args) -> Attractor:
"""Factory function to return an Attractor object with the given name and arg values."""
Expand Down
48 changes: 41 additions & 7 deletions src/attractor_explorer/attractors_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
:root {
--background-color: black
--design-background-color: black;
--design-background-text-color: white;
--design-background-text-color: #ff7f00; # orange
--panel-surface-color: black;
}
#sidebar {
Expand All @@ -42,13 +42,13 @@
pn.config.throttled = True


params = at.ParameterSets(name='Attractors')


class AttractorsExplorer(pn.viewable.Viewer):
"""Select and render attractors."""

attractor_type = param.Selector(default=params.attractors['Clifford'], objects=params.attractors, precedence=0.5)
param_sets = at.ParameterSets(name='Attractors')
attractor_type = param.Selector(
default=param_sets.attractors['Clifford'], objects=param_sets.attractors, precedence=0.5
)

resolution = param.Selector(
doc='Resolution (n points)',
Expand Down Expand Up @@ -94,9 +94,17 @@ def equations(self):
def set_npoints(self):
self.n_points = RESOLUTIONS[self.resolution]

@pn.depends('param_sets.example', watch=True)
def update_attractor(self):
a = self.param_sets.get_attractor(*self.param_sets.example)
if a is not self.attractor_type:
self.param.update(attractor_type=a)
self.colormap.value = colormaps[self.param_sets.example[1]] # type: ignore


ats = AttractorsExplorer(name='Attractors Explorer')
params.current = lambda: ats.attractor_type
ats.param_sets.current = lambda: ats.attractor_type


pn.template.FastListTemplate(
title='Attractor Explorer',
Expand All @@ -109,6 +117,7 @@ def set_npoints(self):
'orientation': 'vertical',
'button_type': 'warning',
'button_style': 'outline',
'stylesheets': [':host(.outline) .bk-btn-group .bk-btn-warning.bk-active {color:white}'],
},
'resolution': {'widget_type': pn.widgets.RadioButtonGroup, 'button_type': 'success'},
},
Expand All @@ -117,8 +126,33 @@ def set_npoints(self):
),
ats.interpolation,
ats.colormap,
pn.layout.Card(
pn.Param(
ats.param_sets.param,
widgets={
'input_examples_filename': {
'widget_type': pn.widgets.TextInput,
'stylesheets': ['.bk-input-group > label {background-color: black}'],
'name': '',
},
'example': {
'stylesheets': ['bk-panel-models-widgets-CustomSelect {background: #2b3035; color: black}'],
'name': '',
},
},
show_name=False,
parameters=['input_examples_filename', 'load', 'example', 'remember_this_one'],
),
title='Load and Save',
collapsed=True,
header_color='white',
header_background='#2c71b4',
),
],
main=[
ats.equations,
ats,
],
main=[ats.equations, ats],
main_layout=None,
sidebar_width=SIDEBAR_WIDTH,
sidebar_footer=__version__,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_attractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_parametersets():
"""Test the ParameterSets class."""
params = at.ParameterSets(name='Attractors')
assert params.example == ['Svensson', 'fire', 0, 0, 1.4, 1.56, 1.4, -6.56]
assert len(params.examples) == 75
assert len(params.param.example.objects) == 75
assert len(params.args('Svensson')) == 4
assert params.attractors['GumowskiMira'].__class__.__name__ == 'GumowskiMira'
assert params.attractors['GumowskiMira'].__class__.name == 'GumowskiMira'
assert isinstance(params.get_attractor(*['FractalDream', 'kgy', 0.0, 0.0, 1.7, 1.7, 1.15, 2.34]), at.Attractor)

0 comments on commit 91c2687

Please sign in to comment.