Skip to content

Commit

Permalink
ENH: compute hash without PYTHONHASHSED (#444)
Browse files Browse the repository at this point in the history
* BEHAVIOR: compute hash with `md5`
* DX: add more hashing tests
* DX: shorten hash test values to first 7 digets
* ENH: set `pickle` protocol to highest
  • Loading branch information
redeboer authored Dec 20, 2024
1 parent 2bbf608 commit f955d0a
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 140 deletions.
2 changes: 0 additions & 2 deletions .envrc
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
uv sync --all-extras --quiet
source .venv/bin/activate

export PYTHONHASHSEED=0
2 changes: 0 additions & 2 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
name: Benchmark
env:
PYTHONHASHSEED: "0"

on:
push:
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/ci-qrules-v0.9.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

env:
PYTHONHASHSEED: "0"

on:
push:
branches:
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ concurrency:
cancel-in-progress: |-
${{ github.ref != format('refs/heads/{0}', github.event.repository.default_branch) }}
env:
PYTHONHASHSEED: "0"

on:
push:
branches:
Expand Down
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ repos:
- id: check-dev-files
args:
- --doc-apt-packages=graphviz
- --environment-variables=PYTHONHASHSEED=0
- --repo-name=ampform
- --repo-title=AmpForm
- --update-lock-files=outsource
Expand Down
4 changes: 2 additions & 2 deletions docs/_extend_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# pyright: reportMissingImports=false
from __future__ import annotations

import hashlib
import inspect
import logging
import pickle
Expand All @@ -24,6 +23,7 @@
from ampform.io import aslatex
from ampform.kinematics.lorentz import ArraySize, FourMomentumSymbol
from ampform.sympy._array_expressions import ArrayMultiplication
from ampform.sympy._cache import get_readable_hash

if TYPE_CHECKING:
from qrules.transition import ReactionInfo, SpinFormalism
Expand Down Expand Up @@ -727,7 +727,7 @@ def __generate_transitions_cached(
) -> ReactionInfo:
version = get_package_version("qrules")
obj = (initial_state, final_state, formalism)
h = hashlib.sha256(pickle.dumps(obj)).hexdigest()
h = get_readable_hash(obj)
docs_dir = Path(__file__).parent
file_name = docs_dir / ".cache" / f"reaction-qrules-v{version}-{h}.pickle"
file_name.parent.mkdir(exist_ok=True)
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,6 @@ commands =
pytest {posargs}
description = Run all unit tests
passenv = *
setenv =
PYTHONHASHSEED = 0
[testenv:bench]
allowlist_externals =
Expand Down
4 changes: 0 additions & 4 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,6 @@ def perform_cached_doit(
:file:`ampform` under the system cache directory (see
:func:`.get_system_cache_directory`).
.. tip:: For a faster cache, set `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ to a
fixed value.
.. versionadded:: 0.14.4
.. automodule:: ampform.sympy._cache
"""
Expand Down
50 changes: 8 additions & 42 deletions src/ampform/sympy/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@

from __future__ import annotations

import functools
import hashlib
import logging
import os
import pickle # noqa: S403
import sys
from textwrap import dedent

import sympy as sp

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,48 +36,18 @@ def get_system_cache_directory() -> str:
return os.path.expanduser("~/.cache")


def get_readable_hash(obj, ignore_hash_seed: bool = False) -> str:
def get_readable_hash(obj) -> str:
"""Get a human-readable hash of any hashable Python object.
The algorithm is fastest if `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ is set.
Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`.
Args:
obj: Any hashable object, mutable or immutable, to be hashed.
ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If
:code:`True`, the hash seed is ignored and the hash is computed with
:func:`hashlib.sha256`.
"""
python_hash_seed = _get_python_hash_seed()
if ignore_hash_seed or python_hash_seed is None:
b = _to_bytes(obj)
return hashlib.sha256(b).hexdigest()
return f"pythonhashseed-{python_hash_seed}{hash(obj):+}"


def _to_bytes(obj) -> bytes:
if isinstance(obj, sp.Expr):
# Using the str printer is slower and not necessarily unique,
# but pickle.dumps() does not always result in the same bytes stream.
_warn_about_unsafe_hash()
return str(obj).encode()
return pickle.dumps(obj)
b = to_bytes(obj)
h = hashlib.md5(b) # noqa: S324
return h.hexdigest()


def _get_python_hash_seed() -> int | None:
python_hash_seed = os.environ.get("PYTHONHASHSEED", "")
if python_hash_seed is not None and python_hash_seed.isdigit():
return int(python_hash_seed)
return None


@functools.cache # warn once
def _warn_about_unsafe_hash() -> None:
message = """
PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions,
set the PYTHONHASHSEED environment variable to a fixed value and rerun the program.
See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)
def to_bytes(obj) -> bytes:
if isinstance(obj, bytes | bytearray):
return obj
return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
146 changes: 67 additions & 79 deletions tests/sympy/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,42 @@

import logging
import os
import sys
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar

import pytest
import qrules
import sympy as sp

from ampform import get_builder
from ampform.dynamics import EnergyDependentWidth
from ampform.sympy._cache import _warn_about_unsafe_hash, get_readable_hash
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff
from ampform.sympy._cache import get_readable_hash

if TYPE_CHECKING:
from _pytest.logging import LogCaptureFixture
from qrules.transition import SpinFormalism

from ampform.helicity import HelicityModel

_GH = "CI" in os.environ and os.uname().sysname != "Darwin"


@pytest.mark.parametrize(
("assumptions", "expected_hashes"),
("expected_hash", "assumptions"),
[
(
dict(),
{
"3.7": 7060330373292767180,
"3.8": 7459658071388516764,
"3.9": 7459658071388516764,
"3.10": 7459658071388516764,
"3.11": 8778804591879682108,
"3.12": 8778804591879682108,
},
),
(
dict(real=True),
{
"3.7": 118635607833730864,
"3.8": 3665410414623666716,
"3.9": 3665410414623666716,
"3.10": 3665410414623666716,
"3.11": -7967572625470457155,
"3.12": -7967572625470457155,
},
),
(
dict(rational=True),
{
"3.7": -1011754479721050016,
"3.8": -7926839224244779605,
"3.9": -7926839224244779605,
"3.10": -7926839224244779605,
"3.11": -8321323707982755013,
"3.12": -8321323707982755013,
},
),
("a7559ca", dict()),
("f4b1fad", dict(real=True)),
("d5bdc74", dict(rational=True)),
],
ids=["symbol", "symbol-real", "symbol-rational"],
)
def test_get_readable_hash(assumptions, expected_hashes, caplog: LogCaptureFixture):
python_version = ".".join(map(str, sys.version_info[:2]))
expected_hash = expected_hashes[python_version]
def test_get_readable_hash(
assumptions: dict, expected_hash: str, caplog: LogCaptureFixture
):
caplog.set_level(logging.WARNING)
x, y = sp.symbols("x y", **assumptions)
expr = x**2 + y
h_str = get_readable_hash(expr)
python_hash_seed = os.environ.get("PYTHONHASHSEED")
if python_hash_seed is None:
assert h_str[:7] == "bbc9833"
if _warn_about_unsafe_hash.cache_info().hits == 0:
assert "PYTHONHASHSEED has not been set." in caplog.text
caplog.clear()
elif python_hash_seed == "0":
h = int(h_str.replace("pythonhashseed-0", ""))
assert h == expected_hash
else:
pytest.skip(f"PYTHONHASHSEED has been set, but is {python_hash_seed}, not 0")
h = get_readable_hash(expr)[:7]
assert h == expected_hash
assert not caplog.text


Expand All @@ -88,31 +53,54 @@ def test_get_readable_hash_energy_dependent_width():
angular_momentum=angular_momentum,
meson_radius=d,
)
h = get_readable_hash(expr)
python_hash_seed = os.environ.get("PYTHONHASHSEED")
if python_hash_seed is None:
pytest.skip("PYTHONHASHSEED has not been set")
if python_hash_seed != "0":
pytest.skip(f"PYTHONHASHSEED is not set to 0, but to {python_hash_seed}")
if sys.version_info >= (3, 11):
assert h == "pythonhashseed-0+4377931190501974271"
else:
assert h == "pythonhashseed-0+8267198661922532208"
h = get_readable_hash(expr)[:7]
assert h == "ccafec3"


class TestLargeHash:
initial_state: ClassVar = [("J/psi(1S)", [-1, 1])]
final_state: ClassVar = ["gamma", "pi0", "pi0"]
allowed_intermediate_particles: ClassVar = ["f(0)(980)", "f(0)(1500)"]
allowed_interaction_types: ClassVar = "strong"

@pytest.mark.parametrize(
("expected_hash", "formalism"),
[
("762cc00", "canonical-helicity"),
("17fefe5", "helicity"),
],
ids=["canonical-helicity", "helicity"],
)
def test_reaction(self, expected_hash: str, formalism: SpinFormalism):
reaction = qrules.generate_transitions(
initial_state=self.initial_state,
final_state=self.final_state,
allowed_intermediate_particles=self.allowed_intermediate_particles,
allowed_interaction_types=self.allowed_interaction_types,
formalism=formalism,
)
h = get_readable_hash(reaction)[:7]
assert h == expected_hash

def test_get_readable_hash_large(amplitude_model: tuple[str, HelicityModel]):
python_hash_seed = os.environ.get("PYTHONHASHSEED")
if python_hash_seed != "0":
pytest.skip("PYTHONHASHSEED is not 0")
formalism, model = amplitude_model
if sys.version_info >= (3, 11):
expected_hash = {
"canonical-helicity": "pythonhashseed-0-8140852268928771574",
"helicity": "pythonhashseed-0-991855900379383849",
}[formalism]
else:
expected_hash = {
"canonical-helicity": "pythonhashseed-0+3166036244969111461",
"helicity": "pythonhashseed-0+4247688887304834148",
}[formalism]
assert get_readable_hash(model.expression) == expected_hash
@pytest.mark.parametrize(
("expected_hash", "formalism"),
[
("87c4839" if _GH else "01bb112", "canonical-helicity"),
("c147bdd" if _GH else "0638a0e", "helicity"),
],
ids=["canonical-helicity", "helicity"],
)
def test_amplitude_model(self, expected_hash: str, formalism: SpinFormalism):
reaction = qrules.generate_transitions(
initial_state=[("J/psi(1S)", [-1, 1])],
final_state=["gamma", "pi0", "pi0"],
allowed_intermediate_particles=["f(0)(980)", "f(0)(1500)"],
allowed_interaction_types="strong",
formalism=formalism,
)
builder = get_builder(reaction)
for name in reaction.get_intermediate_particles().names:
builder.dynamics.assign(name, create_relativistic_breit_wigner_with_ff)
model = builder.formulate()
h = get_readable_hash(model.expression)[:7]
assert h == expected_hash

0 comments on commit f955d0a

Please sign in to comment.