Skip to content

Commit

Permalink
remove svn download command, switch to git
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Feb 17, 2024
1 parent 0c948f2 commit 0a2146c
Showing 1 changed file with 40 additions and 27 deletions.
67 changes: 40 additions & 27 deletions tests/test_paper_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
##################################################################

import os
import warnings
import unittest
import shutil
import subprocess

import numpy as np
Expand All @@ -16,25 +18,33 @@
os.mkdir(ci_data_folder)


def download_models(models_url):
def download_models(repository_urls, folder_name):
"""
function to download model directly from github url
"""
for model_url in models_url:
model_folder_name = os.path.basename(model_url)
if not os.path.exists(os.path.join(ci_data_folder, model_folder_name)):
download_args = ["svn", "export", model_url, os.path.join(ci_data_folder, model_folder_name)]
res = subprocess.Popen(download_args, stdout=subprocess.PIPE)
output, _error = res.communicate()
if not _error:
pass
else:
raise ConnectionError(f"Error downloading the models {model_url}")
else: # if the model is cached on Github Action, do a sanity check on remote folder without downloading it
check_args = ["svn", "log", model_url]
result = subprocess.Popen(check_args)
text = result.communicate()[0]
assert result.returncode == 0, f"Remote folder does not exist at {model_url}"
repo_folder = repository_urls.split("/")[-1]
if not os.path.exists(os.path.join(ci_data_folder, folder_name)):
download_args = ["git", "clone", "-n", "--depth=1", "--filter=tree:0", repository_urls]
res = subprocess.Popen(download_args, stdout=subprocess.PIPE)
output, _error = res.communicate()

checkout_args = ["git", "sparse-checkout", "set", "--no-cone", folder_name]
res = subprocess.Popen(checkout_args, stdout=subprocess.PIPE, cwd=repo_folder)
output, _error = res.communicate()

checkout_args = ["git", "checkout"]
res = subprocess.Popen(checkout_args, stdout=subprocess.PIPE, cwd=repo_folder)
output, _error = res.communicate()

if not _error:
pass
else:
raise ConnectionError(f"Error downloading the models {folder_name} from {repository_urls}")

shutil.move(os.path.join(repo_folder, folder_name), os.path.join(ci_data_folder, folder_name))
shutil.rmtree(repo_folder)
else: # if the model is cached on Github Action, do a sanity check on remote folder without downloading it
warnings.warn(f"Folder {folder_name} already exists, skipping download")


class PapersModelsCase(unittest.TestCase):
Expand All @@ -51,10 +61,11 @@ def test_arXiv_1808_04428(self):

# first model
models_url = [
"https://github.com/henrysky/astroNN_spectra_paper_figures/trunk/astroNN_0606_run001",
"https://github.com/henrysky/astroNN_spectra_paper_figures/trunk/astroNN_0617_run001",
{"repository_urls": "https://github.com/henrysky/astroNN_spectra_paper_figures", "folder_name": "astroNN_0606_run001"},
{"repository_urls": "https://github.com/henrysky/astroNN_spectra_paper_figures", "folder_name": "astroNN_0617_run001"},
]
download_models(models_url)
for url in models_url:
download_models(**url)

opened_fits = fits.open(
visit_spectra(dr=14, location=4405, apogee="2M19060637+4717296")
Expand Down Expand Up @@ -99,11 +110,12 @@ def test_arXiv_1902_08634(self):

# first model
models_url = [
"https://github.com/henrysky/astroNN_gaia_dr2_paper/trunk/astroNN_no_offset_model",
"https://github.com/henrysky/astroNN_gaia_dr2_paper/trunk/astroNN_constant_model",
"https://github.com/henrysky/astroNN_gaia_dr2_paper/trunk/astroNN_multivariate_model",
{"repository_urls": "https://github.com/henrysky/astroNN_gaia_dr2_paper", "folder_name": "astroNN_no_offset_model"},
{"repository_urls": "https://github.com/henrysky/astroNN_gaia_dr2_paper", "folder_name": "astroNN_constant_model"},
{"repository_urls": "https://github.com/henrysky/astroNN_gaia_dr2_paper", "folder_name": "astroNN_multivariate_model"},
]
download_models(models_url)
for url in models_url:
download_models(**url)

opened_fits = fits.open(
visit_spectra(dr=14, location=4405, apogee="2M19060637+4717296")
Expand Down Expand Up @@ -174,12 +186,13 @@ def test_arXiv_2302_05479(self):

# first model
models_url = [
"https://github.com/henrysky/astroNN_ages/trunk/models/astroNN_VEncoderDecoder"
{"repository_urls": "https://github.com/henrysky/astroNN_ages", "folder_name": "models"},
]
download_models(models_url)

for url in models_url:
download_models(**url)

# load the trained encoder-decoder model with astroNN
neuralnet = load_folder(os.path.join(ci_data_folder, "astroNN_VEncoderDecoder"))
neuralnet = load_folder(os.path.join(ci_data_folder, "models/astroNN_VEncoderDecoder"))

# arbitrary spectrum
opened_fits = fits.open(
Expand Down

0 comments on commit 0a2146c

Please sign in to comment.