Source code for ndpyramid.regrid

from __future__ import annotations  # noqa: F401

import itertools
import typing
from collections.abc import Sequence

import numpy as np
import xarray as xr

from .common import Projection
from .utils import add_metadata_and_zarr_encoding, get_version, multiscales_template


def xesmf_weights_to_xarray(regridder) -> xr.Dataset:
    w = regridder.weights.data
    dim = "n_s"
    ds = xr.Dataset(
        {
            "S": (dim, w.data),
            "col": (dim, w.coords[1, :] + 1),
            "row": (dim, w.coords[0, :] + 1),
        }
    )
    ds.attrs = {"n_in": regridder.n_in, "n_out": regridder.n_out}
    return ds


def _reconstruct_xesmf_weights(ds_w: xr.Dataset) -> xr.DataArray:
    """Reconstruct weights into format that xESMF understands"""
    import sparse
    import xarray as xr

    col = ds_w["col"].values - 1
    row = ds_w["row"].values - 1
    s = ds_w["S"].values
    n_out, n_in = ds_w.attrs["n_out"], ds_w.attrs["n_in"]
    crds = np.stack([row, col])
    return xr.DataArray(
        sparse.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights"
    )


def make_grid_ds(
    level: int,
    pixels_per_tile: int = 128,
    projection: typing.Literal["web-mercator", "equidistant-cylindrical"] = "web-mercator",
) -> xr.Dataset:
    """Make a dataset representing a target grid

    Parameters
    ----------
    level : int
        The zoom level to compute the grid for. Level zero is the furthest out zoom level
    pixels_per_tile : int, optional
        Number of pixels to include along each axis in individual tiles, by default 128
    projection : str, optional
        The projection to use for the grid, by default 'equidistant-cylindrical'

    Returns
    -------
    xr.Dataset
        Target grid dataset with the following variables:
        - "x": X coordinate in Web Mercator projection (grid cell center)
        - "y": Y coordinate in Web Mercator projection (grid cell center)
        - "lat": latitude coordinate (grid cell center)
        - "lon": longitude coordinate (grid cell center)
        - "lat_b": latitude bounds for grid cell
        - "lon_b": longitude bounds for grid cell

    """
    projection_model = Projection(name=projection)

    dim = (2**level) * pixels_per_tile

    transform = projection_model.transform(dim=dim)

    if projection_model.name == "equidistant-cylindrical":
        title = "Equidistant Cylindrical Grid"

    elif projection_model.name == "web-mercator":
        title = "Web Mercator Grid"

    else:
        title = "Unknown Projection Grid"

    p = projection_model._proj

    grid_shape = (dim, dim)
    bounds_shape = (dim + 1, dim + 1)

    xs = np.empty(grid_shape)
    ys = np.empty(grid_shape)
    lat = np.empty(grid_shape)
    lon = np.empty(grid_shape)
    lat_b = np.zeros(bounds_shape)
    lon_b = np.zeros(bounds_shape)

    # calc grid cell center coordinates
    ii, jj = np.meshgrid(np.arange(dim) + 0.5, np.arange(dim) + 0.5)
    for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
        locs = [ii[i, j], jj[i, j]]
        xs[i, j], ys[i, j] = transform * locs
        lon[i, j], lat[i, j] = p(xs[i, j], ys[i, j], inverse=True)

    # calc grid cell bounds
    iib, jjb = np.meshgrid(np.arange(dim + 1), np.arange(dim + 1))
    for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
        locs = [iib[i, j], jjb[i, j]]
        x, y = transform * locs
        lon_b[i, j], lat_b[i, j] = p(x, y, inverse=True)

    return xr.Dataset(
        {
            "x": xr.DataArray(xs[0, :], dims=["x"]),
            "y": xr.DataArray(ys[:, 0], dims=["y"]),
            "lat": xr.DataArray(lat, dims=["y", "x"]),
            "lon": xr.DataArray(lon, dims=["y", "x"]),
            "lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
            "lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
        },
        attrs=dict(title=title, Conventions="CF-1.8"),
    )


