Source code for mdpath.src.graph

"""Graph --- :mod:`mdpath.src.graph`
==============================================================================

This module contains the class `GraphBuilder` which generates a graph of residues within a certain distance of each other.
Graph edges are assigned weights based on mutual information differences.
Paths between distant residues are calculated based on the shortest path with the highest total weight.

Classes
--------

:class:`GraphBuilder`
"""

import heapq
import numpy as np
import networkx as nx
import pandas as pd
from scipy.spatial import cKDTree
from Bio import PDB
from typing import Tuple, List
from mdpath.src.structure import StructureCalculations


[docs] class GraphBuilder: """Build and analyze residue interaction graphs based on residue distances and mutual information between residue pais. Attributes: pdb (str): Path to the PDB file. end (int): The last residue number to consider in the graph. mi_diff_df (pd.DataFrame): DataFrame containing mutual information differences between residue pairs. dist (int): Cutoff distance for graph edges in Angstroms. graph (nx.Graph): The constructed residue interaction graph. """ def __init__( self, pdb: str, last_residue: int, mi_diff_df: pd.DataFrame, graphdist: int ) -> None: self.pdb = pdb self.end = last_residue self.mi_diff_df = mi_diff_df self.dist = graphdist self.graph = self.graph_builder()
[docs] def graph_skeleton(self) -> nx.Graph: """Generates a graph of residues with edges for residues within in a given distance of each other. Returns: residue_graph (nx.Graph): Graph of residues within a certain distance of each other. """ residue_graph = nx.Graph() parser = PDB.PDBParser(QUIET=True) structure = parser.get_structure("pdb_structure", self.pdb) heavy_atoms = {"C", "N", "O", "S"} residues = [ res for res in structure.get_residues() if PDB.Polypeptide.is_aa(res) ] coords = [] res_ids = [] for res in residues: rid = res.get_id()[1] if rid <= self.end: for atom in res: if atom.element in heavy_atoms: coords.append(atom.coord) res_ids.append(rid) if not coords: return residue_graph coords = np.array(coords) res_ids = np.array(res_ids) tree = cKDTree(coords) atom_pairs = tree.query_pairs(r=self.dist) for i, j in atom_pairs: if res_ids[i] != res_ids[j]: residue_graph.add_edge(int(res_ids[i]), int(res_ids[j]), weight=0) return residue_graph
[docs] def graph_assign_weights(self, residue_graph: nx.Graph) -> nx.Graph: """Assignes edge weights to the graph based on mutual information between the residue pair. Args: residue_graph (nx.Graph): Base residue graph (graph skeleton). Returns: residue_graph (nx.Graph): Residue graph with edge weights assigned. """ weight_lookup = {} for _, row in self.mi_diff_df.iterrows(): pair = tuple(row["Residue Pair"]) weight_lookup[pair] = row["MI Difference"] for edge in residue_graph.edges(): u, v = edge pair = ("Res " + str(u), "Res " + str(v)) if pair in weight_lookup: residue_graph.edges[edge]["weight"] = weight_lookup[pair] return residue_graph
[docs] def graph_builder(self) -> nx.Graph: """Wrapper function to build the residue graph. Returns: residue_graph (nx.Graph): Full residue graph with edge weights assigned. """ graph = self.graph_skeleton() graph = self.graph_assign_weights(graph) return graph
[docs] def max_weight_shortest_path(self, source: int, target: int) -> Tuple: """Finds the shortest path between 2 nodes with the highest total weight among all shortest paths. Args: source (int): Starting node. target (int): Target node. Returns: best_path (List): List of nodes in the shortest path with the highest weight. total_weight (float): Total weight of the shortest path. """ best = {source: (0, 0)} heap = [(0, 0, source, [source])] while heap: dist, neg_w, u, path = heapq.heappop(heap) acc_w = -neg_w if u == target: return path, acc_w prev_dist, prev_w = best.get(u, (float("inf"), -float("inf"))) if dist > prev_dist or (dist == prev_dist and acc_w < prev_w): continue for v in self.graph.neighbors(u): edge_w = self.graph[u][v].get("weight", 0) new_dist = dist + 1 new_w = acc_w + edge_w prev_v = best.get(v, (float("inf"), -float("inf"))) if new_dist < prev_v[0] or (new_dist == prev_v[0] and new_w > prev_v[1]): best[v] = (new_dist, new_w) heapq.heappush(heap, (new_dist, -new_w, v, path + [v])) raise nx.NetworkXNoPath(f"No path between {source} and {target}.")
[docs] def collect_path_total_weights(self, df_distant_residues: pd.DataFrame) -> list: """Wrapper function to collect the shortest path and total weight between distant residues. Args: residue_graph (nx.Graph): Residue graph. df_distant_residues (pd.DataFrame): Panda dataframe with distant residues. Returns: path_total_weights (list): List of tuples with the shortest path and total weight between distant residues. """ path_total_weights = [] for index, row in df_distant_residues.iterrows(): try: shortest_path, total_weight = self.max_weight_shortest_path( row["Residue1"], row["Residue2"] ) path_total_weights.append((shortest_path, total_weight)) except nx.NetworkXNoPath: continue return path_total_weights