diff --git a/src/model/ml.cc b/src/model/ml.cc index 2d55950f9..486a84d86 100755 --- a/src/model/ml.cc +++ b/src/model/ml.cc @@ -397,44 +397,35 @@ namespace FT{ // make sure the model fit() method passed if (get_weights().empty()) { - cout << "weight empty; returning zeros\n"; + HANDLE_ERROR_NO_THROW("weight empty; returning zeros"); if (this->prob_type==PT_BINARY) { - CBinaryLabels dlabels(X.cols()); + labels = std::shared_ptr(new CBinaryLabels(X.cols())); for (unsigned i = 0; i < X.cols() ; ++i) { - dlabels.set_value(0,i); - dlabels.set_label(0,i); + dynamic_pointer_cast(labels)->set_value(0,i); + dynamic_pointer_cast(labels)->set_label(i,0); } - cout << "setting labels\n"; - labels =shared_ptr(&dlabels); - cout << "returning\n"; return labels; } else if (this->prob_type == PT_MULTICLASS) { - CMulticlassLabels dlabels(X.cols()); + labels = std::shared_ptr(new CMulticlassLabels(X.cols())); for (unsigned i = 0; i < X.cols() ; ++i) { - dlabels.set_value(0,i); - dlabels.set_label(0,i); + dynamic_pointer_cast(labels)->set_value(0,i); + dynamic_pointer_cast(labels)->set_label(i,0); } - cout << "setting labels\n"; - labels =shared_ptr(&dlabels); - cout << "returning\n"; return labels; } else { - CRegressionLabels dlabels(X.cols()); + labels = std::shared_ptr(new CRegressionLabels(X.cols())); for (unsigned i = 0; i < X.cols() ; ++i) { - dlabels.set_value(0,i); - dlabels.set_label(0,i); + dynamic_pointer_cast(labels)->set_value(0,i); + dynamic_pointer_cast(labels)->set_label(i,0); } - cout << "setting labels\n"; - labels =shared_ptr(&dlabels); - cout << "returning\n"; return labels; } }