from __future__ import annotations
import numpy as np
import pandas as pd
from joblib import delayed, Parallel
from scipy.stats import permutation_test
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from phenotypic.tools.constants_ import MeasurementInfo
from .abc_ import SetAnalyzer
class EDGE_CORRECTION(MeasurementInfo):
@classmethod
def category(cls) -> str:
return "EdgeCorrection"
CORRECTED_CAP = "CorrectedCap", "The carrying capacity for the target measurement"
[docs]
class EdgeCorrector(SetAnalyzer):
"""Analyzer for detecting and correcting edge effects in colony detection.
This class identifies colonies at grid edges (missing orthogonal neighbors) and
caps their measurement values to prevent edge effects in growth assays. Edge
colonies often show artificially inflated measurements due to lack of competition
for resources.
"""
[docs]
def __init__(
self,
on: str,
groupby: list[str],
time_label: str = "Metadata_Time",
nrows: int = 8,
ncols: int = 12,
top_n: int = 3,
pvalue: float = 0.05,
connectivity: int = 4,
agg_func: str = "mean",
num_workers: int = 1,
):
"""
Initializes the class with specified parameters to configure the state of the object.
The class is aimed at processing and analyzing connectivity data with multiple grouping
and aggregation options, while ensuring input validation.
Args:
on (str): The dataset column to analyze or process.
groupby (list[str]): List of column names for grouping the data.
time_label (str): Specific time reference column, defaulting to "Metadata_Time".
nrows (int): Number of rows in the dataset, must be positive.
ncols (int): Number of columns in the dataset, must be positive.
top_n (int): Number of top results to analyze. Must be a positive integer.
pvalue (float): Statistical threshold for significance testing between the surrounded and edge colonies.
defaults to 0.05. Set to 0.0 to apply to all plates.
connectivity (int): The connectivity mode to use. Must be either 4 or 8.
agg_func (str): Aggregation function to apply, defaulting to 'mean'.
num_workers (int): Number of workers for parallel processing.
Raises:
ValueError: If `connectivity` is not 4 or 8.
ValueError: If `nrows` or `ncols` are not positive integers.
ValueError: If `top_n` is not a positive integer.
"""
super().__init__(
on=on, groupby=groupby, agg_func=agg_func, num_workers=num_workers
)
if connectivity not in (4, 8):
raise ValueError(f"connectivity must be 4 or 8, got {connectivity}")
if nrows <= 0 or ncols <= 0:
raise ValueError(
f"nrows and ncols must be positive, got nrows={nrows}, ncols={ncols}"
)
if top_n <= 0:
raise ValueError(f"top_n must be positive, got {top_n}")
self.nrows = nrows
self.ncols = ncols
self.top_n = top_n
self.connectivity = connectivity
self.time_label = time_label
self.pvalue = pvalue
self._original_data: pd.DataFrame = pd.DataFrame()
@staticmethod
def _surrounded_positions(
active_idx: np.ndarray | list[int],
shape: tuple[int, int],
connectivity: int = 4,
min_neighbors: int | None = None,
return_counts: bool = False,
dtype: np.dtype = np.int64,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Find grid cells that are surrounded by active neighbors.
This function identifies cells in a 2D grid that have a sufficient number
of active neighbors based on the specified connectivity pattern. Input uses
flattened indices in C-order (row-major).
Args:
active_idx: Flattened indices of active cells. Will be deduplicated.
shape: Grid dimensions as (rows, cols).
connectivity: Neighbor pattern. Must be 4 (N,S,E,W) or 8 (adds diagonals).
min_neighbors: Minimum number of active neighbors required. If None,
requires all neighbors in the connectivity pattern to be active
(fully surrounded). Border cells cannot qualify when None.
return_counts: If True, also return the neighbor counts for selected indices.
dtype: Data type for output arrays.
Returns:
If return_counts is False:
Sorted array of flattened indices meeting the neighbor criterion.
If return_counts is True:
Tuple of (indices, counts) where counts[i] is the number of active
neighbors for indices[i].
Raises:
ValueError: If connectivity is not 4 or 8, if any active_idx is out of
bounds, if min_neighbors is invalid, or if shape is invalid.
Notes:
- Flattening uses C-order: idx = row * cols + col
- When min_neighbors=None, border cells are geometrically excluded since
they cannot have all neighbors active
- Results are always sorted for deterministic output
Examples:
.. dropdown:: Finding fully surrounded and partially surrounded cells on an 8×12 grid
>>> import numpy as np
>>> # 8×12 plate; 3×3 active block centered at (4,6)
>>> rows, cols = 8, 12
>>> block_rc = [(r, c) for r in range(3, 6) for c in range(5, 8)]
>>> active = np.array([r*cols + c for r, c in block_rc], dtype=np.int64)
>>>
>>> # Fully surrounded (default, since min_neighbors=None → all)
>>> res_all = EdgeCorrector._surrounded_positions(active, (rows, cols), connectivity=4)
>>> assert np.array_equal(res_all, np.array([4*cols + 6], dtype=np.int64))
>>>
>>> # Threshold: at least 3 of 4 neighbors
>>> idxs, counts = EdgeCorrector._surrounded_positions(
... active, (rows, cols), connectivity=4, min_neighbors=3, return_counts=True
... )
>>> assert (counts >= 3).all()
>>> assert (4*cols + 6) in idxs # center has 4
"""
# Validate connectivity
if connectivity not in (4, 8):
raise ValueError(f"connectivity must be 4 or 8, got {connectivity}")
# Validate shape
if len(shape) != 2 or shape[0] <= 0 or shape[1] <= 0:
raise ValueError(f"shape must be two positive integers, got {shape}")
rows, cols = shape
total_cells = rows * cols
# Coerce active_idx to 1D unique array
active_idx = np.asarray(active_idx, dtype=dtype).ravel()
active_idx = np.unique(active_idx)
# Validate bounds
if len(active_idx) > 0:
if active_idx.min() < 0 or active_idx.max() >= total_cells:
raise ValueError(
f"All active_idx must be in [0, {total_cells}), "
f"got range [{active_idx.min()}, {active_idx.max()}]"
)
# Determine max_neighbors and validate min_neighbors
max_neighbors = connectivity
if min_neighbors is None:
min_neighbors = max_neighbors
else:
if not (1 <= min_neighbors <= max_neighbors):
raise ValueError(
f"min_neighbors must be in [1, {max_neighbors}], got {min_neighbors}"
)
# Handle empty input
if len(active_idx) == 0:
if return_counts:
return np.array([], dtype=dtype), np.array([], dtype=dtype)
return np.array([], dtype=dtype)
# Build active mask
active_mask = np.zeros((rows, cols), dtype=bool)
rows_idx = active_idx // cols
cols_idx = active_idx % cols
active_mask[rows_idx, cols_idx] = True
# Define neighbor offsets based on connectivity
if connectivity == 4:
offsets = [(-1, 0), (1, 0), (0, -1), (0, 1)]
else: # connectivity == 8
offsets = [
(-1, 0),
(1, 0),
(0, -1),
(0, 1), # cardinal
(-1, -1),
(-1, 1),
(1, -1),
(1, 1), # diagonal
]
# Accumulate neighbor counts using aligned slicing
neighbor_count = np.zeros((rows, cols), dtype=np.int32)
for dr, dc in offsets:
# Calculate slice bounds for source (active_mask)
src_r_start = max(0, -dr)
src_r_end = rows - max(0, dr)
src_c_start = max(0, -dc)
src_c_end = cols - max(0, dc)
# Calculate slice bounds for destination (neighbor_count)
dst_r_start = max(0, dr)
dst_r_end = rows - max(0, -dr)
dst_c_start = max(0, dc)
dst_c_end = cols - max(0, -dc)
# Extract views
src_view = active_mask[src_r_start:src_r_end, src_c_start:src_c_end]
dst_view = neighbor_count[dst_r_start:dst_r_end, dst_c_start:dst_c_end]
# Accumulate
dst_view += src_view.astype(np.int32)
# Select cells that are active AND have sufficient neighbors
sufficient_neighbors = neighbor_count >= min_neighbors
selected_mask = active_mask & sufficient_neighbors
# Convert back to flattened indices
selected_rows, selected_cols = np.where(selected_mask)
result_idx = (selected_rows * cols + selected_cols).astype(dtype)
result_idx = np.sort(result_idx)
if return_counts:
# Get counts for selected indices
counts = neighbor_count[selected_rows, selected_cols].astype(dtype)
# Sort counts to match sorted indices
sort_order = np.argsort(selected_rows * cols + selected_cols)
counts = counts[sort_order]
return result_idx, counts
return result_idx
[docs]
def analyze(self, data: pd.DataFrame) -> pd.DataFrame:
"""Analyze and apply edge correction to grid-based colony measurements.
This method processes the input DataFrame by grouping according to specified
columns and applying edge correction to each group independently. Edge colonies
(those missing orthogonal neighbors) have their measurements capped to prevent
artificially inflated values.
Args:
data: DataFrame containing grid section numbers (GRID.SECTION_NUM) and
measurement data. Must include all columns specified in self.groupby
and self.on.
Returns:
DataFrame with corrected measurement values. Original structure is preserved
with only the measurement column modified for edge-affected rows.
Raises:
KeyError: If required columns are missing from input DataFrame.
ValueError: If data is empty or malformed.
Examples:
.. dropdown:: Applying edge correction to a 96-well plate dataset
>>> import pandas as pd
>>> import numpy as np
>>> from phenotypic.analysis import EdgeCorrector
>>> from phenotypic.tools.constants_ import GRID
>>>
>>> # Create sample grid data with measurements
>>> np.random.seed(42)
>>> data = pd.DataFrame({
... 'ImageName': ['img1'] * 96,
... GRID.SECTION_NUM: range(96),
... 'Area': np.random.uniform(100, 500, 96)
... })
>>>
>>> # Apply edge correction
>>> corrector = EdgeCorrector(
... on='Area',
... groupby=['ImageName'],
... nrows=8,
... ncols=12,
... top_n=10
... )
>>> corrected = corrector.analyze(data)
>>>
>>> # Check results
>>> results = corrector.results()
Notes:
- Stores original data in self._original_data for comparison
- Stores corrected data in self._latest_measurements for retrieval
- Groups are processed independently with their own thresholds
"""
from phenotypic.tools.constants_ import GRID
# Validate input
if data is None or len(data) == 0:
raise ValueError("Input data cannot be empty")
# Store original data for comparison
self._original_data = data
# Check required columns
section_col = str(GRID.SECTION_NUM)
required_cols = set(self.groupby + [section_col, self.on])
missing_cols = required_cols - set(data.columns)
if missing_cols:
raise KeyError(f"Missing required columns: {missing_cols}")
# Prepare configuration for _apply2group_func
config = {
"nrows": self.nrows,
"ncols": self.ncols,
"top_n": self.top_n,
"connectivity": self.connectivity,
"on": self.on,
"pvalue": self.pvalue,
"time_label": self.time_label,
}
# Build aggregation dictionary to preserve all columns
groupby_cols = self.groupby + [section_col]
if self.time_label in data:
groupby_cols = groupby_cols + [self.time_label]
# Determine which columns to aggregate
agg_dict = {}
for col in data.columns:
if col not in groupby_cols:
# Use specified agg_func for measurement column, 'first' for others
if col == self.on:
agg_dict[col] = self.agg_func
else:
agg_dict[col] = "first"
agg_data = data.groupby(by=groupby_cols, as_index=False).agg(agg_dict)
# Handle empty groupby case
if len(self.groupby) == 0:
# Process entire dataset as single group
corrected_data = [self.__class__._apply2group_func(agg_data, **config)]
else:
grouped = agg_data.groupby(by=self.groupby, as_index=False)
corrected_data = Parallel(n_jobs=self.n_jobs)(
delayed(self.__class__._apply2group_func)(group, **config)
for _, group in grouped
)
# Store results
if corrected_data:
self._latest_measurements = pd.concat(corrected_data, ignore_index=True)
else:
self._latest_measurements = pd.DataFrame()
return self._latest_measurements
[docs]
def show(
self,
figsize: tuple[int, int] | None = None,
max_groups: int = 20,
collapsed: bool = True,
criteria: dict[str, any] | None = None,
**kwargs,
) -> tuple[Figure, plt.Axes]:
"""Visualize edge correction results.
Displays the distribution of measurements for the last time point, highlighting
surrounded vs. edge colonies and the calculated correction threshold.
Args:
figsize: Figure size (width, height).
max_groups: Maximum number of groups to display.
collapsed: If True, show groups stacked vertically.
criteria: Filtering criteria.
**kwargs: Additional matplotlib parameters to customize the plot. Common options include:
- dpi: Figure resolution (default 100)
- facecolor: Figure background color
- edgecolor: Figure edge color
- grid_alpha: Alpha value for grid lines
- legend_loc: Legend location (default 'best')
- legend_fontsize: Font size for legend (default 8 or 9)
- marker_alpha: Alpha value for scatter plot markers
- line_width: Line width for box plots and fence lines
Returns:
Tuple of (Figure, Axes).
"""
if self._original_data.empty:
raise RuntimeError("No results to display. Call analyze() first.")
data = self._original_data.copy()
if criteria is not None:
data = self._filter_by(df=data, criteria=criteria, copy=False)
if data.empty:
raise ValueError("No data matches the specified criteria")
# Determine groups
if len(self.groupby) == 1:
groups = data[self.groupby[0]].unique()
group_col = self.groupby[0]
else:
data["_group_key"] = data[self.groupby].astype(str).agg(" | ".join, axis=1)
groups = data["_group_key"].unique()
group_col = "_group_key"
if len(groups) > max_groups:
print(f"Warning: Displaying first {max_groups} groups out of {len(groups)}")
groups = groups[:max_groups]
if collapsed:
return self._show_collapsed(data, groups, group_col, figsize, **kwargs)
else:
return self._show_individual(data, groups, group_col, figsize, **kwargs)
def _show_collapsed(
self,
data: pd.DataFrame,
groups,
group_col: str,
figsize: tuple[int, int] | None,
**kwargs,
) -> tuple[Figure, plt.Axes]:
# Extract figure-level kwargs
fig_kwargs = {
k: v for k, v in kwargs.items() if k in ("dpi", "facecolor", "edgecolor")
}
legend_fontsize = kwargs.get("legend_fontsize", 9)
n_groups = len(groups)
if figsize is None:
figsize = (10, max(6, 0.5 * n_groups + 2))
fig, ax = plt.subplots(figsize=figsize, **fig_kwargs)
added_labels = set()
for idx, group_name in enumerate(groups):
y_pos = n_groups - idx
group_data = data[data[group_col] == group_name]
stats = self._calculate_group_stats(group_data)
if stats is None:
continue
lt_df = stats["last_time_df"]
threshold = stats["threshold"]
surrounded_mask = stats["surrounded_mask"]
edge_mask = stats["edge_mask"]
# Range line
vals = lt_df[self.on].values
if len(vals) > 0:
ax.hlines(
y_pos, vals.min(), vals.max(), colors="lightgray", lw=1.5, zorder=1
)
# Threshold
if not np.isinf(threshold):
lbl = "Threshold"
if lbl not in added_labels:
added_labels.add(lbl)
else:
lbl = None
ax.plot(
[threshold, threshold],
[y_pos - 0.2, y_pos + 0.2],
color="#F4A261",
lw=2.5,
label=lbl,
zorder=2,
)
# Jitter
y_jitter = np.random.normal(y_pos, 0.05, len(lt_df))
is_clipped = lt_df[self.on] > threshold
# Helper for scatter plots
def add_scatter(mask, color, marker, label_key):
if mask.any():
lbl = label_key
if lbl not in added_labels:
added_labels.add(lbl)
else:
lbl = None
ax.scatter(
lt_df.loc[mask, self.on],
y_jitter[mask],
c=color,
marker=marker,
s=30 if marker == "o" else 40,
alpha=0.6 if marker == "o" else 0.8,
label=lbl,
zorder=3,
)
# Inner Pass
add_scatter(surrounded_mask & (~is_clipped), "#2E86AB", "o", "Inner (Pass)")
# Inner Clipped
add_scatter(surrounded_mask & is_clipped, "#2E86AB", "x", "Inner (Clipped)")
# Edge Pass
add_scatter(edge_mask & (~is_clipped), "#E63946", "o", "Edge (Pass)")
# Edge Clipped
add_scatter(edge_mask & is_clipped, "#E63946", "x", "Edge (Clipped)")
# Means
inner_vals = lt_df.loc[surrounded_mask, self.on]
edge_vals = lt_df.loc[edge_mask, self.on]
if len(inner_vals) > 0:
lbl = "Inner Mean"
if lbl not in added_labels:
added_labels.add(lbl)
else:
lbl = None
mean_val = inner_vals.mean()
ax.plot(
[mean_val, mean_val],
[y_pos - 0.25, y_pos + 0.25],
color="#2E86AB",
linewidth=2.5,
label=lbl,
zorder=4,
linestyle="--",
)
if len(edge_vals) > 0:
lbl = "Edge Mean"
if lbl not in added_labels:
added_labels.add(lbl)
else:
lbl = None
mean_val = edge_vals.mean()
ax.plot(
[mean_val, mean_val],
[y_pos - 0.25, y_pos + 0.25],
color="#E63946",
linewidth=2.5,
label=lbl,
zorder=4,
linestyle="--",
)
# P-value
if self.pvalue != 0 and len(inner_vals) > 0 and len(edge_vals) > 0:
pval = self._perm_test(inner_vals, edge_vals)
mean_inner = inner_vals.mean()
mean_edge = edge_vals.mean()
# Bracket parameters
bracket_y = y_pos + 0.3
bracket_h = 0.05
# Draw bracket
ax.plot(
[mean_inner, mean_inner, mean_edge, mean_edge],
[
bracket_y,
bracket_y + bracket_h,
bracket_y + bracket_h,
bracket_y,
],
color="black",
linewidth=1,
zorder=5,
)
# Add p-value text
mid_x = (mean_inner + mean_edge) / 2
ax.text(
mid_x,
bracket_y + bracket_h + 0.05,
f"p={pval:.3f}",
ha="center",
va="bottom",
fontsize=8,
)
ax.set_yticks(range(1, n_groups + 1))
ax.set_yticklabels(groups[::-1])
ax.set_xlabel(self.on)
ax.set_title(f"Edge Correction (Top N={self.top_n}, p={self.pvalue})")
ax.legend(loc="best", fontsize=legend_fontsize)
plt.tight_layout()
return fig, ax
def _show_individual(
self,
data: pd.DataFrame,
groups,
group_col: str,
figsize: tuple[int, int] | None,
**kwargs,
) -> tuple[Figure, plt.Axes]:
# Extract figure-level kwargs
fig_kwargs = {
k: v for k, v in kwargs.items() if k in ("dpi", "facecolor", "edgecolor")
}
legend_fontsize = kwargs.get("legend_fontsize", 8)
n_groups = len(groups)
n_cols = min(3, n_groups)
n_rows = (n_groups + n_cols - 1) // n_cols
if figsize is None:
figsize = (5 * n_cols, 4 * n_rows)
fig, axes = plt.subplots(
n_rows, n_cols, figsize=figsize, squeeze=False, **fig_kwargs
)
axes = axes.flatten()
for idx, group_name in enumerate(groups):
ax = axes[idx]
group_data = data[data[group_col] == group_name]
stats = self._calculate_group_stats(group_data)
if stats is None:
ax.text(0.5, 0.5, "Insufficient Data", ha="center")
continue
lt_df = stats["last_time_df"]
threshold = stats["threshold"]
surrounded_mask = stats["surrounded_mask"]
edge_mask = stats["edge_mask"]
vals = lt_df[self.on].values
is_clipped = lt_df[self.on] > threshold
ax.boxplot(
[vals],
positions=[1],
widths=0.3,
patch_artist=True,
showfliers=False,
boxprops=dict(facecolor="lightgray", alpha=0.3),
)
x_jitter = np.random.normal(1, 0.04, len(lt_df))
# Inner Pass
mask_ip = surrounded_mask & (~is_clipped)
if mask_ip.any():
ax.scatter(
x_jitter[mask_ip],
lt_df.loc[mask_ip, self.on],
c="#2E86AB",
marker="o",
s=30,
alpha=0.6,
label="Inner (Pass)",
)
# Inner Clipped
mask_ic = surrounded_mask & is_clipped
if mask_ic.any():
ax.scatter(
x_jitter[mask_ic],
lt_df.loc[mask_ic, self.on],
c="#2E86AB",
marker="x",
s=40,
alpha=0.8,
label="Inner (Clipped)",
)
# Edge Pass
mask_ep = edge_mask & (~is_clipped)
if mask_ep.any():
ax.scatter(
x_jitter[mask_ep],
lt_df.loc[mask_ep, self.on],
c="#E63946",
marker="o",
s=30,
alpha=0.6,
label="Edge (Pass)",
)
# Edge Clipped
mask_ec = edge_mask & is_clipped
if mask_ec.any():
ax.scatter(
x_jitter[mask_ec],
lt_df.loc[mask_ec, self.on],
c="#E63946",
marker="x",
s=40,
alpha=0.8,
label="Edge (Clipped)",
)
if not np.isinf(threshold):
ax.axhline(
y=threshold, color="#F4A261", linestyle="--", label="Threshold"
)
ax.set_title(group_name)
ax.set_ylabel(self.on)
ax.set_xticks([])
if idx == 0:
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax.legend(
by_label.values(),
by_label.keys(),
loc="best",
fontsize=legend_fontsize,
)
for idx in range(n_groups, len(axes)):
axes[idx].set_visible(False)
plt.tight_layout()
return fig, axes
def _calculate_group_stats(self, group: pd.DataFrame):
from phenotypic.tools.constants_ import GRID
if len(group) == 0:
return None
tmax = group[self.time_label].max()
last_time_group = group[group[self.time_label] == tmax].copy()
present_sections = last_time_group[GRID.SECTION_NUM].dropna().unique()
if len(present_sections) == 0:
return None
active_indices = present_sections.astype(int)
try:
surrounded_idx = self._surrounded_positions(
active_idx=active_indices,
shape=(self.nrows, self.ncols),
connectivity=self.connectivity,
min_neighbors=None,
return_counts=False,
)
except ValueError:
return None
surrounded_idx_set = set(surrounded_idx)
if len(surrounded_idx_set) == 0:
return {
"last_time_df": last_time_group,
"threshold": np.inf,
"surrounded_mask": pd.Series(False, index=last_time_group.index),
"edge_mask": pd.Series(True, index=last_time_group.index),
}
surrounded_mask = last_time_group[GRID.SECTION_NUM].isin(surrounded_idx_set)
edge_mask = ~surrounded_mask & last_time_group[GRID.SECTION_NUM].isin(
present_sections
)
if self.on not in group.columns:
return None
last_inner_values = last_time_group.loc[surrounded_mask, self.on]
threshold = np.inf
should_correct = True
if self.pvalue != 0:
last_edge_values = last_time_group.loc[edge_mask, self.on]
if len(last_edge_values) > 0 and len(last_inner_values) > 0:
perm_results = permutation_test(
data=(last_inner_values, last_edge_values),
statistic=lambda x, y: np.mean(x) - np.mean(y),
permutation_type="independent",
n_resamples=1000,
alternative="two-sided",
)
if perm_results.pvalue > self.pvalue:
should_correct = False
if should_correct:
actual_top_n = min(self.top_n, len(last_inner_values))
if actual_top_n > 0:
top_values = last_inner_values.nlargest(actual_top_n)
threshold = top_values.mean()
return {
"last_time_df": last_time_group,
"threshold": threshold,
"surrounded_mask": surrounded_mask,
"edge_mask": edge_mask,
}
[docs]
def results(self) -> pd.DataFrame:
"""Return the corrected measurement DataFrame.
Returns the DataFrame with edge-corrected measurements from the most recent
call to analyze(). This allows retrieval of results after processing.
Returns:
DataFrame with corrected measurements. If analyze() has not been called,
returns an empty DataFrame.
Examples:
.. dropdown:: Retrieving corrected measurements after analysis
>>> corrector = EdgeCorrector(
... on='Area',
... groupby=['ImageName']
... )
>>> corrected = corrector.analyze(data)
>>> results = corrector.results() # Same as corrected
>>> assert results.equals(corrected)
Notes:
- Returns the DataFrame stored in self._latest_measurements
- Contains the same structure as input but with corrected values
- Use this method to retrieve results after calling analyze()
"""
return self._latest_measurements
@staticmethod
def _apply2group_func(
group: pd.DataFrame,
on: str,
nrows: int,
ncols: int,
top_n: int,
time_label: str,
connectivity: int,
pvalue: float,
) -> pd.DataFrame:
"""
Note:
- assumes "Grid_SectionNum" from a `GridFinder` is in the dataframe groups
= applies permutation test on the last time-point to see if theres a statistically significant difference
- caps clips all the data to the last time point
"""
from phenotypic.tools.constants_ import GRID
section_col = GRID.SECTION_NUM
# Handle empty groups
if len(group) == 0:
return group
# Make a copy to avoid modifying the original
group: pd.DataFrame = group.copy()
if time_label in group.columns:
tmax = group.loc[:, time_label].max()
last_time_group = group.loc[group.loc[:, time_label] == tmax, :]
else:
last_time_group = group
# Get unique section numbers present in the data
present_sections = last_time_group.loc[:, section_col].dropna().unique()
# Handle case where no sections are present
if len(present_sections) == 0:
return group
# Convert section numbers to 0-indexed flattened indices
# Assuming section numbers are 0-indexed already (row * ncols + col)
active_indices = present_sections.astype(int)
# Find fully-surrounded (interior) sections
try:
surrounded_idx = EdgeCorrector._surrounded_positions(
active_idx=active_indices,
shape=(nrows, ncols),
connectivity=connectivity,
min_neighbors=None, # Require all neighbors (fully surrounded)
return_counts=False,
)
except ValueError:
# If validation fails, return group unchanged
return group
# Identify edge sections (all sections - surrounded sections)
surrounded_idx = set(surrounded_idx)
edge_idx = [sec for sec in present_sections if sec not in surrounded_idx]
# If no inner sections, return unchanged
if len(surrounded_idx) == 0:
return group
# Calculate threshold from top N inner values
# ===========================================
if on not in group.columns:
return group
last_inner_values: pd.Series = last_time_group.loc[
last_time_group.loc[:, GRID.SECTION_NUM].isin(surrounded_idx), on
]
if pvalue != 0:
last_edge_values: pd.Series = last_time_group.loc[
last_time_group.loc[:, GRID.SECTION_NUM].isin(edge_idx), on
]
# If difference is not statistically significant, don't apply correction
if EdgeCorrector._perm_test(last_inner_values, last_edge_values) > pvalue:
return group
# Use actual number of values if fewer than top_n available
actual_top_n = min(top_n, len(last_inner_values))
if actual_top_n == 0: # If no inner colonies
return group
# Get top N values and calculate threshold
top_values = last_inner_values.nlargest(actual_top_n)
threshold = top_values.mean()
# Apply correction: cap ALL values that exceed for fairness
group.loc[:, on] = np.clip(group.loc[:, on], a_min=0, a_max=threshold)
return group
@staticmethod
def _perm_test(surrounded, edge):
return permutation_test(
data=(surrounded, edge),
statistic=lambda x, y: np.mean(x) - np.mean(y),
permutation_type="independent",
n_resamples=1000,
alternative="two-sided",
).pvalue
EdgeCorrector.__doc__ = EDGE_CORRECTION.append_rst_to_doc(EdgeCorrector.__doc__)