Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug arclength_sampling #586

Merged
merged 1 commit into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dynamo/prediction/state_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def state_graph(
)

if arc_sample:
Y, arclength, T = arclength_sampling(Y, arclength / 1000, t=t[~T_bool])
Y, arclength, T = arclength_sampling(Y, arclength / 1000, n_steps=1000, t=t[~T_bool])
else:
T = t[~T_bool]
else:
Expand Down
32 changes: 22 additions & 10 deletions dynamo/prediction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def integrate_vf_ivp(
x = x[:idx]
_, arclen, _ = remove_redundant_points_trajectory(x, tol=1e-4, output_discard=True)
arc_stepsize = arclen / interpolation_num
cur_Y, alen, t_[i] = arclength_sampling(x, step_length=arc_stepsize, t=tau[:idx])
cur_Y, alen, t_[i] = arclength_sampling(x, step_length=arc_stepsize, n_steps=interpolation_num, t=tau[:idx])

if integration_direction == "both":
neg_t_len = sum(np.array(t_[i]) < 0)
Expand Down Expand Up @@ -429,7 +429,7 @@ def remove_redundant_points_trajectory(X, tol=1e-4, output_discard=False):
return (X, arclength)


def arclength_sampling(X, step_length, t=None):
def arclength_sampling(X, step_length, n_steps: int, t=None):
"""uniformly sample data points on an arc curve that generated from vector field predictions."""
Y = []
x0 = X[0]
Expand All @@ -439,20 +439,29 @@ def arclength_sampling(X, step_length, t=None):
terminate = False
arclength = 0

def _calculate_new_point():
x = x0 if j == i else X[j - 1]
cur_y = x + (step_length - L) * tangent / d

if t is not None:
cur_tau = t0 if j == i else t[j - 1]
cur_tau += (step_length - L) / d * (t[j] - cur_tau)
T.append(cur_tau)
else:
cur_tau = None

Y.append(cur_y)

return cur_y, cur_tau

while i < len(X) - 1 and not terminate:
L = 0
for j in range(i, len(X)):
tangent = X[j] - x0 if j == i else X[j] - X[j - 1]
d = np.linalg.norm(tangent)
if L + d >= step_length:
x = x0 if j == i else X[j - 1]
y = x + (step_length - L) * tangent / d
if t is not None:
tau = t0 if j == i else t[j - 1]
tau += (step_length - L) / d * (t[j] - tau)
T.append(tau)
t0 = tau
Y.append(y)
y, tau = _calculate_new_point()
t0 = tau if t is not None else None
x0 = y
i = j
break
Expand All @@ -464,6 +473,9 @@ def arclength_sampling(X, step_length, t=None):
if L + d < step_length:
terminate = True

if terminate and len(Y) < n_steps:
_, _ = _calculate_new_point()

if T is not None:
return np.array(Y), arclength, T
else:
Expand Down
Loading