Skip to content

Commit

Permalink
fix tests and selection of visualiser
Browse files Browse the repository at this point in the history
Signed-off-by: Tin Lai <[email protected]>
  • Loading branch information
soraxas committed Jun 12, 2021
1 parent 2d02835 commit b76c39c
Showing 16 changed files with 68 additions and 32 deletions.
31 changes: 24 additions & 7 deletions env.py
Original file line number Diff line number Diff line change
@@ -15,9 +15,15 @@


class Env:
"""Represents the planning environment. The main loop happens inside this class"""
"""Represents the planning environment.
The main planning loop happens inside this class.
"""

def __init__(self, args: MagicDict, fixed_seed: int = None):
"""
:param args: the dictionary of arguments to config the planning problem
:param fixed_seed: if given, fix the random seed
"""
self.started = False

if fixed_seed is not None:
@@ -36,6 +42,17 @@ def __init__(self, args: MagicDict, fixed_seed: int = None):
}[self.args.engine]
self.cc = cc_type(self.args.image, stats=self.stats)

# setup visualiser
if self.args.no_display:
# use pass-through visualiser
VisualiserSwitcher.choose_visualiser("base")
else:
if self.args.engine in ("image", "4d"):
# use pygame visualiser
VisualiserSwitcher.choose_visualiser("pygame")
elif self.args.engine == "klampt":
VisualiserSwitcher.choose_visualiser("klampt")

self.args["num_dim"] = self.cc.get_dimension()
self.args["image_shape"] = self.cc.get_image_shape()
self.dim = self.args["image_shape"]
@@ -53,27 +70,27 @@ def parse_input_pt(pt_as_str):
)
return tuple(map(float, pt))

self.args["start_pt"] = parse_input_pt(self.args["start_pt"])
self.args["goal_pt"] = parse_input_pt(self.args["goal_pt"])
if type(self.args["start_pt"]) is str:
self.args["start_pt"] = parse_input_pt(self.args["start_pt"])
if type(self.args["goal_pt"]) is str:
self.args["goal_pt"] = parse_input_pt(self.args["goal_pt"])

self.args.planner = self.args.planner_data_pack.planner_class(**self.args)
self.args["planner"] = self.args.planner

self.planner = self.args.planner
self.planner.args.env = self

super().__init__()

self.visualiser = VisualiserSwitcher.env_clname(env_instance=self)
self.visualiser.visualiser_init(no_display=self.args["no_display"])
start_pt, goal_pt = self.visualiser.set_start_goal_points(
start=self.args["start_pt"], goal=self.args["goal_pt"]
)

self.start_pt = self.goal_pt = None
if start_pt:
if start_pt is not None:
self.start_pt = Node(start_pt)
if goal_pt:
if goal_pt is not None:
self.goal_pt = Node(goal_pt)
self.start_pt.is_start = True
self.goal_pt.is_goal = True
41 changes: 16 additions & 25 deletions main.py
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@
import logging
import re
import sys
from typing import Optional
from typing import Optional, Union

import numpy as np
from docopt import docopt
@@ -98,7 +98,6 @@
import planners
from utils import planner_registry
from utils.common import MagicDict
from visualiser import VisualiserSwitcher

assert planners

@@ -108,17 +107,17 @@


