Source code for gmso.abc.gmso_base

"""Base model all classes extend."""

import json
import logging
from abc import ABC
from typing import Any, Type

from pydantic import BaseModel, ConfigDict, validators

from gmso.abc.serialization_utils import dict_to_unyt

dict_validator = validators.getattr_migration("dict_validator")
Model = Type["Model"]
DictStrAny = Type["DictStrAny"]
TupleGenerator = Type["TupleGenerator"]
CallableGenerator = Type["CallableGenerator"]

logger = logging.getLogger(__name__)


[docs] class GMSOBase(BaseModel, ABC): """A BaseClass to all abstract classes in GMSO.""" model_config = ConfigDict( arbitrary_types_allowed=True, validate_assignment=True, extra="forbid", populate_by_name=True, ) def __hash__(self): """Return the unique hash of the object.""" return id(self) def __eq__(self, other): """Test if two objects are equivalent.""" return self is other def __setattr__(self, name: Any, value: Any) -> None: """Set the attributes of the object.""" if name in self.model_config.get("alias_to_fields"): name = self.model_config.get("alias_to_fields")[name] elif name in self.model_config.get("alias_to_fields").values(): logger.info( "Use of internal fields is discouraged. " "Please use external fields to set attributes." ) super().__setattr__(name, value) @classmethod def model_validate(cls: Model, obj: Any) -> Model: dict_to_unyt(obj) return super(GMSOBase, cls).model_validate(obj) def model_dump(self, **kwargs) -> DictStrAny: kwargs["by_alias"] = True additional_excludes = set() if "exclude" in kwargs: for term in kwargs["exclude"]: if term in self.model_config["alias_to_fields"]: additional_excludes.add(self.model_config["alias_to_fields"][term]) kwargs["exclude"] = kwargs["exclude"].union(additional_excludes) super_dict = super(GMSOBase, self).model_dump(**kwargs) return super_dict def model_dump_json(self, **kwargs): kwargs["by_alias"] = True additional_excludes = set() if "exclude" in kwargs: for term in kwargs["exclude"]: if term in self.model_config["alias_to_fields"]: additional_excludes.add(self.model_config["alias_to_fields"][term]) kwargs["exclude"] = kwargs["exclude"].union(additional_excludes) super_dict = super(GMSOBase, self).model_dump_json(**kwargs) return super_dict
[docs] def json_dict(self, **kwargs): """Return a JSON serializable dictionary from the object""" raw_json = self.model_dump_json(**kwargs) return json.loads(raw_json)
def _iter(self, **kwargs) -> TupleGenerator: exclude = kwargs.get("exclude") include = kwargs.get("include") include_alias = set() exclude_alias = set() if include: for included in include: if included in self.Config.alias_to_fields: include_alias.add(self.Config.alias_to_fields[included]) else: include_alias.add(included) kwargs["include"] = include_alias if exclude: for excluded in exclude: if excluded in self.Config.alias_to_fields: exclude_alias.add(self.Config.alias_to_fields[excluded]) else: exclude_alias.add(excluded) kwargs["exclude"] = exclude_alias yield from super()._iter(**kwargs)
[docs] @classmethod def validate(cls, value): """Ensure that the object is validated before use.""" if isinstance(value, cls): return value else: return cls(**dict_validator(value))
@classmethod def __get_validators__(cls) -> CallableGenerator: """Get the validators of the object.""" yield cls.validate