"""
Cellular Partitioning Module
This module implements radial cytoplasmic partitioning analysis, a novel approach for
spatial zoning of cellular structures from nucleus to plasma membrane.
Through creating concentric radial partitions, it enables quantitative analysis of
organelle distribution patterns across different cytoplasmic regions.
"""
import os
import numpy as np
import glob
from scipy import ndimage
import matplotlib.pyplot as plt
from tqdm import tqdm
# Add performance optimization dependencies
try:
import torch
TORCH_AVAILABLE = True
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
except ImportError:
TORCH_AVAILABLE = False
DEVICE = torch.device('cpu')
print("Warning: PyTorch not available. Falling back to CPU processing.")
try:
from joblib import Parallel, delayed
JOBLIB_AVAILABLE = True
except ImportError:
JOBLIB_AVAILABLE = False
print("Warning: Joblib not available. Parallel processing will be disabled.")
[docs]
class Partitioning:
"""
This class provides a comprehensive toolkit for cellular partitioning analysis,
including boundary extraction, partition generation, feature calculation, and
visualization. Primarily used for analyzing organelle spatial distribution
between nuclear envelope and plasma membrane.
Parameters
root_dir : str
Root directory path for result outputs
n_slices : int, default=8
Number of radial partitions, dividing nucleus-membrane distance into n_slices regions
num_cores : int, optional
Number of CPU cores for parallel processing, defaults to 1
Attributes
root_dir : str
Output directory path
n_slices : int
Number of partitions
num_cores : int
Number of CPU cores
"""
def __init__(self, root_dir, n_slices=8, num_cores=None):
"""
Initialize Partitioning instance.
:param root_dir: Directory to save outputs.
:param n_slices: Number of shell partitions.
:param num_cores: Number of CPU cores for parallel processing.
"""
self.root_dir = root_dir
self.n_slices = n_slices
self.num_cores = num_cores
os.makedirs(root_dir, exist_ok=True)
os.makedirs(os.path.join(root_dir, 'results'), exist_ok=True)
[docs]
def load_mask_data(self, ne_path, pm_path):
"""
Load nuclear envelope (NE) and plasma membrane (PM) mask data from files
Supports multiple file formats including numpy arrays (.npy), TIFF images
(.tif/.tiff), MRC format (.mrc), and other common biological imaging formats.
Parameters
----------
ne_path : str
Path to nuclear envelope mask file
pm_path : str
Path to plasma membrane mask file
Returns
-------
tuple of numpy.ndarray
(ne_mask, pm_mask) - Binary mask arrays for nuclear envelope and plasma membrane
Raises
------
FileNotFoundError
When specified file paths do not exist
ValueError
When file format is unsupported or data format is incorrect
Examples
--------
>>> ne_mask, pm_mask = partitioner.load_mask_data("nuclear.npy", "membrane.npy")
>>> print(f"NE mask shape: {ne_mask.shape}, PM mask shape: {pm_mask.shape}")
"""
# Implement loading logic here (e.g., using tifffile or mrcfile)
# Placeholder:
ne_mask = np.load(ne_path)
pm_mask = np.load(pm_path)
return ne_mask, pm_mask
[docs]
@staticmethod
def smooth_mask(mask: np.ndarray, sigma: float = 1.0, close_iter: int = 2) -> np.ndarray:
"""
Smooth binary mask using Gaussian filtering and morphological closing.
Parameters
----------
mask : numpy.ndarray
Input binary mask.
sigma : float
Sigma for Gaussian filter. Higher values produce smoother boundaries.
close_iter : int
Number of iterations for morphological closing to fill small gaps.
Returns
-------
numpy.ndarray
Smoothed binary mask.
"""
from scipy.ndimage import gaussian_filter, binary_closing
# 1. Gaussian smoothing on float mask
smoothed = gaussian_filter(mask.astype(float), sigma=sigma)
# 2. Threshold back to binary
binary_smoothed = smoothed > 0.5
# 3. Morphological closing to bridge small gaps
if close_iter > 0:
binary_smoothed = binary_closing(binary_smoothed, iterations=close_iter)
return binary_smoothed.astype(np.uint8)
[docs]
@staticmethod
def min_angle(ne_points, pm_point, center):
"""
Find nuclear envelope point with minimal angle to plasma membrane point relative to center
Calculates cosine angles between vectors to find the directionally closest
nuclear envelope point. Used for establishing optimal NE-PM point pair matching.
Parameters
----------
ne_points : numpy.ndarray
Nuclear envelope point coordinates array, shape (N, 3)
pm_point : numpy.ndarray
Single plasma membrane point coordinates, shape (3,)
center : numpy.ndarray
Cell center coordinates, shape (3,)
Returns
-------
numpy.ndarray
Nuclear envelope point coordinates with minimal angle, shape (3,)
Notes
-----
Uses cosine similarity to calculate vector angles:
cos(angle) = dot product of vectors divided by product of their magnitudes
Selects the point with maximum cosine value, i.e., minimal angle.
Examples
--------
>>> closest_ne = Partitioning.min_angle(ne_points, pm_point, center)
>>> print(f"Closest NE point: {closest_ne}")
"""
vectors_ne = ne_points - center
vector_pm = pm_point - center
# Vectorized calculation of all angles
norms_ne = np.linalg.norm(vectors_ne, axis=1)
norm_pm = np.linalg.norm(vector_pm)
# Avoid division by zero
valid_mask = norms_ne > 1e-8
cos_angles = np.full(len(vectors_ne), -1.0) # Initialize with minimum value
if np.any(valid_mask):
cos_angles[valid_mask] = np.dot(vectors_ne[valid_mask], vector_pm) / (norms_ne[valid_mask] * norm_pm + 1e-8)
max_idx = np.argmax(cos_angles)
return ne_points[max_idx]
[docs]
def find_ne_pm_pairs(self, center, ne_edge, pm_edge,
step=8, save_txt=False, dataid=None):
"""
Pair nuclear envelope and plasma membrane boundary points using GPU acceleration.
"""
print(f"Original NE edge points: {len(ne_edge)}")
print(f"Original PM edge points: {len(pm_edge)}")
# Sampling control
target_ne_points = min(50000, len(ne_edge))
if len(ne_edge) > target_ne_points:
ne_step = len(ne_edge) // target_ne_points
ne_sampled = ne_edge[::ne_step]
else:
ne_sampled = ne_edge
# Limit PM points to avoid memory issues (max 100k for GPU batch processing)
target_pm_points = min(len(ne_sampled) * 8, 100000, len(pm_edge))
if len(pm_edge) > target_pm_points:
pm_step = len(pm_edge) // target_pm_points
pm_sampled = pm_edge[::pm_step]
else:
pm_sampled = pm_edge
print(f"Sampled NE: {len(ne_sampled)}, PM: {len(pm_sampled)}")
if TORCH_AVAILABLE and DEVICE.type == 'cuda':
return self._find_pairs_gpu(center, ne_sampled, pm_sampled, save_txt, dataid)
else:
return self._find_pairs_cpu(center, ne_sampled, pm_sampled, save_txt, dataid)
def _find_pairs_gpu(self, center, ne_sampled, pm_sampled, save_txt, dataid):
"""GPU-accelerated pairing using PyTorch with dynamic batch sizing."""
import torch
ne_tensor = torch.tensor(ne_sampled, dtype=torch.float32, device=DEVICE)
pm_tensor = torch.tensor(pm_sampled, dtype=torch.float32, device=DEVICE)
center_tensor = torch.tensor(center, dtype=torch.float32, device=DEVICE)
# Normalize vectors for angle matching
pm_vectors = pm_tensor - center_tensor
pm_norms = torch.norm(pm_vectors, dim=1, keepdim=True)
pm_unit = pm_vectors / (pm_norms + 1e-8)
ne_vectors = ne_tensor - center_tensor
ne_norms = torch.norm(ne_vectors, dim=1, keepdim=True)
ne_unit = ne_vectors / (ne_norms + 1e-8)
# Dynamic batch size calculation based on GPU memory
num_pm = len(pm_unit)
# Estimate memory per row: num_pm * 4 bytes (float32) * 2 (for sim_matrix and overhead)
# Leave some margin (e.g., use 50% of available memory for this matrix)
try:
free_mem = torch.cuda.mem_get_info()[0] # Available memory in bytes
target_mem = free_mem * 0.4 # Use 40% of free memory
max_rows = int(target_mem / (num_pm * 4))
batch_size = min(max(100, max_rows), 5000) # Clamp between 100 and 5000
except:
batch_size = 1000 # Fallback fixed value
print(f"[GPU] Using batch size: {batch_size} (PM points: {num_pm})")
matched_indices_list = []
for i in range(0, len(ne_unit), batch_size):
ne_batch = ne_unit[i:i+batch_size]
# Compute similarity matrix for this batch
sim_matrix = torch.mm(ne_batch, pm_unit.t())
# Find best match for each NE point in batch
best_pm_indices = torch.argmax(sim_matrix, dim=1)
matched_indices_list.append(best_pm_indices)
# Clear cache to prevent fragmentation
if i % (batch_size * 5) == 0:
torch.cuda.empty_cache()
# Concatenate all matched indices
all_best_pm_indices = torch.cat(matched_indices_list, dim=0)
pm_matched = pm_tensor[all_best_pm_indices]
pairs_array = torch.cat([ne_tensor, pm_matched], dim=1).cpu().numpy()
# Distance filtering
dists = np.linalg.norm(pairs_array[:, :3] - pairs_array[:, 3:6], axis=1)
valid_mask = (dists >= 15.0) & (dists <= 200.0)
pairs = pairs_array[valid_mask]
if save_txt and dataid:
os.makedirs(f"{self.root_dir}/results", exist_ok=True)
np.savetxt(f"{self.root_dir}/results/{dataid}_pairs.txt", pairs, fmt='%.2f')
print(f"[GPU] Generated {len(pairs)} pairs")
return pairs
def _find_pairs_cpu(self, center, ne_sampled, pm_sampled, save_txt, dataid):
"""Fallback CPU pairing using KDTree."""
from scipy.spatial import cKDTree
pm_vectors = pm_sampled - center
pm_norms = np.linalg.norm(pm_vectors, axis=1)
pm_unit_vectors = pm_vectors / (pm_norms[:, np.newaxis] + 1e-8)
ne_vectors = ne_sampled - center
ne_norms = np.linalg.norm(ne_vectors, axis=1)
ne_unit_vectors = ne_vectors / (ne_norms[:, np.newaxis] + 1e-8)
pm_tree = cKDTree(pm_unit_vectors)
distances, indices = pm_tree.query(ne_unit_vectors, k=1)
pm_matched = pm_sampled[indices]
pairs_array = np.column_stack([ne_sampled, pm_matched])
euclidean_dists = np.linalg.norm(pairs_array[:, :3] - pairs_array[:, 3:6], axis=1)
valid_mask = (euclidean_dists >= 15.0) & (euclidean_dists <= 200.0)
pairs = pairs_array[valid_mask]
if save_txt and dataid:
os.makedirs(f"{self.root_dir}/results", exist_ok=True)
np.savetxt(f"{self.root_dir}/results/{dataid}_pairs.txt", pairs, fmt='%.2f')
print(f"[CPU] Generated {len(pairs)} pairs")
return pairs
[docs]
@staticmethod
def points_along_vector(p1, p2, n_slices):
"""
Generate equally spaced division points between two points
Creates n_slices+1 equally spaced points along the vector from p1 to p2,
including start and end points. These points are used to create boundary
surfaces for radial partitions.
Parameters
----------
p1 : numpy.ndarray
Starting point coordinates, shape (3,)
p2 : numpy.ndarray
Ending point coordinates, shape (3,)
n_slices : int
Number of division segments
Returns
-------
numpy.ndarray
Equally spaced points array, shape (n_slices+1, 3)
Examples
--------
>>> points = Partitioning.points_along_vector([0,0,0], [10,10,10], 4)
>>> print(f"Generated {len(points)} division points")
"""
return np.array([p1 + (p2 - p1) * i / n_slices for i in range(n_slices + 1)])
[docs]
def shell_partition(self, pairs, cell_mask=None):
"""
Generate shell partition points from paired NE-PM points with adaptive partitioning
Creates equally spaced partition points along the connection line for each
NE-PM point pair, with adaptive subdivision for better handling of irregular cell shapes.
Parameters
----------
pairs : numpy.ndarray
Paired points array, shape (K, 6)
cell_mask : numpy.ndarray, optional
Binary mask indicating the active cell regions, shape (H, W, D)
Returns
-------
list of numpy.ndarray
Shell points list, each element is a coordinate array of all points in one shell
Notes
-----
Optimized version with adaptive partitioning for irregular cell shapes
"""
print(f"Creating {self.n_slices} adaptive shells from {len(pairs)} pairs...")
all_points = []
pair_distances = []
# Calculate distances for adaptive partitioning
for pair in pairs:
ne_pt = pair[:3]
pm_pt = pair[3:]
distance = np.linalg.norm(ne_pt - pm_pt)
pair_distances.append(distance)
pair_distances = np.array(pair_distances)
median_distance = np.median(pair_distances)
print(f"Pair distance statistics: min={np.min(pair_distances):.2f}, "
f"median={median_distance:.2f}, max={np.max(pair_distances):.2f}")
for i, pair in enumerate(pairs):
ne_pt = pair[:3]
pm_pt = pair[3:]
current_distance = pair_distances[i]
# Adaptive subdivision based on distance
if current_distance < median_distance * 0.5:
# For very short distances (indentations), use fewer subdivisions
adaptive_n_slices = max(3, self.n_slices // 2)
elif current_distance > median_distance * 2.0:
# For very long distances, use more subdivisions
adaptive_n_slices = self.n_slices + 2
else:
# Normal distances use standard subdivisions
adaptive_n_slices = self.n_slices
# Generate adaptive subdivision points
slice_pts = self.points_along_vector(ne_pt, pm_pt, adaptive_n_slices)
all_points.append((slice_pts, adaptive_n_slices))
# Organize points into shells with weighted contribution
shell_points = [[] for _ in range(self.n_slices + 1)]
for pts_info in all_points:
slice_pts, adaptive_n_slices = pts_info
# Map adaptive points to standard shell indices
for i, pt in enumerate(slice_pts):
# Calculate relative position (0 to 1)
relative_pos = i / adaptive_n_slices
# Map to standard shell index
shell_idx = min(int(relative_pos * self.n_slices), self.n_slices)
shell_points[shell_idx].append(pt)
cleaned_shells = []
for i, shell in enumerate(shell_points):
if len(shell) == 0:
cleaned_shells.append(np.array([]))
continue
shell_arr = np.unique(np.array(shell), axis=0)
if cell_mask is not None:
# Use vectorized operations to filter points
shell_int = shell_arr.astype(int)
valid_mask = (
(shell_int[:, 0] >= 0) & (shell_int[:, 0] < cell_mask.shape[0]) &
(shell_int[:, 1] >= 0) & (shell_int[:, 1] < cell_mask.shape[1]) &
(shell_int[:, 2] >= 0) & (shell_int[:, 2] < cell_mask.shape[2])
)
if np.any(valid_mask):
valid_shell_int = shell_int[valid_mask]
mask_values = cell_mask[valid_shell_int[:, 0], valid_shell_int[:, 1], valid_shell_int[:, 2]]
final_mask = mask_values > 0
shell_arr = shell_arr[valid_mask][final_mask]
else:
shell_arr = np.array([])
cleaned_shells.append(shell_arr)
print(f"Shell {i}: {len(shell_arr)} points")
print(f"Generated {len(cleaned_shells)} adaptive shells total")
return cleaned_shells
def _convert_shells_to_partition_mask_pure_pairs(self, shells, shape, pm_mask=None, ne_mask=None):
"""
Convert shell point cloud data to continuous 3D partition masks using pure NE-PM pair information
This method only uses the shell points generated from NE-PM pairs without any EDT guidance.
It assigns partition IDs based on proximity to shell boundaries and interpolation between shells.
Parameters
----------
shells : list of numpy.ndarray
List of shell point arrays, each containing 3D coordinates
shape : tuple
Shape of the output partition mask (Z, Y, X)
pm_mask : numpy.ndarray, optional
Original plasma membrane mask for defining cell boundaries
ne_mask : numpy.ndarray, optional
Original nuclear envelope mask for defining nuclear boundaries
Returns
-------
numpy.ndarray
3D partition mask where each voxel value indicates partition ID (0=background)
"""
from scipy.ndimage import binary_fill_holes
from scipy.spatial import cKDTree
import numpy as np
print(f"Converting {len(shells)} shells to partition mask...")
# Create cell regions from masks
if pm_mask is not None and ne_mask is not None:
pm_binary = pm_mask > 0
ne_binary = ne_mask > 0
cell_interior = binary_fill_holes(pm_binary)
nuclear_interior = binary_fill_holes(ne_binary)
cytoplasm_region = cell_interior & ~nuclear_interior
print(f"Cytoplasm region voxels: {np.sum(cytoplasm_region)}")
else:
print("Warning: No masks provided, using full volume")
cytoplasm_region = np.ones(shape, dtype=bool)
# Filter out empty shells and build KDTrees for non-empty shells
valid_shells = []
shell_trees = []
shell_indices = []
for i, shell in enumerate(shells):
if len(shell) > 0:
valid_shells.append(shell)
shell_trees.append(cKDTree(shell))
shell_indices.append(i + 1) # Partition IDs start from 1
print(f"Shell {i}: {len(shell)} points -> Partition {i + 1}")
if len(valid_shells) < 2:
print("Error: Need at least 2 valid shells for partitioning")
return np.zeros(shape, dtype=int)
print(f"Built KDTrees for {len(valid_shells)} valid shells")
# Get cytoplasm coordinates
cytoplasm_coords = np.column_stack(np.where(cytoplasm_region))
print(f"Processing {len(cytoplasm_coords)} cytoplasm voxels...")
partition_mask = np.zeros(shape, dtype=int)
# Process in batches for memory efficiency
batch_size = 50000
n_batches = (len(cytoplasm_coords) + batch_size - 1) // batch_size
for batch_idx in range(n_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(cytoplasm_coords))
batch_coords = cytoplasm_coords[start_idx:end_idx]
# Find distances to all shells for the batch
batch_distances = []
for tree in shell_trees:
distances, _ = tree.query(batch_coords, k=1)
batch_distances.append(distances)
batch_distances = np.array(batch_distances).T # Shape: (batch_size, n_shells)
# Assign partitions based on corrected logic
batch_partition_ids = np.zeros(len(batch_coords), dtype=int)
for i, coord_distances in enumerate(batch_distances):
# Find the closest shell
closest_shell_idx = np.argmin(coord_distances)
closest_shell_id = shell_indices[closest_shell_idx]
# Check if closest shell is innermost (shell 1) or outermost (shell n)
if closest_shell_id == 1:
# Closest to innermost shell -> directly assign to partition 1
batch_partition_ids[i] = 1
elif closest_shell_id >= len(shell_indices):
# Closest to outermost shell -> directly assign to outermost partition
batch_partition_ids[i] = max(shell_indices)
else:
# Point is closest to a middle shell -> need interpolation between shells
sorted_indices = np.argsort(coord_distances)
closest_shell_idx = sorted_indices[0]
second_closest_shell_idx = sorted_indices[1] if len(sorted_indices) > 1 else closest_shell_idx
closest_dist = coord_distances[closest_shell_idx]
second_closest_dist = coord_distances[second_closest_shell_idx]
if len(sorted_indices) == 1:
# Only one shell available, assign to its partition
batch_partition_ids[i] = shell_indices[closest_shell_idx]
else:
# Multiple shells: determine which partition region this voxel belongs to
shell_id_1 = shell_indices[closest_shell_idx]
shell_id_2 = shell_indices[second_closest_shell_idx]
# Determine partition based on relative position between shells
total_dist = closest_dist + second_closest_dist
if total_dist > 0:
# Relative position: closer to which shell
relative_pos = second_closest_dist / total_dist
# If significantly closer to first shell, assign to its partition
if relative_pos > 0.6: # >60% means much closer to first shell
batch_partition_ids[i] = shell_id_1
# If significantly closer to second shell, assign to its partition
elif relative_pos < 0.4: # <40% means much closer to second shell
batch_partition_ids[i] = shell_id_2
else:
# In between -> assign to the partition between these shells
# Use the smaller shell ID (closer to nucleus)
batch_partition_ids[i] = min(shell_id_1, shell_id_2)
else:
# Fallback: use closest shell's partition
batch_partition_ids[i] = shell_indices[closest_shell_idx]
# Assign batch results to partition mask
for i, coord in enumerate(batch_coords):
z, y, x = coord
partition_mask[z, y, x] = batch_partition_ids[i]
if (batch_idx + 1) % 20 == 0:
print(f"Processed batch {batch_idx + 1}/{n_batches}")
# Post-processing: Fill gaps and smooth boundaries
print("Post-processing partition mask...")
# Fill small gaps within partitions
for partition_id in shell_indices:
partition_region = partition_mask == partition_id
if np.sum(partition_region) > 0:
# Small morphological closing to fill gaps
from scipy.ndimage import binary_closing
filled_region = binary_closing(partition_region, iterations=1)
partition_mask[filled_region & cytoplasm_region] = partition_id
# Verify partition results
unique_partitions = np.unique(partition_mask)
print("Final partition distribution:")
for partition_id in unique_partitions:
if partition_id != 0:
count = np.sum(partition_mask == partition_id)
print(f"Partition {partition_id}: {count} voxels")
print(f"Successfully created {len(unique_partitions)-1} pure pair-based partitions")
return partition_mask
[docs]
def create_nepm_radial_partitions_with_edt(self, ne_edge, pm_edge, shape, n_slices=None, pm_mask=None, ne_mask=None):
"""
Optimized version of radial partition generation
Parameters
----------
ne_edge : numpy.ndarray
Nuclear envelope boundary coordinates array, shape (N, 3)
pm_edge : numpy.ndarray
Plasma membrane boundary coordinates array, shape (M, 3)
shape : tuple
Shape of original image volume (Z, Y, X)
n_slices : int, optional
Number of radial partitions, defaults to self.n_slices
pm_mask : numpy.ndarray, optional
Original plasma membrane mask for improved accuracy
ne_mask : numpy.ndarray, optional
Original nuclear envelope mask for improved accuracy
Returns
-------
numpy.ndarray
3D partition mask array where each voxel value indicates partition ID (0=background)
Notes
-----
Uses local EDT-guided pair matching for better handling of irregular cell shapes
"""
if n_slices is None:
n_slices = self.n_slices
print(f"Creating {n_slices} radial partitions using optimized algorithms...")
# Calculate cell center
center = np.mean(ne_edge, axis=0)
# Generate optimized NE-PM pairs
pairs = self.find_ne_pm_pairs(center, ne_edge, pm_edge, step=8, save_txt=False)
print(f"Generated {len(pairs)} NE-PM pairs")
# Set correct n_slices
original_n_slices = self.n_slices
self.n_slices = n_slices - 1
# Use optimized shell generation
if len(pairs) > 1000:
shells = self.shell_partition_optimized(pairs, cell_mask=pm_mask)
else:
shells = self.shell_partition(pairs, cell_mask=pm_mask)
print(f"Generated {len(shells)} shells for {n_slices} partitions")
# Restore original n_slices
self.n_slices = original_n_slices
# Use optimized partition mask conversion
partition_mask = self._convert_shells_to_partition_mask_optimized(shells, shape, pm_mask, ne_mask)
return partition_mask
def _convert_shells_to_partition_mask_optimized(self, shells, shape, pm_mask=None, ne_mask=None):
"""
Optimized version of partition mask conversion with corrected shell assignment logic
"""
from scipy.ndimage import binary_fill_holes, distance_transform_edt, binary_erosion
from scipy.spatial import cKDTree
print(f"Converting {len(shells)} shells to partition mask using corrected assignment logic...")
# Create cytoplasm region mask
if pm_mask is not None and ne_mask is not None:
pm_binary = pm_mask > 0
ne_binary = ne_mask > 0
cell_interior = binary_fill_holes(pm_binary)
nuclear_interior = binary_fill_holes(ne_binary)
if cell_interior is not None and nuclear_interior is not None:
cytoplasm_region = np.asarray(cell_interior, dtype=bool) & (~np.asarray(nuclear_interior, dtype=bool))
pm_eroded = binary_erosion(pm_binary, iterations=2)
pm_boundary = np.asarray(pm_binary, dtype=bool) & (~np.asarray(pm_eroded, dtype=bool))
ne_eroded = binary_erosion(ne_binary, iterations=1)
ne_boundary = np.asarray(ne_binary, dtype=bool) & (~np.asarray(ne_eroded, dtype=bool))
print(f"Cytoplasm region voxels: {np.sum(cytoplasm_region)}")
else:
print("Error: Failed to create interior regions")
return np.zeros(shape, dtype=int)
else:
print("Error: Original masks are required for optimized method")
return np.zeros(shape, dtype=int)
# Get cytoplasm coordinates
cytoplasm_coords = np.column_stack(np.where(cytoplasm_region))
print(f"Processing {len(cytoplasm_coords)} cytoplasm voxels...")
n_partitions = len([s for s in shells if len(s) > 0])
partition_mask = np.zeros(shape, dtype=int)
if n_partitions < 2:
print("Error: Need at least 2 valid shells for partitioning")
return partition_mask
# Build KDTrees for boundary detection
ne_boundary_coords = np.column_stack(np.where(ne_boundary))
pm_boundary_coords = np.column_stack(np.where(pm_boundary))
if len(ne_boundary_coords) > 0 and len(pm_boundary_coords) > 0:
ne_tree = cKDTree(ne_boundary_coords)
pm_tree = cKDTree(pm_boundary_coords)
print("Built boundary KDTrees for direct assignment")
else:
print("Warning: Cannot build boundary trees, falling back to shell-based method")
ne_tree = pm_tree = None
# Process in batches for memory efficiency
batch_size = 50000
n_batches = (len(cytoplasm_coords) + batch_size - 1) // batch_size
# Define distance thresholds for direct assignment
direct_assignment_threshold = 3.0 # pixels
for batch_idx in range(n_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(cytoplasm_coords))
batch_coords = cytoplasm_coords[start_idx:end_idx]
batch_partition_ids = np.zeros(len(batch_coords), dtype=int)
direct_assignment_mask = np.zeros(len(batch_coords), dtype=bool)
# Step 1: Direct assignment for points very close to boundaries
if ne_tree is not None and pm_tree is not None:
# Distance to nuclear boundary
ne_distances, _ = ne_tree.query(batch_coords, k=1)
# Distance to plasma membrane boundary
pm_distances, _ = pm_tree.query(batch_coords, k=1)
# Direct assignment to innermost shell (shell 1) for points very close to nucleus
inner_mask = (ne_distances <= direct_assignment_threshold) & (ne_distances < pm_distances)
batch_partition_ids[inner_mask] = 1
direct_assignment_mask[inner_mask] = True
# Direct assignment to outermost shell for points very close to PM
outer_mask = (pm_distances <= direct_assignment_threshold) & (pm_distances < ne_distances)
batch_partition_ids[outer_mask] = n_partitions
direct_assignment_mask[outer_mask] = True
print(f"Batch {batch_idx}: Direct assignment - {np.sum(inner_mask)} to inner, {np.sum(outer_mask)} to outer")
# Step 2: Shell-based assignment for remaining points
remaining_mask = ~direct_assignment_mask
remaining_coords = batch_coords[remaining_mask]
if len(remaining_coords) > 0:
# Find distances to all shells for remaining points
shell_distances = []
valid_shell_indices = []
for i, shell in enumerate(shells):
if len(shell) > 0:
shell_tree = cKDTree(shell)
distances, _ = shell_tree.query(remaining_coords, k=1)
shell_distances.append(distances)
valid_shell_indices.append(i + 1) # Partition IDs start from 1
if len(shell_distances) > 0:
shell_distances = np.array(shell_distances).T # Shape: (n_remaining, n_shells)
# For each remaining point, find the closest shell
closest_shell_indices = np.argmin(shell_distances, axis=1)
# Assign to closest shell's partition
for i, shell_idx in enumerate(closest_shell_indices):
batch_partition_ids[remaining_mask][i] = valid_shell_indices[shell_idx]
# Assign batch results to partition mask
for i, coord in enumerate(batch_coords):
if batch_partition_ids[i] > 0:
z, y, x = coord
partition_mask[z, y, x] = batch_partition_ids[i]
if (batch_idx + 1) % 20 == 0:
print(f"Processed batch {batch_idx + 1}/{n_batches}")
# Post-processing: Ensure proper shell ordering and fill gaps
print("Post-processing partition mask...")
# Verify and correct partition ordering
unique_partitions = np.unique(partition_mask)
valid_partitions = unique_partitions[unique_partitions > 0]
print("Partition distribution after corrected assignment:")
for partition_id in valid_partitions:
count = np.sum(partition_mask == partition_id)
print(f"Partition {partition_id}: {count} voxels")
# Fill small gaps within partitions
from scipy.ndimage import binary_closing
for partition_id in valid_partitions:
partition_region = partition_mask == partition_id
if np.sum(partition_region) > 0:
filled_region = binary_closing(partition_region, iterations=1)
combined_region = np.asarray(filled_region, dtype=bool) & cytoplasm_region
partition_mask[combined_region] = partition_id
print(f"Successfully created {len(valid_partitions)} corrected partitions")
return partition_mask
def _parallel_shell_processing(self, pairs_chunk, n_slices):
"""
Parallel processing worker function for shell generation
"""
shell_points = [[] for _ in range(n_slices + 1)]
for pair in pairs_chunk:
ne_pt = pair[:3]
pm_pt = pair[3:]
# Generate subdivision points
slice_pts = self.points_along_vector(ne_pt, pm_pt, n_slices)
for i, pt in enumerate(slice_pts):
shell_points[i].append(pt)
return shell_points
[docs]
def shell_partition_optimized(self, pairs, cell_mask=None):
"""
Optimized version of shell partition generation with parallel processing support
"""
print(f"Creating {self.n_slices + 1} shells from {len(pairs)} pairs ...")
# Initialize shell_points for all code paths
shell_points = [[] for _ in range(self.n_slices + 1)]
# Use parallel processing if available and data is large
try:
if len(pairs) > 5000:
from joblib import Parallel, delayed
print("Using parallel processing for shell generation...")
# Chunk processing
n_cores = min(4, os.cpu_count() or 1) # Limit cores to avoid over-parallelization
chunk_size = len(pairs) // n_cores + 1
pairs_chunks = [pairs[i:i+chunk_size] for i in range(0, len(pairs), chunk_size)]
# Parallel processing
results = Parallel(n_jobs=n_cores)(
delayed(self._parallel_shell_processing)(chunk, self.n_slices)
for chunk in pairs_chunks
)
# Merge results
if results is not None:
for result in results:
if result is not None:
for i, shell_pts in enumerate(result):
shell_points[i].extend(shell_pts)
else:
raise ImportError("Data size too small for parallel processing")
except (ImportError, ModuleNotFoundError):
print("Using sequential processing for shell generation...")
# Original sequential processing logic
for pair in pairs:
ne_pt = pair[:3]
pm_pt = pair[3:]
slice_pts = self.points_along_vector(ne_pt, pm_pt, self.n_slices)
for i, pt in enumerate(slice_pts):
shell_points[i].append(pt)
# Cleanup and deduplication
cleaned_shells = []
for i, shell in enumerate(shell_points):
if len(shell) == 0:
cleaned_shells.append(np.array([]))
continue
shell_arr = np.unique(np.array(shell), axis=0)
if cell_mask is not None:
# Vectorized point validation
shell_int = shell_arr.astype(int)
valid_mask = (
(shell_int[:, 0] >= 0) & (shell_int[:, 0] < cell_mask.shape[0]) &
(shell_int[:, 1] >= 0) & (shell_int[:, 1] < cell_mask.shape[1]) &
(shell_int[:, 2] >= 0) & (shell_int[:, 2] < cell_mask.shape[2])
)
if np.any(valid_mask):
valid_shell_int = shell_int[valid_mask]
mask_values = cell_mask[valid_shell_int[:, 0], valid_shell_int[:, 1], valid_shell_int[:, 2]]
final_mask = mask_values > 0
shell_arr = shell_arr[valid_mask][final_mask]
else:
shell_arr = np.array([])
cleaned_shells.append(shell_arr)
print(f"Shell {i}: {len(shell_arr)} points")
print(f"Generated {len(cleaned_shells)} optimized shells total")
return cleaned_shells
[docs]
def save_partition_coords_to_xvg(self, partition_coords, dataid, output_dir):
"""
Save partition coordinates to XVG format file
Parameters
----------
partition_coords : list of numpy.ndarray
List of partition coordinates
dataid : str
Data identifier
output_dir : str
Output directory
Returns
-------
str
Output directory path
"""
xvg_dir = output_dir
os.makedirs(xvg_dir, exist_ok=True)
# Save combined XVG file
combined_xvg_path = os.path.join(xvg_dir, f'{dataid}_partition_coords.xvg')
with open(combined_xvg_path, 'w') as f:
# XVG header
f.write("# XVG file generated by iPA - Partition Coordinates\n")
f.write(f"# All partition coordinates from data ID {dataid}\n")
f.write("# Contains 3D coordinates (z, y, x) and partition index\n")
f.write("@ title \"Partition Coordinates\"\n")
f.write("@ xaxis label \"Z\"\n")
f.write("@ yaxis label \"Y\"\n")
f.write("@TYPE xy\n")
# Define legends for each partition
for i, coords in enumerate(partition_coords):
if len(coords) > 0:
f.write(f"@ s{i} legend \"Partition {i+1}\"\n")
# Write data points
for i, coords in enumerate(partition_coords):
for j, point in enumerate(coords):
if len(point) == 2: # 2D points
f.write(f"{point[0]:.3f} {point[1]:.3f} 0.000 {i+1}\n")
elif len(point) == 3: # 3D points
f.write(f"{point[0]:.3f} {point[1]:.3f} {point[2]:.3f} {i+1}\n")
else:
raise ValueError(f"Unexpected point dimension: {len(point)}")
print(f"[INFO] Saved partition coordinates as XVG to {combined_xvg_path}")
return xvg_dir
[docs]
def create_nepm_radial_partitions(self, ne_edge, pm_edge, shape, n_slices=None, pm_mask=None, ne_mask=None):
"""
Create radial partitions using pure NE-PM pair information without EDT guidance
This method generates radial partitions based solely on the shell points created from
NE-PM pairs, without using Euclidean Distance Transform for edge refinement.
It's more straightforward and computationally efficient than the EDT-guided version.
Parameters
----------
ne_edge : numpy.ndarray
Nuclear envelope boundary coordinates array, shape (N, 3)
pm_edge : numpy.ndarray
Plasma membrane boundary coordinates array, shape (M, 3)
shape : tuple
Shape of original image volume (Z, Y, X)
n_slices : int, optional
Number of radial partitions, defaults to self.n_slices
pm_mask : numpy.ndarray, optional
Original plasma membrane mask for defining cell boundaries
ne_mask : numpy.ndarray, optional
Original nuclear envelope mask for defining nuclear boundaries
Returns
-------
numpy.ndarray
3D partition mask array where each voxel value indicates partition ID (0=background)
Notes
-----
This pure pair-based method is recommended when:
- You want faster computation without EDT overhead
- The cell shape is relatively regular
- You prefer simpler, more interpretable partitioning logic
Examples
--------
>>> partition_mask = partitioner.create_nepm_radial_partitions(
... ne_coords, pm_coords, (100, 200, 200), n_slices=8
... )
>>> print(f"Created partitions with shape: {partition_mask.shape}")
"""
if n_slices is None:
n_slices = self.n_slices
print(f"Creating {n_slices} radial partitions using pure NE-PM pair method...")
# Calculate cell center
center = np.mean(ne_edge, axis=0)
# Generate NE-PM pairs
pairs = self.find_ne_pm_pairs(center, ne_edge, pm_edge, step=8, save_txt=False)
print(f"Generated {len(pairs)} NE-PM pairs")
# Store original n_slices and set for shell generation
original_n_slices = self.n_slices
self.n_slices = n_slices - 1
# Generate shells from pairs
if len(pairs) > 1000:
shells = self.shell_partition_optimized(pairs, cell_mask=pm_mask)
else:
shells = self.shell_partition(pairs, cell_mask=pm_mask)
print(f"Generated {len(shells)} shells for {n_slices} partitions")
# Restore original n_slices
self.n_slices = original_n_slices
# Convert shells to partition mask using pure pair method
partition_mask = self._convert_shells_to_partition_mask_pure_pairs(shells, shape, pm_mask, ne_mask)
return partition_mask
# Keep backward compatibility alias
[docs]
def create_shell_based_partitions(self, ne_edge, pm_edge, shape, n_slices=None, pm_mask=None, ne_mask=None):
"""
Deprecated: Use create_nepm_radial_partitions_with_edt instead.
This method is kept for backward compatibility.
"""
print("Warning: create_shell_based_partitions is deprecated. Use create_nepm_radial_partitions_with_edt instead.")
return self.create_nepm_radial_partitions_with_edt(ne_edge, pm_edge, shape, n_slices, pm_mask, ne_mask)
# Alias for pure pairs method (for clarity)
[docs]
def create_nepm_radial_partitions_pure_pairs(self, ne_edge, pm_edge, shape, n_slices=None, pm_mask=None, ne_mask=None):
"""
Alias for create_nepm_radial_partitions().
This method name explicitly indicates it uses the pure NE-PM pair approach
without EDT guidance. It's provided for clarity and backward compatibility.
See Also
--------
create_nepm_radial_partitions : The main implementation
create_nepm_radial_partitions_with_edt : EDT-guided alternative
"""
return self.create_nepm_radial_partitions(ne_edge, pm_edge, shape, n_slices, pm_mask, ne_mask)