Source code for bblean.merges

r"""Merging criteria for BitBIRCH clustering

The functionality in this module is advanced and not needed for normal usage of the
library. Make sure you understand the features here before applying them.
"""

import re
import typing as tp
import numpy as np
from numpy.typing import NDArray

# NOTE: jt_isim_from_sum is equivalent to jt_isim_diameter_compl_from_sum
from bblean.similarity import jt_isim_from_sum, jt_isim_radius_compl_from_sum

_BUILTIN_MERGES = [
    "radius",
    "diameter",
    "flexible-tolerance-diameter",
    "tolerance-diameter",
    "tolerance-radius",
    "tolerance",
    "never",
]


[docs] class DiscardSubcluster(Exception): r"""If raised in hooks, immediatly exit the merge discarding the incident subcluster Discarded subclusters will not be stored in the final tree, and will only show up if calling `bblean.BitBirch.get_assigments` (or the `labels_` attribute if using `bblean.sklearn`) with a cluster label of 0. """
[docs] class RejectMerge(Exception): r"""If raised in hooks, immediatly exit the merge and reject it"""
[docs] class MergeAcceptFunction: r"""Base class for user defined merges If you want to implement a custom BitBirch merge you can subclass this and pass an instance of this function to a `bblean.BitBirch` class upon creation as ``BitBirch(..., merge_criterion=instance)``. .. warning:: This is an advanced feature, make sure you fully understand what you are doing! """ # For the merge functions, although outputs of jt_isim_from_sum f64, directly using # f64 is *not* faster than starting with uint64 def __call__( self, thresh: float, new_ls: NDArray[np.integer], new_n: int, old_ls: NDArray[np.integer], nom_ls: NDArray[np.integer], old_n: int, nom_n: int, old_idxs: tp.Sequence[int], nom_idxs: tp.Sequence[int], ) -> bool: try: thresh = self.on_check_merge_start( thresh, new_ls, new_n, old_ls, nom_ls, old_n, nom_n, old_idxs, nom_idxs ) accepted = self.check_merge( thresh, new_ls, new_n, old_ls, nom_ls, old_n, nom_n, old_idxs, nom_idxs ) self.on_check_merge_end(accepted, old_idxs, nom_idxs) except RejectMerge: return False return accepted
[docs] def on_check_merge_start( self, threshold: float, new_sum: NDArray[np.integer], new_n: int, old_sum: NDArray[np.integer], nominee_sum: NDArray[np.integer], old_n: int, nominee_n: int, old_idxs: tp.Sequence[int], nominee_idxs: tp.Sequence[int], ) -> float: r"""Hook called before a merge is checked (meant to be overriden) See `MergeAcceptFunction.check_merge` for an explanation of the different args .. warning:: Numpy arrays passed to this function may use uint types, watch out for pitfalls of unsigned integer arithmetic. This function must return the threshold, unchanged. If the threshold is modified by this function, the new threshold will be used for this specific merge check. """ return threshold
[docs] def check_merge( self, threshold: float, new_sum: NDArray[np.integer], new_n: int, old_sum: NDArray[np.integer], nominee_sum: NDArray[np.integer], old_n: int, nominee_n: int, old_idxs: tp.Sequence[int], nominee_idxs: tp.Sequence[int], ) -> bool: r"""Check if a merge should be accepted. All user-defined merges should subclass this function. threshold: Threshold for the merge new_sum: old_sum + nominee_sum new_n: old_n + new_n old_sum: col-wise sum of all fingerprints in this cluster nominee_sum: col-wise sum of all fingerprints in the nominee cluster old_n: size of this cluster nominee_n: size of the nominee cluster old_idxs: Indices of the mols in cluster before the nominee cluster is merged nominee_idxs: Nominee indices to merge. If merging a single molecule (the most common case, when calling `bblean.BitBirch.fit`), ``nominee_idxs`` will be a list with a *single index*, ``new_n = 1``, and ``nominee_sum`` will be the molecule fingerprint. This function must return a boolean that determines whether the merge was accepted .. warning:: Numpy arrays passed to this function may use uint types, watch out for pitfalls of unsigned integer arithmetic. """ raise NotImplementedError
[docs] def on_check_merge_end( self, accepted: bool, old_idxs: tp.Sequence[int], nominee_idxs: tp.Sequence[int], ) -> None: r"""Hook called after a merge is checked (meant to be overriden) accept: Whether the merge was accepted old_idxs: Indices of the mols in cluster before the nominee cluster is merged nominee_idxs: Nominee indices to merge. If merging a single molecule (the most common case, when calling `bblean.BitBirch.fit`, ``nominee_idxs`` will be a list with a *single index*) This function must not return a value """
@property def name(self) -> str: return "-".join( s.lower() for s in re.split(r"(?=[A-Z])", self.__class__.__name__)[1:] if s != "Merge" ) def __repr__(self) -> str: return f"{self.__class__.__name__}()"
[docs] class RadiusMerge(MergeAcceptFunction): def check_merge( self, threshold: float, new_sum: NDArray[np.integer], new_n: int, old_sum: NDArray[np.integer], nominee_sum: NDArray[np.integer], old_n: int, nominee_n: int, old_idxs: tp.Sequence[int] | None = None, nominee_idxs: tp.Sequence[int] | None = None, ) -> bool: r""":meta private:""" return jt_isim_radius_compl_from_sum(new_sum, new_n) >= threshold
[docs] class DiameterMerge(MergeAcceptFunction): def check_merge( self, threshold: float, new_sum: NDArray[np.integer], new_n: int, old_sum: NDArray[np.integer], nominee_sum: NDArray[np.integer], old_n: int, nominee_n: int, old_idxs: tp.Sequence[int] | None = None, nominee_idxs: tp.Sequence[int] | None = None, ) -> bool: r""":meta private:""" return jt_isim_from_sum(new_sum, new_n) >= threshold
[docs] class FlexibleToleranceDiameterMerge(MergeAcceptFunction): # NOTE: Equivalent to tolerance-diameter but uses min(old_dc, threshold) as the # criteria def __init__( self, tolerance: float = 0.05, n_max: int = 1000, decay: float = 1e-3, adaptive: bool = True, ) -> None: self.tolerance = tolerance self.decay = decay self.offset = np.exp(-decay * n_max) if not adaptive: self.decay = 0.0 self.offset = 0.0 def check_merge( self, threshold: float, new_sum: NDArray[np.integer], new_n: int, old_sum: NDArray[np.integer], nominee_sum: NDArray[np.integer], old_n: int, nominee_n: int, old_idxs: tp.Sequence[int] | None = None, nominee_idxs: tp.Sequence[int] | None = None, ) -> bool: r""":meta private:""" new_dc = jt_isim_from_sum(new_sum, new_n) if new_dc < threshold: return False if old_n == 1: return True old_dc = jt_isim_from_sum(old_sum, old_n) tol = max(self.tolerance * (np.exp(-self.decay * old_n) - self.offset), 0.0) return new_dc >= min(old_dc, threshold) - tol def __repr__(self) -> str: return f"{self.__class__.__name__}({self.tolerance})"
[docs] class ToleranceDiameterMerge(MergeAcceptFunction): # NOTE: The reliability of the estimate of the cluster should be a function of the # size of the old cluster, so in this metric, tolerance is larger for small clusters # tolerance = max{ alpha * (exp(-decay * N_old) - offset), 0} def __init__( self, tolerance: float = 0.05, n_max: int = 1000, decay: float = 1e-3, adaptive: bool = True, ) -> None: self.tolerance = tolerance self.decay = decay self.offset = np.exp(-decay * n_max) if not adaptive: self.decay = 0.0 self.offset = 0.0 def check_merge( self, threshold: float, new_sum: NDArray[np.integer], new_n: int, old_sum: NDArray[np.integer], nominee_sum: NDArray[np.integer], old_n: int, nominee_n: int, old_idxs: tp.Sequence[int] | None = None, nominee_idxs: tp.Sequence[int] | None = None, ) -> bool: r""":meta private:""" new_dc = jt_isim_from_sum(new_sum, new_n) if new_dc < threshold: return False # If the old n is 1 then merge directly (infinite tolerance), since the # old_d is undefined for a single fp if old_n == 1: return True # Only merge if the new_dc is greater or equal to the old, up to some tolerance, # which decays with N old_dc = jt_isim_from_sum(old_sum, old_n) tol = max(self.tolerance * (np.exp(-self.decay * old_n) - self.offset), 0.0) return new_dc >= old_dc - tol def __repr__(self) -> str: return f"{self.__class__.__name__}({self.tolerance})"
[docs] class ToleranceRadiusMerge(ToleranceDiameterMerge): def check_merge( self, threshold: float, new_sum: NDArray[np.integer], new_n: int, old_sum: NDArray[np.integer], nominee_sum: NDArray[np.integer], old_n: int, nominee_n: int, old_idxs: tp.Sequence[int] | None = None, nominee_idxs: tp.Sequence[int] | None = None, ) -> bool: r""":meta private:""" new_rc = jt_isim_radius_compl_from_sum(new_sum, new_n) if new_rc < threshold: return False if old_n == 1: return True old_rc = jt_isim_radius_compl_from_sum(old_sum, old_n) tol = max(self.tolerance * (np.exp(-self.decay * old_n) - self.offset), 0.0) return new_rc >= old_rc - tol def __repr__(self) -> str: return f"{self.__class__.__name__}({self.tolerance})"
class NeverMerge(ToleranceDiameterMerge): r""":meta private:""" def check_merge( self, threshold: float, new_sum: NDArray[np.integer], new_n: int, old_sum: NDArray[np.integer], nominee_sum: NDArray[np.integer], old_n: int, nominee_n: int, old_idxs: tp.Sequence[int] | None = None, nominee_idxs: tp.Sequence[int] | None = None, ) -> bool: return False def __repr__(self) -> str: return f"{self.__class__.__name__}()" class ToleranceLegacyMerge(MergeAcceptFunction): r""":meta private:""" def __init__(self, tolerance: float = 0.05) -> None: self.tolerance = tolerance def check_merge( self, threshold: float, new_sum: NDArray[np.integer], new_n: int, old_sum: NDArray[np.integer], nominee_sum: NDArray[np.integer], old_n: int, nominee_n: int, old_idxs: tp.Sequence[int] | None = None, nominee_idxs: tp.Sequence[int] | None = None, ) -> bool: # First two branches are equivalent to 'diameter' new_dc = jt_isim_from_sum(new_sum, new_n) if new_dc < threshold: return False if old_n == 1 or nominee_n != 1: return True # 'new_dc >= threshold' and 'new_n == old_n + 1' are guaranteed here old_dc = jt_isim_from_sum(old_sum, old_n) return (new_dc * new_n - old_dc * (old_n - 1)) / 2 >= old_dc - self.tolerance def __repr__(self) -> str: return f"{self.__class__.__name__}({self.tolerance})" # Make these ones leaner with less calls so they don't take up any extra time at all in # the default case class _FastMerge(MergeAcceptFunction): def __call__( self, thresh: float, new_ls: NDArray[np.integer], new_n: int, old_ls: NDArray[np.integer], nom_ls: NDArray[np.integer], old_n: int, nom_n: int, old_idxs: tp.Sequence[int], nom_idxs: tp.Sequence[int], ) -> bool: return self.check_merge( thresh, new_ls, new_n, old_ls, nom_ls, old_n, nom_n, old_idxs, nom_idxs ) @property def name(self) -> str: return "-".join( s.lower() for s in re.split(r"(?=[A-Z])", self.__class__.__name__)[1:] if s not in ["Merge", "Fast"] ) class _FastDiameterMerge(_FastMerge, DiameterMerge): pass class _FastRadiusMerge(_FastMerge, RadiusMerge): pass class _FastToleranceDiameterMerge(_FastMerge, ToleranceDiameterMerge): pass class _FastFlexibleToleranceDiameterMerge(_FastMerge, FlexibleToleranceDiameterMerge): pass class _FastToleranceRadiusMerge(_FastMerge, ToleranceRadiusMerge): pass class _FastToleranceLegacyMerge(_FastMerge, ToleranceLegacyMerge): pass class _FastNeverMerge(_FastMerge, NeverMerge): pass def _get_merge_accept_fn( merge_criterion: str, tolerance: float = 0.05 ) -> MergeAcceptFunction: if merge_criterion == "radius": return _FastRadiusMerge() elif merge_criterion == "diameter": return _FastDiameterMerge() elif merge_criterion == "tolerance-legacy": return _FastToleranceLegacyMerge(tolerance) elif merge_criterion == "tolerance-diameter": return _FastToleranceDiameterMerge(tolerance) elif merge_criterion == "flexible-tolerance-diameter": return _FastFlexibleToleranceDiameterMerge(tolerance) elif merge_criterion == "tolerance-radius": return _FastToleranceRadiusMerge(tolerance) elif merge_criterion == "never": return _FastNeverMerge(tolerance) raise ValueError( f"Unknown merge criterion {merge_criterion} " "Valid criteria are: radius|diameter|tolerance-diameter|tolerance-radius" )