Skip to content

Commit

Permalink
change interface to pass args directly to env
Browse files Browse the repository at this point in the history
Signed-off-by: Tin Lai <[email protected]>
  • Loading branch information
soraxas committed Jun 12, 2021
1 parent 005e119 commit 2d02835
Show file tree
Hide file tree
Showing 20 changed files with 75 additions and 58 deletions.
30 changes: 15 additions & 15 deletions collisionChecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import numpy as np

from utils.common import MagicDict
from utils.common import Stats


class CollisionChecker(ABC):
"""Abstract collision checker"""

def __init__(self, args: MagicDict):
self.args = args
def __init__(self, stats: Stats):
self.stats = stats

@abstractmethod
def get_dimension(self) -> int:
Expand Down Expand Up @@ -54,13 +54,13 @@ class ImgCollisionChecker(CollisionChecker):
2D Image Space simulator engine
"""

def __init__(self, img: typing.IO, args: MagicDict):
def __init__(self, img: typing.IO, stats: Stats):
"""
:param img: a file-like object (e.g. a filename) for the image as the
environment that the planning operates in
:param args: the args of the planning problem
:param stats: the Stats object to keep track of stats
"""
super().__init__(args)
super().__init__(stats)
from PIL import Image

image = Image.open(img).convert("L")
Expand Down Expand Up @@ -187,12 +187,12 @@ def get_line(start, end):
class KlamptCollisionChecker(CollisionChecker):
"""A wrapper around Klampt's 3D simulator"""

def __init__(self, xml: str, args: MagicDict):
def __init__(self, xml: str, stats: Stats):
"""
:param xml: the xml filename for Klampt to read the world settings
:param args: the args of the planning problem
:param stats: the Stats object to keep track of stats
"""
super().__init__(args)
super().__init__(stats)
import klampt
from klampt.plan import robotplanning

Expand Down Expand Up @@ -254,12 +254,12 @@ def visible(self, a, b):
a = self.translate_to_klampt(a)
b = self.translate_to_klampt(b)
# print(self.space.visible(a, b))
self.args.env.stats.visible_cnt += 1
self.stats.visible_cnt += 1
return self.space.isVisible(a, b)

def feasible(self, p, stats=False):
p = self.translate_to_klampt(p)
self.args.env.stats.feasible_cnt += 1
self.stats.feasible_cnt += 1
return self.space.feasible(p)


Expand All @@ -271,9 +271,9 @@ class RobotArm4dCollisionChecker(CollisionChecker):
def __init__(
self,
img: typing.IO,
args: MagicDict,
stats: Stats,
map_mat: typing.Optional[np.ndarray] = None,
stick_robot_length_config: typing.Tuple[float, ...] = (35, 35),
stick_robot_length_config: typing.Sequence[float] = (35, 35),
):
"""
Expand All @@ -283,9 +283,9 @@ def __init__(
uses `map_mat` directly as the map
:param stick_robot_length_config: a list of numbers that represents the
length of the stick robotic arm
:param args: the args of the planning problem
:param stats: the Stats object to keep track of stats
"""
super().__init__(args)
super().__init__(stats)
if map_mat is None:
from PIL import Image

Expand Down
8 changes: 4 additions & 4 deletions env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class Env:
"""Represents the planning environment. The main loop happens inside this class"""

def __init__(self, fixed_seed=None, **kwargs):
def __init__(self, args: MagicDict, fixed_seed: int = None):
self.started = False

if fixed_seed is not None:
Expand All @@ -26,15 +26,15 @@ def __init__(self, fixed_seed=None, **kwargs):
print(f"Fixed random seed: {fixed_seed}")

# initialize and prepare screen
self.args = MagicDict(kwargs)
self.args = args
self.stats = Stats(showSampledPoint=self.args.showSampledPoint)

cc_type, self.dist = {
"image": (collisionChecker.ImgCollisionChecker, self.euclidean_dist),
"4d": (collisionChecker.RobotArm4dCollisionChecker, self.euclidean_dist),
"klampt": (collisionChecker.KlamptCollisionChecker, self.radian_dist),
}[kwargs["engine"]]
self.cc = cc_type(self.args.image, args=self.args)
}[self.args.engine]
self.cc = cc_type(self.args.image, stats=self.stats)

self.args["num_dim"] = self.cc.get_dimension()
self.args["image_shape"] = self.cc.get_image_shape()
Expand Down
60 changes: 38 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,41 +104,38 @@

LOGGER = logging.getLogger()

# add all of the registered planners
__doc__ = __doc__.format(
all_available_planners="|".join(
sorted(
(planner.name for planner in planner_registry.PLANNERS.values()),
reverse=True,
)
)
)
RAW_DOC_STRING = __doc__


