Source code for ndpyramid.reproject

from __future__ import annotations  # noqa: F401

from collections import defaultdict
from collections.abc import Sequence

import numpy as np
import shapely.errors
import xarray as xr
from odc.geo import CRS as OdcCRS
from odc.geo.geobox import GeoBox
from odc.geo.xr import xr_reproject

from .common import Projection, ProjectionOptions
from .utils import add_metadata_and_zarr_encoding, get_levels, get_version, multiscales_template


def _da_reproject(da: xr.DataArray, *, geobox: GeoBox, resampling: str):
    """Reproject a DataArray to a given GeoBox.

    Notes
    -----
    - Avoids rebuilding CRS/GeoBox per call.
    - Does not mutate the source DataArray; encodings are applied on a shallow copy.
    """
    try:
        # Work on a shallow copy to avoid mutating caller's encoding/attrs
        da_local = da.copy(deep=False)

        # Ensure a sensible fill value for floating types (used by rasterio/odc-geo during warp)
        if da_local.encoding.get("_FillValue") is None and np.issubdtype(
            da_local.dtype, np.floating
        ):
            enc = dict(da_local.encoding)
            enc["_FillValue"] = np.nan
            da_local.encoding = enc

        return xr_reproject(da_local, geobox, resampling=resampling)

    # catch the GEOSException: TopologyException error from shapely and raise a more informative error in case the user runs into
    # https://github.com/opendatacube/odc-geo/issues/147
    except shapely.errors.GEOSException as e:
        raise RuntimeError(
            "Error during reprojection. This can be caused by invalid geometries in the input data. "
            "Try cleaning the geometries or using a different resampling method. If the input data contains dask-arrays, "
            "consider using .compute() to convert them to in-memory arrays before reprojection. "
            "See https://github.com/opendatacube/odc-geo/issues/147 for more details."
        ) from e


def level_reproject(
    ds: xr.Dataset,
    *,
    projection: ProjectionOptions = "web-mercator",
    level: int,
    pixels_per_tile: int = 128,
    resampling: str | dict = "average",
    extra_dim: str | None = None,
    clear_attrs: bool = False,
) -> xr.Dataset:
    """Create a level of a multiscale pyramid of a dataset via reprojection.

    Parameters
    ----------
    ds : xarray.Dataset
        The dataset to create a multiscale pyramid of.
    projection : str, optional
        The projection to use. Default is 'web-mercator'.
    level : int
        The level of the pyramid to create.
    pixels_per_tile : int, optional
        Number of pixels per tile
    resampling : str or dict
        Resampling method to use. If a dict, keys are variable names and values are odc-geo supported
        methods. A string applies to all variables.
    extra_dim : str, optional
        Deprecated/ignored. Extra dimensions are handled natively by odc-geo/xarray broadcasting.
    clear_attrs : bool, False
        Clear the attributes of the DataArrays within the multiscale level. Default is False.

    Returns
    -------
    xr.Dataset
        The multiscale pyramid level.

    Warning
    -------
    Pyramid generation by level is experimental and subject to change.

    """

    # Ensure CRS is present at the dataset level
    # raise error if not present, as this is required for reprojection
    if "spatial_ref" not in ds.coords:
        raise ValueError(
            "Source Dataset has no 'spatial_ref' coordinate. Please assign a CRS to the dataset before reprojection. You can use the 'assign_crs' function from odc.geo.xr."
        )
    projection_model = Projection(name=projection)
    dim = 2**level * pixels_per_tile
    dst_transform = projection_model.transform(dim=dim)
    # Build CRS/GeoBox once per level and reuse
    dst_crs_odc = OdcCRS(projection_model._crs)
    dst_geobox = GeoBox((dim, dim), dst_transform, dst_crs_odc)
    save_kwargs = {
        "level": level,
        "pixels_per_tile": pixels_per_tile,
        "projection": projection,
        "resampling": resampling,
        "extra_dim": extra_dim,
        "clear_attrs": clear_attrs,
    }

    attrs = {
        "multiscales": multiscales_template(
            datasets=[{"path": ".", "level": level, "crs": projection_model._crs}],
            type="reduce",
            method="pyramid_reproject",
            version=get_version(),
            kwargs=save_kwargs,
        )
    }

    # Convert resampling from string to dictionary if necessary
    if isinstance(resampling, str):
        resampling_dict: dict = defaultdict(lambda: resampling)
    else:
        resampling_dict = resampling

    # create the data array for each level (broadcast over extra dims; no Python loop)
    ds_level = xr.Dataset(attrs=ds.attrs)
    for k, da in ds.items():
        da_reprojected = _da_reproject(
            da,
            geobox=dst_geobox,
            resampling=resampling_dict[k],
        )
        if clear_attrs:
            da_reprojected.attrs.clear()
        ds_level[k] = da_reprojected
    ds_level.attrs["multiscales"] = attrs["multiscales"]
    return ds_level


