diff --git a/caiman/source_extraction/cnmf/online_cnmf.py b/caiman/source_extraction/cnmf/online_cnmf.py index ce7a55ca7..1fe25d75d 100644 --- a/caiman/source_extraction/cnmf/online_cnmf.py +++ b/caiman/source_extraction/cnmf/online_cnmf.py @@ -323,9 +323,9 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None): if self.params.get('online', 'path_to_model') is None or self.params.get('online', 'sniper_mode') is False: loaded_model = None self.params.set('online', {'sniper_mode': False}) - # self.tf_in = None - # self.tf_out = None - self.use_torch = None #fix + self.tf_in = None + self.tf_out = None + # self.use_torch = None else: try: from keras.models import load_model @@ -340,12 +340,12 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None): # uses online model -> be careful model_path = ".".join(path + ["keras"]) loaded_model = model_load(model_path) - self.use_torch = False + # self.use_torch = False else: model_path = '.'.join(path + ['pt']) loaded_model = load_graph(model_path) loaded_model = torch.load(model_file) - self.use_torch = True + # self.use_torch = True self.loaded_model = loaded_model @@ -547,8 +547,8 @@ def fit_next(self, t, frame_in, num_iters_hals=3): sniper_mode=self.params.get('online', 'sniper_mode'), use_peak_max=self.params.get('online', 'use_peak_max'), mean_buff=self.estimates.mean_buff, - # tf_in=self.tf_in, tf_out=self.tf_out, - use_torch=self.use_torch, + tf_in=self.tf_in, tf_out=self.tf_out, + # use_torch=self.use_torch, ssub_B=ssub_B, W=self.estimates.W if self.is1p else None, b0=self.estimates.b0 if self.is1p else None, corr_img=self.estimates.corr_img if use_corr else None, @@ -2003,8 +2003,8 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5), patch_size=50, loaded_model=None, test_both=False, thresh_CNN_noisy=0.5, use_peak_max=False, thresh_std_peak_resid = 1, mean_buff=None, - # tf_in=None, tf_out=None): - use_torch=None): + tf_in=None, tf_out=None): + # use_torch=None): """ Extract new candidate components from the residual buffer and test them using space correlation or the CNN classifier. The function runs the CNN @@ -2146,8 +2146,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf, corr_img=None, first_moment=None, second_moment=None, crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None, max_img=None, downscale_matrix=None, upscale_matrix=None, - # tf_in=None, tf_out=None): - torch_in=None, torch_out=None): + tf_in=None, tf_out=None): + # torch_in=None, torch_out=None): """ Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests """ @@ -2177,8 +2177,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf, sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50, loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy, use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff, - # tf_in=tf_in, tf_out=tf_out) - torch_in=torch_in, torch_out=torch_out) + tf_in=tf_in, tf_out=tf_out) + #torch_in=torch_in, torch_out=torch_out) ind_new_all = ijsig_all diff --git a/caiman/utils/nn_models.py b/caiman/utils/nn_models.py index 262e97dc1..adc3f02f0 100644 --- a/caiman/utils/nn_models.py +++ b/caiman/utils/nn_models.py @@ -555,7 +555,7 @@ def fit_NL_model(model_NL, Y, patience=5, val_split=0.2, batch_size=32, Y = np.expand_dims(Y, axis=-1) run_logdir = get_run_logdir() os.mkdir(run_logdir) - path_to_model = os.path.join(run_logdir, 'model.h5') + path_to_model = os.path.join(run_logdir, 'model.weights.h5') chk = ModelCheckpoint(filepath=path_to_model, verbose=0, save_best_only=True, save_weights_only=True) es = EarlyStopping(monitor='val_loss', patience=patience, @@ -566,7 +566,7 @@ def fit_NL_model(model_NL, Y, patience=5, val_split=0.2, batch_size=32, history_NL = model_NL.fit(Y, Y, epochs=epochs, batch_size=batch_size, shuffle=True, validation_split=val_split, callbacks=callbacks) - model_NL.load_weights(os.path.join(run_logdir, 'model.h5')) + model_NL.load_weights(os.path.join(run_logdir, 'model.weights.h5')) return model_NL, history_NL, path_to_model def get_MCNN_model(Y, gSig=5, n_channels=8, lr=1e-4, pct=10, r_factor=1.5,