From bc53b37584b21553b98eda948b53345e22fb8c9b Mon Sep 17 00:00:00 2001 From: mr-mikmik Date: Fri, 5 Apr 2024 10:48:38 -0400 Subject: [PATCH 1/2] adding option for step-dependent dynamics --- src/pytorch_icem/icem.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pytorch_icem/icem.py b/src/pytorch_icem/icem.py index d18a69b..f6ee891 100644 --- a/src/pytorch_icem/icem.py +++ b/src/pytorch_icem/icem.py @@ -26,6 +26,7 @@ def __init__(self, dynamics, trajectory_cost, nx, nu, sigma=None, num_samples=10 warmup_iters=100, online_iters=100, includes_x0=False, fixed_H=True, + step_dependent_dynamics=False, device="cpu"): self.dynamics = dynamics @@ -47,6 +48,7 @@ def __init__(self, dynamics, trajectory_cost, nx, nu, sigma=None, num_samples=10 self.sigma = sigma self.dtype = self.sigma.dtype + self.step_dependency = step_dependent_dynamics self.warmup_iters = warmup_iters self.online_iters = online_iters self.includes_x0 = includes_x0 @@ -104,8 +106,8 @@ def _cost(self, x, u): # return self.problem.objective(xu) @handle_batch_input(n=2) - def _dynamics(self, x, u): - return self.dynamics(x, u) + def _dynamics(self, x, u, t): + return self.dynamics(x, u, t) if self.step_dependency else self.dynamics(x, u) def _rollout_dynamics(self, x0, u): N, H, du = u.shape @@ -114,7 +116,7 @@ def _rollout_dynamics(self, x0, u): x = [x0.reshape(1, self.nx).repeat(N, 1)] for t in range(self.H): - x.append(self._dynamics(x[-1], u[:, t])) + x.append(self._dynamics(x[-1], u[:, t], t)) if self.includes_x0: return torch.stack(x[:-1], dim=1) From aa66be6bb1e5ce16d256677526009bf719d49d18 Mon Sep 17 00:00:00 2001 From: mr-mikmik Date: Fri, 5 Apr 2024 16:17:24 -0400 Subject: [PATCH 2/2] adding option for displaying optimization progress as a progress bar using tqdm --- src/pytorch_icem/icem.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/pytorch_icem/icem.py b/src/pytorch_icem/icem.py index f6ee891..cb1474b 100644 --- a/src/pytorch_icem/icem.py +++ b/src/pytorch_icem/icem.py @@ -1,6 +1,7 @@ import torch import colorednoise from arm_pytorch_utilities import handle_batch_input +from tqdm import tqdm import logging @@ -122,7 +123,7 @@ def _rollout_dynamics(self, x0, u): return torch.stack(x[:-1], dim=1) return torch.stack(x[1:], dim=1) - def command(self, state, shift_nominal_trajectory=True, return_full_trajectories=False, **kwargs): + def command(self, state, shift_nominal_trajectory=True, return_full_trajectories=False, progress_bar=False, **kwargs): if not torch.is_tensor(state): state = torch.tensor(state, device=self.device, dtype=self.dtype) x = state @@ -144,7 +145,12 @@ def command(self, state, shift_nominal_trajectory=True, return_full_trajectories # Shift the keep elites - for i in range(iterations): + if progress_bar: + its = tqdm(range(iterations)) + else: + its = range(iterations) + + for i in its: if self.kept_elites is None: # Sample actions U = self.sample_action_sequences(x, self.N)