[docs] def pyramid_reproject( ds: xr.Dataset, *, projection: ProjectionOptions = "web-mercator", levels: int | None = None, level_list: Sequence[int] | None = None, pixels_per_tile: int = 128, other_chunks: dict | None = None, resampling: str | dict = "average", extra_dim: str | None = None, clear_attrs: bool = False, ) -> xr.DataTree: """Create a multiscale pyramid of a dataset via reprojection. Parameters ---------- ds : xarray.Dataset The dataset to create a multiscale pyramid of. projection : str, optional The projection to use. Default is 'web-mercator'. levels : int, optional The number of (contiguous) levels to create starting at 0. Mutually exclusive with ``level_list``. If both ``levels`` and ``level_list`` are ``None`` then an attempt is made to infer the number of levels via ``get_levels`` (currently not implemented). level_list : Sequence[int], optional Explicit list of zoom levels to generate (e.g. ``[4]`` to only build Z4, or ``[2,4,6]`` for a sparse pyramid). Mutually exclusive with ``levels``. pixels_per_tile : int, optional Number of pixels per tile, by default 128 other_chunks : dict Chunks for non-spatial dims to pass to :py:meth:`~xr.Dataset.chunk`. Default is None resampling : str or dict, optional Resampling method to use. Default is 'average'. If a dict, keys are variable names and values are resampling methods. extra_dim : str, optional Deprecated/ignored. Extra dimensions are handled natively by odc-geo/xarray broadcasting. clear_attrs : bool, False Clear the attributes of the DataArrays within the multiscale pyramid. Default is False. Returns ------- xr.DataTree The multiscale pyramid. """ if levels is not None and level_list is not None: raise ValueError("Specify only one of 'levels' or 'level_list'.") if level_list is not None: # sanitize and sort unique levels level_indices = sorted({int(idx) for idx in level_list}) else: if not levels: levels = get_levels(ds) level_indices = list(range(int(levels))) projection_model = Projection(name=projection) save_kwargs = { # store the explicit list for reproducibility "levels": level_indices, "pixels_per_tile": pixels_per_tile, "projection": projection, "other_chunks": other_chunks, "resampling": resampling, "extra_dim": extra_dim, "clear_attrs": clear_attrs, } attrs = { "multiscales": multiscales_template( datasets=[ {"path": str(i), "level": i, "crs": projection_model._crs} for i in level_indices ], type="reduce", method="pyramid_reproject", version=get_version(), kwargs=save_kwargs, ) } plevels = { str(level): level_reproject( ds, projection=projection, level=level, pixels_per_tile=pixels_per_tile, resampling=resampling, extra_dim=extra_dim, clear_attrs=clear_attrs, ) for level in level_indices } # create the final multiscale pyramid plevels["/"] = xr.Dataset(attrs=attrs) pyramid = xr.DataTree.from_dict(plevels) pyramid = add_metadata_and_zarr_encoding( pyramid, levels=level_indices, pixels_per_tile=pixels_per_tile, other_chunks=other_chunks, projection=projection_model, ) return pyramid