import itertools
from typing import Any, Callable, Dict, List, Literal, Tuple, Union
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.optimize as optimize
from joblib import delayed, Parallel
from sklearn.metrics import (
mean_absolute_error,
mean_squared_error,
root_mean_squared_error,
)
from phenotypic.analysis.abc_ import ModelFitter
from phenotypic.tools.constants_ import MeasurementInfo
class LOG_GROWTH_MODEL(MeasurementInfo):
@classmethod
def category(cls) -> str:
return "LogGrowthModel"
R_FIT = "r", "The intrinsic growth rate"
K_FIT = "K", "The carrying capacity"
N0_FIT = "N0", "The initial number of the colony size metric being fitted"
LAM = (
"lamda",
"The regularization factor applied to the max specific growth rate and initial population size",
)
ALPHA = (
"alpha",
(
"The penalty factor applied to relative difference of "
"the carrying capacity from the largest measurement"
),
)
GROWTH_RATE = "µmax", "The growth rate of the colony calculated as (K*r)/4"
K_MAX = "Kmax", "The upper bound of the carrying capacity for model fitting"
NUM_SAMPLES = "NumSamples", "The number of samples used for model fitting"
LOSS = "OptimizerLoss", "The loss of model fitting"
STATUS = "OptimizerStatus", "The output of the optimizer status"
MAE = "MAE", "The mean absolute error"
MSE = "MSE", "The mean squared error"
RMSE = "RMSE", "The root mean squared error"
[docs]
class LogGrowthModel(ModelFitter):
r"""
Represents a log growth model fitter.
This class defines methods and attributes to configure and fit logarithmic
growth models to grouped data. It provides functionality for analyzing and
visualizing the fitted models as well as exposing the results for further
processing.
Logistic Kinetics Model:
.. math::
N(t) = \frac{K}{1 + \frac{K - N_0}{N_0} e^{-rt}}
:math:`N_t`: population size at time :math:`t`
:math:`N_0`: initial population size at time :math:`t`
:math:`r`: growth rate
:math:`K`: carrying capacity (maximum population size)
From this we derive:
.. math::
\mu_{\max} = \frac{K r}{4}
:math:`\mu_{\max}`: maximum specific growth rate
Loss Function:
To solve for the parameters, we use the following loss function with the
SciPy linear least-squares solver:
.. math::
J(K, N_0, r) =
\frac{1}{n}\sum_{i=1}^{n}
\frac{1}{2}\left(f_{K,N_0,r}(t^{(i)}) - N_t^{(i)}\right)^2
+ \lambda\left(\left(\frac{dN}{dt}\right)^2 + N_0^2\right)
+ \alpha \frac{\lvert K - \max(N_t) \rvert}{N_t}
:math:`\lambda`: regularization term for growth rate and initial population size
:math:`\alpha`: penalty term for deviations in carrying capacity relative to
the largest measurement
Attributes:
lam (float): The penalty factor applied to growth rates.
alpha (float): The maximum penalty factor applied to the carrying
capacity.
loss (Literal["linear"]): The loss calculation method used for fitting.
verbose (bool): A flag to enable or disable detailed logging.
time_label (str): The column name representing the time dimension
in the input data.
Kmax_label (str | None): The column name for the maximum carrying capacity
values, if provided.
"""
[docs]
def __init__(
self,
on: str,
groupby: List[str],
time_label: str = "Metadata_Time",
agg_func: Callable | str | list | dict | None = "mean",
lam=1.2,
alpha=2,
Kmax_label: str | None = None,
loss: Literal["linear"] = "linear",
verbose: bool = False,
n_jobs: int = 1,
):
"""
This class initializes parameters for a data processing or modeling procedure.
It takes configuration arguments for handling data grouping, time management,
aggregation, penalties, loss calculation, and verbosity.
Args:
on (str): The target variable or column to process.
groupby (List[str]): The columns that define the grouping structure.
time_label (str): Column name that represents time in the data. Defaults to
'Metadata_Time'.
agg_func (Callable | str | list | dict | None): Aggregation function(s) to
apply to grouped data. Parameter is fed to
`pandas.DataFrame.groupby.agg()`. Defaults to 'mean'.
lam: The penalty factor applied to growth rates. Defaults to 1.2.
alpha: The maximum penalty factor applied to the carrying capacity.
Defaults to 2.
Kmax_label (str | None): Column name that provides maximum K value for
processing. Defaults to None.
loss (Literal["linear"]): Loss calculation method to apply. Defaults to
"linear".
verbose (bool): If True, enables detailed logging for process execution.
Defaults to False.
n_jobs (int): Number of parallel jobs to execute. Defaults to 1.
"""
super().__init__(on=on, groupby=groupby, agg_func=agg_func, num_workers=n_jobs)
self.lam = lam
self.alpha = alpha
self.loss = loss
self.verbose = verbose
self.time_label = time_label
self.Kmax_label = Kmax_label
[docs]
def analyze(self, data: pd.DataFrame) -> pd.DataFrame:
data: pd.DataFrame = data.copy(deep=True)
data.loc[:, self.time_label] = self._ensure_float_array(
data.loc[:, self.time_label]
)
self._latest_measurements = data
apply2group_kwargs = dict(
groupby_names=self.groupby,
model=self.__class__.model_func,
time_label=self.time_label,
size_label=self.on,
Kmax_label=self.Kmax_label,
lam=self.lam,
alpha=self.alpha,
loss=self.loss,
verbose=self.verbose,
)
# aggregate so that only one sample per timepoint
agg_dict = {self.on: self.agg_func}
if self.Kmax_label is not None:
agg_dict[self.Kmax_label] = (
"max" # Use max for Kmax as it's a carrying capacity
)
agg_data = data.groupby(
by=self.groupby + [self.time_label], as_index=False
).agg(agg_dict)
# Create groups
grouped = agg_data.groupby(by=self.groupby, as_index=True)
if self.n_jobs == 1:
model_res = []
for key, group in grouped:
model_res.append(
self.__class__._apply2group_func(key, group,
**apply2group_kwargs)
)
else:
model_res = Parallel(n_jobs=self.n_jobs)(
delayed(self.__class__._apply2group_func)(
key, group, **apply2group_kwargs
)
for key, group in grouped
)
self._latest_model_scores = pd.concat(model_res, axis=0).reset_index(drop=False)
self._latest_model_scores.insert(
loc=len(self._latest_model_scores.columns),
column=LOG_GROWTH_MODEL.LAM,
value=self.lam,
)
self._latest_model_scores.insert(
loc=len(self._latest_model_scores.columns),
column=LOG_GROWTH_MODEL.ALPHA,
value=self.alpha,
)
return self._latest_model_scores
[docs]
def show(self,
tmax: int | float | None = None,
criteria: Dict[str, Union[Any, List[Any]]] | None = None,
figsize=(6, 4),
cmap: str | None = "tab20",
legend=True,
ax: plt.Axes = None,
**kwargs,
) -> Tuple[plt.Figure, plt.Axes]:
"""
Visualizes model predictions alongside measurements, allowing optional
filtering by specified criteria and plotting configuration.
Args:
tmax (int | float | None, optional): The maximum time value for plotting. If
set to None, the maximum time value will be determined from the data
automatically.
criteria (Dict[str, Union[Any, List[Any]]] | None, optional): A dictionary
specifying filtering criteria for data selection. When provided, only
data matching the criteria will be used for plotting.
figsize (tuple, optional): A tuple specifying the size of the figure.
Defaults to (6, 4).
cmap (str | None, optional): A string representing either a matplotlib colormap name
or a single color (e.g., 'red', '#FF0000'). If a matplotlib colormap is provided,
colors will be cycled through it. If a single color is provided, all lines will
use that color. Defaults to 'tab20'.
legend (bool, optional): A boolean that controls whether a legend is
displayed on the plot. Defaults to True.
ax (plt.Axes, optional): A matplotlib Axes object on which to plot. If not
provided, a new figure and axes object will be created.
**kwargs: Additional matplotlib parameters to customize the plot. Common options include:
- dpi: Figure resolution (default 100)
- facecolor: Figure background color
- edgecolor: Figure edge color
- line_width: Line width for prediction lines
- marker_size: Size of data point markers
- elinewidth: Error bar line width
- capsize: Error bar cap size
- title: Custom figure title
- xlabel: Custom x-axis label
- ylabel: Custom y-axis label
- legend_loc: Legend location (default 'best')
- legend_fontsize: Font size for legend
Returns:
Tuple[plt.Figure, plt.Axes]: A tuple containing the matplotlib Figure and
Axes objects used for plotting.
Raises:
KeyError: If the group keys for model results and measurements do not
align, or if specified columns are missing from the input data.
"""
# Extract figure-level kwargs
fig_kwargs = {
k: v for k, v in kwargs.items() if k in ("dpi", "facecolor", "edgecolor")
}
line_width = kwargs.get("line_width", None)
marker_size = kwargs.get("marker_size", None)
elinewidth = kwargs.get("elinewidth", 1)
capsize = kwargs.get("capsize", 2)
legend_loc = kwargs.get("legend_loc", "best")
legend_fontsize = kwargs.get("legend_fontsize", None)
if ax is None:
fig, ax = plt.subplots(figsize=figsize, **fig_kwargs)
else:
fig = ax.get_figure()
# apply filtering of data table
if criteria is not None:
filtered_model_scores = self._filter_by(
df=self._latest_model_scores, criteria=criteria, copy=True
)
filtered_measurements = self._filter_by(
df=self._latest_measurements, criteria=criteria, copy=True
)
else:
filtered_model_scores = self._latest_model_scores.copy()
filtered_measurements = self._latest_measurements.copy()
if filtered_measurements.empty:
import warnings
warnings.warn("No data found matching the criteria. Returning empty plot.")
return fig, ax
model_groups = {
model_keys: model_groups
for model_keys, model_groups in filtered_model_scores.groupby(
by=self.groupby
)
}
meas_groups = {
meas_keys: meas_groups
for meas_keys, meas_groups in filtered_measurements.groupby(by=self.groupby)
}
filtered_measurements.loc[:, self.time_label] = self._ensure_float_array(
filtered_measurements.loc[:, self.time_label]
)
timepoints = pd.Series(filtered_measurements.loc[:, self.time_label].unique())
step = np.abs(np.mean(timepoints.sort_values().diff().dropna()))
if np.isnan(step) or step <= 0:
step = 1.0
tmax = timepoints.max() if tmax is None else tmax
t = np.arange(stop=tmax + step, step=step)
# Handle both colormap and single color
if cmap is not None:
try:
# Try to get as a matplotlib colormap
cmap_obj = cm.get_cmap(cmap)
color_iter = itertools.cycle(
cmap_obj(
np.linspace(
start=0, stop=1, num=len(model_groups),
endpoint=False
)
)
)
except (ValueError, AttributeError):
# If it fails, treat as a single color and create infinite cycle
color_iter = itertools.cycle([cmap])
else:
# Default to None for automatic matplotlib coloring
color_iter = itertools.cycle([None]*len(model_groups))
for model_key, model_group in model_groups.items():
curr_meas = meas_groups[model_key]
curr_color = next(color_iter)
y_pred = self.model_func(
t=t,
r=model_group[LOG_GROWTH_MODEL.R_FIT].iloc[0],
K=model_group[LOG_GROWTH_MODEL.K_FIT].iloc[0],
N0=model_group[LOG_GROWTH_MODEL.N0_FIT].iloc[0],
)
plot_kwargs = {}
if curr_color is not None:
plot_kwargs["color"] = curr_color
if line_width is not None:
plot_kwargs["linewidth"] = line_width
ax.plot(t, y_pred, **plot_kwargs)
curr_time_groups = curr_meas.groupby(by=self.time_label)
curr_mean = curr_time_groups[self.on].mean()
curr_stddev = curr_time_groups[self.on].std()
curr_stderr = curr_stddev/np.sqrt(curr_time_groups[self.on].count())
# noinspection PyUnresolvedReferences
errorbar_kwargs = {
"x" : curr_mean.index.values,
"y" : curr_mean.values,
"yerr" : curr_stderr,
"fmt" : "o",
"elinewidth": elinewidth,
"capsize" : capsize,
"label" : kwargs.get("label", f"{model_key[0]}"),
}
if curr_color is not None:
errorbar_kwargs["color"] = curr_color
errorbar_kwargs["ecolor"] = curr_color
if marker_size is not None:
errorbar_kwargs["markersize"] = marker_size
ax.errorbar(**errorbar_kwargs)
if legend:
# Create legend and check if it fits within the axes
legend_kwargs = {"loc": legend_loc}
if legend_fontsize is not None:
legend_kwargs["fontsize"] = legend_fontsize
legend_obj = ax.legend(**legend_kwargs)
# Draw to ensure bounding boxes are available
fig.canvas.draw()
# Get bounding boxes in display coordinates
legend_bbox = legend_obj.get_window_extent()
axes_bbox = ax.get_window_extent()
# Check if legend is larger than axes (with small tolerance)
if (
legend_bbox.width > axes_bbox.width*0.95
or legend_bbox.height > axes_bbox.height*0.95
):
legend_obj.remove()
ax.set_title("mean±SE")
return fig, ax
[docs]
def results(self) -> pd.DataFrame:
return self._latest_model_scores
[docs]
@staticmethod
def model_func(t: np.ndarray[float] | float, r: float, K: float, N0: float):
"""
Computes the value of the logistic growth model for a given time point or array
of time points and parameters. The logistic model describes growth that
initially increases exponentially but levels off as the population reaches
a carrying capacity.
This static method uses the formula:
N(t) = K / (1 + [(K - N0) / N0] * exp(-r * t))
Where:
t: Time (independent variable, can be scalar or array).
r: Growth rate.
K: Carrying capacity (maximum population size).
N0: Initial population size.
Args:
t (np.ndarray[float] | float): Time at which the population is calculated.
Can be a single value or an array of values.
r (float): Growth rate of the population.
K (float): Carrying capacity or the maximum population size.
N0 (float): Initial population size at time t=0.
Returns:
float | np.ndarray[float]: The computed population size at the given time
or array of times based on the logistic growth model.
"""
a = (K - N0)/N0
return K/(1 + a*np.exp(-r*t))
@staticmethod
def _loss_func(params, t, y, lam, alpha):
r"""
Computes a combined loss which includes both the residuals from the predicted
values using a logarithmic growth model, a regularization term, and a penalty
for deviations in the carrying capacity (K).
To solve for the parameters, we use the following loss function with the
SciPy linear least-squares solver:
.. math::
J(K, N_0, r) =
\frac{1}{n}\sum_{i=1}^{n}
\frac{1}{2}\left(f_{K,N_0,r}(t^{(i)}) - N_t^{(i)}\right)^2
+ \lambda\left(\left(\frac{dN}{dt}\right)^2 + N_0^2\right)
+ \alpha \frac{\lvert K - \max(N_t) \rvert}{N_t}
:math:`\lambda`: regularization term for growth rate and initial population size
:math:`\alpha`: penalty term for deviations in carrying capacity relative to
the largest measurement
The function calculates the residuals (difference between actual and predicted
values), a regularization term based on biological parameters, and applies a
penalty proportional to the deviation of K from the observed maximum value
within the data. A small epsilon is used to ensure numerical stability during
penalty calculation.
Note:
This function is used in conjunction with the `scipy.optimize.least_squares`
solver to find the parameters
Args:
params (List[float]): A list containing the parameters [r, K, N0], where:
r: Growth rate.
K: Carrying capacity.
N0: Initial population size.
t (Union[List[float], np.ndarray]): Time points for the observations.
y (Union[List[float], np.ndarray, pd.Series]): Observed population size
corresponding to the time points t. Can be a list, numpy array, or
pandas.Series object.
lam (float): Regularization parameter for the specific growth rate and
initial population size.
alpha (float): Scaling parameter for the K-based penalty.
Returns:
np.ndarray: A combined loss array consisting of the residuals,
regularization sterms, and the K penalty. The array includes:
- Residuals: Difference between observed and model-predicted values.
- Regularization terms: Regularization applied to dN/dt and N0.
- K Penalty: Penalty term applied based on the deviation of K.
"""
r, K, N0 = params
# cost function (residuals)
cost_func = y - LogGrowthModel.model_func(t=t, r=r, K=K, N0=N0)
# regularization term
dN_dt = r*K/4
reg_term = np.sqrt(lam)*np.array([dN_dt, N0])
# K-based penalty
if hasattr(y, "values"):
y_array = y.values
else:
y_array = np.array(y)
# Get the value of the last time point
# assumes thats the best indicator for carrying capacity
y_max_observed = y_array[np.argmax(t)]
# Numerical stability epsilon
epsilon = 1e-8*np.median(np.abs(y_array[y_array > 0]))
if epsilon == 0 or np.isnan(epsilon):
epsilon = 1e-8
# Relative K penalty
K_penalty_weight = np.sqrt(alpha)
K_penalty = (
K_penalty_weight*np.abs(K - y_max_observed)/(y_max_observed + epsilon)
)
return np.hstack(
[cost_func, reg_term, [K_penalty]]
) # K_penalty is scalar so we need to wrap in a list
@staticmethod
def _apply2group_func(
group_key: tuple,
group: pd.DataFrame,
groupby_names: tuple,
model: Callable,
time_label: str,
size_label: str,
Kmax_label: str | None,
lam: float,
alpha: float,
loss: Literal["linear"],
verbose: bool,
):
t_data = group[time_label]
size_data = group[size_label]
i_min = 0
n_samples = len(t_data)
r_min, r_max = 1e-5, np.inf
N0_min, N0_max = 0, size_data.min()
# Safety check since max bound must be higher than min bound
if N0_max <= N0_min:
N0_max = N0_min + 1
if Kmax_label is None:
K_max = size_data.max()
else:
K_max = group[Kmax_label].max()
if K_max == np.nan:
K_max = size_data.max() + 1
K_min = i_min
try:
out = optimize.least_squares(
LogGrowthModel._loss_func,
x0=[1e-5, size_data.max(), 0],
bounds=(
[r_min, K_min, N0_min],
[r_max, K_max, N0_max],
),
kwargs=dict(
t=t_data,
y=size_data,
lam=lam,
alpha=alpha,
),
verbose=verbose,
method="trf",
loss=loss,
)
x = out.x
fitted_values = {
LOG_GROWTH_MODEL.R_FIT : x[0],
LOG_GROWTH_MODEL.K_FIT : x[1],
LOG_GROWTH_MODEL.N0_FIT : x[2],
LOG_GROWTH_MODEL.GROWTH_RATE: (x[0]*x[1])/4,
}
y_pred = LogGrowthModel.model_func(t=t_data, r=x[0], K=x[1], N0=x[2])
model_stats = {
LOG_GROWTH_MODEL.K_MAX : K_max,
LOG_GROWTH_MODEL.NUM_SAMPLES: n_samples,
LOG_GROWTH_MODEL.LOSS : out.cost,
LOG_GROWTH_MODEL.STATUS : out.status,
LOG_GROWTH_MODEL.MAE : mean_absolute_error(size_data, y_pred),
LOG_GROWTH_MODEL.MSE : mean_squared_error(size_data, y_pred),
LOG_GROWTH_MODEL.RMSE : root_mean_squared_error(size_data,
y_pred),
}
except ValueError:
fitted_values = {
LOG_GROWTH_MODEL.R_FIT : np.nan,
LOG_GROWTH_MODEL.K_FIT : np.nan,
LOG_GROWTH_MODEL.N0_FIT : np.nan,
LOG_GROWTH_MODEL.GROWTH_RATE: np.nan,
}
model_stats = {
LOG_GROWTH_MODEL.K_MAX : np.nan,
LOG_GROWTH_MODEL.NUM_SAMPLES: np.nan,
LOG_GROWTH_MODEL.LOSS : np.nan,
LOG_GROWTH_MODEL.STATUS : np.nan,
LOG_GROWTH_MODEL.MAE : np.nan,
LOG_GROWTH_MODEL.MSE : np.nan,
LOG_GROWTH_MODEL.RMSE : np.nan,
}
return pd.DataFrame(
data={**fitted_values, **model_stats},
index=pd.MultiIndex.from_tuples(tuples=[group_key],
names=groupby_names),
)