Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native quaternion implementation #166

Merged
merged 17 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions core/include/core/G3Map.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include <G3Frame.h>
#include <G3Vector.h>
#include <G3Quat.h>
#include <map>
#include <sstream>
#include <complex>
Expand Down Expand Up @@ -80,8 +79,6 @@ G3MAP_OF(std::string, G3VectorVectorString, G3MapVectorVectorString);
G3MAP_OF(std::string, std::vector<std::complex<double> >, G3MapVectorComplexDouble);
G3MAP_OF(std::string, G3VectorTime, G3MapVectorTime);
G3MAP_OF(std::string, std::string, G3MapString);
G3MAP_OF(std::string, quat, G3MapQuat);
G3MAP_OF(std::string, G3VectorQuat, G3MapVectorQuat);

#define G3MAP_SPLIT(key, value, name, version) \
typedef G3Map< key, value > name; \
Expand Down
149 changes: 102 additions & 47 deletions core/include/core/G3Quat.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,104 @@
#define _CORE_G3QUAT_H

#include <G3Vector.h>
#include <G3Map.h>

#include <boost/math/quaternion.hpp>
#include <cereal/types/vector.hpp>
class Quat
{
public:
Quat() : a_(0), b_(0), c_(0), d_(0) {}
Quat(double a, double b, double c, double d) :
a_(a), b_(b), c_(c), d_(d) {}
Quat(const Quat &q) : a_(q.a_), b_(q.b_), c_(q.c_), d_(q.d_) {}

double a() const { return a_; }
double b() const { return b_; }
double c() const { return c_; }
double d() const { return d_; }

double real() const;
Quat unreal() const;
Quat conj() const;
double norm() const;
double vnorm() const;
double abs() const;
double dot3(const Quat &b) const;
Quat cross3(const Quat &b) const;

Quat operator -() const;
Quat operator ~() const;

Quat &operator +=(const Quat &);
Quat &operator -=(const Quat &);
Quat &operator *=(double);
Quat &operator *=(const Quat &);
Quat &operator /=(double);
Quat &operator /=(const Quat &);

Quat operator +(const Quat &) const;
Quat operator -(const Quat &) const;
Quat operator *(double) const;
Quat operator *(const Quat &) const;
Quat operator /(double) const;
Quat operator /(const Quat &) const;

bool operator ==(const Quat &) const;
bool operator !=(const Quat &) const;

typedef boost::math::quaternion<double> quat;
template <class A> void serialize(A &ar, unsigned v);
private:
double a_, b_, c_, d_;
};