def generate_args(
planner_id: Optional[str],
map_fname: Optional[str],
start: Optional[np.ndarray] = None,
goal: Optional[np.ndarray] = None,
planner_id: Optional[str],
map_fname: Optional[str],
start_pt: Optional[Union[np.ndarray, str]] = None,
goal_pt: Optional[Union[np.ndarray, str]] = None,
) -> MagicDict:
"""Get the default set of arguments
:param planner_id: the planner to use in the planning environment
:param map_fname: the filename of the map to use in the planning environment
:param start: overrides the starting configuration
:param goal: overrides the goal configuration
:param start_pt: overrides the starting configuration
:param goal_pt: overrides the goal configuration
:return: the default dictionary of arguments to config the planning problem
"""
@@ -134,14 +133,16 @@ def generate_args(
sys.argv[1:] = [planner_id, map_fname]

# add all of the registered planners
args = docopt(format_doc_with_registered_planners(RAW_DOC_STRING),
version="SBP-Env Research v2.0")
args = docopt(
format_doc_with_registered_planners(RAW_DOC_STRING),
version="SBP-Env Research v2.0",
)

# allow the map filename, start and goal point to be override.
if start is not None:
args["<start_x1,x2,..,xn>"] = start
if goal is not None:
args["<goal_x1,x2,..,xn>"] = goal
if start_pt is not None:
args["<start_x1,x2,..,xn>"] = start_pt
if goal_pt is not None:
args["<goal_x1,x2,..,xn>"] = goal_pt

# setup environment engine
args["--engine"] = "" if args["--engine"] is None else args["--engine"].lower()
@@ -172,16 +173,6 @@ def generate_args(
)
LOGGER.info(_notice.format(args["--engine"], _file_extension))

if args["--no-display"]:
# use pass-through visualiser
VisualiserSwitcher.choose_visualiser("base")
else:
if args["--engine"] in ("image", "4d"):
# use pygame visualiser
VisualiserSwitcher.choose_visualiser("pygame")
elif args["--engine"] == "klampt":
VisualiserSwitcher.choose_visualiser("klampt")

########################################

if args["--verbose"] > 2:
@@ -206,7 +197,7 @@ def generate_args(
if (not a.startswith("--") and args[a] is True and a not in ("start", "goal"))
]
assert (
len(planner_canidates) == 1
len(planner_canidates) == 1
), f"Planner to use '{planner_canidates}' has length {len(planner_canidates)}"
planner_to_use = planner_canidates[0]
print(planner_to_use)
2 changes: 2 additions & 0 deletions tests/test_birrtPlanner.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import numpy as np

import visualiser
from env import Env
from samplers.birrtSampler import BiRRTSampler
from tests.common_vars import template_args, MockNumpyEquality
@@ -14,6 +15,7 @@
class TestBiRRTPlanner(TestRRTPlanner):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = BiRRTSampler()
2 changes: 2 additions & 0 deletions tests/test_birrtSampler.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import numpy as np

import visualiser
from env import Env
from samplers.birrtSampler import BiRRTSampler
from tests.common_vars import template_args
@@ -12,6 +13,7 @@
class TestBiRRTSampler(TestCase):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = BiRRTSampler()
2 changes: 2 additions & 0 deletions tests/test_informedSampler.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import numpy as np

import visualiser
from env import Env
from samplers.informedSampler import InformedSampler
from tests.common_vars import template_args
@@ -12,6 +13,7 @@
class TestInformedSampler(TestCase):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = InformedSampler()
2 changes: 2 additions & 0 deletions tests/test_informedrrtPlanner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy
from unittest.mock import MagicMock

import visualiser
from env import Env
from samplers.informedSampler import InformedSampler
from tests.common_vars import template_args
@@ -12,6 +13,7 @@
class TestInformedRRTPlanner(TestRRTPlanner):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = InformedSampler(prob_block_size=10)
2 changes: 2 additions & 0 deletions tests/test_likelihoodPlanner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy
from unittest.mock import MagicMock

import visualiser
from env import Env
from samplers.likelihoodPolicySampler import LikelihoodPolicySampler
from tests.common_vars import template_args
@@ -12,6 +13,7 @@
class LikelihoodPolicyPlanner(TestRRTPlanner):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = LikelihoodPolicySampler(prob_block_size=10)
2 changes: 2 additions & 0 deletions tests/test_likelihoodPolicySampler.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import numpy as np

import visualiser
from env import Env
from samplers.likelihoodPolicySampler import LikelihoodPolicySampler
from tests.common_vars import template_args
@@ -13,6 +14,7 @@
class TestLikelihoodPolicySampler(TestCase):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = LikelihoodPolicySampler(10)
2 changes: 2 additions & 0 deletions tests/test_nearbyPlanner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy
from unittest.mock import MagicMock

import visualiser
from env import Env
from samplers.nearbyPolicySampler import NearbyPolicySampler
from tests.common_vars import template_args
@@ -12,6 +13,7 @@
class TestNearbyPolicyPlanner(TestRRTPlanner):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = NearbyPolicySampler(prob_block_size=10)
2 changes: 2 additions & 0 deletions tests/test_nearbyPolicySampler.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import numpy as np

import visualiser
from env import Env
from planners.rrdtPlanner import Node
from samplers.nearbyPolicySampler import NearbyPolicySampler
@@ -13,6 +14,7 @@
class TestNearbyPolicySampler(TestCase):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = NearbyPolicySampler(prob_block_size=10)
2 changes: 2 additions & 0 deletions tests/test_prmPlanner.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import numpy as np

import visualiser
from env import Env
from samplers.prmSampler import PRMSampler
from tests.common_vars import template_args, MockNumpyEquality
@@ -14,6 +15,7 @@
class PRMPlanner(TestRRTPlanner):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = PRMSampler(prob_block_size=10)
2 changes: 2 additions & 0 deletions tests/test_prmSampler.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import numpy as np

import visualiser
from env import Env
from planners.rrdtPlanner import Node
from samplers.prmSampler import PRMSampler
@@ -13,6 +14,7 @@
class TestPRMSampler(TestCase):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = PRMSampler()
2 changes: 2 additions & 0 deletions tests/test_randomPolicySampler.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import numpy as np

import visualiser
from env import Env
from samplers.randomPolicySampler import RandomPolicySampler
from tests.common_vars import template_args
@@ -14,6 +15,7 @@ class TestRandomPolicySampler(TestCase):
def setUp(self) -> None:

args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = RandomPolicySampler()
2 changes: 2 additions & 0 deletions tests/test_rrdtPlanner.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import numpy as np
from tqdm import tqdm

import visualiser
from env import Env
from planners.rrdtPlanner import RRdTSampler
from tests.common_vars import template_args, MockNumpyEquality
@@ -18,6 +19,7 @@
class TestRRdTPlanner(TestRRTPlanner):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = RRdTSampler()
2 changes: 2 additions & 0 deletions tests/test_rrdtSampler.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

import numpy as np

import visualiser
from env import Env
from planners.rrdtPlanner import RRdTSampler, Node
from tests.common_vars import template_args, MockNumpyEquality
@@ -13,6 +14,7 @@
class TestRRdTSampler(TestCase):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = RRdTSampler()
2 changes: 2 additions & 0 deletions tests/test_rrtPlanner.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

import numpy as np

import visualiser
from env import Env
from samplers.randomPolicySampler import RandomPolicySampler
from tests.common_vars import template_args, MockNumpyEquality
@@ -14,6 +15,7 @@
class TestRRTPlanner(TestCase):
def setUp(self) -> None:
args = deepcopy(template_args)
visualiser.VisualiserSwitcher.choose_visualiser("base")

# setup to use the correct sampler
args["sampler"] = RandomPolicySampler()

0 comments on commit b76c39c

Please sign in to comment.