Skip to content

Commit

Permalink
Merge pull request #108 from mila-iqia/test_greedy_sampling
Browse files Browse the repository at this point in the history
Test greedy sampling
  • Loading branch information
sblackburn86 authored Nov 29, 2024
2 parents ab6cc55 + eca5015 commit 6a74116
Show file tree
Hide file tree
Showing 12 changed files with 618 additions and 241 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def initialize(

return fixed_init_composition

def relative_coordinates_update(
def _relative_coordinates_update(
self,
relative_coordinates: torch.Tensor,
sigma_normalized_scores: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, pickle_path: Path, num_classes: int):
data = torch.load(pickle_path, map_location=torch.device("cpu"))
logger.info("Done reading data.")

noise_parameters = NoiseParameters(**data['noise_parameters'][0])
noise_parameters = NoiseParameters(**data['noise_parameters'])
sampler = NoiseScheduler(noise_parameters, num_classes=num_classes)
self.noise, _ = sampler.get_all_sampling_parameters()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class SamplingParameters:
False # should the predictor and corrector steps be recorded to a file
)
record_samples_corrector_steps: bool = False
record_atom_type_update: bool = False # record the information pertaining to generating atom types.


class AXLGenerator(ABC):
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def __init__(
**kwargs,
):
"""Init method."""
# T = 1 is a dangerous and meaningless edge case.
assert (
number_of_discretization_steps > 0
), "The number of discretization steps should be larger than zero"
number_of_discretization_steps > 1
), "The number of discretization steps should be larger than one"
assert (
number_of_corrector_steps >= 0
), "The number of corrector steps should be non-negative"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ def get_probability_at_previous_time_step(
one-step transition normalized probabilities of dimension [batch_size, number_of_atoms, num_type_atoms]
"""
if probability_at_zeroth_timestep_are_logits:
probability_at_zeroth_timestep = torch.nn.functional.softmax(
probability_at_zeroth_timestep, dim=-1
)
probability_at_zeroth_timestep = get_probability_from_logits(probability_at_zeroth_timestep,
lowest_probability_value=small_epsilon)

numerator1 = einops.einsum(
probability_at_zeroth_timestep, q_bar_tm1_matrices, "... j, ... j i -> ... i"
Expand All @@ -116,12 +115,35 @@ def get_probability_at_previous_time_step(
one_hot_probability_at_current_timestep,
"... i j, ... j -> ... i",
)
den2 = einops.einsum(
probability_at_zeroth_timestep, den1, "... j, ... j -> ..."
).clip(min=small_epsilon)
den2 = einops.einsum(probability_at_zeroth_timestep, den1, "... j, ... j -> ...")

denominator = einops.repeat(
den2, "... -> ... num_classes", num_classes=numerator.shape[-1]
)

return numerator / denominator


def get_probability_from_logits(logits: torch.Tensor, lowest_probability_value: float) -> torch.Tensor:
"""Get probability from logits.
Compute the probabilities from the logit, imposing that no class probablility can be lower than
lowest_probability_value.
Args:
logits: Unormalized values that can be turned into probabilities. Dimensions [..., num_classes]
lowest_probability_value: imposed lowest probability value for any class.
Returns:
probabilities: derived from the logits, with minimal clipped values. Dimensions [..., num_classes].
"""
raw_probabilities = torch.nn.functional.softmax(logits, dim=-1)
probability_sum = raw_probabilities.sum(dim=-1)
torch.testing.assert_close(probability_sum, torch.ones_like(probability_sum),
msg="Logits are pathological: the probabilities do not sum to one.")

clipped_probabilities = raw_probabilities.clip(min=lowest_probability_value)

probabilities = clipped_probabilities / clipped_probabilities.sum(dim=-1).unsqueeze(-1)
return probabilities
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,10 @@ def record(self, key: str, entry: Union[Dict[str, Any], NamedTuple]):

def write_to_pickle(self, path_to_pickle: str):
"""Write data to pickle file."""
self._internal_data = dict(self._internal_data)
for key, value in self._internal_data.items():
if len(value) == 1:
self._internal_data[key] = value[0]

with open(path_to_pickle, "wb") as fd:
torch.save(self._internal_data, fd)
2 changes: 1 addition & 1 deletion tests/generators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class FakeAXLNetwork(ScoreNetwork):
"""A fake, smooth score network for the ODE solver."""
"""A fake score network for tests."""

def _forward_unchecked(
self, batch: Dict[AnyStr, torch.Tensor], conditional: bool = False
Expand Down
Loading

0 comments on commit 6a74116

Please sign in to comment.