Skip to content

Commit

Permalink
Merge pull request #10 from jjshoots/test_jit2
Browse files Browse the repository at this point in the history
Jit some expensive functions
  • Loading branch information
jjshoots authored Aug 9, 2023
2 parents cf315a9 + 65c7c31 commit a89a0a2
Show file tree
Hide file tree
Showing 10 changed files with 330 additions and 138 deletions.
53 changes: 43 additions & 10 deletions PyFlyt/core/abstractions/gimbals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from pybullet_utils import bullet_client

from PyFlyt.utils import jitter


class Gimbals:
"""A set of actuated gimbals.
Expand Down Expand Up @@ -136,7 +138,7 @@ def physics_update(self):
"`state_update` does not need to be called for gimbals, call `compute_rotation` instead."
)

def compute_rotation(self, gimbal_command) -> np.ndarray:
def compute_rotation(self, gimbal_command: np.ndarray) -> np.ndarray:
"""Returns a rotation vector after the gimbal rotation.
Args:
Expand All @@ -154,25 +156,56 @@ def compute_rotation(self, gimbal_command) -> np.ndarray:
gimbal_command - self.gimbal_state
)

# precompute some things
gimbal_angles = np.expand_dims(
self.gimbal_state * self.gimbal_range_radians, axis=(-1, -2)
# compute gimbal euler angles
gimbal_angles = self.gimbal_state * self.gimbal_range_radians
gimbal_angles = gimbal_angles.reshape(*gimbal_angles.shape, 1, 1)

# compute gimbal rotation matrix
(rotation1, rotation2) = self._jitted_compute_rotation(
gimbal_angles,
self.w1,
self.w2,
self.w1_squared,
self.w2_squared,
)
return rotation1 @ rotation2

@staticmethod
@jitter
def _jitted_compute_rotation(
gimbal_angles: np.ndarray,
w1: np.ndarray,
w2: np.ndarray,
w1_squared: np.ndarray,
w2_squared: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute the rotation matrix given the gimbal action values.
Args:
gimbal_angles (np.ndarray): gimbal_angles
w1 (np.ndarray): w1 from self
w2 (np.ndarray): w2 from self
w1_squared (np.ndarray): w1_squared from self
w2_squared (np.ndarray): w2_squared from self
Returns:
tuple[np.ndarray, np.ndarray]:
"""
# precompute some things
sin_angles = np.sin(gimbal_angles)
sin_half_angles = np.sin(gimbal_angles / 2.0)

# start calculating rotation matrices
# https://math.stackexchange.com/questions/142821/matrix-for-rotation-around-a-vector
rotation1 = (
np.eye(3)
+ sin_angles[:, 0, ...] * self.w1
+ 2 * (sin_half_angles[:, 0, ...] ** 2) * self.w1_squared
+ sin_angles[:, 0, ...] * w1
+ 2 * (sin_half_angles[:, 0, ...] ** 2) * w1_squared
)
rotation2 = (
np.eye(3)
+ sin_angles[:, 1, ...] * self.w2
+ 2 * (sin_half_angles[:, 1, ...] ** 2) * self.w2_squared
+ sin_angles[:, 1, ...] * w2
+ 2 * (sin_half_angles[:, 1, ...] ** 2) * w2_squared
)

# get the final thrust vector
return rotation1 @ rotation2
return rotation1, rotation2
Loading

0 comments on commit a89a0a2

Please sign in to comment.