Skip to content

Commit

Permalink
catch invalid model specification for duplicat
Browse files Browse the repository at this point in the history
e channels, samples and modifiers.
  • Loading branch information
lorenzennio committed Dec 8, 2023
1 parent a1a31f1 commit e6887d9
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 0 deletions.
47 changes: 47 additions & 0 deletions src/pyhf/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Any, Sequence

from pyhf import exceptions
from pyhf.typing import Channel

log = logging.getLogger(__name__)
Expand All @@ -21,6 +22,10 @@ class _ChannelSummaryMixin:
def __init__(self, *args: Any, **kwargs: Sequence[Channel]):
channels = kwargs.pop('channels')
super().__init__(*args, **kwargs)

# check for duplicates
self._check_for_duplicates(channels)

self._channels: list[str] = []
self._samples: list[str] = []
self._modifiers: list[tuple[str, str]] = []
Expand Down Expand Up @@ -89,3 +94,45 @@ def channel_slices(self) -> dict[str, slice]:
Dictionary mapping channel name to the bin slices in the model.
"""
return self._channel_slices

def _check_for_duplicates(self, channels: Sequence[Channel]) -> None:
"""
Check for duplicate channels.
Check for duplicate samples within each channel.
Check for duplicate modifiers within each sample.
"""
channel_names = [channel['name'] for channel in channels]
if len(channel_names) != len(set(channel_names)):
duplicates = sorted(
set([f"'{x}'" for x in channel_names if channel_names.count(x) > 1])
)
raise exceptions.InvalidNameReuse(
"Duplicate channels "
+ ", ".join(duplicates)
+ " found in the model. Rename one of them."
)
for channel in channels:
sample_names = [samples['name'] for samples in channel['samples']]
if len(sample_names) != len(set(sample_names)):
duplicates = sorted(
set([f"'{x}'" for x in sample_names if sample_names.count(x) > 1])
)
raise exceptions.InvalidNameReuse(
"Duplicate samples "
+ ", ".join(duplicates)
+ f" found in the channel '{channel['name']}'. Rename one of them."
)
for sample in channel['samples']:
modifiers = [
(modifier['name'], modifier['type'])
for modifier in sample['modifiers']
]
if len(modifiers) != len(set(modifiers)):
duplicates = sorted(
set([f"'{x[0]}'" for x in modifiers if modifiers.count(x) > 1])
)
raise exceptions.InvalidNameReuse(
"Duplicate modifiers "
+ ", ".join(duplicates)
+ f" of the same type found in channel '{channel['name']}' and sample '{sample['name']}'. Rename one of them."
)
15 changes: 15 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,18 @@ def test_schema_tensor_type_disallowed(mocker, backend):
}
with pytest.raises(pyhf.exceptions.InvalidSpecification):
pyhf.schema.validate(spec, "model.json")


@pytest.mark.parametrize(
'model_file',
[
'model_duplicate_channels.json',
'model_duplicate_samples.json',
'model_duplicate_modifiers.json',
],
)
def test_schema_catch_duplicates(datadir, model_file):
with open(datadir.joinpath(model_file), encoding="utf-8") as spec_file:
model_spec = json.load(spec_file)
with pytest.raises(pyhf.exceptions.InvalidNameReuse):
pyhf.Model(model_spec)
50 changes: 50 additions & 0 deletions tests/test_schema/model_duplicate_channels.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"channels": [
{
"name": "singlechannel",
"samples": [
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "background",
"data": [15],
"modifiers": [
{
"name": "uncorr_bkguncrt",
"type": "shapesys",
"data": [5]
}
]
}
]
},
{
"name": "singlechannel",
"samples": [
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "background",
"data": [15],
"modifiers": [
{
"name": "uncorr_bkguncrt",
"type": "shapesys",
"data": [5]
}
]
}
]
}
]
}
28 changes: 28 additions & 0 deletions tests/test_schema/model_duplicate_modifiers.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"channels": [
{
"name": "singlechannel",
"samples": [
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null},
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "background",
"data": [15],
"modifiers": [
{
"name": "uncorr_bkguncrt",
"type": "shapesys",
"data": [5]
}
]
}
]
}
]
}
34 changes: 34 additions & 0 deletions tests/test_schema/model_duplicate_samples.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"channels": [
{
"name": "singlechannel",
"samples": [
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "background",
"data": [15],
"modifiers": [
{
"name": "uncorr_bkguncrt",
"type": "shapesys",
"data": [5]
}
]
}
]
}
]
}

0 comments on commit e6887d9

Please sign in to comment.