{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# LineageOT on a convergent trajectory\n\nThis shows results of applying LineageOT to a simulation of convergent trajectories, closely following ``simulation_demo.ipynb`` in the source code.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import copy\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport ot\n\nimport lineageot.simulation as sim\nimport lineageot.evaluation as sim_eval\nimport lineageot.inference as sim_inf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Generating simulated data\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "flow_type = 'convergent'\nnp.random.seed(257)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Setting simulation parameters\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if flow_type == 'bifurcation':\n    timescale = 1\nelse:\n    timescale = 100\n\nx0_speed = 1/timescale\n\n\nsim_params = sim.SimulationParameters(division_time_std = 0.01*timescale,\n                                      flow_type = flow_type,\n                                      x0_speed = x0_speed,\n                                      mutation_rate = 1/timescale,\n                                      mean_division_time = 1.1*timescale,\n                                      timestep = 0.001*timescale\n                                     )\n\nmean_x0_early = 2\ntime_early = 4*timescale # Time when early cells are sampled\ntime_late = time_early + 4*timescale # Time when late cells are sampled\nx0_initial = mean_x0_early -time_early*x0_speed\ninitial_cell = sim.Cell(np.array([x0_initial, 0, 0]), np.zeros(sim_params.barcode_length))\nsample_times = {'early' : time_early, 'late' : time_late}\n\n# Choosing which of the three dimensions to show in later plots\nif flow_type == 'mismatched_clusters':\n    dimensions_to_plot = [1,2]\nelse:\n    dimensions_to_plot = [0,1]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Running the simulation\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "sample = sim.sample_descendants(initial_cell.deepcopy(), time_late, sim_params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Processing simulation output\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Extracting trees and barcode matrices\ntrue_trees = {'late':sim_inf.list_tree_to_digraph(sample)}\ntrue_trees['late'].nodes['root']['cell'] = initial_cell\n\ntrue_trees['early'] = sim_inf.truncate_tree(true_trees['late'], sample_times['early'], sim_params)\n\n# Computing the ground-truth coupling\ncouplings = {'true': sim_inf.get_true_coupling(true_trees['early'], true_trees['late'])}\n\ndata_arrays = {'late' : sim_inf.extract_data_arrays(true_trees['late'])}\nrna_arrays = {'late': data_arrays['late'][0]}\nbarcode_arrays = {'late': data_arrays['late'][1]}\n\nrna_arrays['early'] = sim_inf.extract_data_arrays(true_trees['early'])[0]\nnum_cells = {'early': rna_arrays['early'].shape[0], 'late': rna_arrays['late'].shape[0]}\n\nprint(\"Times    : \", sample_times)\nprint(\"Number of cells: \", num_cells)\n\n# Creating a copy of the true tree for use in LineageOT\ntrue_trees['late, annotated'] = copy.deepcopy(true_trees['late'])\nsim_inf.add_node_times_from_division_times(true_trees['late, annotated'])\nsim_inf.add_nodes_at_time(true_trees['late, annotated'], sample_times['early']);\n\n\n# Scatter plot of cell states\n\ncmap = \"coolwarm\"\ncolors = [plt.get_cmap(cmap)(0), plt.get_cmap(cmap)(256)]\nfor a,label, c in zip([rna_arrays['early'], rna_arrays['late']], ['Early cells', 'Late cells'], colors):\n    plt.scatter(a[:, dimensions_to_plot[0]], a[:, dimensions_to_plot[1]], alpha = 0.4, label = label, color = c)\n\n\nplt.xlabel('Gene ' + str(dimensions_to_plot[0] + 1))\nplt.ylabel('Gene ' + str(dimensions_to_plot[1] + 1))\nplt.legend();"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Since these are simulations, we can compute and plot inferred ancestor locations based on the true tree.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Infer ancestor locations for the late cells based on the true lineage tree\nobserved_nodes = [n for n in sim_inf.get_leaves(true_trees['late, annotated'], include_root=False)]\nsim_inf.add_conditional_means_and_variances(true_trees['late, annotated'], observed_nodes)\n\nancestor_info = {'true tree':sim_inf.get_ancestor_data(true_trees['late, annotated'], sample_times['early'])}\n\n# Scatter plot of cell states, with inferred ancestor locations for the late cells\n\nfor a,label, c in zip([rna_arrays['early'], rna_arrays['late']], ['Early cells', 'Late cells'], colors):\n    plt.scatter(a[:, dimensions_to_plot[0]], a[:, dimensions_to_plot[1]], alpha = 0.4, label = label, color = c)\n\nplt.scatter(ancestor_info['true tree'][0][:,dimensions_to_plot[0]],\n            ancestor_info['true tree'][0][:,dimensions_to_plot[1]],\n            alpha = 0.1,\n            label = 'Inferred ancestors',\n            color = 'green')\nplt.xlabel('Gene ' + str(dimensions_to_plot[0] + 1))\nplt.ylabel('Gene ' + str(dimensions_to_plot[1] + 1))\nplt.legend();"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To better visualize cases where there were two clusters at the early time point,\nwe can color the late cells (and their inferred ancestors) by their cluster of origin\nCells in orange are from the late time point with ancestors on the left; \ncells in green are from the late time point with ancestors on the right.\nThough the green and orange distributions substantially overlap, the estimated ancestor distributions\nin red and purple are separate.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "is_from_left = sim_inf.extract_ancestor_data_arrays(true_trees['late'], sample_times['early'], sim_params)[0][:,1] < 0\nfor a,label in zip([rna_arrays['early'], rna_arrays['late'][is_from_left,:], rna_arrays['late'][~is_from_left,:]], ['Early cells', 'Late cells from left', 'Late cells from right']):\n    plt.scatter(a[:, 1], a[:, 2], alpha = 0.4)\n\nplt.xlabel('Gene 2')\nplt.ylabel('Gene 3')\n\n\nfor a, label in zip([ancestor_info['true tree'][0][is_from_left, :], ancestor_info['true tree'][0][~is_from_left, :]], ['Left ancestors', 'Right ancestors']):\n    plt.scatter(a[:,1], a[:,2], alpha = 0.4, label = label)\nplt.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Running LineageOT\nThe first step is to fit a lineage tree to observed barcodes\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# True distances\ntrue_distances = {key:sim_inf.compute_tree_distances(true_trees[key]) for key in true_trees}\n\n\n# Estimate mutation rate from fraction of unmutated barcodes\nrate_estimate = sim_inf.rate_estimator(barcode_arrays['late'], sample_times['late'])\n\n# Compute Hamming distance matrices for neighbor joining\nhamming_distances_with_roots = {'late':sim_inf.barcode_distances(np.concatenate([barcode_arrays['late'],\n                                                                                 np.zeros([1,sim_params.barcode_length])]))}\n\n# Compute neighbor-joining tree\nfitted_tree = sim_inf.neighbor_join(hamming_distances_with_roots['late'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Once the tree is computed, we need to annotate it with node times and states\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Annotate fitted tree with internal node times\nsim_inf.add_leaf_barcodes(fitted_tree, barcode_arrays['late'])\nsim_inf.add_leaf_x(fitted_tree, rna_arrays['late'])\nsim_inf.add_leaf_times(fitted_tree, sample_times['late'])\nsim_inf.annotate_tree(fitted_tree,\n                  rate_estimate*np.ones(sim_params.barcode_length),\n                  time_inference_method = 'least_squares');\n\n# Add inferred ancestor nodes and states\nsim_inf.add_node_times_from_division_times(fitted_tree)\nsim_inf.add_nodes_at_time(fitted_tree, sample_times['early'])\nobserved_nodes = [n for n in sim_inf.get_leaves(fitted_tree, include_root = False)]\nsim_inf.add_conditional_means_and_variances(fitted_tree, observed_nodes)\nancestor_info['fitted tree'] = sim_inf.get_ancestor_data(fitted_tree, sample_times['early'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We're now ready to compute LineageOT cost matrices\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Compute cost matrices for each method\ncoupling_costs = {}\ncoupling_costs['lineageOT, true tree'] = ot.utils.dist(rna_arrays['early'], ancestor_info['true tree'][0])@np.diag(ancestor_info['true tree'][1]**(-1))\ncoupling_costs['OT'] = ot.utils.dist(rna_arrays['early'], rna_arrays['late'])\ncoupling_costs['lineageOT, fitted tree'] = ot.utils.dist(rna_arrays['early'], ancestor_info['fitted tree'][0])@np.diag(ancestor_info['fitted tree'][1]**(-1))\n\n\nearly_time_rna_cost = ot.utils.dist(rna_arrays['early'], sim_inf.extract_ancestor_data_arrays(true_trees['late'], sample_times['early'], sim_params)[0])\nlate_time_rna_cost = ot.utils.dist(rna_arrays['late'], rna_arrays['late'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Given the cost matrices, we can fit couplings with a range of entropy parameters.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "epsilons = np.logspace(-2, 3, 15)\n\ncouplings['OT'] = ot.emd([],[],coupling_costs['OT'])\ncouplings['lineageOT'] = ot.emd([], [], coupling_costs['lineageOT, true tree'])\ncouplings['lineageOT, fitted'] = ot.emd([], [], coupling_costs['lineageOT, fitted tree'])\nfor e in epsilons:\n    if e >=0.1:\n        f = ot.sinkhorn\n    else:\n        # Epsilon scaling is more robust at smaller epsilon, but slower than simple sinkhorn\n        f = ot.bregman.sinkhorn_epsilon_scaling\n    couplings['entropic rna ' + str(e)] = f([],[],coupling_costs['OT'], e)\n    couplings['lineage entropic rna ' + str(e)] = f([], [], coupling_costs['lineageOT, true tree'], e*np.mean(ancestor_info['true tree'][1]**(-1)))\n    couplings['fitted lineage rna ' + str(e)] = f([], [], coupling_costs['lineageOT, fitted tree'], e*np.mean(ancestor_info['fitted tree'][1]**(-1)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Evaluation of couplings\nFirst compute the independent coupling as a reference\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "couplings['independent'] = np.ones(couplings['OT'].shape)/couplings['OT'].size\nind_ancestor_error = sim_inf.OT_cost(couplings['independent'], early_time_rna_cost)\nind_descendant_error = sim_inf.OT_cost(sim_eval.expand_coupling(couplings['independent'],\n                                                                couplings['true'],\n                                                                late_time_rna_cost),\n                                       late_time_rna_cost)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plotting the accuracy of ancestor prediction\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ancestor_errors = sim_eval.plot_metrics(couplings,\n                                        lambda x:sim_inf.OT_cost(x, early_time_rna_cost),\n                                        'Normalized ancestor error',\n                                        epsilons,\n                                        scale = ind_ancestor_error,\n                                        points=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plotting the accuracy of descendant prediction\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "descendant_errors = sim_eval.plot_metrics(couplings,\n                                          lambda x:sim_inf.OT_cost(sim_eval.expand_coupling(x,\n                                                                                            couplings['true'],\n                                                                                            late_time_rna_cost),\n                                                                   late_time_rna_cost),\n                                          'Normalized descendant error',\n                                          epsilons, scale = ind_descendant_error)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Coupling visualizations\nVisualizing the ground-truth coupling, zero-entropy LineageOT coupling, and zero-entropy optimal transport coupling.\n\nGround truth:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "sim_eval.plot2D_samples_mat(rna_arrays['early'][:, [dimensions_to_plot[0],dimensions_to_plot[1]]],\n                   rna_arrays['late'][:, [dimensions_to_plot[0],dimensions_to_plot[1]]],\n                   couplings['true'],\n                   c=[0.2, 0.8, 0.5],\n                   alpha_scale = 0.1)\n\nplt.xlabel('Gene ' + str(dimensions_to_plot[0] + 1))\nplt.ylabel('Gene ' + str(dimensions_to_plot[1] + 1))\nplt.title('True coupling')\n\n\nfor a,label, c in zip([rna_arrays['early'], rna_arrays['late']], ['Early cells', 'Late cells'], colors):\n    plt.scatter(a[:, dimensions_to_plot[0]], a[:, dimensions_to_plot[1]], alpha = 0.4, label = label, color = c)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "LineageOT:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "sim_eval.plot2D_samples_mat(rna_arrays['early'][:, [dimensions_to_plot[0],dimensions_to_plot[1]]],\n                   rna_arrays['late'][:, [dimensions_to_plot[0],dimensions_to_plot[1]]],\n                   couplings['lineageOT'],\n                   c=[0.2, 0.8, 0.5],\n                   alpha_scale = 0.1)\nplt.xlabel('Gene ' + str(dimensions_to_plot[0] + 1))\nplt.ylabel('Gene ' + str(dimensions_to_plot[1] + 1))\nplt.title('LineageOT coupling')\n\nfor a,label, c in zip([rna_arrays['early'], rna_arrays['late']], ['Early cells', 'Late cells'], colors):\n    plt.scatter(a[:, dimensions_to_plot[0]], a[:, dimensions_to_plot[1]], alpha = 0.4, label = label, color = c)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Optimal transport\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "sim_eval.plot2D_samples_mat(rna_arrays['early'][:, [dimensions_to_plot[0],dimensions_to_plot[1]]],\n                   rna_arrays['late'][:, [dimensions_to_plot[0],dimensions_to_plot[1]]],\n                   couplings['OT'],\n                   c=[0.2, 0.8, 0.5],\n                   alpha_scale = 0.1)\nplt.xlabel('Gene ' + str(dimensions_to_plot[0] + 1))\nplt.ylabel('Gene ' + str(dimensions_to_plot[1] + 1))\nplt.title('OT coupling')\n\n\nfor a,label, c in zip([rna_arrays['early'], rna_arrays['late']], ['Early cells', 'Late cells'], colors):\n    plt.scatter(a[:, dimensions_to_plot[0]], a[:, dimensions_to_plot[1]], alpha = 0.4, label = label, color = c)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}