From 5f7834fee82902950cc27be8962bb32f18821f9c Mon Sep 17 00:00:00 2001 From: HesNobi <36954719+HesNobi@users.noreply.github.com> Date: Fri, 25 Jun 2021 12:15:02 +0800 Subject: [PATCH] Update __init__.py The 'sequence_dataset()' index error is fixed, And some steps and episode limits have been added to the generator. --- d4rl/__init__.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/d4rl/__init__.py b/d4rl/__init__.py index 6ad9b470..1cd8c2b3 100644 --- a/d4rl/__init__.py +++ b/d4rl/__init__.py @@ -133,8 +133,7 @@ def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs): 'terminals': np.array(done_), } - -def sequence_dataset(env, dataset=None, **kwargs): +def sequence_dataset(env, dataset=None, max_steps = None, max_episodes=None, **kwargs): """ Returns an iterator through trajectories. @@ -151,10 +150,21 @@ def sequence_dataset(env, dataset=None, **kwargs): rewards terminals """ + # TODO: Some serious performance issues. + # TODO: Randomize the episode selection without extracting all of them. + # TODO: Adding discounted reward returns. + if dataset is None: dataset = env.get_dataset(**kwargs) - N = dataset['rewards'].shape[0] + total_steps = dataset['rewards'].shape[0] + if max_steps is None: + max_steps = total_steps + + assert max_steps <= dataset['rewards'].shape[0],\ + "\"max_steps ={} \" should be smaller (or equal) than total number of steps = {}.".format( + max_steps, total_steps) + data_ = collections.defaultdict(list) # The newer version of the dataset adds an explicit @@ -163,15 +173,21 @@ def sequence_dataset(env, dataset=None, **kwargs): if 'timeouts' in dataset: use_timeouts = True + key_list = [] + for k_index in dataset: + if isinstance(dataset[k_index], np.ndarray) \ + and dataset[k_index].shape[0] == total_steps: + key_list.append(k_index) + episode_step = 0 - for i in range(N): + for i in range(max_steps): done_bool = bool(dataset['terminals'][i]) if use_timeouts: final_timestep = dataset['timeouts'][i] else: final_timestep = (episode_step == env._max_episode_steps - 1) - for k in dataset: + for k in key_list: data_[k].append(dataset[k][i]) if done_bool or final_timestep: @@ -181,6 +197,13 @@ def sequence_dataset(env, dataset=None, **kwargs): episode_data[k] = np.array(data_[k]) yield episode_data data_ = collections.defaultdict(list) + if max_episodes: + max_episodes -= 1 + if max_episodes < 1: + break episode_step += 1 + if max_episodes is not None and max_episodes > 0: + import warnings + warnings.warn("[WARNING] Not enough steps in the dataset to generate the requested number of episodes")