Skip to content

Commit

Permalink
Refactor CMSInterferenceFunc to use base class CMSExternalMorph
Browse files Browse the repository at this point in the history
This should make extending external morphs, e.g. to replace
RooParametricHist, much simpler.
  • Loading branch information
nsmith- committed Nov 1, 2023
1 parent 06a807f commit 6123b59
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 50 deletions.
45 changes: 45 additions & 0 deletions interface/CMSExternalMorph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef CMSExternalMorph_h
#define CMSExternalMorph_h
#include <vector>

#include "RooAbsReal.h"
#include "RooRealVar.h"
#include "RooRealProxy.h"

class CMSExternalMorph : public RooAbsReal {
public:
CMSExternalMorph();
/*
* All subclasses need to provide an edges array of length nbins+1
* of the observable (x)
* TODO: CMSHistFunc and CMSHistSum do not check the binning is compatible
* with their binning other than having the correct length
*/
CMSExternalMorph(
const char* name,
const char* title,
RooRealVar& x,
const std::vector<double>& edges
);
CMSExternalMorph(CMSExternalMorph const& other, const char* name = 0);
virtual ~CMSExternalMorph();

/* Batch accessor for CMSHistFunc / CMSHistSum, to be overriden by concrete
* implementations. hasChanged() should indicate whether or not
* batchGetBinValues() would return a new vector, given the state of
* any dependent variables.
*/
virtual bool hasChanged() const = 0;
virtual const std::vector<double>& batchGetBinValues() const = 0;

protected:
RooRealProxy x_;
std::vector<double> edges_;

double evaluate() const;

private:
ClassDef(CMSExternalMorph, 1)
};

#endif // CMSExternalMorph_h
5 changes: 2 additions & 3 deletions interface/CMSHistFunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "CMSHistV.h"
#include "FastTemplate_Old.h"
#include "SimpleCacheSentry.h"
#include "CMSInterferenceFunc.h"
#include "CMSExternalMorph.h"

class CMSHistFuncWrapper;