def make_grid_pyramid(
    levels: int = 6,
    *,
    level_list: Sequence[int] | None = None,
    projection: typing.Literal["web-mercator", "equidistant-cylindrical"] = "web-mercator",
    pixels_per_tile: int = 128,
) -> xr.DataTree:
    """Helper function to create a grid pyramid for use with xesmf

    Parameters
    ----------
    levels : int, optional
        Number of contiguous levels (0..levels-1) to build. Ignored if ``level_list`` is provided.
    level_list : Sequence[int], optional
        Explicit list of zoom levels to build. Useful for sparse pyramids. Mutually exclusive with
        ``levels``.

    Returns
    -------
    pyramid : xr.DataTree
        Multiscale grid definition

    """
    if level_list is not None:
        level_indices = sorted({int(i) for i in level_list})
    else:
        level_indices = list(range(levels))

    plevels = {
        str(level): make_grid_ds(
            level, projection=projection, pixels_per_tile=pixels_per_tile
        ).chunk(-1)
        for level in level_indices
    }
    return xr.DataTree.from_dict(plevels)


def generate_weights_pyramid(
    ds_in: xr.Dataset,
    levels: int | None = None,
    *,
    level_list: Sequence[int] | None = None,
    method: str = "bilinear",
    regridder_kws: dict | None = None,
    projection: typing.Literal["web-mercator", "equidistant-cylindrical"] = "web-mercator",
) -> xr.DataTree:
    """Helper function to generate weights for a multiscale regridder

    Parameters
    ----------
    ds_in : xr.Dataset
        Input dataset to regrid
    levels : int, optional
        Number of contiguous levels (0..levels-1) to build. Ignored if ``level_list`` is provided.
    level_list : Sequence[int], optional
        Explicit list of zoom levels to build (sparse weights). Mutually exclusive with ``levels``.
    method : str, optional
        Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
    regridder_kws : dict
        Keyword arguments to pass to :py:class:`~xesmf.Regridder`. Default is `{'periodic': True}`
    projection : str, optional
        The projection to use for the grid, by default 'web-mercator'

    Returns
    -------
    weights : xr.DataTree
        Multiscale weights

    """
    import xesmf as xe

    regridder_kws = {} if regridder_kws is None else regridder_kws
    regridder_kws = {"periodic": True, **regridder_kws}

    if levels is not None and level_list is not None:
        raise ValueError("Specify only one of 'levels' or 'level_list'.")
    if level_list is not None:
        level_indices = sorted({int(i) for i in level_list})
    else:
        if levels is None:
            raise ValueError("Must provide either 'levels' or 'level_list'.")
        level_indices = list(range(levels))

    plevels = {}
    for level in level_indices:
        ds_out = make_grid_ds(level=level, projection=projection)
        regridder = xe.Regridder(ds_in, ds_out, method, **regridder_kws)
        ds = xesmf_weights_to_xarray(regridder)

        plevels[str(level)] = ds

    root_levels_attr: typing.Any = level_indices if level_list is not None else len(level_indices)
    root = xr.Dataset(attrs={"levels": root_levels_attr, "regrid_method": method})
    plevels["/"] = root
    return xr.DataTree.from_dict(plevels)


