Source code for ipa.processing.segmentation.base

"""
Base class for segmentation methods.

Provides a unified interface for different segmentation tasks
across multiple imaging modalities (SXT, SIM, WFM, Cryo-ET).
"""

from abc import ABC, abstractmethod
import numpy as np
import torch
from typing import Optional, Dict, Any


[docs] class BaseSegmenter(ABC): """Abstract base class for all segmentation methods.""" def __init__(self, modality: str, task: str, device: Optional[str] = None): """ Initialize the segmenter. Args: modality: Imaging modality ('sxt', 'sim', 'wfm', 'et') task: Segmentation task ('cell', 'mito', 'er', 'nucleus', 'filament', etc.) device: Device to run on ('cuda' or 'cpu'). Auto-detected if None. """ self.modality = modality.lower() self.task = task.lower() # Auto-detect device if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) self.model = None self.is_loaded = False
[docs] @abstractmethod def predict(self, data: np.ndarray, **kwargs) -> np.ndarray: """ Predict segmentation mask. Args: data: Input image/volume data Returns: Segmentation mask (numpy array) """ pass
[docs] def load_model(self, path: str = None, **kwargs): """ Load trained model from file. Args: path: Path to the saved model. If None, uses default path if available. **kwargs: Additional parameters for model loading """ if not hasattr(self, '_load_model_impl'): raise NotImplementedError("Model loading not implemented for this segmenter") self._load_model_impl(path, **kwargs) self.is_loaded = True if path: print(f"Model loaded from: {path}") else: print(f"Model loaded using default path.")
[docs] def save_model(self, path: str): """ Save trained model to file. Args: path: Path to save the model """ if not self.is_loaded: raise RuntimeError("No model loaded to save!") if not hasattr(self, '_save_model_impl'): raise NotImplementedError("Model saving not implemented for this segmenter") self._save_model_impl(path) print(f"Model saved to: {path}")
[docs] def train(self, train_data: Any, val_data: Optional[Any] = None, **kwargs): """ Train the segmentation model (optional). Args: train_data: Training data val_data: Optional validation data **kwargs: Training parameters """ if not hasattr(self, '_train_impl'): raise NotImplementedError("Training not implemented for this segmenter") self._train_impl(train_data, val_data, **kwargs) self.is_loaded = True print("Training completed!")
def _normalize_data(self, data: np.ndarray) -> np.ndarray: """ Normalize data to [0, 1] range. Args: data: Input data Returns: Normalized data """ data = data.astype(np.float32) min_val = data.min() max_val = data.max() if max_val > min_val: data = (data - min_val) / (max_val - min_val) return data def _ensure_channel_dim(self, data: np.ndarray, n_channels: int = 1) -> np.ndarray: """ Ensure data has channel dimension. Args: data: Input data n_channels: Number of channels Returns: Data with channel dimension """ if data.ndim == 2: # (H, W) -> (1, H, W) or (H, W, 1) return data[np.newaxis, ...] if n_channels == 1 else data[..., np.newaxis] elif data.ndim == 3: # Could be (D, H, W) or (H, W, C) if n_channels == 1: return data[np.newaxis, ...] # (1, D, H, W) else: return data # Assume already (H, W, C) else: return data
[docs] def get_info(self) -> Dict[str, str]: """Get segmenter information.""" return { 'modality': self.modality, 'task': self.task, 'device': str(self.device), 'is_loaded': str(self.is_loaded) }