Skip to content

Commit

Permalink
Improving the python typing and argument names.
Browse files Browse the repository at this point in the history
Also, save the checkpoint 1,000 iterations to save disk.

PiperOrigin-RevId: 634833220
Change-Id: I1fb6399e8ed6472f53bd56b63a927b7d4b9cd6b0
  • Loading branch information
esonghori authored and copybara-github committed May 17, 2024
1 parent 28e1fdf commit 26b4a84
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 44 deletions.
78 changes: 40 additions & 38 deletions circuit_training/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
import os
import time
from typing import Any, Callable, Dict, Optional, Text, Tuple
from typing import Any, Callable, Protocol

from absl import logging
from circuit_training.dreamplace import dreamplace_core
Expand All @@ -35,8 +35,8 @@
from tf_agents.environments import suite_gym
from tf_agents.environments import wrappers

ObsType = Dict[Text, np.ndarray]
InfoType = Dict[Text, float]
ObsType = dict[str, np.ndarray]
InfoType = dict[str, float]


DREAMPLACE_RUNTIME = 'dreamplace_runtime'
Expand Down Expand Up @@ -68,6 +68,17 @@ def __str__(self):
COST_COMPONENTS = ['wirelength', 'congestion', 'density']


class CostInfoFunctionCallable(Protocol):

def __call__(
self,
plc: plc_client.PlacementCost,
done: bool,
infeasible_state: bool = False,
) -> tuple[float, dict[str, float]]:
...


@gin.configurable
def cost_info_function(
plc: plc_client.PlacementCost,
Expand All @@ -76,7 +87,7 @@ def cost_info_function(
wirelength_weight: float = 1.0,
density_weight: float = 1.0,
congestion_weight: float = 0.5,
) -> Tuple[float, Dict[Text, float]]:
) -> tuple[float, dict[str, float]]:
"""Returns the RL cost and info.
Args:
Expand Down Expand Up @@ -126,26 +137,23 @@ class CircuitEnv(object):

def __init__(
self,
netlist_file: Text = '',
init_placement: Text = '',
netlist_file: str = '',
init_placement: str = '',
create_placement_cost_fn: Callable[
..., plc_client.PlacementCost
] = placement_util.create_placement_cost,
std_cell_placer_mode: Text = 'fd',
cost_info_fn: Callable[
[plc_client.PlacementCost, bool, bool],
Tuple[float, Dict[Text, float]],
] = cost_info_function,
std_cell_placer_mode: str = 'fd',
cost_info_fn: CostInfoFunctionCallable = cost_info_function,
global_seed: int = 0,
netlist_index: int = 0,
is_eval: bool = False,
save_best_cost: bool = False,
output_plc_file: Text = '',
save_placement: bool = False,
save_best_cost: bool = True,
output_plc_file: str = '',
cd_finetune: bool = False,
cd_plc_file: Text = 'ppo_cd_placement.plc',
train_step: Optional[tf.Variable] = None,
cd_plc_file: str = 'ppo_cd_placement.plc',
train_step: tf.Variable | None = None,
output_all_features: bool = False,
node_order: Text = 'descending_size_macro_first',
node_order: str = 'descending_size_macro_first',
save_snapshot: bool = True,
save_partial_placement: bool = False,
):
Expand All @@ -164,7 +172,7 @@ def __init__(
global_seed: Global seed for initializing env features. This seed should
be the same across actors.
netlist_index: Netlist index in the model static features.
is_eval: If set, save the final placement in output_dir.
save_placement: If set, save the final placement in output_dir.
save_best_cost: Boolean, if set, saves the palcement if its cost is better
than the previously saved palcement.
output_plc_file: The path to save the final placement.
Expand All @@ -188,7 +196,7 @@ def __init__(
self.netlist_file = netlist_file
self._std_cell_placer_mode = std_cell_placer_mode
self._cost_info_fn = cost_info_fn
self._is_eval = is_eval
self._save_placement = save_placement
self._save_best_cost = save_best_cost
self._output_plc_file = output_plc_file
self._output_plc_dir = os.path.dirname(output_plc_file)
Expand Down Expand Up @@ -297,7 +305,7 @@ def action_space(self) -> gym.spaces.Space:
return gym.spaces.Discrete(self._observation_config.max_grid_size**2)

@property
def environment_name(self) -> Text:
def environment_name(self) -> str:
return self.netlist_file

@property
Expand Down Expand Up @@ -329,10 +337,8 @@ def get_static_obs(self):
"""
return self._observation_extractor.get_static_features()

def get_cost_info(
self, done: bool = False
) -> Tuple[float, Dict[Text, float]]:
return self._cost_info_fn(plc=self._plc, done=done, infeasible_state=False) # pytype: disable=wrong-keyword-args # trace-all-classes
def get_cost_info(self, done: bool = False) -> tuple[float, dict[str, float]]:
return self._cost_info_fn(plc=self._plc, done=done, infeasible_state=False)

def _get_mask(self) -> np.ndarray:
"""Gets the node mask for the current node.
Expand All @@ -341,9 +347,7 @@ def _get_mask(self) -> np.ndarray:
List of 0s and 1s indicating if action is feasible or not.
"""
if self._done:
mask = np.zeros(
self._observation_config.max_grid_size**2, dtype=np.int32
)
mask = np.zeros(self._observation_config.max_grid_size**2, dtype=np.int32)
else:
node_index = self._sorted_node_indices[self._current_node]
mask = np.asarray(self._plc.get_node_mask(node_index), dtype=np.int32)
Expand Down Expand Up @@ -400,12 +404,12 @@ def _run_cd(self):
# Plc modified by CD will be reset at the end of the episode.

