Skip to content

Commit

Permalink
Fix trajopt ifopt collision with fixed states (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
Levi-Armstrong authored Dec 4, 2023
1 parent 7b93537 commit 89661ad
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 24 deletions.
4 changes: 3 additions & 1 deletion trajopt_common/include/trajopt_common/collision_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ std::size_t cantorHash(int shape_id, int subshape_id);
* @brief Remove any results that are invalid.
* Invalid state are contacts that occur at fixed states or have distances outside the threshold.
* @param contact_results Contact results vector to process.
* @param position_vars_fixed Indicate if a state is fixed
*/
void removeInvalidContactResults(tesseract_collision::ContactResultVector& contact_results,
const Eigen::Vector3d& data);
const Eigen::Vector3d& data,
const std::array<bool, 2>& position_vars_fixed);

/**
* @brief Extracts the gradient information based on the contact results
Expand Down
43 changes: 37 additions & 6 deletions trajopt_common/src/collision_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,44 @@ std::size_t cantorHash(int shape_id, int subshape_id)
return static_cast<std::size_t>(1 / 2.0 * (shape_id + subshape_id) * (shape_id + subshape_id + 1) + subshape_id);
}

void removeInvalidContactResults(tesseract_collision::ContactResultVector& contact_results, const Eigen::Vector3d& data)
void removeInvalidContactResults(tesseract_collision::ContactResultVector& contact_results,
const Eigen::Vector3d& data,
const std::array<bool, 2>& position_vars_fixed)
{
auto end = std::remove_if(
contact_results.begin(), contact_results.end(), [=, &data](const tesseract_collision::ContactResult& r) {
/** @todo Is this correct? (Levi)*/
return (!((data[0] + data[1]) > r.distance));
});
auto end = std::remove_if(contact_results.begin(),
contact_results.end(),
[=, &data, &position_vars_fixed](const tesseract_collision::ContactResult& r) {
/** @todo Is this correct? (Levi)*/
if ((!((data[0] + data[1]) > r.distance)))
return true;

if (!position_vars_fixed[0] && !position_vars_fixed[1])
return false;

if (position_vars_fixed[0])
{
if (r.cc_type[0] != tesseract_collision::ContinuousCollisionType::CCType_None &&
r.cc_type[0] != tesseract_collision::ContinuousCollisionType::CCType_Time0)
return false;

if (r.cc_type[1] != tesseract_collision::ContinuousCollisionType::CCType_None &&
r.cc_type[1] != tesseract_collision::ContinuousCollisionType::CCType_Time0)
return false;
}

if (position_vars_fixed[1])
{
if (r.cc_type[0] != tesseract_collision::ContinuousCollisionType::CCType_None &&
r.cc_type[0] != tesseract_collision::ContinuousCollisionType::CCType_Time1)
return false;

if (r.cc_type[1] != tesseract_collision::ContinuousCollisionType::CCType_None &&
r.cc_type[1] != tesseract_collision::ContinuousCollisionType::CCType_Time1)
return false;
}

return true;
});

contact_results.erase(end, contact_results.end());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ class LVSContinuousCollisionEvaluator : public ContinuousCollisionEvaluator
CalcCollisionsCacheDataHelper(const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals1);

void CalcCollisionsHelper(const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
void CalcCollisionsHelper(tesseract_collision::ContactResultMap& dist_results,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals1,
tesseract_collision::ContactResultMap& dist_results);
const std::array<bool, 2>& position_vars_fixed);
};

/**
Expand Down Expand Up @@ -190,9 +191,10 @@ class LVSDiscreteCollisionEvaluator : public ContinuousCollisionEvaluator
CalcCollisionsCacheDataHelper(const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals1);

void CalcCollisionsHelper(const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
void CalcCollisionsHelper(tesseract_collision::ContactResultMap& dist_results,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals1,
tesseract_collision::ContactResultMap& dist_results);
const std::array<bool, 2>& position_vars_fixed);
};
} // namespace trajopt_ifopt
#endif // TRAJOPT_IFOPT_CONTINUOUS_COLLISION_EVALUATOR_H
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ LVSContinuousCollisionEvaluator::CalcCollisionData(const Eigen::Ref<const Eigen:
}

auto data = std::make_shared<trajopt_common::CollisionCacheData>();
CalcCollisionsHelper(dof_vals0, dof_vals1, data->contact_results_map);
CalcCollisionsHelper(data->contact_results_map, dof_vals0, dof_vals1, position_vars_fixed);

for (const auto& pair : data->contact_results_map)
{
Expand Down Expand Up @@ -165,9 +165,10 @@ LVSContinuousCollisionEvaluator::CalcCollisionData(const Eigen::Ref<const Eigen:
return data;
}

void LVSContinuousCollisionEvaluator::CalcCollisionsHelper(const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
void LVSContinuousCollisionEvaluator::CalcCollisionsHelper(tesseract_collision::ContactResultMap& dist_results,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals1,
tesseract_collision::ContactResultMap& dist_results)
const std::array<bool, 2>& position_vars_fixed)
{
// The first step is to see if the distance between two states is larger than the longest valid segment. If larger
// the collision checking is broken up into multiple casted collision checks such that each check is less then
Expand All @@ -186,7 +187,7 @@ void LVSContinuousCollisionEvaluator::CalcCollisionsHelper(const Eigen::Ref<cons
// Don't include contacts at the fixed state
// Don't include contacts with zero coeffs
const auto& zero_coeff_pairs = collision_config_->collision_coeff_data.getPairsWithZeroCoeff();
auto filter = [this, &zero_coeff_pairs](tesseract_collision::ContactResultMap::PairType& pair) {
auto filter = [this, &zero_coeff_pairs, &position_vars_fixed](tesseract_collision::ContactResultMap::PairType& pair) {
// Remove pairs with zero coeffs
if (std::find(zero_coeff_pairs.begin(), zero_coeff_pairs.end(), pair.first) != zero_coeff_pairs.end())
{
Expand All @@ -199,7 +200,7 @@ void LVSContinuousCollisionEvaluator::CalcCollisionsHelper(const Eigen::Ref<cons
pair.first.second);
double coeff = collision_config_->collision_coeff_data.getPairCollisionCoeff(pair.first.first, pair.first.second);
const Eigen::Vector3d data = { dist, collision_config_->collision_margin_buffer, coeff };
trajopt_common::removeInvalidContactResults(pair.second, data); /** @todo Should this be removed? levi */
trajopt_common::removeInvalidContactResults(pair.second, data, position_vars_fixed);
};

