Source code for geoh5py.ui_json.enforcers

# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
#  Copyright (c) 2025 Mira Geoscience Ltd.                                     '
#                                                                              '
#  This file is part of geoh5py.                                               '
#                                                                              '
#  geoh5py is free software: you can redistribute it and/or modify             '
#  it under the terms of the GNU Lesser General Public License as published by '
#  the Free Software Foundation, either version 3 of the License, or           '
#  (at your option) any later version.                                         '
#                                                                              '
#  geoh5py is distributed in the hope that it will be useful,                  '
#  but WITHOUT ANY WARRANTY; without even the implied warranty of              '
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               '
#  GNU Lesser General Public License for more details.                         '
#                                                                              '
#  You should have received a copy of the GNU Lesser General Public License    '
#  along with geoh5py.  If not, see <https://www.gnu.org/licenses/>.           '
# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''


from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any
from uuid import UUID

from geoh5py.shared.exceptions import (
    AggregateValidationError,
    BaseValidationError,
    InCollectionValidationError,
    RequiredFormMemberValidationError,
    RequiredObjectDataValidationError,
    RequiredUIJsonParameterValidationError,
    RequiredWorkspaceObjectValidationError,
    TypeUIDValidationError,
    TypeValidationError,
    UUIDValidationError,
    ValueValidationError,
)
from geoh5py.shared.utils import SetDict, is_uuid


