Source code for nlmod.cache

import functools
import hashlib
import importlib
import inspect
import json
import logging
import numbers
import os
import pathlib
import pickle

import flopy
import numpy as np
import pandas as pd
import xarray as xr
from dask.diagnostics import ProgressBar
from xarray.testing import assert_identical

from .config import NLMOD_CACHE_OPTIONS

logger = logging.getLogger(__name__)


[docs] def clear_cache(cachedir, prompt=True): """Clears the cache in a given cache directory by removing all .pklz and corresponding .nc files. Parameters ---------- cachedir : str path to cache directory. prompt : bool, optional Ask for confirmation before removing the cache. The default is True. Returns ------- None. """ if prompt: ans = input( f"this will remove all cached files in {cachedir} are you sure [Y/N]" ) if ans.lower() != "y": return for fname in os.listdir(cachedir): # assuming all pklz files belong to a cached netcdf file if fname.endswith(".pklz"): fname_nc = fname.replace(".pklz", ".nc") # remove pklz file os.remove(os.path.join(cachedir, fname)) msg = f"removed {fname}" logger.info(msg) # remove netcdf file fpath_nc = os.path.join(cachedir, fname_nc) if os.path.exists(fname_nc): # make sure cached netcdf is closed cached_ds = xr.open_dataset(fpath_nc, decode_coords="all") cached_ds.close() os.remove(fpath_nc) msg = f"removed {fname_nc}" logger.info(msg)
[docs] def cache_netcdf( coords_2d=False, coords_3d=False, coords_time=False, attrs_ds=False, datavars=None, coords=None, attrs=None, nc_hash=True, ): """Decorator to read/write the result of a function from/to a file to speed up function calls with the same arguments. Should only be applied to functions that: - return an Xarray Dataset - have no more than one xarray dataset as function argument - have functions arguments of types that can be checked using the _is_valid_cache functions 1. The directory and filename of the cache should be defined by the person calling a function with this decorator. If not defined no cache is created nor used. 2. Create a new cached file if it is impossible to check if the function arguments used to create the cached file are the same as the current function arguments. This can happen if one of the function arguments has a type that cannot be checked using the _is_valid_cache function. 3. Function arguments are pickled together with the cache to check later if the cache is valid. 4. If one of the function arguments is an xarray Dataset it is not pickled. Therefore we cannot check if this function argument is identical for the cached data and the new function call. We do check if the xarray Dataset coördinates correspond to the coördinates of the cached netcdf file. 5. This function uses `functools.wraps` and some home made magic in _update_docstring_and_signature to add arguments of the decorator to the decorated function. This assumes that the decorated function has a docstring with a "Returns" heading. If this is not the case an error is raised when trying to decorate the function. If all kwargs are left to their defaults, the function caches the full dataset. Parameters ---------- ds : xr.Dataset Dataset with dimensions and coordinates. coords_2d : bool, optional Shorthand for adding 2D coordinates. The default is False. coords_3d : bool, optional Shorthand for adding 3D coordinates. The default is False. coords_time : bool, optional Shorthand for adding time coordinates. The default is False. attrs_ds : bool, optional Shorthand for adding model dataset attributes. The default is False. datavars : list, optional List of data variables to check for. The default is an empty list. coords : list, optional List of coordinates to check for. The default is an empty list. attrs : list, optional List of attributes to check for. The default is an empty list. nc_hash: bool, optional check if the pickled function arguments belong to the cached netcdf file. Default is True. """ def decorator(func): # add cachedir and cachename to docstring _update_docstring_and_signature(func) @functools.wraps(func) def wrapper(*args, cachedir=None, cachename=None, **kwargs): # 1 check if cachedir and name are provided if cachedir is None or cachename is None: return func(*args, **kwargs) if not cachename.endswith(".nc"): cachename += ".nc" fname_cache = os.path.join(cachedir, cachename) # netcdf file fname_pickle_cache = fname_cache.replace(".nc", ".pklz") # adjust args and kwargs with minimal dataset args_adj = [] kwargs_adj = {} datasets = [] func_args_dic = {} for i, arg in enumerate(args): if isinstance(arg, xr.Dataset): arg_adj = ds_contains( arg, coords_2d=coords_2d, coords_3d=coords_3d, coords_time=coords_time, attrs_ds=attrs_ds, datavars=datavars, coords=coords, attrs=attrs, ) args_adj.append(arg_adj) datasets.append(arg_adj) else: args_adj.append(arg) func_args_dic[f"arg{i}"] = arg for key, arg in kwargs.items(): if isinstance(arg, xr.Dataset): arg_adj = ds_contains( arg, coords_2d=coords_2d, coords_3d=coords_3d, coords_time=coords_time, attrs_ds=attrs_ds, datavars=datavars, coords=coords, attrs=attrs, ) kwargs_adj[key] = arg_adj datasets.append(arg_adj) else: kwargs_adj[key] = arg func_args_dic[key] = arg if len(datasets) == 0: dataset = None elif len(datasets) == 1: dataset = datasets[0] else: raise NotImplementedError( "Function was called with multiple xarray dataset arguments. " "Currently unsupported." ) # only use cache if the cache file and the pickled function arguments exist if os.path.exists(fname_cache) and os.path.exists(fname_pickle_cache): # check if you can read the pickle, there are several reasons why a # pickle can not be read. try: with open(fname_pickle_cache, "rb") as f: func_args_dic_cache = pickle.load(f) pickle_check = True except (pickle.UnpicklingError, ModuleNotFoundError): logger.info("could not read pickle, not using cache") pickle_check = False argument_check = False # check if the module where the function is defined was changed # after the cache was created time_mod_func = _get_modification_time(func) time_mod_cache = os.path.getmtime(fname_cache) modification_check = time_mod_cache > time_mod_func if not modification_check: logger.info( f"module of function {func.__name__} recently modified, " "not using cache" ) with xr.open_dataset(fname_cache, decode_coords="all") as cached_ds: cached_ds.load() if pickle_check: # Ensure that the pickle pairs with the netcdf, see #66. if NLMOD_CACHE_OPTIONS["nc_hash"] and nc_hash: with open(fname_cache, "rb") as myfile: cache_bytes = myfile.read() func_args_dic["_nc_hash"] = hashlib.sha256( cache_bytes ).hexdigest() if dataset is not None: # fix layer dtype if necessary if "layer" in cached_ds.coords: if dataset["layer"].dtype != cached_ds["layer"].dtype: # cached layer dtype might be read as fixed width dtype # modify dataset dtype to make hashes match dataset = dataset.assign_coords( { "layer": dataset["layer"].values.astype( cached_ds["layer"].dtype ) } ) if NLMOD_CACHE_OPTIONS["dataset_coords_hash"]: # Check the coords of the dataset argument, # 20250228: metadata is currently excluded as this was # causing differences that are not relevant to the cache... func_args_dic["_dataset_coords_hash"] = hash_xarray_coords( dataset, include_metadata=False ) else: func_args_dic_cache.pop("_dataset_coords_hash", None) logger.warning( "cache -> dataset coordinates not checked, " "disabled in global config. See " "`nlmod.config.NLMOD_CACHE_OPTIONS`." ) if not NLMOD_CACHE_OPTIONS[ "explicit_dataset_coordinate_comparison" ]: logger.warning( "It is recommended to turn on " "`explicit_dataset_coordinate_comparison` " "in global config when hash check is turned off!" ) if NLMOD_CACHE_OPTIONS["dataset_data_vars_hash"]: # Check the data_vars of the dataset argument # 20250228: metadata is currently excluded as this was # causing differences that are not relevant to the cache... func_args_dic["_dataset_data_vars_hash"] = ( hash_xarray_data_vars(dataset, include_metadata=False) ) else: func_args_dic_cache.pop("_dataset_data_vars_hash", None) logger.warning( "cache -> dataset data vars not checked, " "disabled in global config. See " "`nlmod.config.NLMOD_CACHE_OPTIONS`." ) # check if cache was created with same function arguments as # function call argument_check = _same_function_arguments( func_args_dic, func_args_dic_cache ) # explicit check on input dataset coordinates and cached dataset if NLMOD_CACHE_OPTIONS[ "explicit_dataset_coordinate_comparison" ] and isinstance(dataset, (xr.DataArray, xr.Dataset)): b = _explicit_dataset_coordinate_comparison(dataset, cached_ds) # update argument check argument_check = argument_check and b cached_ds = _check_for_data_array(cached_ds) if modification_check and argument_check and pickle_check: msg = f"using cached data -> {cachename}" logger.info(msg) return cached_ds # create cache result = func(*args_adj, **kwargs_adj) msg = f"caching data -> {cachename}" logger.info(msg) if isinstance(result, xr.DataArray): # set the DataArray as a variable in a new Dataset result = xr.Dataset({"__xarray_dataarray_variable__": result}) if isinstance(result, xr.Dataset): # close cached netcdf (otherwise it is impossible to overwrite) if os.path.exists(fname_cache): with xr.open_dataset(fname_cache, decode_coords="all") as cached_ds: cached_ds.load() # write netcdf cache # check if dataset is chunked for writing with dask.delayed first_data_var = next(iter(result.data_vars.keys())) if result[first_data_var].chunks: delayed = result.to_netcdf(fname_cache, compute=False) with ProgressBar(): delayed.compute() # close and reopen dataset to ensure data is read from # disk, and not from opendap result.close() result = xr.open_dataset( fname_cache, decode_coords="all", chunks="auto" ) else: result.to_netcdf(fname_cache) # add netcdf hash to function arguments dic, see #66 if NLMOD_CACHE_OPTIONS["nc_hash"] and nc_hash: with open(fname_cache, "rb") as myfile: cache_bytes = myfile.read() func_args_dic["_nc_hash"] = hashlib.sha256(cache_bytes).hexdigest() # Add dataset argument hash to function arguments dic if dataset is not None: if NLMOD_CACHE_OPTIONS["dataset_coords_hash"]: # 20250228: metadata is currently excluded as this was # causing differences that are not relevant to the cache... func_args_dic["_dataset_coords_hash"] = hash_xarray_coords( dataset, include_metadata=False ) else: logger.warning( "cache -> not writing dataset coordinates hash to " "pickle file, disabled in global config. See " "`nlmod.config.NLMOD_CACHE_OPTIONS`." ) if NLMOD_CACHE_OPTIONS["dataset_data_vars_hash"]: # 20250228: metadata is currently excluded as this was # causing differences that are not relevant to the cache... func_args_dic["_dataset_data_vars_hash"] = ( hash_xarray_data_vars(dataset, include_metadata=False) ) else: logger.warning( "cache -> not writing dataset data vars hash to " "pickle file, disabled in global config. See " "`nlmod.config.NLMOD_CACHE_OPTIONS`." ) # pickle function arguments with open(fname_pickle_cache, "wb") as fpklz: pickle.dump(func_args_dic, fpklz) else: msg = f"expected xarray Dataset, got {type(result)} instead" raise TypeError(msg) return _check_for_data_array(result) return wrapper return decorator
[docs] def cache_pickle(func): """Decorator to read/write the result of a function from/to a file to speed up function calls with the same arguments. Should only be applied to functions that: - return a picklable object - have functions arguments of types that can be checked using the _is_valid_cache functions 1. The directory and filename of the cache should be defined by the person calling a function with this decorator. If not defined no cache is created nor used. 2. Create a new cached file if it is impossible to check if the function arguments used to create the cached file are the same as the current function arguments. This can happen if one of the function arguments has a type that cannot be checked using the _is_valid_cache function. 3. Function arguments are pickled together with the cache to check later if the cache is valid. 4. This function uses `functools.wraps` and some home made magic in _update_docstring_and_signature to add arguments of the decorator to the decorated function. This assumes that the decorated function has a docstring with a "Returns" heading. If this is not the case an error is raised when trying to decorate the function. """ # add cachedir and cachename to docstring _update_docstring_and_signature(func) @functools.wraps(func) def decorator(*args, cachedir=None, cachename=None, **kwargs): # 1 check if cachedir and name are provided if cachedir is None or cachename is None: return func(*args, **kwargs) if not cachename.endswith(".pklz"): cachename += ".pklz" fname_cache = os.path.join(cachedir, cachename) # pklz file fname_pickle_cache = fname_cache.replace(".pklz", "__cache__.pklz") # create dictionary with function arguments func_args_dic = {f"arg{i}": args[i] for i in range(len(args))} func_args_dic.update(kwargs) # only use cache if the cache file and the pickled function arguments exist if os.path.exists(fname_cache) and os.path.exists(fname_pickle_cache): # check if you can read the function argument pickle, there are # several reasons why a pickle can not be read. try: with open(fname_pickle_cache, "rb") as f: func_args_dic_cache = pickle.load(f) pickle_check = True except (pickle.UnpicklingError, ModuleNotFoundError): logger.info("could not read pickle, not using cache") pickle_check = False argument_check = False # check if the module where the function is defined was changed # after the cache was created time_mod_func = _get_modification_time(func) time_mod_cache = os.path.getmtime(fname_cache) modification_check = time_mod_cache > time_mod_func if not modification_check: msg = ( f"module of function {func.__name__} recently modified, " "not using cache" ) logger.info(msg) # check if you can read the cached pickle, there are # several reasons why a pickle can not be read. try: with open(fname_cache, "rb") as f: cached_pklz = pickle.load(f) except (pickle.UnpicklingError, ModuleNotFoundError): logger.info("could not read pickle, not using cache") pickle_check = False argument_check = False if pickle_check: # add dataframe hash to function arguments dic try: import joblib func_args_dic["_pklz_hash"] = joblib.hash(cached_pklz) except ImportError: logger.warning( "joblib is not installed, cannot add dataframe hash to function arguments" ) # check if cache was created with same function arguments as # function call argument_check = _same_function_arguments( func_args_dic, func_args_dic_cache ) if modification_check and argument_check and pickle_check: msg = f"using cached data -> {cachename}" logger.info(msg) return cached_pklz # create cache result = func(*args, **kwargs) msg = f"caching data -> {cachename}" logger.info(msg) if isinstance(result, pd.DataFrame): # write pklz cache result.to_pickle(fname_cache) # add dataframe hash to function arguments dic with open(fname_cache, "rb") as f: temp = pickle.load(f) try: import joblib func_args_dic["_pklz_hash"] = joblib.hash(temp) except ImportError: logger.warning( "joblib is not installed, cannot add dataframe hash to function arguments" ) # pickle function arguments with open(fname_pickle_cache, "wb") as fpklz: pickle.dump(func_args_dic, fpklz) else: msg = f"expected DataFrame, got {type(result)} instead" raise TypeError(msg) return result return decorator
[docs] def _same_function_arguments(func_args_dic, func_args_dic_cache): """Checks if two dictionaries with function arguments are identical. The following items are checked: 1. if they have the same keys 2. if the items have the same type 3. if the items have the same values (only implemented for the types: int, float, bool, str, bytes, list, tuple, dict, np.ndarray, xr.DataArray, flopy.mf6.ModflowGwf). Parameters ---------- func_args_dic : dictionary dictionary with all the args and kwargs of a function call. func_args_dic_cache : dictionary dictionary with all the args and kwargs of a previous function call of which the results are cached. Returns ------- bool if True the dictionaries are identical which means that the cached data was created using the same function arguments as the requested data. Notes ----- Keys that end with '_hash' are assumed to be hashes and not function arguments. They are checked equally. """ for key, item in func_args_dic.items(): # check if cache and function call have same argument names if key not in func_args_dic_cache: msg = ( f"cache was created using different function argument '{key}' " "not in cached arguments, do not use cached data" ) logger.info(msg) return False # check if cache and function call have same argument types if not isinstance(item, type(func_args_dic_cache[key])): msg = ( f"cache was created using different function argument types for {key}: " f"current '{type(item)}' cache: '{type(func_args_dic_cache[key])}', " "do not use cached data" ) logger.info(msg) return False # check if cache and function call have same argument values if item is None: # Value of None type is always None so the check happens in previous if statement pass elif isinstance( item, (numbers.Number, bool, str, bytes, list, tuple, pathlib.PurePath) ): if item != func_args_dic_cache[key]: if key.endswith("_hash") and isinstance(item, str): logger.info( f"cached hashes do not match: {key}, do not use cached data" ) else: logger.info( f"cache was created using different function argument: {key}, " "do not use cached data" ) logger.debug(f"{key}: {item} != {func_args_dic_cache[key]}") return False elif isinstance(item, np.ndarray): if not np.allclose(item, func_args_dic_cache[key]): logger.info( f"cache was created using different numpy array for: {key}, " "do not use cached data" ) logger.debug( f"array '{key}' max difference with stored copy is " f"{np.max(np.abs(item - func_args_dic_cache[key]))}" ) return False elif isinstance(item, (pd.DataFrame, pd.Series, xr.DataArray)): if not item.equals(func_args_dic_cache[key]): logger.info( "cache was created using different DataFrame/Series/DataArray for: " f"{key}, do not use cached data" ) return False elif isinstance(item, dict): # recursive checking if not _same_function_arguments(item, func_args_dic_cache[key]): logger.info( f"cache was created using a different dictionary for: {key}, " "do not use cached data" ) return False elif isinstance(item, (flopy.mf6.ModflowGwf, flopy.modflow.mf.Modflow)): if str(item) != str(func_args_dic_cache[key]): logger.info( "cache was created using different groundwater flow model for: " f"{key}, do not use cached data" ) return False elif isinstance(item, flopy.utils.gridintersect.GridIntersect): i2 = func_args_dic_cache[key] is_method_equal = item.method == i2.method # check if mfgrid is equal except for cache_dict and polygons excl = ("_cache_dict", "_polygons") mfgrid1 = {k: v for k, v in item.mfgrid.__dict__.items() if k not in excl} mfgrid2 = {k: v for k, v in i2.mfgrid.__dict__.items() if k not in excl} is_same_length_props = all( np.all(np.size(v) == np.size(mfgrid2[k])) for k, v in mfgrid1.items() ) if ( not is_method_equal or mfgrid1.keys() != mfgrid2.keys() or not is_same_length_props ): logger.info( f"cache was created using different gridintersect object: {key}, " "do not use cached data" ) return False is_other_props_equal = all( np.all(v == mfgrid2[k]) for k, v in mfgrid1.items() ) if not is_other_props_equal: logger.info( f"cache was created using different gridintersect object: {key}, " "do not use cached data" ) return False else: logger.info( f"cannot check if cache argument {key} is valid, assuming invalid cache" f", function argument of type {type(item)}" ) return False return True
[docs] def _get_modification_time(func): """Return the modification time of the module where func is defined. Parameters ---------- func : function function. Returns ------- float modification time of module. """ mod = func.__module__ active_mod = importlib.import_module(mod.split(".")[0]) if "." in mod: for submod in mod.split(".")[1:]: active_mod = getattr(active_mod, submod) return os.path.getmtime(active_mod.__file__)
[docs] def _update_docstring_and_signature(func): """Add function arguments 'cachedir' and 'cachename' to the docstring and signature of a function. The function arguments are added before the "Returns" header in the docstring. If the function has no Returns header in the docstring, the function arguments are not added to the docstring. Parameters ---------- func : function function that is decorated. Returns ------- None """ # add cachedir and cachename to signature sig = inspect.signature(func) cur_param = tuple(sig.parameters.values()) if cur_param[-1].name == "kwargs": add_kwargs = cur_param[-1] cur_param = cur_param[:-1] else: add_kwargs = None new_param = ( *cur_param, inspect.Parameter( "cachedir", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None ), inspect.Parameter( "cachename", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None ), ) if add_kwargs is not None: new_param = (*new_param, add_kwargs) sig = sig.replace(parameters=new_param) func.__signature__ = sig # add cachedir and cachename to docstring original_doc = func.__doc__ if original_doc is None: msg = f'Function "{func.__name__}" has no docstring' logger.warning(msg) return if "Returns" not in original_doc: msg = f'Function "{func.__name__}" has no "Returns" header in docstring' logger.warning(msg) return before, after = original_doc.split("Returns") mod_before = ( before.strip() + "\n cachedir : str or None, optional\n" " directory to save cache. If None no cache is used." " Default is None.\n cachename : str or None, optional\n" " filename of netcdf cache. If None no cache is used." " Default is None.\n\n Returns" ) new_doc = f"{mod_before}{after}" func.__doc__ = new_doc return
[docs] def _check_for_data_array(ds): """Check if the saved NetCDF-file represents a DataArray or a Dataset, and return this data-variable. The file contains a DataArray when a variable called "__xarray_dataarray_variable__" is present in the Dataset. If so, return a DataArray, otherwise return the Dataset. By saving the DataArray, the coordinate "spatial_ref" was saved as a separate variable. Therefore, add this variable as a coordinate to the DataArray again. Parameters ---------- ds : xr.Dataset Dataset with dimensions and coordinates. Returns ------- ds : xr.Dataset or xr.DataArray A Dataset or DataArray containing the cached data. """ if "__xarray_dataarray_variable__" in ds: spatial_ref = ds.spatial_ref if "spatial_ref" in ds else None # the method returns a DataArray, so we return only this DataArray ds = ds["__xarray_dataarray_variable__"] if spatial_ref is not None: ds = ds.assign_coords({"spatial_ref": spatial_ref}) return ds
[docs] def ds_contains( ds, coords_2d=False, coords_3d=False, coords_time=False, attrs_ds=False, datavars=None, coords=None, attrs=None, ): """Returns a Dataset containing only the required data. If all kwargs are left to their defaults, the function returns the full dataset. Parameters ---------- ds : xr.Dataset Dataset with dimensions and coordinates. coords_2d : bool, optional Shorthand for adding 2D coordinates. The default is False. coords_3d : bool, optional Shorthand for adding 3D coordinates. The default is False. coords_time : bool, optional Shorthand for adding time coordinates. The default is False. attrs_ds : bool, optional Shorthand for adding model dataset attributes. The default is False. datavars : list, optional List of data variables to check for. The default is an empty list. coords : list, optional List of coordinates to check for. The default is an empty list. attrs : list, optional List of attributes to check for. The default is an empty list. Returns ------- ds : xr.Dataset A Dataset containing only the required data. """ # Return the full dataset if not configured if ds is None: msg = "No dataset provided" raise ValueError(msg) isdefault_args = not any( [coords_2d, coords_3d, coords_time, attrs_ds, datavars, coords, attrs] ) if isdefault_args: return ds isvertex = ds.attrs["gridtype"] == "vertex" # Initialize lists if datavars is None: datavars = [] if coords is None: coords = [] if attrs is None: attrs = [] # Add coords, datavars and attrs via shorthands if coords_2d or coords_3d: coords.append("x") coords.append("y") attrs.append("extent") attrs.append("gridtype") if isvertex: datavars.append("xv") datavars.append("yv") datavars.append("icvert") if "angrot" in ds.attrs: # set by `nlmod.base.to_model_ds()` and `nlmod.dims.resample._set_angrot_attributes()` attrs_angrot_required = ["angrot", "xorigin", "yorigin"] attrs.extend(attrs_angrot_required) if coords_3d: coords.append("layer") datavars.append("top") datavars.append("botm") if coords_time: coords.append("time") datavars.append("steady") datavars.append("nstp") datavars.append("tsmult") if attrs_ds: # set by `nlmod.base.to_model_ds()` and `nlmod.base.set_ds_attrs()`, # excluding "created_on" attrs_ds_required = [ "model_name", "mfversion", "exe_name", "model_ws", "figdir", "cachedir", "transport", ] attrs.extend(attrs_ds_required) # User-friendly error messages if missing from ds if "northsea" in datavars and "northsea" not in ds.data_vars: msg = "Northsea not in dataset. Run nlmod.read.rws.add_northsea() first." raise ValueError(msg) if coords_time: if "time" not in ds.coords: msg = "time not in dataset. Run nlmod.time.set_ds_time() first." raise ValueError(msg) # Check if time-coord is complete time_attrs_required = ["start", "time_units"] for t_attr in time_attrs_required: if t_attr not in ds["time"].attrs: msg = ( f"{t_attr} not in dataset['time'].attrs. " + "Run nlmod.time.set_ds_time() to set time." ) raise ValueError(msg) if attrs_ds: for attr in attrs_ds_required: if attr not in ds.attrs: msg = f"{attr} not in dataset.attrs. Run nlmod.set_ds_attrs() first." raise ValueError(msg) # User-unfriendly error messages for datavar in datavars: if datavar not in ds.data_vars: msg = f"{datavar} not in dataset.data_vars" raise ValueError(msg) for coord in coords: if coord not in ds.coords: msg = f"{coord} not in dataset.coords" raise ValueError(msg) for attr in attrs: if attr not in ds.attrs: msg = f"{attr} not in dataset.attrs" raise ValueError(msg) # Return only the required data return xr.Dataset( data_vars={k: ds.data_vars[k] for k in datavars}, coords={k: ds.coords[k] for k in coords}, attrs={k: ds.attrs[k] for k in attrs}, )
[docs] def _explicit_dataset_coordinate_comparison(ds_in, ds_cache): """Perform explicit dataset coordinate comparison. Uses `xarray.testing.assert_identical()`. Parameters ---------- ds_in : xr.Dataset Input dataset. ds_cache : xr.Dataset Cached dataset. Returns ------- bool True if coordinates are identical, else False. Raises ------ AssertionError If the coordinates are not equal. """ logger.debug("cache -> performing explicit dataset coordinate comparison") for coord in ds_cache.coords: logger.debug(f"cache -> comparing coordinate {coord}") try: assert_identical(ds_in[coord], ds_cache[coord]) except AssertionError as e: logger.debug(f"cache -> coordinate {coord} not equal") logger.debug(e) return False logger.debug("cache -> all coordinates equal") return True
[docs] class NumpyEncoder(json.JSONEncoder): """Special json encoder for numpy types."""
[docs] def default(self, o): if isinstance(o, np.integer): return int(o) elif isinstance(o, np.floating): return float(o) return json.JSONEncoder.default(self, o)
[docs] def hash_xarray_coords(ds, include_metadata: bool = False): """ Create a hash of xarray coordinate(s) using array bytes and optionally metadata. Parameters ---------- coord : xarray.core.coordinates.Coordinate The xarray coordinate object to hash include_metadata : bool, optional Whether to include metadata in the hash. Default is False. Returns ------- str The hexadecimal hash string """ combined_bytes = b"" for coord in ds.coords: # get the raw bytes from the numpy array values values_bytes = ds[coord].values.tobytes() combined_bytes += values_bytes if include_metadata: # get metadata as JSON metadata = { "name": coord, "dims": ds[coord].dims, "attrs": ds[coord].attrs, "dtype": str(ds[coord].dtype), "shape": ds[coord].shape, } metadata_bytes = json.dumps( metadata, sort_keys=True, cls=NumpyEncoder ).encode("utf-8") # combine both sets of bytes for hashing if include_metadata: combined_bytes += metadata_bytes # create a hash of the combined bytes return hashlib.sha256(combined_bytes).hexdigest()
[docs] def hash_xarray_data_vars( ds, include_metadata: bool = False, ): """ Create a hash of xarray data variables using array bytes and optionally metadata. Parameters ---------- ds : xarray.Dataset or xarray.DataArray The xarray coordinate object to hash include_metadata : bool, optional Whether to include metadata in the hash. Default is False. Returns ------- str The hexadecimal hash string """ # get the raw bytes from the numpy array values combined_bytes = b"" if isinstance(ds, xr.Dataset): # sort data vars en ensure hashes remain the same data_arrays = [ds[da] for da in sorted(ds.data_vars)] elif isinstance(ds, xr.DataArray): data_arrays = [ds] else: raise TypeError("Input must be an xarray Dataset or DataArray") for da in data_arrays: combined_bytes += da.values.tobytes() if include_metadata: # hash each coordinate separately coord_hashes = {} for coord_name, coord in sorted(da.coords.items()): coord_hashes[coord_name] = hash_xarray_coords( coord, include_metadata=False ) # get metadata as JSON metadata = { "name": da.name, "dims": da.dims, "attrs": da.attrs, "dtype": str(da.dtype), "shape": da.shape, "coord_hashes": coord_hashes, } metadata_bytes = json.dumps( metadata, sort_keys=True, cls=NumpyEncoder ).encode("utf-8") # combine both sets of bytes for hashing combined_bytes += metadata_bytes # create a hash of the combined bytes return hashlib.sha256(combined_bytes).hexdigest()