Skip to content

Commit

Permalink
torch edits for 2p + ring cnn
Browse files Browse the repository at this point in the history
  • Loading branch information
mannypaeza committed Aug 26, 2024
1 parent 3251217 commit e33642e
Show file tree
Hide file tree
Showing 13 changed files with 7,542 additions and 119 deletions.
57 changes: 26 additions & 31 deletions caiman/components_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import os
import peakutils
import tensorflow as tf
import torch
import scipy
from scipy.sparse import csc_matrix
from scipy.stats import norm
Expand Down Expand Up @@ -273,42 +273,37 @@ def evaluate_components_CNN(A,
if not isGPU and 'CAIMAN_ALLOW_GPU' not in os.environ:
print("GPU run not requested, disabling use of GPUs")
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
try:
os.environ["KERAS_BACKEND"] = "tensorflow"
from tensorflow.keras.models import model_from_json
use_keras = True
logger.info('Using Keras')
try:
os.environ["KERAS_BACKEND"] = "torch"
from keras.models import model_load
use_keras = True
logging.info('Using Keras')
except (ModuleNotFoundError):
use_keras = False
logger.info('Using Tensorflow')
use_keras = False
logging.info('Using Torch')

if loaded_model is None:
if use_keras:
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".json")):
model_file = os.path.join(caiman_datadir(), model_name + ".json")
model_weights = os.path.join(caiman_datadir(), model_name + ".h5")
elif os.path.isfile(model_name + ".json"):
model_file = model_name + ".json"
model_weights = model_name + ".h5"
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".keras")):
model_file = os.path.join(caiman_datadir(), model_name + ".keras")
elif os.path.isfile(model_name + ".keras"):
model_file = model_name + ".keras"
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")
with open(model_file, 'r') as json_file:
print(f"USING MODEL (keras API): {model_file}")
loaded_model_json = json_file.read()

loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(model_name + '.h5')

print(f"USING MODEL (keras API): {model_file}")
loaded_model = model_load(model_file)
else:
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".h5.pb")):
model_file = os.path.join(caiman_datadir(), model_name + ".h5.pb")
elif os.path.isfile(model_name + ".h5.pb"):
model_file = model_name + ".h5.pb"
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".pt")):
model_file = os.path.join(caiman_datadir(), model_name + ".pt")
elif os.path.isfile(model_name + ".pt"):
model_file = model_name + ".pt"
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")
print(f"USING MODEL (tensorflow API): {model_file}")
loaded_model = caiman.utils.utils.load_graph(model_file)
loaded_model = torch.load(model_file)

logger.debug("Loaded model from disk")
logging.debug("Loaded model from disk")