def cost_fn(plc):
return self._cost_info_fn(plc=plc, done=True, infeasible_state=False) # pytype: disable=wrong-keyword-args # trace-all-classes
return self._cost_info_fn(plc=plc, done=True, infeasible_state=False)

cd = cd_placer.CoordinateDescentPlacer(plc=self._plc, cost_fn=cost_fn)
cd.place()

def _save_placement(self, cost: float) -> None:
def _save_placement_fn(self, cost: float) -> None:
"""Saves the current placement.
Args:
Expand All @@ -425,10 +429,9 @@ def _save_placement(self, cost: float) -> None:
placement_util.save_placement(
self._plc, self._output_plc_file, user_comments
)
ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

self._saved_cost = cost

ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
if self._save_snapshot:
ppo_snapshot_file = os.path.join(
self._output_plc_dir,
Expand All @@ -442,7 +445,7 @@ def _save_placement(self, cost: float) -> None:
if self._cd_finetune:
self._run_cd()
cost = self._cost_info_fn(
plc=self._plc, done=True, infeasible_state=False # pytype: disable=wrong-keyword-args # trace-all-classes
plc=self._plc, done=True, infeasible_state=False
)[0]
cd_plc_file = os.path.join(self._output_plc_dir, self._cd_plc_file)
placement_util.save_placement(self._plc, cd_plc_file, user_comments)
Expand All @@ -456,7 +459,7 @@ def _save_placement(self, cost: float) -> None:

def call_analytical_placer_and_get_cost(
self, infeasible_state=False
) -> Tuple[float, InfoType]:
) -> tuple[float, InfoType]:
"""Calls analytical placer.
Calls analystical placer and evaluates cost when all nodes are placed. Also,
Expand All @@ -478,18 +481,17 @@ def call_analytical_placer_and_get_cost(
# This is realized by setting intermediate steps cost as zero, and
# propagate the final cost with discount factor set to 1 in replay buffer.
cost, info = self._cost_info_fn(
plc=self._plc, done=self._done, infeasible_state=infeasible_state # pytype: disable=wrong-keyword-args # trace-all-classes
plc=self._plc, done=self._done, infeasible_state=infeasible_state
)
info[DREAMPLACE_RUNTIME] = total_time

# Only saves placement in eval.
# Happens when the episode is done, when RL places all nodes, or we want to
# save partial placement regardless RL places all nodes.
if self._is_eval:
if self._save_placement:
if self._current_node == self._num_hard_macros or (
self._done and self._save_partial_placement
):
self._save_placement(cost)
self._save_placement_fn(cost)

info[TOTAL_EPISODE_RUNTIME] = time.time() - self._episode_start_time

Expand Down Expand Up @@ -565,7 +567,7 @@ def analytical_placer(self) -> None:
% (self._std_cell_placer_mode)
)

def step(self, action: int) -> Tuple[ObsType, float, bool, Any]:
def step(self, action: int) -> tuple[ObsType, float, bool, Any]:
"""Steps the environment.
Args:
Expand Down
3 changes: 1 addition & 2 deletions circuit_training/environment/environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ def test_save_file_train_step(self):
env = environment.CircuitEnv(
netlist_file=netlist_file,
init_placement=init_placement,
is_eval=True,
save_best_cost=True,
save_placement=True,
output_plc_file=output_plc_file,
cd_finetune=True,
train_step=train_step,
Expand Down
3 changes: 1 addition & 2 deletions circuit_training/learning/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def main(_):
environment.create_circuit_environment,
netlist_file=FLAGS.netlist_file,
init_placement=FLAGS.init_placement,
is_eval=True,
save_best_cost=True,
save_placement=True,
output_plc_file=output_plc_file,
global_seed=FLAGS.global_seed,
cd_finetune=FLAGS.cd_finetune,
Expand Down
7 changes: 5 additions & 2 deletions circuit_training/learning/train_ppo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def compute_total_training_step(
'num_iterations',
'num_episodes_per_iteration',
'init_learning_rate',
'policy_save_interval',
]
)
def train(
Expand Down Expand Up @@ -165,6 +166,7 @@ def train(
# num_replicas.
num_episodes_per_iteration: int = 256,
init_learning_rate: float = 0.004,
policy_save_interval: int = 1000,
num_netlists: int = 1,
debug_summaries: bool = False,
) -> None:
Expand Down Expand Up @@ -195,6 +197,7 @@ def train(
num_episodes_per_iteration: This is the number of episodes we train in each
epoch.
init_learning_rate: Initial learning rate.
policy_save_interval: How often policies are saved.
num_netlists: Number of netlits to train used for normalizing advantage. If
larger than 1, the advantage will be normalize first across the netlists
then on the entire batch.
Expand Down Expand Up @@ -257,8 +260,8 @@ def train(
saved_model_dir,
tf_agent,
train_step,
start=-num_episodes_per_iteration,
interval=num_episodes_per_iteration,
start=-policy_save_interval,
interval=policy_save_interval,
)

# Create the variable container.
Expand Down

0 comments on commit 26b4a84

Please sign in to comment.