Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot reproduce results in Table 1, especially the IS score #83

Open
adreamwu opened this issue Aug 15, 2024 · 1 comment
Open

Cannot reproduce results in Table 1, especially the IS score #83

adreamwu opened this issue Aug 15, 2024 · 1 comment

Comments

@adreamwu
Copy link

adreamwu commented Aug 15, 2024

Has anyone been able to reproduce the results in Table 1 of the paper? Could you please share the inference script?

We use B=50 for each class and var_d16 for evaluation.

  • report
FID IS Pre Rec
3.30 274.4 0.84 0.51
  • reproduced
FID IS Pre Rec
3.49 281.71 0.85 0.50
all_labels = np.arange(1000)
count = 0
for c in all_labels:
    class_labels = [c] * 50
    B = len(class_labels)
    label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
    with torch.inference_mode():
        with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
            recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=1.5, top_k=900, top_p=0.96, g_seed=0, more_smooth=False)
    recon_B3HW = recon_B3HW.mul(255).add(0.5).clamp(0, 255).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
    for i in range(recon_B3HW.shape[0]):
        PImage.fromarray(recon_B3HW[i]).save(os.path.join(f"./{FOLDER}/", str(c) + "_" + str(count) +  ".png"))
        count = (count + 1) % NUM_PER_CLASS
create_npz_from_sample_folder(f"./{FOLDER}")
@eyedealism
Copy link

Could you share your evaluation script? I could only reproduce IS around 60~80 for different d.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants