Source code for geoh5py.shared.validators

#  Copyright (c) 2024 Mira Geoscience Ltd.
#
#  This file is part of geoh5py.
#
#  geoh5py is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Lesser General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  geoh5py is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Lesser General Public License for more details.
#
#  You should have received a copy of the GNU Lesser General Public License
#  along with geoh5py.  If not, see <https://www.gnu.org/licenses/>.
#
#  This file is part of geoapps.
#
#  geoapps is distributed under the terms and conditions of the MIT License
#  (see LICENSE file at the root of this source code package).

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import Any
from uuid import UUID

import numpy as np

from geoh5py import Workspace
from geoh5py.groups import PropertyGroup
from geoh5py.shared import Entity
from geoh5py.shared.exceptions import (
    AssociationValidationError,
    AtLeastOneValidationError,
    OptionalValidationError,
    PropertyGroupValidationError,
    RequiredValidationError,
    ShapeValidationError,
    TypeValidationError,
    UUIDValidationError,
    ValueValidationError,
)
from geoh5py.shared.utils import iterable


[docs] class BaseValidator(ABC): """Concrete base class for validators.""" def __init__(self, **kwargs): for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) def __call__(self, *args): self.validate(*args)
[docs] @classmethod @abstractmethod def validate(cls, name: str, value: Any, valid: Any): """ Custom validation function. """ raise NotImplementedError( "The 'validate' method must be implemented by the sub-class. " f"Must contain a 'name' {name}, 'value' {value} and 'valid' {valid} argument." )
@property @abstractmethod def validator_type(self) -> str: """ Validation type identifier. """ raise NotImplementedError("Must implement the validator_type property.")
[docs] class OptionalValidator(BaseValidator): """Validate that forms contain optional parameter if None value is given.""" validator_type = "optional"
[docs] @classmethod def validate( cls, name: str, value: Any | None, valid: bool, ) -> None: """ :param name: Parameter identifier. :param value: Input parameter value. :param valid: True if optional keyword in form for parameter. """ if value is None and not valid: raise OptionalValidationError(name, value, valid)
[docs] class AssociationValidator(BaseValidator): """Validate the association between data and parent object.""" validator_type = "association"
[docs] @classmethod def validate( cls, name: str, value: Entity | PropertyGroup | UUID | None, valid: Entity | Workspace, ) -> None: """ :param name: Parameter identifier. :param value: Input parameter value. :param valid: Expected value shape """ if valid is None: return if isinstance(valid, list): warnings.warn( "Data associated with multiSelect dependent is not supported. Validation ignored." ) return if not isinstance(valid, (Entity, Workspace)): raise ValueError( "'AssociationValidator.validate' requires a 'valid'" " input of type 'Entity', 'Workspace' or None. " f"Provided '{valid}' of type {type(valid)} for parameter '{name}'" ) if isinstance(value, UUID): uid = value elif isinstance(value, (Entity, PropertyGroup)): uid = value.uid else: return if isinstance(valid, Workspace): # TODO add a generic method to workspace to get all uuid children = valid.get_entity(uid) if None in children: children = valid.fetch_children(valid.root, recursively=True) elif isinstance(valid, Entity): children = valid.workspace.fetch_children(valid, recursively=True) if uid not in [getattr(child, "uid", None) for child in children]: raise AssociationValidationError(name, value, valid)
[docs] class PropertyGroupValidator(BaseValidator): """Validate property_group from parent entity.""" validator_type = "property_group_type"
[docs] @classmethod def validate(cls, name: str, value: PropertyGroup, valid: str) -> None: if (value is not None) and (value.property_group_type != valid): raise PropertyGroupValidationError(name, value, valid)
[docs] class AtLeastOneValidator(BaseValidator): validator_type = "one_of"
[docs] @classmethod def validate(cls, name, value, valid): if not any(v for v in value.values()): raise AtLeastOneValidationError(name, value)
[docs] class RequiredValidator(BaseValidator): """ Validate that required keys are present in parameter. """ validator_type = "required"
[docs] @classmethod def validate(cls, name: str, value: Any, valid: bool) -> None: """ :param name: Parameter identifier. :param value: Input parameter value. :param valid: Assert to be required """ if value is None and valid: raise RequiredValidationError(name)
[docs] class ShapeValidator(BaseValidator): """Validate the shape of provided value.""" validator_type = "shape"
[docs] @classmethod def validate(cls, name: str, value: Any, valid: tuple[int]) -> None: """ :param name: Parameter identifier. :param value: Input parameter value. :param valid: Expected value shape """ if value is None: return if isinstance(value, np.ndarray): pshape = value.shape elif isinstance(value, list): pshape = (len(value),) else: pshape = (1,) if pshape != valid: raise ShapeValidationError(name, pshape, valid)
[docs] class TypeValidator(BaseValidator): """ Validate the value type from a list of valid types. """ validator_type = "types"
[docs] @classmethod def validate(cls, name: str, value: Any, valid: type | list[type]) -> None: """ :param name: Parameter identifier. :param value: Input parameter value. :param valid: List of accepted value types """ if isinstance(valid, type): valid = [valid] if not isinstance(valid, list): raise TypeError("Input `valid` options must be a type or list of types.") if not iterable(value) or (isinstance(value, list) and list in tuple(valid)): value = (value,) for val in value: if not isinstance(val, tuple(valid)): valid_names = [t.__name__ for t in valid if hasattr(t, "__name__")] type_name = type(val).__name__ raise TypeValidationError(name, type_name, valid_names)
[docs] class UUIDValidator(BaseValidator): """Validate a uuui.UUID value or uuid string.""" validator_type = "uuid"
[docs] @classmethod def validate(cls, name: str, value: Any, valid: None = None) -> None: """ :param name: Parameter identifier. :param value: Input parameter uuid. :param valid: [Optional] Validate uuid from parental entity or known uuids """ if isinstance(value, str): try: value = UUID(value) except ValueError as exception: raise UUIDValidationError(name, str(value)) from exception
[docs] class ValueValidator(BaseValidator): """ Validator that ensures that values are valid entries. """ validator_type = "values"
[docs] @classmethod def validate(cls, name: str, value: Any, valid: list[float | str]) -> None: """ :param name: Parameter identifier. :param value: Input parameter value. :param valid: List of accepted values """ if value is None: return if not isinstance(value, (list, tuple)): value = [value] for val in value: if val is not None and val not in valid: raise ValueValidationError(name, val, valid)