# Copyright (c) 2024 Mira Geoscience Ltd.
#
# This file is part of geoh5py.
#
# geoh5py is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# geoh5py is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with geoh5py. If not, see <https://www.gnu.org/licenses/>.
# pylint: disable=too-few-public-methods
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from uuid import UUID
import numpy as np
from geoh5py import TYPE_UID_TO_CLASS, Workspace
from geoh5py.groups import Group, PropertyGroup
from geoh5py.objects import ObjectBase
from geoh5py.shared import Entity
from geoh5py.shared.exceptions import (
AssociationValidationError,
AtLeastOneValidationError,
OptionalValidationError,
PropertyGroupValidationError,
RequiredValidationError,
ShapeValidationError,
TypeValidationError,
UUIDValidationError,
ValueValidationError,
iterable,
)
[docs]
def to_path(value: list[str]) -> list[Path]:
"""Promote path strings to patlib.Path objects."""
out = []
for path in value:
if isinstance(path, str):
out.append(Path(path))
else:
out.append(path)
return out
[docs]
def to_list(value: Any) -> list[Any]:
"""Promote single values to list."""
if isinstance(value, str) and ";" in value:
value = value.split(";")
if not isinstance(value, list):
value = [value]
return value
[docs]
def to_uuid(values):
"""Promote strings to uuid and pass anything else."""
out = []
for val in values:
if isinstance(val, str):
val = UUID(val)
out.append(val)
return out
[docs]
def class_or_raise(value: UUID) -> type[ObjectBase] | type[Group]:
"""Promote uid to class, raise if uid is not a geoh5py type uid."""
if value not in TYPE_UID_TO_CLASS:
raise ValueError(
f"Provided type_uid string {value!s} is not a recognized "
f"geoh5py object or group type uid."
)
return TYPE_UID_TO_CLASS[value]
[docs]
def to_class(
values: list[UUID | type[ObjectBase] | type[Group]],
) -> list[type[ObjectBase] | type[Group]]:
"""
Promote uid to class.
Passes existing classes and raises if uid is not a geoh5py type uid.
"""
out = []
for val in values:
if isinstance(val, UUID):
out.append(class_or_raise(val))
elif issubclass(val, (ObjectBase, Group)):
out.append(val)
return out
[docs]
def empty_string_to_uid(value):
"""Promote empty string to uid, and pass all other values."""
if value == "":
return UUID("00000000-0000-0000-0000-000000000000")
return value
[docs]
class BaseValidator(ABC):
"""Concrete base class for validators."""
validator_type: str
def __init__(self, **kwargs):
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
def __call__(self, *args):
self.validate(*args)
[docs]
@classmethod
@abstractmethod
def validate(cls, name: str, value: Any, valid: Any):
"""
Custom validation function.
"""
raise NotImplementedError(
"The 'validate' method must be implemented by the sub-class. "
f"Must contain a 'name' {name}, 'value' {value} and 'valid' {valid} argument."
)
[docs]
class OptionalValidator(BaseValidator):
"""Validate that forms contain optional parameter if None value is given."""
validator_type = "optional"
[docs]
@classmethod
def validate(
cls,
name: str,
value: Any | None,
valid: bool,
) -> None:
"""
:param name: Parameter identifier.
:param value: Input parameter value.
:param valid: True if optional keyword in form for parameter.
"""
if value is None and not valid:
raise OptionalValidationError(name, value, valid)
[docs]
class AssociationValidator(BaseValidator):
"""Validate the association between data and parent object."""
validator_type = "association"
[docs]
@classmethod
def validate(
cls,
name: str,
value: Entity | PropertyGroup | UUID | None,
valid: Entity | Workspace,
) -> None:
"""
:param name: Parameter identifier.
:param value: Input parameter value.
:param valid: Expected value shape
"""
if valid is None:
return
if isinstance(valid, list):
warnings.warn(
"Data associated with multiSelect dependent is not supported. Validation ignored."
)
return
if not isinstance(valid, (Entity, Workspace)):
raise ValueError(
"'AssociationValidator.validate' requires a 'valid'"
" input of type 'Entity', 'Workspace' or None. "
f"Provided '{valid}' of type {type(valid)} for parameter '{name}'"
)
if isinstance(value, UUID):
uid = value
elif isinstance(value, (Entity, PropertyGroup)):
uid = value.uid
else:
return
if isinstance(valid, Workspace):
# TODO add a generic method to workspace to get all uuid
children = valid.get_entity(uid)
if None in children:
children = valid.fetch_children(valid.root, recursively=True)
elif isinstance(valid, Entity):
children = valid.workspace.fetch_children(valid, recursively=True)
if uid not in [getattr(child, "uid", None) for child in children]:
raise AssociationValidationError(name, value, valid)
[docs]
class PropertyGroupValidator(BaseValidator):
"""Validate property_group from parent entity."""
validator_type = "property_group_type"
[docs]
@classmethod
def validate(cls, name: str, value: PropertyGroup, valid: str | list[str]) -> None:
if isinstance(valid, str):
valid = [valid]
if (value is not None) and (value.property_group_type not in valid):
raise PropertyGroupValidationError(name, value, valid)
[docs]
class AtLeastOneValidator(BaseValidator):
validator_type = "one_of"
[docs]
@classmethod
def validate(cls, name, value, valid):
if not any(v for v in value.values()):
raise AtLeastOneValidationError(name, value)
[docs]
class RequiredValidator(BaseValidator):
"""
Validate that required keys are present in parameter.
"""
validator_type = "required"
[docs]
@classmethod
def validate(cls, name: str, value: Any, valid: bool) -> None:
"""
:param name: Parameter identifier.
:param value: Input parameter value.
:param valid: Assert to be required
"""
if value is None and valid:
raise RequiredValidationError(name)
[docs]
class ShapeValidator(BaseValidator):
"""Validate the shape of provided value."""
validator_type = "shape"
[docs]
@classmethod
def validate(cls, name: str, value: Any, valid: tuple[int, ...]) -> None:
"""
:param name: Parameter identifier.
:param value: Input parameter value.
:param valid: Expected value shape
"""
if value is None:
return
if isinstance(value, np.ndarray):
pshape = value.shape
elif isinstance(value, list):
pshape = (len(value),)
else:
pshape = (1,)
if pshape != valid:
raise ShapeValidationError(name, pshape, valid)
[docs]
class TypeValidator(BaseValidator):
"""
Validate the value type from a list of valid types.
"""
validator_type = "types"
[docs]
@classmethod
def validate(cls, name: str, value: Any, valid: type | list[type]) -> None:
"""
:param name: Parameter identifier.
:param value: Input parameter value.
:param valid: List of accepted value types
"""
if isinstance(valid, type):
valid = [valid]
if not isinstance(valid, list):
raise TypeError("Input `valid` options must be a type or list of types.")
if not iterable(value) or (isinstance(value, list) and list in tuple(valid)):
value = (value,)
for val in value:
if not isinstance(val, tuple(valid)):
valid_names = [t.__name__ for t in valid if hasattr(t, "__name__")]
type_name = type(val).__name__
raise TypeValidationError(name, type_name, valid_names)
[docs]
class UUIDValidator(BaseValidator):
"""Validate a uuui.UUID value or uuid string."""
validator_type = "uuid"
[docs]
@classmethod
def validate(cls, name: str, value: Any, valid: None = None) -> None:
"""
:param name: Parameter identifier.
:param value: Input parameter uuid.
:param valid: [Optional] Validate uuid from parental entity or known uuids
"""
if isinstance(value, str):
try:
value = UUID(value)
except ValueError as exception:
raise UUIDValidationError(name, str(value)) from exception
[docs]
class ValueValidator(BaseValidator):
"""
Validator that ensures that values are valid entries.
"""
validator_type = "values"
[docs]
@classmethod
def validate(cls, name: str, value: Any, valid: list[float | str]) -> None:
"""
:param name: Parameter identifier.
:param value: Input parameter value.
:param valid: List of accepted values
"""
if value is None:
return
if not isinstance(value, (list, tuple)):
value = [value]
for val in value:
if val is not None and val not in valid:
raise ValueValidationError(name, val, valid)