Skip to content

Commit

Permalink
Replace NumPy with Torch in examples/fabric/ (Lightning-AI#17279)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ryan597 and pre-commit-ci[bot] authored Apr 6, 2023
1 parent ef7da5c commit fb775e0
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 41 deletions.
8 changes: 3 additions & 5 deletions examples/fabric/meta_learning/train_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""
import cherry
import learn2learn as l2l
import numpy as np
import torch

from lightning.fabric import Fabric, seed_everything
Expand All @@ -31,10 +30,9 @@ def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways):
data, labels = batch

# Separate data into adaptation/evalutation sets
adaptation_indices = np.zeros(data.size(0), dtype=bool)
adaptation_indices[np.arange(shots * ways) * 2] = True
evaluation_indices = torch.from_numpy(~adaptation_indices)
adaptation_indices = torch.from_numpy(adaptation_indices)
adaptation_indices = torch.zeros(data.size(0), dtype=bool)
adaptation_indices[torch.arange(shots * ways) * 2] = True
evaluation_indices = ~adaptation_indices
adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

Expand Down
9 changes: 3 additions & 6 deletions examples/fabric/meta_learning/train_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import cherry
import learn2learn as l2l
import numpy as np
import torch
import torch.distributed as dist

Expand All @@ -35,10 +34,9 @@ def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
data, labels = data.to(device), labels.to(device)

# Separate data into adaptation/evalutation sets
adaptation_indices = np.zeros(data.size(0), dtype=bool)
adaptation_indices[np.arange(shots * ways) * 2] = True
evaluation_indices = torch.from_numpy(~adaptation_indices)
adaptation_indices = torch.from_numpy(adaptation_indices)
adaptation_indices = torch.zeros(data.size(0), dtype=bool)
adaptation_indices[torch.arange(shots * ways) * 2] = True
evaluation_indices = ~adaptation_indices
adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

