diff --git a/algorithms/offline/edac.py b/algorithms/offline/edac.py index 413801c..a8f347a 100644 --- a/algorithms/offline/edac.py +++ b/algorithms/offline/edac.py @@ -19,8 +19,25 @@ from torch.distributions import Normal from tqdm import trange +from pbrl import scale_rewards, generate_pbrl_dataset, make_latent_reward_dataset, train_latent, predict_and_label_latent_reward +from pbrl import label_by_trajectory_reward, generate_pbrl_dataset_no_overlap, small_d4rl_dataset +from pbrl import label_by_trajectory_reward_multiple_bernoullis, label_by_original_rewards +from ipl_helper import save_preference_dataset + @dataclass class TrainConfig: + # PBRL + num_t: int = 1000 + len_t: int = 20 + latent_reward: int = 0 + bin_label: int = 0 + bin_label_trajectory_batch: int = 0 + bin_label_allow_overlap: int = 1 + num_berno: int = 1 + out_name: str = "" + quick_stop: int = 0 + dataset_size_multiplier: float = 1.0 + # wandb params project: str = "CORL" group: str = "EDAC-D4RL" @@ -55,6 +72,8 @@ class TrainConfig: def __post_init__(self): self.name = f"{self.name}-{self.env_name}-{str(uuid.uuid4())[:8]}" + if self.out_name: + self.name = self.out_name if self.checkpoints_path is not None: self.checkpoints_path = os.path.join(self.checkpoints_path, self.name) @@ -116,10 +135,16 @@ def __init__( action_dim: int, buffer_size: int, device: str = "cpu", + bin_label_trajectory_batch: bool = False, + num_t: int = 1000, + len_t: int = 20, ): self._buffer_size = buffer_size self._pointer = 0 self._size = 0 + self._bin_label_trajectory_batch = bin_label_trajectory_batch + self._num_t = num_t + self._len_t = len_t self._states = torch.zeros( (buffer_size, state_dim), dtype=torch.float32, device=device @@ -157,7 +182,14 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): print(f"Dataset size: {n_transitions}") def sample(self, batch_size: int) -> TensorBatch: - indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) + if self._bin_label_trajectory_batch: + num_t = self._num_t + len_t = self._len_t + indices_of_traj = np.random.randint(0, num_t*2, size=batch_size//len_t) + indices = np.array([np.arange(i*len_t, (i+1)*len_t) for i in indices_of_traj]).flatten() + else: + indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) + states = self._states[indices] actions = self._actions[indices] rewards = self._rewards[indices] @@ -554,12 +586,42 @@ def train(config: TrainConfig): if config.normalize_reward: modify_reward(d4rl_dataset, config.env_name) + ######################################## buffer = ReplayBuffer( - state_dim=state_dim, - action_dim=action_dim, - buffer_size=config.buffer_size, - device=config.device, + state_dim, + action_dim, + config.buffer_size, + config.device, + config.bin_label_trajectory_batch, + config.num_t, + config.len_t ) + dataset = d4rl_dataset.copy() + num_t = config.num_t + len_t = config.len_t + num_trials = config.num_berno + if config.latent_reward: + dataset = scale_rewards(dataset) + pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'CORL/saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t) + latent_reward_model, indices = train_latent(dataset, pbrl_dataset, num_berno=num_trials, num_t=num_t, len_t=len_t) + dataset = predict_and_label_latent_reward(dataset, latent_reward_model, indices) + elif config.bin_label: + dataset = scale_rewards(dataset) + if config.bin_label_allow_overlap: + pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'CORL/saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t) + else: + pbrl_dataset = generate_pbrl_dataset_no_overlap(dataset, pbrl_dataset_file_path=f'CORL/saved/pbrl_datasets_no_ovlp/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}', num_t=num_t, len_t=len_t) + dataset = label_by_trajectory_reward(dataset, pbrl_dataset, num_t=num_t, len_t=len_t, num_trials=num_trials) + else: + pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'CORL/saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t) + dataset = label_by_original_rewards(dataset, pbrl_dataset, num_t) + dataset = small_d4rl_dataset(dataset, dataset_size_multiplier=config.dataset_size_multiplier) + print(f'Dataset size: {(dataset["observations"]).shape[0]}') + if config.quick_stop: + return + d4rl_dataset = dataset.copy() + ######################################## + buffer.load_d4rl_dataset(d4rl_dataset) # Actor & Critic setup @@ -635,4 +697,4 @@ def train(config: TrainConfig): if __name__ == "__main__": - train() + train() \ No newline at end of file