Skip to content

Commit

Permalink
[wpimath] Simplify pose estimator (#6705)
Browse files Browse the repository at this point in the history
  • Loading branch information
KangarooKoala authored Jun 29, 2024
1 parent 5e745bc commit 512a4bf
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 207 deletions.
228 changes: 127 additions & 101 deletions wpimath/src/main/java/edu/wpi/first/math/estimator/PoseEstimator.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
package edu.wpi.first.math.estimator;

import edu.wpi.first.math.MathSharedStore;
import edu.wpi.first.math.MathUtil;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Nat;
import edu.wpi.first.math.VecBuilder;
import edu.wpi.first.math.geometry.Pose2d;
import edu.wpi.first.math.geometry.Rotation2d;
import edu.wpi.first.math.geometry.Twist2d;
import edu.wpi.first.math.interpolation.Interpolatable;
import edu.wpi.first.math.interpolation.TimeInterpolatableBuffer;
import edu.wpi.first.math.kinematics.Kinematics;
import edu.wpi.first.math.kinematics.Odometry;
import edu.wpi.first.math.numbers.N1;
import edu.wpi.first.math.numbers.N3;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.NavigableMap;
import java.util.Optional;
import java.util.TreeMap;

/**
* This class wraps {@link Odometry} to fuse latency-compensated vision measurements with encoder
Expand All @@ -38,14 +37,20 @@
* @param <T> Wheel positions type.
*/
public class PoseEstimator<T> {
private final Kinematics<?, T> m_kinematics;
private final Odometry<T> m_odometry;
private final Matrix<N3, N1> m_q = new Matrix<>(Nat.N3(), Nat.N1());
private final Matrix<N3, N3> m_visionK = new Matrix<>(Nat.N3(), Nat.N3());

private static final double kBufferDuration = 1.5;
private final TimeInterpolatableBuffer<InterpolationRecord> m_poseBuffer =
// Maps timestamps to odometry-only pose estimates
private final TimeInterpolatableBuffer<Pose2d> m_odometryPoseBuffer =
TimeInterpolatableBuffer.createBuffer(kBufferDuration);
// Maps timestamps to vision updates
// Always contains one entry before the oldest entry in m_odometryPoseBuffer, unless there have
// been no vision measurements after the last reset
private final NavigableMap<Double, VisionUpdate> m_visionUpdates = new TreeMap<>();

private Pose2d m_poseEstimate;

/**
* Constructs a PoseEstimator.
Expand All @@ -59,14 +64,16 @@ public class PoseEstimator<T> {
* in meters, y position in meters, and heading in radians). Increase these numbers to trust
* the vision pose measurement less.
*/
@SuppressWarnings("PMD.UnusedFormalParameter")
public PoseEstimator(
Kinematics<?, T> kinematics,
Odometry<T> odometry,
Matrix<N3, N1> stateStdDevs,
Matrix<N3, N1> visionMeasurementStdDevs) {
m_kinematics = kinematics;
m_odometry = odometry;

m_poseEstimate = m_odometry.getPoseMeters();

for (int i = 0; i < 3; ++i) {
m_q.set(i, 0, stateStdDevs.get(i, 0) * stateStdDevs.get(i, 0));
}
Expand Down Expand Up @@ -113,7 +120,9 @@ public final void setVisionMeasurementStdDevs(Matrix<N3, N1> visionMeasurementSt
public void resetPosition(Rotation2d gyroAngle, T wheelPositions, Pose2d poseMeters) {
// Reset state estimate and error covariance
m_odometry.resetPosition(gyroAngle, wheelPositions, poseMeters);
m_poseBuffer.clear();
m_odometryPoseBuffer.clear();
m_visionUpdates.clear();
m_poseEstimate = m_odometry.getPoseMeters();
}

/**
Expand All @@ -122,7 +131,7 @@ public void resetPosition(Rotation2d gyroAngle, T wheelPositions, Pose2d poseMet
* @return The estimated robot pose in meters.
*/
public Pose2d getEstimatedPosition() {
return m_odometry.getPoseMeters();
return m_poseEstimate;
}

/**
Expand All @@ -132,7 +141,54 @@ public Pose2d getEstimatedPosition() {
* @return The pose at the given timestamp (or Optional.empty() if the buffer is empty).
*/
public Optional<Pose2d> sampleAt(double timestampSeconds) {
return m_poseBuffer.getSample(timestampSeconds).map(record -> record.poseMeters);
// Step 0: If there are no odometry updates to sample, skip.
if (m_odometryPoseBuffer.getInternalBuffer().isEmpty()) {
return Optional.empty();
}

// Step 1: Make sure timestamp matches the sample from the odometry pose buffer. (When sampling,
// the buffer will always use a timestamp between the first and last timestamps)
double oldestOdometryTimestamp = m_odometryPoseBuffer.getInternalBuffer().firstKey();
double newestOdometryTimestamp = m_odometryPoseBuffer.getInternalBuffer().lastKey();
timestampSeconds =
MathUtil.clamp(timestampSeconds, oldestOdometryTimestamp, newestOdometryTimestamp);

// Step 2: If there are no applicable vision updates, use the odometry-only information.
if (m_visionUpdates.isEmpty() || timestampSeconds < m_visionUpdates.firstKey()) {
return m_odometryPoseBuffer.getSample(timestampSeconds);
}

// Step 3: Get the latest vision update from before or at the timestamp to sample at.
double floorTimestamp = m_visionUpdates.floorKey(timestampSeconds);
var visionUpdate = m_visionUpdates.get(floorTimestamp);

// Step 4: Get the pose measured by odometry at the time of the sample.
var odometryEstimate = m_odometryPoseBuffer.getSample(timestampSeconds);

// Step 5: Apply the vision compensation to the odometry pose.
return odometryEstimate.map(odometryPose -> visionUpdate.compensate(odometryPose));
}

/** Removes stale vision updates that won't affect sampling. */
private void cleanUpVisionUpdates() {
// Step 0: If there are no odometry samples, skip.
if (m_odometryPoseBuffer.getInternalBuffer().isEmpty()) {
return;
}

// Step 1: Find the oldest timestamp that needs a vision update.
double oldestOdometryTimestamp = m_odometryPoseBuffer.getInternalBuffer().firstKey();

// Step 2: If there are no vision updates before that timestamp, skip.
if (m_visionUpdates.isEmpty() || oldestOdometryTimestamp < m_visionUpdates.firstKey()) {
return;
}

// Step 3: Find the newest vision update timestamp before or at the oldest timestamp.
double newestNeededVisionUpdateTimestamp = m_visionUpdates.floorKey(oldestOdometryTimestamp);

// Step 4: Remove all entries strictly before the newest timestamp we need.
m_visionUpdates.headMap(newestNeededVisionUpdateTimestamp, false).clear();
}

/**
Expand All @@ -156,50 +212,51 @@ public Optional<Pose2d> sampleAt(double timestampSeconds) {
*/
public void addVisionMeasurement(Pose2d visionRobotPoseMeters, double timestampSeconds) {
// Step 0: If this measurement is old enough to be outside the pose buffer's timespan, skip.
try {
if (m_poseBuffer.getInternalBuffer().lastKey() - kBufferDuration > timestampSeconds) {
return;
}
} catch (NoSuchElementException ex) {
if (m_odometryPoseBuffer.getInternalBuffer().isEmpty()
|| m_odometryPoseBuffer.getInternalBuffer().lastKey() - kBufferDuration
> timestampSeconds) {
return;
}

// Step 1: Get the pose odometry measured at the moment the vision measurement was made.
var sample = m_poseBuffer.getSample(timestampSeconds);
// Step 1: Clean up any old entries
cleanUpVisionUpdates();

// Step 2: Get the pose measured by odometry at the moment the vision measurement was made.
var odometrySample = m_odometryPoseBuffer.getSample(timestampSeconds);

if (sample.isEmpty()) {
if (odometrySample.isEmpty()) {
return;
}

// Step 2: Measure the twist between the odometry pose and the vision pose.
var twist = sample.get().poseMeters.log(visionRobotPoseMeters);
// Step 3: Get the vision-compensated pose estimate at the moment the vision measurement was
// made.
var visionSample = sampleAt(timestampSeconds);

// Step 3: We should not trust the twist entirely, so instead we scale this twist by a Kalman
if (visionSample.isEmpty()) {
return;
}

// Step 4: Measure the twist between the old pose estimate and the vision pose.
var twist = visionSample.get().log(visionRobotPoseMeters);

// Step 5: We should not trust the twist entirely, so instead we scale this twist by a Kalman
// gain matrix representing how much we trust vision measurements compared to our current pose.
var k_times_twist = m_visionK.times(VecBuilder.fill(twist.dx, twist.dy, twist.dtheta));

// Step 4: Convert back to Twist2d.
// Step 6: Convert back to Twist2d.
var scaledTwist =
new Twist2d(k_times_twist.get(0, 0), k_times_twist.get(1, 0), k_times_twist.get(2, 0));

// Step 5: Reset Odometry to state at sample with vision adjustment.
m_odometry.resetPosition(
sample.get().gyroAngle,
sample.get().wheelPositions,
sample.get().poseMeters.exp(scaledTwist));

// Step 6: Record the current pose to allow multiple measurements from the same timestamp
m_poseBuffer.addSample(
timestampSeconds,
new InterpolationRecord(
getEstimatedPosition(), sample.get().gyroAngle, sample.get().wheelPositions));

// Step 7: Replay odometry inputs between sample time and latest recorded sample to update the
// pose buffer and correct odometry.
for (Map.Entry<Double, InterpolationRecord> entry :
m_poseBuffer.getInternalBuffer().tailMap(timestampSeconds).entrySet()) {
updateWithTime(entry.getKey(), entry.getValue().gyroAngle, entry.getValue().wheelPositions);
}
// Step 7: Calculate and record the vision update.
var visionUpdate = new VisionUpdate(visionSample.get().exp(scaledTwist), odometrySample.get());
m_visionUpdates.put(timestampSeconds, visionUpdate);

// Step 8: Remove later vision measurements. (Matches previous behavior)
m_visionUpdates.tailMap(timestampSeconds, false).entrySet().clear();

// Step 9: Update latest pose estimate. Since we cleared all updates after this vision update,
// it's guaranteed to be the latest vision update.
m_poseEstimate = visionUpdate.compensate(m_odometry.getPoseMeters());
}

/**
Expand Down Expand Up @@ -258,83 +315,52 @@ public Pose2d update(Rotation2d gyroAngle, T wheelPositions) {
* @return The estimated pose of the robot in meters.
*/
public Pose2d updateWithTime(double currentTimeSeconds, Rotation2d gyroAngle, T wheelPositions) {
m_odometry.update(gyroAngle, wheelPositions);
m_poseBuffer.addSample(
currentTimeSeconds,
new InterpolationRecord(
getEstimatedPosition(), gyroAngle, m_kinematics.copy(wheelPositions)));
var odometryEstimate = m_odometry.update(gyroAngle, wheelPositions);

m_odometryPoseBuffer.addSample(currentTimeSeconds, odometryEstimate);

if (m_visionUpdates.isEmpty()) {
m_poseEstimate = odometryEstimate;
} else {
var visionUpdate = m_visionUpdates.get(m_visionUpdates.lastKey());
m_poseEstimate = visionUpdate.compensate(odometryEstimate);
}

return getEstimatedPosition();
}

/**
* Represents an odometry record. The record contains the inputs provided as well as the pose that
* was observed based on these inputs, as well as the previous record and its inputs.
* Represents a vision update record. The record contains the vision-compensated pose estimate as
* well as the corresponding odometry pose estimate.
*/
private final class InterpolationRecord implements Interpolatable<InterpolationRecord> {
// The pose observed given the current sensor inputs and the previous pose.
private final Pose2d poseMeters;
private static final class VisionUpdate {
// The vision-compensated pose estimate.
private final Pose2d visionPose;

// The current gyro angle.
private final Rotation2d gyroAngle;

// The current encoder readings.
private final T wheelPositions;
// The pose estimated based solely on odometry.
private final Pose2d odometryPose;

/**
* Constructs an Interpolation Record with the specified parameters.
* Constructs a vision update record with the specified parameters.
*
* @param poseMeters The pose observed given the current sensor inputs and the previous pose.
* @param gyro The current gyro angle.
* @param wheelPositions The current encoder readings.
* @param visionPose The vision-compensated pose estimate.
* @param odometryPose The pose estimate based solely on odometry.
*/
private InterpolationRecord(Pose2d poseMeters, Rotation2d gyro, T wheelPositions) {
this.poseMeters = poseMeters;
this.gyroAngle = gyro;
this.wheelPositions = wheelPositions;
private VisionUpdate(Pose2d visionPose, Pose2d odometryPose) {
this.visionPose = visionPose;
this.odometryPose = odometryPose;
}

/**
* Return the interpolated record. This object is assumed to be the starting position, or lower
* bound.
* Returns the vision-compensated version of the pose. Specifically, changes the pose from being
* relative to this record's odometry pose to being relative to this record's vision pose.
*
* @param endValue The upper bound, or end.
* @param t How far between the lower and upper bound we are. This should be bounded in [0, 1].
* @return The interpolated value.
* @param pose The pose to compensate.
* @return The compensated pose.
*/
@Override
public InterpolationRecord interpolate(InterpolationRecord endValue, double t) {
if (t < 0) {
return this;
} else if (t >= 1) {
return endValue;
} else {
// Find the new wheel distances.
var wheelLerp = m_kinematics.interpolate(wheelPositions, endValue.wheelPositions, t);

// Find the new gyro angle.
var gyroLerp = gyroAngle.interpolate(endValue.gyroAngle, t);

// Create a twist to represent the change based on the interpolated sensor inputs.
Twist2d twist = m_kinematics.toTwist2d(wheelPositions, wheelLerp);
twist.dtheta = gyroLerp.minus(gyroAngle).getRadians();

return new InterpolationRecord(poseMeters.exp(twist), gyroLerp, wheelLerp);
}
}

@Override
public boolean equals(Object obj) {
return this == obj
|| obj instanceof PoseEstimator<?>.InterpolationRecord record
&& Objects.equals(gyroAngle, record.gyroAngle)
&& Objects.equals(wheelPositions, record.wheelPositions)
&& Objects.equals(poseMeters, record.poseMeters);
}

@Override
public int hashCode() {
return Objects.hash(gyroAngle, wheelPositions, poseMeters);
public Pose2d compensate(Pose2d pose) {
var delta = pose.minus(this.odometryPose);
return this.visionPose.plus(delta);
}
}
}
Loading

0 comments on commit 512a4bf

Please sign in to comment.