half_crop = np.minimum(gSig[0] * 4 + 1, patch_size), np.minimum(gSig[1] * 4 + 1, patch_size)
dims = np.array(dims)
Expand All @@ -323,11 +318,11 @@ def evaluate_components_CNN(A,
if use_keras:
predictions = loaded_model.predict(final_crops[:, :, :, np.newaxis], batch_size=32, verbose=1)
else:
tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_20_input:0')
tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
with tf.Session(graph=loaded_model) as sess:
predictions = sess.run(tf_out, feed_dict={tf_in: final_crops[:, :, :, np.newaxis]})
sess.close()
final_crops = torch.tensor(final_crops, dtype=torch.float32)
final_crops = torch.reshape(final_crops, (-1, final_crops.shape[-1],
final_crops.shape[1], final_crops.shape[2]))
with torch.no_grad():
prediction = loaded_model(final_crops[:, np.newaxis, :, :])

return predictions, final_crops

Expand Down
74 changes: 42 additions & 32 deletions caiman/source_extraction/cnmf/online_cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
imaging data in real time. In Advances in Neural Information Processing Systems
(pp. 2381-2391).
@url http://papers.nips.cc/paper/6832-onacid-online-analysis-of-calcium-imaging-data-in-real-time
Implemented in PyTorch
Date: July 18, 2024
"""

import cv2
Expand All @@ -26,7 +29,7 @@
from scipy.stats import norm
from sklearn.decomposition import NMF
from sklearn.preprocessing import normalize
import tensorflow as tf
import torch
from time import time

import caiman
Expand Down Expand Up @@ -320,34 +323,30 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
if self.params.get('online', 'path_to_model') is None or self.params.get('online', 'sniper_mode') is False:
loaded_model = None
self.params.set('online', {'sniper_mode': False})
self.tf_in = None
self.tf_out = None
# self.tf_in = None
# self.tf_out = None
self.use_torch = None #fix
else:
try:
from tensorflow.keras.models import model_from_json
logger.info('Using Keras')
try:
from keras.models import load_model
logging.info('Using Keras')
use_keras = True
except(ModuleNotFoundError):
use_keras = False
logger.info('Using Tensorflow')
use_keras = False
logging.info('Using Torch')

path = self.params.get('online', 'path_to_model').split(".")[:-1]
if use_keras:
path = self.params.get('online', 'path_to_model').split(".")[:-1]
json_path = ".".join(path + ["json"])
model_path = ".".join(path + ["h5"])
json_file = open(json_path, 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(model_path)
self.tf_in = None
self.tf_out = None
else:
path = self.params.get('online', 'path_to_model').split(".")[:-1]
model_path = '.'.join(path + ['h5', 'pb'])
# uses online model -> be careful
model_path = ".".join(path + ["keras"])
loaded_model = model_load(model_path)
self.use_torch = False
else:
model_path = '.'.join(path + ['pt'])
loaded_model = load_graph(model_path)
self.tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_1_input:0')
self.tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
loaded_model = tf.Session(graph=loaded_model)
loaded_model = torch.load(model_file)
self.use_torch = True

self.loaded_model = loaded_model

if self.is1p:
Expand Down Expand Up @@ -548,7 +547,8 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
sniper_mode=self.params.get('online', 'sniper_mode'),
use_peak_max=self.params.get('online', 'use_peak_max'),
mean_buff=self.estimates.mean_buff,
tf_in=self.tf_in, tf_out=self.tf_out,
# tf_in=self.tf_in, tf_out=self.tf_out,
use_torch=self.use_torch,
ssub_B=ssub_B, W=self.estimates.W if self.is1p else None,
b0=self.estimates.b0 if self.is1p else None,
corr_img=self.estimates.corr_img if use_corr else None,
Expand Down Expand Up @@ -2003,7 +2003,8 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
patch_size=50, loaded_model=None, test_both=False,
thresh_CNN_noisy=0.5, use_peak_max=False,
thresh_std_peak_resid = 1, mean_buff=None,
tf_in=None, tf_out=None):
# tf_in=None, tf_out=None):
use_torch=None):
"""
Extract new candidate components from the residual buffer and test them
using space correlation or the CNN classifier. The function runs the CNN
Expand Down Expand Up @@ -2084,12 +2085,19 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
Ain2 /= np.std(Ain2,axis=1)[:,None]
Ain2 = np.reshape(Ain2,(-1,) + tuple(np.diff(ijSig_cnn).squeeze()),order= 'F')
Ain2 = np.stack([cv2.resize(ain,(patch_size ,patch_size)) for ain in Ain2])
if tf_in is None:
if use_torch is None:
predictions = loaded_model.predict(Ain2[:,:,:,np.newaxis], batch_size=min_num_trial, verbose=0)
keep_cnn = list(np.where(predictions[:, 0] > thresh_CNN_noisy)[0])
else:
predictions = loaded_model.run(tf_out, feed_dict={tf_in: Ain2[:, :, :, np.newaxis]})
keep_cnn = list(np.where(predictions[:, 0] > thresh_CNN_noisy)[0])
cnn_pos = Ain2[keep_cnn]
final_crops = torch.tensor(Ain2, dtype=torch.float32)
final_crops = torch.reshape(Ain2, (-1, Ain2.shape[-1],
Ain2.shape[1], Ain2.shape[2]))
with torch.no_grad():
prediction = loaded_model(Ain2[:, np.newaxis, :, :])
keep_cnn = list(torch.where(predictions[:, 0] > thresh_CNN_noisy)[0])

cnn_pos = Ain2[keep_cnn] #Make sure this works
# tensor.numpy() also works
else:
keep_cnn = [] # list(range(len(Ain_cnn)))

Expand Down Expand Up @@ -2138,7 +2146,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
corr_img=None, first_moment=None, second_moment=None,
crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None,
max_img=None, downscale_matrix=None, upscale_matrix=None,
tf_in=None, tf_out=None):
# tf_in=None, tf_out=None):
torch_in=None, torch_out=None):
"""
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
"""
Expand Down Expand Up @@ -2168,7 +2177,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50,
loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy,
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff,
tf_in=tf_in, tf_out=tf_out)
# tf_in=tf_in, tf_out=tf_out)
torch_in=torch_in, torch_out=torch_out)

ind_new_all = ijsig_all

Expand Down
49 changes: 49 additions & 0 deletions caiman/tests/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python

import numpy as np
import os
import keras

from caiman.paths import caiman_datadir
from caiman.utils.utils import load_graph

try:
os.environ["KERAS_BACKEND"] = "torch"
from keras.models import load_model
use_keras = True
except(ModuleNotFoundError):
import torch
use_keras = False

def test_torch():
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

try:
model_name = os.path.join(caiman_datadir(), 'model', 'cnn_model')
if use_keras:
model_file = model_name + ".keras"
print('USING MODEL:' + model_file)

loaded_model = load_model(model_file)
loaded_model.compile('sgd', 'mse')
elif use_keras == True:
model_file = model_name + ".pth"
loaded_model = torch.load(model_file)
except:
raise Exception(f'NN model could not be loaded. use_keras = {use_keras}')

A = np.random.randn(10, 50, 50, 1)
try:
if use_keras == False:
predictions = loaded_model.predict(A, batch_size=32)
elif use_keras == True:
A = torch.tensor(A, dtype=torch.float32)
A = torch.reshape(A, (-1, A.shape[-1], A.shape[1], A.shape[2]))
with torch.no_grad():
predictions = loaded_model(A)
pass
except:
raise Exception('NN model could not be deployed. use_keras = ' + str(use_keras))

if __name__ == "__main__":
test_torch()
7 changes: 7 additions & 0 deletions caiman/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env python
import pkg_resources

from caiman.train.train_cnn_model_keras import cnn_model_keras, save_model_keras, load_model_keras
from caiman.train.train_cnn_model_pytorch import cnn_model_pytorch, train_test_split, train, validate, get_batch_accuracy, save_model_pytorch, load_model_pytorch

__version__ = pkg_resources.get_distribution('caiman').version
Loading

0 comments on commit e33642e

Please sign in to comment.