Source code for ndpyramid.resample

from __future__ import annotations  # noqa: F401

import typing
import warnings
from collections import defaultdict
from collections.abc import Sequence

import numpy as np
import xarray as xr
from odc.geo.xr import assign_crs
from pyproj.crs import CRS

from .common import Projection, ProjectionOptions
from .utils import add_metadata_and_zarr_encoding, get_levels, get_version, multiscales_template

ResamplingOptions = typing.Literal["bilinear", "nearest"]


def _da_resample(
    da: xr.DataArray,
    *,
    dim: int,
    projection_model: Projection,
    pixels_per_tile: int,
    other_chunk: int,
    resampling: ResamplingOptions,
):
    try:
        from pyresample.area_config import create_area_def
        from pyresample.future.resamplers.resampler import (
            add_crs_xy_coords,
            update_resampled_coords,
        )
        from pyresample.gradient import (
            block_bilinear_interpolator,
            block_nn_interpolator,
            gradient_resampler_indices_block,
        )
        from pyresample.resampler import resample_blocks
        from pyresample.utils.cf import load_cf_area
    except ImportError as e:
        raise ImportError(
            "The use of pyramid_resample requires the packages pyresample and dask"
        ) from e
    if da.encoding.get("_FillValue") is None and np.issubdtype(da.dtype, np.floating):
        da.encoding["_FillValue"] = np.nan
    if resampling == "bilinear":
        fun = block_bilinear_interpolator
    elif resampling == "nearest":
        fun = block_nn_interpolator
    else:
        raise ValueError(f"Unrecognized interpolation method {resampling} for gradient resampling.")
    target_area_def = create_area_def(
        area_id=projection_model.name,
        projection=projection_model._crs,
        shape=(dim, dim),
        area_extent=projection_model._area_extent,
    )
    try:
        source_area_def = load_cf_area(da.to_dataset(name="var"), variable="var")[0]
    except ValueError as e:
        warnings.warn(
            f"Automatic determination of source AreaDefinition from CF conventions failed with {e}."
            " Falling back to AreaDefinition creation from coordinates."
        )
        lx = da.x[0] - (da.x[1] - da.x[0]) / 2
        rx = da.x[-1] + (da.x[-1] - da.x[-2]) / 2
        uy = da.y[0] - (da.y[1] - da.y[0]) / 2
        ly = da.y[-1] + (da.y[-1] - da.y[-2]) / 2
        # Retrieve CRS from odc-geo accessor (assigned via assign_crs in fixtures/pipeline)
        try:
            odc_crs = da.odc.crs  # odc.geo.CRS instance
            source_crs = CRS.from_string(str(odc_crs))
        except Exception as e:  # pragma: no cover - fallback path
            raise ValueError("Unable to determine source CRS for resampling") from e
        source_area_def = create_area_def(
            area_id=2,
            projection=source_crs,
            shape=(da.sizes["y"], da.sizes["x"]),
            area_extent=(lx.values, ly.values, rx.values, uy.values),
        )
    indices_xy = resample_blocks(
        gradient_resampler_indices_block,
        source_area_def,
        [],
        target_area_def,
        chunk_size=(other_chunk, pixels_per_tile, pixels_per_tile),
        dtype=float,
    )
    resampled = resample_blocks(
        fun,
        source_area_def,
        [da.data],
        target_area_def,
        dst_arrays=[indices_xy],
        chunk_size=(other_chunk, pixels_per_tile, pixels_per_tile),
        dtype=da.dtype,
    )
    resampled_da = xr.DataArray(resampled, dims=("time", "y", "x"))
    resampled_da = update_resampled_coords(da, resampled_da, target_area_def)
    resampled_da = add_crs_xy_coords(resampled_da, target_area_def)
    resampled_da = resampled_da.drop_vars("crs")
    resampled_da.attrs = {}
    return resampled_da