Expand Down Expand Up @@ -149,8 +149,7 @@ class CMSHistFunc : public RooAbsReal {
friend class CMSHistV<CMSHistFunc>;
friend class CMSHistSum;

// TODO: allow any class that implements hasChanged() and batchGetBinValues()
void injectExternalMorph(CMSInterferenceFunc& morph);
void injectExternalMorph(CMSExternalMorph& morph);
/*
– RooAbsArg::setVerboseEval(Int_t level) • Level 0 – No messages
Expand Down
3 changes: 1 addition & 2 deletions interface/CMSHistSum.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ class CMSHistSum : public RooAbsReal {
static void EnableFastVertical();
friend class CMSHistV<CMSHistSum>;

// TODO: allow any class that implements hasChanged() and batchGetBinValues()
void injectExternalMorph(int idx, CMSInterferenceFunc& morph);
void injectExternalMorph(int idx, CMSExternalMorph& morph);

protected:
RooRealProxy x_;
Expand Down
20 changes: 6 additions & 14 deletions interface/CMSInterferenceFunc.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
#ifndef CMSInterferenceFunc_h
#define CMSInterferenceFunc_h
#include <vector>

#include "RooAbsReal.h"
#include "RooRealVar.h"
#include "RooRealProxy.h"
#include "RooListProxy.h"
#include "SimpleCacheSentry.h"

#include "CMSExternalMorph.h"

class _InterferenceEval;

class CMSInterferenceFunc : public RooAbsReal {
class CMSInterferenceFunc : public CMSExternalMorph {
public:
CMSInterferenceFunc();
CMSInterferenceFunc(CMSInterferenceFunc const& other, const char* name = 0);
Expand All @@ -24,8 +21,8 @@ class CMSInterferenceFunc : public RooAbsReal {
const char* name,
const char* title,
RooRealVar& x,
const RooArgList& coefficients,
const std::vector<double>& edges,
const RooArgList& coefficients,
const std::vector<std::vector<double>> binscaling
);
virtual ~CMSInterferenceFunc();
Expand All @@ -38,21 +35,16 @@ class CMSInterferenceFunc : public RooAbsReal {
std::ostream& os, Int_t contents, Bool_t verbose, TString indent
) const override;

// batch accessor for CMSHistFunc / CMSHistSum
bool hasChanged() const { return !sentry_.good(); };
const std::vector<double>& batchGetBinValues() const;
bool hasChanged() const override { return !sentry_.good(); };
const std::vector<double>& batchGetBinValues() const override;

protected:
RooRealProxy x_;
RooListProxy coefficients_;
std::vector<double> edges_;
std::vector<std::vector<double>> binscaling_;

mutable SimpleCacheSentry sentry_; //!
mutable std::unique_ptr<_InterferenceEval> evaluator_; //!

double evaluate() const override;

private:
void initialize() const;
void updateCache() const;
Expand Down
2 changes: 1 addition & 1 deletion python/InterferenceModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def done(self):
for sbin in item["scaling"]:
scaling_array.push_back(sbin)

self.modelBuilder.out.safe_import(ROOT.CMSInterferenceFunc(funcname, "", histfunc.getXVar(), params, edges, scaling_array))
self.modelBuilder.out.safe_import(ROOT.CMSInterferenceFunc(funcname, "", histfunc.getXVar(), edges, params, scaling_array))
func = self.modelBuilder.out.function(funcname)
if isinstance(histfunc, ROOT.CMSHistFunc):
histfunc.injectExternalMorph(func)
Expand Down
33 changes: 33 additions & 0 deletions src/CMSExternalMorph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "../interface/CMSExternalMorph.h"

CMSExternalMorph::CMSExternalMorph() {}

CMSExternalMorph::CMSExternalMorph(
const char* name,
const char* title,
RooRealVar& x,
const std::vector<double>& edges
) :
RooAbsReal(name, title),
x_("x", "", this, x),
edges_(edges)
{
}

CMSExternalMorph::CMSExternalMorph(CMSExternalMorph const& other, const char* name) :
RooAbsReal(other, name),
x_("x", this, other.x_),
edges_(other.edges_)
{
}

CMSExternalMorph::~CMSExternalMorph() = default;

double CMSExternalMorph::evaluate() const {
auto it = std::upper_bound(std::begin(edges_), std::end(edges_), x_->getVal());
if ( (it == std::begin(edges_)) or (it == std::end(edges_)) ) {
return 0.0;
}
size_t idx = std::distance(std::begin(edges_), it) - 1;
return batchGetBinValues()[idx];
}
6 changes: 3 additions & 3 deletions src/CMSHistFunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ void CMSHistFunc::updateCache() const {
if (mcache_.size() == 0) mcache_.resize(storage_.size());
}

bool external_morph_updated = (external_morph_.getSize() && static_cast<CMSInterferenceFunc*>(external_morph_.at(0))->hasChanged());
bool external_morph_updated = (external_morph_.getSize() && static_cast<CMSExternalMorph*>(external_morph_.at(0))->hasChanged());
if (step1 || external_morph_updated) {
fast_vertical_ = false;
}
Expand Down Expand Up @@ -567,7 +567,7 @@ void CMSHistFunc::updateCache() const {
std::cout << "Template before external morph update:" << mcache_[idx].step2.Integral() << "\n";
mcache_[idx].step2.Dump();
#endif
auto& extdata = static_cast<CMSInterferenceFunc*>(external_morph_.at(0))->batchGetBinValues();
auto& extdata = static_cast<CMSExternalMorph*>(external_morph_.at(0))->batchGetBinValues();
for(size_t i=0; i<extdata.size(); ++i) {
mcache_[idx].step2[i] *= extdata[i];
}
Expand Down Expand Up @@ -1220,7 +1220,7 @@ void CMSHistFunc::EnableFastVertical() {
}


void CMSHistFunc::injectExternalMorph(CMSInterferenceFunc& morph) {
void CMSHistFunc::injectExternalMorph(CMSExternalMorph& morph) {
if ( morph.batchGetBinValues().size() != cache_.size() ) {
throw std::runtime_error("Mismatched binning between external morph and CMSHistFunc");
// equal edges are user responsibility for now
Expand Down
6 changes: 3 additions & 3 deletions src/CMSHistSum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ void CMSHistSum::initialize() const {

void CMSHistSum::updateMorphs() const {
// set up pointers ahead of time for quick loop
std::vector<CMSInterferenceFunc*> process_morphs(compcache_.size(), nullptr);
std::vector<CMSExternalMorph*> process_morphs(compcache_.size(), nullptr);
// if any external morphs are dirty, disable fast_mode_
for(size_t i=0; i < external_morph_indices_.size(); ++i) {
auto* morph = static_cast<CMSInterferenceFunc*>(external_morphs_.at(i));
auto* morph = static_cast<CMSExternalMorph*>(external_morphs_.at(i));
process_morphs[external_morph_indices_[i]] = morph;
if (morph->hasChanged()) {
fast_mode_ = 0;
Expand Down Expand Up @@ -780,7 +780,7 @@ void CMSHistSum::EnableFastVertical() {
enable_fast_vertical_ = true;
}

void CMSHistSum::injectExternalMorph(int idx, CMSInterferenceFunc& morph) {
void CMSHistSum::injectExternalMorph(int idx, CMSExternalMorph& morph) {
if ( idx >= coeffpars_.getSize() ) {
throw std::runtime_error("Process index larger than number of processes in CMSHistSum");
}
Expand Down
29 changes: 9 additions & 20 deletions src/CMSInterferenceFunc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,24 @@ CMSInterferenceFunc::CMSInterferenceFunc() {};
CMSInterferenceFunc::CMSInterferenceFunc(
CMSInterferenceFunc const& other, const char* name
) :
RooAbsReal(other, name), x_("x", this, other.x_),
CMSExternalMorph(other, name),
coefficients_("coefficients", this, other.coefficients_),
edges_(other.edges_), binscaling_(other.binscaling_),
binscaling_(other.binscaling_),
sentry_(name ? TString(name) + "_sentry" : TString(other.GetName())+"_sentry", "")
{
}

CMSInterferenceFunc::CMSInterferenceFunc(
const char* name, const char* title, RooRealVar& x,
RooArgList const& coefficients, const std::vector<double>& edges,
const char* name,
const char* title,
RooRealVar& x,
const std::vector<double>& edges,
RooArgList const& coefficients,
const std::vector<std::vector<double>> binscaling
) :
RooAbsReal(name, title), x_("x", "", this, x),
CMSExternalMorph(name, title, x, edges),
coefficients_("coefficients", "", this),
edges_(edges), binscaling_(binscaling),
binscaling_(binscaling),
sentry_(TString(name) + "_sentry", "")
{
coefficients_.add(coefficients);
Expand Down Expand Up @@ -117,21 +120,7 @@ void CMSInterferenceFunc::updateCache() const {
sentry_.reset();
}

double CMSInterferenceFunc::evaluate() const {
if ( not evaluator_ ) initialize();
if ( not sentry_.good() ) updateCache();

auto it = std::upper_bound(std::begin(edges_), std::end(edges_), x_->getVal());
if ( (it == std::begin(edges_)) or (it == std::end(edges_)) ) {
return 0.0;
}
size_t idx = std::distance(std::begin(edges_), it) - 1;
return evaluator_->getValues()[idx];
}

const std::vector<double>& CMSInterferenceFunc::batchGetBinValues() const {
// we don't really expect the cache to be valid, as upstream callers are
// managing their own and calling this only when dirty, but let's check anyway
if ( not evaluator_ ) initialize();
if ( not sentry_.good() ) updateCache();
return evaluator_->getValues();
Expand Down
1 change: 1 addition & 0 deletions src/classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@
#include "HiggsAnalysis/CombinedLimit/interface/RooCheapProduct.h"
#include "HiggsAnalysis/CombinedLimit/interface/CMSHggFormula.h"
#include "HiggsAnalysis/CombinedLimit/interface/SimpleProdPdf.h"
#include "HiggsAnalysis/CombinedLimit/interface/CMSExternalMorph.h"
#include "HiggsAnalysis/CombinedLimit/interface/CMSInterferenceFunc.h"
#include "HiggsAnalysis/CombinedLimit/interface/RooEFTScalingFunction.h"
1 change: 1 addition & 0 deletions src/classes_def.xml
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@
<class name="CMSHggFormulaD1" />
<class name="CMSHggFormulaD2" />
<class name="SimpleProdPdf" />
<class name="CMSExternalMorph" />
<class name="CMSInterferenceFunc" />
<class name="RooEFTScalingFunction" />
</lcgdict>
12 changes: 8 additions & 4 deletions test/test_interference.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,17 @@ def setvars(x, kl, kv, k2v):

# toy generation is different between the histsum and histfunc models, somehow
ntoys = 10
ret = subprocess.call(f"combine -M GenerateOnly card.root -t {ntoys} --saveToys".split(" "))
ret = subprocess.call("combine -M GenerateOnly card.root -t {ntoys} --saveToys".format(ntoys=ntoys).split(" "))
assert ret == 0

ret = subprocess.call(f"combine -M MultiDimFit card.root -t {ntoys} --toysFile higgsCombineTest.GenerateOnly.mH120.123456.root -n HistFunc".split(" "))
ret = subprocess.call(
"combine -M MultiDimFit card.root -t {ntoys} --toysFile higgsCombineTest.GenerateOnly.mH120.123456.root -n HistFunc".format(ntoys=ntoys).split(" ")
)
assert ret == 0

ret = subprocess.call(f"combine -M MultiDimFit card_histsum.root -t {ntoys} --toysFile higgsCombineTest.GenerateOnly.mH120.123456.root -n HistSum".split(" "))
ret = subprocess.call(
"combine -M MultiDimFit card_histsum.root -t {ntoys} --toysFile higgsCombineTest.GenerateOnly.mH120.123456.root -n HistSum".format(ntoys=ntoys).split(" ")
)
assert ret == 0

f_histfunc = ROOT.TFile.Open("higgsCombineHistFunc.MultiDimFit.mH120.123456.root")
Expand All @@ -195,4 +199,4 @@ def setvars(x, kl, kv, k2v):
if abs(row1.k2v - row2.k2v) > 1e-4:
ndiff["k2v"] += 1

print(f"Out of {ntoys} toys, {ndiff} are not matching (tolerance: 1e-4) between CMSHistFunc and CMSHistSum")
print("Out of {ntoys} toys, {ndiff} are not matching (tolerance: 1e-4) between CMSHistFunc and CMSHistSum".format(ntoys=ntoys, ndiff=ndiff))

0 comments on commit 6123b59

Please sign in to comment.