Clustering method

This tutorial gives a bit more detail about clustering methods and how to implement your own.

# Prerequisites
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["figure.dpi"] = 300
plt.rcParams["figure.figsize"] = (12, 5)

# Include this once we have a published release to fetch test data
# from toad.utils import download_test_data
# download_test_data()
from toad import TOAD
from toad.shifts import ASDETECT

td = TOAD("test_data/garbe_2020_antarctica.nc", time_dim="GMST")
td.data = td.data.coarsen(x=3, y=3, GMST=3, boundary="trim").reduce(np.mean)
/Users/jakobharteg/miniconda3/envs/toad312/lib/python3.12/site-packages/pyproj/network.py:59: UserWarning: pyproj unable to set PROJ database path.
  _set_context_ca_bundle_path(ca_bundle_path)
td.compute_shifts("thk", method=ASDETECT(), overwrite=True)
INFO: New shifts variable thk_dts: min/mean/max=-1.000/-0.223/0.897 using 1642 grid cells. Skipped 58.6% grid cells: 0 NaN, 2327 constant.

The td.compute_clusters function accepts clustering methods from the sklearn.cluster module.

from sklearn.cluster import DBSCAN, HDBSCAN

# HDBSCAN
td.compute_clusters(
    "thk",
    method=HDBSCAN(
        min_cluster_size=5,
    ),
    shift_threshold=0.5,
)

# DBSCAN
td.compute_clusters(
    "thk",
    method=DBSCAN(
        eps=0.1,
        min_samples=5,
        metric="euclidean",  # optional, defaults to 'euclidean'
    ),
    shift_threshold=0.5,
)
INFO: New cluster variable thk_dts_cluster: Identified 25 clusters in 1,633 pts; Left 23.1% as noise (378 pts).
INFO: New cluster variable thk_dts_cluster_1: Identified 28 clusters in 1,633 pts; Left 87.0% as noise (1,420 pts).
td

TOAD Object

Variable Hierarchy:

base var thk (1 shifts, 2 clusterings)

Hint: to access the xr.dataset call td.data

<xarray.Dataset> Size: 7MB
Dimensions:            (GMST: 116, y: 63, x: 63)
Coordinates:
  * GMST               (GMST) float64 928B 0.0701 0.1901 0.3101 ... 13.75 13.87
  * y                  (y) float64 504B -3e+06 -2.904e+06 ... 2.952e+06
  * x                  (x) float64 504B -3e+06 -2.904e+06 ... 2.952e+06
Data variables:
    thk                (GMST, y, x) float32 2MB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
    thk_dts            (GMST, y, x) float32 2MB nan nan nan nan ... nan nan nan
    thk_dts_cluster    (GMST, y, x) int32 2MB -1 -1 -1 -1 -1 ... -1 -1 -1 -1 -1
    thk_dts_cluster_1  (GMST, y, x) int32 2MB -1 -1 -1 -1 -1 ... -1 -1 -1 -1 -1
