import itertools
from typing import List, Optional, Sequence
from pydantic import ConfigDict, Field, model_validator
from gmso.abc.abstract_site import Site
from gmso.abc.gmso_base import GMSOBase
from gmso.exceptions import GMSOError
[docs]
class Connection(GMSOBase):
__base_doc__ = """An abstract class that stores data about connections between sites.
This class functions as a super-class for any connected groups (bonds, angles, dihedrals, etc).
Each instance will have a property for the conection_type (bond_type, angle_type, dihedral_type)
"""
name_: str = Field(
default="",
description="Name of the connection. Defaults to class name.",
alias="name",
)
connection_members_: Optional[Sequence[Site]] = Field(
default=None,
description="A list of constituents in this connection, in order.",
alias="connection_members",
)
model_config = ConfigDict(
alias_to_fields={
"name": "name_",
"connection_members": "connection_members_",
}
)
@property
def connection_members(self) -> Optional[Sequence[Site]]:
"""Return the ordered sequence of sites that form this connection."""
return self.__dict__.get("connection_members_")
@property
def name(self) -> str:
"""Return the name of this connection."""
return self.__dict__.get("name_")
@property
def member_types(self) -> Optional[List[str]]:
"""Return the atom-type name of each connection member.
Returns the names from the connection's ``connection_type`` when
available, otherwise falls back to the individual members' atom types.
Returns ``None`` when no type information is present.
"""
return self._get_members_types_or_classes("member_types")
@property
def member_classes(self) -> Optional[List[str]]:
"""Return the atom-type class of each connection member.
Returns the classes from the connection's ``connection_type`` when
available, otherwise falls back to the individual members' atom types.
Returns ``None`` when no type information is present.
"""
return self._get_members_types_or_classes("member_classes")
def _has_typed_members(self):
"""Check if all the members of this connection are typed."""
return all(
member.atom_type for member in self.__dict__.get("connection_members_")
)
def _get_members_types_or_classes(self, to_return):
"""Return types or classes for connection members if they exist."""
assert to_return in {"member_types", "member_classes"}
ctype = getattr(self, "connection_type")
ctype_attr = getattr(ctype, to_return) if ctype else None
if ctype_attr:
return list(ctype_attr)
elif self._has_typed_members():
tc = [
(
member.atom_type.name
if to_return == "member_types"
else member.atom_type.atomclass
)
for member in self.__dict__.get("connection_members_")
]
return tc if all(tc) else None
[docs]
@model_validator(mode="before")
def validate_fields(cls, values):
if "connection_members" in values:
connection_members = values.get("connection_members")
else:
connection_members = values.get("connection_members_")
if all(isinstance(member, dict) for member in connection_members):
connection_members = [
cls.__members_creator__(x) for x in connection_members
]
if not all(isinstance(x, Site) for x in connection_members):
raise TypeError("A non-site object provided to be a connection member")
if len(set(connection_members)) != len(connection_members):
raise GMSOError(
f"Trying to create a {cls.__name__} between "
f"same sites. A {cls.__name__} between same "
f"{type(connection_members[0]).__name__}s is not allowed"
)
if not values.get("name"):
values["name"] = cls.__name__
return values
def __repr__(self):
return (
f"<{self.__class__.__name__} {self.name},\n "
f"connection_members: {self.connection_members},\n "
f"potential: {str(self.connection_type)},\n "
f"id: {id(self)}>"
)
def __str__(self):
return f"<{self.__class__.__name__} {self.name}, id: {id(self)}> "
[docs]
def get_connection_identifiers(self):
borderDict = {1: "-", 2: "=", 3: "#", 0: "~", None: "~", 1.5: ":"}
choices = [
(site.atom_type.name, site.atom_type.atomclass, "*")
for site in self.connection_members
]
if not getattr(self, "bonds", None):
bond_identifiers = [borderDict[self.bond_order]]
else:
bond_identifiers = [borderDict[b.bond_order] for b in self.bonds]
choices += [(val, "~") for val in bond_identifiers]
return itertools.product(*choices)