[docs] def pyramid_regrid( ds: xr.Dataset, projection: typing.Literal["web-mercator", "equidistant-cylindrical"] = "web-mercator", target_pyramid: xr.DataTree | None = None, levels: int | None = None, *, level_list: Sequence[int] | None = None, parallel_weights: bool = True, weights_pyramid: xr.DataTree | None = None, method: str = "bilinear", regridder_kws: dict | None = None, regridder_apply_kws: dict | None = None, other_chunks: dict | None = None, pixels_per_tile: int = 128, ) -> xr.DataTree: """Make a pyramid using xesmf's regridders Parameters ---------- ds : xr.Dataset Input dataset projection : str, optional Projection to use for the grid, by default 'web-mercator' target_pyramid : xr.DataTree, optional Target grids, if not provided, they will be generated, by default None levels : int, optional Number of contiguous levels to build (0..levels-1). Ignored if ``level_list`` provided. level_list : Sequence[int], optional Explicit list of zoom levels to build (sparse). Mutually exclusive with ``levels``. weights_pyramid : xr.DataTree, optional pyramid containing pregenerated weights parallel_weights : Bool Use dask to generate parallel weights method : str, optional Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear' regridder_kws : dict Keyword arguments to pass to regridder. Default is `{'periodic': True}` regridder_apply_kws : dict Keyword arguments such as `keep_attrs`, `skipna`, `na_thres` to pass to :py:meth:`~xesmf.Regridder.__call__`. Default is None other_chunks : dict Chunks for non-spatial dims to pass to :py:meth:`~xr.Dataset.chunk`. Default is None pixels_per_tile : int, optional Number of pixels per tile, by default 128 Returns ------- pyramid : xr.DataTree Multiscale data pyramid """ import xesmf as xe if target_pyramid is None: if levels is not None and level_list is not None: raise ValueError("Specify only one of 'levels' or 'level_list'.") if levels is not None or level_list is not None: target_pyramid = make_grid_pyramid( levels if levels is not None else 0, level_list=level_list, projection=projection, pixels_per_tile=pixels_per_tile, ) else: raise ValueError( "must either provide a target_pyramid or number of levels / level_list" ) # determine list of level indices from target_pyramid keys (excluding root if present) level_indices = sorted([int(k) for k in target_pyramid.keys() if k != "/"]) # backward compatibility: if levels specified ensure it matches if levels is not None and level_list is None and levels != len(level_indices): raise ValueError("Provided 'levels' does not match target_pyramid contents") regridder_kws = {} if regridder_kws is None else regridder_kws regridder_kws = {"periodic": True, **regridder_kws} # multiscales spec projection_model = Projection(name=projection) save_kwargs = { "levels": level_indices, "pixels_per_tile": pixels_per_tile, "projection": projection, "other_chunks": other_chunks, "method": method, "regridder_kws": regridder_kws, "regridder_apply_kws": regridder_apply_kws, } attrs = { "multiscales": multiscales_template( datasets=[ {"path": str(i), "level": i, "crs": projection_model._crs} for i in level_indices ], type="reduce", method="pyramid_regrid", version=get_version(), kwargs=save_kwargs, ) } save_kwargs.pop("levels") save_kwargs.pop("other_chunks") # set up pyramid plevels = {} # pyramid data for level in level_indices: grid = target_pyramid[str(level)].ds.load() # get the regridder object if weights_pyramid is None: regridder = xe.Regridder(ds, grid, method, parallel=parallel_weights, **regridder_kws) else: # Reconstruct weights into format that xESMF understands # this is a hack that assumes the weights were generated by # the `generate_weights_pyramid` function ds_w = weights_pyramid[str(level)].ds weights = _reconstruct_xesmf_weights(ds_w) regridder = xe.Regridder( ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws ) # regrid if regridder_apply_kws is None: regridder_apply_kws = {} regridder_apply_kws = {**{"keep_attrs": True}, **regridder_apply_kws} plevels[str(level)] = regridder(ds, **regridder_apply_kws) level_attrs = { "multiscales": multiscales_template( datasets=[{"path": ".", "level": level, "crs": projection_model._crs}], type="reduce", method="pyramid_regrid", version=get_version(), kwargs=save_kwargs, ) } plevels[str(level)].attrs["multiscales"] = level_attrs["multiscales"] root = xr.Dataset(attrs=attrs) plevels["/"] = root pyramid = xr.DataTree.from_dict(plevels) pyramid = add_metadata_and_zarr_encoding( pyramid, levels=level_indices, other_chunks=other_chunks, pixels_per_tile=pixels_per_tile, projection=Projection(name=projection), ) return pyramid