Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex for floatNuisances #851

Merged
merged 5 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions interface/Combine.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Combine {
private:
bool mklimit(RooWorkspace *w, RooStats::ModelConfig *mc_s, RooStats::ModelConfig *mc_b, RooAbsData &data, double &limit, double &limitErr) ;

std::string parseRegex(std::string instr, const RooArgSet *nuisances, RooWorkspace *w) ;
void addDiscreteNuisances(RooWorkspace *);
void addNuisances(const RooArgSet *);
void addFloatingParameters(const RooArgSet &);
Expand Down
182 changes: 102 additions & 80 deletions src/Combine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Combine::Combine() :
("validateModel,V", "Perform some sanity checks on the model and abort if they fail.")
("saveToys", "Save results of toy MC in output file")
("floatAllNuisances", po::value<bool>(&floatAllNuisances_)->default_value(false), "Make all nuisance parameters floating")
("floatParameters", po::value<string>(&floatNuisances_)->default_value(""), "Set to floating these parameters (note freeze will take priority over float)")
("floatParameters", po::value<string>(&floatNuisances_)->default_value(""), "Set to floating these parameters (note freeze will take priority over float), also accepts regexp with syntax 'rgx{<my regexp>}' or 'var{<my regexp>}'")
("freezeAllGlobalObs", po::value<bool>(&freezeAllGlobalObs_)->default_value(true), "Make all global observables constant")
;
miscOptions_.add_options()
Expand Down Expand Up @@ -215,6 +215,62 @@ void Combine::applyOptions(const boost::program_options::variables_map &vm) {
makeToyGenSnapshot_ = (method == "FitDiagnostics" && !vm.count("justFit"));
}

std::string Combine::parseRegex(std::string instr, const RooArgSet *nuisances, RooWorkspace *w) {
// expand regexps inside the "rgx{}" option
while (instr.find("rgx{") != std::string::npos) {
size_t pos1 = instr.find("rgx{");
size_t pos2 = instr.find("}",pos1);
std::string prestr = instr.substr(0,pos1);
std::string poststr = instr.substr(pos2+1,instr.size()-pos2);
std::string reg_esp = instr.substr(pos1+4,pos2-pos1-4);

std::regex rgx( reg_esp, std::regex::ECMAScript);

std::string matchingParams="";
std::unique_ptr<TIterator> iter(nuisances->createIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {
const std::string &target = a->GetName();
std::smatch match;
if (std::regex_match(target, match, rgx)) {
matchingParams = matchingParams + target + ",";
}
}

instr = prestr+matchingParams+poststr;
instr = boost::replace_all_copy(instr, ",,", ",");
}

// expand regexps inside the "var{}" option
while (instr.find("var{") != std::string::npos) {
size_t pos1 = instr.find("var{");
size_t pos2 = instr.find("}",pos1);
std::string prestr = instr.substr(0,pos1);
std::string poststr = instr.substr(pos2+1,instr.size()-pos2);
std::string reg_esp = instr.substr(pos1+4,pos2-pos1-4);

std::regex rgx( reg_esp, std::regex::ECMAScript);

std::string matchingParams="";
std::unique_ptr<TIterator> iter(w->componentIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {

if ( ! (a->IsA()->InheritsFrom(RooRealVar::Class()) || a->IsA()->InheritsFrom(RooCategory::Class()))) continue;

const std::string &target = a->GetName();
// std::cout<<"var "<<target<<std::endl;
std::smatch match;
if (std::regex_match(target, match, rgx)) {
matchingParams = matchingParams + target + ",";
}
}

instr = prestr+matchingParams+poststr;
instr = boost::replace_all_copy(instr, ",,", ",");
}

return instr;
}

bool Combine::mklimit(RooWorkspace *w, RooStats::ModelConfig *mc_s, RooStats::ModelConfig *mc_b, RooAbsData &data, double &limit, double &limitErr) {
TStopwatch timer;

Expand Down Expand Up @@ -589,77 +645,43 @@ void Combine::run(TString hlfFile, const std::string &dataset, double &limit, do
}

if (floatNuisances_ != "") {
RooArgSet toFloat((floatNuisances_=="all")?*nuisances:(w->argSet(floatNuisances_.c_str())));
floatNuisances_ = parseRegex(floatNuisances_, nuisances, w);

RooArgSet toFloat;
if (floatNuisances_=="all") {
toFloat.add(*nuisances);
} else {
std::vector<std::string> nuisToFloat;
boost::split(nuisToFloat, floatNuisances_, boost::is_any_of(","), boost::token_compress_on);
for (int k=0; k<(int)nuisToFloat.size(); k++) {
if (nuisToFloat[k]=="") continue;
else if(nuisToFloat[k]=="all") {
toFloat.add(*nuisances);
continue;
}
else if (!w->fundArg(nuisToFloat[k].c_str())) {
std::cout<<"WARNING: cannot float nuisance parameter "<<nuisToFloat[k].c_str()<<" if it doesn't exist!"<<std::endl;
continue;
}
const RooAbsArg *arg = (RooAbsArg*)w->fundArg(nuisToFloat[k].c_str());
toFloat.add(*arg);
}
}

if (verbose > 0) {
std::cout << "Set floating the following parameters: "; toFloat.Print("");
Logger::instance().log(std::string(Form("Combine.cc: %d -- Set floating the following parameters: ",__LINE__)),Logger::kLogLevelInfo,__func__);
std::cout << "Floating the following parameters: "; toFloat.Print("");
Logger::instance().log(std::string(Form("Combine.cc: %d -- Floating the following parameters: ",__LINE__)),Logger::kLogLevelInfo,__func__);
std::unique_ptr<TIterator> iter(toFloat.createIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {
Logger::instance().log(std::string(Form("Combine.cc: %d %s ",__LINE__,a->GetName())),Logger::kLogLevelInfo,__func__);
}
}
}
utils::setAllConstant(toFloat, false);
}

if (freezeNuisances_ != "") {
freezeNuisances_ = parseRegex(freezeNuisances_, nuisances, w);

// expand regexps
while (freezeNuisances_.find("rgx{") != std::string::npos) {
size_t pos1 = freezeNuisances_.find("rgx{");
size_t pos2 = freezeNuisances_.find("}",pos1);
std::string prestr = freezeNuisances_.substr(0,pos1);
std::string poststr = freezeNuisances_.substr(pos2+1,freezeNuisances_.size()-pos2);
std::string reg_esp = freezeNuisances_.substr(pos1+4,pos2-pos1-4);

//std::cout<<"interpreting "<<reg_esp<<" as regex "<<std::endl;
std::regex rgx( reg_esp, std::regex::ECMAScript);

std::string matchingParams="";
std::unique_ptr<TIterator> iter(nuisances->createIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {
const std::string &target = a->GetName();
std::smatch match;
if (std::regex_match(target, match, rgx)) {
matchingParams = matchingParams + target + ",";
}
}

freezeNuisances_ = prestr+matchingParams+poststr;
freezeNuisances_ = boost::replace_all_copy(freezeNuisances_, ",,", ",");

}

// expand regexps
while (freezeNuisances_.find("var{") != std::string::npos) {
size_t pos1 = freezeNuisances_.find("var{");
size_t pos2 = freezeNuisances_.find("}",pos1);
std::string prestr = freezeNuisances_.substr(0,pos1);
std::string poststr = freezeNuisances_.substr(pos2+1,freezeNuisances_.size()-pos2);
std::string reg_esp = freezeNuisances_.substr(pos1+4,pos2-pos1-4);

// std::cout<<"interpreting "<<reg_esp<<" as regex "<<std::endl;
std::regex rgx( reg_esp, std::regex::ECMAScript);

std::string matchingParams="";
std::unique_ptr<TIterator> iter(w->componentIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {

if ( ! (a->IsA()->InheritsFrom(RooRealVar::Class()) || a->IsA()->InheritsFrom(RooCategory::Class()))) continue;

const std::string &target = a->GetName();
// std::cout<<"var "<<target<<std::endl;
std::smatch match;
if (std::regex_match(target, match, rgx)) {
matchingParams = matchingParams + target + ",";
}
}

freezeNuisances_ = prestr+matchingParams+poststr;
freezeNuisances_ = boost::replace_all_copy(freezeNuisances_, ",,", ",");

}

//RooArgSet toFreeze((freezeNuisances_=="all")?*nuisances:(w->argSet(freezeNuisances_.c_str())));
RooArgSet toFreeze;
if (freezeNuisances_=="allConstrainedNuisances") {
toFreeze.add(*nuisances);
Expand Down Expand Up @@ -687,7 +709,7 @@ void Combine::run(TString hlfFile, const std::string &dataset, double &limit, do
std::unique_ptr<TIterator> iter(toFreeze.createIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {
Logger::instance().log(std::string(Form("Combine.cc: %d %s ",__LINE__,a->GetName())),Logger::kLogLevelInfo,__func__);
}
}
}
utils::setAllConstant(toFreeze, true);
if (nuisances) {
Expand All @@ -705,24 +727,24 @@ void Combine::run(TString hlfFile, const std::string &dataset, double &limit, do
for (std::vector<string>::iterator ng_it=nuisanceGroups.begin();ng_it!=nuisanceGroups.end();ng_it++){
bool freeze_complement=false;
if (boost::algorithm::starts_with((*ng_it),"^")){
freeze_complement=true;
(*ng_it).erase(0,1);
}
freeze_complement=true;
(*ng_it).erase(0,1);
}

if (!w->set(Form("group_%s",(*ng_it).c_str()))){
std::cerr << "Unknown nuisance group: " << (*ng_it) << std::endl;
throw std::invalid_argument("Unknown nuisance group name");
}
RooArgSet groupNuisances(*(w->set(Form("group_%s",(*ng_it).c_str()))));
RooArgSet toFreeze;
if (!w->set(Form("group_%s",(*ng_it).c_str()))){
std::cerr << "Unknown nuisance group: " << (*ng_it) << std::endl;
throw std::invalid_argument("Unknown nuisance group name");
}
RooArgSet groupNuisances(*(w->set(Form("group_%s",(*ng_it).c_str()))));
RooArgSet toFreeze;

if (freeze_complement) {
RooArgSet still_floating(*mc->GetNuisanceParameters());
still_floating.remove(groupNuisances,true,true);
toFreeze.add(still_floating);
} else {
toFreeze.add(groupNuisances);
}
if (freeze_complement) {
RooArgSet still_floating(*mc->GetNuisanceParameters());
still_floating.remove(groupNuisances,true,true);
toFreeze.add(still_floating);
} else {
toFreeze.add(groupNuisances);
}

if (verbose > 0) { std::cout << "Freezing the following nuisance parameters: "; toFreeze.Print(""); }
utils::setAllConstant(toFreeze, true);
Expand All @@ -732,7 +754,7 @@ void Combine::run(TString hlfFile, const std::string &dataset, double &limit, do
mc->SetNuisanceParameters(newnuis);
if (mc_bonly) mc_bonly->SetNuisanceParameters(newnuis);
nuisances = mc->GetNuisanceParameters();
}
}
}
}

Expand Down