namespace cereal
{
// Define cereal serialization for the Quaternions
template<class A>
void serialize(A & ar, quat & q, unsigned version)
{
using namespace cereal;
double a, b, c, d;
a = q.R_component_1();
b = q.R_component_2();
c = q.R_component_3();
d = q.R_component_4();
ar & make_nvp("a", a);
ar & make_nvp("b", b);
ar & make_nvp("c", c);
ar & make_nvp("d", d);
q = quat(a,b,c,d);
}
std::ostream& operator<<(std::ostream& os, const Quat &);

namespace cereal {
template <class A> struct specialize<A, Quat, cereal::specialization::member_serialize> {};
}

quat cross3(quat a, quat b);
double dot3(quat a, quat b);
CEREAL_CLASS_VERSION(Quat, 1);

Quat operator *(double, const Quat &);
Quat operator /(double, const Quat &);

inline double real(const Quat &q) { return q.real(); };
inline Quat unreal(const Quat &q) { return q.unreal(); };
inline Quat conj(const Quat &q) { return q.conj(); };
inline double norm(const Quat &q) { return q.norm(); }
inline double vnorm(const Quat &q) { return q.vnorm(); }
inline double abs(const Quat &q) { return q.abs(); }

G3VECTOR_OF(quat, G3VectorQuat);
Quat pow(const Quat &, int);

Quat cross3(const Quat &a, const Quat &b);
double dot3(const Quat &a, const Quat &b);

// Frame object data wrapper

class G3Quat : public G3FrameObject {
public:
Quat value;

G3Quat() {}
G3Quat(const Quat &val) : value(val) {}

template <class A> void serialize(A &ar, unsigned v);
std::string Description() const;
bool operator==(const G3Quat & other) const {return value == other.value;}
};

G3_POINTERS(G3Quat);
G3_SERIALIZABLE(G3Quat, 1);

G3VECTOR_OF(Quat, G3VectorQuat);

class G3TimestreamQuat : public G3VectorQuat
{
public:
G3TimestreamQuat() : G3VectorQuat() {}
G3TimestreamQuat(std::vector<quat>::size_type s) : G3VectorQuat(s) {}
G3TimestreamQuat(std::vector<quat>::size_type s,
const quat &val) : G3VectorQuat(s, val) {}
G3TimestreamQuat(std::vector<Quat>::size_type s) : G3VectorQuat(s) {}
G3TimestreamQuat(std::vector<Quat>::size_type s,
const Quat &val) : G3VectorQuat(s, val) {}
G3TimestreamQuat(const G3TimestreamQuat &r) : G3VectorQuat(r),
start(r.start), stop(r.stop) {}
G3TimestreamQuat(const G3VectorQuat &r) : G3VectorQuat(r) {}
Expand All @@ -62,52 +122,47 @@ namespace cereal {
G3_POINTERS(G3TimestreamQuat);
G3_SERIALIZABLE(G3TimestreamQuat, 1);

namespace boost {
namespace math {
quat operator ~(quat);
};
};

G3VectorQuat operator ~ (const G3VectorQuat &);
G3VectorQuat operator * (const G3VectorQuat &, double);
G3VectorQuat &operator *= (G3VectorQuat &, double);
G3VectorQuat operator / (const G3VectorQuat &, double);
G3VectorQuat operator / (double, const G3VectorQuat &);
G3VectorQuat operator / (const G3VectorQuat &, const quat &);
G3VectorQuat operator / (const quat &, const G3VectorQuat &);
G3VectorQuat operator / (const G3VectorQuat &, const Quat &);
G3VectorQuat operator / (const Quat &, const G3VectorQuat &);
G3VectorQuat operator / (const G3VectorQuat &, const G3VectorQuat &);
G3VectorQuat &operator /= (G3VectorQuat &, double);
G3VectorQuat &operator /= (G3VectorQuat &, const quat &);
G3VectorQuat &operator /= (G3VectorQuat &, const Quat &);
G3VectorQuat &operator /= (G3VectorQuat &, const G3VectorQuat &);
G3VectorQuat operator * (const G3VectorQuat &, const G3VectorQuat &);
G3VectorQuat &operator *= (G3VectorQuat &, const G3VectorQuat &);
G3VectorQuat operator * (double, const G3VectorQuat &);
G3VectorQuat operator * (const G3VectorQuat &, quat);
G3VectorQuat operator * (quat, const G3VectorQuat &);
G3VectorQuat &operator *= (G3VectorQuat &, quat);
G3VectorQuat operator * (const G3VectorQuat &, const Quat &);
G3VectorQuat operator * (const Quat &, const G3VectorQuat &);
G3VectorQuat &operator *= (G3VectorQuat &, const Quat &);

G3VectorQuat pow(const G3VectorQuat &a, double b);
G3VectorQuat pow(const G3VectorQuat &a, int b);

G3TimestreamQuat operator ~ (const G3TimestreamQuat &);
G3TimestreamQuat operator * (const G3TimestreamQuat &, double);
G3TimestreamQuat operator * (double, const G3TimestreamQuat &);
G3TimestreamQuat operator / (const G3TimestreamQuat &, double);
G3TimestreamQuat operator / (double, const G3TimestreamQuat &);
G3TimestreamQuat operator / (const G3TimestreamQuat &, const quat &);
G3TimestreamQuat operator / (const quat &, const G3TimestreamQuat &);
G3TimestreamQuat operator / (const G3TimestreamQuat &, const Quat &);
G3TimestreamQuat operator / (const Quat &, const G3TimestreamQuat &);
G3TimestreamQuat operator / (const G3TimestreamQuat &, const G3VectorQuat &);
G3TimestreamQuat &operator /= (G3TimestreamQuat &, double);
G3TimestreamQuat &operator /= (G3TimestreamQuat &, const quat &);
G3TimestreamQuat &operator /= (G3TimestreamQuat &, const Quat &);
G3TimestreamQuat &operator /= (G3TimestreamQuat &, const G3VectorQuat &);
G3TimestreamQuat operator * (const G3TimestreamQuat &, const G3VectorQuat &);
G3TimestreamQuat &operator *= (G3TimestreamQuat &, const G3VectorQuat &);
G3TimestreamQuat operator * (double, const G3TimestreamQuat &);
G3TimestreamQuat operator * (const G3TimestreamQuat &, quat);
G3TimestreamQuat operator * (quat, const G3TimestreamQuat &);
G3TimestreamQuat &operator *= (G3TimestreamQuat &, quat);
G3TimestreamQuat operator * (const G3TimestreamQuat &, const Quat &);
G3TimestreamQuat operator * (const Quat &, const G3TimestreamQuat &);
G3TimestreamQuat &operator *= (G3TimestreamQuat &, const Quat &);

G3TimestreamQuat pow(const G3TimestreamQuat &a, double b);
G3TimestreamQuat pow(const G3TimestreamQuat &a, int b);

G3MAP_OF(std::string, G3VectorQuat, G3MapVectorQuat);
G3MAP_OF(std::string, Quat, G3MapQuat);

#endif
1 change: 1 addition & 0 deletions core/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ def fix_logging_crash():
from .dataextensions import *
from .frameextensions import *
from .timestreamextensions import *
from .quatextensions import *

from .g3decorators import cache_frame_data, scan_func_cache_data
42 changes: 42 additions & 0 deletions core/python/quatextensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
from . import Quat, G3VectorQuat, G3TimestreamQuat

__all__ = []

quat_types = (Quat, G3VectorQuat, G3TimestreamQuat)


def quat_ufunc(self, ufunc, method, *inputs, **kwargs):
"""Numpy ufunc interface for vectorized quaternion methods."""
if ufunc.__name__ in ["isinf", "isnan", "isfinite"] and len(inputs) == 1:
v = getattr(ufunc, method)(np.asarray(inputs[0]), **kwargs)
if ufunc.__name__ == "isfinite":
return np.all(v, axis=-1)
return np.any(v, axis=-1)
if ufunc.__name__.startswith("logical"):
args = []
for arg in inputs:
if isinstance(arg, quat_types):
arg = np.any(np.asarray(arg), axis=-1)
args.append(arg)
return getattr(ufunc, method)(*args, **kwargs)
if method != "__call__" or kwargs:
return NotImplemented
if len(inputs) == 1:
if ufunc.__name__ == "absolute":
return self.abs()
if ufunc.__name__ == "negative":
return self.__neg__()
if ufunc.__name__ == "conjugate":
return self.conj()
if ufunc.__name__ == "reciprocal":
return Quat(1, 0, 0, 0) / self
if len(inputs) == 2 and np.isscalar(inputs[1]):
if ufunc.__name__ == "power":
return self.__pow__(inputs[1])
return NotImplemented


Quat.__array_ufunc__ = quat_ufunc
G3VectorQuat.__array_ufunc__ = quat_ufunc
G3TimestreamQuat.__array_ufunc__ = quat_ufunc
9 changes: 9 additions & 0 deletions core/src/G3Frame.cxx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <G3Frame.h>
#include <G3Data.h>
#include <G3Quat.h>
#include <serialization.h>
#include <pybindings.h>
#include <dataio.h>
Expand Down Expand Up @@ -493,6 +494,12 @@ static void g3frame_python_put(G3Frame &f, std::string name, bp::object obj)
return;
}

bp::extract<Quat> extquat(obj);
if (extquat.check()) {
f.Put(name, boost::make_shared<G3Quat>(extquat()));
return;
}

bp::extract<std::string> extstr(obj);
if (extstr.check())
f.Put(name, boost::make_shared<G3String>(extstr()));
Expand Down Expand Up @@ -520,6 +527,8 @@ static bp::object g3frame_python_get(G3Frame &f, std::string name)
return bp::object(boost::dynamic_pointer_cast<const G3String>(element)->value);
else if (!!boost::dynamic_pointer_cast<const G3Bool>(element))
return bp::object(boost::dynamic_pointer_cast<const G3Bool>(element)->value);
else if (!!boost::dynamic_pointer_cast<const G3Quat>(element))
return bp::object(boost::dynamic_pointer_cast<const G3Quat>(element)->value);
else
return bp::object(boost::const_pointer_cast<G3FrameObject>(element));
}
Expand Down
6 changes: 0 additions & 6 deletions core/src/G3Map.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,12 @@ std::string G3MapFrameObject::Description() const
G3_SERIALIZABLE_CODE(G3MapDouble);
G3_SERIALIZABLE_CODE(G3MapMapDouble);
G3_SERIALIZABLE_CODE(G3MapString);
G3_SERIALIZABLE_CODE(G3MapQuat);
G3_SERIALIZABLE_CODE(G3MapVectorBool);
G3_SERIALIZABLE_CODE(G3MapVectorDouble);
G3_SERIALIZABLE_CODE(G3MapVectorString);
G3_SERIALIZABLE_CODE(G3MapVectorVectorString);
G3_SERIALIZABLE_CODE(G3MapVectorComplexDouble);
G3_SERIALIZABLE_CODE(G3MapVectorTime);
G3_SERIALIZABLE_CODE(G3MapVectorQuat);

G3_SPLIT_SERIALIZABLE_CODE(G3MapInt);
G3_SPLIT_SERIALIZABLE_CODE(G3MapVectorInt);
Expand All @@ -245,8 +243,6 @@ PYBINDINGS("core") {
register_g3map<G3MapInt>("G3MapInt", "Mapping from strings to ints.");
register_g3map<G3MapString>("G3MapString", "Mapping from strings to "
"strings.");
register_g3map<G3MapQuat>("G3MapQuat", "Mapping from strings to "
"quaternions.");
register_g3map<G3MapVectorBool>("G3MapVectorBool", "Mapping from "
"strings to arrays of booleans.");
register_g3map<G3MapVectorDouble>("G3MapVectorDouble", "Mapping from "
Expand All @@ -261,8 +257,6 @@ PYBINDINGS("core") {
"Mapping from strings to lists of lists of strings.");
register_g3map<G3MapVectorTime>("G3MapVectorTime", "Mapping from "
"strings to lists of G3 time objects.");
register_g3map<G3MapVectorQuat>("G3MapVectorQuat", "Mapping from "
"strings to lists of quaternions.");

// Special handling to get the object proxying right
register_g3map<G3MapFrameObject, true>("G3MapFrameObject", "Mapping "
Expand Down
Loading