from __future__ import annotations
import functools
import sys
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Tuple
from xarray import DataArray, Dataset
from .iterators import LevelOrderIter
from .treenode import NodePath, TreeNode
if TYPE_CHECKING:
from .datatree import DataTree
[docs]class TreeIsomorphismError(ValueError):
"""Error raised if two tree objects do not share the same node structure."""
pass
def check_isomorphic(
a: DataTree,
b: DataTree,
require_names_equal: bool = False,
check_from_root: bool = True,
):
"""
Check that two trees have the same structure, raising an error if not.
Does not compare the actual data in the nodes.
By default this function only checks that subtrees are isomorphic, not the entire tree above (if it exists).
Can instead optionally check the entire trees starting from the root, which will ensure all
Can optionally check if corresponding nodes should have the same name.
Parameters
----------
a : DataTree
b : DataTree
require_names_equal : Bool
Whether or not to also check that each node has the same name as its counterpart.
check_from_root : Bool
Whether or not to first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
Raises
------
TypeError
If either a or b are not tree objects.
TreeIsomorphismError
If a and b are tree objects, but are not isomorphic to one another.
Also optionally raised if their structure is isomorphic, but the names of any two
respective nodes are not equal.
"""
if not isinstance(a, TreeNode):
raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}")
if not isinstance(b, TreeNode):
raise TypeError(f"Argument `b` is not a tree, it is of type {type(b)}")
if check_from_root:
a = a.root
b = b.root
diff = diff_treestructure(a, b, require_names_equal=require_names_equal)
if diff:
raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff)
def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
"""
# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
# (which it is so long as children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path
if require_names_equal:
if node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff
if len(node_a.children) != len(node_b.children):
diff = dedent(
f"""\
Number of children on node '{path_a}' of the left object: {len(node_a.children)}
Number of children on node '{path_b}' of the right object: {len(node_b.children)}"""
)
return diff
return ""
[docs]def map_over_subtree(func: Callable) -> Callable:
"""
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.
Applies a function to every dataset in one or more subtrees, returning new trees which store the results.
The function will be applied to any data-containing dataset stored in any of the nodes in the trees. The returned
trees will have the same structure as the supplied trees.
`func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after
mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any
returned value that is one of these types will be stacked into a separate tree before returning all of them.
The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named
similarly, but all the output trees will have nodes named in the same way as the first tree passed.
Parameters
----------
func : callable
Function to apply to datasets with signature:
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via .ds .
**kwargs : Any
Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via .ds .
Returns
-------
mapped : callable
Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at
each node.
See also
--------
DataTree.map_over_subtree
DataTree.map_over_subtree_inplace
DataTree.subtree
"""
# TODO examples in the docstring
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
@functools.wraps(func)
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from .datatree import DataTree
all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
a for a in kwargs.values() if isinstance(a, DataTree)
]
if len(all_tree_inputs) > 0:
first_tree, *other_trees = all_tree_inputs
else:
raise TypeError("Must pass at least one tree object")
for other_tree in other_trees:
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
check_isomorphic(
first_tree, other_tree, require_names_equal=False, check_from_root=False
)
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
out_data_objects = {}
args_as_tree_length_iterables = [
a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
]
n_args = len(args_as_tree_length_iterables)
kwargs_as_tree_length_iterables = {
k: v.subtree if isinstance(v, DataTree) else repeat(v)
for k, v in kwargs.items()
}
for node_of_first_tree, *all_node_args in zip(
first_tree.subtree,
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
):
node_args_as_datasetviews = [
a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]
]
node_kwargs_as_datasetviews = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.ds if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
)
)
func_with_error_context = _handle_errors_with_path_context(
node_of_first_tree.path
)(func)
if node_of_first_tree.has_data:
# call func on the data in this particular set of corresponding nodes
results = func_with_error_context(
*node_args_as_datasetviews, **node_kwargs_as_datasetviews
)
elif node_of_first_tree.has_attrs:
# propagate attrs
results = node_of_first_tree.ds
else:
# nothing to propagate so use fastpath to create empty node in new tree
results = None
# TODO implement mapping over multiple trees in-place using if conditions from here on?
out_data_objects[node_of_first_tree.path] = results
# Find out how many return values we received
num_return_values = _check_all_return_values(out_data_objects)
# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
original_root_path = first_tree.path
result_trees = []
for i in range(num_return_values):
out_tree_contents = {}
for n in first_tree.subtree:
p = n.path
if p in out_data_objects.keys():
if isinstance(out_data_objects[p], tuple):
output_node_data = out_data_objects[p][i]
else:
output_node_data = out_data_objects[p]
else:
output_node_data = None
# Discard parentage so that new trees don't include parents of input nodes
relative_path = str(NodePath(p).relative_to(original_root_path))
relative_path = "/" if relative_path == "." else relative_path
out_tree_contents[relative_path] = output_node_data
new_tree = DataTree.from_dict(
out_tree_contents,
name=first_tree.name,
)
result_trees.append(new_tree)
# If only one result then don't wrap it in a tuple
if len(result_trees) == 1:
return result_trees[0]
else:
return tuple(result_trees)
return _map_over_subtree
def _handle_errors_with_path_context(path):
"""Wraps given function so that if it fails it also raises path to node on which it failed."""
def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if sys.version_info >= (3, 11):
# Add the context information to the error message
e.add_note(
f"Raised whilst mapping function over node with path {path}"
)
raise
return wrapper
return decorator
def add_note(err: BaseException, msg: str) -> None:
# TODO: remove once python 3.10 can be dropped
if sys.version_info < (3, 11):
err.__notes__ = getattr(err, "__notes__", []) + [msg] # type: ignore[attr-defined]
else:
err.add_note(msg)
def _check_single_set_return_values(path_to_node, obj):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
elif isinstance(obj, tuple):
for r in obj:
if not isinstance(r, (Dataset, DataArray)):
raise TypeError(
f"One of the results of calling func on datasets on the nodes at position {path_to_node} is "
f"of type {type(r)}, not Dataset or DataArray."
)
return len(obj)
else:
raise TypeError(
f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not "
f"Dataset or DataArray, nor a tuple of such types."
)
def _check_all_return_values(returned_objects):
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""
if all(r is None for r in returned_objects.values()):
raise TypeError(
"Called supplied function on all nodes but found a return value of None for"
"all of them."
)
result_data_objects = [
(path_to_node, r)
for path_to_node, r in returned_objects.items()
if r is not None
]
if len(result_data_objects) == 1:
# Only one node in the tree: no need to check consistency of results between nodes
path_to_node, result = result_data_objects[0]
num_return_values = _check_single_set_return_values(path_to_node, result)
else:
prev_path, _ = result_data_objects[0]
prev_num_return_values, num_return_values = None, None
for path_to_node, obj in result_data_objects[1:]:
num_return_values = _check_single_set_return_values(path_to_node, obj)
if (
num_return_values != prev_num_return_values
and prev_num_return_values is not None
):
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return "
f"values, whereas calling func on the nodes at position {prev_path} instead returns "
f"{prev_num_return_values} separate return values."
)
prev_path, prev_num_return_values = path_to_node, num_return_values
return num_return_values