Source code for datatree.datatree

from __future__ import annotations

import copy
import itertools
from collections import OrderedDict
from html import escape
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Generic,
    Hashable,
    Iterable,
    Iterator,
    List,
    Mapping,
    MutableMapping,
    Optional,
    Set,
    Tuple,
    Union,
    overload,
)

from xarray.core import utils
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
from xarray.core.indexes import Index, Indexes
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
from xarray.core.utils import (
    Default,
    Frozen,
    HybridMappingProxy,
    _default,
    either_dict_or_kwargs,
    maybe_wrap_array,
)
from xarray.core.variable import Variable

from . import formatting, formatting_html
from .common import TreeAttrAccessMixin
from .mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree
from .ops import (
    DataTreeArithmeticMixin,
    MappedDatasetMethodsMixin,
    MappedDataWithCoords,
)
from .render import RenderTree
from .treenode import NamedNode, NodePath, Tree

try:
    from xarray.core.variable import calculate_dimensions
except ImportError:
    # for xarray versions 2022.03.0 and earlier
    from xarray.core.dataset import calculate_dimensions

if TYPE_CHECKING:
    import pandas as pd
    from xarray.core.merge import CoercibleValue
    from xarray.core.types import ErrorOptions

# """
# DEVELOPERS' NOTE
# ----------------
# The idea of this module is to create a `DataTree` class which inherits the tree structure from TreeNode, and also copies
# the entire API of `xarray.Dataset`, but with certain methods decorated to instead map the dataset function over every
# node in the tree. As this API is copied without directly subclassing `xarray.Dataset` we instead create various Mixin
# classes (in ops.py) which each define part of `xarray.Dataset`'s extensive API.
#
# Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine to inherit unaltered
# (normally because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new
# tree) and some will get overridden by the class definition of DataTree.
# """


T_Path = Union[str, NodePath]


def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset:
    if isinstance(data, DataArray):
        ds = data.to_dataset()
    elif isinstance(data, Dataset):
        ds = data
    elif data is None:
        ds = Dataset()
    else:
        raise TypeError(
            f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}"
        )
    return ds


def _check_for_name_collisions(
    children: Iterable[str], variables: Iterable[Hashable]
) -> None:
    colliding_names = set(children).intersection(set(variables))
    if colliding_names:
        raise KeyError(
            f"Some names would collide between variables and children: {list(colliding_names)}"
        )