Expand Down Expand Up @@ -76,7 +74,6 @@ def main(
seed = seed + rank

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cpu")
if cuda and torch.cuda.device_count():
Expand Down
22 changes: 13 additions & 9 deletions examples/fabric/reinforcement_learning/rl/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from typing import Dict, Tuple

import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
from rl.loss import entropy_loss, policy_loss, value_loss
Expand All @@ -24,7 +24,8 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_
raise ValueError("Unrecognized activation function: `act_fun` must be either `relu` or `tanh`")
self.critic = torch.nn.Sequential(
layer_init(
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
ortho_init=ortho_init,
),
act_fun,
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
Expand All @@ -33,7 +34,8 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_
)
self.actor = torch.nn.Sequential(
layer_init(
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
ortho_init=ortho_init,
),
act_fun,
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
Expand Down Expand Up @@ -81,10 +83,10 @@ def estimate_returns_and_advantages(
lastgaelam = 0
for t in reversed(range(num_steps)):
if t == num_steps - 1:
nextnonterminal = 1.0 - next_done
nextnonterminal = torch.logical_not(next_done)
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextnonterminal = torch.logical_not(dones[t + 1])
nextvalues = values[t + 1]
delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
Expand Down Expand Up @@ -119,7 +121,8 @@ def __init__(
self.normalize_advantages = normalize_advantages
self.critic = torch.nn.Sequential(
layer_init(
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
ortho_init=ortho_init,
),
act_fun,
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
Expand All @@ -128,7 +131,8 @@ def __init__(
)
self.actor = torch.nn.Sequential(
layer_init(
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
ortho_init=ortho_init,
),
act_fun,
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
Expand Down Expand Up @@ -179,10 +183,10 @@ def estimate_returns_and_advantages(
lastgaelam = 0
for t in reversed(range(num_steps)):
if t == num_steps - 1:
nextnonterminal = 1.0 - next_done
nextnonterminal = torch.logical_not(next_done)
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextnonterminal = torch.logical_not(dones[t + 1])
nextvalues = values[t + 1]
delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
Expand Down
16 changes: 10 additions & 6 deletions examples/fabric/reinforcement_learning/rl/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import argparse
import math
import os
from distutils.util import strtobool
from typing import Optional, TYPE_CHECKING, Union

import gymnasium as gym
import numpy as np
import torch
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,7 +118,12 @@ def parse_args():
return args


def layer_init(layer: torch.nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0, ortho_init: bool = True):
def layer_init(
layer: torch.nn.Module,
std: float = math.sqrt(2),
bias_const: float = 0.0,
ortho_init: bool = True,
):
if ortho_init:
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
Expand Down Expand Up @@ -157,16 +161,16 @@ def test(
step = 0
done = False
cumulative_rew = 0
next_obs = Tensor(env.reset(seed=args.seed)[0]).to(device)
next_obs = torch.tensor(env.reset(seed=args.seed)[0], device=device)
while not done:
# Act greedly through the environment
action = agent.get_greedy_action(next_obs)

# Single environment step
next_obs, reward, done, truncated, info = env.step(action.cpu().numpy())
done = np.logical_or(done, truncated)
done = done or truncated
cumulative_rew += reward
next_obs = Tensor(next_obs).to(device)
next_obs = torch.tensor(next_obs, device=device)
step += 1
logger.add_scalar("Test/cumulative_reward", cumulative_rew, 0)
env.close()
7 changes: 3 additions & 4 deletions examples/fabric/reinforcement_learning/train_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import Dict

import gymnasium as gym
import numpy as np
import torch
import torchmetrics
from rl.agent import PPOLightningAgent
Expand Down Expand Up @@ -128,7 +127,7 @@ def main(args: argparse.Namespace):
num_updates = args.total_timesteps // single_global_rollout

# Get the first environment observation and start the optimization
next_obs = Tensor(envs.reset(seed=args.seed)[0]).to(device)
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
next_done = torch.zeros(args.num_envs, device=device)
for update in range(1, num_updates + 1):
# Learning rate annealing
Expand All @@ -150,9 +149,9 @@ def main(args: argparse.Namespace):

# Single environment step
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
done = np.logical_or(done, truncated)
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
rewards[step] = torch.tensor(reward, device=device).view(-1)
next_obs, next_done = Tensor(next_obs).to(device), Tensor(done).to(device)
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)

if "final_info" in info:
for agent_final_info in info["final_info"]:
Expand Down
10 changes: 4 additions & 6 deletions examples/fabric/reinforcement_learning/train_fabric_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@
from datetime import datetime

import gymnasium as gym
import numpy as np
import torch
from rl.agent import PPOLightningAgent
from rl.utils import linear_annealing, make_env, parse_args, test
from torch import Tensor
from torch.utils.data import BatchSampler, DistributedSampler
from torchmetrics import MeanMetric

Expand Down Expand Up @@ -108,7 +106,7 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T
world_collective.broadcast(update_t, src=0)

# Get the first environment observation and start the optimization
next_obs = Tensor(envs.reset(seed=args.seed)[0]).to(device)
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
next_done = torch.zeros(args.num_envs).to(device)
for update in range(1, num_updates + 1):
for step in range(0, args.num_steps):
Expand All @@ -124,9 +122,9 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T

# Single environment step
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
done = np.logical_or(done, truncated)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = Tensor(next_obs).to(device), Tensor(done).to(device)
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
rewards[step] = torch.tensor(reward, device=device).view(-1)
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)

if "final_info" in info:
for agent_final_info in info["final_info"]:
Expand Down
8 changes: 3 additions & 5 deletions examples/fabric/reinforcement_learning/train_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from typing import Dict

import gymnasium as gym
import numpy as np
import torch
import torch.distributed as distributed
import torch.nn as nn
Expand Down Expand Up @@ -118,7 +117,6 @@ def main(args: argparse.Namespace):

# Seed everything
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
Expand Down Expand Up @@ -181,7 +179,7 @@ def main(args: argparse.Namespace):
num_updates = args.total_timesteps // single_global_step

# Get the first environment observation and start the optimization
next_obs = Tensor(envs.reset(seed=args.seed)[0]).to(device)
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
next_done = torch.zeros(args.num_envs, device=device)
for update in range(1, num_updates + 1):
# Learning rate annealing
Expand All @@ -204,9 +202,9 @@ def main(args: argparse.Namespace):

# Single environment step
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
done = np.logical_or(done, truncated)
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
rewards[step] = torch.tensor(reward, device=device).view(-1)
next_obs, next_done = Tensor(next_obs).to(device), Tensor(done).to(device)
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)

if "final_info" in info:
for agent_final_info in info["final_info"]:
Expand Down

0 comments on commit fb775e0

Please sign in to comment.