def generate_args(
planner_id: Optional[str],
map_fname: Optional[str],
start: Optional[np.ndarray] = None,
goal: Optional[np.ndarray] = None,
planner_id: Optional[str],
map_fname: Optional[str],
start: Optional[np.ndarray] = None,
goal: Optional[np.ndarray] = None,
) -> MagicDict:
"""The entry point of the planning scene module
"""Get the default set of arguments
:param map_fname: overrides the map to test
:param planner_id: the planner to use in the planning environment
:param map_fname: the filename of the map to use in the planning environment
:param start: overrides the starting configuration
:param goal: overrides the goal configuration
:return: the default dictionary of arguments to config the planning problem
"""
if map_fname is None or planner_id is None:
if len(sys.argv) < 3:
if __name__ != "__main__":
# being called as a standalone function (as oppose to cli)
if map_fname is None or planner_id is None:
raise ValueError(
"Both `map_fname` and `planner_id` must be provided to " "generate args"
)
else:
if planner_id not in planner_registry.PLANNERS:
raise ValueError(f"The given planner id '{planner_id}' does not exists.")
# inject the inputs for docopt to parse
sys.argv[1:] = [planner_id, map_fname]

args = docopt(__doc__, version="SBP-Env Research v2.0")
# add all of the registered planners
args = docopt(format_doc_with_registered_planners(RAW_DOC_STRING),
version="SBP-Env Research v2.0")

# allow the map filename, start and goal point to be override.
if start is not None:
Expand Down Expand Up @@ -209,7 +206,7 @@ def generate_args(
if (not a.startswith("--") and args[a] is True and a not in ("start", "goal"))
]
assert (
len(planner_canidates) == 1
len(planner_canidates) == 1
), f"Planner to use '{planner_canidates}' has length {len(planner_canidates)}"
planner_to_use = planner_canidates[0]
print(planner_to_use)
Expand Down Expand Up @@ -251,8 +248,27 @@ def generate_args(
return planning_option


def format_doc_with_registered_planners(doc: str):
"""
Format the main.py's doc with the current registered planner
:param doc: the doc string to be formatted
"""
return doc.format(
all_available_planners="|".join(
sorted(
(planner.name for planner in planner_registry.PLANNERS.values()),
reverse=True,
)
)
)


if __name__ == "__main__":
planning_option = generate_args(planner_id=None, map_fname=None)
# The entry point of the planning scene module from the cli
default_arguments = generate_args(planner_id=None, map_fname=None)

environment = env.Env(**planning_option)
environment = env.Env(args=default_arguments)
environment.run()

__doc__ = format_doc_with_registered_planners(__doc__)
3 changes: 2 additions & 1 deletion tests/test_4d_arm_collision_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
create_test_image,
pt,
)
from utils.common import Stats


class TestRobotArm4dCollisionChecker(TestCase):
def setUp(self) -> None:
self.cc = RobotArm4dCollisionChecker(
create_test_image(), stick_robot_length_config=[0.1, 0.1]
create_test_image(), stick_robot_length_config=[0.1, 0.1], stats=Stats()
)
self.target = mock_image_as_np == 255
self.target = self.target.astype(self.cc.image.dtype)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_birrtPlanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["birrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler
self.planner = self.env.args.planner

Expand Down
2 changes: 1 addition & 1 deletion tests/test_birrtSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["birrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler

def test_init(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_image_space_collision_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collisionChecker import ImgCollisionChecker
from tests.common_vars import create_test_image, mock_image_as_np
from utils.common import Stats

eps = 1e-5

Expand All @@ -23,7 +24,7 @@ def pts_pair(args1, args2):

class TestImgCollisionChecker(TestCase):
def setUp(self) -> None:
self.cc = ImgCollisionChecker(create_test_image())
self.cc = ImgCollisionChecker(create_test_image(), stats=Stats())
self.target = mock_image_as_np == 255
self.target = self.target.astype(self.cc.image.dtype)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_informedSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler

def test_get_next_pos(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_informedrrtPlanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler
self.planner = self.env.args.planner

Expand Down
2 changes: 1 addition & 1 deletion tests/test_likelihoodPlanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler
self.planner = self.env.args.planner

Expand Down
2 changes: 1 addition & 1 deletion tests/test_likelihoodPolicySampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler

def test_likelihood_increases(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nearbyPlanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler
self.planner = self.env.args.planner

Expand Down
2 changes: 1 addition & 1 deletion tests/test_nearbyPolicySampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler

def test_likelihood_increases(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prmPlanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["prm"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler
self.planner = self.env.args.planner

Expand Down
2 changes: 1 addition & 1 deletion tests/test_prmSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["prm"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler

def test_init(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_randomPolicySampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler

def test_init(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rrdtPlanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrdt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler
self.planner = self.env.args.planner

Expand Down
2 changes: 1 addition & 1 deletion tests/test_rrdtSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrdt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler
# self.env.args.planner.run_once()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_rrtPlanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setUp(self) -> None:
# use some suitable planner
args["planner_data_pack"] = planner_registry.PLANNERS["rrt"]

self.env = Env(**args)
self.env = Env(args)
self.sampler = self.env.args.sampler
self.planner = self.env.args.planner

Expand Down
1 change: 0 additions & 1 deletion utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ class Stats:
:ivar visible_cnt: the number of calls to visibility test in the collision checker
:ivar feasible_cnt: the number of calls to feasibility test in the collision checker
:type invalid_samples_connections: int
:type invalid_samples_obstacles: int
:type valid_sample: int
Expand Down

0 comments on commit 2d02835

Please sign in to comment.