"""
Unified Segmentation Interface
Provides a simplified, unified API for all segmentation tasks while maintaining
backward compatibility with existing implementations.
Supported Modalities:
- SXT (Soft X-ray Tomography): cell, nucleus, mitochondria
- SIM (Structured Illumination Microscopy): ER, mitochondria, ISG
- WFM (Widefield Microscopy): cell, nucleus, general organelles
- Cryo-ET (Cryo-Electron Tomography): filaments, membranes
"""
import os
import sys
import numpy as np
from typing import Optional, Dict, Any
from .base import BaseSegmenter
# Define model root relative to the project root (ipa directory)
# This ensures paths are robust regardless of where the script is executed from
PACKAGE_ROOT = os.path.dirname(os.path.abspath(__file__))
MODEL_ROOT = os.path.join(PACKAGE_ROOT, 'models')
DEFAULT_SXT_CELL_MODEL = os.path.join(MODEL_ROOT, 'sxt', 'cell_nucleus_unet_best.pth')
DEFAULT_SXT_MITO_MODEL = os.path.join(MODEL_ROOT, 'sxt', 'mito_unet_best.pth')
DEFAULT_SXT_ISG_MODEL = os.path.join(MODEL_ROOT, 'sxt', 'isg_unet_best.pth')
DEFAULT_SXT_ISG_MASKRCNN_MODEL = os.path.join(MODEL_ROOT, 'sxt', 'isg_mask_rcnn_final.pth')
DEFAULT_SIM_ER_MODEL = os.path.join(MODEL_ROOT, 'sim', 'ernet_final.pth')
class SXTCellSegmenter(BaseSegmenter):
"""SXT Cell and Nucleus Segmentation using U-Net."""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='sxt', task='cell', device=device)
self.model = None
def _load_model_impl(self, path: str = None, **kwargs):
"""Load SXT cell segmentation model."""
if path is None:
path = DEFAULT_SXT_CELL_MODEL
if not os.path.exists(path):
raise FileNotFoundError(f"SXT Cell model not found at {path}.")
import torch
# Use the Unet class that matches the nu_best.pth structure
from .segmentation_sxt.model_sxt_cell_mito.networks.Unet import Unet
# Initialize model: n_class=3 (bg, mem, nu)
self.model = Unet(n_class=3, is_dropout=True)
# Load state dict (handle DataParallel 'module.' prefix if present)
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
new_state_dict = {}
for k, v in checkpoint.items():
name = k.replace('module.', '') # remove 'module.' prefix
new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)
self.model.to(self.device)
self.model.eval()
self.model_path = path
def predict(self, data: np.ndarray, **kwargs) -> Dict[str, np.ndarray]:
"""
Predict cell and nucleus masks.
Args:
data: Input 3D volume (D, H, W)
Returns:
Dictionary with 'cell_mask' and 'nucleus_mask'
"""
if not self.is_loaded or self.model is None:
raise RuntimeError("Model not loaded!")
import torch
from skimage.transform import resize
import cv2
# Normalize input to [0, 255] uint8 as expected by the legacy model
data_norm = ((data - data.min()) / (data.max() - data.min()) * 255).astype(np.uint8)
D, H, W = data.shape
pred_volume = np.zeros((D, H, W), dtype=np.uint8)
self.model.eval()
with torch.no_grad():
for i in range(D):
slice_img = data_norm[i]
# Legacy model expects specific preprocessing: normalize + resize to (288, 480)
inputslice = cv2.resize(slice_img, (480, 288), interpolation=cv2.INTER_NEAREST)
# Convert to tensor with normalization used in training
transform_input = (inputslice.astype(np.float32) / 255.0 - 0.456) / 0.224
input_tensor = torch.from_numpy(transform_input).float().unsqueeze(0).unsqueeze(0).to(self.device)
output = self.model(input_tensor)
probs = torch.softmax(output, dim=1)
pred = torch.argmax(probs, dim=1).squeeze().cpu().numpy()
# Resize back to original shape
pred = cv2.resize(pred.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST)
pred_volume[i] = pred
# Extract masks: 1 = membrane/cell, 2 = nucleus
cell_mask = (pred_volume == 1).astype(np.uint8)
nucleus_mask = (pred_volume == 2).astype(np.uint8)
return {'cell_mask': cell_mask, 'nucleus_mask': nucleus_mask}
class SXTMitoSegmenter(BaseSegmenter):
"""SXT Mitochondria Segmentation using Mask R-CNN.
Implements instance segmentation for mitochondria using a pretrained
Mask R-CNN model (ResNet50-FPN backbone).
"""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='sxt', task='mito', device=device)
self.model = None
def _load_model_impl(self, path: str = None, **kwargs):
"""Load SXT mitochondria U-Net model."""
if path is None:
path = DEFAULT_SXT_MITO_MODEL
if not os.path.exists(path):
raise FileNotFoundError(f"SXT Mito model not found at {path}.")
import torch
from .segmentation_sxt.model_sxt_isg.unet.unet_model import UNet
# Initialize model (assuming 3 channels, 2 classes: bg + mito)
self.model = UNet(n_channels=3, n_classes=2, bilinear=False)
self.model.load_state_dict(torch.load(path, map_location=self.device, weights_only=False))
self.model.to(self.device)
self.model.eval()
self.model_path = path
def predict(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""
Predict mitochondria mask using U-Net.
"""
if not self.is_loaded or self.model is None:
raise RuntimeError("Model not loaded!")
import torch
from skimage.transform import resize
import cv2
# Normalize input to [0, 255] uint8
data_norm = ((data - data.min()) / (data.max() - data.min()) * 255).astype(np.uint8)
D, H, W = data.shape
pred_volume = np.zeros((D, H, W), dtype=np.uint8)
self.model.eval()
with torch.no_grad():
for i in range(D):
slice_img = data_norm[i]
# Preprocess: resize to (480, 288) and normalize
inputslice = cv2.resize(slice_img, (480, 288), interpolation=cv2.INTER_NEAREST)
transform_input = (inputslice.astype(np.float32) / 255.0 - 0.456) / 0.224
# Convert to 3-channel tensor as expected by the model
input_tensor = torch.from_numpy(np.stack([transform_input]*3, axis=0)).float().unsqueeze(0).to(self.device)
output = self.model(input_tensor)
probs = torch.softmax(output, dim=1)
pred = torch.argmax(probs, dim=1).squeeze().cpu().numpy()
# Resize back to original shape
pred = cv2.resize(pred.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST)
pred_volume[i] = pred
# Class 1 is mitochondria
return (pred_volume == 1).astype(np.uint8)
class SXTISGSegmenter(BaseSegmenter):
"""SXT ISG (Insulin Granule) Segmentation using U-Net."""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='sxt', task='isg', device=device)
self.model = None
def _load_model_impl(self, path: str = None, **kwargs):
"""Load SXT ISG U-Net model."""
if path is None:
path = DEFAULT_SXT_ISG_MODEL
if not os.path.exists(path):
raise FileNotFoundError(f"SXT ISG model not found at {path}.")
import torch
from .segmentation_sxt.model_sxt_isg.unet.unet_model import UNet
from PIL import Image
# Initialize U-Net model
self.model = UNet(n_channels=1, n_classes=2, bilinear=False)
self.model.to(self.device)
# Load weights
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint)
self.model.eval()
self.PIL_Image = Image # Store for preprocessing
self.model_path = path
def predict(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""
Predict ISG (insulin granule) mask.
Args:
data: Input 3D volume (D, H, W)
**kwargs: Additional parameters
Returns:
ISG segmentation mask (binary)
"""
if not self.is_loaded or self.model is None:
raise RuntimeError("Model not loaded! Call load_model() first.")
import torch
from skimage.transform import resize
D, H, W = data.shape
pred_volume = np.zeros((D, H, W), dtype=np.uint8)
self.model.eval()
with torch.no_grad():
for z in range(D):
slice_data = data[z]
# 1. Per-slice normalization (0-1)
v_min, v_max = slice_data.min(), slice_data.max()
if v_max - v_min > 0:
slice_norm = (slice_data - v_min) / (v_max - v_min)
else:
slice_norm = slice_data
# 2. Padding to square
h, w = slice_norm.shape
size = max(h, w)
img_padded = np.zeros((size, size), dtype=np.float32)
start_h = (size - h) // 2
start_w = (size - w) // 2
img_padded[start_h:start_h+h, start_w:start_w+w] = slice_norm
# 3. Resize to 400x400
pil_img = self.PIL_Image.fromarray((img_padded * 255).astype(np.uint8))
pil_img = pil_img.resize((400, 400), resample=self.PIL_Image.BILINEAR)
input_tensor = torch.from_numpy(np.asarray(pil_img) / 255.0).float().unsqueeze(0).unsqueeze(0).to(self.device)
# 4. Predict
output = self.model(input_tensor)
pred_slice = torch.argmax(output, dim=1).squeeze().cpu().numpy()
# 5. Resize back to original size
pred_resized = (resize(pred_slice, (H, W), order=0) > 0.5).astype(np.uint8)
pred_volume[z] = pred_resized
return pred_volume
[docs]
class SXTISGMaskRCNNSegmenter(BaseSegmenter):
"""SXT ISG Instance Segmentation using Mask R-CNN.
Implements true instance segmentation where each ISG gets a unique ID.
Uses a ResNet50-FPN backbone trained on COCO-format patches.
"""
[docs]
def __init__(self, device: Optional[str] = None):
super().__init__(modality='sxt', task='isg_maskrcnn', device=device)
self.model = None
[docs]
def _load_model_impl(self, path: str = None, **kwargs):
"""Load Mask R-CNN model."""
if path is None:
path = DEFAULT_SXT_ISG_MASKRCNN_MODEL
if not os.path.exists(path):
raise FileNotFoundError(f"SXT ISG Mask R-CNN model not found at {path}.")
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
# Initialize model (2 classes: background + ISG)
self.model = maskrcnn_resnet50_fpn(weights=None, num_classes=2)
self.model.load_state_dict(torch.load(path, map_location=self.device, weights_only=False))
self.model.to(self.device)
self.model.eval()
self.model_path = path
[docs]
def predict(self, data: np.ndarray, threshold: float = 0.5, **kwargs) -> np.ndarray:
"""
Predict ISG instance mask for a 3D volume by processing 2D slices.
Args:
data: Input 3D volume (D, H, W)
threshold: Confidence threshold for detection
Returns:
Labeled 3D mask where each ISG has a unique integer ID
"""
import torch
from torchvision.transforms import functional as F
from scipy import ndimage
D, H, W = data.shape
# We'll collect all detected instances across slices
# For true 3D instance consistency, we'd need 3D Mask R-CNN,
# but here we use 2D slices and label them.
final_3d_mask = np.zeros((D, H, W), dtype=np.uint16)
global_inst_id = 1
self.model.eval()
with torch.no_grad():
for z in range(D):
slice_data = data[z]
# Normalize to 0-1 and convert to tensor
v_min, v_max = slice_data.min(), slice_data.max()
if v_max - v_min > 0:
slice_norm = (slice_data - v_min) / (v_max - v_min)
else:
slice_norm = slice_data
# Convert to RGB tensor (Mask R-CNN expects 3 channels)
img_tensor = torch.from_numpy(np.stack([slice_norm]*3, axis=0)).float().to(self.device)
# Predict
predictions = self.model([img_tensor])
# Process predictions for this slice
boxes = predictions[0]['boxes'].cpu().numpy()
masks = predictions[0]['masks'].cpu().numpy()[0, 0] # (H, W)
scores = predictions[0]['scores'].cpu().numpy()
# Filter by threshold
keep = scores > threshold
filtered_masks = masks * keep[:, None, None] if len(keep) > 0 else np.zeros_like(masks)
# Label the filtered masks in this slice
if np.sum(filtered_masks) > 0:
# Combine all high-confidence masks into one binary slice
binary_slice = (np.max(filtered_masks, axis=0) > 0.5).astype(np.uint8)
labeled_slice, n = ndimage.label(binary_slice)
# Assign global IDs
for inst_id in range(1, n + 1):
final_3d_mask[z][labeled_slice == inst_id] = global_inst_id
global_inst_id += 1
return final_3d_mask
class SIMERSegmenter(BaseSegmenter):
"""SIM Endoplasmic Reticulum Segmentation using ERNet."""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='sim', task='er', device=device)
def _load_model_impl(self, path: str = None, image_size: int = 1000, **kwargs):
"""Load ERNet model."""
if path is None:
path = DEFAULT_SIM_ER_MODEL
if not os.path.exists(path):
raise FileNotFoundError(f"SIM ER model not found at {path}. Please provide a valid model path.")
import torch
from types import SimpleNamespace
from .segmentation_sim_wfm.ERNet.models import GetModel
from .segmentation_sim_wfm.ERNet.datahandler import toTensor
# Create config for ERNet
opt = SimpleNamespace()
opt.model = 'rcan'
opt.nch_in = 1
opt.nch_out = 2
opt.n_resgroups = 5
opt.n_resblocks = 10
opt.n_feats = 64
opt.reduction = 16
opt.narch = 0
opt.multigpu = False
opt.cpu = (self.device.type == 'cpu')
opt.undomulti = False
# Load model with config
self.model = GetModel(opt)
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
# Handle DataParallel wrapper
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)
self.model.to(self.device)
self.model.eval()
self.toTensor = toTensor
self.image_size = image_size
def predict(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""
Predict ER mask.
Args:
data: Input image (H, W) or (H, W, C)
**kwargs: Additional parameters
Returns:
ER segmentation mask
"""
import torch
if not self.is_loaded:
raise RuntimeError("Model not loaded! Call load_model() first.")
# Preprocess
data_normalized = self._normalize_data(data)
# Convert to tensor and move to device
input_tensor = self.toTensor(data_normalized)
input_tensor = input_tensor.unsqueeze(0).to(self.device)
# Predict
with torch.no_grad():
output = self.model(input_tensor)
# Postprocess
mask = output.squeeze().cpu().numpy()
mask = (mask > 0.5).astype(np.uint8)
return mask
class SIMMitoSegmenter(BaseSegmenter):
"""SIM Mitochondria Segmentation using thresholding-based approach.
Implements a thresholding-based pipeline inspired by Lefebvre et al. (2021):
1. Background correction (Gaussian blur + subtraction)
2. Global/adaptive thresholding (Otsu)
3. Morphological filtering (opening/closing)
4. Connected-component labeling with size filtering
"""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='sim', task='mito', device=device)
# Thresholding-based method, no model needed
self.is_loaded = True
def predict(self, data: np.ndarray, threshold: Optional[float] = None,
min_size: int = 50, max_size: Optional[int] = None,
sigma_bg: float = 20.0, **kwargs) -> np.ndarray:
"""
Segment mitochondria using thresholding-based pipeline.
Args:
data: Input image (H, W) or (D, H, W)
threshold: Segmentation threshold (auto-calculated via Otsu if None)
min_size: Minimum object size to keep (default: 50 pixels)
max_size: Maximum object size (None for no limit)
sigma_bg: Gaussian sigma for background estimation (default: 20)
**kwargs: Additional parameters
Returns:
Binary segmentation mask (same shape as input)
"""
from skimage import filters, morphology, measure
from scipy import ndimage
# Ensure float type
data_float = data.astype(np.float32)
# Step 1: Background correction
# Estimate background with large Gaussian blur and subtract
background = filters.gaussian(data_float, sigma=sigma_bg)
corrected = data_float - background
# Normalize to [0, 1]
corrected = corrected - corrected.min()
if corrected.max() > 0:
corrected = corrected / corrected.max()
# Step 2: Thresholding
if threshold is None:
# Use Otsu's method for automatic threshold
threshold = filters.threshold_otsu(corrected)
binary = corrected > threshold
# Step 3: Morphological operations
# Choose structuring element based on dimensionality
if data.ndim == 2:
elem_open = morphology.disk(1)
elem_close = morphology.disk(2)
else:
elem_open = morphology.ball(1)
elem_close = morphology.ball(2)
# Opening to remove small noise
binary = morphology.binary_opening(binary, elem_open)
# Closing to fill small holes and smooth boundaries
binary = morphology.binary_closing(binary, elem_close)
# Fill holes
binary = ndimage.binary_fill_holes(binary)
# Step 4: Connected component labeling and size filtering
labeled = measure.label(binary)
props = measure.regionprops(labeled)
# Filter by size
filtered_mask = np.zeros_like(binary, dtype=np.uint8)
for prop in props:
area = prop.area
if area >= min_size:
if max_size is None or area <= max_size:
filtered_mask[labeled == prop.label] = 1
return filtered_mask
class SIMISGSegmenter(BaseSegmenter):
"""SIM ISG Instance Segmentation using intensity-based watershed.
Implements a robust pipeline for detecting individual insulin granules:
1. Background subtraction (Top-hat filtering)
2. LoG blob detection for seed points
3. Marker-controlled Watershed for instance separation
"""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='sim', task='isg', device=device)
self.is_loaded = True # No model needed
def predict(self, data: np.ndarray, sigma: float = 2.0,
threshold_rel: float = 0.1, min_size: int = 10, **kwargs) -> np.ndarray:
"""
Segment individual ISGs.
Args:
data: Input image (H, W) or (D, H, W)
sigma: Sigma for LoG filter (controls spot size)
threshold_rel: Relative threshold for blob detection
min_size: Minimum object size in pixels
Returns:
Labeled mask where each ISG has a unique integer ID
"""
from skimage import filters, morphology, measure, segmentation
from scipy import ndimage
data_float = data.astype(np.float32)
if data_float.max() > 0:
data_float = data_float / data_float.max()
# 1. Enhance spots using LoG
log_filtered = ndimage.gaussian_laplace(data_float, sigma=sigma)
# Invert because LoG produces dark spots on bright background for blobs
spots = -log_filtered
spots = spots - spots.min()
if spots.max() > 0:
spots = spots / spots.max()
# 2. Find markers (seeds) using local maxima
from skimage.feature import peak_local_max
threshold = filters.threshold_otsu(spots) * threshold_rel
coordinates = peak_local_max(spots, min_distance=int(sigma*2), threshold_abs=threshold)
markers = np.zeros_like(data_float, dtype=bool)
if len(coordinates) > 0:
# Handle both 2D and 3D coordinates
if data_float.ndim == 2:
markers[coordinates[:, 0], coordinates[:, 1]] = True
else:
markers[coordinates[:, 0], coordinates[:, 1], coordinates[:, 2]] = True
markers = ndimage.label(markers)[0]
# 3. Watershed segmentation
# Use the original intensity as elevation map
elevation_map = 1 - data_float
labeled_mask = segmentation.watershed(elevation_map, markers, mask=data_float > 0.1)
# 4. Filter by size
props = measure.regionprops(labeled_mask)
final_mask = np.zeros_like(labeled_mask, dtype=np.uint16)
current_id = 1
for prop in props:
if prop.area >= min_size:
final_mask[labeled_mask == prop.label] = current_id
current_id += 1
return final_mask
class WFMSegmenter(BaseSegmenter):
"""WFM General Organelle Segmentation."""
def __init__(self, task: str = 'cell', device: Optional[str] = None):
super().__init__(modality='wfm', task=task, device=device)
def predict(self, data: np.ndarray, threshold: Optional[float] = None,
min_size: int = 100, **kwargs) -> np.ndarray:
"""
Segment organelles in WFM images using traditional methods.
"""
from skimage import filters, morphology, measure
from scipy import ndimage
# Handle multi-dimensional data (e.g., 5D time series)
if data.ndim > 3:
print(f"Warning: Input has {data.ndim} dimensions. Extracting t=0 slice...")
if data.ndim == 5:
z_mid = data.shape[1] // 2
data = data[0, z_mid, 0]
elif data.ndim == 4:
data = data[0, 0]
data_normalized = self._normalize_data(data)
# Auto threshold if not provided
if threshold is None:
threshold = filters.threshold_otsu(data_normalized)
binary = data_normalized > threshold
# Morphological cleaning
elem = morphology.ball(1) if data.ndim == 3 else morphology.disk(1)
binary = morphology.binary_opening(binary, elem)
binary = ndimage.binary_fill_holes(binary)
# Size filtering
labeled = measure.label(binary)
props = measure.regionprops(labeled)
mask = np.zeros_like(binary, dtype=np.uint8)
for prop in props:
if prop.area >= min_size:
mask[labeled == prop.label] = 1
return mask
class ETFilamentSegmenter(BaseSegmenter):
"""Cryo-ET Filament Segmentation and Skeletonization."""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='et', task='filament', device=device)
def predict(self, data: np.ndarray, skeletonize: bool = True,
threshold_multiplier: float = 1.2, **kwargs) -> Dict[str, np.ndarray]:
"""
Segment and skeletonize filaments in Cryo-ET data.
Args:
data: Input 3D volume (D, H, W)
skeletonize: Whether to extract skeleton (default: True)
threshold_multiplier: Multiplier for automatic threshold (default: 1.2)
**kwargs: Additional parameters passed to skeletonization_et_segmentation
Returns:
Dictionary with 'mask' and optionally 'skeleton'
"""
from .segmentation_et.segment_et import (
skeletonization_et_segmentation,
save_filament_branches_json
)
# Normalize
data_normalized = self._normalize_data(data)
# Segment using threshold multiplier
mask = skeletonization_et_segmentation(
data_normalized,
threshold_multiplier=threshold_multiplier
)
result = {'mask': mask.astype(np.uint8)}
# Extract skeleton if requested
if skeletonize:
from skimage.morphology import skeletonize
skeleton = skeletonize(mask > 0)
result['skeleton'] = skeleton.astype(np.uint8)
return result
class ETMembraneSegmenter(BaseSegmenter):
"""Cryo-ET Membrane Segmentation."""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='et', task='membrane', device=device)
def predict(self, data: np.ndarray, threshold: Optional[float] = None,
**kwargs) -> np.ndarray:
"""
Segment membranes in Cryo-ET data.
Args:
data: Input 3D volume (D, H, W)
threshold: Segmentation threshold
**kwargs: Additional parameters
Returns:
Membrane segmentation mask
"""
# Use simple thresholding for now
# Can be extended with more sophisticated methods
data_normalized = self._normalize_data(data)
if threshold is None:
# Auto threshold using Otsu's method
from skimage.filters import threshold_otsu
threshold = threshold_otsu(data_normalized)
mask = (data_normalized > threshold).astype(np.uint8)
# Clean up small objects
from skimage.morphology import remove_small_objects
mask = remove_small_objects(mask.astype(bool), min_size=100).astype(np.uint8)
return mask
class SXTMitoInstanceSegmenter(BaseSegmenter):
"""SXT Mitochondria Instance Segmentation.
Implements instance segmentation for mitochondria using the same blob separation
algorithm as ISGs, adapted for tubular/network structures.
"""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='sxt', task='mito_instance', device=device)
self.semantic_segmenter = SXTMitoSegmenter(device=device)
def _load_model_impl(self, path: str = None, **kwargs):
"""Load the underlying semantic model."""
self.semantic_segmenter.load_model(path=path, **kwargs)
self.is_loaded = True
def predict(self, data: np.ndarray, min_voxels: int = 50, **kwargs) -> np.ndarray:
"""
Predict mitochondria instance mask using blob separation.
Args:
data: Input 3D volume (D, H, W)
min_voxels: Minimum number of voxels to keep an instance (default: 50)
**kwargs: Additional parameters passed to semantic segmenter
Returns:
Labeled mask where each mitochondrion has a unique integer ID
"""
from .segmentation_sxt.blob_separation import blob_fit
from scipy import ndimage
# 1. Get semantic mask
semantic_mask = self.semantic_segmenter.predict(data, **kwargs)
# Prepare raw data (normalize to 0-1 range)
raw_data = data.astype(np.float32)
v_min, v_max = raw_data.min(), raw_data.max()
if v_max - v_min > 0:
raw_data = (raw_data - v_min) / (v_max - v_min)
else:
raw_data = np.zeros_like(raw_data)
# 2. Get initial connected components
labeled, num_features = ndimage.label(semantic_mask)
if num_features == 0:
return np.zeros_like(semantic_mask, dtype=np.uint16)
final_mask = np.zeros_like(semantic_mask, dtype=np.uint16)
current_global_id = 1
print(f"Starting mito instance separation for {num_features} components...")
# 3. Process each component with blob separation
for local_id in range(1, num_features + 1):
coords = np.where(labeled == local_id)
if len(coords[0]) < min_voxels:
continue
# Extract ROI with padding
x_min, x_max = np.min(coords[0]), np.max(coords[0])
y_min, y_max = np.min(coords[1]), np.max(coords[1])
z_min, z_max = np.min(coords[2]), np.max(coords[2])
pad = 5
roi = (
slice(max(0, x_min - pad), min(data.shape[0], x_max + 1 + pad)),
slice(max(0, y_min - pad), min(data.shape[1], y_max + 1 + pad)),
slice(max(0, z_min - pad), min(data.shape[2], z_max + 1 + pad))
)
roi_mask = semantic_mask[roi].astype(np.uint8)
roi_raw = raw_data[roi]
# Run blob separation (same algorithm as ISG)
separated_roi = blob_fit(roi_mask, roi_raw, min_r=1.5, theta=0.8, check_=False)
# Map results back to global mask
non_zero = np.where(separated_roi > 0)
if len(non_zero[0]) > 0:
unique_local_ids = np.unique(separated_roi[non_zero])
for lid in unique_local_ids:
if lid == 0: continue
local_coords = np.where(separated_roi == lid)
gx = local_coords[0] + roi[0].start
gy = local_coords[1] + roi[1].start
gz = local_coords[2] + roi[2].start
final_mask[gx, gy, gz] = current_global_id
current_global_id += 1
print(f"Mito instance separation complete. Found {current_global_id - 1} individual mitochondria.")
return final_mask
class SXTISGInstanceSegmenter(BaseSegmenter):
"""SXT ISG Instance Segmentation using Organelle Separation.
Implements instance segmentation for ISGs by performing connected component
analysis on the semantic mask, with optional size filtering to remove noise.
Uses a simplified blob separation algorithm based on intensity and morphology.
"""
def __init__(self, device: Optional[str] = None):
super().__init__(modality='sxt', task='isg_instance', device=device)
self.semantic_segmenter = SXTISGSegmenter(device=device)
def _load_model_impl(self, path: str = None, **kwargs):
"""Load the underlying semantic model."""
self.semantic_segmenter.load_model(path=path, **kwargs)
self.is_loaded = True
def predict(self, data: np.ndarray, min_voxels: int = 8, use_advanced_separation: bool = True, **kwargs) -> np.ndarray:
"""
Predict ISG instance mask.
Args:
data: Input 3D volume (D, H, W)
min_voxels: Minimum number of voxels to keep an instance (default: 8)
use_advanced_separation: Whether to use the advanced blob separation algorithm
**kwargs: Additional parameters passed to semantic segmenter
Returns:
Labeled mask where each ISG has a unique integer ID
"""
from scipy import ndimage
from skimage.feature import peak_local_max
from skimage.segmentation import watershed
# 1. Get semantic mask
semantic_mask = self.semantic_segmenter.predict(data, **kwargs)
if not use_advanced_separation:
# Simple connected components approach
labeled, num_features = ndimage.label(semantic_mask)
if num_features == 0:
return np.zeros_like(semantic_mask, dtype=np.uint16)
component_sizes = np.bincount(labeled.ravel())
small_components = component_sizes < min_voxels
small_components[0] = False # Keep background
filtered_labeled = labeled.copy()
for label_id in np.where(small_components)[0]:
filtered_labeled[labeled == label_id] = 0
# Re-label to ensure consecutive IDs starting from 1
final_mask = np.zeros_like(filtered_labeled, dtype=np.uint16)
unique_labels = np.unique(filtered_labeled)
unique_labels = unique_labels[unique_labels > 0]
for new_id, old_id in enumerate(unique_labels, start=1):
final_mask[filtered_labeled == old_id] = new_id
return final_mask
else:
# Advanced blob separation using the integrated professional algorithm
from .segmentation_sxt.blob_separation import blob_fit
# Prepare raw data (normalize to 0-1 range)
raw_data = data.astype(np.float32)
v_min, v_max = raw_data.min(), raw_data.max()
if v_max - v_min > 0:
raw_data = (raw_data - v_min) / (v_max - v_min)
else:
raw_data = np.zeros_like(raw_data)
# Get initial connected components to process individually
labeled, num_features = ndimage.label(semantic_mask)
if num_features == 0:
return np.zeros_like(semantic_mask, dtype=np.uint16)
final_mask = np.zeros_like(semantic_mask, dtype=np.uint16)
current_global_id = 1
print(f"Starting advanced blob separation for {num_features} components...")
# Process each component
for local_id in range(1, num_features + 1):
coords = np.where(labeled == local_id)
if len(coords[0]) < min_voxels:
continue
# Extract ROI with padding
x_min, x_max = np.min(coords[0]), np.max(coords[0])
y_min, y_max = np.min(coords[1]), np.max(coords[1])
z_min, z_max = np.min(coords[2]), np.max(coords[2])
pad = 5
roi = (
slice(max(0, x_min - pad), min(data.shape[0], x_max + 1 + pad)),
slice(max(0, y_min - pad), min(data.shape[1], y_max + 1 + pad)),
slice(max(0, z_min - pad), min(data.shape[2], z_max + 1 + pad))
)
roi_mask = semantic_mask[roi].astype(np.uint8)
roi_raw = raw_data[roi]
# Run the professional blob separation algorithm
separated_roi = blob_fit(roi_mask, roi_raw, min_r=1.5, theta=0.8, check_=False)
# Map results back to global mask
non_zero = np.where(separated_roi > 0)
if len(non_zero[0]) > 0:
unique_local_ids = np.unique(separated_roi[non_zero])
for lid in unique_local_ids:
if lid == 0: continue
local_coords = np.where(separated_roi == lid)
gx = local_coords[0] + roi[0].start
gy = local_coords[1] + roi[1].start
gz = local_coords[2] + roi[2].start
final_mask[gx, gy, gz] = current_global_id
current_global_id += 1
print(f"Advanced separation complete. Found {current_global_id - 1} instances.")
return final_mask
# Factory function for easy access
def create_segmenter(modality: str, task: str, device: Optional[str] = None) -> BaseSegmenter:
"""
Create a segmenter instance based on modality and task.
Args:
modality: 'sxt', 'sim', 'wfm', or 'et'
task: Segmentation task (see table below)
device: 'cuda' or 'cpu'
Returns:
Segmenter instance
Supported Modality-Task Combinations:
| Modality | Task | Description |
|----------|--------------|--------------------------------|
| sxt | cell | Cell and nucleus segmentation |
| sxt | mito | Mitochondria segmentation |
| sxt | isg | ISG (insulin granule) seg |
| sim | er | ER segmentation (ERNet) |
| sim | mito | Mitochondria segmentation |
| wfm | cell | Cell shape segmentation |
| wfm | nucleus | Nucleus segmentation |
| wfm | sphere | Spherical organelles |
| wfm | vesicle | Vesicle segmentation |
| et | filament | Filament skeletonization |
| et | membrane | Membrane segmentation |
Example:
>>> # SXT cell segmentation
>>> segmenter = create_segmenter('sxt', 'cell')
>>> segmenter.load_model('model.pth')
>>> mask = segmenter.predict(data)
>>>
>>> # WFM nucleus segmentation (no model needed)
>>> segmenter = create_segmenter('wfm', 'nucleus')
>>> mask = segmenter.predict(data)
"""
key = f"{modality.lower()}_{task.lower()}"
segmenters = {
# SXT
'sxt_cell': SXTCellSegmenter,
'sxt_mito': SXTMitoSegmenter,
'sxt_isg': SXTISGSegmenter,
'sxt_isg_instance': SXTISGInstanceSegmenter,
'sxt_mito_instance': SXTMitoInstanceSegmenter,
'sxt_isg_maskrcnn': SXTISGMaskRCNNSegmenter,
# SIM
'sim_er': SIMERSegmenter,
'sim_mito': SIMMitoSegmenter,
'sim_isg': SIMISGSegmenter,
# WFM
'wfm_cell': lambda device=None: WFMSegmenter(task='cell', device=device),
'wfm_nucleus': lambda device=None: WFMSegmenter(task='nucleus', device=device),
'wfm_sphere': lambda device=None: WFMSegmenter(task='sphere', device=device),
'wfm_vesicle': lambda device=None: WFMSegmenter(task='vesicle', device=device),
# Cryo-ET
'et_filament': ETFilamentSegmenter,
'et_membrane': ETMembraneSegmenter,
}
if key not in segmenters:
available = list(segmenters.keys())
raise ValueError(
f"Unsupported modality/task combination: '{modality}/{task}'.\n"
f"Key: '{key}'\n"
f"Available combinations: {available}\n\n"
f"Example usage:\n"
f" create_segmenter('sxt', 'cell')\n"
f" create_segmenter('sim', 'er')\n"
f" create_segmenter('wfm', 'nucleus')\n"
f" create_segmenter('et', 'filament')"
)
# Handle lambda functions for parameterized segmenters
creator = segmenters[key]
if callable(creator) and not hasattr(creator, '__class__'):
return creator(device)
else:
return creator(device=device)