class DatasetView(Dataset):
    """
    An immutable Dataset-like view onto the data in a single DataTree node.

    In-place operations modifying this object should raise an AttributeError.
    This requires overriding all inherited constructors.

    Operations returning a new result will return a new xarray.Dataset object.
    This includes all API on Dataset, which will be inherited.
    """

    # TODO what happens if user alters (in-place) a DataArray they extracted from this object?

    __slots__ = (
        "_attrs",
        "_cache",
        "_coord_names",
        "_dims",
        "_encoding",
        "_close",
        "_indexes",
        "_variables",
    )

    def __init__(
        self,
        data_vars: Optional[Mapping[Any, Any]] = None,
        coords: Optional[Mapping[Any, Any]] = None,
        attrs: Optional[Mapping[Any, Any]] = None,
    ):
        raise AttributeError("DatasetView objects are not to be initialized directly")

    @classmethod
    def _from_node(
        cls,
        wrapping_node: DataTree,
    ) -> DatasetView:
        """Constructor, using dataset attributes from wrapping node"""

        obj: DatasetView = object.__new__(cls)
        obj._variables = wrapping_node._variables
        obj._coord_names = wrapping_node._coord_names
        obj._dims = wrapping_node._dims
        obj._indexes = wrapping_node._indexes
        obj._attrs = wrapping_node._attrs
        obj._close = wrapping_node._close
        obj._encoding = wrapping_node._encoding

        return obj

    def __setitem__(self, key, val) -> None:
        raise AttributeError(
            "Mutation of the DatasetView is not allowed, please use `.__setitem__` on the wrapping DataTree node, "
            "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`,"
            "use `.copy()` first to get a mutable version of the input dataset."
        )

    def update(self, other) -> None:
        raise AttributeError(
            "Mutation of the DatasetView is not allowed, please use `.update` on the wrapping DataTree node, "
            "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`,"
            "use `.copy()` first to get a mutable version of the input dataset."
        )

    # FIXME https://github.com/python/mypy/issues/7328
    @overload
    def __getitem__(self, key: Mapping) -> Dataset:  # type: ignore[misc]
        ...

    @overload
    def __getitem__(self, key: Hashable) -> DataArray:  # type: ignore[misc]
        ...

    @overload
    def __getitem__(self, key: Any) -> Dataset:
        ...

    def __getitem__(self, key) -> DataArray:
        # TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes
        # For now just call Dataset.__getitem__
        return Dataset.__getitem__(self, key)

    @classmethod
    def _construct_direct(
        cls,
        variables: dict[Any, Variable],
        coord_names: set[Hashable],
        dims: Optional[dict[Any, int]] = None,
        attrs: Optional[dict] = None,
        indexes: Optional[dict[Any, Index]] = None,
        encoding: Optional[dict] = None,
        close: Optional[Callable[[], None]] = None,
    ) -> Dataset:
        """
        Overriding this method (along with ._replace) and modifying it to return a Dataset object
        should hopefully ensure that the return type of any method on this object is a Dataset.
        """
        if dims is None:
            dims = calculate_dimensions(variables)
        if indexes is None:
            indexes = {}
        obj = object.__new__(Dataset)
        obj._variables = variables
        obj._coord_names = coord_names
        obj._dims = dims
        obj._indexes = indexes
        obj._attrs = attrs
        obj._close = close
        obj._encoding = encoding
        return obj

    def _replace(
        self,
        variables: Optional[dict[Hashable, Variable]] = None,
        coord_names: Optional[set[Hashable]] = None,
        dims: Optional[dict[Any, int]] = None,
        attrs: dict[Hashable, Any] | None | Default = _default,
        indexes: Optional[dict[Hashable, Index]] = None,
        encoding: dict | None | Default = _default,
        inplace: bool = False,
    ) -> Dataset:
        """
        Overriding this method (along with ._construct_direct) and modifying it to return a Dataset object
        should hopefully ensure that the return type of any method on this object is a Dataset.
        """

        if inplace:
            raise AttributeError("In-place mutation of the DatasetView is not allowed")

        return Dataset._replace(
            self,
            variables=variables,
            coord_names=coord_names,
            dims=dims,
            attrs=attrs,
            indexes=indexes,
            encoding=encoding,
            inplace=inplace,
        )

    def map(
        self,
        func: Callable,
        keep_attrs: bool | None = None,
        args: Iterable[Any] = (),
        **kwargs: Any,
    ) -> Dataset:
        """Apply a function to each data variable in this dataset

        Parameters
        ----------
        func : callable
            Function which can be called in the form `func(x, *args, **kwargs)`
            to transform each DataArray `x` in this dataset into another
            DataArray.
        keep_attrs : bool or None, optional
            If True, both the dataset's and variables' attributes (`attrs`) will be
            copied from the original objects to the new ones. If False, the new dataset
            and variables will be returned without copying the attributes.
        args : iterable, optional
            Positional arguments passed on to `func`.
        **kwargs : Any
            Keyword arguments passed on to `func`.

        Returns
        -------
        applied : Dataset
            Resulting dataset from applying ``func`` to each data variable.

        Examples
        --------
        >>> da = xr.DataArray(np.random.randn(2, 3))
        >>> ds = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])})
        >>> ds
        <xarray.Dataset>
        Dimensions:  (dim_0: 2, dim_1: 3, x: 2)
        Dimensions without coordinates: dim_0, dim_1, x
        Data variables:
            foo      (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 -0.9773
            bar      (x) int64 -1 2
        >>> ds.map(np.fabs)
        <xarray.Dataset>
        Dimensions:  (dim_0: 2, dim_1: 3, x: 2)
        Dimensions without coordinates: dim_0, dim_1, x
        Data variables:
            foo      (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773
            bar      (x) float64 1.0 2.0
        """

        # Copied from xarray.Dataset so as not to call type(self), which causes problems (see datatree GH188).
        # TODO Refactor xarray upstream to avoid needing to overwrite this.
        # TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated
        variables = {
            k: maybe_wrap_array(v, func(v, *args, **kwargs))
            for k, v in self.data_vars.items()
        }
        # return type(self)(variables, attrs=attrs)
        return Dataset(variables)


