Skip to content

Commit

Permalink
Merge pull request #282 from weefuzzy/empty_object_json_crashes
Browse files Browse the repository at this point in the history
Crashes with JSON generated by empty `KNN*` objects
  • Loading branch information
weefuzzy authored Oct 12, 2024
2 parents 3ceddcd + c3e99a9 commit ef77dba
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
23 changes: 13 additions & 10 deletions include/clients/nrt/KNNClassifierClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ namespace knnclassifier {

struct KNNClassifierData
{
algorithm::KDTree tree{0};
algorithm::KDTree tree{algorithm::KDTree()};
FluidDataSet<std::string, std::string, 1> labels{1};
index size() const { return labels.size(); }
index dims() const { return tree.dims(); }
void clear()
index size() const { return labels.size(); }
index dims() const { return tree.dims(); }
void clear()
{
labels = FluidDataSet<std::string, std::string, 1>(1);
tree.clear();
Expand All @@ -43,7 +43,10 @@ void to_json(nlohmann::json& j, const KNNClassifierData& data)
bool check_json(const nlohmann::json& j, const KNNClassifierData&)
{
return fluid::check_json(j, {"tree", "labels"},
{JSONTypes::OBJECT, JSONTypes::OBJECT});
{JSONTypes::OBJECT, JSONTypes::OBJECT}) &&
fluid::algorithm::check_json(j.at("tree"), algorithm::KDTree()) &&
fluid::check_json(j.at("labels"),
FluidDataSet<std::string, std::string, 1>());
}

void from_json(const nlohmann::json& j, KNNClassifierData& data)
Expand Down Expand Up @@ -132,14 +135,14 @@ class KNNClassifierClient : public FluidBaseClient,
algorithm::KNNClassifier classifier;
RealVector point(mAlgorithm.tree.dims());
point <<= BufferAdaptor::ReadAccess(data.get())
.samps(0, mAlgorithm.tree.dims(), 0);
.samps(0, mAlgorithm.tree.dims(), 0);
std::string result = classifier.predict(mAlgorithm.tree, point,
mAlgorithm.labels, k, weight);
return result;
}

MessageResult<void> predict(InputDataSetClientRef source,
LabelSetClientRef dest) const
MessageResult<void> predict(InputDataSetClientRef source,
LabelSetClientRef dest) const
{
index k = get<kNumNeighbors>();
bool weight = get<kWeight>() != 0;
Expand All @@ -163,7 +166,7 @@ class KNNClassifierClient : public FluidBaseClient,
{
RealVectorView point = data.row(i);
StringVector label = {classifier.predict(mAlgorithm.tree, point,
mAlgorithm.labels, k, weight)};
mAlgorithm.labels, k, weight)};
result.add(ids(i), label);
}
destPtr->setLabelSet(result);
Expand All @@ -186,7 +189,7 @@ class KNNClassifierClient : public FluidBaseClient,
makeMessage("read", &KNNClassifierClient::read));
}

index encodeIndex(std::string const& label) const
index encodeIndex(std::string const& label) const
{
return mLabelSetEncoder.encodeIndex(label);
}
Expand Down
28 changes: 17 additions & 11 deletions include/clients/nrt/KNNRegressorClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace knnregressor {

struct KNNRegressorData
{
algorithm::KDTree tree{0};
algorithm::KDTree tree{algorithm::KDTree()};
FluidDataSet<std::string, double, 1> target{1};
index size() const { return target.size(); }
index dims() const { return tree.dims(); }
Expand All @@ -41,7 +41,10 @@ void to_json(nlohmann::json& j, const KNNRegressorData& data)
bool check_json(const nlohmann::json& j, const KNNRegressorData&)
{
return fluid::check_json(j, {"tree", "target"},
{JSONTypes::OBJECT, JSONTypes::OBJECT});
{JSONTypes::OBJECT, JSONTypes::OBJECT}) &&
fluid::algorithm::check_json(j.at("tree"), algorithm::KDTree()) &&
fluid::check_json(j.at("labels"),
FluidDataSet<std::string, std::string, 1>());
}

void from_json(const nlohmann::json& j, KNNRegressorData& data)
Expand Down Expand Up @@ -128,24 +131,25 @@ class KNNRegressorClient : public FluidBaseClient,
if (mAlgorithm.tree.size() < k) return Error(NotEnoughData);

InBufferCheck bufCheck(mAlgorithm.tree.dims());
if (!bufCheck.checkInputs(in.get()))
return Error(bufCheck.error());
if (!bufCheck.checkInputs(in.get())) return Error(bufCheck.error());
BufferAdaptor::ReadAccess inBuf(in.get());
BufferAdaptor::Access outBuf(out.get());
BufferAdaptor::Access outBuf(out.get());
if (!outBuf.exists()) return Error(InvalidBuffer);
Result resizeResult = outBuf.resize(mAlgorithm.target.dims(), 1, inBuf.sampleRate());
Result resizeResult =
outBuf.resize(mAlgorithm.target.dims(), 1, inBuf.sampleRate());
if (!resizeResult.ok()) return Error(BufferAlloc);
algorithm::KNNRegressor regressor;
RealVector input(mAlgorithm.tree.dims());
RealVector output(mAlgorithm.target.dims());
input <<= inBuf.samps(0, mAlgorithm.tree.dims(), 0);
regressor.predict(mAlgorithm.tree, mAlgorithm.target, input, output, k, weight);
regressor.predict(mAlgorithm.tree, mAlgorithm.target, input, output, k,
weight);
outBuf.samps(0) <<= output;
return OK();
}

MessageResult<void> predict(InputDataSetClientRef source,
DataSetClientRef dest) const
DataSetClientRef dest) const
{
index k = get<kNumNeighbors>();
bool weight = get<kWeight>() != 0;
Expand All @@ -169,7 +173,8 @@ class KNNRegressorClient : public FluidBaseClient,
for (index i = 0; i < dataSet.size(); i++)
{
RealVectorView point = data.row(i);
regressor.predict(mAlgorithm.tree, mAlgorithm.target, point, prediction, k, weight);
regressor.predict(mAlgorithm.tree, mAlgorithm.target, point, prediction,
k, weight);
result.add(ids(i), prediction);
}
destPtr->setDataSet(result);
Expand Down Expand Up @@ -262,9 +267,10 @@ class KNNRegressorQuery : public FluidBaseClient, ControlIn, ControlOut
RealVector output(algorithm.target.dims(), c.allocator());

input <<= BufferAdaptor::ReadAccess(get<kInputBuffer>().get())
.samps(0, algorithm.tree.dims(), 0);
.samps(0, algorithm.tree.dims(), 0);

regressor.predict(algorithm.tree, algorithm.target, input, output, k, weight, c.allocator());
regressor.predict(algorithm.tree, algorithm.target, input, output, k,
weight, c.allocator());
outBuf.samps(0) <<= output;
}
}
Expand Down

0 comments on commit ef77dba

Please sign in to comment.