Skip to content

Commit

Permalink
Rename c++ ::wfa::any_sketch::Distribution to ::wfa::any_sketch::Base…
Browse files Browse the repository at this point in the history
…Distribution
  • Loading branch information
brianyi9 committed Feb 7, 2024
1 parent 89d7e73 commit 3f2463c
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 37 deletions.
5 changes: 3 additions & 2 deletions src/main/cc/any_sketch/any_sketch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

namespace wfa::any_sketch {

AnySketch::AnySketch(std::vector<std::unique_ptr<Distribution>> indexes,
AnySketch::AnySketch(std::vector<std::unique_ptr<BaseDistribution>> indexes,
std::vector<ValueFunction> values)
: indexes_(indexes.size()), values_(values.size()) {
std::move(indexes.begin(), indexes.end(), indexes_.begin());
Expand Down Expand Up @@ -75,7 +75,8 @@ absl::StatusOr<int64_t> AnySketch::GetIndex(
absl::string_view item, const ItemMetadata& item_metadata) const {
uint64_t product = 1;
uint64_t linearized_index = 0;
for (const std::unique_ptr<Distribution>& distribution : indexes_) {
for (const std::unique_ptr<BaseDistribution>& distribution :
indexes_) {
ASSIGN_OR_RETURN(int64_t distribution_value,
distribution->Apply(item, item_metadata));
int64_t index_part = distribution_value - distribution->min_value();
Expand Down
4 changes: 2 additions & 2 deletions src/main/cc/any_sketch/any_sketch.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AnySketch {
// Creates a new, empty AnySketch.
//
// The inputs will be moved from.
AnySketch(std::vector<std::unique_ptr<Distribution>> indexes,
AnySketch(std::vector<std::unique_ptr<BaseDistribution>> indexes,
std::vector<ValueFunction> values);

AnySketch(const AnySketch &) = delete;
Expand Down Expand Up @@ -117,7 +117,7 @@ class AnySketch {

private:
absl::flat_hash_map<uint64_t, absl::FixedArray<ValueType>> registers_;
absl::FixedArray<std::unique_ptr<Distribution>> indexes_;
absl::FixedArray<std::unique_ptr<BaseDistribution>> indexes_;
absl::FixedArray<ValueFunction> values_;

size_t register_size() const;
Expand Down
22 changes: 11 additions & 11 deletions src/main/cc/any_sketch/distributions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@

namespace wfa::any_sketch {
namespace {
class BaseDistribution : public Distribution {
class BaseDistributionImpl : public BaseDistribution {
public:
BaseDistribution(int64_t min_value, int64_t max_value)
BaseDistributionImpl(int64_t min_value, int64_t max_value)
: min_value_(min_value), max_value_(max_value) {}

absl::StatusOr<int64_t> Apply(
Expand All @@ -53,7 +53,7 @@ class BaseDistribution : public Distribution {
absl::string_view item, const ItemMetadata& item_metadata) const = 0;
};

absl::StatusOr<int64_t> BaseDistribution::Apply(
absl::StatusOr<int64_t> BaseDistributionImpl::Apply(
absl::string_view item, const ItemMetadata& item_metadata) const {
ASSIGN_OR_RETURN(int64_t value, ApplyInternal(item, item_metadata));

Expand All @@ -69,11 +69,11 @@ absl::StatusOr<int64_t> BaseDistribution::Apply(
return value;
}

class OracleDistribution : public BaseDistribution {
class OracleDistribution : public BaseDistributionImpl {
public:
OracleDistribution(int64_t min_value, int64_t max_value,
absl::string_view feature_name)
: BaseDistribution(min_value, max_value), feature_name_(feature_name) {}
: BaseDistributionImpl(min_value, max_value), feature_name_(feature_name) {}

private:
absl::StatusOr<int64_t> ApplyInternal(
Expand All @@ -90,11 +90,11 @@ class OracleDistribution : public BaseDistribution {
std::string feature_name_;
};

class FingerprintingDistribution : public BaseDistribution {
class FingerprintingDistribution : public BaseDistributionImpl {
public:
FingerprintingDistribution(int64_t min_value, int64_t max_value,
const Fingerprinter* fingerprinter)
: BaseDistribution(min_value, max_value), fingerprinter_(fingerprinter) {}
: BaseDistributionImpl(min_value, max_value), fingerprinter_(fingerprinter) {}

private:
absl::StatusOr<int64_t> ApplyInternal(
Expand Down Expand Up @@ -173,23 +173,23 @@ class GeometricDistribution : public FingerprintingDistribution {
};
} // namespace

std::unique_ptr<Distribution> GetOracleDistribution(
std::unique_ptr<BaseDistribution> GetOracleDistribution(
absl::string_view feature_name, int64_t min_value, int64_t max_value) {
return absl::make_unique<OracleDistribution>(min_value, max_value,
feature_name);
}
std::unique_ptr<Distribution> GetUniformDistribution(
std::unique_ptr<BaseDistribution> GetUniformDistribution(
const Fingerprinter* fingerprinter, int64_t min_value, int64_t max_value) {
return absl::make_unique<UniformDistribution>(min_value, max_value,
fingerprinter);
}
std::unique_ptr<Distribution> GetExponentialDistribution(
std::unique_ptr<BaseDistribution> GetExponentialDistribution(
const Fingerprinter* fingerprinter, double rate, int64_t size) {
ABSL_ASSERT(rate > 0.0);
ABSL_ASSERT(size > 0);
return absl::make_unique<ExponentialDistribution>(rate, size, fingerprinter);
}
std::unique_ptr<Distribution> GetGeometricDistribution(
std::unique_ptr<BaseDistribution> GetGeometricDistribution(
const Fingerprinter* fingerprinter, int64_t min_value, int64_t max_value) {
return absl::make_unique<GeometricDistribution>(min_value, max_value,
fingerprinter);
Expand Down
22 changes: 11 additions & 11 deletions src/main/cc/any_sketch/distributions.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ namespace wfa::any_sketch {

using ItemMetadata = absl::flat_hash_map<std::string, int64_t>;

// Base for representing distributions -- a way of deterministically mapping an
// item and associated metadata to a number.
class Distribution {
// Abstract Base for representing distributions -- a way of deterministically mapping an item
// and associated metadata to a number.
class BaseDistribution {
public:
Distribution(const Distribution&) = delete;
Distribution& operator=(const Distribution&) = delete;
BaseDistribution(const BaseDistribution&) = delete;
BaseDistribution& operator=(const BaseDistribution&) = delete;

virtual ~Distribution() = default;
virtual ~BaseDistribution() = default;

// The smallest value (inclusive) that the Distribution can return.
virtual int64_t min_value() const = 0;
Expand All @@ -52,16 +52,16 @@ class Distribution {
absl::string_view item, const ItemMetadata& item_metadata) const = 0;

protected:
Distribution() = default;
BaseDistribution() = default;
};

std::unique_ptr<Distribution> GetOracleDistribution(
std::unique_ptr<BaseDistribution> GetOracleDistribution(
absl::string_view feature_name, int64_t min_value, int64_t max_value);
std::unique_ptr<Distribution> GetUniformDistribution(
std::unique_ptr<BaseDistribution> GetUniformDistribution(
const Fingerprinter* fingerprinter, int64_t min_value, int64_t max_value);
std::unique_ptr<Distribution> GetExponentialDistribution(
std::unique_ptr<BaseDistribution> GetExponentialDistribution(
const Fingerprinter* fingerprinter, double rate, int64_t size);
std::unique_ptr<Distribution> GetGeometricDistribution(
std::unique_ptr<BaseDistribution> GetGeometricDistribution(
const Fingerprinter* fingerprinter, int64_t min_value, int64_t max_value);

} // namespace wfa::any_sketch
Expand Down
2 changes: 1 addition & 1 deletion src/main/cc/any_sketch/value_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace wfa::any_sketch {
struct ValueFunction {
std::string name;
AggregatorType aggregator_type;
std::unique_ptr<Distribution> distribution;
std::unique_ptr<BaseDistribution> distribution;
};

} // namespace wfa::any_sketch
Expand Down
9 changes: 5 additions & 4 deletions src/test/cc/any_sketch/any_sketch_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Matcher<AnySketch::Register> RegisterIs(uint64_t index,
new RegisterIsMatcher(index, std::move(values)));
}

class FakeDistribution : public Distribution {
class FakeDistribution : public BaseDistribution {
public:
absl::StatusOr<int64_t> Apply(
absl::string_view item,
Expand All @@ -82,7 +82,7 @@ class FakeDistribution : public Distribution {
int64_t max_value() const override { return 10; }
};

std::unique_ptr<Distribution> MakeFakeDistribution() {
std::unique_ptr<BaseDistribution> MakeFakeDistribution() {
return absl::make_unique<FakeDistribution>();
}

Expand All @@ -93,12 +93,13 @@ std::vector<T> MakeSingleItemVector(T&& t) {
return v;
}

std::vector<std::unique_ptr<Distribution>> MakeFakeDistributionIndex() {
std::vector<std::unique_ptr<BaseDistribution>> MakeFakeDistributionIndex() {
return MakeSingleItemVector(MakeFakeDistribution());
}

ValueFunction MakeValueFunction(AggregatorType aggregator,
std::unique_ptr<Distribution> distribution) {
std::unique_ptr<BaseDistribution>
distribution) {
return {.name = "SomeValueFunction",
.aggregator_type = aggregator,
.distribution = std::move(distribution)};
Expand Down
8 changes: 4 additions & 4 deletions src/test/cc/any_sketch/distributions_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class FakeFingerprinter : public Fingerprinter {
};

TEST(DistributionsTest, OracleDistribution) {
std::unique_ptr<Distribution> distribution =
std::unique_ptr<BaseDistribution> distribution =
GetOracleDistribution("foo", 3, 10);

ASSERT_FALSE(distribution == nullptr);
Expand All @@ -53,7 +53,7 @@ TEST(DistributionsTest, OracleDistribution) {

TEST(DistributionsTest, UniformDistribution) {
FakeFingerprinter fingerprinter;
std::unique_ptr<Distribution> distribution =
std::unique_ptr<BaseDistribution> distribution =
GetUniformDistribution(&fingerprinter, 3, 10);

ASSERT_FALSE(distribution == nullptr);
Expand All @@ -76,7 +76,7 @@ TEST(DistributionsTest, UniformDistribution) {

TEST(DistributionsTest, ExponentialDistribution) {
FakeFingerprinter fingerprinter;
std::unique_ptr<Distribution> distribution =
std::unique_ptr<BaseDistribution> distribution =
GetExponentialDistribution(&fingerprinter, 2, 10);

ASSERT_FALSE(distribution == nullptr);
Expand All @@ -94,7 +94,7 @@ TEST(DistributionsTest, ExponentialDistribution) {

TEST(DistributionsTest, GeometricDistribution) {
FakeFingerprinter fingerprinter;
std::unique_ptr<Distribution> distribution =
std::unique_ptr<BaseDistribution> distribution =
GetGeometricDistribution(&fingerprinter, 10, 74);

ASSERT_FALSE(distribution == nullptr);
Expand Down
4 changes: 2 additions & 2 deletions src/test/cc/estimation/estimators_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

namespace wfa::estimation {
namespace {
using ::wfa::any_sketch::Distribution;
using ::wfa::any_sketch::BaseDistribution;

MATCHER_P2(EqWithError, value, error, "") {
// Since value and arg might be of different, possibly unsigned types,
Expand Down Expand Up @@ -60,7 +60,7 @@ uint64_t GenerateRandomSketchAndGetSize(double decay_rate,
absl::flat_hash_set<int64_t> indexes;

const Fingerprinter& fingerprinter = GetSha256Fingerprinter();
std::unique_ptr<Distribution> exponential_distribution =
std::unique_ptr<BaseDistribution> exponential_distribution =
wfa::any_sketch::GetExponentialDistribution(&fingerprinter, decay_rate,
num_of_total_registers);

Expand Down

0 comments on commit 3f2463c

Please sign in to comment.