diff --git a/_freeze/posts/elastic-metric/osteosarcoma_analysis/execute-results/html.json b/_freeze/posts/elastic-metric/osteosarcoma_analysis/execute-results/html.json index 728f801..a89cba2 100644 --- a/_freeze/posts/elastic-metric/osteosarcoma_analysis/execute-results/html.json +++ b/_freeze/posts/elastic-metric/osteosarcoma_analysis/execute-results/html.json @@ -1,8 +1,8 @@ { - "hash": "2ae8a1628c2596fc937d8be63db017bb", + "hash": "858362269bc7f2063e87e4cc3fd076b4", "result": { "engine": "jupyter", - "markdown": "---\ntitle: Shape Analysis of Cancer Cells\nauthor: Wanxin Li\ndate: \"August 15, 2024\"\ncategories: [biology, bioinformatics] \n---\n\n\nThis notebook is adapted from [this notebook](https://github.com/geomstats/geomstats/blob/main/notebooks/11_real_world_applications__cell_shapes_analysis.ipynb) (Lead author: Nina Miolane). \n\nThis notebook studies *Osteosarcoma* (bone cancer) cells and the impact of drug treatment on their *morphological shapes*, by analyzing cell images obtained from fluorescence microscopy. \n\nThis analysis relies on the *elastic metric between discrete curves* from Geomstats. We will study to which extent this metric can detect how the cell shape is associated with the response to treatment.\n\nThis notebook is adapted from Florent Michel's submission to the [ICLR 2021 Computational Geometry and Topology challenge](https://github.com/geomstats/challenge-iclr-2021).\n\n
\n \n
\n\nFigure 1: Representative images of the cell lines using fluorescence microscopy, studied in this notebook (Image credit : Ashok Prasad). The cells nuclei (blue), the actin cytoskeleton (green) and the lipid membrane (red) of each cell are stained and colored. We only focus on the cell shape in our analysis.\n\n# 1. Introduction and Motivation\n\nBiological cells adopt a variety of shapes, determined by multiple processes and biophysical forces under the control of the cell. These shapes can be studied with different quantitative measures that reflect the cellular morphology [(MGCKCKDDRTWSBCC2018)](#References). With the emergence of large-scale biological cell image data, morphological studies have many applications. For example, measures of irregularity and spreading of cells allow accurate classification and discrimination between cancer cell lines treated with different drugs [(AXCFP2019)](#References).\n\nAs metrics defined on the shape space of curves, the *elastic metrics* [(SKJJ2010)](#References) implemented in Geomstats are a potential tool for analyzing and comparing biological cell shapes. Their associated geodesics and geodesic distances provide a natural framework for optimally matching, deforming, and comparing cell shapes.\n\n::: {#42291cac .cell execution_count=1}\n``` {.python .cell-code}\nfrom decimal import Decimal\nimport matplotlib.pyplot as plt\n\nimport geomstats.backend as gs\nimport numpy as np\nfrom common import *\nimport random\nimport os\nimport scipy.stats as stats\nfrom sklearn import manifold\n\ngs.random.seed(2021)\n```\n:::\n\n\n::: {#139743cd .cell execution_count=2}\n``` {.python .cell-code}\nbase_path = \"/home/wanxinli/dyn/dyn/\"\ndata_path = os.path.join(base_path, \"datasets\")\n\ndataset_name = 'osteosarcoma'\nfigs_dir = os.path.join(\"/home/wanxinli/dyn/dyn/figs\", dataset_name)\nsavefig = False\n\n# If compute for the first time, we need to compute pairwise distances and run DeCOr-MDS\n# Otherwise, we can just use the pre-computed results\nfirst_time = False\nif savefig:\n print(f\"Will save figs to {figs_dir}\")\n```\n:::\n\n\n# 2. Dataset Description\n\nWe study a dataset of mouse *Osteosarcoma* imaged cells [(AXCFP2019)](#References). The dataset contains two different cancer cell lines : *DLM8* and *DUNN*, respectively representing a more agressive and a less agressive cancer. Among these cells, some have also been treated with different single drugs that perturb the cellular cytoskeleton. Overall, we can label each cell according to their cell line (*DLM8* and *DUNN*), and also if it is a *control* cell (no treatment), or has been treated with one of the following drugs : *Jasp* (jasplakinolide) and *Cytd* (cytochalasin D).\n\nEach cell comes from a raw image containing a set of cells, which was thresholded to generate binarized images.\n\n\n \n\n\nAfter binarizing the images, contouring was used to isolate each cell, and to extract their boundaries as a counter-clockwise ordered list of 2D coordinates, which corresponds to the representation of discrete curve in Geomstats. We load these discrete curves into the notebook.\n\n::: {#9146ad00 .cell execution_count=3}\n``` {.python .cell-code}\nimport geomstats.datasets.utils as data_utils\n\ncells, lines, treatments = data_utils.load_cells()\nprint(f\"Total number of cells : {len(cells)}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nTotal number of cells : 650\n```\n:::\n:::\n\n\nThe cells are grouped by treatment class in the dataset : \n- the *control* cells, \n- the cells treated with *Cytd*,\n- and the ones treated with *Jasp*. \n\nAdditionally, in each of these classes, there are two cell lines : \n- the *DLM8* cells, and\n- the *DUNN* ones.\n\nBefore using the dataset, we check for duplicates in the dataset.\n\nWe compute the pairwise distance between two cells. If the pairwise distance is smaller than 0.1, we visualize the corresponding cells to check they are duplicates.\n\n::: {#7b0beca0 .cell execution_count=4}\n``` {.python .cell-code}\ntol = 1e-1\nfor i, cell_i in enumerate(cells):\n for j, cell_j in enumerate(cells):\n if i != j and cell_i.shape[0] == cell_j.shape[0]:\n dist = np.sum(np.sqrt(np.sum((cell_i-cell_j)**2,axis=1)))\n if dist < tol:\n print(f\"cell indices are: {i} and {j}, {lines[i]}, {lines[j]}, {treatments[i]}, {treatments[j]}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncell indices are: 363 and 396, dlm8, dlm8, cytd, cytd\ncell indices are: 396 and 363, dlm8, dlm8, cytd, cytd\ncell indices are: 513 and 519, dlm8, dlm8, jasp, jasp\ncell indices are: 519 and 513, dlm8, dlm8, jasp, jasp\n```\n:::\n:::\n\n\n::: {#eff201bf .cell execution_count=5}\n``` {.python .cell-code}\npair_indices = [363, 396]\n\nfig = plt.figure(figsize=(10, 5))\nfig.add_subplot(121)\nindex_0 = pair_indices[0]\nplt.scatter(cells[index_0][:, 0], cells[index_0][:, 1], s=4)\nplt.axis(\"equal\")\nplt.title(f\"Cell {index_0}\")\n\nfig.add_subplot(122)\nindex_1 = pair_indices[1]\nplt.scatter(cells[index_1][:, 0], cells[index_1][:, 1], s=4)\nplt.axis(\"equal\")\nplt.title(f\"Cell {index_1}\")\n```\n\n::: {.cell-output .cell-output-display execution_count=84}\n```\nText(0.5, 1.0, 'Cell 396')\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-6-output-2.png){width=809 height=431}\n:::\n:::\n\n\n::: {#16ba3131 .cell execution_count=6}\n``` {.python .cell-code}\npair_indices = [513, 519]\n\nfig = plt.figure(figsize=(10, 5))\nfig.add_subplot(121)\nindex_0 = pair_indices[0]\nplt.scatter(cells[index_0][:, 0], cells[index_0][:, 1], s=4)\nplt.axis(\"equal\")\nplt.title(f\"Cell {index_0}\")\n\nfig.add_subplot(122)\nindex_1 = pair_indices[1]\nplt.scatter(cells[index_1][:, 0], cells[index_1][:, 1], s=4)\nplt.axis(\"equal\")\nplt.title(f\"Cell {index_1}\")\n```\n\n::: {.cell-output .cell-output-display execution_count=85}\n```\nText(0.5, 1.0, 'Cell 519')\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-7-output-2.png){width=809 height=431}\n:::\n:::\n\n\nCheck the category indices in order to remove corresponding cells in `ds_align`\n\n::: {#cd8116b8 .cell execution_count=7}\n``` {.python .cell-code}\ndelete_indices = [363, 396, 513, 519]\ncategory_count = {}\nglobal_count = 0\nfor i in range(len(cells)):\n treatment = treatments[i]\n line = lines[i]\n if treatment not in category_count:\n category_count[treatment] = {}\n if line not in category_count[treatment]:\n category_count[treatment][line] = 0\n # if global_count in delete_indices:\n # print(treatment, line, category_count[treatment][line])\n category_count[treatment][line] += 1\n global_count += 1\n```\n:::\n\n\nSince 363th, 396th and 513th, 519th are duplicates of each other and after visualization we see they are poor quality cells with overlapping adjacent cells, we remove them from our dataset. \n\n::: {#acb715d8 .cell execution_count=8}\n``` {.python .cell-code}\ndef remove_cells(cells, lines, treatments, delete_indices):\n \"\"\" \n Remove cells of control group from cells, lines and treatments\n\n :param list[int] delete_indices: the indices to delete\n \"\"\"\n delete_indices = sorted(delete_indices, reverse=True) # to prevent change in index when deleting elements\n \n # Delete elements\n cells = del_arr_elements(cells, delete_indices)\n lines = list(np.delete(np.array(lines), delete_indices, axis=0))\n treatments = list(np.delete(np.array(treatments), delete_indices, axis=0))\n\n return cells, lines, treatments\n```\n:::\n\n\n::: {#368f2a21 .cell execution_count=9}\n``` {.python .cell-code}\ndelete_indices = [363, 396, 513, 519]\ncells, lines, treatments = remove_cells(cells, lines, treatments, delete_indices)\n# print(len(cells), len(lines), len(treatments))\n```\n:::\n\n\nThis is shown by displaying the unique elements in the lists `treatments` and `lines`:\n\n::: {#b0f7461c .cell execution_count=10}\n``` {.python .cell-code}\nimport pandas as pd\n\nTREATMENTS = gs.unique(treatments)\nprint(TREATMENTS)\nLINES = gs.unique(lines)\nprint(LINES)\nMETRICS = ['SRV', 'Linear']\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n['control' 'cytd' 'jasp']\n['dlm8' 'dunn']\n```\n:::\n:::\n\n\nThe size of each class is displayed below:\n\n::: {#10f8e0fd .cell execution_count=11}\n``` {.python .cell-code}\nds = {}\n\nn_cells_arr = gs.zeros((3, 2))\n\nfor i, treatment in enumerate(TREATMENTS):\n print(f\"{treatment} :\")\n ds[treatment] = {}\n for j, line in enumerate(LINES):\n to_keep = gs.array(\n [\n one_treatment == treatment and one_line == line\n for one_treatment, one_line in zip(treatments, lines)\n ]\n )\n ds[treatment][line] = [\n cell_i for cell_i, to_keep_i in zip(cells, to_keep) if to_keep_i\n ]\n nb = len(ds[treatment][line])\n print(f\"\\t {nb} {line}\")\n n_cells_arr[i, j] = nb\n\nn_cells_df = pd.DataFrame({\"dlm8\": n_cells_arr[:, 0], \"dunn\": n_cells_arr[:, 1]})\nn_cells_df = n_cells_df.set_index(TREATMENTS)\n\ndisplay(n_cells_df)\n# display(ds)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncontrol :\n\t 114 dlm8\n\t 204 dunn\ncytd :\n\t 80 dlm8\n\t 93 dunn\njasp :\n\t 60 dlm8\n\t 95 dunn\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
dlm8dunn
control114.0204.0
cytd80.093.0
jasp60.095.0
\n
\n```\n:::\n:::\n\n\nWe have organized the cell data into the dictionnary `ds`. Before proceeding to the actual data analysis, we provide an auxiliary function `apply_func_to_ds`.\n\n::: {#6d7e68a1 .cell execution_count=12}\n``` {.python .cell-code}\ndef apply_func_to_ds(input_ds, func):\n \"\"\"Apply the input function func to the input dictionnary input_ds.\n\n This function goes through the dictionnary structure and applies\n func to every cell in input_ds[treatment][line].\n\n It stores the result in a dictionnary output_ds that is returned\n to the user.\n\n Parameters\n ----------\n input_ds : dict\n Input dictionnary, with keys treatment-line.\n func : callable\n Function to be applied to the values of the dictionnary, i.e.\n the cells.\n\n Returns\n -------\n output_ds : dict\n Output dictionnary, with the same keys as input_ds.\n \"\"\"\n output_ds = {}\n for treatment in TREATMENTS:\n output_ds[treatment] = {}\n for line in LINES:\n output_list = []\n for one_cell in input_ds[treatment][line]:\n output_list.append(func(one_cell))\n output_ds[treatment][line] = gs.array(output_list)\n return output_ds\n```\n:::\n\n\nNow we can move on to the actual data analysis, starting with a preprocessing of the cell boundaries.\n\n# 3. Preprocessing \n\n### Interpolation: Encoding Discrete Curves With Same Number of Points\n\nAs we need discrete curves with the same number of sampled points to compute pairwise distances, the following interpolation is applied to each curve, after setting the number of sampling points.\n\nTo set up the number of sampling points, you can edit the following line in the next cell:\n\n::: {#eb5c4847 .cell execution_count=13}\n``` {.python .cell-code}\ndef interpolate(curve, nb_points):\n \"\"\"Interpolate a discrete curve with nb_points from a discrete curve.\n\n Returns\n -------\n interpolation : discrete curve with nb_points points\n \"\"\"\n old_length = curve.shape[0]\n interpolation = gs.zeros((nb_points, 2))\n incr = old_length / nb_points\n pos = 0\n for i in range(nb_points):\n index = int(gs.floor(pos))\n interpolation[i] = curve[index] + (pos - index) * (\n curve[(index + 1) % old_length] - curve[index]\n )\n pos += incr\n return interpolation\n\n\nk_sampling_points = 2000\n```\n:::\n\n\nTo illustrate the result of this interpolation, we compare for a randomly chosen cell the original curve with the correponding interpolated one (to visualize another cell, you can simply re-run the code).\n\n::: {#f783b7be .cell execution_count=14}\n``` {.python .cell-code}\nindex = 0\ncell_rand = cells[index]\ncell_interpolation = interpolate(cell_rand, k_sampling_points)\n\nfig = plt.figure(figsize=(15, 5))\n\nfig.add_subplot(121)\nplt.scatter(cell_rand[:, 0], cell_rand[:, 1], color='black', s=4)\n\nplt.plot(cell_rand[:, 0], cell_rand[:, 1])\nplt.axis(\"equal\")\nplt.title(f\"Original curve ({len(cell_rand)} points)\")\nplt.axis(\"off\")\n\nfig.add_subplot(122)\nplt.scatter(cell_interpolation[:, 0], cell_interpolation[:, 1], color='black', s=4)\n\nplt.plot(cell_interpolation[:, 0], cell_interpolation[:, 1])\nplt.axis(\"equal\")\nplt.title(f\"Interpolated curve ({k_sampling_points} points)\")\nplt.axis(\"off\")\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"interpolation.svg\"))\n plt.savefig(os.path.join(figs_dir, \"interpolation.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-15-output-1.png){width=1135 height=409}\n:::\n:::\n\n\nAs the interpolation is working as expected, we use the auxiliary function `apply_func_to_ds` to apply the function `func=interpolate` to the dataset `ds`, i.e. the dictionnary containing the cells boundaries.\n\nWe obtain a new dictionnary, `ds_interp`, with the interpolated cell boundaries.\n\n::: {#69915936 .cell execution_count=15}\n``` {.python .cell-code}\nds_interp = apply_func_to_ds(\n input_ds=ds, func=lambda x: interpolate(x, k_sampling_points)\n)\n```\n:::\n\n\nThe shape of an array of cells in `ds_interp[treatment][cell]` is therefore: `(\"number of cells in treatment-line\", \"number of sampling points\", 2)`, where 2 refers to the fact that we are considering cell shapes in 2D. \n\n### Visualization of Interpolated Dataset of Curves\n\nWe visualize the curves obtained, for a sample of control cells and treated cells (top row shows control, i.e. non-treated cells; bottom rows shows treated cells) across cell lines (left and blue for dlm8 and right and orange for dunn).\n\n::: {#5bc7bb7f .cell execution_count=16}\n``` {.python .cell-code}\nn_cells_to_plot = 5\n# radius = 800\n\nfig = plt.figure(figsize=(16, 6))\ncount = 1\nfor i, treatment in enumerate(TREATMENTS):\n for line in LINES:\n cell_data = ds_interp[treatment][line]\n for i_to_plot in range(n_cells_to_plot):\n cell = gs.random.choice(cell_data)\n fig.add_subplot(3, 2 * n_cells_to_plot, count)\n count += 1\n plt.plot(cell[:, 0], cell[:, 1], color=\"C\" + str(i))\n # plt.xlim(-radius, radius)\n # plt.ylim(-radius, radius)\n plt.axis(\"equal\")\n plt.axis(\"off\")\n if i_to_plot == n_cells_to_plot // 2:\n plt.title(f\"{treatment} - {line}\", fontsize=20)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"sample_cells.svg\"))\n plt.savefig(os.path.join(figs_dir, \"sample_cells.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-17-output-1.png){width=1210 height=491}\n:::\n:::\n\n\nVisual inspection of these curves seems to indicate more protusions appearing in treated cells, compared with control ones. This is in agreement with the physiological impact of the drugs, which are known to perturb the internal cytoskeleton connected to the cell membrane. Using the elastic metric, our goal will be to see if we can quantitatively confirm these differences.\n\n### Remove duplicate sample points in curves\n\nDuring interpolation it is likely that some of the discrete curves in the dataset are downsampled from higher number of discrete data points to lower number of data points. Hence, two sampled data points that are close enough may end up overlapping after interpolation and hence such data points have to be dealt with specifically. \n\n::: {#4e405157 .cell execution_count=17}\n``` {.python .cell-code}\nimport numpy as np\n\ndef preprocess(curve, tol=1e-10):\n \"\"\"Preprocess curve to ensure that there are no consecutive duplicate points.\n\n Returns\n -------\n curve : discrete curve\n \"\"\"\n\n dist = curve[1:] - curve[:-1]\n dist_norm = np.sqrt(np.sum(np.square(dist), axis=1))\n\n if np.any( dist_norm < tol ):\n for i in range(len(curve)-1):\n if np.sqrt(np.sum(np.square(curve[i+1] - curve[i]), axis=0)) < tol:\n curve[i+1] = (curve[i] + curve[i+2]) / 2\n\n return curve\n```\n:::\n\n\n::: {#4e7c32ec .cell execution_count=18}\n``` {.python .cell-code}\nds_proc = apply_func_to_ds(ds_interp, func=lambda x: preprocess(x))\n```\n:::\n\n\nCheck we did not loss any cells after duplicates\n\n::: {#790de69a .cell execution_count=19}\n``` {.python .cell-code}\nfor treatment in TREATMENTS:\n for line in LINES:\n for metric in METRICS:\n print(f\"{treatment} and {line} using {metric}: {len(ds_proc[treatment][line])}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncontrol and dlm8 using SRV: 114\ncontrol and dlm8 using Linear: 114\ncontrol and dunn using SRV: 204\ncontrol and dunn using Linear: 204\ncytd and dlm8 using SRV: 80\ncytd and dlm8 using Linear: 80\ncytd and dunn using SRV: 93\ncytd and dunn using Linear: 93\njasp and dlm8 using SRV: 60\njasp and dlm8 using Linear: 60\njasp and dunn using SRV: 95\njasp and dunn using Linear: 95\n```\n:::\n:::\n\n\n### Alignment\n\nOur goal is to study the cell boundaries in our dataset, as points in a shape space of closed curves quotiented by translation, scaling, and rotation, so these transformations do not affect our measure of distance between curves.\n\nIn practice, we apply functions that were initially designed to center (substract the barycenter), rescale (divide by the Frobenius norm) and then reparameterize (only for SRV metric).\n\nSince the alignment procedure takes 30 minutes, we ran `osteosarocoma_align.py` and saved the results in `~/dyn/datasets/osteosarcoma/aligned`\n\nLoad aligned cells from txt files. These files were generated by calling `align` function in `common.py`.\n\nWe get the aligned cells from preprocessed dataset.\n\nFurthermore, we align the barycenters of the cells to the barycenter of the projected base curve, and (optionally) flip the cell.\n\n::: {#98a18697 .cell execution_count=20}\n``` {.python .cell-code}\ndef align_barycenter(cell, centroid_x, centroid_y, flip):\n \"\"\" \n Align the the barycenter of the cell to ref centeriod and flip the cell against the xaxis of the centriod if flip is True. \n\n :param 2D np array cell: cell to align\n :param float centroid_x: the x coordinates of the projected BASE_CURVE\n :param float centroid_y: the y coordinates of the projected BASE_CURVE\n :param bool flip: flip the cell against x = centroid x if True \n \"\"\"\n \n cell_bc = np.mean(cell, axis=0)\n aligned_cell = cell+[centroid_x, centroid_y]-cell_bc\n\n if flip:\n aligned_cell[:, 0] = 2*centroid_x-aligned_cell[:, 0]\n # Flip the order of the points\n med_index = int(np.floor(aligned_cell.shape[0]/2))\n flipped_aligned_cell = np.concatenate((aligned_cell[med_index:], aligned_cell[:med_index]), axis=0)\n flipped_aligned_cell = np.flipud(flipped_aligned_cell)\n aligned_cell = flipped_aligned_cell\n return aligned_cell\n\ndef get_centroid(base_curve):\n total_space = DiscreteCurvesStartingAtOrigin(k_sampling_points=k_sampling_points)\n proj_base_curve = total_space.projection(base_curve)\n base_centroid = np.mean(proj_base_curve, axis=0)\n return base_centroid[0], base_centroid[1]\n```\n:::\n\n\n::: {#d9bd9379 .cell execution_count=21}\n``` {.python .cell-code}\ndelete_indices = [363, 396, 513, 519]\n\naligned_base_folder = os.path.join(data_path, dataset_name, \"aligned\")\n\nBASE_CURVE = generate_ellipse(k_sampling_points)\ncentroid_x, centroid_y = get_centroid(BASE_CURVE)\n\nds_align = {}\n\nfor metric in METRICS:\n ds_align[metric] = {}\n if metric == 'SRV':\n aligned_folder = os.path.join(aligned_base_folder, 'projection_rescale_rotation_reparameterization')\n elif metric == 'Linear':\n aligned_folder = os.path.join(aligned_base_folder, 'projection_rescale_rotation_reparameterization')\n for treatment in TREATMENTS:\n ds_align[metric][treatment] = {}\n for line in LINES:\n ds_align[metric][treatment][line] = []\n cell_num = len(ds_proc[treatment][line])\n if line == 'dlm8' and (treatment == 'cytd' or treatment == 'jasp'):\n cell_num += 2\n for i in range(cell_num):\n # Do not load duplicate cells\n # cytd dlm8 45\n # cytd dlm8 78\n # jasp dlm8 20\n # jasp dlm8 26\n\n if (treatment == 'cytd' and line == 'dlm8' and (i == 45 or i == 78)) or \\\n (treatment == 'jasp' and line == 'dlm8' and (i == 20 or i == 26)):\n continue\n \n file_path = os.path.join(aligned_folder, f\"{treatment}_{line}_{i}.txt\")\n if os.path.exists(file_path):\n cell = np.loadtxt(file_path)\n ds_align[metric][treatment][line].append(cell)\n\n```\n:::\n\n\nCheck we did not loss any cells after alignment\n\n::: {#9b9f6ca7 .cell execution_count=22}\n``` {.python .cell-code}\nfor treatment in TREATMENTS:\n for line in LINES:\n for metric in METRICS:\n print(f\"{treatment} and {line} using {metric}: {len(ds_align[metric][treatment][line])}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncontrol and dlm8 using SRV: 113\ncontrol and dlm8 using Linear: 113\ncontrol and dunn using SRV: 199\ncontrol and dunn using Linear: 199\ncytd and dlm8 using SRV: 74\ncytd and dlm8 using Linear: 74\ncytd and dunn using SRV: 92\ncytd and dunn using Linear: 92\njasp and dlm8 using SRV: 56\njasp and dlm8 using Linear: 56\njasp and dunn using SRV: 91\njasp and dunn using Linear: 91\n```\n:::\n:::\n\n\nUpdate `lines` and `treatments`\n\n::: {#885469f0 .cell execution_count=23}\n``` {.python .cell-code}\ntreatments = []\nlines = []\nfor treatment in TREATMENTS:\n for line in LINES:\n treatments.extend([treatment]*len(ds_align['SRV'][treatment][line]))\n lines.extend([line]*len(ds_align['SRV'][treatment][line]))\n\ntreatments = np.array(treatments)\nlines = np.array(lines)\nprint(\"treatment length is:\", len(treatments), \"lines length is:\", len(lines))\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ntreatment length is: 625 lines length is: 625\n```\n:::\n:::\n\n\nVisualize reference cell, unaligned cell and aligned cell.\n\n::: {#64d37242 .cell execution_count=24}\n``` {.python .cell-code}\nindex = 0\nmetric = 'SRV'\nunaligned_cell = ds_proc[\"control\"][\"dlm8\"][index]\naligned_cell = ds_align[metric][\"control\"][\"dlm8\"][index]\n\nfirst_round_aligned_folder = os.path.join(aligned_base_folder, 'projection_rescale_rotation_reparameterization_first_round')\nreference_path = os.path.join(first_round_aligned_folder, f\"reference.txt\")\nmean_first_round = np.loadtxt(reference_path)\n\nfig = plt.figure(figsize=(15, 5))\n\nfig.add_subplot(131)\nplt.plot(mean_first_round[:, 0], mean_first_round[:, 1])\nplt.plot([mean_first_round[-1, 0], mean_first_round[0, 0]], [mean_first_round[-1, 1], mean_first_round[0, 1]], 'tab:blue')\nplt.scatter(mean_first_round[:, 0], mean_first_round[:, 1], s=4, c='black')\nplt.plot(mean_first_round[0, 0], mean_first_round[0, 1], \"ro\")\nplt.axis(\"equal\")\nplt.title(\"Reference curve\")\n\nfig.add_subplot(132)\nplt.plot(unaligned_cell[:, 0], unaligned_cell[:, 1])\nplt.scatter(unaligned_cell[:, 0], unaligned_cell[:, 1], s=4, c='black')\nplt.plot(unaligned_cell[0, 0], unaligned_cell[0, 1], \"ro\")\nplt.axis(\"equal\")\nplt.title(\"Unaligned curve\")\n\nfig.add_subplot(133)\nplt.plot(aligned_cell[:, 0], aligned_cell[:, 1])\nplt.scatter(aligned_cell[:, 0], aligned_cell[:, 1], s=4, c='black')\nplt.plot(aligned_cell[0, 0], aligned_cell[0, 1], \"ro\")\nplt.axis(\"equal\")\nplt.title(\"Aligned curve\")\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"alignment.svg\"))\n plt.savefig(os.path.join(figs_dir, \"alignment.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-25-output-1.png){width=1185 height=431}\n:::\n:::\n\n\nIn the plot above, the red dot shows the start of the parametrization of each curve. The right curve has been rotated from the curve in the middle, to be aligned with the left (reference) curve, which represents the first cell of the dataset. The starting point (in red) of this right curve has been also set to align with the reference.\n\n# 4 Data Analysis\n\n## Compute Mean Cell Shape of the Whole Dataset: \"Global\" Mean Shape\n\nWe want to compute the mean cell shape of the whole dataset. Thus, we first combine all the cell shape data into a single array.\n\n::: {#c3dffaca .cell execution_count=25}\n``` {.python .cell-code}\nCURVES_SPACE_SRV = DiscreteCurvesStartingAtOrigin(ambient_dim=2, k_sampling_points=k_sampling_points)\n```\n:::\n\n\n::: {#44946358 .cell execution_count=26}\n``` {.python .cell-code}\ncell_shapes_list = {}\nfor metric in METRICS:\n cell_shapes_list[metric] = []\n for treatment in TREATMENTS:\n for line in LINES:\n cell_shapes_list[metric].extend(ds_align[metric][treatment][line])\n\ncell_shapes = {}\nfor metric in METRICS:\n cell_shapes[metric] = gs.array(cell_shapes_list[metric])\nprint(cell_shapes['SRV'].shape)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n(625, 1999, 2)\n```\n:::\n:::\n\n\nRemove outliers using DeCOr-MDS, together for DUNN and DLM8 cell lines.\n\n::: {#2188512b .cell execution_count=27}\n``` {.python .cell-code}\ndef linear_dist(cell1, cell2):\n return gs.linalg.norm(cell1 - cell2)\n\ndef srv_dist(cell1, cell2):\n CURVES_SPACE_SRV.equip_with_metric(SRVMetric)\n return CURVES_SPACE_SRV.metric.dist(cell1, cell2)\n \n# compute pairwise distances, we only need to compute it once and save the results \npairwise_dists = {}\n\nif first_time:\n metric = 'SRV'\n pairwise_dists[metric] = parallel_dist(cell_shapes[metric], srv_dist, k_sampling_points)\n\n metric = 'Linear' \n pairwise_dists[metric] = parallel_dist(cell_shapes[metric], linear_dist, k_sampling_points)\n\n for metric in METRICS:\n np.savetxt(os.path.join(data_path, dataset_name, \"distance_matrix\", f\"{metric}_matrix.txt\"), pairwise_dists[metric])\nelse:\n for metric in METRICS:\n pairwise_dists[metric] = np.loadtxt(os.path.join(data_path, dataset_name, \"distance_matrix\", f\"{metric}_matrix.txt\"))\n```\n:::\n\n\n::: {#cce807fb .cell execution_count=28}\n``` {.python .cell-code}\n# to remove 132 and 199\none_cell = cell_shapes['Linear'][199]\nplt.plot(one_cell[:, 0], one_cell[:, 1], c=f\"gray\")\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-29-output-1.png){width=599 height=411}\n:::\n:::\n\n\n::: {#2f91c4c7 .cell execution_count=29}\n``` {.python .cell-code}\n# run DeCOr-MDS\nmetric = 'SRV'\ndim_start = 2 # we know the subspace dimension is 3, we set start and end to 3 to reduce runtime \ndim_end = 10\n# dim_start = 3\n# dim_end = 3\nstd_multi = 1\nif first_time:\n subspace_dim, outlier_indices = find_subspace_dim(pairwise_dists[metric], dim_start, dim_end, std_multi)\n print(f\"subspace dimension is: {subspace_dim}\")\n print(f\"outlier_indices are: {outlier_indices}\")\n```\n:::\n\n\nVisualize outlier cells to see if they are artifacts\n\n::: {#d7c78148 .cell execution_count=30}\n``` {.python .cell-code}\nif first_time:\n fig, axes = plt.subplots(\n nrows= 1,\n ncols=len(outlier_indices),\n figsize=(2*len(outlier_indices), 2),\n )\n\n for i, outlier_index in enumerate(outlier_indices):\n one_cell = cell_shapes[metric][outlier_index]\n ax = axes[i]\n ax.plot(one_cell[:, 0], one_cell[:, 1], c=f\"C{j}\")\n ax.set_title(f\"{outlier_index}\", fontsize=14)\n # Turn off tick labels\n ax.set_yticklabels([])\n ax.set_xticklabels([])\n ax.set_xticks([])\n ax.set_yticks([])\n ax.spines[\"top\"].set_visible(False)\n ax.spines[\"right\"].set_visible(False)\n ax.spines[\"bottom\"].set_visible(False)\n ax.spines[\"left\"].set_visible(False)\n\n plt.tight_layout()\n plt.suptitle(f\"\", y=-0.01, fontsize=24)\n # plt.savefig(os.path.join(figs_dir, \"outlier.svg\"))\n```\n:::\n\n\n::: {#8d951a52 .cell execution_count=31}\n``` {.python .cell-code}\ndelete_indices = [132, 199]\n\n\nfig, axes = plt.subplots(\n nrows= 1,\n ncols=len(delete_indices),\n figsize=(2*len(delete_indices), 2),\n)\n\n\nfor i, outlier_index in enumerate(delete_indices):\n one_cell = cell_shapes[metric][outlier_index]\n ax = axes[i]\n ax.plot(one_cell[:, 0], one_cell[:, 1], c=f\"gray\")\n ax.set_title(f\"{outlier_index}\", fontsize=14)\n # ax.axis(\"off\")\n # Turn off tick labels\n ax.set_yticklabels([])\n ax.set_xticklabels([])\n ax.set_xticks([])\n ax.set_yticks([])\n ax.spines[\"top\"].set_visible(False)\n ax.spines[\"right\"].set_visible(False)\n ax.spines[\"bottom\"].set_visible(False)\n ax.spines[\"left\"].set_visible(False)\n\nplt.tight_layout()\nplt.suptitle(f\"\", y=-0.01, fontsize=24)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"delete_outlier.svg\"))\n plt.savefig(os.path.join(figs_dir, \"delete_outlier.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-32-output-1.png){width=374 height=182}\n:::\n:::\n\n\nAfter visual inspection, we decide to remove the outlier cells\n\n::: {#4c1469e0 .cell execution_count=32}\n``` {.python .cell-code}\ndef remove_ds_two_layer(ds, delete_indices):\n global_i = sum(len(v) for values in ds.values() for v in values.values())-1\n\n for treatment in reversed(list(ds.keys())):\n treatment_values = ds[treatment]\n for line in reversed(list(treatment_values.keys())):\n line_cells = treatment_values[line]\n for i, _ in reversed(list(enumerate(line_cells))):\n if global_i in delete_indices:\n print(np.array(ds[treatment][line][:i]).shape, np.array(ds[treatment][line][i+1:]).shape)\n if len(np.array(ds[treatment][line][:i]).shape) == 1:\n ds[treatment][line] = np.array(ds[treatment][line][i+1:])\n elif len(np.array(ds[treatment][line][i+1:]).shape) == 1:\n ds[treatment][line] = np.array(ds[treatment][line][:i])\n else:\n ds[treatment][line] = np.concatenate((np.array(ds[treatment][line][:i]), np.array(ds[treatment][line][i+1:])), axis=0) \n global_i -= 1\n return ds\n\n\n\ndef remove_cells_two_layer(cells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align, delete_indices):\n \"\"\" \n Remove cells of control group from cells, cell_shapes, ds,\n the parameters returned from load_treated_osteosarcoma_cells\n Also update n_cells\n\n :param list[int] delete_indices: the indices to delete\n \"\"\"\n delete_indices = sorted(delete_indices, reverse=True) # to prevent change in index when deleting elements\n \n # Delete elements\n cells = del_arr_elements(cells, delete_indices) \n lines = list(np.delete(np.array(lines), delete_indices, axis=0))\n treatments = list(np.delete(np.array(treatments), delete_indices, axis=0))\n ds_proc = remove_ds_two_layer(ds_proc, delete_indices)\n \n for metric in METRICS:\n cell_shapes[metric] = np.delete(np.array(cell_shapes[metric]), delete_indices, axis=0)\n ds_align[metric] = remove_ds_two_layer(ds_align[metric], delete_indices)\n pairwise_dists[metric] = np.delete(pairwise_dists[metric], delete_indices, axis=0)\n pairwise_dists[metric] = np.delete(pairwise_dists[metric], delete_indices, axis=1)\n\n\n return cells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align\n```\n:::\n\n\n::: {#2971cb07 .cell execution_count=33}\n``` {.python .cell-code}\ncells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align = remove_cells_two_layer(cells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align, delete_indices)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n(85, 2000, 2) (118, 2000, 2)\n(18, 2000, 2) (184, 2000, 2)\n(86, 1999, 2) (112, 1999, 2)\n(19, 1999, 2) (178, 1999, 2)\n(86, 1999, 2) (112, 1999, 2)\n(19, 1999, 2) (178, 1999, 2)\n```\n:::\n:::\n\n\nCheck we did not loss any other cells after the removal\n\n::: {#7a4ba1fc .cell execution_count=34}\n``` {.python .cell-code}\ndef check_num(cell_shapes, treatments, lines, pairwise_dists, ds_align):\n \n print(f\"treatments number is: {len(treatments)}, lines number is: {len(lines)}\")\n for metric in METRICS:\n print(f\"pairwise_dists for {metric} shape is: {pairwise_dists[metric].shape}\")\n print(f\"cell_shapes for {metric} number is : {len(cell_shapes[metric])}\")\n \n for line in LINES:\n for treatment in TREATMENTS:\n print(f\"ds_align {treatment} {line} using {metric}: {len(ds_align[metric][treatment][line])}\")\n```\n:::\n\n\n::: {#d3d61119 .cell execution_count=35}\n``` {.python .cell-code}\ncheck_num(cell_shapes, treatments, lines, pairwise_dists, ds_align)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ntreatments number is: 623, lines number is: 623\npairwise_dists for SRV shape is: (623, 623)\ncell_shapes for SRV number is : 623\nds_align control dlm8 using SRV: 113\nds_align cytd dlm8 using SRV: 74\nds_align jasp dlm8 using SRV: 56\nds_align control dunn using SRV: 197\nds_align cytd dunn using SRV: 92\nds_align jasp dunn using SRV: 91\npairwise_dists for Linear shape is: (623, 623)\ncell_shapes for Linear number is : 623\nds_align control dlm8 using Linear: 113\nds_align cytd dlm8 using Linear: 74\nds_align jasp dlm8 using Linear: 56\nds_align control dunn using Linear: 197\nds_align cytd dunn using Linear: 92\nds_align jasp dunn using Linear: 91\n```\n:::\n:::\n\n\nWe compute the mean cell shape by using the SRV metric defined on the space of curves' shapes. The space of curves' shape is a manifold: we use the Frechet mean, associated to the SRV metric, to get the mean cell shape.\n\nDo not include cells with duplicate points when calculating the mean shapes\n\n::: {#449de4e2 .cell execution_count=36}\n``` {.python .cell-code}\ndef check_duplicate(cell):\n \"\"\" \n Return true if there are duplicate points in the cell\n \"\"\"\n for i in range(cell.shape[0]-1):\n cur_coord = cell[i]\n next_coord = cell[i+1]\n if np.linalg.norm(cur_coord-next_coord) == 0:\n return True\n \n # Checking the last point vs the first poit\n if np.linalg.norm(cell[-1]-cell[0]) == 0:\n return True\n \n return False\n```\n:::\n\n\n::: {#fbeada91 .cell execution_count=37}\n``` {.python .cell-code}\ndelete_indices = []\nfor metric in METRICS:\n for i, cell in reversed(list(enumerate(cell_shapes[metric]))):\n if check_duplicate(cell):\n if i not in delete_indices:\n delete_indices.append(i)\n\n\ncells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align = \\\n remove_cells_two_layer(cells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align, delete_indices)\n\n```\n:::\n\n\nRecheck cell number after removing cells with duplicated points\n\n::: {#c8505730 .cell execution_count=38}\n``` {.python .cell-code}\ncheck_num(cell_shapes, treatments, lines, pairwise_dists, ds_align)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ntreatments number is: 623, lines number is: 623\npairwise_dists for SRV shape is: (623, 623)\ncell_shapes for SRV number is : 623\nds_align control dlm8 using SRV: 113\nds_align cytd dlm8 using SRV: 74\nds_align jasp dlm8 using SRV: 56\nds_align control dunn using SRV: 197\nds_align cytd dunn using SRV: 92\nds_align jasp dunn using SRV: 91\npairwise_dists for Linear shape is: (623, 623)\ncell_shapes for Linear number is : 623\nds_align control dlm8 using Linear: 113\nds_align cytd dlm8 using Linear: 74\nds_align jasp dlm8 using Linear: 56\nds_align control dunn using Linear: 197\nds_align cytd dunn using Linear: 92\nds_align jasp dunn using Linear: 91\n```\n:::\n:::\n\n\n::: {#387ef673 .cell execution_count=39}\n``` {.python .cell-code}\nfrom geomstats.learning.frechet_mean import FrechetMean\n\nmetric = 'SRV'\nCURVES_SPACE_SRV = DiscreteCurvesStartingAtOrigin(ambient_dim=2, k_sampling_points=k_sampling_points)\nmean = FrechetMean(CURVES_SPACE_SRV)\nprint(cell_shapes[metric].shape)\ncells = cell_shapes[metric]\nmean.fit(cells)\n\nmean_estimate = mean.estimate_\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n(623, 1999, 2)\n```\n:::\n:::\n\n\n::: {#502f4e51 .cell execution_count=40}\n``` {.python .cell-code}\nmean_estimate_aligned = {}\n\nmean_estimate_clean = mean_estimate[~gs.isnan(gs.sum(mean_estimate, axis=1)), :]\nmean_estimate_aligned[metric] = (\n mean_estimate_clean - gs.mean(mean_estimate_clean, axis=0)\n)\n```\n:::\n\n\nAlso we compute the linear mean\n\n::: {#63b0b3ae .cell execution_count=41}\n``` {.python .cell-code}\nmetric = 'Linear'\nlinear_mean_estimate = gs.mean(cell_shapes[metric], axis=0)\nlinear_mean_estimate_clean = linear_mean_estimate[~gs.isnan(gs.sum(linear_mean_estimate, axis=1)), :]\n\nmean_estimate_aligned[metric] = (\n linear_mean_estimate_clean - gs.mean(linear_mean_estimate_clean, axis=0)\n)\n```\n:::\n\n\nPlot SRV mean cell versus linear mean cell\n\n::: {#dec784d9 .cell execution_count=42}\n``` {.python .cell-code}\nfig = plt.figure(figsize=(6, 3))\n\nfig.add_subplot(121)\nmetric = 'SRV'\nplt.plot(mean_estimate_aligned[metric][:, 0], mean_estimate_aligned[metric][:, 1])\nplt.axis(\"equal\")\nplt.title(\"SRV\")\nplt.axis(\"off\")\n\nfig.add_subplot(122)\nmetric = 'Linear'\nplt.plot(mean_estimate_aligned[metric][:, 0], mean_estimate_aligned[metric][:, 1])\nplt.axis(\"equal\")\nplt.title(\"Linear\")\nplt.axis(\"off\")\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"global_mean.svg\"))\n plt.savefig(os.path.join(figs_dir, \"global_mean.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-43-output-1.png){width=466 height=261}\n:::\n:::\n\n\n# Analyze Distances to the \"Global\" Mean Shape\n\nWe consider each of the subgroups of cells, defined by their treatment and cell line. We wish to study how far each of this group is from the global mean shape. We compute the list of distances to the global mean shape.\n\n::: {#5e4d4ee1 .cell execution_count=43}\n``` {.python .cell-code}\nmetric = 'SRV'\ndists_to_global_mean = {}\ndists_to_global_mean_list = {}\nprint(mean_estimate_aligned[metric].shape)\n\ndists_to_global_mean[metric] = apply_func_to_ds(\n ds_align[metric], \n func=lambda x: CURVES_SPACE_SRV.metric.dist(x, mean_estimate_aligned[metric])\n)\n\ndists_to_global_mean_list[metric] = []\nfor t in TREATMENTS:\n for l in LINES:\n dists_to_global_mean_list[metric].extend(dists_to_global_mean[metric][t][l])\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n(1999, 2)\n```\n:::\n:::\n\n\nCompute distances to linear mean\n\n::: {#aeae3043 .cell execution_count=44}\n``` {.python .cell-code}\nmetric = 'Linear'\ndists_to_global_mean[metric] = apply_func_to_ds(\n ds_align[metric], func=lambda x: gs.linalg.norm(mean_estimate_aligned[metric] - x) \n)\n\ndists_to_global_mean_list[metric] = []\nfor t in TREATMENTS:\n for l in LINES:\n dists_to_global_mean_list[metric].extend(dists_to_global_mean[metric][t][l])\n```\n:::\n\n\n::: {#84c49b54 .cell execution_count=45}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4))\n\nline = 'dlm8'\nkde_dict = {}\nfor j, metric in enumerate(METRICS):\n distances = []\n min_dists = min(dists_to_global_mean_list[metric])\n max_dists = max(dists_to_global_mean_list[metric])\n xx = gs.linspace(gs.floor(min_dists), gs.ceil(max_dists), k_sampling_points)\n kde_dict[metric] = {}\n for i, treatment in enumerate(TREATMENTS):\n distances = dists_to_global_mean[metric][treatment][line][~gs.isnan(dists_to_global_mean[metric][treatment][line])]\n \n \n axs[j].hist(distances, bins=20, alpha=0.4, density=True, label=treatment, color=f\"C{i}\")\n kde = stats.gaussian_kde(distances)\n kde_dict[metric][treatment] = kde\n axs[j].plot(xx, kde(xx), color=f\"C{i}\")\n axs[j].set_xlim((min_dists, max_dists))\n axs[j].legend(fontsize=12)\n\n axs[j].set_title(f\"{metric}\", fontsize=14)\n axs[j].set_ylabel(\"Fraction of cells\", fontsize=14)\n\n\n# fig.suptitle(\"Histograms of SRV distances to global mean cell\", fontsize=20)\n \nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_histogram.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_histogram.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-46-output-1.png){width=758 height=374}\n:::\n:::\n\n\nCalculate the ratio of overlapping regions formed by the kde curves\n\n::: {#f03cb1ef .cell execution_count=46}\n``` {.python .cell-code}\ndef calc_ratio(kde1, kde2, min, max):\n xx = np.linspace(min, max, 1000)\n kde1_values = kde1(xx)\n kde2_values = kde2(xx)\n\n overlap = np.minimum(kde1_values, kde2_values)\n overlap_area = np.trapz(overlap, xx)\n\n bound = np.maximum(kde1_values, kde2_values)\n bound_area = np.trapz(bound, xx)\n\n return overlap_area/bound_area\n```\n:::\n\n\n::: {#afe68cfc .cell execution_count=47}\n``` {.python .cell-code}\nfor metric in METRICS:\n min_dists = min(dists_to_global_mean_list[metric])\n max_dists = max(dists_to_global_mean_list[metric])\n for i, tmt1 in enumerate(TREATMENTS):\n for j in range(i+1, len(TREATMENTS)):\n tmt2 = TREATMENTS[j]\n ratio = calc_ratio(kde_dict[metric][tmt1], kde_dict[metric][tmt2], min_dists, max_dists)\n print(f\"Overlap ratio for {line} between {tmt1} and {tmt2} using {metric} metric is: {round(ratio, 2)}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nOverlap ratio for dlm8 between control and cytd using SRV metric is: 0.28\nOverlap ratio for dlm8 between control and jasp using SRV metric is: 0.53\nOverlap ratio for dlm8 between cytd and jasp using SRV metric is: 0.39\nOverlap ratio for dlm8 between control and cytd using Linear metric is: 0.43\nOverlap ratio for dlm8 between control and jasp using Linear metric is: 0.69\nOverlap ratio for dlm8 between cytd and jasp using Linear metric is: 0.59\n```\n:::\n:::\n\n\n::: {#3310c596 .cell execution_count=48}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4))\n\nline = 'dunn'\n\nnp.set_printoptions(precision=12)\n\nkde_dict = {}\nfor j, metric in enumerate(METRICS):\n distances = []\n min_dists = min(dists_to_global_mean_list[metric])\n max_dists = max(dists_to_global_mean_list[metric])\n xx = gs.linspace(gs.floor(min_dists), gs.ceil(max_dists), k_sampling_points)\n kde_dict[metric] = {}\n \n for i, treatment in enumerate(TREATMENTS):\n \n distances = dists_to_global_mean[metric][treatment][line][~gs.isnan(dists_to_global_mean[metric][treatment][line])]\n counts, bin_edges, _ = axs[j].hist(distances, bins=20, alpha=0.4, density=True, label=treatment, color=f\"C{i}\")\n print(treatment, metric)\n print(\"counts are:\", counts)\n print(\"bin_edges are:\", bin_edges)\n kde = stats.gaussian_kde(distances)\n kde_dict[metric][treatment] = kde\n axs[j].plot(xx, kde(xx), color=f\"C{i}\")\n axs[j].set_xlim((min_dists, max_dists))\n axs[j].legend(fontsize=12)\n\n axs[j].set_title(f\"{metric}\", fontsize=14)\n axs[j].set_ylabel(\"Fraction of cells\", fontsize=14)\n\n\n# fig.suptitle(\"Histograms of SRV distances to global mean cell\", fontsize=20)\n \nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_histogram.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_histogram.pdf\"))\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncontrol SRV\ncounts are: [3.599823688084 9.414923491911 9.138013977443 2.492185630212\n 2.215276115744 2.215276115744 2.492185630212 4.15364271702\n 6.092009318296 3.876733202552 2.492185630212 1.38454757234\n 1.107638057872 2.492185630212 0.553819028936 0.\n 0.553819028936 0. 0. 0.276909514468]\nbin_edges are: [0.190412844846 0.208744255891 0.227075666936 0.245407077981\n 0.263738489026 0.28206990007 0.300401311115 0.31873272216\n 0.337064133205 0.35539554425 0.373726955295 0.39205836634\n 0.410389777385 0.42872118843 0.447052599475 0.46538401052\n 0.483715421565 0.50204683261 0.520378243655 0.5387096547\n 0.557041065745]\ncytd SRV\ncounts are: [0.627751614862 0. 1.883254844586 1.255503229724\n 1.255503229724 1.255503229724 1.883254844586 5.649764533759\n 4.394261304035 5.649764533759 8.160770993208 5.649764533759\n 6.905267763483 3.138758074311 2.511006459448 1.255503229724\n 3.138758074311 1.883254844586 0.627751614862 0.627751614862]\nbin_edges are: [0.26221861859 0.279533691877 0.296848765164 0.314163838451\n 0.331478911738 0.348793985025 0.366109058312 0.383424131599\n 0.400739204886 0.418054278173 0.43536935146 0.452684424747\n 0.469999498034 0.487314571321 0.504629644608 0.521944717895\n 0.539259791183 0.55657486447 0.573889937757 0.591205011044\n 0.608520084331]\njasp SRV\ncounts are: [0.928427307436 0.928427307436 0.928427307436 2.785281922307\n 3.713709229743 3.713709229743 4.642136537178 6.49899115205\n 6.49899115205 6.49899115205 9.284273074357 8.355845766921\n 6.49899115205 7.427418459485 2.785281922307 4.642136537178\n 1.856854614871 2.785281922307 1.856854614871 1.856854614871]\nbin_edges are: [0.244313646946 0.256149803531 0.267985960117 0.279822116702\n 0.291658273288 0.303494429873 0.315330586458 0.327166743044\n 0.339002899629 0.350839056215 0.3626752128 0.374511369386\n 0.386347525971 0.398183682557 0.410019839142 0.421855995727\n 0.433692152313 0.445528308898 0.457364465484 0.469200622069\n 0.481036778655]\ncontrol Linear\ncounts are: [0.973976940289 1.704459645506 3.895907761156 3.165425055939\n 4.626390466373 4.139401996228 5.35687317159 5.843861641734\n 4.626390466373 3.408919291012 2.19144811565 2.678436585795\n 1.460965410434 0.730482705217 0.973976940289 0.243494235072\n 0.730482705217 0. 0.243494235072 0.973976940289]\nbin_edges are: [0.084550020208 0.105397093366 0.126244166523 0.147091239681\n 0.167938312838 0.188785385996 0.209632459153 0.230479532311\n 0.251326605468 0.272173678626 0.293020751783 0.313867824941\n 0.334714898098 0.355561971256 0.376409044413 0.397256117571\n 0.418103190728 0.438950263886 0.459797337043 0.480644410201\n 0.501491483358]\ncytd Linear\ncounts are: [2.686991765509 1.343495882754 1.791327843673 2.239159804591\n 2.686991765509 3.582655687345 4.478319609181 4.478319609181\n 5.821815491936 4.030487648263 4.030487648263 1.343495882754\n 1.343495882754 0.447831960918 0. 0.\n 0.447831960918 0. 0. 0.447831960918]\nbin_edges are: [0.18370748819 0.20797901449 0.23225054079 0.25652206709\n 0.28079359339 0.30506511969 0.32933664599 0.35360817229\n 0.37787969859 0.40215122489 0.42642275119 0.45069427749\n 0.47496580379 0.49923733009 0.52350885639 0.54778038269\n 0.57205190899 0.59632343529 0.62059496159 0.64486648789\n 0.669138014189]\njasp Linear\ncounts are: [3.47808161386 5.21712242079 2.608561210395 6.956163227719\n 6.521403025987 6.086642824255 3.912841815592 0.434760201732\n 1.73904080693 0.434760201732 0.434760201732 0.\n 0.434760201732 0.434760201732 0. 0.\n 0. 0. 0.434760201732 0.434760201732]\nbin_edges are: [0.154345044651 0.179621072552 0.204897100452 0.230173128353\n 0.255449156253 0.280725184154 0.306001212054 0.331277239955\n 0.356553267855 0.381829295756 0.407105323656 0.432381351557\n 0.457657379457 0.482933407358 0.508209435258 0.533485463159\n 0.558761491059 0.58403751896 0.60931354686 0.634589574761\n 0.659865602661]\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-49-output-2.png){width=758 height=374}\n:::\n:::\n\n\nCalculate the ratio of overlapping regions formed by the three kde curves \n\n::: {#56a026f5 .cell execution_count=49}\n``` {.python .cell-code}\nfor metric in METRICS:\n min_dists = min(dists_to_global_mean_list[metric])\n max_dists = max(dists_to_global_mean_list[metric])\n for i, tmt1 in enumerate(TREATMENTS):\n for j in range(i+1, len(TREATMENTS)):\n tmt2 = TREATMENTS[j]\n ratio = calc_ratio(kde_dict[metric][tmt1], kde_dict[metric][tmt2], min_dists, max_dists)\n print(f\"Overlap ratio for {line} between {tmt1} and {tmt2} using {metric} metric is: {round(ratio, 2)}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nOverlap ratio for dunn between control and cytd using SRV metric is: 0.2\nOverlap ratio for dunn between control and jasp using SRV metric is: 0.4\nOverlap ratio for dunn between cytd and jasp using SRV metric is: 0.35\nOverlap ratio for dunn between control and cytd using Linear metric is: 0.32\nOverlap ratio for dunn between control and jasp using Linear metric is: 0.72\nOverlap ratio for dunn between cytd and jasp using Linear metric is: 0.37\n```\n:::\n:::\n\n\nConduct T-test to test if the two samples have the same expected average\n\n::: {#0386376a .cell execution_count=50}\n``` {.python .cell-code}\nfor line in LINES:\n for i in range(len(TREATMENTS)):\n tmt1 = TREATMENTS[i]\n for j in range(i+1, len(TREATMENTS)):\n tmt2 = TREATMENTS[j]\n for metric in METRICS:\n distance1 = dists_to_global_mean[metric][tmt1][line][~gs.isnan(dists_to_global_mean[metric][tmt1][line])]\n distance2 = dists_to_global_mean[metric][tmt2][line][~gs.isnan(dists_to_global_mean[metric][tmt2][line])]\n t_statistic, p_value = stats.ttest_ind(distance1, distance2)\n print(f\"Significance of differences for {line} between {tmt1} and {tmt2} using {metric} metric is: {'%.2e' % Decimal(p_value)}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nSignificance of differences for dlm8 between control and cytd using SRV metric is: 5.16e-25\nSignificance of differences for dlm8 between control and cytd using Linear metric is: 3.15e-11\nSignificance of differences for dlm8 between control and jasp using SRV metric is: 6.87e-06\nSignificance of differences for dlm8 between control and jasp using Linear metric is: 1.65e-01\nSignificance of differences for dlm8 between cytd and jasp using SRV metric is: 1.10e-09\nSignificance of differences for dlm8 between cytd and jasp using Linear metric is: 1.77e-04\nSignificance of differences for dunn between control and cytd using SRV metric is: 1.29e-41\nSignificance of differences for dunn between control and cytd using Linear metric is: 3.35e-24\nSignificance of differences for dunn between control and jasp using SRV metric is: 1.74e-14\nSignificance of differences for dunn between control and jasp using Linear metric is: 2.50e-03\nSignificance of differences for dunn between cytd and jasp using SRV metric is: 8.05e-16\nSignificance of differences for dunn between cytd and jasp using Linear metric is: 1.97e-10\n```\n:::\n:::\n\n\nLet's analyze bi-modal distribution for the control group of dunn cell line using SRV metric\n\nWe consider two groups: cells with [3.42551653, 3.43015473) - distance to the mean, cells with [3.47189855, 3.47653676) distance to the mean, and find the modes of the two groups\n\n::: {#f6aecdf7 .cell execution_count=51}\n``` {.python .cell-code}\nline = 'dunn'\ntreatment = 'control'\nmetric = 'SRV'\ndistances = dists_to_global_mean[metric][treatment][line]\nprint(min(distances), max(distances))\ngroup_1_left = 0.208744255891 \ngroup_1_right = 0.227075666936\ngroup_2_left = 0.337064133205 \ngroup_2_right = 0.35539554425\ngroup_1_indices = [i for i, element in enumerate(distances) if element <= group_1_right and element > group_1_left]\ngroup_2_indices = [i for i, element in enumerate(distances) if element <= group_2_right and element > group_2_left]\nprint(group_1_indices)\nprint(group_2_indices)\ngroup_1_cells = gs.array(ds_align[metric][treatment][line])[group_1_indices,:,:]\ngroup_2_cells = gs.array(ds_align[metric][treatment][line])[group_2_indices,:,:]\n\ncol_num = max(len(group_1_indices), len(group_2_indices))\nfig = plt.figure(figsize=(2*col_num, 2))\ncount = 1\nfor index in range(len(group_1_indices)):\n cell = group_1_cells[index]\n fig.add_subplot(2, col_num, count)\n count += 1\n plt.plot(cell[:, 0], cell[:, 1])\n plt.axis(\"equal\")\n plt.axis(\"off\")\n\ncount = max(len(group_1_indices), len(group_2_indices))+1\nfor index in range(len(group_2_indices)):\n cell = group_2_cells[index]\n fig.add_subplot(2, col_num, count)\n count += 1\n plt.plot(cell[:, 0], cell[:, 1])\n plt.axis(\"equal\")\n plt.axis(\"off\")\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_bimodal_mean.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_bimodal_mean.pdf\"))\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n0.19041284484557636 0.5570410657452678\n[24, 25, 26, 28, 29, 33, 34, 36, 37, 39, 40, 44, 48, 51, 54, 55, 56, 58, 107, 117, 120, 126, 128, 130, 131, 134, 136, 137, 138, 140, 141, 145, 151, 153]\n[2, 10, 16, 17, 64, 66, 77, 80, 86, 87, 90, 91, 95, 100, 101, 104, 157, 167, 168, 170, 175, 196]\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-52-output-2.png){width=5078 height=167}\n:::\n:::\n\n\n# Visualization of the Mean of each Treatment\n\nThe mean distances to the global mean shape differ. We also plot the mean shape for each of the subgroup, to get intuition on how the mean shape of each subgroup looks like.\n\nWe first calculate the SRV mean\n\n::: {#76398fc3 .cell execution_count=52}\n``` {.python .cell-code}\nmean_treatment_cells = {}\nmetric = 'SRV'\nfor treatment in TREATMENTS:\n treatment_cells = []\n for line in LINES:\n treatment_cells.extend(ds_align[metric][treatment][line])\n mean_estimator = FrechetMean(space=CURVES_SPACE_SRV)\n mean_estimator.fit(CURVES_SPACE_SRV.projection(gs.array(treatment_cells)))\n mean_treatment_cells[treatment] = mean_estimator.estimate_\n```\n:::\n\n\n::: {#e8fb4979 .cell execution_count=53}\n``` {.python .cell-code}\nmean_line_cells = {}\nfor line in LINES:\n line_cells = []\n for treatment in TREATMENTS:\n line_cells.extend(ds_align[metric][treatment][line])\n mean_estimator = FrechetMean(space=CURVES_SPACE_SRV)\n mean_estimator.fit(CURVES_SPACE_SRV.projection(gs.array(line_cells)))\n mean_line_cells[line] = mean_estimator.estimate_\n```\n:::\n\n\n::: {#2a7bb34f .cell execution_count=54}\n``` {.python .cell-code}\nmean_cells = {}\nmetric = 'SRV'\nmean_cells[metric] = {}\nfor treatment in TREATMENTS:\n mean_cells[metric][treatment] = {}\n for line in LINES:\n mean_estimator = FrechetMean(space=CURVES_SPACE_SRV)\n mean_estimator.fit(CURVES_SPACE_SRV.projection(gs.array(ds_align[metric][treatment][line])))\n mean_cells[metric][treatment][line] = mean_estimator.estimate_\n```\n:::\n\n\nWe then calculate the linear mean\n\n::: {#4e853e9d .cell execution_count=55}\n``` {.python .cell-code}\nmetric = 'Linear'\nmean_cells[metric] = {}\nfor treatment in TREATMENTS:\n mean_cells[metric][treatment] = {}\n for line in LINES:\n mean_cells[metric][treatment][line] = gs.mean(ds_align[metric][treatment][line], axis=0)\n```\n:::\n\n\nWhile the mean shapes of the control groups (for both cell lines) look regular, we observe that:\n- the mean shape for cytd is the most irregular (for both cell lines)\n- while the mean shape for jasp is more elongated for dlm8 cell line, and more irregular for dunn cell line.\n\n# Distance of the Cell Shapes to their Own Mean Shape\n\nLastly, we evaluate how each subgroup of cell shapes is distributed around the mean shape of their specific subgroup.\n\n::: {#01a3e4c1 .cell execution_count=56}\n``` {.python .cell-code}\ndists_to_own_mean = {}\n\nfor metric in METRICS:\n dists_to_own_mean[metric] = {}\n for treatment in TREATMENTS:\n dists_to_own_mean[metric][treatment] = {}\n for line in LINES:\n dists = []\n ids = []\n for i_curve, curve in enumerate(ds_align[metric][treatment][line]):\n if metric == 'SRV':\n one_dist = CURVES_SPACE_SRV.metric.dist(curve, mean_cells[metric][treatment][line])\n else:\n one_dist = gs.linalg.norm(curve - mean_cells[metric][treatment][line])\n if ~gs.isnan(one_dist):\n dists.append(one_dist)\n else:\n ids.append(i_curve)\n dists_to_own_mean[metric][treatment][line] = dists\n```\n:::\n\n\n::: {#134269d3 .cell execution_count=57}\n``` {.python .cell-code}\n# Align with ellipse\n\nline = 'dunn'\n\nfig, axes = plt.subplots(\n ncols=len(TREATMENTS),\n nrows=len(METRICS),\n figsize=(2.5*len(TREATMENTS), 2*len(METRICS)))\n\nfor j, metric in enumerate(METRICS):\n for i, treatment in enumerate(TREATMENTS):\n ax = axes[j, i]\n mean_cell = mean_cells[metric][treatment][line]\n ax.plot(mean_cell[:, 0], mean_cell[:, 1], color=f\"C{i}\")\n ax.axis(\"equal\")\n ax.axis(\"off\")\n ax.set_title(f\"{metric}-{treatment}\", fontsize=20)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_own_mean.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_own_mean.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-58-output-1.png){width=587 height=343}\n:::\n:::\n\n\n::: {#558baaed .cell execution_count=58}\n``` {.python .cell-code}\nline = 'dlm8'\n\nfig, axes = plt.subplots(\n ncols=len(TREATMENTS),\n nrows=len(METRICS),\n figsize=(2.5*len(TREATMENTS), 2*len(METRICS)))\n\nfor j, metric in enumerate(METRICS):\n for i, treatment in enumerate(TREATMENTS):\n ax = axes[j, i]\n mean_cell = mean_cells[metric][treatment][line]\n ax.plot(mean_cell[:, 0], mean_cell[:, 1], color=f\"C{i}\")\n ax.axis(\"equal\")\n ax.axis(\"off\")\n ax.set_title(f\"{metric}-{treatment}\", fontsize=20)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_own_mean.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_own_mean.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-59-output-1.png){width=587 height=343}\n:::\n:::\n\n\nWe observe for the linear mean, the means go narrower as going right. This is caused by the start points for the cells align exactly on the right with the start point of the reference cell.\n\nWe notice this artifactual pattern only happens for the linear means (espectially for the cytd group). Can we argue this is an advantage for SRV (reparameterization + SRV mean)? \n\nThe above code find a given number of quantiles within the distance's histogram, using SRV metric and own mean, and plots the corresponding cell, for each treatment and each cell line.\n\n::: {#16361f0f .cell execution_count=59}\n``` {.python .cell-code}\nimport scipy.stats as ss\n\nline = 'dunn'\nn_quantiles = 10\n\nfig, axes = plt.subplots(\n nrows=len(TREATMENTS)*len(METRICS),\n ncols=n_quantiles,\n figsize=(20, 2 * len(TREATMENTS) * len(METRICS)),\n)\n\nranks = {}\n\nfor i, treatment in enumerate(TREATMENTS):\n ranks[treatment] = {}\n for j, metric in enumerate(METRICS):\n \n dists_list = dists_to_own_mean[metric][treatment][line]\n dists_list = [d + 0.0001 * gs.random.rand(1)[0] for d in dists_list]\n cells_list = list(ds_align[metric][treatment][line])\n assert len(dists_list) == len(cells_list)\n n_cells = len(dists_list)\n\n ranks[treatment][metric] = ss.rankdata(dists_list)\n\n zipped_lists = zip(dists_list, cells_list)\n sorted_pairs = sorted(zipped_lists)\n\n tuples = zip(*sorted_pairs)\n sorted_dists_list, sorted_cells_list = [list(t) for t in tuples]\n for i_quantile in range(n_quantiles):\n quantile = int(0.1 * n_cells * i_quantile)\n one_cell = sorted_cells_list[quantile]\n ax = axes[2*i+j, i_quantile]\n ax.plot(one_cell[:, 0], one_cell[:, 1], c=f\"C{i}\")\n ax.set_title(f\"0.{i_quantile} quantile\", fontsize=14)\n # ax.axis(\"off\")\n # Turn off tick labels\n ax.set_yticklabels([])\n ax.set_xticklabels([])\n ax.set_xticks([])\n ax.set_yticks([])\n ax.spines[\"top\"].set_visible(False)\n ax.spines[\"right\"].set_visible(False)\n ax.spines[\"bottom\"].set_visible(False)\n ax.spines[\"left\"].set_visible(False)\n if i_quantile == 0:\n ax.set_ylabel(f\"{metric} - \\n {treatment}\", rotation=90, fontsize=18)\nplt.tight_layout()\n# plt.suptitle(f\"Quantiles for linear metric using own mean\", y=-0.01, fontsize=24)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_quantile.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_quantile.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-60-output-1.png){width=1909 height=1142}\n:::\n:::\n\n\nWe do not observe any clear patterns between the rank of the cells with distances using SRV metric and with the linear metric.\n\n::: {#c57cf34f .cell execution_count=60}\n``` {.python .cell-code}\nline = 'dlm8'\nn_quantiles = 10\n\nfig, axes = plt.subplots(\n nrows=len(TREATMENTS)*len(METRICS),\n ncols=n_quantiles,\n figsize=(20, 2 * len(TREATMENTS) * len(METRICS)),\n)\n\nfor i, treatment in enumerate(TREATMENTS):\n for j, metric in enumerate(METRICS):\n dists_list = dists_to_own_mean[metric][treatment][line]\n dists_list = [d + 0.0001 * gs.random.rand(1)[0] for d in dists_list]\n cells_list = list(ds_align[metric][treatment][line])\n assert len(dists_list) == len(dists_list)\n n_cells = len(dists_list)\n\n zipped_lists = zip(dists_list, cells_list)\n sorted_pairs = sorted(zipped_lists)\n\n tuples = zip(*sorted_pairs)\n sorted_dists_list, sorted_cells_list = [list(t) for t in tuples]\n for i_quantile in range(n_quantiles):\n quantile = int(0.1 * n_cells * i_quantile)\n one_cell = sorted_cells_list[quantile]\n ax = axes[2*i+j, i_quantile]\n ax.plot(one_cell[:, 0], one_cell[:, 1], c=f\"C{i}\")\n ax.set_title(f\"0.{i_quantile} quantile\", fontsize=14)\n # ax.axis(\"off\")\n # Turn off tick labels\n ax.set_yticklabels([])\n ax.set_xticklabels([])\n ax.set_xticks([])\n ax.set_yticks([])\n ax.spines[\"top\"].set_visible(False)\n ax.spines[\"right\"].set_visible(False)\n ax.spines[\"bottom\"].set_visible(False)\n ax.spines[\"left\"].set_visible(False)\n if i_quantile == 0:\n ax.set_ylabel(f\"{metric} - \\n {treatment}\", rotation=90, fontsize=18)\nplt.tight_layout()\n# plt.suptitle(f\"Quantiles for linear metric using own mean\", y=-0.01, fontsize=24)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_quantile.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_quantile.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-61-output-1.png){width=1909 height=1142}\n:::\n:::\n\n\nThe above code find a given number of quantiles within the distance's histogram, using linear metric and own mean, and plots the corresponding cell, for each treatment and each cell line.\n\n# Dimensionality Reduction\n\nWe use the following experiments to illustrate how SRV metric can help with dimensionality reduction \n\n::: {#dd38cf1e .cell execution_count=61}\n``` {.python .cell-code}\ndef scaled_stress(pos, pairwise_dists):\n \"\"\" \n Calculate the scaled stress invariant to scaling using the original stress \\\n statistics and actual pairwise distances\n\n :param float unscaled_stress: the original stress\n :param 2D np.array[float] pairwise_dists: pairwise distance\n \"\"\"\n \n # compute pairwise distance of pos\n pairwise_pos = np.empty(shape=(pos.shape[0], pos.shape[0]))\n for i in range(pos.shape[0]):\n for j in range(pos.shape[0]):\n pairwise_pos[i,j] = np.sqrt(np.sum(pos[i]-pos[j])**2)\n \n print(pairwise_pos)\n stress = np.sqrt(np.sum((pairwise_dists-pairwise_pos)**2))\n \n return stress/np.sqrt(np.sum(pairwise_dists**2))\n```\n:::\n\n\n::: {#b0035a2f .cell execution_count=62}\n``` {.python .cell-code}\nmds = {}\npos = {}\ndims = range(2, 11)\nstresses = {}\n\nfor metric in METRICS:\n mds[metric] = {}\n pos[metric] = {}\n stresses[metric] = []\n for dim in dims:\n mds[metric][dim] = manifold.MDS(n_components=dim, random_state=0, dissimilarity=\"precomputed\") # random_state set to 10\n pos[metric][dim] = mds[metric][dim].fit(pairwise_dists[metric]).embedding_\n stress_val = mds[metric][dim].stress_\n scaled_stress_val = np.sqrt(stress_val/((pairwise_dists[metric]**2).sum()/2))\n # scaled_stress_val = scaled_stress(pos[metric][dim], pairwise_dists[metric])\n\n print(f\"the unscaled stress for {metric} model is for {dim}:\", stress_val)\n stresses[metric].append(scaled_stress_val)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nthe unscaled stress for SRV model is for 2: 0.0015505150986308987\nthe unscaled stress for SRV model is for 3: 0.0009766856050873998\nthe unscaled stress for SRV model is for 4: 0.0007390199671520337\nthe unscaled stress for SRV model is for 5: 0.0005748305174444293\nthe unscaled stress for SRV model is for 6: 0.00047113942181298865\nthe unscaled stress for SRV model is for 7: 0.0003990770585748401\nthe unscaled stress for SRV model is for 8: 0.00034641999727906943\nthe unscaled stress for SRV model is for 9: 0.00030596906074277627\nthe unscaled stress for SRV model is for 10: 0.00027546016788315334\nthe unscaled stress for Linear model is for 2: 0.0012568732933103922\nthe unscaled stress for Linear model is for 3: 0.0008789553123291832\nthe unscaled stress for Linear model is for 4: 0.0007370740946128706\nthe unscaled stress for Linear model is for 5: 0.0006365408960217103\nthe unscaled stress for Linear model is for 6: 0.0005664042865819429\nthe unscaled stress for Linear model is for 7: 0.0005223292015115522\nthe unscaled stress for Linear model is for 8: 0.0004846528585517728\nthe unscaled stress for Linear model is for 9: 0.00046151351278745815\nthe unscaled stress for Linear model is for 10: 0.0004397214282582284\n```\n:::\n:::\n\n\n::: {#b6cad3d9 .cell execution_count=63}\n``` {.python .cell-code}\nplt.figure(figsize = (4,4))\nfor metric in METRICS:\n plt.scatter(dims, stresses[metric], label=metric)\n plt.plot(dims, stresses[metric])\nplt.xticks(dims)\nplt.legend()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"MDS_stress.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"MDS_stress.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-64-output-1.png){width=356 height=337}\n:::\n:::\n\n\nIn terms of the scaled stress statistics, we observe linear metric perform better than SRV metric. That is, linear metric preserves the pairwise distances in embedded dimension better than the SRV metric.\n\nCalculate MDS statistics for dimension 2\n\n::: {#80d1bb19 .cell execution_count=64}\n``` {.python .cell-code}\nmetric = 'SRV'\nmds = manifold.MDS(n_components=2, random_state=0, dissimilarity=\"precomputed\")\npos = mds.fit(pairwise_dists[metric]).embedding_\n```\n:::\n\n\nMDS embedding of cell treatments (control, cytd and jasp) for different cell lines (dunn and dlm8)\n\n::: {#b22689f4 .cell execution_count=65}\n``` {.python .cell-code}\nembs = {}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\nWe draw a comparison with linear metric using the following code\n\n::: {#7ce66870 .cell execution_count=66}\n``` {.python .cell-code}\nmetric = 'Linear'\nmds = manifold.MDS(n_components=2, random_state=0, dissimilarity=\"precomputed\")\npos = mds.fit(pairwise_dists[metric]).embedding_\nprint(\"the stress for linear model is:\", mds.stress_)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nthe stress for linear model is: 0.0012568732933103922\n```\n:::\n:::\n\n\n::: {#d786171d .cell execution_count=67}\n``` {.python .cell-code}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\nThe stress for MDS embedding using the linear metric is better than SRV metric. \n\nHowever, if we can make a better interpretation of the visual result of SRV metric, we could still argue SRV is better at capturing cell heterogeneity. \n\n::: {#4241bccd .cell execution_count=68}\n``` {.python .cell-code}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\n::: {#23f4e310 .cell execution_count=69}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4))\n\nline = 'dunn'\nfor j, metric in enumerate(METRICS):\n for i, treatment in enumerate(TREATMENTS):\n cur_embs = embs[metric][treatment][line]\n axs[j].scatter(\n cur_embs[:, 0],\n cur_embs[:, 1],\n label=treatment,\n s=10,\n alpha=0.4\n )\n # axs[j].set_xlim(-3.5*1e-5, 3.5*1e-5)\n axs[j].set_xlabel(\"First Dimension\")\n axs[j].set_ylabel(\"Second Dimension\")\n axs[j].legend()\n axs[j].set_title(f\"{metric}\")\n# fig.suptitle(\"MDS of cell shapes using SRV metric\", fontsize=20)\n\nplt.tight_layout()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_2D.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_2D.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-70-output-1.png){width=758 height=374}\n:::\n:::\n\n\n::: {#407a2524 .cell execution_count=70}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4))\n\nline = 'dlm8'\nfor j, metric in enumerate(METRICS):\n distances = []\n for i, treatment in enumerate(TREATMENTS):\n cur_embs = embs[metric][treatment][line]\n axs[j].scatter(\n cur_embs[:, 0],\n cur_embs[:, 1],\n label=treatment,\n s=10,\n alpha=0.4\n )\n # axs[j].set_xlim(-3.5*1e-5, 3.5*1e-5)\n axs[j].set_xlabel(\"First Dimension\")\n axs[j].set_ylabel(\"Second Dimension\")\n axs[j].legend()\n axs[j].set_title(f\"{metric}\")\n# fig.suptitle(\"MDS of cell shapes using SRV metric\", fontsize=20)\n\nplt.tight_layout()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_2D.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_2D.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-71-output-1.png){width=758 height=374}\n:::\n:::\n\n\nWe also consider embedding in 3D. \n\n::: {#74f6c096 .cell execution_count=71}\n``` {.python .cell-code}\nmetric = 'SRV'\nmds = manifold.MDS(n_components=3, random_state=0, dissimilarity=\"precomputed\")\npos = mds.fit(pairwise_dists[metric]).embedding_\n```\n:::\n\n\n::: {#af57ad95 .cell execution_count=72}\n``` {.python .cell-code}\nembs = {}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\n::: {#15441d7e .cell execution_count=73}\n``` {.python .cell-code}\nmetric = 'Linear'\nmds = manifold.MDS(n_components=3, random_state=1, dissimilarity=\"precomputed\")\npos = mds.fit(pairwise_dists[metric]).embedding_\nprint(\"the stress for linear model is:\", mds.stress_)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nthe stress for linear model is: 0.0008821306413255005\n```\n:::\n:::\n\n\n::: {#aaecc5c7 .cell execution_count=74}\n``` {.python .cell-code}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\n::: {#8f24e315 .cell execution_count=75}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4), subplot_kw=dict(projection='3d'))\n\nline = 'dunn'\nfor j, metric in enumerate(METRICS):\n distances = []\n for i, treatment in enumerate(TREATMENTS):\n cur_embs = embs[metric][treatment][line]\n axs[j].scatter(\n cur_embs[:, 0],\n cur_embs[:, 1],\n cur_embs[:, 2],\n label=treatment,\n s=10,\n alpha=0.4\n )\n axs[j].set_xlabel(\"First Dimension\")\n axs[j].set_ylabel(\"Second Dimension\")\n axs[j].legend()\n axs[j].set_title(f\"{metric}\")\n# fig.suptitle(\"MDS of cell shapes using linear metric\", fontsize=20)\n\nplt.tight_layout()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_3D.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_3D.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-76-output-1.png){width=777 height=398}\n:::\n:::\n\n\n::: {#90b02bd3 .cell execution_count=76}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4), subplot_kw=dict(projection='3d'))\n\nline = 'dlm8'\nfor j, metric in enumerate(METRICS):\n distances = []\n for i, treatment in enumerate(TREATMENTS):\n cur_embs = embs[metric][treatment][line]\n axs[j].scatter(\n cur_embs[:, 0],\n cur_embs[:, 1],\n cur_embs[:, 2],\n label=treatment,\n s=10,\n alpha=0.4\n )\n # axs[j].set_xlim(-3.5*1e-5, 3.5*1e-5)\n axs[j].set_xlabel(\"First Dimension\")\n axs[j].set_ylabel(\"Second Dimension\")\n axs[j].legend()\n axs[j].set_title(f\"{metric}\")\n# fig.suptitle(\"MDS of cell shapes using linear metric\", fontsize=20)\n\nplt.tight_layout()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_3D.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_3D.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-77-output-1.png){width=777 height=394}\n:::\n:::\n\n\n# Multi-class (3-class) classification \n\nWe now consider one cell line at the same time, to investigate the effects of the drugs on the cell shapes. Applying the MDS again gives the following results:\n\nSince the detected subspace dimension for this dataset is 3, we perform the classification based on 3D embeddings.\n\n::: {#63cc430b .cell execution_count=77}\n``` {.python .cell-code}\nfrom sklearn.metrics import precision_score, recall_score, accuracy_score\n\ndef svm_5_fold_classification(X, y):\n # Initialize a Support Vector Classifier\n svm_classifier = svm.SVC(kernel='poly', degree=4)\n\n # Prepare to split the data into 5 folds, maintaining the percentage of samples for each class\n skf = StratifiedKFold(n_splits=5)\n \n # To store precision and recall per class for each fold\n precisions_per_class = []\n recalls_per_class = []\n accuracy_per_class = []\n\n # Perform 5-fold cross-validation\n for train_index, test_index in skf.split(X, y):\n # Splitting data into training and test sets\n X_train, X_test = X[train_index], X[test_index]\n y_train, y_test = y[train_index], y[test_index]\n\n # Train the model\n svm_classifier.fit(X_train, y_train)\n \n # Predict on the test data\n y_pred = svm_classifier.predict(X_test)\n\n # Calculate precision and recall per class\n precision = precision_score(y_test, y_pred, average=None, zero_division=np.nan)\n recall = recall_score(y_test, y_pred, average=None, zero_division=np.nan)\n accuracy = accuracy_score(y_test, y_pred)\n\n # Store results from each fold\n precisions_per_class.append(precision)\n recalls_per_class.append(recall)\n accuracy_per_class.append(accuracy)\n \n # Calculate the mean precision and recall per class across all folds\n mean_precisions = np.mean(precisions_per_class, axis=0)\n mean_recalls = np.mean(recalls_per_class, axis=0)\n mean_accuracies = np.mean(accuracy_per_class, axis=0)\n \n print(\"Mean precisions per class across all folds:\", round(np.mean(mean_precisions), 2))\n print(\"Mean recalls per class across all folds:\", round(np.mean(mean_recalls), 2))\n print(\"Mean accuracies per class across all folds:\", round(mean_accuracies, 2))\n\n return mean_precisions, mean_recalls\n```\n:::\n\n\n::: {#89f6db0c .cell execution_count=78}\n``` {.python .cell-code}\nlines = gs.array(lines)\ntreatments = gs.array(treatments)\n```\n:::\n\n\n::: {#d6230677 .cell execution_count=79}\n``` {.python .cell-code}\nfor line in LINES:\n for metric in METRICS:\n control_indexes = gs.where((lines == line) & (treatments == \"control\"))[0]\n cytd_indexes = gs.where((lines == line) & (treatments == \"cytd\"))[0]\n jasp_indexes = gs.where((lines == line) & (treatments == \"jasp\"))[0]\n treatment_indexes = gs.where((lines == line) & (treatments != 'control'))[0]\n\n # indexes = gs.concatenate((jasp_indexes, cytd_indexes, control_indexes))\n indexes = gs.concatenate((control_indexes, treatment_indexes))\n matrix = pairwise_dists[metric][indexes][:, indexes]\n\n mds = manifold.MDS(n_components=2, random_state = 10, dissimilarity=\"precomputed\")\n pos = mds.fit(matrix).embedding_\n\n line_treatments = treatments[lines == line]\n line_treatments_strings, line_treatments_labels = np.unique(line_treatments, return_inverse=True)\n # print(line_treatments_strings)\n # print(line_treatments_labels)\n\n for i, label in enumerate(line_treatments_labels):\n if line_treatments_strings[label] == 'cytd' or line_treatments_strings[label] == 'jasp':\n line_treatments_labels[i] = len(line_treatments_strings)\n \n\n print(f\"Using {metric} on {line}\")\n # print(line_treatments_labels)\n svm_5_fold_classification(pos, line_treatments_labels)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nUsing SRV on dlm8\nMean precisions per class across all folds: 0.71\nMean recalls per class across all folds: 0.7\nMean accuracies per class across all folds: 0.69\nUsing Linear on dlm8\nMean precisions per class across all folds: 0.68\nMean recalls per class across all folds: 0.62\nMean accuracies per class across all folds: 0.6\nUsing SRV on dunn\nMean precisions per class across all folds: 0.73\nMean recalls per class across all folds: 0.69\nMean accuracies per class across all folds: 0.7\nUsing Linear on dunn\nMean precisions per class across all folds: 0.62\nMean recalls per class across all folds: 0.59\nMean accuracies per class across all folds: 0.6\n```\n:::\n:::\n\n\n", + "markdown": "---\ntitle: Shape Analysis of Cancer Cells\nauthor: Wanxin Li\ndate: \"August 15, 2024\"\ncategories: [biology, bioinformatics] \n---\n\n\nThis notebook is adapted from [this notebook](https://github.com/geomstats/geomstats/blob/main/notebooks/11_real_world_applications__cell_shapes_analysis.ipynb) (Lead author: Nina Miolane). \n\nThis notebook studies *Osteosarcoma* (bone cancer) cells and the impact of drug treatment on their *morphological shapes*, by analyzing cell images obtained from fluorescence microscopy. \n\nThis analysis relies on the *elastic metric between discrete curves* from Geomstats. We will study to which extent this metric can detect how the cell shape is associated with the response to treatment.\n\nThis notebook is adapted from Florent Michel's submission to the [ICLR 2021 Computational Geometry and Topology challenge](https://github.com/geomstats/challenge-iclr-2021).\n\n
\n \n
\n\nFigure 1: Representative images of the cell lines using fluorescence microscopy, studied in this notebook (Image credit : Ashok Prasad). The cells nuclei (blue), the actin cytoskeleton (green) and the lipid membrane (red) of each cell are stained and colored. We only focus on the cell shape in our analysis.\n\n# 1. Introduction and Motivation\n\nBiological cells adopt a variety of shapes, determined by multiple processes and biophysical forces under the control of the cell. These shapes can be studied with different quantitative measures that reflect the cellular morphology [(MGCKCKDDRTWSBCC2018)](#References). With the emergence of large-scale biological cell image data, morphological studies have many applications. For example, measures of irregularity and spreading of cells allow accurate classification and discrimination between cancer cell lines treated with different drugs [(AXCFP2019)](#References).\n\nAs metrics defined on the shape space of curves, the *elastic metrics* [(SKJJ2010)](#References) implemented in Geomstats are a potential tool for analyzing and comparing biological cell shapes. Their associated geodesics and geodesic distances provide a natural framework for optimally matching, deforming, and comparing cell shapes.\n\n::: {#65cdabfd .cell execution_count=1}\n``` {.python .cell-code}\nfrom decimal import Decimal\nimport matplotlib.pyplot as plt\n\nimport geomstats.backend as gs\nimport numpy as np\nfrom common import *\nimport random\nimport os\nimport scipy.stats as stats\nfrom sklearn import manifold\n\ngs.random.seed(2021)\n```\n:::\n\n\n::: {#0d1fda83 .cell execution_count=2}\n``` {.python .cell-code}\nbase_path = \"/home/wanxinli/dyn/dyn/\"\ndata_path = os.path.join(base_path, \"datasets\")\n\ndataset_name = 'osteosarcoma'\nfigs_dir = os.path.join(\"/home/wanxinli/dyn/dyn/figs\", dataset_name)\nsavefig = False\n\n# If compute for the first time, we need to compute pairwise distances and run DeCOr-MDS\n# Otherwise, we can just use the pre-computed results\nfirst_time = False\nif savefig:\n print(f\"Will save figs to {figs_dir}\")\n```\n:::\n\n\n# 2. Dataset Description\n\nWe study a dataset of mouse *Osteosarcoma* imaged cells [(AXCFP2019)](#References). The dataset contains two different cancer cell lines : *DLM8* and *DUNN*, respectively representing a more agressive and a less agressive cancer. Among these cells, some have also been treated with different single drugs that perturb the cellular cytoskeleton. Overall, we can label each cell according to their cell line (*DLM8* and *DUNN*), and also if it is a *control* cell (no treatment), or has been treated with one of the following drugs : *Jasp* (jasplakinolide) and *Cytd* (cytochalasin D).\n\nEach cell comes from a raw image containing a set of cells, which was thresholded to generate binarized images.\n\n\n \n\n\nAfter binarizing the images, contouring was used to isolate each cell, and to extract their boundaries as a counter-clockwise ordered list of 2D coordinates, which corresponds to the representation of discrete curve in Geomstats. We load these discrete curves into the notebook.\n\n::: {#3436b16b .cell execution_count=3}\n``` {.python .cell-code}\nimport geomstats.datasets.utils as data_utils\n\ncells, lines, treatments = data_utils.load_cells()\nprint(f\"Total number of cells : {len(cells)}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nTotal number of cells : 650\n```\n:::\n:::\n\n\nThe cells are grouped by treatment class in the dataset : \n- the *control* cells, \n- the cells treated with *Cytd*,\n- and the ones treated with *Jasp*. \n\nAdditionally, in each of these classes, there are two cell lines : \n- the *DLM8* cells, and\n- the *DUNN* ones.\n\nBefore using the dataset, we check for duplicates in the dataset.\n\nWe compute the pairwise distance between two cells. If the pairwise distance is smaller than 0.1, we visualize the corresponding cells to check they are duplicates.\n\n::: {#75ab261b .cell execution_count=4}\n``` {.python .cell-code}\ntol = 1e-1\nfor i, cell_i in enumerate(cells):\n for j, cell_j in enumerate(cells):\n if i != j and cell_i.shape[0] == cell_j.shape[0]:\n dist = np.sum(np.sqrt(np.sum((cell_i-cell_j)**2,axis=1)))\n if dist < tol:\n print(f\"cell indices are: {i} and {j}, {lines[i]}, {lines[j]}, {treatments[i]}, {treatments[j]}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncell indices are: 363 and 396, dlm8, dlm8, cytd, cytd\ncell indices are: 396 and 363, dlm8, dlm8, cytd, cytd\ncell indices are: 513 and 519, dlm8, dlm8, jasp, jasp\ncell indices are: 519 and 513, dlm8, dlm8, jasp, jasp\n```\n:::\n:::\n\n\n::: {#b446389f .cell execution_count=5}\n``` {.python .cell-code}\npair_indices = [363, 396]\n\nfig = plt.figure(figsize=(10, 5))\nfig.add_subplot(121)\nindex_0 = pair_indices[0]\nplt.scatter(cells[index_0][:, 0], cells[index_0][:, 1], s=4)\nplt.axis(\"equal\")\nplt.title(f\"Cell {index_0}\")\n\nfig.add_subplot(122)\nindex_1 = pair_indices[1]\nplt.scatter(cells[index_1][:, 0], cells[index_1][:, 1], s=4)\nplt.axis(\"equal\")\nplt.title(f\"Cell {index_1}\")\n```\n\n::: {.cell-output .cell-output-display execution_count=5}\n```\nText(0.5, 1.0, 'Cell 396')\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-6-output-2.png){width=809 height=431}\n:::\n:::\n\n\n::: {#493c53c2 .cell execution_count=6}\n``` {.python .cell-code}\npair_indices = [513, 519]\n\nfig = plt.figure(figsize=(10, 5))\nfig.add_subplot(121)\nindex_0 = pair_indices[0]\nplt.scatter(cells[index_0][:, 0], cells[index_0][:, 1], s=4)\nplt.axis(\"equal\")\nplt.title(f\"Cell {index_0}\")\n\nfig.add_subplot(122)\nindex_1 = pair_indices[1]\nplt.scatter(cells[index_1][:, 0], cells[index_1][:, 1], s=4)\nplt.axis(\"equal\")\nplt.title(f\"Cell {index_1}\")\n```\n\n::: {.cell-output .cell-output-display execution_count=6}\n```\nText(0.5, 1.0, 'Cell 519')\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-7-output-2.png){width=809 height=431}\n:::\n:::\n\n\nCheck the category indices in order to remove corresponding cells in `ds_align`\n\n::: {#0754d105 .cell execution_count=7}\n``` {.python .cell-code}\ndelete_indices = [363, 396, 513, 519]\ncategory_count = {}\nglobal_count = 0\nfor i in range(len(cells)):\n treatment = treatments[i]\n line = lines[i]\n if treatment not in category_count:\n category_count[treatment] = {}\n if line not in category_count[treatment]:\n category_count[treatment][line] = 0\n # if global_count in delete_indices:\n # print(treatment, line, category_count[treatment][line])\n category_count[treatment][line] += 1\n global_count += 1\n```\n:::\n\n\nSince 363th, 396th and 513th, 519th are duplicates of each other and after visualization we see they are poor quality cells with overlapping adjacent cells, we remove them from our dataset. \n\n::: {#6a24c758 .cell execution_count=8}\n``` {.python .cell-code}\ndef remove_cells(cells, lines, treatments, delete_indices):\n \"\"\" \n Remove cells of control group from cells, lines and treatments\n\n :param list[int] delete_indices: the indices to delete\n \"\"\"\n delete_indices = sorted(delete_indices, reverse=True) # to prevent change in index when deleting elements\n \n # Delete elements\n cells = del_arr_elements(cells, delete_indices)\n lines = list(np.delete(np.array(lines), delete_indices, axis=0))\n treatments = list(np.delete(np.array(treatments), delete_indices, axis=0))\n\n return cells, lines, treatments\n```\n:::\n\n\n::: {#dca3535d .cell execution_count=9}\n``` {.python .cell-code}\ndelete_indices = [363, 396, 513, 519]\ncells, lines, treatments = remove_cells(cells, lines, treatments, delete_indices)\n# print(len(cells), len(lines), len(treatments))\n```\n:::\n\n\nThis is shown by displaying the unique elements in the lists `treatments` and `lines`:\n\n::: {#b252c007 .cell execution_count=10}\n``` {.python .cell-code}\nimport pandas as pd\n\nTREATMENTS = gs.unique(treatments)\nprint(TREATMENTS)\nLINES = gs.unique(lines)\nprint(LINES)\nMETRICS = ['SRV', 'Linear']\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n['control' 'cytd' 'jasp']\n['dlm8' 'dunn']\n```\n:::\n:::\n\n\nThe size of each class is displayed below:\n\n::: {#bb752d7a .cell execution_count=11}\n``` {.python .cell-code}\nds = {}\n\nn_cells_arr = gs.zeros((3, 2))\n\nfor i, treatment in enumerate(TREATMENTS):\n print(f\"{treatment} :\")\n ds[treatment] = {}\n for j, line in enumerate(LINES):\n to_keep = gs.array(\n [\n one_treatment == treatment and one_line == line\n for one_treatment, one_line in zip(treatments, lines)\n ]\n )\n ds[treatment][line] = [\n cell_i for cell_i, to_keep_i in zip(cells, to_keep) if to_keep_i\n ]\n nb = len(ds[treatment][line])\n print(f\"\\t {nb} {line}\")\n n_cells_arr[i, j] = nb\n\nn_cells_df = pd.DataFrame({\"dlm8\": n_cells_arr[:, 0], \"dunn\": n_cells_arr[:, 1]})\nn_cells_df = n_cells_df.set_index(TREATMENTS)\n\ndisplay(n_cells_df)\n# display(ds)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncontrol :\n\t 114 dlm8\n\t 204 dunn\ncytd :\n\t 80 dlm8\n\t 93 dunn\njasp :\n\t 60 dlm8\n\t 95 dunn\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
dlm8dunn
control114.0204.0
cytd80.093.0
jasp60.095.0
\n
\n```\n:::\n:::\n\n\nWe have organized the cell data into the dictionnary `ds`. Before proceeding to the actual data analysis, we provide an auxiliary function `apply_func_to_ds`.\n\n::: {#24e6d433 .cell execution_count=12}\n``` {.python .cell-code}\ndef apply_func_to_ds(input_ds, func):\n \"\"\"Apply the input function func to the input dictionnary input_ds.\n\n This function goes through the dictionnary structure and applies\n func to every cell in input_ds[treatment][line].\n\n It stores the result in a dictionnary output_ds that is returned\n to the user.\n\n Parameters\n ----------\n input_ds : dict\n Input dictionnary, with keys treatment-line.\n func : callable\n Function to be applied to the values of the dictionnary, i.e.\n the cells.\n\n Returns\n -------\n output_ds : dict\n Output dictionnary, with the same keys as input_ds.\n \"\"\"\n output_ds = {}\n for treatment in TREATMENTS:\n output_ds[treatment] = {}\n for line in LINES:\n output_list = []\n for one_cell in input_ds[treatment][line]:\n output_list.append(func(one_cell))\n output_ds[treatment][line] = gs.array(output_list)\n return output_ds\n```\n:::\n\n\nNow we can move on to the actual data analysis, starting with a preprocessing of the cell boundaries.\n\n# 3. Preprocessing \n\n### Interpolation: Encoding Discrete Curves With Same Number of Points\n\nAs we need discrete curves with the same number of sampled points to compute pairwise distances, the following interpolation is applied to each curve, after setting the number of sampling points.\n\nTo set up the number of sampling points, you can edit the following line in the next cell:\n\n::: {#362fd3b3 .cell execution_count=13}\n``` {.python .cell-code}\ndef interpolate(curve, nb_points):\n \"\"\"Interpolate a discrete curve with nb_points from a discrete curve.\n\n Returns\n -------\n interpolation : discrete curve with nb_points points\n \"\"\"\n old_length = curve.shape[0]\n interpolation = gs.zeros((nb_points, 2))\n incr = old_length / nb_points\n pos = 0\n for i in range(nb_points):\n index = int(gs.floor(pos))\n interpolation[i] = curve[index] + (pos - index) * (\n curve[(index + 1) % old_length] - curve[index]\n )\n pos += incr\n return interpolation\n\n\nk_sampling_points = 2000\n```\n:::\n\n\nTo illustrate the result of this interpolation, we compare for a randomly chosen cell the original curve with the correponding interpolated one (to visualize another cell, you can simply re-run the code).\n\n::: {#9f4adad4 .cell execution_count=14}\n``` {.python .cell-code}\nindex = 0\ncell_rand = cells[index]\ncell_interpolation = interpolate(cell_rand, k_sampling_points)\n\nfig = plt.figure(figsize=(15, 5))\n\nfig.add_subplot(121)\nplt.scatter(cell_rand[:, 0], cell_rand[:, 1], color='black', s=4)\n\nplt.plot(cell_rand[:, 0], cell_rand[:, 1])\nplt.axis(\"equal\")\nplt.title(f\"Original curve ({len(cell_rand)} points)\")\nplt.axis(\"off\")\n\nfig.add_subplot(122)\nplt.scatter(cell_interpolation[:, 0], cell_interpolation[:, 1], color='black', s=4)\n\nplt.plot(cell_interpolation[:, 0], cell_interpolation[:, 1])\nplt.axis(\"equal\")\nplt.title(f\"Interpolated curve ({k_sampling_points} points)\")\nplt.axis(\"off\")\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"interpolation.svg\"))\n plt.savefig(os.path.join(figs_dir, \"interpolation.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-15-output-1.png){width=1135 height=409}\n:::\n:::\n\n\nAs the interpolation is working as expected, we use the auxiliary function `apply_func_to_ds` to apply the function `func=interpolate` to the dataset `ds`, i.e. the dictionnary containing the cells boundaries.\n\nWe obtain a new dictionnary, `ds_interp`, with the interpolated cell boundaries.\n\n::: {#fc2eec09 .cell execution_count=15}\n``` {.python .cell-code}\nds_interp = apply_func_to_ds(\n input_ds=ds, func=lambda x: interpolate(x, k_sampling_points)\n)\n```\n:::\n\n\nThe shape of an array of cells in `ds_interp[treatment][cell]` is therefore: `(\"number of cells in treatment-line\", \"number of sampling points\", 2)`, where 2 refers to the fact that we are considering cell shapes in 2D. \n\n### Visualization of Interpolated Dataset of Curves\n\nWe visualize the curves obtained, for a sample of control cells and treated cells (top row shows control, i.e. non-treated cells; bottom rows shows treated cells) across cell lines (left and blue for dlm8 and right and orange for dunn).\n\n::: {#656ba711 .cell execution_count=16}\n``` {.python .cell-code}\nn_cells_to_plot = 5\n# radius = 800\n\nfig = plt.figure(figsize=(16, 6))\ncount = 1\nfor i, treatment in enumerate(TREATMENTS):\n for line in LINES:\n cell_data = ds_interp[treatment][line]\n for i_to_plot in range(n_cells_to_plot):\n cell = gs.random.choice(cell_data)\n fig.add_subplot(3, 2 * n_cells_to_plot, count)\n count += 1\n plt.plot(cell[:, 0], cell[:, 1], color=\"C\" + str(i))\n # plt.xlim(-radius, radius)\n # plt.ylim(-radius, radius)\n plt.axis(\"equal\")\n plt.axis(\"off\")\n if i_to_plot == n_cells_to_plot // 2:\n plt.title(f\"{treatment} - {line}\", fontsize=20)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"sample_cells.svg\"))\n plt.savefig(os.path.join(figs_dir, \"sample_cells.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-17-output-1.png){width=1210 height=491}\n:::\n:::\n\n\nVisual inspection of these curves seems to indicate more protusions appearing in treated cells, compared with control ones. This is in agreement with the physiological impact of the drugs, which are known to perturb the internal cytoskeleton connected to the cell membrane. Using the elastic metric, our goal will be to see if we can quantitatively confirm these differences.\n\n### Remove duplicate sample points in curves\n\nDuring interpolation it is likely that some of the discrete curves in the dataset are downsampled from higher number of discrete data points to lower number of data points. Hence, two sampled data points that are close enough may end up overlapping after interpolation and hence such data points have to be dealt with specifically. \n\n::: {#c6bc6e60 .cell execution_count=17}\n``` {.python .cell-code}\nimport numpy as np\n\ndef preprocess(curve, tol=1e-10):\n \"\"\"Preprocess curve to ensure that there are no consecutive duplicate points.\n\n Returns\n -------\n curve : discrete curve\n \"\"\"\n\n dist = curve[1:] - curve[:-1]\n dist_norm = np.sqrt(np.sum(np.square(dist), axis=1))\n\n if np.any( dist_norm < tol ):\n for i in range(len(curve)-1):\n if np.sqrt(np.sum(np.square(curve[i+1] - curve[i]), axis=0)) < tol:\n curve[i+1] = (curve[i] + curve[i+2]) / 2\n\n return curve\n```\n:::\n\n\n::: {#fc5fd212 .cell execution_count=18}\n``` {.python .cell-code}\nds_proc = apply_func_to_ds(ds_interp, func=lambda x: preprocess(x))\n```\n:::\n\n\nCheck we did not loss any cells after duplicates\n\n::: {#246ee9ca .cell execution_count=19}\n``` {.python .cell-code}\nfor treatment in TREATMENTS:\n for line in LINES:\n for metric in METRICS:\n print(f\"{treatment} and {line} using {metric}: {len(ds_proc[treatment][line])}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncontrol and dlm8 using SRV: 114\ncontrol and dlm8 using Linear: 114\ncontrol and dunn using SRV: 204\ncontrol and dunn using Linear: 204\ncytd and dlm8 using SRV: 80\ncytd and dlm8 using Linear: 80\ncytd and dunn using SRV: 93\ncytd and dunn using Linear: 93\njasp and dlm8 using SRV: 60\njasp and dlm8 using Linear: 60\njasp and dunn using SRV: 95\njasp and dunn using Linear: 95\n```\n:::\n:::\n\n\n### Alignment\n\nOur goal is to study the cell boundaries in our dataset, as points in a shape space of closed curves quotiented by translation, scaling, and rotation, so these transformations do not affect our measure of distance between curves.\n\nIn practice, we apply functions that were initially designed to center (substract the barycenter), rescale (divide by the Frobenius norm) and then reparameterize (only for SRV metric).\n\nSince the alignment procedure takes 30 minutes, we ran `osteosarocoma_align.py` and saved the results in `~/dyn/datasets/osteosarcoma/aligned`\n\nLoad aligned cells from txt files. These files were generated by calling `align` function in `common.py`.\n\nWe get the aligned cells from preprocessed dataset.\n\nFurthermore, we align the barycenters of the cells to the barycenter of the projected base curve, and (optionally) flip the cell.\n\n::: {#9314e9ec .cell execution_count=20}\n``` {.python .cell-code}\ndef align_barycenter(cell, centroid_x, centroid_y, flip):\n \"\"\" \n Align the the barycenter of the cell to ref centeriod and flip the cell against the xaxis of the centriod if flip is True. \n\n :param 2D np array cell: cell to align\n :param float centroid_x: the x coordinates of the projected BASE_CURVE\n :param float centroid_y: the y coordinates of the projected BASE_CURVE\n :param bool flip: flip the cell against x = centroid x if True \n \"\"\"\n \n cell_bc = np.mean(cell, axis=0)\n aligned_cell = cell+[centroid_x, centroid_y]-cell_bc\n\n if flip:\n aligned_cell[:, 0] = 2*centroid_x-aligned_cell[:, 0]\n # Flip the order of the points\n med_index = int(np.floor(aligned_cell.shape[0]/2))\n flipped_aligned_cell = np.concatenate((aligned_cell[med_index:], aligned_cell[:med_index]), axis=0)\n flipped_aligned_cell = np.flipud(flipped_aligned_cell)\n aligned_cell = flipped_aligned_cell\n return aligned_cell\n\ndef get_centroid(base_curve):\n total_space = DiscreteCurvesStartingAtOrigin(k_sampling_points=k_sampling_points)\n proj_base_curve = total_space.projection(base_curve)\n base_centroid = np.mean(proj_base_curve, axis=0)\n return base_centroid[0], base_centroid[1]\n```\n:::\n\n\n::: {#65ea290b .cell execution_count=21}\n``` {.python .cell-code}\ndelete_indices = [363, 396, 513, 519]\n\naligned_base_folder = os.path.join(data_path, dataset_name, \"aligned\")\n\nBASE_CURVE = generate_ellipse(k_sampling_points)\ncentroid_x, centroid_y = get_centroid(BASE_CURVE)\n\nds_align = {}\n\nfor metric in METRICS:\n ds_align[metric] = {}\n if metric == 'SRV':\n aligned_folder = os.path.join(aligned_base_folder, 'projection_rescale_rotation_reparameterization')\n elif metric == 'Linear':\n aligned_folder = os.path.join(aligned_base_folder, 'projection_rescale_rotation_reparameterization')\n for treatment in TREATMENTS:\n ds_align[metric][treatment] = {}\n for line in LINES:\n ds_align[metric][treatment][line] = []\n cell_num = len(ds_proc[treatment][line])\n if line == 'dlm8' and (treatment == 'cytd' or treatment == 'jasp'):\n cell_num += 2\n for i in range(cell_num):\n # Do not load duplicate cells\n # cytd dlm8 45\n # cytd dlm8 78\n # jasp dlm8 20\n # jasp dlm8 26\n\n if (treatment == 'cytd' and line == 'dlm8' and (i == 45 or i == 78)) or \\\n (treatment == 'jasp' and line == 'dlm8' and (i == 20 or i == 26)):\n continue\n \n file_path = os.path.join(aligned_folder, f\"{treatment}_{line}_{i}.txt\")\n if os.path.exists(file_path):\n cell = np.loadtxt(file_path)\n ds_align[metric][treatment][line].append(cell)\n\n```\n:::\n\n\nCheck we did not loss any cells after alignment\n\n::: {#c079e116 .cell execution_count=22}\n``` {.python .cell-code}\nfor treatment in TREATMENTS:\n for line in LINES:\n for metric in METRICS:\n print(f\"{treatment} and {line} using {metric}: {len(ds_align[metric][treatment][line])}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncontrol and dlm8 using SRV: 113\ncontrol and dlm8 using Linear: 113\ncontrol and dunn using SRV: 199\ncontrol and dunn using Linear: 199\ncytd and dlm8 using SRV: 74\ncytd and dlm8 using Linear: 74\ncytd and dunn using SRV: 92\ncytd and dunn using Linear: 92\njasp and dlm8 using SRV: 56\njasp and dlm8 using Linear: 56\njasp and dunn using SRV: 91\njasp and dunn using Linear: 91\n```\n:::\n:::\n\n\nUpdate `lines` and `treatments`\n\n::: {#fba9c599 .cell execution_count=23}\n``` {.python .cell-code}\ntreatments = []\nlines = []\nfor treatment in TREATMENTS:\n for line in LINES:\n treatments.extend([treatment]*len(ds_align['SRV'][treatment][line]))\n lines.extend([line]*len(ds_align['SRV'][treatment][line]))\n\ntreatments = np.array(treatments)\nlines = np.array(lines)\nprint(\"treatment length is:\", len(treatments), \"lines length is:\", len(lines))\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ntreatment length is: 625 lines length is: 625\n```\n:::\n:::\n\n\nVisualize reference cell, unaligned cell and aligned cell.\n\n::: {#aa91c348 .cell execution_count=24}\n``` {.python .cell-code}\nindex = 0\nmetric = 'SRV'\nunaligned_cell = ds_proc[\"control\"][\"dlm8\"][index]\naligned_cell = ds_align[metric][\"control\"][\"dlm8\"][index]\n\nfirst_round_aligned_folder = os.path.join(aligned_base_folder, 'projection_rescale_rotation_reparameterization_first_round')\nreference_path = os.path.join(first_round_aligned_folder, f\"reference.txt\")\nmean_first_round = np.loadtxt(reference_path)\n\nfig = plt.figure(figsize=(15, 5))\n\nfig.add_subplot(131)\nplt.plot(mean_first_round[:, 0], mean_first_round[:, 1])\nplt.plot([mean_first_round[-1, 0], mean_first_round[0, 0]], [mean_first_round[-1, 1], mean_first_round[0, 1]], 'tab:blue')\nplt.scatter(mean_first_round[:, 0], mean_first_round[:, 1], s=4, c='black')\nplt.plot(mean_first_round[0, 0], mean_first_round[0, 1], \"ro\")\nplt.axis(\"equal\")\nplt.title(\"Reference curve\")\n\nfig.add_subplot(132)\nplt.plot(unaligned_cell[:, 0], unaligned_cell[:, 1])\nplt.scatter(unaligned_cell[:, 0], unaligned_cell[:, 1], s=4, c='black')\nplt.plot(unaligned_cell[0, 0], unaligned_cell[0, 1], \"ro\")\nplt.axis(\"equal\")\nplt.title(\"Unaligned curve\")\n\nfig.add_subplot(133)\nplt.plot(aligned_cell[:, 0], aligned_cell[:, 1])\nplt.scatter(aligned_cell[:, 0], aligned_cell[:, 1], s=4, c='black')\nplt.plot(aligned_cell[0, 0], aligned_cell[0, 1], \"ro\")\nplt.axis(\"equal\")\nplt.title(\"Aligned curve\")\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"alignment.svg\"))\n plt.savefig(os.path.join(figs_dir, \"alignment.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-25-output-1.png){width=1185 height=431}\n:::\n:::\n\n\nIn the plot above, the red dot shows the start of the parametrization of each curve. The right curve has been rotated from the curve in the middle, to be aligned with the left (reference) curve, which represents the first cell of the dataset. The starting point (in red) of this right curve has been also set to align with the reference.\n\n# 4 Data Analysis\n\n## Compute Mean Cell Shape of the Whole Dataset: \"Global\" Mean Shape\n\nWe want to compute the mean cell shape of the whole dataset. Thus, we first combine all the cell shape data into a single array.\n\n::: {#22422a7d .cell execution_count=25}\n``` {.python .cell-code}\nCURVES_SPACE_SRV = DiscreteCurvesStartingAtOrigin(ambient_dim=2, k_sampling_points=k_sampling_points)\n```\n:::\n\n\n::: {#bf629086 .cell execution_count=26}\n``` {.python .cell-code}\ncell_shapes_list = {}\nfor metric in METRICS:\n cell_shapes_list[metric] = []\n for treatment in TREATMENTS:\n for line in LINES:\n cell_shapes_list[metric].extend(ds_align[metric][treatment][line])\n\ncell_shapes = {}\nfor metric in METRICS:\n cell_shapes[metric] = gs.array(cell_shapes_list[metric])\nprint(cell_shapes['SRV'].shape)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n(625, 1999, 2)\n```\n:::\n:::\n\n\nRemove outliers using DeCOr-MDS, together for DUNN and DLM8 cell lines.\n\n::: {#9b93d992 .cell execution_count=27}\n``` {.python .cell-code}\ndef linear_dist(cell1, cell2):\n return gs.linalg.norm(cell1 - cell2)\n\ndef srv_dist(cell1, cell2):\n CURVES_SPACE_SRV.equip_with_metric(SRVMetric)\n return CURVES_SPACE_SRV.metric.dist(cell1, cell2)\n \n# compute pairwise distances, we only need to compute it once and save the results \npairwise_dists = {}\n\nif first_time:\n metric = 'SRV'\n pairwise_dists[metric] = parallel_dist(cell_shapes[metric], srv_dist, k_sampling_points)\n\n metric = 'Linear' \n pairwise_dists[metric] = parallel_dist(cell_shapes[metric], linear_dist, k_sampling_points)\n\n for metric in METRICS:\n np.savetxt(os.path.join(data_path, dataset_name, \"distance_matrix\", f\"{metric}_matrix.txt\"), pairwise_dists[metric])\nelse:\n for metric in METRICS:\n pairwise_dists[metric] = np.loadtxt(os.path.join(data_path, dataset_name, \"distance_matrix\", f\"{metric}_matrix.txt\"))\n```\n:::\n\n\n::: {#0a21085e .cell execution_count=28}\n``` {.python .cell-code}\n# to remove 132 and 199\none_cell = cell_shapes['Linear'][199]\nplt.plot(one_cell[:, 0], one_cell[:, 1], c=f\"gray\")\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-29-output-1.png){width=599 height=411}\n:::\n:::\n\n\n::: {#45b62d1e .cell execution_count=29}\n``` {.python .cell-code}\n# run DeCOr-MDS\nmetric = 'SRV'\ndim_start = 2 # we know the subspace dimension is 3, we set start and end to 3 to reduce runtime \ndim_end = 10\n# dim_start = 3\n# dim_end = 3\nstd_multi = 1\nif first_time:\n subspace_dim, outlier_indices = find_subspace_dim(pairwise_dists[metric], dim_start, dim_end, std_multi)\n print(f\"subspace dimension is: {subspace_dim}\")\n print(f\"outlier_indices are: {outlier_indices}\")\n```\n:::\n\n\nVisualize outlier cells to see if they are artifacts\n\n::: {#700d9f1f .cell execution_count=30}\n``` {.python .cell-code}\nif first_time:\n fig, axes = plt.subplots(\n nrows= 1,\n ncols=len(outlier_indices),\n figsize=(2*len(outlier_indices), 2),\n )\n\n for i, outlier_index in enumerate(outlier_indices):\n one_cell = cell_shapes[metric][outlier_index]\n ax = axes[i]\n ax.plot(one_cell[:, 0], one_cell[:, 1], c=f\"C{j}\")\n ax.set_title(f\"{outlier_index}\", fontsize=14)\n # Turn off tick labels\n ax.set_yticklabels([])\n ax.set_xticklabels([])\n ax.set_xticks([])\n ax.set_yticks([])\n ax.spines[\"top\"].set_visible(False)\n ax.spines[\"right\"].set_visible(False)\n ax.spines[\"bottom\"].set_visible(False)\n ax.spines[\"left\"].set_visible(False)\n\n plt.tight_layout()\n plt.suptitle(f\"\", y=-0.01, fontsize=24)\n # plt.savefig(os.path.join(figs_dir, \"outlier.svg\"))\n```\n:::\n\n\n::: {#af492cf5 .cell execution_count=31}\n``` {.python .cell-code}\ndelete_indices = [132, 199]\n\n\nfig, axes = plt.subplots(\n nrows= 1,\n ncols=len(delete_indices),\n figsize=(2*len(delete_indices), 2),\n)\n\n\nfor i, outlier_index in enumerate(delete_indices):\n one_cell = cell_shapes[metric][outlier_index]\n ax = axes[i]\n ax.plot(one_cell[:, 0], one_cell[:, 1], c=f\"gray\")\n ax.set_title(f\"{outlier_index}\", fontsize=14)\n # ax.axis(\"off\")\n # Turn off tick labels\n ax.set_yticklabels([])\n ax.set_xticklabels([])\n ax.set_xticks([])\n ax.set_yticks([])\n ax.spines[\"top\"].set_visible(False)\n ax.spines[\"right\"].set_visible(False)\n ax.spines[\"bottom\"].set_visible(False)\n ax.spines[\"left\"].set_visible(False)\n\nplt.tight_layout()\nplt.suptitle(f\"\", y=-0.01, fontsize=24)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"delete_outlier.svg\"))\n plt.savefig(os.path.join(figs_dir, \"delete_outlier.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-32-output-1.png){width=374 height=182}\n:::\n:::\n\n\nAfter visual inspection, we decide to remove the outlier cells\n\n::: {#18d83448 .cell execution_count=32}\n``` {.python .cell-code}\ndef remove_ds_two_layer(ds, delete_indices):\n global_i = sum(len(v) for values in ds.values() for v in values.values())-1\n\n for treatment in reversed(list(ds.keys())):\n treatment_values = ds[treatment]\n for line in reversed(list(treatment_values.keys())):\n line_cells = treatment_values[line]\n for i, _ in reversed(list(enumerate(line_cells))):\n if global_i in delete_indices:\n print(np.array(ds[treatment][line][:i]).shape, np.array(ds[treatment][line][i+1:]).shape)\n if len(np.array(ds[treatment][line][:i]).shape) == 1:\n ds[treatment][line] = np.array(ds[treatment][line][i+1:])\n elif len(np.array(ds[treatment][line][i+1:]).shape) == 1:\n ds[treatment][line] = np.array(ds[treatment][line][:i])\n else:\n ds[treatment][line] = np.concatenate((np.array(ds[treatment][line][:i]), np.array(ds[treatment][line][i+1:])), axis=0) \n global_i -= 1\n return ds\n\n\n\ndef remove_cells_two_layer(cells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align, delete_indices):\n \"\"\" \n Remove cells of control group from cells, cell_shapes, ds,\n the parameters returned from load_treated_osteosarcoma_cells\n Also update n_cells\n\n :param list[int] delete_indices: the indices to delete\n \"\"\"\n delete_indices = sorted(delete_indices, reverse=True) # to prevent change in index when deleting elements\n \n # Delete elements\n cells = del_arr_elements(cells, delete_indices) \n lines = list(np.delete(np.array(lines), delete_indices, axis=0))\n treatments = list(np.delete(np.array(treatments), delete_indices, axis=0))\n ds_proc = remove_ds_two_layer(ds_proc, delete_indices)\n \n for metric in METRICS:\n cell_shapes[metric] = np.delete(np.array(cell_shapes[metric]), delete_indices, axis=0)\n ds_align[metric] = remove_ds_two_layer(ds_align[metric], delete_indices)\n pairwise_dists[metric] = np.delete(pairwise_dists[metric], delete_indices, axis=0)\n pairwise_dists[metric] = np.delete(pairwise_dists[metric], delete_indices, axis=1)\n\n\n return cells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align\n```\n:::\n\n\n::: {#b758486b .cell execution_count=33}\n``` {.python .cell-code}\ncells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align = remove_cells_two_layer(cells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align, delete_indices)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n(85, 2000, 2) (118, 2000, 2)\n(18, 2000, 2) (184, 2000, 2)\n(86, 1999, 2) (112, 1999, 2)\n(19, 1999, 2) (178, 1999, 2)\n(86, 1999, 2) (112, 1999, 2)\n(19, 1999, 2) (178, 1999, 2)\n```\n:::\n:::\n\n\nCheck we did not loss any other cells after the removal\n\n::: {#e62179b1 .cell execution_count=34}\n``` {.python .cell-code}\ndef check_num(cell_shapes, treatments, lines, pairwise_dists, ds_align):\n \n print(f\"treatments number is: {len(treatments)}, lines number is: {len(lines)}\")\n for metric in METRICS:\n print(f\"pairwise_dists for {metric} shape is: {pairwise_dists[metric].shape}\")\n print(f\"cell_shapes for {metric} number is : {len(cell_shapes[metric])}\")\n \n for line in LINES:\n for treatment in TREATMENTS:\n print(f\"ds_align {treatment} {line} using {metric}: {len(ds_align[metric][treatment][line])}\")\n```\n:::\n\n\n::: {#2c23f486 .cell execution_count=35}\n``` {.python .cell-code}\ncheck_num(cell_shapes, treatments, lines, pairwise_dists, ds_align)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ntreatments number is: 623, lines number is: 623\npairwise_dists for SRV shape is: (623, 623)\ncell_shapes for SRV number is : 623\nds_align control dlm8 using SRV: 113\nds_align cytd dlm8 using SRV: 74\nds_align jasp dlm8 using SRV: 56\nds_align control dunn using SRV: 197\nds_align cytd dunn using SRV: 92\nds_align jasp dunn using SRV: 91\npairwise_dists for Linear shape is: (623, 623)\ncell_shapes for Linear number is : 623\nds_align control dlm8 using Linear: 113\nds_align cytd dlm8 using Linear: 74\nds_align jasp dlm8 using Linear: 56\nds_align control dunn using Linear: 197\nds_align cytd dunn using Linear: 92\nds_align jasp dunn using Linear: 91\n```\n:::\n:::\n\n\nWe compute the mean cell shape by using the SRV metric defined on the space of curves' shapes. The space of curves' shape is a manifold: we use the Frechet mean, associated to the SRV metric, to get the mean cell shape.\n\nDo not include cells with duplicate points when calculating the mean shapes\n\n::: {#032ae826 .cell execution_count=36}\n``` {.python .cell-code}\ndef check_duplicate(cell):\n \"\"\" \n Return true if there are duplicate points in the cell\n \"\"\"\n for i in range(cell.shape[0]-1):\n cur_coord = cell[i]\n next_coord = cell[i+1]\n if np.linalg.norm(cur_coord-next_coord) == 0:\n return True\n \n # Checking the last point vs the first poit\n if np.linalg.norm(cell[-1]-cell[0]) == 0:\n return True\n \n return False\n```\n:::\n\n\n::: {#def63565 .cell execution_count=37}\n``` {.python .cell-code}\ndelete_indices = []\nfor metric in METRICS:\n for i, cell in reversed(list(enumerate(cell_shapes[metric]))):\n if check_duplicate(cell):\n if i not in delete_indices:\n delete_indices.append(i)\n\n\ncells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align = \\\n remove_cells_two_layer(cells, cell_shapes, lines, treatments, pairwise_dists, ds_proc, ds_align, delete_indices)\n\n```\n:::\n\n\nRecheck cell number after removing cells with duplicated points\n\n::: {#95e617c9 .cell execution_count=38}\n``` {.python .cell-code}\ncheck_num(cell_shapes, treatments, lines, pairwise_dists, ds_align)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ntreatments number is: 623, lines number is: 623\npairwise_dists for SRV shape is: (623, 623)\ncell_shapes for SRV number is : 623\nds_align control dlm8 using SRV: 113\nds_align cytd dlm8 using SRV: 74\nds_align jasp dlm8 using SRV: 56\nds_align control dunn using SRV: 197\nds_align cytd dunn using SRV: 92\nds_align jasp dunn using SRV: 91\npairwise_dists for Linear shape is: (623, 623)\ncell_shapes for Linear number is : 623\nds_align control dlm8 using Linear: 113\nds_align cytd dlm8 using Linear: 74\nds_align jasp dlm8 using Linear: 56\nds_align control dunn using Linear: 197\nds_align cytd dunn using Linear: 92\nds_align jasp dunn using Linear: 91\n```\n:::\n:::\n\n\n::: {#c053fc4b .cell execution_count=39}\n``` {.python .cell-code}\nfrom geomstats.learning.frechet_mean import FrechetMean\n\nmetric = 'SRV'\nCURVES_SPACE_SRV = DiscreteCurvesStartingAtOrigin(ambient_dim=2, k_sampling_points=k_sampling_points)\nmean = FrechetMean(CURVES_SPACE_SRV)\nprint(cell_shapes[metric].shape)\ncells = cell_shapes[metric]\nmean.fit(cells)\n\nmean_estimate = mean.estimate_\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n(623, 1999, 2)\n```\n:::\n:::\n\n\n::: {#06249ecf .cell execution_count=40}\n``` {.python .cell-code}\nmean_estimate_aligned = {}\n\nmean_estimate_clean = mean_estimate[~gs.isnan(gs.sum(mean_estimate, axis=1)), :]\nmean_estimate_aligned[metric] = (\n mean_estimate_clean - gs.mean(mean_estimate_clean, axis=0)\n)\n```\n:::\n\n\nAlso we compute the linear mean\n\n::: {#95edd262 .cell execution_count=41}\n``` {.python .cell-code}\nmetric = 'Linear'\nlinear_mean_estimate = gs.mean(cell_shapes[metric], axis=0)\nlinear_mean_estimate_clean = linear_mean_estimate[~gs.isnan(gs.sum(linear_mean_estimate, axis=1)), :]\n\nmean_estimate_aligned[metric] = (\n linear_mean_estimate_clean - gs.mean(linear_mean_estimate_clean, axis=0)\n)\n```\n:::\n\n\nPlot SRV mean cell versus linear mean cell\n\n::: {#7f2e6375 .cell execution_count=42}\n``` {.python .cell-code}\nfig = plt.figure(figsize=(6, 3))\n\nfig.add_subplot(121)\nmetric = 'SRV'\nplt.plot(mean_estimate_aligned[metric][:, 0], mean_estimate_aligned[metric][:, 1])\nplt.axis(\"equal\")\nplt.title(\"SRV\")\nplt.axis(\"off\")\n\nfig.add_subplot(122)\nmetric = 'Linear'\nplt.plot(mean_estimate_aligned[metric][:, 0], mean_estimate_aligned[metric][:, 1])\nplt.axis(\"equal\")\nplt.title(\"Linear\")\nplt.axis(\"off\")\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, \"global_mean.svg\"))\n plt.savefig(os.path.join(figs_dir, \"global_mean.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-43-output-1.png){width=466 height=261}\n:::\n:::\n\n\n# Analyze Distances to the \"Global\" Mean Shape\n\nWe consider each of the subgroups of cells, defined by their treatment and cell line. We wish to study how far each of this group is from the global mean shape. We compute the list of distances to the global mean shape.\n\n::: {#917bf4bc .cell execution_count=43}\n``` {.python .cell-code}\nmetric = 'SRV'\ndists_to_global_mean = {}\ndists_to_global_mean_list = {}\nprint(mean_estimate_aligned[metric].shape)\n\ndists_to_global_mean[metric] = apply_func_to_ds(\n ds_align[metric], \n func=lambda x: CURVES_SPACE_SRV.metric.dist(x, mean_estimate_aligned[metric])\n)\n\ndists_to_global_mean_list[metric] = []\nfor t in TREATMENTS:\n for l in LINES:\n dists_to_global_mean_list[metric].extend(dists_to_global_mean[metric][t][l])\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n(1999, 2)\n```\n:::\n:::\n\n\nCompute distances to linear mean\n\n::: {#92f1f77a .cell execution_count=44}\n``` {.python .cell-code}\nmetric = 'Linear'\ndists_to_global_mean[metric] = apply_func_to_ds(\n ds_align[metric], func=lambda x: gs.linalg.norm(mean_estimate_aligned[metric] - x) \n)\n\ndists_to_global_mean_list[metric] = []\nfor t in TREATMENTS:\n for l in LINES:\n dists_to_global_mean_list[metric].extend(dists_to_global_mean[metric][t][l])\n```\n:::\n\n\n::: {#fd1ddf13 .cell execution_count=45}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4))\n\nline = 'dlm8'\nkde_dict = {}\nfor j, metric in enumerate(METRICS):\n distances = []\n min_dists = min(dists_to_global_mean_list[metric])\n max_dists = max(dists_to_global_mean_list[metric])\n xx = gs.linspace(gs.floor(min_dists), gs.ceil(max_dists), k_sampling_points)\n kde_dict[metric] = {}\n for i, treatment in enumerate(TREATMENTS):\n distances = dists_to_global_mean[metric][treatment][line][~gs.isnan(dists_to_global_mean[metric][treatment][line])]\n \n \n axs[j].hist(distances, bins=20, alpha=0.4, density=True, label=treatment, color=f\"C{i}\")\n kde = stats.gaussian_kde(distances)\n kde_dict[metric][treatment] = kde\n axs[j].plot(xx, kde(xx), color=f\"C{i}\")\n axs[j].set_xlim((min_dists, max_dists))\n axs[j].legend(fontsize=12)\n\n axs[j].set_title(f\"{metric}\", fontsize=14)\n axs[j].set_ylabel(\"Fraction of cells\", fontsize=14)\n\n\n# fig.suptitle(\"Histograms of SRV distances to global mean cell\", fontsize=20)\n \nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_histogram.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_histogram.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-46-output-1.png){width=758 height=374}\n:::\n:::\n\n\nCalculate the ratio of overlapping regions formed by the kde curves\n\n::: {#a8bd6478 .cell execution_count=46}\n``` {.python .cell-code}\ndef calc_ratio(kde1, kde2, min, max):\n xx = np.linspace(min, max, 1000)\n kde1_values = kde1(xx)\n kde2_values = kde2(xx)\n\n overlap = np.minimum(kde1_values, kde2_values)\n overlap_area = np.trapz(overlap, xx)\n\n bound = np.maximum(kde1_values, kde2_values)\n bound_area = np.trapz(bound, xx)\n\n return overlap_area/bound_area\n```\n:::\n\n\n::: {#6496b542 .cell execution_count=47}\n``` {.python .cell-code}\nfor metric in METRICS:\n min_dists = min(dists_to_global_mean_list[metric])\n max_dists = max(dists_to_global_mean_list[metric])\n for i, tmt1 in enumerate(TREATMENTS):\n for j in range(i+1, len(TREATMENTS)):\n tmt2 = TREATMENTS[j]\n ratio = calc_ratio(kde_dict[metric][tmt1], kde_dict[metric][tmt2], min_dists, max_dists)\n print(f\"Overlap ratio for {line} between {tmt1} and {tmt2} using {metric} metric is: {round(ratio, 2)}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nOverlap ratio for dlm8 between control and cytd using SRV metric is: 0.28\nOverlap ratio for dlm8 between control and jasp using SRV metric is: 0.53\nOverlap ratio for dlm8 between cytd and jasp using SRV metric is: 0.39\nOverlap ratio for dlm8 between control and cytd using Linear metric is: 0.43\nOverlap ratio for dlm8 between control and jasp using Linear metric is: 0.69\nOverlap ratio for dlm8 between cytd and jasp using Linear metric is: 0.59\n```\n:::\n:::\n\n\n::: {#82e1bb50 .cell execution_count=48}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4))\n\nline = 'dunn'\n\nnp.set_printoptions(precision=12)\n\nkde_dict = {}\nfor j, metric in enumerate(METRICS):\n distances = []\n min_dists = min(dists_to_global_mean_list[metric])\n max_dists = max(dists_to_global_mean_list[metric])\n xx = gs.linspace(gs.floor(min_dists), gs.ceil(max_dists), k_sampling_points)\n kde_dict[metric] = {}\n \n for i, treatment in enumerate(TREATMENTS):\n \n distances = dists_to_global_mean[metric][treatment][line][~gs.isnan(dists_to_global_mean[metric][treatment][line])]\n counts, bin_edges, _ = axs[j].hist(distances, bins=20, alpha=0.4, density=True, label=treatment, color=f\"C{i}\")\n print(treatment, metric)\n print(\"counts are:\", counts)\n print(\"bin_edges are:\", bin_edges)\n kde = stats.gaussian_kde(distances)\n kde_dict[metric][treatment] = kde\n axs[j].plot(xx, kde(xx), color=f\"C{i}\")\n axs[j].set_xlim((min_dists, max_dists))\n axs[j].legend(fontsize=12)\n\n axs[j].set_title(f\"{metric}\", fontsize=14)\n axs[j].set_ylabel(\"Fraction of cells\", fontsize=14)\n\n\n# fig.suptitle(\"Histograms of SRV distances to global mean cell\", fontsize=20)\n \nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_histogram.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_histogram.pdf\"))\n```\n\n::: {.cell-output .cell-output-stdout}\n```\ncontrol SRV\ncounts are: [3.599823688084 9.414923491911 9.138013977443 2.492185630212\n 2.215276115744 2.215276115744 2.492185630212 4.15364271702\n 6.092009318296 3.876733202552 2.492185630212 1.38454757234\n 1.107638057872 2.492185630212 0.553819028936 0.\n 0.553819028936 0. 0. 0.276909514468]\nbin_edges are: [0.190412844846 0.208744255891 0.227075666936 0.245407077981\n 0.263738489026 0.28206990007 0.300401311115 0.31873272216\n 0.337064133205 0.35539554425 0.373726955295 0.39205836634\n 0.410389777385 0.42872118843 0.447052599475 0.46538401052\n 0.483715421565 0.50204683261 0.520378243655 0.5387096547\n 0.557041065745]\ncytd SRV\ncounts are: [0.627751614862 0. 1.883254844586 1.255503229724\n 1.255503229724 1.255503229724 1.883254844586 5.649764533759\n 4.394261304035 5.649764533759 8.160770993208 5.649764533759\n 6.905267763483 3.138758074311 2.511006459448 1.255503229724\n 3.138758074311 1.883254844586 0.627751614862 0.627751614862]\nbin_edges are: [0.26221861859 0.279533691877 0.296848765164 0.314163838451\n 0.331478911738 0.348793985025 0.366109058312 0.383424131599\n 0.400739204886 0.418054278173 0.43536935146 0.452684424747\n 0.469999498034 0.487314571321 0.504629644608 0.521944717895\n 0.539259791183 0.55657486447 0.573889937757 0.591205011044\n 0.608520084331]\njasp SRV\ncounts are: [0.928427307436 0.928427307436 0.928427307436 2.785281922307\n 3.713709229743 3.713709229743 4.642136537178 6.49899115205\n 6.49899115205 6.49899115205 9.284273074357 8.355845766921\n 6.49899115205 7.427418459485 2.785281922307 4.642136537178\n 1.856854614871 2.785281922307 1.856854614871 1.856854614871]\nbin_edges are: [0.244313646946 0.256149803531 0.267985960117 0.279822116702\n 0.291658273288 0.303494429873 0.315330586458 0.327166743044\n 0.339002899629 0.350839056215 0.3626752128 0.374511369386\n 0.386347525971 0.398183682557 0.410019839142 0.421855995727\n 0.433692152313 0.445528308898 0.457364465484 0.469200622069\n 0.481036778655]\ncontrol Linear\ncounts are: [0.973976940289 1.704459645506 3.895907761156 3.165425055939\n 4.626390466373 4.139401996228 5.35687317159 5.843861641734\n 4.626390466373 3.408919291012 2.19144811565 2.678436585795\n 1.460965410434 0.730482705217 0.973976940289 0.243494235072\n 0.730482705217 0. 0.243494235072 0.973976940289]\nbin_edges are: [0.084550020208 0.105397093366 0.126244166523 0.147091239681\n 0.167938312838 0.188785385996 0.209632459153 0.230479532311\n 0.251326605468 0.272173678626 0.293020751783 0.313867824941\n 0.334714898098 0.355561971256 0.376409044413 0.397256117571\n 0.418103190728 0.438950263886 0.459797337043 0.480644410201\n 0.501491483358]\ncytd Linear\ncounts are: [2.686991765509 1.343495882754 1.791327843673 2.239159804591\n 2.686991765509 3.582655687345 4.478319609181 4.478319609181\n 5.821815491936 4.030487648263 4.030487648263 1.343495882754\n 1.343495882754 0.447831960918 0. 0.\n 0.447831960918 0. 0. 0.447831960918]\nbin_edges are: [0.18370748819 0.20797901449 0.23225054079 0.25652206709\n 0.28079359339 0.30506511969 0.32933664599 0.35360817229\n 0.37787969859 0.40215122489 0.42642275119 0.45069427749\n 0.47496580379 0.49923733009 0.52350885639 0.54778038269\n 0.57205190899 0.59632343529 0.62059496159 0.64486648789\n 0.669138014189]\njasp Linear\ncounts are: [3.47808161386 5.21712242079 2.608561210395 6.956163227719\n 6.521403025987 6.086642824255 3.912841815592 0.434760201732\n 1.73904080693 0.434760201732 0.434760201732 0.\n 0.434760201732 0.434760201732 0. 0.\n 0. 0. 0.434760201732 0.434760201732]\nbin_edges are: [0.154345044651 0.179621072552 0.204897100452 0.230173128353\n 0.255449156253 0.280725184154 0.306001212054 0.331277239955\n 0.356553267855 0.381829295756 0.407105323656 0.432381351557\n 0.457657379457 0.482933407358 0.508209435258 0.533485463159\n 0.558761491059 0.58403751896 0.60931354686 0.634589574761\n 0.659865602661]\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-49-output-2.png){width=758 height=374}\n:::\n:::\n\n\nCalculate the ratio of overlapping regions formed by the three kde curves \n\n::: {#924abd95 .cell execution_count=49}\n``` {.python .cell-code}\nfor metric in METRICS:\n min_dists = min(dists_to_global_mean_list[metric])\n max_dists = max(dists_to_global_mean_list[metric])\n for i, tmt1 in enumerate(TREATMENTS):\n for j in range(i+1, len(TREATMENTS)):\n tmt2 = TREATMENTS[j]\n ratio = calc_ratio(kde_dict[metric][tmt1], kde_dict[metric][tmt2], min_dists, max_dists)\n print(f\"Overlap ratio for {line} between {tmt1} and {tmt2} using {metric} metric is: {round(ratio, 2)}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nOverlap ratio for dunn between control and cytd using SRV metric is: 0.2\nOverlap ratio for dunn between control and jasp using SRV metric is: 0.4\nOverlap ratio for dunn between cytd and jasp using SRV metric is: 0.35\nOverlap ratio for dunn between control and cytd using Linear metric is: 0.32\nOverlap ratio for dunn between control and jasp using Linear metric is: 0.72\nOverlap ratio for dunn between cytd and jasp using Linear metric is: 0.37\n```\n:::\n:::\n\n\nConduct T-test to test if the two samples have the same expected average\n\n::: {#42319195 .cell execution_count=50}\n``` {.python .cell-code}\nfor line in LINES:\n for i in range(len(TREATMENTS)):\n tmt1 = TREATMENTS[i]\n for j in range(i+1, len(TREATMENTS)):\n tmt2 = TREATMENTS[j]\n for metric in METRICS:\n distance1 = dists_to_global_mean[metric][tmt1][line][~gs.isnan(dists_to_global_mean[metric][tmt1][line])]\n distance2 = dists_to_global_mean[metric][tmt2][line][~gs.isnan(dists_to_global_mean[metric][tmt2][line])]\n t_statistic, p_value = stats.ttest_ind(distance1, distance2)\n print(f\"Significance of differences for {line} between {tmt1} and {tmt2} using {metric} metric is: {'%.2e' % Decimal(p_value)}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nSignificance of differences for dlm8 between control and cytd using SRV metric is: 5.16e-25\nSignificance of differences for dlm8 between control and cytd using Linear metric is: 3.15e-11\nSignificance of differences for dlm8 between control and jasp using SRV metric is: 6.87e-06\nSignificance of differences for dlm8 between control and jasp using Linear metric is: 1.65e-01\nSignificance of differences for dlm8 between cytd and jasp using SRV metric is: 1.10e-09\nSignificance of differences for dlm8 between cytd and jasp using Linear metric is: 1.77e-04\nSignificance of differences for dunn between control and cytd using SRV metric is: 1.29e-41\nSignificance of differences for dunn between control and cytd using Linear metric is: 3.35e-24\nSignificance of differences for dunn between control and jasp using SRV metric is: 1.74e-14\nSignificance of differences for dunn between control and jasp using Linear metric is: 2.50e-03\nSignificance of differences for dunn between cytd and jasp using SRV metric is: 8.05e-16\nSignificance of differences for dunn between cytd and jasp using Linear metric is: 1.97e-10\n```\n:::\n:::\n\n\nLet's analyze bi-modal distribution for the control group of dunn cell line using SRV metric\n\nWe consider two groups: cells with [3.42551653, 3.43015473) - distance to the mean, cells with [3.47189855, 3.47653676) distance to the mean, and find the modes of the two groups\n\n::: {#0d68572a .cell execution_count=51}\n``` {.python .cell-code}\nline = 'dunn'\ntreatment = 'control'\nmetric = 'SRV'\ndistances = dists_to_global_mean[metric][treatment][line]\nprint(min(distances), max(distances))\ngroup_1_left = 0.208744255891 \ngroup_1_right = 0.227075666936\ngroup_2_left = 0.337064133205 \ngroup_2_right = 0.35539554425\ngroup_1_indices = [i for i, element in enumerate(distances) if element <= group_1_right and element > group_1_left]\ngroup_2_indices = [i for i, element in enumerate(distances) if element <= group_2_right and element > group_2_left]\nprint(group_1_indices)\nprint(group_2_indices)\ngroup_1_cells = gs.array(ds_align[metric][treatment][line])[group_1_indices,:,:]\ngroup_2_cells = gs.array(ds_align[metric][treatment][line])[group_2_indices,:,:]\n\ncol_num = max(len(group_1_indices), len(group_2_indices))\nfig = plt.figure(figsize=(2*col_num, 2))\ncount = 1\nfor index in range(len(group_1_indices)):\n cell = group_1_cells[index]\n fig.add_subplot(2, col_num, count)\n count += 1\n plt.plot(cell[:, 0], cell[:, 1])\n plt.axis(\"equal\")\n plt.axis(\"off\")\n\ncount = max(len(group_1_indices), len(group_2_indices))+1\nfor index in range(len(group_2_indices)):\n cell = group_2_cells[index]\n fig.add_subplot(2, col_num, count)\n count += 1\n plt.plot(cell[:, 0], cell[:, 1])\n plt.axis(\"equal\")\n plt.axis(\"off\")\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_bimodal_mean.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_bimodal_mean.pdf\"))\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n0.19041284484557636 0.5570410657452678\n[24, 25, 26, 28, 29, 33, 34, 36, 37, 39, 40, 44, 48, 51, 54, 55, 56, 58, 107, 117, 120, 126, 128, 130, 131, 134, 136, 137, 138, 140, 141, 145, 151, 153]\n[2, 10, 16, 17, 64, 66, 77, 80, 86, 87, 90, 91, 95, 100, 101, 104, 157, 167, 168, 170, 175, 196]\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-52-output-2.png){width=5078 height=167}\n:::\n:::\n\n\n# Visualization of the Mean of each Treatment\n\nThe mean distances to the global mean shape differ. We also plot the mean shape for each of the subgroup, to get intuition on how the mean shape of each subgroup looks like.\n\nWe first calculate the SRV mean\n\n::: {#9b3085ca .cell execution_count=52}\n``` {.python .cell-code}\nmean_treatment_cells = {}\nmetric = 'SRV'\nfor treatment in TREATMENTS:\n treatment_cells = []\n for line in LINES:\n treatment_cells.extend(ds_align[metric][treatment][line])\n mean_estimator = FrechetMean(space=CURVES_SPACE_SRV)\n mean_estimator.fit(CURVES_SPACE_SRV.projection(gs.array(treatment_cells)))\n mean_treatment_cells[treatment] = mean_estimator.estimate_\n```\n:::\n\n\n::: {#c46fc780 .cell execution_count=53}\n``` {.python .cell-code}\nmean_line_cells = {}\nfor line in LINES:\n line_cells = []\n for treatment in TREATMENTS:\n line_cells.extend(ds_align[metric][treatment][line])\n mean_estimator = FrechetMean(space=CURVES_SPACE_SRV)\n mean_estimator.fit(CURVES_SPACE_SRV.projection(gs.array(line_cells)))\n mean_line_cells[line] = mean_estimator.estimate_\n```\n:::\n\n\n::: {#e2fafe71 .cell execution_count=54}\n``` {.python .cell-code}\nmean_cells = {}\nmetric = 'SRV'\nmean_cells[metric] = {}\nfor treatment in TREATMENTS:\n mean_cells[metric][treatment] = {}\n for line in LINES:\n mean_estimator = FrechetMean(space=CURVES_SPACE_SRV)\n mean_estimator.fit(CURVES_SPACE_SRV.projection(gs.array(ds_align[metric][treatment][line])))\n mean_cells[metric][treatment][line] = mean_estimator.estimate_\n```\n:::\n\n\nWe then calculate the linear mean\n\n::: {#bc7e2ed5 .cell execution_count=55}\n``` {.python .cell-code}\nmetric = 'Linear'\nmean_cells[metric] = {}\nfor treatment in TREATMENTS:\n mean_cells[metric][treatment] = {}\n for line in LINES:\n mean_cells[metric][treatment][line] = gs.mean(ds_align[metric][treatment][line], axis=0)\n```\n:::\n\n\nWhile the mean shapes of the control groups (for both cell lines) look regular, we observe that:\n- the mean shape for cytd is the most irregular (for both cell lines)\n- while the mean shape for jasp is more elongated for dlm8 cell line, and more irregular for dunn cell line.\n\n# Distance of the Cell Shapes to their Own Mean Shape\n\nLastly, we evaluate how each subgroup of cell shapes is distributed around the mean shape of their specific subgroup.\n\n::: {#86c04b27 .cell execution_count=56}\n``` {.python .cell-code}\ndists_to_own_mean = {}\n\nfor metric in METRICS:\n dists_to_own_mean[metric] = {}\n for treatment in TREATMENTS:\n dists_to_own_mean[metric][treatment] = {}\n for line in LINES:\n dists = []\n ids = []\n for i_curve, curve in enumerate(ds_align[metric][treatment][line]):\n if metric == 'SRV':\n one_dist = CURVES_SPACE_SRV.metric.dist(curve, mean_cells[metric][treatment][line])\n else:\n one_dist = gs.linalg.norm(curve - mean_cells[metric][treatment][line])\n if ~gs.isnan(one_dist):\n dists.append(one_dist)\n else:\n ids.append(i_curve)\n dists_to_own_mean[metric][treatment][line] = dists\n```\n:::\n\n\n::: {#66ddf144 .cell execution_count=57}\n``` {.python .cell-code}\n# Align with ellipse\n\nline = 'dunn'\n\nfig, axes = plt.subplots(\n ncols=len(TREATMENTS),\n nrows=len(METRICS),\n figsize=(2.5*len(TREATMENTS), 2*len(METRICS)))\n\nfor j, metric in enumerate(METRICS):\n for i, treatment in enumerate(TREATMENTS):\n ax = axes[j, i]\n mean_cell = mean_cells[metric][treatment][line]\n ax.plot(mean_cell[:, 0], mean_cell[:, 1], color=f\"C{i}\")\n ax.axis(\"equal\")\n ax.axis(\"off\")\n ax.set_title(f\"{metric}-{treatment}\", fontsize=20)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_own_mean.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_own_mean.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-58-output-1.png){width=587 height=343}\n:::\n:::\n\n\n::: {#b66a6bac .cell execution_count=58}\n``` {.python .cell-code}\nline = 'dlm8'\n\nfig, axes = plt.subplots(\n ncols=len(TREATMENTS),\n nrows=len(METRICS),\n figsize=(2.5*len(TREATMENTS), 2*len(METRICS)))\n\nfor j, metric in enumerate(METRICS):\n for i, treatment in enumerate(TREATMENTS):\n ax = axes[j, i]\n mean_cell = mean_cells[metric][treatment][line]\n ax.plot(mean_cell[:, 0], mean_cell[:, 1], color=f\"C{i}\")\n ax.axis(\"equal\")\n ax.axis(\"off\")\n ax.set_title(f\"{metric}-{treatment}\", fontsize=20)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_own_mean.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_own_mean.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-59-output-1.png){width=587 height=343}\n:::\n:::\n\n\nWe observe for the linear mean, the means go narrower as going right. This is caused by the start points for the cells align exactly on the right with the start point of the reference cell.\n\nWe notice this artifactual pattern only happens for the linear means (espectially for the cytd group). Can we argue this is an advantage for SRV (reparameterization + SRV mean)? \n\nThe above code find a given number of quantiles within the distance's histogram, using SRV metric and own mean, and plots the corresponding cell, for each treatment and each cell line.\n\n::: {#3795cbf3 .cell execution_count=59}\n``` {.python .cell-code}\nimport scipy.stats as ss\n\nline = 'dunn'\nn_quantiles = 10\n\nfig, axes = plt.subplots(\n nrows=len(TREATMENTS)*len(METRICS),\n ncols=n_quantiles,\n figsize=(20, 2 * len(TREATMENTS) * len(METRICS)),\n)\n\nranks = {}\n\nfor i, treatment in enumerate(TREATMENTS):\n ranks[treatment] = {}\n for j, metric in enumerate(METRICS):\n \n dists_list = dists_to_own_mean[metric][treatment][line]\n dists_list = [d + 0.0001 * gs.random.rand(1)[0] for d in dists_list]\n cells_list = list(ds_align[metric][treatment][line])\n assert len(dists_list) == len(cells_list)\n n_cells = len(dists_list)\n\n ranks[treatment][metric] = ss.rankdata(dists_list)\n\n zipped_lists = zip(dists_list, cells_list)\n sorted_pairs = sorted(zipped_lists)\n\n tuples = zip(*sorted_pairs)\n sorted_dists_list, sorted_cells_list = [list(t) for t in tuples]\n for i_quantile in range(n_quantiles):\n quantile = int(0.1 * n_cells * i_quantile)\n one_cell = sorted_cells_list[quantile]\n ax = axes[2*i+j, i_quantile]\n ax.plot(one_cell[:, 0], one_cell[:, 1], c=f\"C{i}\")\n ax.set_title(f\"0.{i_quantile} quantile\", fontsize=14)\n # ax.axis(\"off\")\n # Turn off tick labels\n ax.set_yticklabels([])\n ax.set_xticklabels([])\n ax.set_xticks([])\n ax.set_yticks([])\n ax.spines[\"top\"].set_visible(False)\n ax.spines[\"right\"].set_visible(False)\n ax.spines[\"bottom\"].set_visible(False)\n ax.spines[\"left\"].set_visible(False)\n if i_quantile == 0:\n ax.set_ylabel(f\"{metric} - \\n {treatment}\", rotation=90, fontsize=18)\nplt.tight_layout()\n# plt.suptitle(f\"Quantiles for linear metric using own mean\", y=-0.01, fontsize=24)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_quantile.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_quantile.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-60-output-1.png){width=1909 height=1142}\n:::\n:::\n\n\nWe do not observe any clear patterns between the rank of the cells with distances using SRV metric and with the linear metric.\n\n::: {#2bf24771 .cell execution_count=60}\n``` {.python .cell-code}\nline = 'dlm8'\nn_quantiles = 10\n\nfig, axes = plt.subplots(\n nrows=len(TREATMENTS)*len(METRICS),\n ncols=n_quantiles,\n figsize=(20, 2 * len(TREATMENTS) * len(METRICS)),\n)\n\nfor i, treatment in enumerate(TREATMENTS):\n for j, metric in enumerate(METRICS):\n dists_list = dists_to_own_mean[metric][treatment][line]\n dists_list = [d + 0.0001 * gs.random.rand(1)[0] for d in dists_list]\n cells_list = list(ds_align[metric][treatment][line])\n assert len(dists_list) == len(dists_list)\n n_cells = len(dists_list)\n\n zipped_lists = zip(dists_list, cells_list)\n sorted_pairs = sorted(zipped_lists)\n\n tuples = zip(*sorted_pairs)\n sorted_dists_list, sorted_cells_list = [list(t) for t in tuples]\n for i_quantile in range(n_quantiles):\n quantile = int(0.1 * n_cells * i_quantile)\n one_cell = sorted_cells_list[quantile]\n ax = axes[2*i+j, i_quantile]\n ax.plot(one_cell[:, 0], one_cell[:, 1], c=f\"C{i}\")\n ax.set_title(f\"0.{i_quantile} quantile\", fontsize=14)\n # ax.axis(\"off\")\n # Turn off tick labels\n ax.set_yticklabels([])\n ax.set_xticklabels([])\n ax.set_xticks([])\n ax.set_yticks([])\n ax.spines[\"top\"].set_visible(False)\n ax.spines[\"right\"].set_visible(False)\n ax.spines[\"bottom\"].set_visible(False)\n ax.spines[\"left\"].set_visible(False)\n if i_quantile == 0:\n ax.set_ylabel(f\"{metric} - \\n {treatment}\", rotation=90, fontsize=18)\nplt.tight_layout()\n# plt.suptitle(f\"Quantiles for linear metric using own mean\", y=-0.01, fontsize=24)\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_quantile.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_quantile.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-61-output-1.png){width=1909 height=1142}\n:::\n:::\n\n\nThe above code find a given number of quantiles within the distance's histogram, using linear metric and own mean, and plots the corresponding cell, for each treatment and each cell line.\n\n# Dimensionality Reduction\n\nWe use the following experiments to illustrate how SRV metric can help with dimensionality reduction \n\n::: {#49f9427f .cell execution_count=61}\n``` {.python .cell-code}\ndef scaled_stress(pos, pairwise_dists):\n \"\"\" \n Calculate the scaled stress invariant to scaling using the original stress \\\n statistics and actual pairwise distances\n\n :param float unscaled_stress: the original stress\n :param 2D np.array[float] pairwise_dists: pairwise distance\n \"\"\"\n \n # compute pairwise distance of pos\n pairwise_pos = np.empty(shape=(pos.shape[0], pos.shape[0]))\n for i in range(pos.shape[0]):\n for j in range(pos.shape[0]):\n pairwise_pos[i,j] = np.sqrt(np.sum(pos[i]-pos[j])**2)\n \n print(pairwise_pos)\n stress = np.sqrt(np.sum((pairwise_dists-pairwise_pos)**2))\n \n return stress/np.sqrt(np.sum(pairwise_dists**2))\n```\n:::\n\n\n::: {#f57a209a .cell execution_count=62}\n``` {.python .cell-code}\nmds = {}\npos = {}\ndims = range(2, 11)\nstresses = {}\n\nfor metric in METRICS:\n mds[metric] = {}\n pos[metric] = {}\n stresses[metric] = []\n for dim in dims:\n mds[metric][dim] = manifold.MDS(n_components=dim, random_state=0, dissimilarity=\"precomputed\") # random_state set to 10\n pos[metric][dim] = mds[metric][dim].fit(pairwise_dists[metric]).embedding_\n stress_val = mds[metric][dim].stress_\n scaled_stress_val = np.sqrt(stress_val/((pairwise_dists[metric]**2).sum()/2))\n # scaled_stress_val = scaled_stress(pos[metric][dim], pairwise_dists[metric])\n\n print(f\"the unscaled stress for {metric} model is for {dim}:\", stress_val)\n stresses[metric].append(scaled_stress_val)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nthe unscaled stress for SRV model is for 2: 0.0015505150986308987\nthe unscaled stress for SRV model is for 3: 0.0009766856050873998\nthe unscaled stress for SRV model is for 4: 0.0007390199671520337\nthe unscaled stress for SRV model is for 5: 0.0005748305174444293\nthe unscaled stress for SRV model is for 6: 0.00047113942181298865\nthe unscaled stress for SRV model is for 7: 0.0003990770585748401\nthe unscaled stress for SRV model is for 8: 0.00034641999727906943\nthe unscaled stress for SRV model is for 9: 0.00030596906074277627\nthe unscaled stress for SRV model is for 10: 0.00027546016788315334\nthe unscaled stress for Linear model is for 2: 0.0012568732933103922\nthe unscaled stress for Linear model is for 3: 0.0008789553123291832\nthe unscaled stress for Linear model is for 4: 0.0007370740946128706\nthe unscaled stress for Linear model is for 5: 0.0006365408960217103\nthe unscaled stress for Linear model is for 6: 0.0005664042865819429\nthe unscaled stress for Linear model is for 7: 0.0005223292015115522\nthe unscaled stress for Linear model is for 8: 0.0004846528585517728\nthe unscaled stress for Linear model is for 9: 0.00046151351278745815\nthe unscaled stress for Linear model is for 10: 0.0004397214282582284\n```\n:::\n:::\n\n\n::: {#600cc2c8 .cell execution_count=63}\n``` {.python .cell-code}\nplt.figure(figsize = (4,4))\nfor metric in METRICS:\n plt.scatter(dims, stresses[metric], label=metric)\n plt.plot(dims, stresses[metric])\nplt.xticks(dims)\nplt.legend()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"MDS_stress.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"MDS_stress.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-64-output-1.png){width=356 height=337}\n:::\n:::\n\n\nIn terms of the scaled stress statistics, we observe linear metric perform better than SRV metric. That is, linear metric preserves the pairwise distances in embedded dimension better than the SRV metric.\n\nCalculate MDS statistics for dimension 2\n\n::: {#285a1986 .cell execution_count=64}\n``` {.python .cell-code}\nmetric = 'SRV'\nmds = manifold.MDS(n_components=2, random_state=0, dissimilarity=\"precomputed\")\npos = mds.fit(pairwise_dists[metric]).embedding_\n```\n:::\n\n\nMDS embedding of cell treatments (control, cytd and jasp) for different cell lines (dunn and dlm8)\n\n::: {#9416b2ea .cell execution_count=65}\n``` {.python .cell-code}\nembs = {}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\nWe draw a comparison with linear metric using the following code\n\n::: {#74b60b04 .cell execution_count=66}\n``` {.python .cell-code}\nmetric = 'Linear'\nmds = manifold.MDS(n_components=2, random_state=0, dissimilarity=\"precomputed\")\npos = mds.fit(pairwise_dists[metric]).embedding_\nprint(\"the stress for linear model is:\", mds.stress_)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nthe stress for linear model is: 0.0012568732933103922\n```\n:::\n:::\n\n\n::: {#6eb1a6db .cell execution_count=67}\n``` {.python .cell-code}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\nThe stress for MDS embedding using the linear metric is better than SRV metric. \n\nHowever, if we can make a better interpretation of the visual result of SRV metric, we could still argue SRV is better at capturing cell heterogeneity. \n\n::: {#3d0d93ac .cell execution_count=68}\n``` {.python .cell-code}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\n::: {#025e5d13 .cell execution_count=69}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4))\n\nline = 'dunn'\nfor j, metric in enumerate(METRICS):\n for i, treatment in enumerate(TREATMENTS):\n cur_embs = embs[metric][treatment][line]\n axs[j].scatter(\n cur_embs[:, 0],\n cur_embs[:, 1],\n label=treatment,\n s=10,\n alpha=0.4\n )\n # axs[j].set_xlim(-3.5*1e-5, 3.5*1e-5)\n axs[j].set_xlabel(\"First Dimension\")\n axs[j].set_ylabel(\"Second Dimension\")\n axs[j].legend()\n axs[j].set_title(f\"{metric}\")\n# fig.suptitle(\"MDS of cell shapes using SRV metric\", fontsize=20)\n\nplt.tight_layout()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_2D.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_2D.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-70-output-1.png){width=758 height=374}\n:::\n:::\n\n\n::: {#c02d2a85 .cell execution_count=70}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4))\n\nline = 'dlm8'\nfor j, metric in enumerate(METRICS):\n distances = []\n for i, treatment in enumerate(TREATMENTS):\n cur_embs = embs[metric][treatment][line]\n axs[j].scatter(\n cur_embs[:, 0],\n cur_embs[:, 1],\n label=treatment,\n s=10,\n alpha=0.4\n )\n # axs[j].set_xlim(-3.5*1e-5, 3.5*1e-5)\n axs[j].set_xlabel(\"First Dimension\")\n axs[j].set_ylabel(\"Second Dimension\")\n axs[j].legend()\n axs[j].set_title(f\"{metric}\")\n# fig.suptitle(\"MDS of cell shapes using SRV metric\", fontsize=20)\n\nplt.tight_layout()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_2D.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_2D.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-71-output-1.png){width=758 height=374}\n:::\n:::\n\n\nWe also consider embedding in 3D. \n\n::: {#6c5df327 .cell execution_count=71}\n``` {.python .cell-code}\nmetric = 'SRV'\nmds = manifold.MDS(n_components=3, random_state=0, dissimilarity=\"precomputed\")\npos = mds.fit(pairwise_dists[metric]).embedding_\n```\n:::\n\n\n::: {#5436b1f2 .cell execution_count=72}\n``` {.python .cell-code}\nembs = {}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\n::: {#78765feb .cell execution_count=73}\n``` {.python .cell-code}\nmetric = 'Linear'\nmds = manifold.MDS(n_components=3, random_state=1, dissimilarity=\"precomputed\")\npos = mds.fit(pairwise_dists[metric]).embedding_\nprint(\"the stress for linear model is:\", mds.stress_)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nthe stress for linear model is: 0.0008821306413255005\n```\n:::\n:::\n\n\n::: {#4dfa79ec .cell execution_count=74}\n``` {.python .cell-code}\nembs[metric] = {}\nindex = 0\nfor treatment in TREATMENTS:\n embs[metric][treatment] = {}\n for line in LINES:\n cell_num = len(ds_align[metric][treatment][line]) \n embs[metric][treatment][line] = pos[index:index+cell_num]\n index += cell_num\n```\n:::\n\n\n::: {#b27b3001 .cell execution_count=75}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4), subplot_kw=dict(projection='3d'))\n\nline = 'dunn'\nfor j, metric in enumerate(METRICS):\n distances = []\n for i, treatment in enumerate(TREATMENTS):\n cur_embs = embs[metric][treatment][line]\n axs[j].scatter(\n cur_embs[:, 0],\n cur_embs[:, 1],\n cur_embs[:, 2],\n label=treatment,\n s=10,\n alpha=0.4\n )\n axs[j].set_xlabel(\"First Dimension\")\n axs[j].set_ylabel(\"Second Dimension\")\n axs[j].legend()\n axs[j].set_title(f\"{metric}\")\n# fig.suptitle(\"MDS of cell shapes using linear metric\", fontsize=20)\n\nplt.tight_layout()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_3D.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_3D.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-76-output-1.png){width=777 height=398}\n:::\n:::\n\n\n::: {#c23dcb8a .cell execution_count=76}\n``` {.python .cell-code}\nfig, axs = plt.subplots(1, 2, sharex=False, sharey=False, tight_layout=True, figsize=(8, 4), subplot_kw=dict(projection='3d'))\n\nline = 'dlm8'\nfor j, metric in enumerate(METRICS):\n distances = []\n for i, treatment in enumerate(TREATMENTS):\n cur_embs = embs[metric][treatment][line]\n axs[j].scatter(\n cur_embs[:, 0],\n cur_embs[:, 1],\n cur_embs[:, 2],\n label=treatment,\n s=10,\n alpha=0.4\n )\n # axs[j].set_xlim(-3.5*1e-5, 3.5*1e-5)\n axs[j].set_xlabel(\"First Dimension\")\n axs[j].set_ylabel(\"Second Dimension\")\n axs[j].legend()\n axs[j].set_title(f\"{metric}\")\n# fig.suptitle(\"MDS of cell shapes using linear metric\", fontsize=20)\n\nplt.tight_layout()\n\nif savefig:\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_3D.svg\"))\n plt.savefig(os.path.join(figs_dir, f\"{line}_MDS_3D.pdf\"))\n```\n\n::: {.cell-output .cell-output-display}\n![](osteosarcoma_analysis_files/figure-html/cell-77-output-1.png){width=777 height=394}\n:::\n:::\n\n\n# Multi-class (3-class) classification \n\nWe now consider one cell line at the same time, to investigate the effects of the drugs on the cell shapes. Applying the MDS again gives the following results:\n\nSince the detected subspace dimension for this dataset is 3, we perform the classification based on 3D embeddings.\n\n::: {#c46249a8 .cell execution_count=77}\n``` {.python .cell-code}\nfrom sklearn.metrics import precision_score, recall_score, accuracy_score\n\ndef svm_5_fold_classification(X, y):\n # Initialize a Support Vector Classifier\n svm_classifier = svm.SVC(kernel='poly', degree=4)\n\n # Prepare to split the data into 5 folds, maintaining the percentage of samples for each class\n skf = StratifiedKFold(n_splits=5)\n \n # To store precision and recall per class for each fold\n precisions_per_class = []\n recalls_per_class = []\n accuracy_per_class = []\n\n # Perform 5-fold cross-validation\n for train_index, test_index in skf.split(X, y):\n # Splitting data into training and test sets\n X_train, X_test = X[train_index], X[test_index]\n y_train, y_test = y[train_index], y[test_index]\n\n # Train the model\n svm_classifier.fit(X_train, y_train)\n \n # Predict on the test data\n y_pred = svm_classifier.predict(X_test)\n\n # Calculate precision and recall per class\n precision = precision_score(y_test, y_pred, average=None, zero_division=np.nan)\n recall = recall_score(y_test, y_pred, average=None, zero_division=np.nan)\n accuracy = accuracy_score(y_test, y_pred)\n\n # Store results from each fold\n precisions_per_class.append(precision)\n recalls_per_class.append(recall)\n accuracy_per_class.append(accuracy)\n \n # Calculate the mean precision and recall per class across all folds\n mean_precisions = np.mean(precisions_per_class, axis=0)\n mean_recalls = np.mean(recalls_per_class, axis=0)\n mean_accuracies = np.mean(accuracy_per_class, axis=0)\n \n print(\"Mean precisions per class across all folds:\", round(np.mean(mean_precisions), 2))\n print(\"Mean recalls per class across all folds:\", round(np.mean(mean_recalls), 2))\n print(\"Mean accuracies per class across all folds:\", round(mean_accuracies, 2))\n\n return mean_precisions, mean_recalls\n```\n:::\n\n\n::: {#a597b67d .cell execution_count=78}\n``` {.python .cell-code}\nlines = gs.array(lines)\ntreatments = gs.array(treatments)\n```\n:::\n\n\n::: {#7e32955c .cell execution_count=79}\n``` {.python .cell-code}\nfor line in LINES:\n for metric in METRICS:\n control_indexes = gs.where((lines == line) & (treatments == \"control\"))[0]\n cytd_indexes = gs.where((lines == line) & (treatments == \"cytd\"))[0]\n jasp_indexes = gs.where((lines == line) & (treatments == \"jasp\"))[0]\n treatment_indexes = gs.where((lines == line) & (treatments != 'control'))[0]\n\n # indexes = gs.concatenate((jasp_indexes, cytd_indexes, control_indexes))\n indexes = gs.concatenate((control_indexes, treatment_indexes))\n matrix = pairwise_dists[metric][indexes][:, indexes]\n\n mds = manifold.MDS(n_components=2, random_state = 10, dissimilarity=\"precomputed\")\n pos = mds.fit(matrix).embedding_\n\n line_treatments = treatments[lines == line]\n line_treatments_strings, line_treatments_labels = np.unique(line_treatments, return_inverse=True)\n # print(line_treatments_strings)\n # print(line_treatments_labels)\n\n for i, label in enumerate(line_treatments_labels):\n if line_treatments_strings[label] == 'cytd' or line_treatments_strings[label] == 'jasp':\n line_treatments_labels[i] = len(line_treatments_strings)\n \n\n print(f\"Using {metric} on {line}\")\n # print(line_treatments_labels)\n svm_5_fold_classification(pos, line_treatments_labels)\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nUsing SRV on dlm8\nMean precisions per class across all folds: 0.71\nMean recalls per class across all folds: 0.7\nMean accuracies per class across all folds: 0.69\nUsing Linear on dlm8\nMean precisions per class across all folds: 0.68\nMean recalls per class across all folds: 0.62\nMean accuracies per class across all folds: 0.6\nUsing SRV on dunn\nMean precisions per class across all folds: 0.73\nMean recalls per class across all folds: 0.69\nMean accuracies per class across all folds: 0.7\nUsing Linear on dunn\nMean precisions per class across all folds: 0.62\nMean recalls per class across all folds: 0.59\nMean accuracies per class across all folds: 0.6\n```\n:::\n:::\n\n\n", "supporting": [ "osteosarcoma_analysis_files" ], diff --git a/_freeze/posts/elastic-metric/osteosarcoma_analysis/figure-html/cell-17-output-1.png b/_freeze/posts/elastic-metric/osteosarcoma_analysis/figure-html/cell-17-output-1.png index 58f4cb7..9d2fc5d 100644 Binary files a/_freeze/posts/elastic-metric/osteosarcoma_analysis/figure-html/cell-17-output-1.png and b/_freeze/posts/elastic-metric/osteosarcoma_analysis/figure-html/cell-17-output-1.png differ diff --git a/posts/elastic-metric/introduction.qmd b/posts/elastic-metric/elastic_metric.qmd similarity index 98% rename from posts/elastic-metric/introduction.qmd rename to posts/elastic-metric/elastic_metric.qmd index 2307c51..c51ffcf 100644 --- a/posts/elastic-metric/introduction.qmd +++ b/posts/elastic-metric/elastic_metric.qmd @@ -1,8 +1,8 @@ --- title: "Elastic metric" bibliography: refs.bib -# engine: "jupyter" -# jupyter: "python3" +# engine: /home/wanxinli/miniconda3/envs/main@92c7a58/bin/python3 +# jupyter: "" author: - name: "Wanxin Li" diff --git a/posts/elastic-metric/figs/illustration/binarized_cells.png b/posts/elastic-metric/figs/illustration/binarized_cells.png new file mode 100644 index 0000000..bf89710 Binary files /dev/null and b/posts/elastic-metric/figs/illustration/binarized_cells.png differ diff --git a/posts/elastic-metric/figs/illustration/cells_image.png b/posts/elastic-metric/figs/illustration/cells_image.png new file mode 100644 index 0000000..6c275de Binary files /dev/null and b/posts/elastic-metric/figs/illustration/cells_image.png differ diff --git a/posts/elastic-metric/index.qmd b/posts/elastic-metric/index.qmd deleted file mode 100644 index f561c8d..0000000 --- a/posts/elastic-metric/index.qmd +++ /dev/null @@ -1,35 +0,0 @@ ---- -title: "Elastic metric for cell shape analysis" - -engine: "jupyter" -author: - - name: "Clément Soubrier" - email: "c.soubrier @math.ubc.ca" - affiliations: - - name: KDD Group - url: "https://rtviii.xyz/" - - - name: "Khanh Dao Duc" - email: "kdd@math.ubc.ca" - affiliations: - - name: Department of Mathematics, UBC - url: "https://www.math.ubc.ca/" - - name: Department of Computer Science, UBC - url: "https://www.cs.ubc.ca/" - -date: "July 30 2024" -categories: [biology, bioinformatics] - -callout-icon: false -# format: -# pdf: -# include-in-header: -# text: | -# \usepackage{amsmath} - -execute: - echo: false - freeze: auto - pip: ["pyvista", "open3d", "scikit-learn", "mendeleev", "compas", "matplotlib"] - ---- \ No newline at end of file diff --git a/posts/elastic-metric/osteosarcoma_analysis.qmd b/posts/elastic-metric/osteosarcoma_analysis.qmd index 6cfc75c..d28ea2b 100644 --- a/posts/elastic-metric/osteosarcoma_analysis.qmd +++ b/posts/elastic-metric/osteosarcoma_analysis.qmd @@ -14,7 +14,7 @@ This analysis relies on the *elastic metric between discrete curves* from Geomst This notebook is adapted from Florent Michel's submission to the [ICLR 2021 Computational Geometry and Topology challenge](https://github.com/geomstats/challenge-iclr-2021).
- +
Figure 1: Representative images of the cell lines using fluorescence microscopy, studied in this notebook (Image credit : Ashok Prasad). The cells nuclei (blue), the actin cytoskeleton (green) and the lipid membrane (red) of each cell are stained and colored. We only focus on the cell shape in our analysis. @@ -62,7 +62,7 @@ We study a dataset of mouse *Osteosarcoma* imaged cells [(AXCFP2019)](#Reference Each cell comes from a raw image containing a set of cells, which was thresholded to generate binarized images. - + After binarizing the images, contouring was used to isolate each cell, and to extract their boundaries as a counter-clockwise ordered list of 2D coordinates, which corresponds to the representation of discrete curve in Geomstats. We load these discrete curves into the notebook.