"""Abstract representation of a Potential object."""
from typing import Any, Dict, Iterator, List, Set
import unyt as u
from pydantic import ConfigDict, Field, field_serializer, field_validator
from gmso.abc.gmso_base import GMSOBase
from gmso.abc.serialization_utils import unyt_to_dict
from gmso.utils.expression import PotentialExpression
[docs]
class AbstractPotential(GMSOBase):
__base_doc__ = """An abstract potential class.
AbstractPotential stores a general interaction between components of a chemical
topology that can be specified by a mathematical expression. The functional
form of the potential is stored as a `sympy` expression and the parameters
are stored explicitly. This class is agnostic to the instantiation of the
potential, which can be e.g. a non-bonded potential, a bonded potential, an
angle potential, a dihedral potential, etc. and is designed to be inherited
by classes that represent these potentials.
"""
name_: str = Field(
"",
description="The name of the potential. Defaults to class name",
alias="name",
)
potential_expression_: PotentialExpression = Field(
PotentialExpression(expression="a*x+b", independent_variables={"x"}),
description="The mathematical expression for the potential",
alias="potential_expression",
)
tags_: Dict[str, Any] = Field(
{},
description="Tags associated with the potential",
alias="tags",
)
model_config = ConfigDict(
alias_to_fields={
"name": "name_",
"potential_expression": "potential_expression_",
"tags": "tags_",
}
)
def __init__(
self,
name="Potential",
expression="a*x+b",
independent_variables=None,
potential_expression=None,
**kwargs,
):
if potential_expression is None:
if expression is None:
expression = "a*x+b"
if independent_variables is None:
independent_variables = {"x"}
potential_expression = PotentialExpression(
expression=expression,
independent_variables=independent_variables,
parameters=None,
)
if not kwargs.get("tags"):
kwargs["tags"] = {}
super().__init__(name=name, potential_expression=potential_expression, **kwargs)
@property
def name(self) -> str:
"""The name of the potential."""
return self.__dict__.get("name_")
@property
def independent_variables(self) -> Set:
"""The independent variables in the potential's expression.
Returns the set of :mod:`sympy` symbols that are *not* parameters —
typically the interparticle distance ``r`` or an angle ``theta``.
"""
return self.potential_expression.independent_variables
@property
def expression(self):
"""The mathematical expression of the potential as a :class:`sympy.Expr`."""
return self.potential_expression.expression
@property
def potential_expression(self) -> "PotentialExpression":
"""The :class:`~gmso.utils.expression.PotentialExpression` for this potential.
Bundles the sympy expression, the independent variables, and the
parameter dictionary into a single validated object.
"""
return self.__dict__.get("potential_expression_")
@property
def tags(self) -> Dict[str, Any]:
"""Arbitrary metadata attached to this potential."""
return self.__dict__.get("tags_")
@property
def tag_names(self) -> List[str]:
return list(self.__dict__.get("tags_"))
@property
def tag_names_iter(self) -> Iterator[str]:
return iter(self.__dict__.get("tags_"))
[docs]
@field_serializer("potential_expression_")
def serialize_expression(self, potential_expression_: PotentialExpression):
expr = str(potential_expression_.expression)
ind = sorted(
list(str(ind) for ind in potential_expression_.independent_variables)
)
params = {
param: unyt_to_dict(val)
for param, val in potential_expression_.parameters.items()
}
return {
"expression": expr,
"independent_variables": ind,
"parameters": params,
}
[docs]
def add_tag(self, tag: str, value: Any, overwrite: bool = True) -> None:
"""Add or update a metadata tag on this potential.
Parameters
----------
tag : str
Name of the tag to add.
value : Any
Value to associate with the tag.
overwrite : bool, optional, default=True
If ``False``, raise :exc:`ValueError` when *tag* already exists.
"""
if self.tags.get(tag) and not overwrite:
raise ValueError(
f"Tag {tag} already exists. Please use overwrite=True to overwrite"
)
self.tags[tag] = value
[docs]
def get_tag(self, tag: str, throw: bool = False) -> Any:
"""Return the value of a metadata tag.
Parameters
----------
tag : str
Name of the tag to retrieve.
throw : bool, optional, default=False
If ``True``, raise :exc:`KeyError` when *tag* does not exist.
If ``False``, return ``None`` for missing tags.
"""
if throw:
return self.tags[tag]
else:
return self.tags.get(tag)
[docs]
def delete_tag(self, tag: str) -> None:
"""Delete a metadata tag.
Parameters
----------
tag : str
Name of the tag to delete.
Raises
------
KeyError
If *tag* does not exist.
"""
del self.tags[tag]
[docs]
def pop_tag(self, tag: str) -> Any:
"""Remove and return a metadata tag, or ``None`` if absent.
Parameters
----------
tag : str
Name of the tag to pop.
"""
return self.tags.pop(tag, None)
[docs]
@field_validator("potential_expression_", mode="before")
@classmethod
def validate_potential_expression(cls, v):
if isinstance(v, dict):
v = PotentialExpression(**v)
return v
def __setattr__(self, key: Any, value: Any) -> None:
"""Set attributes of the potential."""
if key == "expression":
self.potential_expression.expression = value
elif key == "independent_variables":
self.potential_expression.independent_variables = value
elif key == "set_ref_":
return
else:
super().__setattr__(key, value)
def __repr__(self):
"""Return a formatted representation of the potential."""
desc = (
f"<{self.__class__.__name__} {self.name},\n "
f"expression: {self.expression},\n "
f"id: {id(self)}>"
)
return desc
def __str__(self):
"""Return a string representation of the potential."""
return (
f"<{self.__class__.__name__} {self.name}, "
f"expression: {self.expression}, "
f"id: {id(self)}>"
)