Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score target #4

Merged
merged 15 commits into from
Mar 6, 2024
6 changes: 6 additions & 0 deletions crystal_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pathlib import Path

ROOT_DIR = Path(__file__).parent
TOP_DIR = ROOT_DIR.parent
ANALYSIS_RESULTS_DIR = TOP_DIR.joinpath("analysis_results/")
ANALYSIS_RESULTS_DIR.mkdir(exist_ok=True)
6 changes: 6 additions & 0 deletions crystal_diffusion/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pathlib import Path

PLEASANT_FIG_SIZE = (7.2, 4.45)

ANALYSIS_DIR = Path(__file__).parent
PLOT_STYLE_PATH = ANALYSIS_DIR.joinpath("plot_style.txt")
40 changes: 40 additions & 0 deletions crystal_diffusion/analysis/plot_style.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
xtick.color: 323034
ytick.color: 323034
text.color: 323034
lines.markeredgecolor: black
patch.facecolor : bc80bd
patch.force_edgecolor : True
patch.linewidth: 0.8
scatter.edgecolors: black
grid.color: b1afb5
axes.titlesize: 16
legend.title_fontsize: 12
xtick.labelsize: 12
ytick.labelsize: 12
axes.labelsize: 12
font.size: 10
axes.prop_cycle : (cycler('color', ['bc80bd' ,'fb8072', 'b3de69','fdb462','fccde5','8dd3c7','ffed6f','bebada','80b1d3', 'ccebc5', 'd9d9d9']))
mathtext.fontset: stix
font.family: STIXGeneral
lines.linewidth: 2
legend.frameon: True
legend.framealpha: 0.8
legend.fontsize: 10
legend.edgecolor: 0.9
legend.borderpad: 0.2
legend.columnspacing: 1.5
legend.labelspacing: 0.4
text.usetex: False
axes.titlelocation: left
axes.formatter.use_mathtext: True
axes.autolimit_mode: round_numbers
axes.labelpad: 3
axes.formatter.limits: -4, 4
axes.labelcolor: black
axes.edgecolor: black
axes.linewidth: 0.6
axes.spines.right : False
axes.spines.top : False
axes.grid: False
figure.titlesize: 18
figure.dpi: 300
91 changes: 91 additions & 0 deletions crystal_diffusion/analysis/target_score_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Target score analysis.

This script computes and plots the target score for various values of sigma, showing
that the 'smart' implementation converges quickly and is equal to the expected brute force value.
"""
import matplotlib.pyplot as plt
import numpy as np
import torch

from crystal_diffusion import ANALYSIS_RESULTS_DIR
from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH
from crystal_diffusion.score.wrapped_gaussian_score import (
SIGMA_THRESHOLD, get_expected_sigma_normalized_score_brute_force,
get_sigma_normalized_score)

plt.style.use(PLOT_STYLE_PATH)

if __name__ == '__main__':

list_u = np.linspace(0, 1, 101)[:-1]
relative_positions = torch.from_numpy(list_u)

# A first figure to compare the "smart" and the "brute force" results
fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig1.suptitle("Smart vs. brute force scores")

kmax = 4
ax1 = fig1.add_subplot(121)
ax2 = fig1.add_subplot(122)

for sigma_factor, color in zip([0.1, 1., 2.], ['r', 'g', 'b']):
sigma = sigma_factor * SIGMA_THRESHOLD

sigmas = torch.ones_like(relative_positions) * sigma
list_scores_brute = np.array([get_expected_sigma_normalized_score_brute_force(u, sigma) for u in list_u])
list_scores = get_sigma_normalized_score(relative_positions, sigmas, kmax=kmax).numpy()
error = list_scores - list_scores_brute

ax1.plot(list_u, list_scores_brute, '--', c=color, lw=4, label='brute force')
ax1.plot(list_u, list_scores, '-', c=color, lw=2, label='smart')

ax2.plot(list_u, error, '-', c=color, label=f'$\\sigma$ = {sigma_factor} $\\sigma_{{th}}$')

ax1.set_ylabel('$\\sigma^2\\times S$')
ax2.set_xlabel('u')
ax2.set_ylabel('Error')

for ax in [ax1, ax2]:
ax.set_xlabel('u')
ax.set_xlim([0, 1])
ax.legend(loc=0)

fig1.tight_layout()
fig1.savefig(ANALYSIS_RESULTS_DIR.joinpath("score_convergence_with_sigma.png"))

# A second figure to show convergence with kmax
fig2 = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig2.suptitle("Convergence with kmax")

ax3 = fig2.add_subplot(121)
ax4 = fig2.add_subplot(122)

sigma_factors = torch.linspace(0.001, 4., 40).to(torch.double)
sigmas = sigma_factors * SIGMA_THRESHOLD

u = 0.6
relative_positions = torch.ones_like(sigmas).to(torch.double) * u

ms = 8
for kmax, color in zip([1, 2, 3, 4, 5], ['y', 'r', 'g', 'b', 'k']):
list_scores = get_sigma_normalized_score(relative_positions, sigmas, kmax=kmax).numpy()
ax3.semilogy(sigma_factors, list_scores, 'o-',
ms=ms, c=color, lw=2, alpha=0.25, label=f'kmax = {kmax}')

list_scores_brute = np.array([
get_expected_sigma_normalized_score_brute_force(u, sigma, kmax=4 * kmax) for sigma in sigmas])
ax4.semilogy(sigma_factors, list_scores_brute, 'o-',
ms=ms, c=color, lw=2, alpha=0.25, label=f'kmax = {4 * kmax}')

ms = 0.75 * ms

for ax in [ax3, ax4]:
ax.set_xlabel('$\\sigma$ ($\\sigma_{th}$)')
ax.set_ylabel(f'$\\sigma^2 \\times S(u={u})$')
ax.set_xlim([0., 1.1 * sigma_factors.max()])
ax.legend(loc=0)

ax3.set_title("Smart implementation")
ax4.set_title("Brute force implementation")
fig2.tight_layout()
fig2.savefig(ANALYSIS_RESULTS_DIR.joinpath("score_convergence_with_k.png"))
Empty file.
Loading
Loading