# Copyright (c) 2024 Mira Geoscience Ltd.
#
# This file is part of geoh5py.
#
# geoh5py is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# geoh5py is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with geoh5py. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import warnings
from abc import ABC
from contextlib import contextmanager
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from uuid import UUID
import h5py
import numpy as np
if TYPE_CHECKING:
from ..workspace import Workspace
from .entity import Entity
KEY_MAP = {
"cells": "Cells",
"color_map": "Color map",
"concatenated_attributes": "Attributes",
"concatenated_object_ids": "Concatenated object IDs",
"layers": "Layers",
"metadata": "Metadata",
"octree_cells": "Octree Cells",
"options": "options",
"prisms": "Prisms",
"property_groups": "PropertyGroups",
"property_group_ids": "Property Group IDs",
"surveys": "Surveys",
"trace": "Trace",
"trace_depth": "TraceDepth",
"u_cell_delimiters": "U cell delimiters",
"v_cell_delimiters": "V cell delimiters",
"values": "Data",
"vertices": "Vertices",
"z_cell_delimiters": "Z cell delimiters",
"INVALID": "Invalid",
"INTEGER": "Integer",
"FLOAT": "Float",
"TEXT": "Text",
"BOOLEAN": "Boolean",
"REFERENCED": "Referenced",
"FILENAME": "Filename",
"BLOB": "Blob",
"VECTOR": "Vector",
"DATETIME": "DateTime",
"GEOMETRIC": "Geometric",
"MULTI_TEXT": "Multi-Text",
"UNKNOWN": "Unknown",
"OBJECT": "Object",
"CELL": "Cell",
"VERTEX": "Vertex",
"FACE": "Face",
"GROUP": "Group",
}
INV_KEY_MAP = {value: key for key, value in KEY_MAP.items()}
PNG_KWARGS = {"format": "PNG", "compress_level": 9}
JPG_KWARGS = {"format": "JPEG", "quality": 85}
TIF_KWARGS = {"format": "TIFF"}
PILLOW_ARGUMENTS = {
"1": PNG_KWARGS,
"L": PNG_KWARGS,
"P": PNG_KWARGS,
"RGB": PNG_KWARGS,
"RGBA": PNG_KWARGS,
"CMYK": JPG_KWARGS,
"YCbCr": JPG_KWARGS,
"I": TIF_KWARGS,
"F": TIF_KWARGS,
}
[docs]
@contextmanager
def fetch_active_workspace(workspace: Workspace | None, mode: str = "r"):
"""
Open a workspace in the requested 'mode'.
If receiving an opened Workspace instead, merely return the given workspace.
:param workspace: A Workspace class
:param mode: Set the h5 read/write mode
:return h5py.File: Handle to an opened Workspace.
"""
if (
workspace is None
or getattr(workspace, "_geoh5")
and mode in workspace.geoh5.mode
):
try:
yield workspace
finally:
pass
else:
if getattr(workspace, "_geoh5"):
warnings.warn(
f"Closing the workspace in mode '{workspace.geoh5.mode}' "
f"and re-opening in mode '{mode}'."
)
workspace.close()
try:
yield workspace.open(mode=mode)
finally:
workspace.close()
[docs]
@contextmanager
def fetch_h5_handle(file: str | h5py.File | Path, mode: str = "r") -> h5py.File:
"""
Open in read+ mode a geoh5 file from string.
If receiving a file instead of a string, merely return the given file.
:param file: Name or handle to a geoh5 file.
:param mode: Set the h5 read/write mode
:return h5py.File: Handle to an opened h5py file.
"""
if isinstance(file, h5py.File):
try:
yield file
finally:
pass
else:
if Path(file).suffix != ".geoh5":
raise ValueError("Input h5 file must have a 'geoh5' extension.")
h5file = h5py.File(file, mode)
try:
yield h5file
finally:
h5file.close()
[docs]
def match_values(vec_a, vec_b, collocation_distance=1e-4) -> np.ndarray:
"""
Find indices of matching values between two arrays, within collocation_distance.
:param: vec_a, list or numpy.ndarray
Input sorted values
:param: vec_b, list or numpy.ndarray
Query values
:return: indices, numpy.ndarray
Pairs of indices for matching values between the two arrays such
that vec_a[ind[:, 0]] == vec_b[ind[:, 1]].
"""
ind_sort = np.argsort(vec_a)
ind = np.minimum(
np.searchsorted(vec_a[ind_sort], vec_b, side="right"), vec_a.shape[0] - 1
)
nearests = np.c_[ind, ind - 1]
match = np.where(
np.abs(vec_a[ind_sort][nearests] - vec_b[:, None]) < collocation_distance
)
indices = np.c_[ind_sort[nearests[match[0], match[1]]], match[0]]
return indices
[docs]
def merge_arrays(
head,
tail,
replace="A->B",
mapping=None,
collocation_distance=1e-4,
return_mapping=False,
) -> np.ndarray:
"""
Given two numpy.arrays of different length, find the matching values and append both arrays.
:param: head, numpy.array of float
First vector of shape(M,) to be appended.
:param: tail, numpy.array of float
Second vector of shape(N,) to be appended
:param: mapping=None, numpy.ndarray of int
Optional array where values from the head are replaced by the tail.
:param: collocation_distance=1e-4, float
Tolerance between matching values.
:return: numpy.array shape(O,)
Unique values from head to tail without repeats, within collocation_distance.
"""
if mapping is None:
mapping = match_values(head, tail, collocation_distance=collocation_distance)
if mapping.shape[0] > 0:
if replace == "B->A":
head[mapping[:, 0]] = tail[mapping[:, 1]]
else:
tail[mapping[:, 1]] = head[mapping[:, 0]]
tail = np.delete(tail, mapping[:, 1])
if return_mapping:
return np.r_[head, tail], mapping
return np.r_[head, tail]
[docs]
def clear_array_attributes(entity: Entity, recursive: bool = False):
"""
Clear all stashed values of attributes from an entity to free up memory.
:param entity: Entity to clear attributes from.
:param recursive: Clear attributes from children entities.
"""
if isinstance(entity.workspace.h5file, BytesIO):
return
for attribute in ["vertices", "cells", "values", "prisms", "layers"]:
if hasattr(entity, attribute):
setattr(entity, f"_{attribute}", None)
if recursive and hasattr(entity, "children"):
for child in entity.children:
clear_array_attributes(child, recursive=recursive)
[docs]
def are_objects_similar(obj1, obj2, ignore: list[str] | None):
"""
Compare two objects to see if they are similar. This is a shallow comparison.
:param obj1: The first object.
:param obj2: The first object.
:param ignore: List of attributes to ignore.
:return: If attributes similar or not.
"""
assert isinstance(obj1, type(obj2)), "Objects are not the same type."
attributes1 = getattr(obj1, "__dict__", obj1)
attributes2 = getattr(obj2, "__dict__", obj2)
# remove the ignore attributes
if isinstance(ignore, list) and isinstance(attributes1, dict):
for item in ignore:
attributes1.pop(item, None)
attributes2.pop(item, None)
return attributes1 == attributes2
[docs]
def compare_arrays(object_a, object_b, attribute: str, decimal: int = 6):
if getattr(object_b, attribute) is None:
raise ValueError(f"attr {attribute} is None for object {object_b.name}")
attr_a = getattr(object_a, attribute).tolist()
if len(attr_a) > 0 and isinstance(attr_a[0], str):
assert all(
a == b
for a, b in zip(getattr(object_a, attribute), getattr(object_b, attribute))
), f"Error comparing attribute '{attribute}'."
else:
np.testing.assert_array_almost_equal(
attr_a,
getattr(object_b, attribute).tolist(),
decimal=decimal,
err_msg=f"Error comparing attribute '{attribute}'.",
)
[docs]
def compare_floats(object_a, object_b, attribute: str, decimal: int = 6):
np.testing.assert_almost_equal(
getattr(object_a, attribute),
getattr(object_b, attribute),
decimal=decimal,
err_msg=f"Error comparing attribute '{attribute}'.",
)
[docs]
def compare_list(object_a, object_b, attribute: str, ignore: list[str] | None):
get_object_a = getattr(object_a, attribute)
get_object_b = getattr(object_b, attribute)
assert isinstance(get_object_a, list)
assert len(get_object_a) == len(get_object_b)
for obj_a, obj_b in zip(get_object_a, get_object_b):
assert are_objects_similar(obj_a, obj_b, ignore)
[docs]
def compare_bytes(object_a, object_b):
assert (
object_a == object_b
), f"{type(object_a)} objects: {object_a}, {object_b} are not equal."
[docs]
def compare_entities(
object_a, object_b, ignore: list[str] | None = None, decimal: int = 6
) -> None:
if isinstance(object_a, bytes):
compare_bytes(object_a, object_b)
return
base_ignore = ["_workspace", "_children", "_visual_parameters", "_entity_class"]
ignore_list = base_ignore + ignore if ignore else base_ignore
for attr in [k for k in object_a.__dict__.keys() if k not in ignore_list]:
if isinstance(getattr(object_a, attr[1:]), ABC):
compare_entities(
getattr(object_a, attr[1:]),
getattr(object_b, attr[1:]),
ignore=ignore,
decimal=decimal,
)
else:
if isinstance(getattr(object_a, attr[1:]), np.ndarray):
compare_arrays(object_a, object_b, attr[1:], decimal=decimal)
elif isinstance(getattr(object_a, attr[1:]), float):
compare_floats(object_a, object_b, attr[1:], decimal=decimal)
elif isinstance(getattr(object_a, attr[1:]), list):
compare_list(object_a, object_b, attr[1:], ignore)
else:
assert np.all(
getattr(object_a, attr[1:]) == getattr(object_b, attr[1:])
), f"Output attribute '{attr[1:]}' for {object_a} do not match input {object_b}"
[docs]
def iterable(value: Any, checklen: bool = False) -> bool:
"""
Checks if object is iterable.
Parameters
----------
value : Object to check for iterableness.
checklen : Restrict objects with __iter__ method to len > 1.
Returns
-------
True if object has __iter__ attribute but is not string or dict type.
"""
only_array_like = (not isinstance(value, str)) & (not isinstance(value, dict))
if (hasattr(value, "__iter__")) & only_array_like:
return not (checklen and (len(value) == 1))
return False
[docs]
def iterable_message(valid: list[Any] | None) -> str:
"""Append possibly iterable valid: "Must be (one of): {valid}."."""
if valid is None:
msg = ""
elif iterable(valid, checklen=True):
vstr = "'" + "', '".join(str(k) for k in valid) + "'"
msg = f" Must be one of: {vstr}."
else:
msg = f" Must be: '{valid[0]}'."
return msg
[docs]
def is_uuid(value: str) -> bool:
"""Check if a string is UUID compliant."""
try:
UUID(str(value))
return True
except ValueError:
return False
[docs]
def entity2uuid(value: Any) -> UUID | Any:
"""Convert an entity to its UUID."""
if hasattr(value, "uid"):
return value.uid
return value
[docs]
def uuid2entity(value: UUID, workspace: Workspace) -> Entity | Any:
"""Convert UUID to a known entity."""
if isinstance(value, UUID):
if value in workspace.list_entities_name:
return workspace.get_entity(value)[0]
# Search for property groups
for obj in workspace.objects:
if getattr(obj, "property_groups", None) is not None:
prop_group = [
prop_group
for prop_group in getattr(obj, "property_groups")
if prop_group.uid == value
]
if prop_group:
return prop_group[0]
return None
return value
[docs]
def str2uuid(value: Any) -> UUID | Any:
"""Convert string to UUID"""
if isinstance(value, bytes):
value = value.decode("utf-8")
if is_uuid(value):
# TODO insert validation
return UUID(str(value))
return value
[docs]
def as_str_if_uuid(value: UUID | Any) -> str | Any:
"""Convert :obj:`UUID` to string used in geoh5."""
if isinstance(value, UUID):
return "{" + str(value) + "}"
return value
[docs]
def bool_value(value: np.int8) -> bool:
"""Convert logical int8 to bool."""
return bool(value)
[docs]
def as_str_if_utf8_bytes(value) -> str:
"""Convert bytes to string"""
if isinstance(value, bytes):
value = value.decode("utf-8")
return value
[docs]
def ensure_uuid(value: UUID | str) -> UUID:
"""
Ensure that the value is a UUID.
If not, it raises a type error.
:param value: The value to ensure is a UUID.
:return: The verified UUID.
"""
value = str2uuid(value)
if not isinstance(value, UUID):
raise TypeError(f"Value {value} is not a UUID but a {type(value)}.")
return value
[docs]
def dict_mapper(val, string_funcs: list[Callable], *args, omit: dict | None = None):
"""
Recursion through nested dictionaries and applies mapping functions to values.
:param val: Value (could be another dictionary) to apply transform functions.
:param string_funcs: Functions to apply on values within the input dictionary.
:param omit: Dictionary of functions to omit.
:return val: Transformed values
"""
if isinstance(val, dict):
for key, values in val.items():
short_list = string_funcs.copy()
if omit is not None:
short_list = [
fun for fun in string_funcs if fun not in omit.get(key, [])
]
val[key] = dict_mapper(values, short_list)
if isinstance(val, list):
out = []
for elem in val:
for fun in string_funcs:
elem = fun(elem, *args)
out += [elem]
return out
for fun in string_funcs:
val = fun(val, *args)
return val
[docs]
def box_intersect(extent_a: np.ndarray, extent_b: np.ndarray) -> bool:
"""
Compute the intersection of two axis-aligned bounding extents defined by their
arrays of minimum and maximum bounds in N-D space.
:param extent_a: First extent or shape (2, N)
:param extent_b: Second extent or shape (2, N)
:return: Logic if the box extents intersect along all dimensions.
"""
for extent in [extent_a, extent_b]:
if not isinstance(extent, np.ndarray) or extent.ndim != 2:
raise TypeError("Input extents must be 2D numpy.ndarrays.")
if extent.shape[0] != 2 or not np.all(extent[0, :] <= extent[1, :]):
raise ValueError(
"Extents must be of shape (2, N) containing the minimum and maximum "
"bounds in nd-space on the first and second row respectively."
)
for comp_a, comp_b in zip(extent_a.T, extent_b.T):
min_ext = max(comp_a[0], comp_b[0])
max_ext = min(comp_a[1], comp_b[1])
if min_ext > max_ext:
return False
return True
[docs]
def mask_by_extent(
locations: np.ndarray, extent: np.ndarray, inverse: bool = False
) -> np.ndarray:
"""
Find indices of locations within a rectangular extent.
:param locations: shape(*, 3) or shape(*, 2) Coordinates to be evaluated.
:param extent: shape(2, 2) Limits defined by the South-West and
North-East corners. Extents can also be provided as 3D coordinates
with shape(2, 3) defining the top and bottom limits.
:param inverse: Return the complement of the mask extent.
:returns: Array of bool for the locations inside or outside the box extent.
"""
if not isinstance(extent, np.ndarray) or extent.ndim != 2:
raise ValueError("Input 'extent' must be a 2D array-like.")
if not isinstance(locations, np.ndarray) or locations.ndim != 2:
raise ValueError(
"Input 'locations' must be an array-like of shape(*, 3) or (*, 2)."
)
indices = np.ones(locations.shape[0], dtype=bool)
for loc, lim in zip(locations.T, extent.T):
indices &= (lim[0] <= loc) & (loc <= lim[1])
if inverse:
return ~indices
return indices
[docs]
def get_attributes(entity, omit_list=(), attributes=None):
"""Extract the attributes of an object with omissions."""
if attributes is None:
attributes = {}
for key in vars(entity):
if key not in omit_list:
if key[0] == "_":
key = key[1:]
attr = getattr(entity, key)
attributes[key] = attr
return attributes
[docs]
def xy_rotation_matrix(angle: float) -> np.ndarray:
"""
Rotation matrix about the z-axis.
:param angle: Rotation angle in radians.
:return rot: Rotation matrix.
"""
return np.array(
[
[np.cos(angle), -np.sin(angle), 0.0],
[np.sin(angle), np.cos(angle), 0.0],
[0.0, 0.0, 1.0],
]
)
[docs]
def yz_rotation_matrix(angle: float) -> np.ndarray:
"""
Rotation matrix about the x-axis.
:param angle: Rotation angle in radians.
:return: rot: Rotation matrix.
"""
return np.array(
[
[1, 0, 0],
[0, np.cos(angle), -np.sin(angle)],
[0, np.sin(angle), np.cos(angle)],
]
)
[docs]
def dip_points(points: np.ndarray, dip: float, rotation: float = 0) -> np.ndarray:
"""
Rotate points about the x-axis by the dip angle and then about the z-axis by the rotation angle.
:param points: an array of points to rotate
:param dip: the dip angle in radians
:param rotation: the rotation angle in radians
:return: the rotated points
"""
# Assert points is a numpy array containing 3D points
if not isinstance(points, np.ndarray) and points.ndim != 2 and points.shape[1] != 3:
raise TypeError("Input points must be a 2D numpy array of shape (N, 3).")
# rotate the points about the z-axis by the inverse rotation angle
points = xy_rotation_matrix(-rotation) @ points.T
# Rotate points with the dip angle
points = yz_rotation_matrix(dip) @ points
# Rotate back the points to initial orientation
points = xy_rotation_matrix(rotation) @ points
return points.T
[docs]
def map_attributes(object_, **kwargs):
"""
Map attributes to an object. The object must have an '_attribute_map'.
:param object_: The object to map the attributes to.
:param kwargs: The kwargs to map to the object.
"""
if not hasattr(object_, "_attribute_map"):
warnings.warn(f"Object {object_} does not have an attribute map.")
return
for attr, item in kwargs.items():
try:
if attr in getattr(object_, "_attribute_map"):
attr = getattr(object_, "_attribute_map")[attr]
setattr(object_, attr, item)
except AttributeError:
continue
[docs]
def to_tuple(value: Any) -> tuple:
"""
Convert value to a tuple.
:param value: The value to convert.
:return: A tuple
"""
# ensure the names are a tuple
if isinstance(value, tuple):
return value
if isinstance(value, list):
return tuple(value)
return (value,)
[docs]
class SetDict(dict):
def __init__(self, **kwargs):
kwargs = {k: self.make_set(v) for k, v in kwargs.items()}
super().__init__(kwargs)
[docs]
def make_set(self, value):
if isinstance(value, (set, tuple, list)):
value = set(value)
else:
value = {value}
return value
def __setitem__(self, key, value):
value = self.make_set(value)
super().__setitem__(key, value)
[docs]
def update(self, value: dict, **kwargs) -> None: # type: ignore
for key, val in value.items():
val = self.make_set(val)
if key in self:
val = self[key].union(val)
value[key] = val
super().update(value, **kwargs)