{ "cells": [ { "cell_type": "markdown", "id": "32ef1567-c0a8-4ae3-a9b0-1d259729b526", "metadata": {}, "source": [ "Example 3: Clustering based on structural similarity\n", "====================================================" ] }, { "cell_type": "code", "execution_count": 1, "id": "86789648-6b77-4383-9a7b-cb64797f22fe", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/elijah/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import numpy as np\n", "import itertools as it\n", "import multiprocess\n", "import matplotlib.pyplot as plt\n", "import time\n", "import os\n", "from umap.umap_ import UMAP\n", "\n", "from GWProt.GW_protein import *\n", "from GWProt.GW_protein_pI import *" ] }, { "cell_type": "markdown", "id": "dfb1bd4d-730d-4086-832d-59b902fe21e5", "metadata": {}, "source": [ "Riboviruses are a diverse group of viruses that include many human pathogens. They all encode an RNA-dependent RNA polymerase (RdRp), which is essential for replication, evolves rapidly, and often shows very low sequence homology. For these reasons, structural comparison is frequently more informative than sequence alone for identification and classification (see [1](https://peerj.com/articles/14055/) and [2](https://www.nature.com/articles/s41586-021-04332-2)).\n", "\n", "In this notebook we use a curated set of ~300 computationally folded RdRp core domains to illustrate how Gromov–Wasserstein (GW) distances can quantify structural similarity. We compare different GW variants available in GWProt (including fused-GW) to show their tradeoffs in clustering accuracy and runtime. The PDB files and metadata for this example are in the GWProt repository: https://github.com/CamaraLab/GWProt/tree/main/docs/Examples/Example_Data/RdRp%20Proteins.\n", "\n", "Below we define a small helper for computing all-vs-all GW distance matrices in parallel." ] }, { "cell_type": "code", "execution_count": 2, "id": "889ab742-5074-438d-84bf-1fd0da9ec275", "metadata": {}, "outputs": [], "source": [ "import multiprocess\n", "import itertools as it\n", "import numpy as np\n", "import time\n", "import GWProt.GW_protein\n", "\n", "def compute_in_parallel(proteins, comparison_method, **kwargs):\n", " N = len(proteins)\n", " dist_mat = np.zeros((N,N))\n", " start_time = time.time()\n", "\n", " with multiprocess.Pool() as pool:\n", " for r in pool.imap(lambda pair: (pair[0], pair[1], comparison_method(proteins[pair[0]], proteins[pair[1]], **kwargs)), it.combinations(range(N),2), chunksize = 32):\n", " i,j,d = r\n", " dist_mat[i,j] = d\n", " dist_mat[j,i] = d\n", "\n", " run_time = int(time.time() - start_time)\n", " \n", " print(f'run time = {run_time//60} min, {run_time % 60} sec')\n", " return dist_mat" ] }, { "cell_type": "markdown", "id": "7fa31a4b-291b-44dc-b2a2-ae8007bccf39", "metadata": {}, "source": [ "The main class in GWProt is `GW_protein`. An instance stores the data needed for structural comparisons: alpha-carbon coordinates, amino-acid sequence, and the intra-protein distance matrix (pairwise distances between alpha-carbons). You can initialize `GW_protein` objects from PDB files. In this example we assume the PDBs are in the folder `palmstrub_little/` and load them into `GW_protein` instances:" ] }, { "cell_type": "code", "execution_count": null, "id": "e1638a7d-8ede-4e44-8f7a-ea86e01ae059", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "palmstrub_dir = \"Example_Data/RdRp Proteins/palmstrub_little/\"\n", "\n", "all_files = os.listdir(palmstrub_dir)\n", "base_prots = [GWProt.GW_protein.GW_protein.make_protein_from_pdb(palmstrub_dir + file) for file in all_files]" ] }, { "cell_type": "markdown", "id": "724fe111", "metadata": {}, "source": [ "We can look at the distribution of the number of alpha-Carbons:" ] }, { "cell_type": "code", "execution_count": null, "id": "57adca3c", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.hist([len(m) for m in base_prots], bins=50, edgecolor='black')" ] }, { "cell_type": "markdown", "id": "d67e0561", "metadata": {}, "source": [ "![Example_Data/Images/Carbons.PNG](Example_Data/Images/Carbons.PNG)" ] }, { "cell_type": "markdown", "id": "2a121a9b", "metadata": {}, "source": [ "Computing GW distances using all alpha-carbons is computationally intensive. For this dataset the following command may take several hours on a typical desktop (8 cores, 64 GB RAM). Consider downsampling or running on a cluster for faster results." ] }, { "cell_type": "code", "execution_count": null, "id": "35570074-a67f-43b3-8c2f-91250fd10ed8", "metadata": {}, "outputs": [], "source": [ "GW_dist_mat = compute_in_parallel(proteins=base_prots, comparison_method= GWProt.GW_protein.GW_protein.run_GW)" ] }, { "cell_type": "markdown", "id": "23fbd85f", "metadata": {}, "source": [ "We can embed the GW pairwise distance matrix using UMAP and color points by ribovirus class to inspect separation between classes visually. The following code builds an interactive Plotly scatter plot for exploration." ] }, { "cell_type": "code", "execution_count": null, "id": "1f48c3f3", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import umap\n", "import plotly.express as px\n", "import pandas as pd\n", "import plotly.io as pio\n", "\n", "df = pd.read_csv('Example_Data/RdRp Proteins/TableS2.csv', index_col='GenPept accession')\n", "\n", "reducer = umap.UMAP(metric='precomputed')\n", "embedding = reducer.fit_transform(GW_dist_mat)\n", "\n", "clas = [df['Class'].loc[m.name.split(\".\")[-1]] for m in base_prots]\n", "\n", "df2 = pd.DataFrame({\n", " 'UMAP1': embedding[:, 0],\n", " 'UMAP2': embedding[:, 1],\n", " 'Class' : clas\n", "})\n", "\n", "pio.renderers.default = 'notebook_connected' # Set to the appropriate plotly renderer\n", "\n", "fig = px.scatter(\n", " df2,\n", " x='UMAP1',\n", " y='UMAP2',\n", " color='Class',\n", " color_discrete_sequence=px.colors.qualitative.Set2 # You can change this palette\n", ")\n", "\n", "fig.update_layout(\n", " plot_bgcolor='white',\n", " paper_bgcolor='white',\n", " xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, scaleanchor='y'),\n", " yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)\n", ")\n", "\n", "fig.update_traces(marker=dict(size=3, opacity=0.8))\n", "fig.show()" ] }, { "cell_type": "markdown", "id": "ef778719", "metadata": {}, "source": [ "![Example_Data/Images/RdRpUMAP.PNG](Example_Data/Images/RdRpUMAP.PNG)" ] }, { "cell_type": "markdown", "id": "ac86a4da", "metadata": {}, "source": [ "We train a nearest-neighbor classifier in the GW embedding to predict the ribovirus class and evaluate performance using Matthews Correlation Coefficient (MCC) with 5-fold cross-validation. MCC is a balanced metric that accounts for true/false positives and negatives. In this example the classifier yields an MCC of 0.912, indicating strong class separability in the GW-derived space." ] }, { "cell_type": "code", "execution_count": null, "id": "7997dc45", "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.model_selection import StratifiedKFold, cross_val_predict\n", "from sklearn.metrics import matthews_corrcoef\n", "\n", "clf = KNeighborsClassifier(metric=\"precomputed\", n_neighbors=3, weights=\"distance\")\n", "cv = StratifiedKFold(n_splits=5, shuffle=True,random_state=0)\n", "cvp = cross_val_predict(clf, X=GW_dist_mat, y=clas, cv=cv)\n", "\n", "print(matthews_corrcoef(cvp, clas))" ] }, { "cell_type": "markdown", "id": "e5e0b2d2-ad17-4361-b2a8-4fd44e6ce8ff", "metadata": {}, "source": [ "We can downsample each protein to an evenly spaced subset of residues to speed up GW computations. Downsampling reduces runtime significantly at the cost of some structural detail — higher downsampling is faster but typically reduces accuracy. Use this option when exploring large datasets or when GPU/cluster resources are limited." ] }, { "cell_type": "code", "execution_count": null, "id": "a7f117a0-d36d-46a9-96d9-ca2104062360", "metadata": {}, "outputs": [], "source": [ "downsampled_100_prots = [p.downsample_n(n=100) for p in base_prots]\n", "GW_dist_mat = compute_in_parallel(proteins=downsampled_100_prots, comparison_method= GWProt.GW_protein.GW_protein.run_GW)" ] }, { "cell_type": "markdown", "id": "ecc2cee0-a156-4b42-87e8-5a54897c0640", "metadata": {}, "source": [ "On this machine, the downsampled run completed in about 4 minutes, with MCC dropping slightly to 0.899.\n", "\n", "Another strategy is to rescale the intra-protein distance matrices (e.g., applying a concave transform such as square root). Rescaling increases the relative weight of short-range distances and often improves classification accuracy with little runtime penalty. Below we apply rescaling together with downsampling; the order (downsampling → scaling vs. scaling → downsampling) typically does not affect the final result." ] }, { "cell_type": "code", "execution_count": null, "id": "b795b944-4984-4bd3-96cc-1006c8a47501", "metadata": {}, "outputs": [], "source": [ "scaled_prots = [p.scale_ipdm(inplace = False) for p in downsampled_100_prots]\n", "GW_dist_mat = compute_in_parallel(proteins=scaled_prots, comparison_method= GWProt.GW_protein.GW_protein.run_GW)" ] }, { "cell_type": "markdown", "id": "6e62ed1e-879f-43c4-bf15-2f9b760af637", "metadata": {}, "source": [ "After rescaling the intra-protein distances we observe improved performance (MCC = 0.921 in this example), higher than the simple downsampled result.\n", "\n", "Next, we demonstrate fused Gromov–Wasserstein (FGW) using a BLOSUM-derived difference matrix. FGW can incorporate sequence- or feature-based differences in addition to geometry; its effect on accuracy depends on the chosen feature distances. FGW generally has a runtime comparable to standard GW." ] }, { "cell_type": "code", "execution_count": null, "id": "7ac74beb-bc5c-4479-8958-c52840936855", "metadata": {}, "outputs": [], "source": [ "import GWProt.FGW_matrices\n", "\n", "GW_dist_mat = compute_in_parallel(proteins=scaled_prots, comparison_method=GWProt.GW_protein.GW_protein.run_FGW_dict, alpha = 0.05,d = GWProt.FGW_matrices.get_BLOSUM_dict(n = 62))" ] }, { "cell_type": "markdown", "id": "f17e12e8", "metadata": {}, "source": [ "GWProt provides several feature-based distance options for fused GW in addition to BLOSUM-based scores. Examples include residue isoelectric point distances (module `GW_protein_pI`) and hydrophobicity-based distances; users can also supply custom distance matrices. These fused metrics can improve performance when feature differences are informative." ] }, { "cell_type": "code", "execution_count": null, "id": "cade4634-8ea2-470e-9522-cf3f959f62cf", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "toc-autonumbering": true, "toc-showmarkdowntxt": true }, "nbformat": 4, "nbformat_minor": 5 }