Skip to content

Commit

Permalink
torch dev for ring_cnn + 2p spatial
Browse files Browse the repository at this point in the history
  • Loading branch information
mannypaeza committed Aug 26, 2024
1 parent e33642e commit 9532291
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
26 changes: 13 additions & 13 deletions caiman/source_extraction/cnmf/online_cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
if self.params.get('online', 'path_to_model') is None or self.params.get('online', 'sniper_mode') is False:
loaded_model = None
self.params.set('online', {'sniper_mode': False})
# self.tf_in = None
# self.tf_out = None
self.use_torch = None #fix
self.tf_in = None
self.tf_out = None
# self.use_torch = None
else:
try:
from keras.models import load_model
Expand All @@ -340,12 +340,12 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
# uses online model -> be careful
model_path = ".".join(path + ["keras"])
loaded_model = model_load(model_path)
self.use_torch = False
# self.use_torch = False
else:
model_path = '.'.join(path + ['pt'])
loaded_model = load_graph(model_path)
loaded_model = torch.load(model_file)
self.use_torch = True
# self.use_torch = True

self.loaded_model = loaded_model

Expand Down Expand Up @@ -547,8 +547,8 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
sniper_mode=self.params.get('online', 'sniper_mode'),
use_peak_max=self.params.get('online', 'use_peak_max'),
mean_buff=self.estimates.mean_buff,
# tf_in=self.tf_in, tf_out=self.tf_out,
use_torch=self.use_torch,
tf_in=self.tf_in, tf_out=self.tf_out,
# use_torch=self.use_torch,
ssub_B=ssub_B, W=self.estimates.W if self.is1p else None,
b0=self.estimates.b0 if self.is1p else None,
corr_img=self.estimates.corr_img if use_corr else None,
Expand Down Expand Up @@ -2003,8 +2003,8 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
patch_size=50, loaded_model=None, test_both=False,
thresh_CNN_noisy=0.5, use_peak_max=False,
thresh_std_peak_resid = 1, mean_buff=None,
# tf_in=None, tf_out=None):
use_torch=None):
tf_in=None, tf_out=None):
# use_torch=None):
"""
Extract new candidate components from the residual buffer and test them
using space correlation or the CNN classifier. The function runs the CNN
Expand Down Expand Up @@ -2146,8 +2146,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
corr_img=None, first_moment=None, second_moment=None,
crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None,
max_img=None, downscale_matrix=None, upscale_matrix=None,
# tf_in=None, tf_out=None):
torch_in=None, torch_out=None):
tf_in=None, tf_out=None):
# torch_in=None, torch_out=None):
"""
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
"""
Expand Down Expand Up @@ -2177,8 +2177,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50,
loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy,
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff,
# tf_in=tf_in, tf_out=tf_out)
torch_in=torch_in, torch_out=torch_out)
tf_in=tf_in, tf_out=tf_out)
#torch_in=torch_in, torch_out=torch_out)

ind_new_all = ijsig_all

Expand Down
4 changes: 2 additions & 2 deletions caiman/utils/nn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def fit_NL_model(model_NL, Y, patience=5, val_split=0.2, batch_size=32,
Y = np.expand_dims(Y, axis=-1)
run_logdir = get_run_logdir()
os.mkdir(run_logdir)
path_to_model = os.path.join(run_logdir, 'model.h5')
path_to_model = os.path.join(run_logdir, 'model.weights.h5')
chk = ModelCheckpoint(filepath=path_to_model,
verbose=0, save_best_only=True, save_weights_only=True)
es = EarlyStopping(monitor='val_loss', patience=patience,
Expand All @@ -566,7 +566,7 @@ def fit_NL_model(model_NL, Y, patience=5, val_split=0.2, batch_size=32,
history_NL = model_NL.fit(Y, Y, epochs=epochs, batch_size=batch_size,
shuffle=True, validation_split=val_split,
callbacks=callbacks)
model_NL.load_weights(os.path.join(run_logdir, 'model.h5'))
model_NL.load_weights(os.path.join(run_logdir, 'model.weights.h5'))
return model_NL, history_NL, path_to_model

def get_MCNN_model(Y, gSig=5, n_channels=8, lr=1e-4, pct=10, r_factor=1.5,
Expand Down

0 comments on commit 9532291

Please sign in to comment.