if (collision_config_->type == tesseract_collision::CollisionEvaluatorType::LVS_CONTINUOUS &&
Expand Down Expand Up @@ -329,7 +330,7 @@ LVSDiscreteCollisionEvaluator::CalcCollisionData(const Eigen::Ref<const Eigen::V
}

auto data = std::make_shared<trajopt_common::CollisionCacheData>();
CalcCollisionsHelper(dof_vals0, dof_vals1, data->contact_results_map);
CalcCollisionsHelper(data->contact_results_map, dof_vals0, dof_vals1, position_vars_fixed);
for (const auto& pair : data->contact_results_map)
{
using ShapeGrsType = std::map<std::pair<std::size_t, std::size_t>, trajopt_common::GradientResultsSet>;
Expand Down Expand Up @@ -401,9 +402,10 @@ LVSDiscreteCollisionEvaluator::CalcCollisionData(const Eigen::Ref<const Eigen::V
return data;
}

void LVSDiscreteCollisionEvaluator::CalcCollisionsHelper(const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
void LVSDiscreteCollisionEvaluator::CalcCollisionsHelper(tesseract_collision::ContactResultMap& dist_results,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals0,
const Eigen::Ref<const Eigen::VectorXd>& dof_vals1,
tesseract_collision::ContactResultMap& dist_results)
const std::array<bool, 2>& position_vars_fixed)
{
// If not empty then there are links that are not part of the kinematics object that can move (dynamic environment)
if (!diff_active_link_names_.empty())
Expand All @@ -417,7 +419,7 @@ void LVSDiscreteCollisionEvaluator::CalcCollisionsHelper(const Eigen::Ref<const
// Don't include contacts at the fixed state
// Don't include contacts with zero coeffs
const auto& zero_coeff_pairs = collision_config_->collision_coeff_data.getPairsWithZeroCoeff();
auto filter = [this, &zero_coeff_pairs](tesseract_collision::ContactResultMap::PairType& pair) {
auto filter = [this, &zero_coeff_pairs, &position_vars_fixed](tesseract_collision::ContactResultMap::PairType& pair) {
// Remove pairs with zero coeffs
if (std::find(zero_coeff_pairs.begin(), zero_coeff_pairs.end(), pair.first) != zero_coeff_pairs.end())
{
Expand All @@ -432,7 +434,7 @@ void LVSDiscreteCollisionEvaluator::CalcCollisionsHelper(const Eigen::Ref<const
const Eigen::Vector3d data = { dist, collision_config_->collision_margin_buffer, coeff };

// Don't include contacts at the fixed state
trajopt_common::removeInvalidContactResults(pair.second, data);
trajopt_common::removeInvalidContactResults(pair.second, data, position_vars_fixed);
};

// The first step is to see if the distance between two states is larger than the longest valid segment. If larger
Expand Down
2 changes: 1 addition & 1 deletion trajopt_optimizers/trajopt_sqp/include/trajopt_sqp/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ enum class SQPStatus
NLP_CONVERGED, /**< NLP Successfully converged */
ITERATION_LIMIT, /**< SQP Optimization reached iteration limit */
PENALTY_ITERATION_LIMIT, /**< SQP Optimization reached penalty iteration limit */
TIME_LIMIT, /**< SQP Optimization reached reached limit */
OPT_TIME_LIMIT, /**< SQP Optimization reached reached limit */
QP_SOLVER_ERROR, /**< QP Solver failed */
CALLBACK_STOPPED /**< Optimization stopped because callback returned false */
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void TrustRegionSQPSolver::solve(const QPProblem::Ptr& qp_problem)
if (elapsed_time > params.max_time)
{
CONSOLE_BRIDGE_logInform("Elapsed time %f has exceeded max time %f", elapsed_time, params.max_time);
status_ = SQPStatus::TIME_LIMIT;
status_ = SQPStatus::OPT_TIME_LIMIT;
break;
}

Expand All @@ -127,7 +127,7 @@ void TrustRegionSQPSolver::solve(const QPProblem::Ptr& qp_problem)
}

// If status is iteration limit or time limit we need to exit penalty iteration loop
if (status_ == SQPStatus::ITERATION_LIMIT || status_ == SQPStatus::TIME_LIMIT)
if (status_ == SQPStatus::ITERATION_LIMIT || status_ == SQPStatus::OPT_TIME_LIMIT)
break;

// Set status to running
Expand Down

0 comments on commit 89661ad

Please sign in to comment.