Skip to content

Commit

Permalink
polish docs, improve tests, add new test classes
Browse files Browse the repository at this point in the history
  • Loading branch information
cahity committed Nov 1, 2024
1 parent c214e88 commit 48eea4b
Showing 1 changed file with 67 additions and 51 deletions.
118 changes: 67 additions & 51 deletions vectoptal/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def plot_2d_cone(ordering_cone, path: Optional[Union[str, PathLike]] = None):
ax.set_ylim(ylim)

ax.set(xticks=[], xticklabels=[], yticks=[], yticklabels=[])
ax.spines['bottom'].set_position('zero')
ax.spines['left'].set_position('zero')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines["bottom"].set_position("zero")
ax.spines["left"].set_position("zero")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

W = ordering_cone.W
cone_degree = ordering_cone.cone_degree if hasattr(ordering_cone, "cone_degree") else 90
Expand All @@ -45,28 +45,34 @@ def plot_2d_cone(ordering_cone, path: Optional[Union[str, PathLike]] = None):
y_left = 0 + m2 * (x_left - 0)

if cone_degree > 90:
verts = np.array([
[0, 0],
[x_left[0], y_left[0]],
[xlim[1], ylim[1]],
[x_right[-1], y_right[-1]],
])
verts = np.array(
[
[0, 0],
[x_left[0], y_left[0]],
[xlim[1], ylim[1]],
[x_right[-1], y_right[-1]],
]
)
elif cone_degree == 90:
verts = np.array([
[0, 0],
[0, ylim[1]],
[xlim[1], ylim[1]],
[xlim[1], 0],
])
verts = np.array(
[
[0, 0],
[0, ylim[1]],
[xlim[1], ylim[1]],
[xlim[1], 0],
]
)
else:
verts = np.array([
[0, 0],
[x_left[-1], y_left[-1]],
[xlim[1], ylim[1]],
[x_right[-1], y_right[-1]],
])
verts = np.array(
[
[0, 0],
[x_left[-1], y_left[-1]],
[xlim[1], ylim[1]],
[x_right[-1], y_right[-1]],
]
)

ax.add_patch(Polygon(verts, color='blue', alpha=0.5))
ax.add_patch(Polygon(verts, color="blue", alpha=0.5))

if path is not None:
fig.savefig(path)
Expand All @@ -87,21 +93,21 @@ def plot_3d_cone(ordering_cone, path: Optional[Union[str, PathLike]] = None):
zlim = [-5, 5]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax = fig.add_subplot(111, projection="3d")
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_zlim(zlim)

ax.set(xticks=[], xticklabels=[], yticks=[], yticklabels=[], zticks=[], zticklabels=[])
ax.spines['bottom'].set_position('zero')
ax.spines['left'].set_position('zero')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines["bottom"].set_position("zero")
ax.spines["left"].set_position("zero")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

# Add X, Y, and Z axis lines at the middle of the region
ax.plot([xlim[0], xlim[1]], [0, 0], [0, 0], color='black')
ax.plot([0, 0], [ylim[0], ylim[1]], [0, 0], color='black')
ax.plot([0, 0], [0, 0], [zlim[0], zlim[1]], color='black')
ax.plot([xlim[0], xlim[1]], [0, 0], [0, 0], color="black")
ax.plot([0, 0], [ylim[0], ylim[1]], [0, 0], color="black")
ax.plot([0, 0], [0, 0], [zlim[0], zlim[1]], color="black")

x_pts = np.linspace(xlim[0], xlim[1], 25)
y_pts = np.linspace(ylim[0], ylim[1], 25)
Expand All @@ -112,17 +118,15 @@ def plot_3d_cone(ordering_cone, path: Optional[Union[str, PathLike]] = None):
pts = pts[ordering_cone.is_inside(pts)]
X, Y, Z = pts[:, 0], pts[:, 1], pts[:, 2]

ax.scatter(X, Y, Z, alpha=0.3, c='blue', s=8)
ax.scatter(X, Y, Z, alpha=0.3, c="blue", s=8)

if path is not None:
fig.savefig(path)

return fig


def plot_pareto_front(
order, elements: np.ndarray, path: Optional[Union[str, PathLike]] = None
):
def plot_pareto_front(order, elements: np.ndarray, path: Optional[Union[str, PathLike]] = None):
dim = elements.shape[1]
assert elements.ndim == 2, "Elements array should be N-by-dim."
assert dim in [2, 3], "Only 2D and 3D plots are supported."
Expand All @@ -137,43 +141,55 @@ def plot_pareto_front(
ax = fig.add_subplot(111)

ax.set(xticks=[], xticklabels=[], yticks=[], yticklabels=[])
ax.spines['bottom'].set_position('center')
ax.spines['left'].set_position('center')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines["bottom"].set_position("center")
ax.spines["left"].set_position("center")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

ax.scatter(
elements[pareto_indices][:, 0],
elements[pareto_indices][:, 1], c="mediumslateblue", label="Pareto", alpha=0.6
elements[pareto_indices][:, 1],
c="mediumslateblue",
label="Pareto",
alpha=0.6,
)
ax.scatter(
elements[non_pareto_indices][:, 0],
elements[non_pareto_indices][:, 1], c="tab:blue", label="Non Pareto", alpha=0.6
elements[non_pareto_indices][:, 1],
c="tab:blue",
label="Non Pareto",
alpha=0.6,
)
else:
ax = fig.add_subplot(111, projection='3d')
ax = fig.add_subplot(111, projection="3d")

ax.set(xticks=[], xticklabels=[], yticks=[], yticklabels=[], zticks=[], zticklabels=[])
ax.spines['bottom'].set_position('center')
ax.spines['left'].set_position('center')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines["bottom"].set_position("center")
ax.spines["left"].set_position("center")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

ax.scatter(
elements[pareto_indices][:, 0],
elements[pareto_indices][:, 1],
elements[pareto_indices][:, 2], c="mediumslateblue", label="Pareto", alpha=0.6
elements[pareto_indices][:, 2],
c="mediumslateblue",
label="Pareto",
alpha=0.6,
)
ax.scatter(
elements[non_pareto_indices][:, 0],
elements[non_pareto_indices][:, 1],
elements[non_pareto_indices][:, 2], c="tab:blue", label="Non Pareto", alpha=0.6
elements[non_pareto_indices][:, 2],
c="tab:blue",
label="Non Pareto",
alpha=0.6,
)

# Add X, Y, and Z axis lines at the middle of the region
ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [0, 0], [0, 0], color='black')
ax.plot([0, 0], [ax.get_ylim()[0], ax.get_ylim()[1]], [0, 0], color='black')
ax.plot([0, 0], [0, 0], [ax.get_xlim()[0], ax.get_zlim()[1]], color='black')
ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [0, 0], [0, 0], color="black")
ax.plot([0, 0], [ax.get_ylim()[0], ax.get_ylim()[1]], [0, 0], color="black")
ax.plot([0, 0], [0, 0], [ax.get_xlim()[0], ax.get_zlim()[1]], color="black")

ax.legend(loc="lower left")
fig.tight_layout()
Expand Down

0 comments on commit 48eea4b

Please sign in to comment.