Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding step dependent dynamics #1

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/pytorch_icem/icem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import colorednoise
from arm_pytorch_utilities import handle_batch_input
from tqdm import tqdm

import logging

Expand All @@ -26,6 +27,7 @@ def __init__(self, dynamics, trajectory_cost, nx, nu, sigma=None, num_samples=10
warmup_iters=100, online_iters=100,
includes_x0=False,
fixed_H=True,
step_dependent_dynamics=False,
device="cpu"):

self.dynamics = dynamics
Expand All @@ -47,6 +49,7 @@ def __init__(self, dynamics, trajectory_cost, nx, nu, sigma=None, num_samples=10
self.sigma = sigma
self.dtype = self.sigma.dtype

self.step_dependency = step_dependent_dynamics
self.warmup_iters = warmup_iters
self.online_iters = online_iters
self.includes_x0 = includes_x0
Expand Down Expand Up @@ -104,8 +107,8 @@ def _cost(self, x, u):
# return self.problem.objective(xu)

@handle_batch_input(n=2)
def _dynamics(self, x, u):
return self.dynamics(x, u)
def _dynamics(self, x, u, t):
return self.dynamics(x, u, t) if self.step_dependency else self.dynamics(x, u)

def _rollout_dynamics(self, x0, u):
N, H, du = u.shape
Expand All @@ -114,13 +117,13 @@ def _rollout_dynamics(self, x0, u):

x = [x0.reshape(1, self.nx).repeat(N, 1)]
for t in range(self.H):
x.append(self._dynamics(x[-1], u[:, t]))
x.append(self._dynamics(x[-1], u[:, t], t))

if self.includes_x0:
return torch.stack(x[:-1], dim=1)
return torch.stack(x[1:], dim=1)

def command(self, state, shift_nominal_trajectory=True, return_full_trajectories=False, **kwargs):
def command(self, state, shift_nominal_trajectory=True, return_full_trajectories=False, progress_bar=False, **kwargs):
if not torch.is_tensor(state):
state = torch.tensor(state, device=self.device, dtype=self.dtype)
x = state
Expand All @@ -142,7 +145,12 @@ def command(self, state, shift_nominal_trajectory=True, return_full_trajectories

# Shift the keep elites

for i in range(iterations):
if progress_bar:
its = tqdm(range(iterations))
else:
its = range(iterations)

for i in its:
if self.kept_elites is None:
# Sample actions
U = self.sample_action_sequences(x, self.N)
Expand Down