"""
LeafletFinder
=============
Classes
-------
.. autoclass:: LeafletFinder
:members:
"""
from typing import Optional, Union
import numpy as np
from MDAnalysis.core.universe import Universe
from MDAnalysis.core.groups import AtomGroup
from MDAnalysis.selections import get_writer
from MDAnalysis.analysis.distances import capped_distance, distance_array
from .grouping import GraphMethod, SpectralClusteringMethod
from ..lib.utils import cached_property
from ..lib import mdautils
[docs]
class LeafletFinder:
"""Identify atoms in the same leaflet of a lipid bilayer.
You can use a predefined method ("graph", "spectralclustering").
Alternatively, you can pass in your own function
as a method.
Parameters
----------
universe : Universe or AtomGroup
Atoms to apply the algorithm to
select : str
A :meth:`Universe.select_atoms` selection string
for atoms that define the lipid head groups, e.g.
universe.atoms.PO4 or "name PO4" or "name P*"
cutoff : float (optional)
cutoff distance for computing distances (for the spectral clustering
method) or determining connectivity in the same leaflet (for the graph
method). In spectral clustering, it just has to be suitably large to
cover a significant part of the leaflet, but lower values increase
computational efficiency. Please see the :func:`optimize_cutoff`
function for help with values for the graph method. A cutoff is not
used for the "center_of_geometry" method.
pbc : bool (optional)
If ``False``, does not follow the minimum image convention when
computing distances
method: str or function (optional)
method to use to assign groups to leaflets. Choose
"graph" for :class:`~lipyds.leafletfinder.grouping.GraphMethod`;
"spectralclustering" for
:class:`~lipyds.leafletfinder.grouping.SpectralClusteringMethod`;
**kwargs:
Passed to ``method``
Attributes
----------
universe: Universe
select: str
Selection string
selection: AtomGroup
Atoms that the analysis is applied to
residues: ResidueGroup
residues that the analysis is applied to
headgroups: List of AtomGroup
Atoms that the analysis is applied to, grouped by residue.
pbc: bool
Whether to use PBC or not
leaflet_indices_by_size: list of list of indices
List of residue indices in each leaflet. This is the index
of residues in ``residues``, *not* the canonical ``resindex``
attribute from MDAnalysis. Leaflets are sorted by size such
that the largest leaflet is first.
leaflet_residues_by_size: list of ResidueGroup
List of ResidueGroups in each leaflet.
Leaflets are sorted by size such
that the largest leaflet is first.
leaflet_atoms_by_size: list of AtomGroup
List of AtomGroups in each leaflet.
Leaflets are sorted by size such
that the largest leaflet is first.
leaflet_indices: list of list of indices
List of residue indices in each leaflet. This is the index
of residues in ``residues``, *not* the canonical ``resindex``
attribute from MDAnalysis.
leaflet_residues: list of ResidueGroup
List of ResidueGroups in each leaflet.
The leaflets are sorted by z-coordinate so that the
lower-most leaflet is first.
leaflet_atoms: list of AtomGroup
List of AtomGroups in each leaflet.
The leaflets are sorted by z-coordinate so that the
lower-most leaflet is first.
"""
def __init__(self, universe: Union[AtomGroup, Universe],
select: Optional[str] = 'all',
select_tailgroups: Optional[str] = None,
cutoff: float = 40,
pbc: bool = True,
method: str = "spectralclustering",
n_leaflets: int = 2,
normal_axis: str = "z",
update_TopologyAttr: bool = False,
**kwargs):
self._cache = {}
self.universe = universe.universe
self.pbc = pbc
self.n_leaflets = n_leaflets
self.cutoff = cutoff
self.kwargs = dict(**kwargs)
self._normal_axis = ["x", "y", "z"].index(normal_axis)
self.atomgroup = universe.select_atoms(select, periodic=pbc)
self.atoms_by_residue = self.atomgroup.split("residue")
self._first_residue_atoms = sum(ag[0] for ag in self.atoms_by_residue)
self.residues = self.atomgroup.residues
self.n_residues = len(self.residues)
if select_tailgroups is not None:
self.tailgroups = self.residues.atoms.select_atoms(select_tailgroups,
periodic=pbc)
else:
self.tailgroups = self.residues.atoms - self.atomgroup
if pbc:
self._get_box = lambda: self.universe.dimensions
else:
self._get_box = lambda: None
if isinstance(method, str):
method = method.lower().replace('_', '')
if method == "graph":
self.method = GraphMethod(self.atomgroup, self.tailgroups,
cutoff=self.cutoff, pbc=self.pbc,
**kwargs)
self._method = self.method.run
elif method == "spectralclustering":
self.method = SpectralClusteringMethod(self.atomgroup, self.tailgroups,
cutoff=self.cutoff, pbc=self.pbc,
n_leaflets=self.n_leaflets,
**kwargs)
self._method = self.method.run
else:
self._method = self.method = method
self._update_TopologyAttr = update_TopologyAttr
@property
def box(self):
return self._get_box()
[docs]
def run(self):
"""
This clears the cache for lazy running.
"""
self._cache = {}
self._output_leaflet_indices
if self._update_TopologyAttr:
self.atomgroup.universe.add_TopologyAttr("leaflet")
for i, residues in enumerate(self.leaflet_residues):
for residue in residues:
residue.leaflet = i
[docs]
def write_selection(self, filename, mode="w", format=None, **kwargs):
"""Write selections for the leaflets to *filename*.
The format is typically determined by the extension of *filename*
(e.g. "vmd", "pml", or "ndx" for VMD, PyMol, or Gromacs).
See :class:`MDAnalysis.selections.base.SelectionWriter` for all
options.
"""
sw = get_writer(filename, format)
with sw(filename, mode=mode,
preamble=f"Leaflets found by {repr(self)}\n",
**kwargs) as writer:
for i, ag in enumerate(self.leaflet_atoms, 1):
writer.write(ag, name=f"leaflet_{i:d}")
def __repr__(self):
return (f"LeafletFinder(method={self.method}, select='{self.atomgroup}', "
f"cutoff={self.cutoff:.1f} Å, pbc={self.pbc})")
@cached_property
def _output_leaflet_indices(self):
clusters = self._method(selection=self.atomgroup,
tailgroups=self.tailgroups,
cutoff=self.cutoff, box=self.box,
**self.kwargs)
return [sorted(x) for x in clusters]
@cached_property
def residue_leaflets(self):
arr = np.full(self.n_residues, -1, dtype=int)
for leaflet, residues in enumerate(self.leaflet_indices):
for residue_index in residues:
arr[residue_index] = leaflet
return arr
@cached_property
def _output_leaflet_residues(self):
return [self.residues[x] for x in self._output_leaflet_indices]
def _get_atomgroup_by_indices(self, indices):
ag = sum(self.atoms_by_residue[i] for i in indices)
if not ag:
return self.atomgroup[[]]
return ag
@cached_property
def _output_leaflet_atoms(self):
return [self._get_atomgroup_by_indices(x)
for x in self._output_leaflet_indices]
@cached_property
def leaflet_indices_by_size(self):
return sorted(self._output_leaflet_indices, key=len, reverse=True)
@cached_property
def leaflet_residues_by_size(self):
return [self.residues[x] for x in self.leaflet_indices_by_size]
@cached_property
def leaflet_atoms_by_size(self):
return [self._get_atomgroup_by_indices(x)
for x in self.leaflet_indices_by_size]
def _argsort_by_normal(self, groups):
positions = [x.positions for x in groups]
unwrapped = [mdautils.unwrap_coordinates(x, x[0], self.box) for x in positions]
vals = [np.mean(x[:, self._normal_axis]) for x in unwrapped]
args = np.argsort(vals)[::-1]
return args
@cached_property
def _output_by_normal(self):
return self._argsort_by_normal(self._output_leaflet_atoms)
@cached_property
def leaflet_indices_by_normal(self):
return [self._output_leaflet_indices[i]
for i in self._output_by_normal]
@cached_property
def leaflet_residues_by_normal(self):
return [self._output_leaflet_residues[i]
for i in self._output_by_normal]
@cached_property
def leaflet_atoms_by_normal(self):
return [self._output_leaflet_atoms[i]
for i in self._output_by_normal]
@cached_property
def _argsort_by_size_and_normal(self):
atoms = self.leaflet_atoms_by_size[:self.n_leaflets]
return self._argsort_by_normal(atoms)
@cached_property
def leaflet_indices(self):
return [self.leaflet_indices_by_size[i]
for i in self._argsort_by_size_and_normal]
@cached_property
def leaflet_residues(self):
return [self.leaflet_residues_by_size[i]
for i in self._argsort_by_size_and_normal]
@cached_property
def leaflet_atoms(self):
return [self.leaflet_atoms_by_size[i]
for i in self._argsort_by_size_and_normal]
@cached_property
def leaflet_coordinates(self):
by_leaflet = [mdautils.get_centers_by_residue(ag, box=self.box)
for ag in self.leaflet_atoms]
unwrapped = [mdautils.unwrap_coordinates(x, center=by_leaflet[0][0], box=self.box)
for x in by_leaflet]
center = np.concatenate(unwrapped).mean(axis=0)
return unwrapped
return [x - center for x in unwrapped]
@cached_property
def resindex_to_leaflet(self):
r2l = {}
for i, residues in enumerate(self.leaflet_residues):
for resindex in residues.resindices:
self.resindex_to_leaflet[resindex] = i
def atom_leaflet_by_distance(self, atom, cutoff=10):
zs = self._first_residue_atoms.positions
# zs[:, :2] = 0
atom_z = atom.position
# atom_z[:2] = 0
pairs, dists = capped_distance(atom_z,
zs,
max_cutoff=self.cutoff,
box=self.box, return_distances=True)
if dists.min() > cutoff:
return -1
if len(pairs):
neighbors = self.residue_leaflets[pairs[:, 1]]
most_common = np.bincount(neighbors).argmax()
return most_common
distances = distance_array(atom.position,
self._first_residue_atoms.positions,
box=self.box).reshape(-1)
arg = distances.argmin()
return self.residue_leaflets[distances.argmin()]
def assign_atoms_by_distance(self, atomgroup, cutoff=10):
leaflets = np.full(len(atomgroup), -1, dtype=int)
for i, atom in enumerate(atomgroup):
leaflets[i] = self.atom_leaflet_by_distance(atom, cutoff=cutoff)
return leaflets
def get_first_outside_atoms(self, residues):
atoms = residues.atoms.split("residue")
residues = residues.residues
outside_ix = np.where(~np.in1d(residues, self.residues))[0]
if not len(outside_ix):
return residues.atoms[[]]
outside_atoms = sum([atoms[i][0] for i in outside_ix])
return outside_atoms