Skip to content

Commit

Permalink
Native quaternion implementation (#166)
Browse files Browse the repository at this point in the history
This PR removes the dependency on boost::math::quaternion in favor of a minimal native implementation, which includes simple arithmetic and comparison operators, and conj, abs, norm and pow function implementations.  The Quat object is serializeable and provides a numpy buffer interface.

Closes #143.
  • Loading branch information
arahlin authored Nov 14, 2024
1 parent 862c407 commit 136e6e8
Show file tree
Hide file tree
Showing 23 changed files with 899 additions and 354 deletions.
3 changes: 0 additions & 3 deletions core/include/core/G3Map.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include <G3Frame.h>
#include <G3Vector.h>
#include <G3Quat.h>
#include <map>
#include <sstream>
#include <complex>
Expand Down Expand Up @@ -80,8 +79,6 @@ G3MAP_OF(std::string, G3VectorVectorString, G3MapVectorVectorString);
G3MAP_OF(std::string, std::vector<std::complex<double> >, G3MapVectorComplexDouble);
G3MAP_OF(std::string, G3VectorTime, G3MapVectorTime);
G3MAP_OF(std::string, std::string, G3MapString);
G3MAP_OF(std::string, quat, G3MapQuat);
G3MAP_OF(std::string, G3VectorQuat, G3MapVectorQuat);

#define G3MAP_SPLIT(key, value, name, version) \
typedef G3Map< key, value > name; \
Expand Down
149 changes: 102 additions & 47 deletions core/include/core/G3Quat.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,104 @@
#define _CORE_G3QUAT_H

#include <G3Vector.h>
#include <G3Map.h>

#include <boost/math/quaternion.hpp>
#include <cereal/types/vector.hpp>
class Quat
{
public:
Quat() : a_(0), b_(0), c_(0), d_(0) {}
Quat(double a, double b, double c, double d) :
a_(a), b_(b), c_(c), d_(d) {}
Quat(const Quat &q) : a_(q.a_), b_(q.b_), c_(q.c_), d_(q.d_) {}

double a() const { return a_; }
double b() const { return b_; }
double c() const { return c_; }
double d() const { return d_; }

double real() const;
Quat unreal() const;
Quat conj() const;
double norm() const;
double vnorm() const;
double abs() const;
double dot3(const Quat &b) const;
Quat cross3(const Quat &b) const;

Quat operator -() const;
Quat operator ~() const;

Quat &operator +=(const Quat &);
Quat &operator -=(const Quat &);
Quat &operator *=(double);
Quat &operator *=(const Quat &);
Quat &operator /=(double);
Quat &operator /=(const Quat &);

Quat operator +(const Quat &) const;
Quat operator -(const Quat &) const;
Quat operator *(double) const;
Quat operator *(const Quat &) const;
Quat operator /(double) const;
Quat operator /(const Quat &) const;

bool operator ==(const Quat &) const;
bool operator !=(const Quat &) const;

typedef boost::math::quaternion<double> quat;
template <class A> void serialize(A &ar, unsigned v);
private:
double a_, b_, c_, d_;
};

namespace cereal
{
// Define cereal serialization for the Quaternions
template<class A>
void serialize(A & ar, quat & q, unsigned version)
{
using namespace cereal;
double a, b, c, d;
a = q.R_component_1();
b = q.R_component_2();
c = q.R_component_3();
d = q.R_component_4();
ar & make_nvp("a", a);
ar & make_nvp("b", b);
ar & make_nvp("c", c);
ar & make_nvp("d", d);
q = quat(a,b,c,d);
}
std::ostream& operator<<(std::ostream& os, const Quat &);

namespace cereal {
template <class A> struct specialize<A, Quat, cereal::specialization::member_serialize> {};
}

quat cross3(quat a, quat b);
double dot3(quat a, quat b);
CEREAL_CLASS_VERSION(Quat, 1);

Quat operator *(double, const Quat &);
Quat operator /(double, const Quat &);

inline double real(const Quat &q) { return q.real(); };
inline Quat unreal(const Quat &q) { return q.unreal(); };
inline Quat conj(const Quat &q) { return q.conj(); };
inline double norm(const Quat &q) { return q.norm(); }
inline double vnorm(const Quat &q) { return q.vnorm(); }
inline double abs(const Quat &q) { return q.abs(); }

G3VECTOR_OF(quat, G3VectorQuat);
Quat pow(const Quat &, int);

Quat cross3(const Quat &a, const Quat &b);
double dot3(const Quat &a, const Quat &b);

// Frame object data wrapper

class G3Quat : public G3FrameObject {
public:
Quat value;

G3Quat() {}
G3Quat(const Quat &val) : value(val) {}

template <class A> void serialize(A &ar, unsigned v);
std::string Description() const;
bool operator==(const G3Quat & other) const {return value == other.value;}
};

G3_POINTERS(G3Quat);
G3_SERIALIZABLE(G3Quat, 1);

G3VECTOR_OF(Quat, G3VectorQuat);

class G3TimestreamQuat : public G3VectorQuat
{
public:
G3TimestreamQuat() : G3VectorQuat() {}
G3TimestreamQuat(std::vector<quat>::size_type s) : G3VectorQuat(s) {}
G3TimestreamQuat(std::vector<quat>::size_type s,
const quat &val) : G3VectorQuat(s, val) {}
G3TimestreamQuat(std::vector<Quat>::size_type s) : G3VectorQuat(s) {}
G3TimestreamQuat(std::vector<Quat>::size_type s,
const Quat &val) : G3VectorQuat(s, val) {}
G3TimestreamQuat(const G3TimestreamQuat &r) : G3VectorQuat(r),
start(r.start), stop(r.stop) {}
G3TimestreamQuat(const G3VectorQuat &r) : G3VectorQuat(r) {}
Expand All @@ -62,52 +122,47 @@ namespace cereal {
G3_POINTERS(G3TimestreamQuat);
G3_SERIALIZABLE(G3TimestreamQuat, 1);

namespace boost {
namespace math {
quat operator ~(quat);
};
};

G3VectorQuat operator ~ (const G3VectorQuat &);
G3VectorQuat operator * (const G3VectorQuat &, double);
G3VectorQuat &operator *= (G3VectorQuat &, double);
G3VectorQuat operator / (const G3VectorQuat &, double);
G3VectorQuat operator / (double, const G3VectorQuat &);
G3VectorQuat operator / (const G3VectorQuat &, const quat &);
G3VectorQuat operator / (const quat &, const G3VectorQuat &);
G3VectorQuat operator / (const G3VectorQuat &, const Quat &);
G3VectorQuat operator / (const Quat &, const G3VectorQuat &);
G3VectorQuat operator / (const G3VectorQuat &, const G3VectorQuat &);
G3VectorQuat &operator /= (G3VectorQuat &, double);
G3VectorQuat &operator /= (G3VectorQuat &, const quat &);
G3VectorQuat &operator /= (G3VectorQuat &, const Quat &);
G3VectorQuat &operator /= (G3VectorQuat &, const G3VectorQuat &);
G3VectorQuat operator * (const G3VectorQuat &, const G3VectorQuat &);
G3VectorQuat &operator *= (G3VectorQuat &, const G3VectorQuat &);
G3VectorQuat operator * (double, const G3VectorQuat &);
G3VectorQuat operator * (const G3VectorQuat &, quat);
G3VectorQuat operator * (quat, const G3VectorQuat &);
G3VectorQuat &operator *= (G3VectorQuat &, quat);
G3VectorQuat operator * (const G3VectorQuat &, const Quat &);
G3VectorQuat operator * (const Quat &, const G3VectorQuat &);
G3VectorQuat &operator *= (G3VectorQuat &, const Quat &);

G3VectorQuat pow(const G3VectorQuat &a, double b);
G3VectorQuat pow(const G3VectorQuat &a, int b);

G3TimestreamQuat operator ~ (const G3TimestreamQuat &);
G3TimestreamQuat operator * (const G3TimestreamQuat &, double);
G3TimestreamQuat operator * (double, const G3TimestreamQuat &);
G3TimestreamQuat operator / (const G3TimestreamQuat &, double);
G3TimestreamQuat operator / (double, const G3TimestreamQuat &);
G3TimestreamQuat operator / (const G3TimestreamQuat &, const quat &);
G3TimestreamQuat operator / (const quat &, const G3TimestreamQuat &);
G3TimestreamQuat operator / (const G3TimestreamQuat &, const Quat &);
G3TimestreamQuat operator / (const Quat &, const G3TimestreamQuat &);
G3TimestreamQuat operator / (const G3TimestreamQuat &, const G3VectorQuat &);
G3TimestreamQuat &operator /= (G3TimestreamQuat &, double);
G3TimestreamQuat &operator /= (G3TimestreamQuat &, const quat &);
G3TimestreamQuat &operator /= (G3TimestreamQuat &, const Quat &);
G3TimestreamQuat &operator /= (G3TimestreamQuat &, const G3VectorQuat &);
G3TimestreamQuat operator * (const G3TimestreamQuat &, const G3VectorQuat &);
G3TimestreamQuat &operator *= (G3TimestreamQuat &, const G3VectorQuat &);
G3TimestreamQuat operator * (double, const G3TimestreamQuat &);
G3TimestreamQuat operator * (const G3TimestreamQuat &, quat);
G3TimestreamQuat operator * (quat, const G3TimestreamQuat &);
G3TimestreamQuat &operator *= (G3TimestreamQuat &, quat);
G3TimestreamQuat operator * (const G3TimestreamQuat &, const Quat &);
G3TimestreamQuat operator * (const Quat &, const G3TimestreamQuat &);
G3TimestreamQuat &operator *= (G3TimestreamQuat &, const Quat &);

G3TimestreamQuat pow(const G3TimestreamQuat &a, double b);
G3TimestreamQuat pow(const G3TimestreamQuat &a, int b);

G3MAP_OF(std::string, G3VectorQuat, G3MapVectorQuat);
G3MAP_OF(std::string, Quat, G3MapQuat);

#endif
1 change: 1 addition & 0 deletions core/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ def fix_logging_crash():
from .dataextensions import *
from .frameextensions import *
from .timestreamextensions import *
from .quatextensions import *

from .g3decorators import cache_frame_data, scan_func_cache_data
42 changes: 42 additions & 0 deletions core/python/quatextensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
from . import Quat, G3VectorQuat, G3TimestreamQuat

__all__ = []

quat_types = (Quat, G3VectorQuat, G3TimestreamQuat)


def quat_ufunc(self, ufunc, method, *inputs, **kwargs):
"""Numpy ufunc interface for vectorized quaternion methods."""
if ufunc.__name__ in ["isinf", "isnan", "isfinite"] and len(inputs) == 1:
v = getattr(ufunc, method)(np.asarray(inputs[0]), **kwargs)
if ufunc.__name__ == "isfinite":
return np.all(v, axis=-1)
return np.any(v, axis=-1)
if ufunc.__name__.startswith("logical"):
args = []
for arg in inputs:
if isinstance(arg, quat_types):
arg = np.any(np.asarray(arg), axis=-1)
args.append(arg)
return getattr(ufunc, method)(*args, **kwargs)
if method != "__call__" or kwargs:
return NotImplemented
if len(inputs) == 1:
if ufunc.__name__ == "absolute":
return self.abs()
if ufunc.__name__ == "negative":
return self.__neg__()
if ufunc.__name__ == "conjugate":
return self.conj()
if ufunc.__name__ == "reciprocal":
return Quat(1, 0, 0, 0) / self
if len(inputs) == 2 and np.isscalar(inputs[1]):
if ufunc.__name__ == "power":
return self.__pow__(inputs[1])
return NotImplemented


Quat.__array_ufunc__ = quat_ufunc
G3VectorQuat.__array_ufunc__ = quat_ufunc
G3TimestreamQuat.__array_ufunc__ = quat_ufunc
9 changes: 9 additions & 0 deletions core/src/G3Frame.cxx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <G3Frame.h>
#include <G3Data.h>
#include <G3Quat.h>
#include <serialization.h>
#include <pybindings.h>
#include <dataio.h>
Expand Down Expand Up @@ -493,6 +494,12 @@ static void g3frame_python_put(G3Frame &f, std::string name, bp::object obj)
return;
}

bp::extract<Quat> extquat(obj);
if (extquat.check()) {
f.Put(name, boost::make_shared<G3Quat>(extquat()));
return;
}

bp::extract<std::string> extstr(obj);
if (extstr.check())
f.Put(name, boost::make_shared<G3String>(extstr()));
Expand Down Expand Up @@ -520,6 +527,8 @@ static bp::object g3frame_python_get(G3Frame &f, std::string name)
return bp::object(boost::dynamic_pointer_cast<const G3String>(element)->value);
else if (!!boost::dynamic_pointer_cast<const G3Bool>(element))
return bp::object(boost::dynamic_pointer_cast<const G3Bool>(element)->value);
else if (!!boost::dynamic_pointer_cast<const G3Quat>(element))
return bp::object(boost::dynamic_pointer_cast<const G3Quat>(element)->value);
else
return bp::object(boost::const_pointer_cast<G3FrameObject>(element));
}
Expand Down
6 changes: 0 additions & 6 deletions core/src/G3Map.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,12 @@ std::string G3MapFrameObject::Description() const
G3_SERIALIZABLE_CODE(G3MapDouble);
G3_SERIALIZABLE_CODE(G3MapMapDouble);
G3_SERIALIZABLE_CODE(G3MapString);
G3_SERIALIZABLE_CODE(G3MapQuat);
G3_SERIALIZABLE_CODE(G3MapVectorBool);
G3_SERIALIZABLE_CODE(G3MapVectorDouble);
G3_SERIALIZABLE_CODE(G3MapVectorString);
G3_SERIALIZABLE_CODE(G3MapVectorVectorString);
G3_SERIALIZABLE_CODE(G3MapVectorComplexDouble);
G3_SERIALIZABLE_CODE(G3MapVectorTime);
G3_SERIALIZABLE_CODE(G3MapVectorQuat);

G3_SPLIT_SERIALIZABLE_CODE(G3MapInt);
G3_SPLIT_SERIALIZABLE_CODE(G3MapVectorInt);
Expand All @@ -245,8 +243,6 @@ PYBINDINGS("core") {
register_g3map<G3MapInt>("G3MapInt", "Mapping from strings to ints.");
register_g3map<G3MapString>("G3MapString", "Mapping from strings to "
"strings.");
register_g3map<G3MapQuat>("G3MapQuat", "Mapping from strings to "
"quaternions.");
register_g3map<G3MapVectorBool>("G3MapVectorBool", "Mapping from "
"strings to arrays of booleans.");
register_g3map<G3MapVectorDouble>("G3MapVectorDouble", "Mapping from "
Expand All @@ -261,8 +257,6 @@ PYBINDINGS("core") {
"Mapping from strings to lists of lists of strings.");
register_g3map<G3MapVectorTime>("G3MapVectorTime", "Mapping from "
"strings to lists of G3 time objects.");
register_g3map<G3MapVectorQuat>("G3MapVectorQuat", "Mapping from "
"strings to lists of quaternions.");

// Special handling to get the object proxying right
register_g3map<G3MapFrameObject, true>("G3MapFrameObject", "Mapping "
Expand Down
Loading

0 comments on commit 136e6e8

Please sign in to comment.