[docs]class DataTree( NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmeticMixin, TreeAttrAccessMixin, Generic[Tree], Mapping, ): """ A tree-like hierarchical collection of xarray objects. Attempts to present an API like that of xarray.Dataset, but methods are wrapped to also update all the tree's child nodes. """ # TODO Some way of sorting children by depth # TODO do we need a watch out for if methods intended only for root nodes are called on non-root nodes? # TODO dataset methods which should not or cannot act over the whole tree, such as .to_array # TODO .loc method # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from # TODO all groupby classes # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from # TODO __slots__ # TODO all groupby classes _name: Optional[str] _parent: Optional[DataTree] _children: OrderedDict[str, DataTree] _attrs: Optional[Dict[Hashable, Any]] _cache: Dict[str, Any] _coord_names: Set[Hashable] _dims: Dict[Hashable, int] _encoding: Optional[Dict[Hashable, Any]] _close: Optional[Callable[[], None]] _indexes: Dict[Hashable, Index] _variables: Dict[Hashable, Variable] __slots__ = ( "_name", "_parent", "_children", "_attrs", "_cache", "_coord_names", "_dims", "_encoding", "_close", "_indexes", "_variables", )
[docs] def __init__( self, data: Optional[Dataset | DataArray] = None, parent: Optional[DataTree] = None, children: Optional[Mapping[str, DataTree]] = None, name: Optional[str] = None, ): """ Create a single node of a DataTree. The node may optionally contain data in the form of data and coordinate variables, stored in the same way as data is stored in an xarray.Dataset. Parameters ---------- data : Dataset, DataArray, or None, optional Data to store under the .ds attribute of this node. DataArrays will be promoted to Datasets. Default is None. parent : DataTree, optional Parent node to this node. Default is None. children : Mapping[str, DataTree], optional Any child nodes of this node. Default is None. name : str, optional Name for this node of the tree. Default is None. Returns ------- DataTree See Also -------- DataTree.from_dict """ # validate input if children is None: children = {} ds = _coerce_to_dataset(data) _check_for_name_collisions(children, ds.variables) super().__init__(name=name) # set data attributes self._replace( inplace=True, variables=ds._variables, coord_names=ds._coord_names, dims=ds._dims, indexes=ds._indexes, attrs=ds._attrs, encoding=ds._encoding, ) self._close = ds._close # set tree attributes (must happen after variables set to avoid initialization errors) self.children = children self.parent = parent
@property def parent(self: DataTree) -> DataTree | None: """Parent of this node.""" return self._parent @parent.setter def parent(self: DataTree, new_parent: DataTree) -> None: if new_parent and self.name is None: raise ValueError("Cannot set an unnamed node as a child of another node") self._set_parent(new_parent, self.name) @property def ds(self) -> DatasetView: """ An immutable Dataset-like view onto the data in this node. For a mutable Dataset containing the same data as in this node, use `.to_dataset()` instead. See Also -------- DataTree.to_dataset """ return DatasetView._from_node(self) @ds.setter def ds(self, data: Optional[Union[Dataset, DataArray]] = None) -> None: ds = _coerce_to_dataset(data) _check_for_name_collisions(self.children, ds.variables) self._replace( inplace=True, variables=ds._variables, coord_names=ds._coord_names, dims=ds._dims, indexes=ds._indexes, attrs=ds._attrs, encoding=ds._encoding, ) self._close = ds._close def _pre_attach(self: DataTree, parent: DataTree) -> None: """ Method which superclass calls before setting parent, here used to prevent having two children with duplicate names (or a data variable with the same name as a child). """ super()._pre_attach(parent) if self.name in list(parent.ds.variables): raise KeyError( f"parent {parent.name} already contains a data variable named {self.name}" )
[docs] def to_dataset(self) -> Dataset: """ Return the data in this node as a new xarray.Dataset object. See Also -------- DataTree.ds """ return Dataset._construct_direct( self._variables, self._coord_names, self._dims, self._attrs, self._indexes, self._encoding, self._close, )
@property def has_data(self): """Whether or not there are any data variables in this node.""" return len(self._variables) > 0 @property def has_attrs(self) -> bool: """Whether or not there are any metadata attributes in this node.""" return len(self.attrs.keys()) > 0 @property def is_empty(self) -> bool: """False if node contains any data or attrs. Does not look at children.""" return not (self.has_data or self.has_attrs) @property def is_hollow(self) -> bool: """True if only leaf nodes contain data.""" return not any(node.has_data for node in self.subtree if not node.is_leaf) @property def variables(self) -> Mapping[Hashable, Variable]: """Low level interface to node contents as dict of Variable objects. This ordered dictionary is frozen to prevent mutation that could violate Dataset invariants. It contains all variable objects constituting this DataTree node, including both data variables and coordinates. """ return Frozen(self._variables) @property def attrs(self) -> Dict[Hashable, Any]: """Dictionary of global attributes on this node object.""" if self._attrs is None: self._attrs = {} return self._attrs @attrs.setter def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property def encoding(self) -> Dict: """Dictionary of global encoding attributes on this node object.""" if self._encoding is None: self._encoding = {} return self._encoding @encoding.setter def encoding(self, value: Mapping) -> None: self._encoding = dict(value) @property def dims(self) -> Mapping[Hashable, int]: """Mapping from dimension names to lengths. Cannot be modified directly, but is updated when adding new variables. Note that type of this object differs from `DataArray.dims`. See `DataTree.sizes`, `Dataset.sizes`, and `DataArray.sizes` for consistently named properties. """ return Frozen(self._dims) @property def sizes(self) -> Mapping[Hashable, int]: """Mapping from dimension names to lengths. Cannot be modified directly, but is updated when adding new variables. This is an alias for `DataTree.dims` provided for the benefit of consistency with `DataArray.sizes`. See Also -------- DataArray.sizes """ return self.dims @property def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: """Places to look-up items for attribute-style access""" yield from self._item_sources yield self.attrs @property def _item_sources(self) -> Iterable[Mapping[Any, Any]]: """Places to look-up items for key-completion""" yield self.data_vars yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) # virtual coordinates yield HybridMappingProxy(keys=self.dims, mapping=self) # immediate child nodes yield self.children def _ipython_key_completions_(self) -> List[str]: """Provide method for the key-autocompletions in IPython. See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion For the details. """ # TODO allow auto-completing relative string paths, e.g. `dt['path/to/../ <tab> node'` # Would require changes to ipython's autocompleter, see https://github.com/ipython/ipython/issues/12420 # Instead for now we only list direct paths to all node in subtree explicitly items_on_this_node = self._item_sources full_file_like_paths_to_all_nodes_in_subtree = { node.path[1:]: node for node in self.subtree } all_item_sources = itertools.chain( items_on_this_node, [full_file_like_paths_to_all_nodes_in_subtree] ) items = { item for source in all_item_sources for item in source if isinstance(item, str) } return list(items) def __contains__(self, key: object) -> bool: """The 'in' operator will return true or false depending on whether 'key' is either an array stored in the datatree or a child node, or neither. """ return key in self.variables or key in self.children def __bool__(self) -> bool: return bool(self.ds.data_vars) or bool(self.children) def __iter__(self) -> Iterator[Hashable]: return itertools.chain(self.ds.data_vars, self.children) def __array__(self, dtype=None): raise TypeError( "cannot directly convert a DataTree into a " "numpy array. Instead, create an xarray.DataArray " "first, either with indexing on the DataTree or by " "invoking the `to_array()` method." ) def __repr__(self) -> str: return formatting.datatree_repr(self) def __str__(self) -> str: return formatting.datatree_repr(self) def _repr_html_(self): """Make html representation of datatree object""" if XR_OPTS["display_style"] == "text": return f"<pre>{escape(repr(self))}</pre>" return formatting_html.datatree_repr(self) @classmethod def _construct_direct( cls, variables: dict[Any, Variable], coord_names: set[Hashable], dims: Optional[dict[Any, int]] = None, attrs: Optional[dict] = None, indexes: Optional[dict[Any, Index]] = None, encoding: Optional[dict] = None, name: str | None = None, parent: DataTree | None = None, children: Optional[OrderedDict[str, DataTree]] = None, close: Optional[Callable[[], None]] = None, ) -> DataTree: """Shortcut around __init__ for internal use when we want to skip costly validation.""" # data attributes if dims is None: dims = calculate_dimensions(variables) if indexes is None: indexes = {} if children is None: children = OrderedDict() obj: DataTree = object.__new__(cls) obj._variables = variables obj._coord_names = coord_names obj._dims = dims obj._indexes = indexes obj._attrs = attrs obj._close = close obj._encoding = encoding # tree attributes obj._name = name obj._children = children obj._parent = parent return obj def _replace( self: DataTree, variables: Optional[dict[Hashable, Variable]] = None, coord_names: Optional[set[Hashable]] = None, dims: Optional[dict[Any, int]] = None, attrs: dict[Hashable, Any] | None | Default = _default, indexes: Optional[dict[Hashable, Index]] = None, encoding: dict | None | Default = _default, name: str | None | Default = _default, parent: DataTree | None = _default, children: Optional[OrderedDict[str, DataTree]] = None, inplace: bool = False, ) -> DataTree: """ Fastpath constructor for internal use. Returns an object with optionally replaced attributes. Explicitly passed arguments are *not* copied when placed on the new datatree. It is up to the caller to ensure that they have the right type and are not used elsewhere. """ # TODO Adding new children inplace using this method will cause bugs. # You will end up with an inconsistency between the name of the child node and the key the child is stored under. # Use ._set() instead for now if inplace: if variables is not None: self._variables = variables if coord_names is not None: self._coord_names = coord_names if dims is not None: self._dims = dims if attrs is not _default: self._attrs = attrs if indexes is not None: self._indexes = indexes if encoding is not _default: self._encoding = encoding if name is not _default: self._name = name if parent is not _default: self._parent = parent if children is not None: self._children = children obj = self else: if variables is None: variables = self._variables.copy() if coord_names is None: coord_names = self._coord_names.copy() if dims is None: dims = self._dims.copy() if attrs is _default: attrs = copy.copy(self._attrs) if indexes is None: indexes = self._indexes.copy() if encoding is _default: encoding = copy.copy(self._encoding) if name is _default: name = self._name # no need to copy str objects or None if parent is _default: parent = copy.copy(self._parent) if children is _default: children = copy.copy(self._children) obj = self._construct_direct( variables, coord_names, dims, attrs, indexes, encoding, name, parent, children, ) return obj
[docs] def copy( self: DataTree, deep: bool = False, ) -> DataTree: """ Returns a copy of this subtree. Copies this node and all child nodes. If `deep=True`, a deep copy is made of each of the component variables. Otherwise, a shallow copy of each of the component variable is made, so that the underlying memory region of the new datatree is the same as in the original datatree. Parameters ---------- deep : bool, default: False Whether each component variable is loaded into memory and copied onto the new object. Default is False. Returns ------- object : DataTree New object with dimensions, attributes, coordinates, name, encoding, and data of this node and all child nodes copied from original. See Also -------- xarray.Dataset.copy pandas.DataFrame.copy """ return self._copy_subtree(deep=deep)
def _copy_subtree( self: DataTree, deep: bool = False, memo: dict[int, Any] | None = None, ) -> DataTree: """Copy entire subtree""" new_tree = self._copy_node(deep=deep) for node in self.descendants: path = node.relative_to(self) new_tree[path] = node._copy_node(deep=deep) return new_tree def _copy_node( self: DataTree, deep: bool = False, ) -> DataTree: """Copy just one node of a tree""" new_node: DataTree = DataTree() new_node.name = self.name new_node.ds = self.to_dataset().copy(deep=deep) return new_node def __copy__(self: DataTree) -> DataTree: return self._copy_subtree(deep=False) def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree: return self._copy_subtree(deep=True, memo=memo)
[docs] def get( self: DataTree, key: str, default: Optional[DataTree | DataArray] = None ) -> Optional[DataTree | DataArray]: """ Access child nodes, variables, or coordinates stored in this node. Returned object will be either a DataTree or DataArray object depending on whether the key given points to a child or variable. Parameters ---------- key : str Name of variable / child within this node. Must lie in this immediate node (not elsewhere in the tree). default : DataTree | DataArray, optional A value to return if the specified key does not exist. Default return value is None. """ if key in self.children: return self.children[key] elif key in self.ds: return self.ds[key] else: return default
[docs] def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: """ Access child nodes, variables, or coordinates stored anywhere in this tree. Returned object will be either a DataTree or DataArray object depending on whether the key given points to a child or variable. Parameters ---------- key : str Name of variable / child within this node, or unix-like path to variable / child within another node. Returns ------- Union[DataTree, DataArray] """ # Either: if utils.is_dict_like(key): # dict-like indexing raise NotImplementedError("Should this index over whole tree?") elif isinstance(key, str): # TODO should possibly deal with hashables in general? # path-like: a name of a node/variable, or path to a node/variable path = NodePath(key) return self._get_item(path) elif utils.is_list_like(key): # iterable of variable names raise NotImplementedError( "Selecting via tags is deprecated, and selecting multiple items should be " "implemented via .subset" ) else: raise ValueError(f"Invalid format for key: {key}")
def _set(self, key: str, val: DataTree | CoercibleValue) -> None: """ Set the child node or variable with the specified key to value. Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree. """ if isinstance(val, DataTree): # create and assign a shallow copy here so as not to alter original name of node in grafted tree new_node = val.copy(deep=False) new_node.name = key new_node.parent = self else: if not isinstance(val, (DataArray, Variable)): # accommodate other types that can be coerced into Variables val = DataArray(val) self.update({key: val})
[docs] def __setitem__( self, key: str, value: Any, ) -> None: """ Add either a child node or an array to the tree, at any position. Data can be added anywhere, and new nodes will be created to cross the path to the new location if necessary. If there is already a node at the given location, then if value is a Node class or Dataset it will overwrite the data already present at that node, and if value is a single array, it will be merged with it. """ # TODO xarray.Dataset accepts other possibilities, how do we exactly replicate all the behaviour? if utils.is_dict_like(key): raise NotImplementedError elif isinstance(key, str): # TODO should possibly deal with hashables in general? # path-like: a name of a node/variable, or path to a node/variable path = NodePath(key) return self._set_item(path, value, new_nodes_along_path=True) else: raise ValueError("Invalid format for key")
[docs] def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: """ Update this node's children and / or variables. Just like `dict.update` this is an in-place operation. """ # TODO separate by type new_children = {} new_variables = {} for k, v in other.items(): if isinstance(v, DataTree): # avoid named node being stored under inconsistent key new_child = v.copy() new_child.name = k new_children[k] = new_child elif isinstance(v, (DataArray, Variable)): # TODO this should also accommodate other types that can be coerced into Variables new_variables[k] = v else: raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) # TODO are there any subtleties with preserving order of children like this? merged_children = OrderedDict({**self.children, **new_children}) self._replace( inplace=True, children=merged_children, **vars_merge_result._asdict() )
[docs] def assign( self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any ) -> DataTree: """ Assign new data variables or child nodes to a DataTree, returning a new object with all the original items in addition to the new ones. Parameters ---------- items : mapping of hashable to Any Mapping from variable or child node names to the new values. If the new values are callable, they are computed on the Dataset and assigned to new data variables. If the values are not callable, (e.g. a DataTree, DataArray, scalar, or array), they are simply assigned. **items_kwargs The keyword arguments form of ``variables``. One of variables or variables_kwargs must be provided. Returns ------- dt : DataTree A new DataTree with the new variables or children in addition to all the existing items. Notes ----- Since ``kwargs`` is a dictionary, the order of your arguments may not be preserved, and so the order of the new variables is not well-defined. Assigning multiple items within the same ``assign`` is possible, but you cannot reference other variables created within the same ``assign`` call. See Also -------- xarray.Dataset.assign pandas.DataFrame.assign """ items = either_dict_or_kwargs(items, items_kwargs, "assign") dt = self.copy() dt.update(items) return dt
[docs] def drop_nodes( self: DataTree, names: str | Iterable[str], *, errors: ErrorOptions = "raise" ) -> DataTree: """ Drop child nodes from this node. Parameters ---------- names : str or iterable of str Name(s) of nodes to drop. errors : {"raise", "ignore"}, default: "raise" If 'raise', raises a KeyError if any of the node names passed are not present as children of this node. If 'ignore', any given names that are present are dropped and no error is raised. Returns ------- dropped : DataTree A copy of the node with the specified children dropped. """ # the Iterable check is required for mypy if isinstance(names, str) or not isinstance(names, Iterable): names = {names} else: names = set(names) if errors == "raise": extra = names - set(self.children) if extra: raise KeyError(f"Cannot drop all nodes - nodes {extra} not present") children_to_keep = OrderedDict( {name: child for name, child in self.children.items() if name not in names} ) return self._replace(children=children_to_keep)
[docs] @classmethod def from_dict( cls, d: MutableMapping[str, Dataset | DataArray | DataTree | None], name: Optional[str] = None, ) -> DataTree: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. Parameters ---------- d : dict-like A mapping from path names to xarray.Dataset, xarray.DataArray, or DataTree objects. Path names are to be given as unix-like path. If path names containing more than one part are given, new tree nodes will be constructed as necessary. To assign data to the root node of the tree use "/" as the path. name : Hashable, optional Name for the root node of the tree. Default is None. Returns ------- DataTree Notes ----- If your dictionary is nested you will need to flatten it before using this method. """ # First create the root node root_data = d.pop("/", None) obj = cls(name=name, data=root_data, parent=None, children=None) if d: # Populate tree with children determined from data_objects mapping for path, data in d.items(): # Create and set new node node_name = NodePath(path).name if isinstance(data, cls): new_node = data.copy() new_node.orphan() else: new_node = cls(name=node_name, data=data) obj._set_item( path, new_node, allow_overwrite=False, new_nodes_along_path=True, ) return obj
[docs] def to_dict(self) -> Dict[str, Dataset]: """ Create a dictionary mapping of absolute node paths to the data contained in those nodes. Returns ------- Dict[str, Dataset] """ return {node.path: node.to_dataset() for node in self.subtree}
@property def nbytes(self) -> int: return sum(node.to_dataset().nbytes for node in self.subtree) def __len__(self) -> int: return len(self.children) + len(self.data_vars) @property def indexes(self) -> Indexes[pd.Index]: """Mapping of pandas.Index objects used for label based indexing. Raises an error if this DataTree node has indexes that cannot be coerced to pandas.Index objects. See Also -------- DataTree.xindexes """ return self.xindexes.to_pandas_indexes() @property def xindexes(self) -> Indexes[Index]: """Mapping of xarray Index objects used for label based indexing.""" return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) @property def coords(self) -> DatasetCoordinates: """Dictionary of xarray.DataArray objects corresponding to coordinate variables """ return DatasetCoordinates(self.to_dataset()) @property def data_vars(self) -> DataVariables: """Dictionary of DataArray objects corresponding to data variables""" return DataVariables(self.to_dataset())
[docs] def isomorphic( self, other: DataTree, from_root: bool = False, strict_names: bool = False, ) -> bool: """ Two DataTrees are considered isomorphic if every node has the same number of children. Nothing about the data in each node is checked. Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, such as ``tree1 + tree2``. By default this method does not check any part of the tree above the given node. Therefore this method can be used as default to check that two subtrees are isomorphic. Parameters ---------- other : DataTree The other tree object to compare to. from_root : bool, optional, default is False Whether or not to first traverse to the root of the two trees before checking for isomorphism. If neither tree has a parent then this has no effect. strict_names : bool, optional, default is False Whether or not to also check that every node in the tree has the same name as its counterpart in the other tree. See Also -------- DataTree.equals DataTree.identical """ try: check_isomorphic( self, other, require_names_equal=strict_names, check_from_root=from_root, ) return True except (TypeError, TreeIsomorphismError): return False
[docs] def equals(self, other: DataTree, from_root: bool = True) -> bool: """ Two DataTrees are equal if they have isomorphic node structures, with matching node names, and if they have matching variables and coordinates, all of which are equal. By default this method will check the whole tree above the given node. Parameters ---------- other : DataTree The other tree object to compare to. from_root : bool, optional, default is True Whether or not to first traverse to the root of the two trees before checking for isomorphism. If neither tree has a parent then this has no effect. See Also -------- Dataset.equals DataTree.isomorphic DataTree.identical """ if not self.isomorphic(other, from_root=from_root, strict_names=True): return False return all( [ node.ds.equals(other_node.ds) for node, other_node in zip(self.subtree, other.subtree) ] )
[docs] def identical(self, other: DataTree, from_root=True) -> bool: """ Like equals, but will also check all dataset attributes and the attributes on all variables and coordinates. By default this method will check the whole tree above the given node. Parameters ---------- other : DataTree The other tree object to compare to. from_root : bool, optional, default is True Whether or not to first traverse to the root of the two trees before checking for isomorphism. If neither tree has a parent then this has no effect. See Also -------- Dataset.identical DataTree.isomorphic DataTree.equals """ if not self.isomorphic(other, from_root=from_root, strict_names=True): return False return all( node.ds.identical(other_node.ds) for node, other_node in zip(self.subtree, other.subtree) )
[docs] def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: """ Filter nodes according to a specified condition. Returns a new tree containing only the nodes in the original tree for which `fitlerfunc(node)` is True. Will also contain empty nodes at intermediate positions if required to support leaves. Parameters ---------- filterfunc: function A function which accepts only one DataTree - the node on which filterfunc will be called. Returns ------- DataTree See Also -------- match pipe map_over_subtree """ filtered_nodes = { node.path: node.ds for node in self.subtree if filterfunc(node) } return DataTree.from_dict(filtered_nodes, name=self.root.name)
[docs] def match(self, pattern: str) -> DataTree: """ Return nodes with paths matching pattern. Uses unix glob-like syntax for pattern-matching. Parameters ---------- pattern: str A pattern to match each node path against. Returns ------- DataTree See Also -------- filter pipe map_over_subtree Examples -------- >>> dt = DataTree.from_dict( ... { ... "/a/A": None, ... "/a/B": None, ... "/b/A": None, ... "/b/B": None, ... } ... ) >>> dt.match("*/B") DataTree('None', parent=None) ├── DataTree('a') │ └── DataTree('B') └── DataTree('b') └── DataTree('B') """ matching_nodes = { node.path: node.ds for node in self.subtree if NodePath(node.path).match(pattern) } return DataTree.from_dict(matching_nodes, name=self.root.name)
[docs] def map_over_subtree( self, func: Callable, *args: Iterable[Any], **kwargs: Any, ) -> DataTree | Tuple[DataTree]: """ Apply a function to every dataset in this subtree, returning a new tree which stores the results. The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the descendant nodes. The returned tree will have the same structure as the original subtree. func needs to return a Dataset in order to rebuild the subtree. Parameters ---------- func : callable Function to apply to datasets with signature: `func(node.ds, *args, **kwargs) -> Dataset`. Function will not be applied to any nodes without datasets. *args : tuple, optional Positional arguments passed on to `func`. **kwargs : Any Keyword arguments passed on to `func`. Returns ------- subtrees : DataTree, Tuple of DataTrees One or more subtrees containing results from applying ``func`` to the data at each node. """ # TODO this signature means that func has no way to know which node it is being called upon - change? # TODO fix this typing error return map_over_subtree(func)(self, *args, **kwargs) # type: ignore[operator]
def map_over_subtree_inplace( self, func: Callable, *args: Iterable[Any], **kwargs: Any, ) -> None: """ Apply a function to every dataset in this subtree, updating data in place. Parameters ---------- func : callable Function to apply to datasets with signature: `func(node.ds, *args, **kwargs) -> Dataset`. Function will not be applied to any nodes without datasets, *args : tuple, optional Positional arguments passed on to `func`. **kwargs : Any Keyword arguments passed on to `func`. """ # TODO if func fails on some node then the previous nodes will still have been updated... for node in self.subtree: if node.has_data: node.ds = func(node.ds, *args, **kwargs)
[docs] def pipe( self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any ) -> Any: """Apply ``func(self, *args, **kwargs)`` This method replicates the pandas method of the same name. Parameters ---------- func : callable function to apply to this xarray object (Dataset/DataArray). ``args``, and ``kwargs`` are passed into ``func``. Alternatively a ``(callable, data_keyword)`` tuple where ``data_keyword`` is a string indicating the keyword of ``callable`` that expects the xarray object. *args positional arguments passed into ``func``. **kwargs a dictionary of keyword arguments passed into ``func``. Returns ------- object : Any the return type of ``func``. Notes ----- Use ``.pipe`` when chaining together functions that expect xarray or pandas objects, e.g., instead of writing .. code:: python f(g(h(dt), arg1=a), arg2=b, arg3=c) You can write .. code:: python (dt.pipe(h).pipe(g, arg1=a).pipe(f, arg2=b, arg3=c)) If you have a function that takes the data as (say) the second argument, pass a tuple indicating which keyword expects the data. For example, suppose ``f`` takes its data as ``arg2``: .. code:: python (dt.pipe(h).pipe(g, arg1=a).pipe((f, "arg2"), arg1=a, arg3=c)) """ if isinstance(func, tuple): func, target = func if target in kwargs: raise ValueError( f"{target} is both the pipe target and a keyword argument" ) kwargs[target] = self else: args = (self,) + args return func(*args, **kwargs)
def render(self): """Print tree structure, including any data stored at each node.""" for pre, fill, node in RenderTree(self): print(f"{pre}DataTree('{self.name}')") for ds_line in repr(node.ds)[1:]: print(f"{fill}{ds_line}")
[docs] def merge(self, datatree: DataTree) -> DataTree: """Merge all the leaves of a second DataTree into this one.""" raise NotImplementedError
def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree: """Merge a set of child nodes into a single new node.""" raise NotImplementedError # TODO some kind of .collapse() or .flatten() method to merge a subtree def as_array(self) -> DataArray: return self.ds.as_dataarray() @property def groups(self): """Return all netCDF4 groups in the tree, given as a tuple of path-like strings.""" return tuple(node.path for node in self.subtree)
[docs] def to_netcdf( self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs ): """ Write datatree contents to a netCDF file. Parameters ---------- filepath : str or Path Path to which to save this datatree. mode : {"w", "a"}, default: "w" Write ('w') or append ('a') mode. If mode='w', any existing file at this location will be overwritten. If mode='a', existing variables will be overwritten. Only appies to the root group. encoding : dict, optional Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1, "zlib": True}, ...}, ...}``. See ``xarray.Dataset.to_netcdf`` for available options. unlimited_dims : dict, optional Mapping of unlimited dimensions per group that that should be serialized as unlimited dimensions. By default, no dimensions are treated as unlimited dimensions. Note that unlimited_dims may also be set via ``dataset.encoding["unlimited_dims"]``. kwargs : Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` """ from .io import _datatree_to_netcdf _datatree_to_netcdf( self, filepath, mode=mode, encoding=encoding, unlimited_dims=unlimited_dims, **kwargs, )
[docs] def to_zarr( self, store, mode: str = "w-", encoding=None, consolidated: bool = True, **kwargs, ): """ Write datatree contents to a Zarr store. Parameters ---------- store : MutableMapping, str or Path, optional Store or path to directory in file system mode : {{"w", "w-", "a", "r+", None}, default: "w-" Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists); “a” means override existing variables (create if does not exist); “r+” means modify existing array values only (raise an error if any metadata or shapes would change). The default mode is “a” if append_dim is set. Otherwise, it is “r+” if region is set and w- otherwise. encoding : dict, optional Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1}, ...}, ...}``. See ``xarray.Dataset.to_zarr`` for available options. consolidated : bool If True, apply zarr's `consolidate_metadata` function to the store after writing metadata for all groups. kwargs : Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` """ from .io import _datatree_to_zarr _datatree_to_zarr( self, store, mode=mode, encoding=encoding, consolidated=consolidated, **kwargs, )
def plot(self): raise NotImplementedError