Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KAN readout options for MACE with possible better accuracy #655

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from

Conversation

Hongyu-yu
Copy link
Contributor

MACE + KAN

With additional KAN readout for MACE, more complex combination of spherical basis emerges with a more accurate MACE model and even makes MACE more explainable. Tests and benchmark results will be updated at https://arxiv.org/abs/2409.03430v1.

dill is used for torch.save/load
pykan is used for multikan

Usage: add --KAN_readout in command line like mace_run_train --KAN_readout ...
Hope this pull could provide a more accurate MACE model to the community!
If it helps and is used, please consider to cite https://arxiv.org/abs/2409.03430v1 and http://arxiv.org/abs/2408.10205.

Hongyu Yu

@gabor1
Copy link
Collaborator

gabor1 commented Oct 24, 2024

Nice work. A couple of questions and comments.

Do you have mace results that you can share? The improvement in validation error for Allegro is impressive, but Allegro is generally less accurate than mace. So I'd like to see how the improvements translate.

You should really have the y axis of your bar charts start at zero. This kinds of thing is in the first chapter of "How to Lie with Statistics" https://g.co/kgs/RGfduso

@Hongyu-yu
Copy link
Contributor Author

Hongyu-yu commented Oct 24, 2024

@gabor1 Thanks for the quick feedback!
For the results, I have recently attended an MLIP competition and tried MACE and MACE-KAN on perovskite (t1). MACE ends with 1.0 meV/atom and 7.5 meV/A while MACE-KAN ends with 0.9 meV/atom and 5.2 meV/A, which we see an obvious better result for forces (same for other datasets in the competition). For more details and official benchmark, we are still running and will make it public as soon as possible.
And thanks for reminding us of the statistical errors and plot confusion! We will add random number testing and give more convincing results with plots that have y starting from zero.
Thanks again for your valuable feedback!

@ilyes319
Copy link
Contributor

ilyes319 commented Oct 24, 2024

Dear @Hongyu-yu, thank you very much for you PR and nice work.
Could you please make the PR as small as possible, and break down any changes that are not KAN readout related into another PR. The best would be that only the blocks.py, run_train.py, argparser and model.py would be changed + an additional test.

Just an extra note, I am very skeptical about KANs in general, I am bit surprised that it leads to any real improvement. So I would wait to see more results on MACE to merge that.

@Hongyu-yu
Copy link
Contributor Author

Hongyu-yu commented Oct 24, 2024

@ilyes319 Thanks for the quick feedback!
Actually the changes in this PR are all related to KAN while some of them are just transforming torch.save/load with dill which is needed to deploy KAN. The core changes are in mace/modules/models.py and mace/modules/blocks.py.
As for the additional test, I will add one tomorrow ASAP.
For the improvement, a systematic check and experiments on benchmarks are surely needed and we are working on them and I agree with you. By far, our results show that KAN gives better results in the competition above but not tested on benchmark yet. We open-sourced the code now to meet the competition requirements. We will update benchmark results ASAP and report them here. Maybe then we can come to a determined conclusion about whether mulKAN works.
From the view of spherical basis, in the last output layer, it's actually a mixture of basis to a scalar energy. MultiKAN (KAN 2.0) may provide a more complex combination than MLP between the basis given by MACE interaction part, which could result in better accuracy. This could be the reason why KAN works better in decoding the latent basis features.
Welcome to see your opinion!

@ilyes319
Copy link
Contributor

Thanks, how are the changes to torch.save related to KAN?

@Hongyu-yu
Copy link
Contributor Author

torch.save directly not work for KAN with error

INFO: Saving model to checkpoints/base_run-123.model
Traceback (most recent call last):
  File "/public/home/yuhongyu/anaconda3/envs/ace/bin/mace_run_train", line 8, in <module>
    sys.exit(main())
  File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/mace/cli/run_train.py", line 63, in main
    run(args)
  File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/mace/cli/run_train.py", line 734, in run
    torch.save(model, model_path)
  File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/torch/serialization.py", line 589, in _save
    pickler.dump(obj)
AttributeError: Can't pickle local object 'Symbolic_KANLayer.__init__.<locals>.<listcomp>.<listcomp>.<lambda>'

But works with torch.save(model, model_path, pickle_module=dill)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants