Skip to content

Commit

Permalink
fixing edac
Browse files Browse the repository at this point in the history
  • Loading branch information
davidzhu27 committed May 8, 2024
1 parent 5e907da commit 80ea873
Showing 1 changed file with 68 additions and 6 deletions.
74 changes: 68 additions & 6 deletions algorithms/offline/edac.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,25 @@
from torch.distributions import Normal
from tqdm import trange

from pbrl import scale_rewards, generate_pbrl_dataset, make_latent_reward_dataset, train_latent, predict_and_label_latent_reward
from pbrl import label_by_trajectory_reward, generate_pbrl_dataset_no_overlap, small_d4rl_dataset
from pbrl import label_by_trajectory_reward_multiple_bernoullis, label_by_original_rewards
from ipl_helper import save_preference_dataset

@dataclass
class TrainConfig:
# PBRL
num_t: int = 1000
len_t: int = 20
latent_reward: int = 0
bin_label: int = 0
bin_label_trajectory_batch: int = 0
bin_label_allow_overlap: int = 1
num_berno: int = 1
out_name: str = ""
quick_stop: int = 0
dataset_size_multiplier: float = 1.0

# wandb params
project: str = "CORL"
group: str = "EDAC-D4RL"
Expand Down Expand Up @@ -55,6 +72,8 @@ class TrainConfig:

def __post_init__(self):
self.name = f"{self.name}-{self.env_name}-{str(uuid.uuid4())[:8]}"
if self.out_name:
self.name = self.out_name
if self.checkpoints_path is not None:
self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)

Expand Down Expand Up @@ -116,10 +135,16 @@ def __init__(
action_dim: int,
buffer_size: int,
device: str = "cpu",
bin_label_trajectory_batch: bool = False,
num_t: int = 1000,
len_t: int = 20,
):
self._buffer_size = buffer_size
self._pointer = 0
self._size = 0
self._bin_label_trajectory_batch = bin_label_trajectory_batch
self._num_t = num_t
self._len_t = len_t

self._states = torch.zeros(
(buffer_size, state_dim), dtype=torch.float32, device=device
Expand Down Expand Up @@ -157,7 +182,14 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
print(f"Dataset size: {n_transitions}")

def sample(self, batch_size: int) -> TensorBatch:
indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
if self._bin_label_trajectory_batch:
num_t = self._num_t
len_t = self._len_t
indices_of_traj = np.random.randint(0, num_t*2, size=batch_size//len_t)
indices = np.array([np.arange(i*len_t, (i+1)*len_t) for i in indices_of_traj]).flatten()
else:
indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)

states = self._states[indices]
actions = self._actions[indices]
rewards = self._rewards[indices]
Expand Down Expand Up @@ -554,12 +586,42 @@ def train(config: TrainConfig):
if config.normalize_reward:
modify_reward(d4rl_dataset, config.env_name)

########################################
buffer = ReplayBuffer(
state_dim=state_dim,
action_dim=action_dim,
buffer_size=config.buffer_size,
device=config.device,
state_dim,
action_dim,
config.buffer_size,
config.device,
config.bin_label_trajectory_batch,
config.num_t,
config.len_t
)
dataset = d4rl_dataset.copy()
num_t = config.num_t
len_t = config.len_t
num_trials = config.num_berno
if config.latent_reward:
dataset = scale_rewards(dataset)
pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'CORL/saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t)
latent_reward_model, indices = train_latent(dataset, pbrl_dataset, num_berno=num_trials, num_t=num_t, len_t=len_t)
dataset = predict_and_label_latent_reward(dataset, latent_reward_model, indices)
elif config.bin_label:
dataset = scale_rewards(dataset)
if config.bin_label_allow_overlap:
pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'CORL/saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t)
else:
pbrl_dataset = generate_pbrl_dataset_no_overlap(dataset, pbrl_dataset_file_path=f'CORL/saved/pbrl_datasets_no_ovlp/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}', num_t=num_t, len_t=len_t)
dataset = label_by_trajectory_reward(dataset, pbrl_dataset, num_t=num_t, len_t=len_t, num_trials=num_trials)
else:
pbrl_dataset = generate_pbrl_dataset(dataset, pbrl_dataset_file_path=f'CORL/saved/pbrl_datasets/pbrl_dataset_{config.env}_{num_t}_{len_t}_numTrials={num_trials}.npz', num_t=num_t, len_t=len_t)
dataset = label_by_original_rewards(dataset, pbrl_dataset, num_t)
dataset = small_d4rl_dataset(dataset, dataset_size_multiplier=config.dataset_size_multiplier)
print(f'Dataset size: {(dataset["observations"]).shape[0]}')
if config.quick_stop:
return
d4rl_dataset = dataset.copy()
########################################

buffer.load_d4rl_dataset(d4rl_dataset)

# Actor & Critic setup
Expand Down Expand Up @@ -635,4 +697,4 @@ def train(config: TrainConfig):


if __name__ == "__main__":
train()
train()

0 comments on commit 80ea873

Please sign in to comment.