Skip to content

Commit

Permalink
Merge pull request #45 from flatironinstitute/dev
Browse files Browse the repository at this point in the history
v0.5.6
  • Loading branch information
asistradition authored Aug 13, 2021
2 parents 1111e3d + 8e0c089 commit f1e9d36
Show file tree
Hide file tree
Showing 26 changed files with 1,056 additions and 313 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pip wheel
python -m pip install -r requirements.txt
python -m pip install -r requirements-test.txt
python -m pip install -r requirements-multiprocessing.txt
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Inferelator 3.0

[![PyPI version](https://badge.fury.io/py/inferelator.svg)](https://badge.fury.io/py/inferelator)
[![Travis](https://travis-ci.org/flatironinstitute/inferelator.svg?branch=release)](https://travis-ci.org/flatironinstitute/inferelator)
[![CI](https://github.com/flatironinstitute/inferelator/actions/workflows/python-package.yml/badge.svg)](https://github.com/flatironinstitute/inferelator/actions/workflows/python-package.yml/)
[![codecov](https://codecov.io/gh/flatironinstitute/inferelator/branch/release/graph/badge.svg)](https://codecov.io/gh/flatironinstitute/inferelator)
[![Documentation Status](https://readthedocs.org/projects/inferelator/badge/?version=latest)](https://inferelator.readthedocs.io/en/latest/?badge=latest)

Expand Down
72 changes: 62 additions & 10 deletions inferelator/amusr_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import copy
import gc
import warnings
import pandas as pd
from inferelator import utils

from inferelator.utils import Debug
from inferelator import workflow
Expand Down Expand Up @@ -41,6 +43,9 @@ class MultitaskLearningWorkflow(single_cell_workflow.SingleCellWorkflow):
# Multi-task result processor
_result_processor_driver = ResultsProcessorMultiTask

# Prior noise taskwise flag
add_prior_noise_to_task_priors = True

@property
def _num_obs(self):
if self._task_objects is not None:
Expand All @@ -64,6 +69,13 @@ def _num_tfs(self):
else:
return None

@property
def _gene_names(self):
if self._task_objects is not None:
return set().union([t if t is not None else [] for t in map(lambda x: x.data.gene_names, self._task_objects)])
else:
return None

def set_task_filters(self, regulator_expression_filter=None, target_expression_filter=None):
"""
Set the filtering criteria for regulators and targets between tasks
Expand All @@ -89,6 +101,11 @@ def startup_run(self):
"""

self.get_data()

# Set the random seed in the task to the same as the parent
for tobj in self._task_objects:
tobj.random_seed = self.random_seed

self.validate_data()

def get_data(self):
Expand Down Expand Up @@ -196,9 +213,6 @@ def _load_tasks(self):
except AttributeError:
pass

# Set the random seed in the task to the same as the parent
tobj.random_seed = self.random_seed

# Set the num_bootstraps in the task to the same as the parent
tobj.num_bootstraps = self.num_bootstraps

Expand All @@ -218,8 +232,8 @@ def validate_data(self):
:raises ValueError: Raises a ValueError if any tasks have invalid priors or gold standard structures
"""
if self.gold_standard is None:
raise ValueError("A gold standard must be provided to `gold_standard_file` in MultiTaskLearningWorkflow")

super().validate_data(check_prior=False)

# Check to see if there are any tasks which don't have priors
no_priors = sum(map(lambda x: x.priors_data is None, self._task_objects))
Expand All @@ -231,7 +245,21 @@ def _process_default_priors(self):
Process the default priors in the parent workflow for crossvalidation or shuffling
"""

priors = self.priors_data if self.priors_data is not None else self.gold_standard.copy()
# Use priors if given to the MTL workflow
if self.priors_data is not None:
priors = self.priors_data

# If they all have priors don't worry about it - use a 0 prior here for crossvalidation selection if needed
elif self.priors_data is None and self.gold_standard is not None:
priors = pd.DataFrame(0, index=self.gold_standard.index, columns=self.gold_standard.columns)

elif self.priors_data is None and self.tf_names is not None:
priors = pd.DataFrame(0, index=self._gene_names, columns=self.tf_names)

# If there's no gold standard or use_no_prior isn't set, raise a RuntimeError
else:
_msg = "No base prior or gold standard or TF list has been provided."
raise RuntimeError(_msg)

# Crossvalidation
if self.split_gold_standard_for_crossvalidation:
Expand All @@ -251,6 +279,19 @@ def _process_default_priors(self):
if self.shuffle_prior_axis is not None:
priors = self.prior_manager.shuffle_priors(priors, self.shuffle_prior_axis, self.random_seed)

# Add prior noise now (to the base prior) if add_prior_noise_to_task_priors is False
# Otherwise add later to the task priors (will be different for each task)
if self.add_prior_noise is not None and not self.add_prior_noise_to_task_priors:
priors = self.prior_manager.add_prior_noise(priors, self.add_prior_noise, self.random_seed)

_has_prior = [t.priors_data is not None for t in self._task_objects]
if any(_has_prior):
_msg = "Overriding task priors in {tn} because add_prior_noise_to_task_priors is False"
utils.Debug.vprint(_msg.format(tn=_has_prior), level=0)

for t in self._task_objects:
t.priors_data = priors.copy()

# Reset the priors_data in the parent workflow if it exists
self.priors_data = priors if self.priors_data is not None else None

Expand All @@ -263,22 +304,25 @@ def _process_task_priors(self):
# Set priors if task-specific priors are not present
if task_obj.priors_data is None and self.priors_data is None:
raise ValueError("No priors exist in the main workflow or in tasks")

elif task_obj.priors_data is None:
task_obj.priors_data = self.priors_data.copy()

# Set gene names if task-specific gene names is not present
if task_obj.gene_names is None:
task_obj.gene_names = copy.copy(self.gene_names)
task_obj.gene_names = copy.deepcopy(self.gene_names)

# Set tf_names if task-specific tf names are not present
if task_obj.tf_names is None:
task_obj.tf_names = copy.copy(self.tf_names)
task_obj.tf_names = copy.deepcopy(self.tf_names)

_add_prior_noise = self.add_prior_noise if self.add_prior_noise_to_task_priors is True else None
# Process priors in the task data
task_obj.process_priors_and_gold_standard(gold_standard=self.gold_standard,
cv_flag=self.split_gold_standard_for_crossvalidation,
cv_axis=self.cv_split_axis,
shuffle_priors=self.shuffle_prior_axis)
shuffle_priors=self.shuffle_prior_axis,
add_prior_noise=_add_prior_noise)

def _process_task_data(self):
"""
Expand Down Expand Up @@ -437,15 +481,19 @@ def set_run_parameters(self):

warnings.warn("Task-specific `num_bootstraps` and `random_seed` is not supported. Set on parent workflow.")

def process_priors_and_gold_standard(self, gold_standard=None, cv_flag=None, cv_axis=None, shuffle_priors=None):
def process_priors_and_gold_standard(self, gold_standard=None, cv_flag=None, cv_axis=None, shuffle_priors=None,
add_prior_noise=None):
"""
Make sure that the priors for this task are correct
This will remove circularity from the task priors based on the parent gold standard
"""

gold_standard = self.gold_standard if gold_standard is None else gold_standard
cv_flag = self.split_gold_standard_for_crossvalidation if cv_flag is None else cv_flag
cv_axis = self.cv_split_axis if cv_axis is None else cv_axis
shuffle_priors = self.shuffle_prior_axis if shuffle_priors is None else shuffle_priors
add_prior_noise = self.add_prior_noise if add_prior_noise is None else add_prior_noise

# Remove circularity from the gold standard
if cv_flag:
Expand All @@ -462,6 +510,10 @@ def process_priors_and_gold_standard(self, gold_standard=None, cv_flag=None, cv_
if shuffle_priors is not None:
self.priors_data = self.prior_manager.shuffle_priors(self.priors_data, shuffle_priors, self.random_seed)

if add_prior_noise is not None:
self.priors_data = self.prior_manager.add_prior_noise(self.priors_data, add_prior_noise,
self.random_seed)

if min(self.priors_data.shape) == 0:
raise ValueError("Priors for task {n} have an axis of length 0".format(n=self.task_name))

Expand Down
Empty file.
153 changes: 153 additions & 0 deletions inferelator/benchmarking/celloracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import gc
import copy

from inferelator import utils
from inferelator.single_cell_workflow import SingleCellWorkflow
from inferelator.regression.base_regression import _RegressionWorkflowMixin

import numpy as np

# These are required to run this module but nothing else
# They are therefore not package dependencies
import scanpy as sc
import celloracle as co


class CellOracleWorkflow(SingleCellWorkflow):

oracle = None

def startup_finish(self):
"""
Skip inferelator preprocessing and do celloracle preprocessing
As per https://github.com/morris-lab/CellOracle/issues/58
"""

self.align_priors_and_expression()

self.data.convert_to_float()

adata = self.data._adata

if "paga" not in adata.uns:
utils.Debug.vprint("Normalizing data {sh}".format(sh=adata.shape))

sc.pp.filter_genes(adata, min_counts=1)
sc.pp.normalize_per_cell(adata, key_n_counts='n_counts_all')

adata.raw = adata
adata.layers["raw_count"] = adata.raw.X.copy()

utils.Debug.vprint("Scaling data")

sc.pp.log1p(adata)
sc.pp.scale(adata)

utils.Debug.vprint("PCA Preprocessing")

sc.tl.pca(adata, svd_solver='arpack')

utils.Debug.vprint("Diffmap Preprocessing")

sc.pp.neighbors(adata, n_neighbors=4, n_pcs=20)
sc.tl.diffmap(adata)
sc.pp.neighbors(adata, n_neighbors=10, use_rep='X_diffmap')

utils.Debug.vprint("Clustering Preprocessing")

sc.tl.louvain(adata, resolution=0.8)
sc.tl.paga(adata, groups='louvain')
sc.pl.paga(adata)
sc.tl.draw_graph(adata, init_pos='paga', random_state=123)

# Restore counts
adata.X = adata.layers["raw_count"].copy()

else:
# Assume all the preprocessing is done and just move along

utils.Debug.vprint("Using saved preprocessing for CellOracle")


@staticmethod
def reprocess_prior_to_base_GRN(priors_data):

base_GRN = priors_data.copy()
base_GRN.index.name = "Target"
base_GRN = base_GRN.melt(ignore_index=False, var_name="Regulator").reset_index()
base_GRN = base_GRN.loc[base_GRN['value'] != 0, :].copy()
base_GRN.drop("value", axis=1, inplace=True)
return {k: v["Regulator"].tolist() for k, v in base_GRN.groupby("Target")}


@staticmethod
def reprocess_co_output_to_inferelator_results(co_out):

betas = [r.pivot(index='target', columns='source', values='coef_mean').fillna(0) for k, r in co_out.items()]
rankers = [r.pivot(index='target', columns='source', values='-logp').fillna(0) for k, r in co_out.items()]

return betas, rankers


class CellOracleRegression(_RegressionWorkflowMixin):

oracle_imputation = True

def run_regression(self):

utils.Debug.vprint("Creating Oracle Object")

# Set up oracle object
oracle = co.Oracle()
oracle.import_anndata_as_raw_count(adata=self.data._adata,
cluster_column_name="louvain",
embedding_name="X_pca")

# Apparently PCA is not transferred from the adata object
oracle.perform_PCA(100)

# Add prior
oracle.addTFinfo_dictionary(self.reprocess_prior_to_base_GRN(self.priors_data))

utils.Debug.vprint("Imputation Preprocessing")

if self.oracle_imputation:

# Heuristics from Celloracle documentation
n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))>0.002))[0][0]
k = int(0.025 * oracle.adata.shape[0])

# Make sure n_comps is between 10 and 50
# It likes to go to 0 for noise controls
n_comps = max(min(n_comps, 50), 10)

# Make sure k is at least 25 too I guess
k = max(k, 25)

oracle.knn_imputation(n_pca_dims=n_comps, k=k, balanced=True, b_sight=k*8,
b_maxl=k*4, n_jobs=4)

# Pretend to do imputation
else:
oracle.adata.layers["imputed_count"] = oracle.adata.layers["normalized_count"].copy()

utils.Debug.vprint("CellOracle GRN inference")

# Call GRN inference
links = oracle.get_links(cluster_name_for_GRN_unit="louvain", alpha=10,
verbose_level=0, test_mode=False)

# Deepcopy the result dict that we want
result = copy.deepcopy(links.links_dict)

# Try to clean up some of these circular references
del links
del oracle
del self.data._adata
del self.data

# Call an explicit GC cycle
gc.collect()

return self.reprocess_co_output_to_inferelator_results(result)
Loading

0 comments on commit f1e9d36

Please sign in to comment.