Source code for phenotypic.detect._round_peaks_detector

from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from phenotypic import Image

import gc
from typing import Literal

import numpy as np
import scipy.ndimage as ndimage
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d

from phenotypic.abc_ import ObjectDetector
import skimage.filters as filters
import skimage.morphology as morphology


[docs] class RoundPeaksDetector(ObjectDetector): """ Class for detecting circular colonies in gridded plate images using the gitter algorithm. The RoundPeaksDetector implements an improved Python version of the gitter colony detection algorithm originally developed for R. This method is specifically designed for quantifying pinned microbial cultures arranged in a regular grid pattern on agar plates. The algorithm works by: 1. Thresholding the image to create a binary mask of colonies 2. Analyzing row and column intensity profiles to detect periodic peaks 3. Estimating grid edges based on peak positions 4. Assigning pixels to grid cells and identifying dominant colonies This approach is robust to irregular colonies, noise, variable illumination, and other common plate imaging artifacts. Note: For best results, use preprocessing such as `GaussianBlur` or other enhancement techniques before detection. The detector works best with images where colonies are clearly visible against the background. This detector works best for yeast-like growth where the colonies are circular and less likely to work on filamentous fungi. Warning: Grid inference from the binary mask alone (when not using GridImage) may be less accurate than providing explicit grid information. For optimal results, use with GridImage when grid parameters are known. Attributes: thresh_method (str): Thresholding method to use for binary mask creation. Options: 'otsu', 'mean', 'local', 'triangle', 'minimum', 'isodata'. Default is 'otsu'. subtract_background (bool): Whether to apply white tophat background subtraction before thresholding. Helps with uneven illumination. remove_noise (bool): Whether to apply binary opening to remove small noise artifacts after thresholding. footprint_radius (int): Radius for morphological operations (noise removal and background subtraction kernels). smoothing_sigma (float): Standard deviation for Gaussian smoothing of row/column sums before peak detection. Higher values increase robustness to noise but may merge nearby peaks. Set to 0 to disable. min_peak_distance (int | None): Minimum distance between peaks in pixels. If None, automatically estimated from grid dimensions. Prevents detection of spurious peaks too close together. peak_prominence (float | None): Minimum prominence of peaks for detection. If None, automatically estimated from signal statistics. Higher values are more selective. edge_refinement (bool): Whether to refine grid edges using local intensity profiles. Improves accuracy but adds computational cost. References: Wagih, O. and Parts, L. (2014). gitter: a robust and accurate method for quantification of colony sizes from plate images. G3 (Bethesda), 4(3), 547-552. https://omarwagih.github.io/gitter/ """
[docs] def __init__( self, thresh_method: Literal[ "otsu", "mean", "local", "triangle", "minimum", "isodata" ] = "otsu", subtract_background: bool = True, remove_noise: bool = True, footprint_radius: int = 3, smoothing_sigma: float = 2.0, min_peak_distance: int | None = None, peak_prominence: float | None = None, edge_refinement: bool = True, ): """ Initialize the RoundPeaksDetector with specified parameters. Args: thresh_method: Method for thresholding the image. Options are: 'otsu' (default), 'mean', 'local', 'triangle', 'minimum', 'isodata'. subtract_background: If True, apply white tophat transform to remove background variations before thresholding. remove_noise: If True, apply morphological opening to remove small noise artifacts from the binary mask. footprint_radius: Radius in pixels for morphological operations. Larger values remove larger noise but may erode colony edges. smoothing_sigma: Standard deviation for Gaussian smoothing of intensity profiles before peak detection. Set to 0 to disable smoothing. min_peak_distance: Minimum allowed distance between detected peaks. If None, automatically estimated from grid dimensions. peak_prominence: Minimum prominence required for peak detection. If None, automatically calculated as 0.1 * signal range. edge_refinement: If True, refine grid edges using weighted intensity profiles for improved accuracy. """ super().__init__() self.thresh_method = thresh_method self.subtract_background = subtract_background self.footprint_radius = footprint_radius self.remove_noise = remove_noise self.smoothing_sigma = smoothing_sigma self.min_peak_distance = min_peak_distance self.peak_prominence = peak_prominence self.edge_refinement = edge_refinement
def _operate(self, image: Image) -> Image: """ Detect colonies in the image using the gitter algorithm. This method performs the core detection workflow: 1. Threshold the enhanced grayscale image 2. Remove noise if requested 3. Label connected components 4. Determine or estimate grid edges 5. Assign dominant colonies to grid cells 6. Create final object map Args: image: Image object to process. Can be a regular Image or GridImage. Returns: Image: The processed image with updated objmask and objmap. """ from phenotypic import GridImage enh_matrix = image.enh_gray[:] self._log_memory_usage("getting enhanced gray") objmask = self._thresholding(enh_matrix) self._log_memory_usage("after thresholding") if self.remove_noise: objmask = morphology.binary_opening( objmask, morphology.diamond(radius=self.footprint_radius) ) self._log_memory_usage("after noise removal") # Keep a copy of the mask we intend to use for downstream measurements image.objmask[:] = objmask labeled, num_features = ndimage.label( objmask, structure=ndimage.generate_binary_structure(2, 2) ) self._log_memory_usage(f"after labeling ({num_features} features)") # Determine grid edges either from GridImage or by estimating from the binary mask if isinstance(image, GridImage): row_edges = np.round(image.grid.get_row_edges()).astype(int) col_edges = np.round(image.grid.get_col_edges()).astype(int) nrows, ncols = image.nrows, image.ncols else: nrows = ncols = None row_edges = col_edges = None if row_edges is None or col_edges is None: # Estimate edges using peak finding on row/col sums nrows, ncols = self._infer_grid_shape(objmask) self._log_memory_usage(f"inferred grid shape: {nrows}x{ncols}") row_edges = self._estimate_edges(objmask, axis=0, n_bins=nrows) col_edges = self._estimate_edges(objmask, axis=1, n_bins=ncols) self._log_memory_usage("after edge estimation") # Refine edges if requested if self.edge_refinement: row_edges = self._refine_edges(objmask, row_edges, axis=0) col_edges = self._refine_edges(objmask, col_edges, axis=1) self._log_memory_usage("after edge refinement") row_edges = np.clip(np.unique(row_edges), 0, objmask.shape[0]) col_edges = np.clip(np.unique(col_edges), 0, objmask.shape[1]) objmap = np.zeros_like(labeled, dtype=image._OBJMAP_DTYPE) label_counter = 1 # Assign dominant colonies to each grid cell for r in range(len(row_edges) - 1): r0, r1 = row_edges[r], row_edges[r + 1] for c in range(len(col_edges) - 1): c0, c1 = col_edges[c], col_edges[c + 1] region = labeled[r0:r1, c0:c1] if region.size == 0: continue uniq, counts = np.unique(region, return_counts=True) valid = uniq != 0 uniq = uniq[valid] counts = counts[valid] if uniq.size == 0: continue dominant_label = uniq[np.argmax(counts)] mask = region == dominant_label if np.any(mask): objmap[r0:r1, c0:c1][mask] = label_counter label_counter += 1 # Fallback if no regions were labeled (e.g., grid inference failed) if label_counter == 1: objmap = labeled.astype(image._OBJMAP_DTYPE, copy=False) self._log_memory_usage("after grid cell assignment") image.objmap[:] = objmap image.objmap.relabel(connectivity=1) gc.collect() # Force garbage collection self._log_memory_usage( "final cleanup", include_process=True, include_tracemalloc=True ) return image def _thresholding(self, matrix: np.ndarray) -> np.ndarray: """ Threshold the image to create a binary mask of foreground colonies. This method applies optional background subtraction followed by one of several thresholding algorithms to separate colonies from background. Args: matrix: 2D enhanced grayscale array with pixel intensities. Returns: np.ndarray: Binary mask where True/1 indicates colony pixels, False/0 indicates background. Raises: ValueError: If an invalid thresholding method is specified. """ kernel = morphology.footprint_rectangle( (self.footprint_radius * 2, self.footprint_radius * 2) ) enh_matrix = matrix.copy() # Work on a copy to avoid modifying input # Subtract background using white tophat to handle uneven illumination if self.subtract_background: tophat_res = morphology.white_tophat(enh_matrix, kernel) enh_matrix = enh_matrix - tophat_res # Apply selected thresholding method match self.thresh_method: case "otsu": thresh = filters.threshold_otsu(enh_matrix) case "mean": thresh = filters.threshold_mean(enh_matrix) case "local": block_size = max( self.footprint_radius * 2 + 1, 3 ) # Ensure odd block size thresh = filters.threshold_local(enh_matrix, block_size=block_size) case "triangle": thresh = filters.threshold_triangle(enh_matrix) case "minimum": thresh = filters.threshold_minimum(enh_matrix) case "isodata": thresh = filters.threshold_isodata(enh_matrix) case _: # Default to Otsu if method not recognized thresh = filters.threshold_otsu(enh_matrix) return enh_matrix >= thresh def _clean_and_sum_binary( self, binary_image: np.ndarray, p: float = 0.2, axis: int = 0 ) -> np.ndarray: """ Compute projection sums while removing problematic edge artifacts. This method identifies rows (axis=0) or columns (axis=1) near image edges that contain abnormally long stretches of foreground pixels (likely artifacts or plate edges) and excludes them from the sum to avoid spurious peaks. Args: binary_image: Binary mask of detected colonies. p: Proportion of image dimension to use as threshold for detecting problematic long runs (default: 0.2 = 20%). axis: Direction to sum along following numpy convention. - axis=0: Sum along rows (collapse rows → column sums for row edge detection) - axis=1: Sum along columns (collapse columns → row sums for column edge detection) Returns: np.ndarray: 1D array of cleaned sums along the specified axis. Problematic edge regions are set to 0. Note: This cleaning step helps avoid detecting false peaks from plate edges or imaging artifacts that span large portions of rows/columns. """ # Calculate threshold based on image dimensions # For axis=0: we're summing columns, so check for long runs across columns # For axis=1: we're summing rows, so check for long runs across rows if axis == 0: c = p * binary_image.shape[1] # Threshold based on number of columns n_slices = binary_image.shape[0] # Number of rows to iterate through else: c = p * binary_image.shape[0] # Threshold based on number of rows n_slices = binary_image.shape[1] # Number of columns to iterate through # Identify problematic rows/columns with long stretches of 1s problematic = np.zeros(n_slices, dtype=bool) for i in range(n_slices): if axis == 0: slice_data = binary_image[i, :] # Get row i else: slice_data = binary_image[:, i] # Get column i # Run-length encoding to find stretches of 1s diff = np.diff(np.concatenate(([0], slice_data.astype(int), [0]))) starts = np.where(diff == 1)[0] ends = np.where(diff == -1)[0] lengths = ends - starts # Check if any stretch of 1s is longer than threshold if len(lengths) > 0 and np.any(lengths > c): problematic[i] = True # Compute sums along the specified axis sums = np.sum(binary_image, axis=axis, dtype=np.float64) # Split problematic array in half and zero out problematic regions at edges mid = len(problematic) // 2 left_prob = problematic[:mid] right_prob = problematic[mid:] # Zero out sums for problematic regions at edges if np.any(left_prob): last_prob = np.where(left_prob)[0][-1] sums[: last_prob + 1] = 0 if np.any(right_prob): first_prob = np.where(right_prob)[0][0] + mid sums[first_prob:] = 0 return sums def _estimate_edges( self, binary_image: np.ndarray, axis: int, n_bins: int ) -> np.ndarray: """ Estimate grid edges by detecting periodic peaks in row/column intensity sums. This method implements the core of the gitter algorithm by analyzing the projection of colonies onto rows or columns. It detects peaks corresponding to colony centers and derives grid edges between them. Args: binary_image: Binary mask of detected colonies. axis: Direction for edge detection (0 for row edges, 1 for column edges). n_bins: Expected number of grid bins (rows or columns). Returns: np.ndarray: Array of edge positions including image borders. Length is n_bins + 1. Note: The method applies smoothing to the intensity profile before peak detection to improve robustness. If automatic peak detection fails to find enough peaks, it falls back to evenly-spaced bins. """ # Get cleaned sums along the specified axis sums = self._clean_and_sum_binary(binary_image, axis=axis) # Apply Gaussian smoothing if requested to reduce noise if self.smoothing_sigma > 0: sums = gaussian_filter1d(sums, sigma=self.smoothing_sigma) # Calculate expected spacing between colonies image_size = binary_image.shape[1 - axis] # Size along the summed dimension expected_spacing = max(image_size // max(n_bins, 1), 1) # Determine peak detection parameters min_distance = ( self.min_peak_distance if self.min_peak_distance is not None else max(expected_spacing // 2, 1) ) # Calculate prominence if not provided if self.peak_prominence is not None: prominence = self.peak_prominence else: # noinspection PyUnresolvedReferences signal_range = np.max(sums) - np.min(sums) prominence = 0.1 * signal_range if signal_range > 0 else None # Detect peaks with prominence and distance constraints peaks, properties = find_peaks( sums, distance=min_distance, prominence=prominence ) if peaks.size < n_bins: # Fallback: enforce evenly spaced peaks if auto detection under-fits peaks = np.linspace( start=expected_spacing // 2, stop=image_size - expected_spacing // 2, num=n_bins, dtype=int, ) elif peaks.size > n_bins: # Keep the strongest n_bins peaks by height peak_heights = sums[peaks] top_indices = np.argsort(peak_heights)[-n_bins:] peaks = np.sort(peaks[top_indices]) # Derive edges midway between peaks if len(peaks) > 1: # Calculate midpoints between consecutive peaks midpoints = ((peaks[:-1] + peaks[1:]) / 2).astype(int) # Prepend/append image borders edges = np.concatenate(([0], midpoints, [image_size])) else: # Fallback for single or no peaks: evenly divide the space edges = np.linspace(0, image_size, n_bins + 1, dtype=int) # Ensure we have exactly n_bins + 1 edges if edges.size > n_bins + 1: edges = edges[: n_bins + 1] elif edges.size < n_bins + 1: missing = (n_bins + 1) - edges.size edges = np.concatenate((edges, np.full(missing, image_size))) return edges.astype(int) def _refine_edges( self, binary_image: np.ndarray, edges: np.ndarray, axis: int ) -> np.ndarray: """ Refine grid edges using local intensity profiles for improved accuracy. This method adjusts edge positions by analyzing the intensity distribution near each initial edge estimate. It shifts edges to positions of minimum intensity (background) between colonies. Args: binary_image: Binary mask of detected colonies. edges: Initial edge estimates from peak detection. axis: Direction of edges (0 for row edges, 1 for column edges). Returns: np.ndarray: Refined edge positions. Note: This refinement step can significantly improve accuracy by placing edges in the valleys between colonies rather than at fixed positions. """ refined_edges = edges.copy() sums = np.sum(binary_image, axis=axis, dtype=np.float64) # Refine each internal edge (not the borders) for i in range(1, len(edges) - 1): edge_pos = edges[i] # Define search window around current edge search_radius = min(10, (edges[i + 1] - edges[i - 1]) // 4) search_start = max(0, edge_pos - search_radius) search_end = min(len(sums), edge_pos + search_radius + 1) # Find position of minimum intensity in search window search_window = sums[search_start:search_end] if len(search_window) > 0: local_min_idx = np.argmin(search_window) refined_edges[i] = search_start + local_min_idx return refined_edges.astype(int) def _infer_grid_shape(self, binary_image: np.ndarray) -> tuple[int, int]: """ Infer grid dimensions from the binary mask when not explicitly provided. This method estimates the number of rows and columns in the grid by counting connected components and assuming a roughly rectangular layout. Common plate formats (96-well, 384-well) are used as fallbacks. Args: binary_image: Binary mask of detected colonies. Returns: tuple[int, int]: Estimated (n_rows, n_cols) for the grid. Note: This is a best-effort estimate. For accurate results, provide grid dimensions explicitly using GridImage. """ labeled, num = ndimage.label(binary_image) if num == 0: # Default to 96-well plate format (8x12) return 8, 12 # Estimate based on aspect ratio and colony count aspect_ratio = binary_image.shape[1] / binary_image.shape[0] if aspect_ratio > 1.3: # Wide plate (likely 8x12 or similar) # Try 8x12 (96 wells), 16x24 (384 wells), etc. if num <= 100: return 8, 12 elif num <= 400: return 16, 24 else: approx_rows = int(np.ceil(np.sqrt(num / aspect_ratio))) approx_cols = int(np.ceil(np.sqrt(num * aspect_ratio))) return approx_rows, approx_cols else: # Square-ish layout approx_side = int(np.ceil(np.sqrt(num))) return approx_side, max(approx_side, 1)