Attributes: (12/13)
    CDI:            Climate Data Interface version 1.9.6 (http://mpimet.mpg.d...
    proj4:          +lon_0=0.0 +ellps=WGS84 +datum=WGS84 +lat_ts=-71.0 +proj=...
    CDO:            Climate Data Operators version 1.9.6 (http://mpimet.mpg.d...
    source:         PISM (development v1.0-535-gb3de48787 committed by Julius...
    institution:    PIK / Potsdam Institute for Climate Impact Research
    author:         Julius Garbe (julius.garbe@pik-potsdam.de)
    ...             ...
    title:          garbe_2020_antarctica
    Conventions:    CF-1.9
    projection:     Polar Stereographic South (71S,0E)
    ice_density:    910. kg m-3
    NCO:            netCDF Operators version 4.7.8 (Homepage = http://nco.sf....
    Modifications:  Modified by Jakob Harteg (jakob.harteg@pik-potsdam.de) Se...

Defining your own clustering method

You can also define your own clustering method by extending the sklearn.base.ClusterMixin and sklearn.base.BaseEstimator classes:

from sklearn.base import ClusterMixin, BaseEstimator


# Your custom clustering class must inherit from ClusterMixin and BaseEstimator
class MyClusterer(ClusterMixin, BaseEstimator):
    # Pass params to your method here
    def __init__(self, my_param):
        self.my_param = my_param

    # required method to perform clustering
    def fit_predict(self, X: np.ndarray, y=None, **kwargs):
        # X = coords (time, x, y, z)
        # y = weights / detection signal

        cluster_labels_array = ...  # your clustering algorithm here
        return cluster_labels_array


# Then apply it with TOAD
# td.compute_clusters('thk',
#     method=MyClusterer(
#         my_param=(1, 2.0, 2.0), # time, x, y thresholds
#     ),
#     shift_threshold=0.8,
#     overwrite=True,
# )

Real example of a custom clustering method:

from sklearn.base import BaseEstimator, ClusterMixin


class ExampleClusterer(ClusterMixin, BaseEstimator):
    # required method
    def __init__(self, my_param=(0.5, 1.0, 1.0)):
        self.my_param = my_param

    # required method
    def fit_predict(self, X: np.ndarray, y=None, **kwargs):
        # X = coords (time, x, y, z)
        # y = weights / detection signal

        # Perform extremely crude clustering
        clusters = []
        cluster_labels_array = []
        for point in X:
            for i, centroid in enumerate(clusters):
                if all(abs(point - centroid) <= self.my_param):
                    break
            else:
                clusters.append(point)
                i = len(clusters) - 1
            cluster_labels_array.append(i)

        return cluster_labels_array


td.compute_clusters(
    "thk",
    method=ExampleClusterer(
        my_param=(0.5, 2.0, 2.0),  # time, x, y thresholds
    ),
    shift_threshold=0.9,
    # overwrite=True,
)
INFO: New cluster variable thk_dts_cluster_2: Identified 17 clusters in 1,103 pts; Left 0.0% as noise (0 pts).
td

TOAD Object

Variable Hierarchy:

base var thk (1 shifts, 3 clusterings)

Hint: to access the xr.dataset call td.data

<xarray.Dataset> Size: 9MB
Dimensions:            (GMST: 116, y: 63, x: 63)
Coordinates:
  * GMST               (GMST) float64 928B 0.0701 0.1901 0.3101 ... 13.75 13.87
  * y                  (y) float64 504B -3e+06 -2.904e+06 ... 2.952e+06
  * x                  (x) float64 504B -3e+06 -2.904e+06 ... 2.952e+06
Data variables:
    thk                (GMST, y, x) float32 2MB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
    thk_dts            (GMST, y, x) float32 2MB nan nan nan nan ... nan nan nan
    thk_dts_cluster    (GMST, y, x) int32 2MB -1 -1 -1 -1 -1 ... -1 -1 -1 -1 -1
    thk_dts_cluster_1  (GMST, y, x) int32 2MB -1 -1 -1 -1 -1 ... -1 -1 -1 -1 -1
    thk_dts_cluster_2  (GMST, y, x) int32 2MB -1 -1 -1 -1 -1 ... -1 -1 -1 -1 -1
Attributes: (12/13)
    CDI:            Climate Data Interface version 1.9.6 (http://mpimet.mpg.d...
    proj4:          +lon_0=0.0 +ellps=WGS84 +datum=WGS84 +lat_ts=-71.0 +proj=...
    CDO:            Climate Data Operators version 1.9.6 (http://mpimet.mpg.d...
    source:         PISM (development v1.0-535-gb3de48787 committed by Julius...
    institution:    PIK / Potsdam Institute for Climate Impact Research
    author:         Julius Garbe (julius.garbe@pik-potsdam.de)
    ...             ...
    title:          garbe_2020_antarctica
    Conventions:    CF-1.9
    projection:     Polar Stereographic South (71S,0E)
    ice_density:    910. kg m-3
    NCO:            netCDF Operators version 4.7.8 (Homepage = http://nco.sf....
    Modifications:  Modified by Jakob Harteg (jakob.harteg@pik-potsdam.de) Se...

We can inspect all method params in the attributes

# get attributes of last cluster variable
td.get_clusters(td.cluster_vars[-1]).attrs
{'standard_name': 'land_ice_thickness',
 'long_name': 'land ice thickness',
 'units': 'm',
 'pism_intent': 'model_state',
 'time_dim': 'GMST',
 'method_name': 'ExampleClusterer',
 'toad_version': '0.3',
 'base_variable': 'thk',
 'variable_type': 'cluster',
 'method_ignore_nan_warnings': 'False',
 'method_lmin': '5',
 'method_segmentation': 'two_sided',
 'cluster_ids': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16]),
 'shift_threshold': 0.9,
 'shift_selection': 'local',
 'shift_direction': 'both',
 'scaler': 'StandardScaler',
 'time_weight': 1,
 'n_data_points': 1103,
 'runtime_preprocessing': 0.020125150680541992,
 'runtime_clustering': 0.015676021575927734,
 'runtime_total': 0.03580117225646973,
 'shifts_variable': 'thk_dts',
 'method_my_param': '(0.5, 2.0, 2.0)'}
td.plot.overview(td.cluster_vars[-1], map_style={"projection": "south_pole"});
../_images/3399b1d1f149d1174a228a47af587d4fad14075153ad6a00b8f09b81710c37f9.png