[docs] class Enforcer(ABC): """ Base class for rule enforcers. :param enforcer_type: Type of enforcer. :param validations: Value(s) to validate parameter value against. """ enforcer_type: str = "" def __init__(self, validations: set): self.validations = validations
[docs] @abstractmethod def rule(self, value: Any): """True if 'value' adheres to enforcers rule."""
[docs] @abstractmethod def enforce(self, name: str, value: Any): """Enforces rule on 'name' parameter's 'value'."""
def __eq__(self, other) -> bool: """Equal if same type and validations.""" is_equal = False if isinstance(other, type(self)): is_equal = other.validations == self.validations return is_equal def __str__(self): return f"<{type(self).__name__}> : {self.validations}"
[docs] class TypeEnforcer(Enforcer): """ Enforces valid type(s). :param validations: Valid type(s) for parameter value. :raises TypeValidationError: If value is not a valid type. """ enforcer_type: str = "type" def __init__(self, validations: set[type]): super().__init__(validations)
[docs] def enforce(self, name: str, value: Any): """Administers rule to enforce type validation.""" if not self.rule(value): raise TypeValidationError( name, type(value).__name__, [k.__name__ for k in self.validations] )
[docs] def rule(self, value) -> bool: """True if value is one of the valid types.""" return any(isinstance(value, k) for k in self.validations.union({type(None)}))
[docs] class ValueEnforcer(Enforcer): """ Enforces restricted value choices. :param validations: Valid parameter value(s). :raises ValueValidationError: If value is not a valid value choice. """ enforcer_type = "value" def __init__(self, validations: set[Any]): super().__init__(validations)
[docs] def enforce(self, name: str, value: Any): """Administers rule to enforce value validation.""" if not self.rule(value): raise ValueValidationError(name, value, list(self.validations))
[docs] def rule(self, value: Any) -> bool: """True if value is a valid choice.""" return value in self.validations
[docs] class TypeUIDEnforcer(Enforcer): """ Enforces restricted geoh5 entity_type uid(s). :param validations: Valid geoh5py object type uid(s). :raises TypeValidationError: If value is not a valid type uid. """ enforcer_type = "type_uid" def __init__(self, validations: set[str]): super().__init__(validations)
[docs] def enforce(self, name: str, value: Any): """Administers rule to enforce type uid validation.""" if not self.rule(value): raise TypeUIDValidationError(name, value, list(self.validations))
[docs] def rule(self, value: Any) -> bool: """True if value is a valid type uid.""" return self.validations == {""} or value.default_type_uid() in [ UUID(k) for k in self.validations ]
[docs] class UUIDEnforcer(Enforcer): """ Enforces valid uuid string. :param validations: No validations needed, can be empty set. :raises UUIDValidationError: If value is not a valid uuid string. """ enforcer_type = "uuid" def __init__(self, validations=None): super().__init__(validations)
[docs] def enforce(self, name: str, value: Any): """Administers rule to check if valid uuid.""" if not self.rule(value): raise UUIDValidationError( name, value, )
[docs] def rule(self, value: Any) -> bool: """True if value is a valid uuid string.""" if value is None: return True return is_uuid(value)
[docs] class RequiredEnforcer(Enforcer): """ Enforces required items in a collection. :param validations: Items that are required in the collection. :raises InCollectionValidationError: If collection is missing one of the required parameters/members. """ enforcer_type = "required" validation_error = InCollectionValidationError def __init__(self, validations: set[str | tuple[str, str]]): super().__init__(validations)
[docs] def enforce(self, name: str, value: Any): """Administers rule to check if required items in collection.""" if not self.rule(value): raise self.validation_error( name, [k for k in self.validations if k not in self.collection(value)], )
[docs] def rule(self, value: Any) -> bool: """True if all required parameters are in 'value' collection.""" return all(k in self.collection(value) for k in self.validations)
[docs] def collection(self, value: Any) -> list[Any]: """Returns the collection to check for required items.""" return value
[docs] class RequiredUIJsonParameterEnforcer(RequiredEnforcer): enforcer_type = "required_uijson_parameters" validation_error = RequiredUIJsonParameterValidationError
[docs] class RequiredFormMemberEnforcer(RequiredEnforcer): enforcer_type = "required_form_members" validation_error = RequiredFormMemberValidationError
[docs] class RequiredWorkspaceObjectEnforcer(RequiredEnforcer): enforcer_type = "required_workspace_object" validation_error = RequiredWorkspaceObjectValidationError
[docs] def rule(self, value: Any) -> bool: """True if all objects are in the workspace.""" validations = [value[k]["value"].uid for k in self.validations] return all(k in self.collection(value) for k in validations)
[docs] def collection(self, value: dict[str, Any]) -> list[UUID]: return list(value["geoh5"].list_entities_name)
[docs] class RequiredObjectDataEnforcer(Enforcer): enforcer_type = "required_object_data" validation_error = RequiredObjectDataValidationError
[docs] def enforce(self, name: str, value: Any): """Administers rule to check if required items in collection.""" if not self.rule(value): raise self.validation_error( name, [ k for i, k in enumerate(self.validations) if value[k[1]]["value"].uid not in self.collection(value)[i] ], )
[docs] def rule(self, value: Any) -> bool: """True if object/data have parent/child relationship.""" return all( value[k[1]]["value"].uid in self.collection(value)[i] for i, k in enumerate(self.validations) )
[docs] def collection(self, value: dict[str, Any]) -> list[list[UUID]]: """Returns list of children for all parents in validations.""" return [ [c.uid for c in value[k[0]]["value"].children] for k in self.validations ]
[docs] class EnforcerPool: """ Validate data on a collection of enforcers. :param name: Name of parameter. :param enforcers: List (pool) of enforcers. """ enforcer_types = { "type": TypeEnforcer, "value": ValueEnforcer, "uuid": UUIDEnforcer, "type_uid": TypeUIDEnforcer, "required": RequiredEnforcer, "required_uijson_parameters": RequiredUIJsonParameterEnforcer, "required_form_members": RequiredFormMemberEnforcer, "required_workspace_object": RequiredWorkspaceObjectEnforcer, "required_object_data": RequiredObjectDataEnforcer, } def __init__(self, name: str, enforcers: list[Enforcer]): self.name = name self.enforcers: list[Enforcer] = enforcers self._errors: list[BaseValidationError] = []
[docs] @classmethod def from_validations( cls, name: str, validations: SetDict, ) -> EnforcerPool: """ Create enforcers pool from validations. :param name: Name of parameter. :param validations: Encodes validations as enforcer type and validation key value pairs. :param restricted_validations: 0. """ return cls(name, cls._recruit(validations))
@property def validations(self) -> SetDict: """Returns an enforcer type / validation dictionary from pool.""" return SetDict(**{k.enforcer_type: k.validations for k in self.enforcers}) @staticmethod def _recruit(validations: SetDict): """Recruit enforcers from validations.""" return [EnforcerPool._recruit_enforcer(k, v) for k, v in validations.items()] @staticmethod def _recruit_enforcer(enforcer_type: str, validation: set) -> Enforcer: """ Create enforcer from enforcer type and validation. :param enforcer_type: Type of enforcer to create. :param validation: Enforcer validation. """ if enforcer_type not in EnforcerPool.enforcer_types: raise ValueError(f"Invalid enforcer type: {enforcer_type}.") return EnforcerPool.enforcer_types[enforcer_type](validation)
[docs] def enforce(self, value: Any): """Enforce rules from all enforcers in the pool.""" for enforcer in self.enforcers: self._capture_error(enforcer, value) self._raise_errors()
def _capture_error(self, enforcer: Enforcer, value: Any): """Catch and store 'BaseValidationError's for aggregation.""" try: enforcer.enforce(self.name, value) except BaseValidationError as err: self._errors.append(err) def _raise_errors(self): """Raise errors if any exist, aggregate if more than one.""" if self._errors: if len(self._errors) > 1: raise AggregateValidationError(self.name, self._errors) raise self._errors.pop()