def level_resample(
    ds: xr.Dataset,
    *,
    x,
    y,
    projection: ProjectionOptions = "web-mercator",
    level: int,
    pixels_per_tile: int = 128,
    other_chunks: dict | None = None,
    resampling: ResamplingOptions | dict = "bilinear",
    clear_attrs: bool = False,
) -> xr.Dataset:
    """Create a level of a multiscale pyramid of a dataset via resampling.

    Parameters
    ----------
    ds : xarray.Dataset
        The dataset to create a multiscale pyramid of.
    y : string
        name of the variable to use as 'y' axis of the CF area definition
    x : string
        name of the variable to use as 'x' axis of the CF area definition
    projection : str, optional
        The projection to use. Default is 'web-mercator'.
    level : int
        The level of the pyramid to create.
    pixels_per_tile : int, optional
        Number of pixels per tile
    other_chunks : dict
        Chunks for non-spatial dims.
    resampling : str or dict, optional
        Pyresample resampling method to use. Default is 'bilinear'.
        If a dict, keys are variable names and values are resampling methods.
    clear_attrs : bool, False
        Clear the attributes of the DataArrays within the multiscale level. Default is False.

    Returns
    -------
    xr.Dataset
        The multiscale pyramid level.

    Warning
    -------
    Pyramid generation by level is experimental and subject to change.

    """
    dim = 2**level * pixels_per_tile
    projection_model = Projection(name=projection)
    save_kwargs = {"pixels_per_tile": pixels_per_tile}
    attrs = {
        "multiscales": multiscales_template(
            datasets=[{"path": ".", "level": level, "crs": projection_model._crs}],
            type="reduce",
            method="pyramid_resample",
            version=get_version(),
            kwargs=save_kwargs,
        )
    }

    # Convert resampling from string to dictionary if necessary
    if isinstance(resampling, str):
        resampling_dict: dict = defaultdict(lambda: resampling)
    else:
        resampling_dict = resampling
    # update coord naming to x & y and ensure order of dims is time, y, x
    ds = ds.rename({x: "x", y: "y"})
    # create the data array for each level
    ds_level = xr.Dataset(attrs=ds.attrs)
    for k, da in ds.items():
        if clear_attrs:
            da.attrs.clear()
        if len(da.shape) > 3:
            # if extra_dim is not specified, raise an error
            raise NotImplementedError(
                "4+ dimensional datasets are not currently supported for pyramid_resample."
            )
        else:
            # if the data array is not 4D, just resample it
            if other_chunks is None:
                other_chunk = list(da.sizes.values())[0]
            else:
                other_chunk = list(other_chunks.values())[0]
            # Cast resampling method to expected literal type if possible
            method = resampling_dict[k]
            if method not in ("bilinear", "nearest"):
                raise ValueError(f"Unsupported resampling method '{method}' for pyramid_resample")
            ds_level[k] = _da_resample(
                da,
                dim=dim,
                projection_model=projection_model,
                pixels_per_tile=pixels_per_tile,
                other_chunk=other_chunk,
                resampling=method,  # type: ignore[arg-type]
            )
    ds_level.attrs["multiscales"] = attrs["multiscales"]
    ds_level = assign_crs(ds_level, projection_model._crs)
    return ds_level


[docs] def pyramid_resample( ds: xr.Dataset, *, x: str, y: str, projection: ProjectionOptions = "web-mercator", levels: int | None = None, level_list: Sequence[int] | None = None, pixels_per_tile: int = 128, other_chunks: dict | None = None, resampling: ResamplingOptions | dict = "bilinear", clear_attrs: bool = False, ) -> xr.DataTree: """Create a multiscale pyramid of a dataset via resampling. Parameters ---------- ds : xarray.Dataset The dataset to create a multiscale pyramid of. y : string name of the variable to use as ``y`` axis of the CF area definition x : string name of the variable to use as ``x`` axis of the CF area definition projection : str, optional The projection to use. Default is ``web-mercator``. levels : int, optional Number of contiguous levels starting at 0 to create. Mutually exclusive with ``level_list``. level_list : Sequence[int], optional Explicit list of zoom levels to build (e.g. ``[4]``). Mutually exclusive with ``levels``. pixels_per_tile : int, optional Number of pixels per tile, by default 128 other_chunks : dict Chunks for non-spatial dims to pass to :py:meth:`~xr.Dataset.chunk`. Default is None resampling : str or dict, optional Pyresample resampling method to use (``bilinear`` or ``nearest``). Default is ``bilinear``. If a dict, keys are variable names and values are resampling methods. clear_attrs : bool, False Clear the attributes of the DataArrays within the multiscale pyramid. Default is False. Returns ------- xr.DataTree The multiscale pyramid. Warnings -------- - Pyresample expects longitude ranges between -180 - 180 degrees and latitude ranges between -90 and 90 degrees. - 3-D datasets are expected to have a dimension order of ``(time, y, x)``. ``Ndpyramid`` and ``pyresample`` do not check the validity of these assumptions to improve performance. """ if levels is not None and level_list is not None: raise ValueError("Specify only one of 'levels' or 'level_list'.") if level_list is not None: level_indices = sorted({int(i) for i in level_list}) else: if not levels: levels = get_levels(ds) level_indices = list(range(int(levels))) save_kwargs = {"levels": level_indices, "pixels_per_tile": pixels_per_tile} attrs = { "multiscales": multiscales_template( datasets=[{"path": str(i)} for i in level_indices], type="reduce", method="pyramid_resample", version=get_version(), kwargs=save_kwargs, ) } plevels = { str(level): level_resample( ds, x=x, y=y, projection=projection, level=level, pixels_per_tile=pixels_per_tile, other_chunks=other_chunks, resampling=resampling, clear_attrs=clear_attrs, ) for level in level_indices } # create the final multiscale pyramid plevels["/"] = xr.Dataset(attrs=attrs) pyramid = xr.DataTree.from_dict(plevels) projection_model = Projection(name=projection) pyramid = add_metadata_and_zarr_encoding( pyramid, levels=level_indices, pixels_per_tile=pixels_per_tile, other_chunks=other_chunks, projection=projection_model, ) return pyramid