"""Module supporting various connectivity methods and operations."""
import itertools
import re
from itertools import combinations
from typing import TYPE_CHECKING, List
import networkx as nx
from networkx.algorithms import shortest_path_length
if TYPE_CHECKING:
from gmso import Topology
from gmso.core.atom import Site
from gmso.core.bond import Bond
from gmso.core.angle import Angle
from gmso.core.dihedral import Dihedral
from gmso.core.improper import Improper
from gmso.core.virtual_site import VirtualSite
from gmso.exceptions import MissingParameterError
CONNS = {"angle": Angle, "dihedral": Dihedral, "improper": Improper}
[docs]
def identify_connections(top: "Topology", index_only: bool = False) -> None:
"""Identify all angle, dihedral and improper connections.
This requires that the bonded connections are fully defined in
the ``gmso.Topology``.
Parameters
----------
top : gmso.Topology
The gmso topology for which to identify connections for.
index_only : bool, optional, default=False
If True, return integer site indices that would form the connections
rather than adding the connections to the topology.
Returns
-------
None
Connections are added to top in-place. When index_only is
``True``, a dict of integer-index tuples is returned instead and
top is not modified.
Notes
-----
Connections are detected by direct adjacency enumeration over the
topology's bond graph (built with integer site indices as nodes).
This replaces the previous approach of VF2 subgraph isomorphism on
the line graph of the bond graph.
The patterns being searched for (bonds, angles, dihedrals, impropers)
are simple enough that they can be enumerated by walking the
adjacency structure directly, instead of sub-graph matching:
- Angle (a-b-c):
for each node b, enumerate pairs of its neighbors.
- Dihedral (a-b-c-d):
for each edge (b,c), enumerate (neighbor of b) x
(neighbor of c), excluding the b-c bond itself.
- Improper (central; b1,b2,b3):
for each node with degree >= 3,
enumerate all combinations of 3 neighbors.
Site objects are only touched at two boundaries:
- Entry: building the site_index_map (site -> int).
- Exit: _add_connections resolving int indices back to Site objects.
Everything in between operates on plain Python ints.
"""
# Build site -> index map once up front. top._sites is an IndexedSet
site_index_map = {site: i for i, site in enumerate(top.sites)}
# Build an integer-node adjacency dict directly from bonds.
# Using dict[int, set[int]] rather than nx.Graph to avoid networkx
# per-node overhead during the enumeration loops.
adj: dict[int, set[int]] = {}
for b in top.bonds:
i = site_index_map[b.connection_members[0]]
j = site_index_map[b.connection_members[1]]
if i not in adj:
adj[i] = set()
if j not in adj:
adj[j] = set()
adj[i].add(j)
adj[j].add(i)
angle_matches = _enumerate_angles(adj)
dihedral_matches = _enumerate_dihedrals(adj)
improper_matches = _enumerate_impropers(adj)
if not index_only:
for conn_matches, conn_type in zip(
(angle_matches, dihedral_matches, improper_matches),
("angle", "dihedral", "improper"),
):
if conn_matches:
_add_connections(top, conn_matches, conn_type=conn_type)
else:
return {
"angles": angle_matches,
"dihedrals": dihedral_matches,
"impropers": improper_matches,
}
return top
def _add_connections(top, matches, conn_type):
"""Add connections to the topology."""
for sorted_conn in matches:
cmembers = [top.sites[idx] for idx in sorted_conn]
bonds = list()
for i, j in CONNS[conn_type].connectivity:
bond = (cmembers[i], cmembers[j])
key = frozenset([bond, tuple(reversed(bond))])
bonds.append(top._unique_connections[key])
to_add_conn = CONNS[conn_type](connection_members=cmembers, bonds=tuple(bonds))
top.add_connection(to_add_conn, update_types=False)
def _enumerate_angles(adj):
"""Enumerate all angles by direct adjacency traversal.
An angle is any triple (a, b, c) where a and c are both bonded to b.
For each node b with at least 2 neighbors, we enumerate all unordered
pairs of those neighbors.
Canonicalization: smaller terminal index first, so (a, b, c) with a < c.
This is imposed at construction time so the set handles deduplication.
Parameters
----------
adj : dict[int, set[int]]
Adjacency dict of the integer-node bond graph.
Returns
-------
list of tuple[int, int, int]
Sorted list of (end0, middle, end1) triples with end0 < end1,
ordered by (middle, end0, end1).
"""
matches = set()
for b, neighbors in adj.items():
if len(neighbors) < 2:
continue
for a, c in combinations(neighbors, 2):
# Smaller end first.
if a < c:
matches.add((a, b, c))
else:
matches.add((c, b, a))
return sorted(matches, key=lambda x: (x[1], x[0], x[2]))
def _enumerate_dihedrals(adj):
"""Enumerate all dihedrals by direct adjacency traversal.
A dihedral is any quadruple (a, b, c, d) where a-b, b-c, and c-d are
all bonds and a != c, b != d. For each bond (b, c), we enumerate all
valid (a, d) pairs where a is a neighbor of b other than c, and d is a
neighbor of c other than b.
Each bond is only visited once (enforced by c > b) to avoid emitting
both (a,b,c,d) and its mirror (d,c,b,a) before canonicalisation.
Canonicalisation (smaller terminal first) then handles any remaining
orientation ambiguity.
Parameters
----------
adj : dict[int, set[int]]
Adjacency dict of the integer-node bond graph.
Returns
-------
list of tuple[int, int, int, int]
Sorted list of (a, b, c, d) quadruples with a < d (canonical form),
ordered by (b, c, a, d).
"""
matches = set()
for b, b_neighbors in adj.items():
for c in b_neighbors:
if c <= b:
# Process each bond once, only add when c > b
continue
for a in b_neighbors:
if a == c:
continue
for d in adj[c]:
if d == b:
continue
if d == a:
continue
if a < d:
matches.add((a, b, c, d))
else:
matches.add((d, c, b, a))
return sorted(matches, key=lambda x: (x[1], x[2], x[0], x[3]))
def _enumerate_impropers(adj):
"""Enumerate all impropers by direct adjacency traversal.
An improper is any quadruple (central, b1, b2, b3) where central is
bonded to all three of b1, b2, b3. For each node with degree >= 3,
we enumerate all combinations of 3 neighbors as the branch atoms.
Canonicalization: branches are sorted ascending. Central node identity
is unambiguous so no further orientation handling is needed.
Parameters
----------
adj : dict[int, set[int]]
Adjacency dict of the integer-node bond graph.
Returns
-------
list of tuple[int, int, int, int]
Sorted list of (central, b1, b2, b3) with b1 < b2 < b3,
ordered by (central, b1, b2, b3).
"""
matches = set()
for central, neighbors in adj.items():
if len(neighbors) < 3:
continue
for trio in combinations(neighbors, 3):
b1, b2, b3 = sorted(trio)
matches.add((central, b1, b2, b3))
return sorted(matches, key=lambda x: (x[0], x[1], x[2], x[3]))
def generate_pairs_lists(
top, molecule=None, sort_key=None, refer_from_scaling_factor=False
):
"""Generate all the pairs lists of the topology or molecular of topology.
Parameters
----------
top : gmso.Topology
The Topology where we want to generate the pairs lists from.
molecule : molecule namedtuple, optional, default=None
Generate only pairs list of a particular molecule.
sort_key : function, optional, default=None
Function used as key for sorting of site pairs. If None is provided
will used topology.get_index
refer_from_scaling_factor : bool, optional, default=False
If True, only generate pair lists of pairs that have a non-zero scaling
factor value.
Returns
-------
pairs_lists: dict of list
{"pairs12": pairs12, "pairs13": pairs13, "pairs14": pairs14}
NOTE: This method assume that the topology has already been loaded with
angles and dihedrals (through top.identify_connections()). In addition,
if the refer_from_scaling_factor is True, this method will only generate
pairs when the corresponding scaling factor is not 0.
"""
from gmso.external import to_networkx
from gmso.parameterization.molecule_utils import (
molecule_angles,
molecule_bonds,
molecule_dihedrals,
)
nb_scalings, coulombic_scalings = top.scaling_factors
if sort_key is None:
sort_key = top.get_index
graph = to_networkx(top, parse_angles=False, parse_dihedrals=False)
pairs_dict = dict()
if refer_from_scaling_factor:
for i in range(3):
if nb_scalings[i] or coulombic_scalings[i]:
pairs_dict[f"pairs1{i + 2}"] = list()
else:
for i in range(3):
pairs_dict = {f"pairs1{i + 2}": list() for i in range(3)}
if molecule is None:
bonds, angles, dihedrals = top.bonds, top.angles, top.dihedrals
else:
bonds = molecule_bonds(top, molecule)
angles = molecule_angles(top, molecule)
dihedrals = molecule_dihedrals(top, molecule)
if "pairs12" in pairs_dict:
for bond in bonds:
pairs = sorted(bond.connection_members, key=sort_key)
pairs_dict["pairs12"].append(pairs)
if "pairs13" in pairs_dict:
for angle in angles:
pairs = sorted(
(angle.connection_members[0], angle.connection_members[-1]),
key=sort_key,
)
if (
pairs not in pairs_dict["pairs13"]
and shortest_path_length(graph, pairs[0], pairs[1]) == 2
):
pairs_dict["pairs13"].append(pairs)
if "pairs14" in pairs_dict:
for dihedral in dihedrals:
pairs = sorted(
(
dihedral.connection_members[0],
dihedral.connection_members[-1],
),
key=sort_key,
)
if (
pairs not in pairs_dict["pairs14"]
and shortest_path_length(graph, pairs[0], pairs[1]) == 3
):
pairs_dict["pairs14"].append(pairs)
for key in pairs_dict:
pairs_dict[key] = sorted(
pairs_dict[key],
key=lambda pairs: (sort_key(pairs[0]), sort_key(pairs[1])),
)
return pairs_dict
def identify_virtual_sites(
topology: "Topology",
sites: List["Site"],
bonds: List["Bond"],
virtual_types: List[VirtualSite],
):
"""Identify virtual sites within an already typed topology based on the virtual_types.
Parameters
----------
topology : gmso.Topology
Topology to search for parameters.
sites : List[gmso.core.abstract_site.Site]
Sites to use to construct subsearch of topology. Can be all sites in
the topology, or a subset of sites.
bonds : List[gmso.core.bonds.Bond]
Bonds to use to construct subsearch of topology. Can be all bonds in
the topology, or a subset of bonds.
virtual_types : List[gmso.core.virtual_types.VirtualType]
Virtual types, presumably from a gmso.ForceField, used to match the
parent_atoms in the sites and bonds graph.
Returns
-------
virtual_sites : List[gmso.core.virtual_site.VirtualSite]
VirtualSite instances identified in the topology.
"""
for site in sites:
if not site.atom_type:
raise MissingParameterError(site.atom_type, "atom_type")
compound = nx.Graph()
for b in bonds:
compound.add_node(b.connection_members[0], identifier=b.member_types[0])
compound.add_node(b.connection_members[1], identifier=b.member_types[1])
compound.add_edge(b.connection_members[0], b.connection_members[1])
virtual_sites = []
for vtype in virtual_types.values():
vtype_graph = _graph_from_vtype(vtype)
matchesMap = _get_graph_isomorphism_matches(compound, vtype_graph)
for match in matchesMap.values():
vsite = VirtualSite(parent_sites=match.keys())
virtual_sites.append(vsite)
topology._add_virtual_site(vsite)
return virtual_sites
def _get_graph_isomorphism_matches(g1, g2, match_by="identifier"):
"""g1 is a large map that is checked for g2 subgraphs in."""
node_match = nx.algorithms.isomorphism.categorical_node_match(match_by, default="")
graph_matcher = nx.algorithms.isomorphism.GraphMatcher(
g1, g2, node_match=node_match
)
acceptedMaps = dict()
for mapping in graph_matcher.subgraph_isomorphisms_iter():
possibleMap = {g1id: g2id for g1id, g2id in mapping.items()}
acceptedMaps[frozenset(possibleMap.keys())] = possibleMap
return acceptedMaps
def _graph_from_vtype(vtype):
"""Create a graph from a virtual_type."""
virtual_type_graph = nx.Graph()
if vtype.member_types:
iter_elementsStr = "member_types"
else:
iter_elementsStr = "member_classes"
for i, member in enumerate(getattr(vtype, iter_elementsStr)):
virtual_type_graph.add_node(i, identifier=member)
for i in range(len(getattr(vtype, iter_elementsStr)) - 1):
virtual_type_graph.add_edge(i, i + 1)
return virtual_type_graph
def connection_identifier_to_string(identifier):
"""Take a list of [site1, site2, bond1] and reorder into a string identifier.
Parameters
----------
identifier : tuple, list
The identifier for a given connection with a list of sites and bonds.
For example, a dihedral would look like:
combination = dihedral.connection_members + dihedral.bonds
Returns
-------
pattern : str
The identifying pattern for the list of sites. An improper might look like:
`central_atom-atom2-atom3=atom4`
where the combination was:
["central_atom", "atom2", "atom3", "atom4", "-", "-", "="]
"""
bonds_cutoff = len(identifier) // 2
sites = identifier[: bonds_cutoff + 1]
bonds = identifier[bonds_cutoff + 1 :]
pattern = sites[0]
for b, sit in zip(bonds, sites[1:]):
pattern += b + sit
return pattern
def yield_connection_identifiers(identifier):
"""Yield all possible bond identifiers from a tuple or string identifier."""
n_sites = len(identifier) // 2 + 1
if isinstance(identifier, str):
bond_tokens = r"([\=\~\-\#\:])"
identifier = re.split(bond_tokens, identifier)
identifier = identifier[::2] + identifier[1::2]
site_identifiers = identifier[:n_sites]
bond_identifiers = identifier[n_sites:]
choices = [(site_identifier, "*") for site_identifier in site_identifiers]
choices += [(val, "~") for val in bond_identifiers]
return itertools.product(*choices)