Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 30 additions & 36 deletions scripts/create_data_fast_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
import numpy as np
from segger.data.parquet._utils import get_polygons_from_xy

xenium_data_dir = Path('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/')
segger_data_dir = Path('data_tidy/pyg_datasets/bc_rep1_emb_200_final')
xenium_data_dir = Path("data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/")
segger_data_dir = Path("data_tidy/pyg_datasets/bc_rep1_emb_200_final")


scrnaseq_file = Path('/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad')
celltype_column = 'celltype_minor'
gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(
sc.read(scrnaseq_file),
celltype_column
)
scrnaseq_file = Path("/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad")
celltype_column = "celltype_minor"
gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(sc.read(scrnaseq_file), celltype_column)

sample = STSampleParquet(
base_dir=xenium_data_dir,
Expand All @@ -43,30 +40,29 @@


sample.save(
data_dir=segger_data_dir,
k_bd=3,
dist_bd=15,
k_tx=3,
dist_tx=5,
tile_width=200,
tile_height=200,
neg_sampling_ratio=5.0,
frac=1.0,
val_prob=0.3,
test_prob=0,
data_dir=segger_data_dir,
k_bd=3,
dist_bd=15,
k_tx=3,
dist_tx=5,
tile_width=200,
tile_height=200,
neg_sampling_ratio=5.0,
frac=1.0,
val_prob=0.3,
test_prob=0,
)


xenium_data_dir = Path('data_tidy/bc_5k')
segger_data_dir = Path('data_tidy/pyg_datasets/bc_5k_emb_new')

xenium_data_dir = Path("data_tidy/bc_5k")
segger_data_dir = Path("data_tidy/pyg_datasets/bc_5k_emb_new")


sample = STSampleParquet(
base_dir=xenium_data_dir,
n_workers=8,
sample_type='xenium',
weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available
sample_type="xenium",
weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available
)


Expand All @@ -88,16 +84,14 @@


sample.save(
data_dir=segger_data_dir,
k_bd=3,
dist_bd=15.0,
k_tx=15,
dist_tx=3,
tile_size=50_000,
neg_sampling_ratio=5.0,
frac=0.1,
val_prob=0.1,
test_prob=0.1,
data_dir=segger_data_dir,
k_bd=3,
dist_bd=15.0,
k_tx=15,
dist_tx=3,
tile_size=50_000,
neg_sampling_ratio=5.0,
frac=0.1,
val_prob=0.1,
test_prob=0.1,
)


4 changes: 2 additions & 2 deletions scripts/predict_model_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
seg_tag = "bc_fast_data_emb_major"
model_version = 1

segger_data_dir = Path('data_tidy/pyg_datasets') / seg_tag
models_dir = Path("./models") / seg_tag
segger_data_dir = Path("data_tidy/pyg_datasets") / seg_tag
models_dir = Path("./models") / seg_tag
benchmarks_dir = Path("/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc")
transcripts_file = "data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet"
# Initialize the Lightning data module
Expand Down
27 changes: 12 additions & 15 deletions scripts/train_model_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os


segger_data_dir = segger_data_dir = Path('data_tidy/pyg_datasets/bc_rep1_emb_final_200')
segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/bc_rep1_emb_final_200")
models_dir = Path("./models/bc_rep1_emb_final_200")

# Base directory to store Pytorch Lightning models
Expand All @@ -35,37 +35,34 @@

# If you use custom gene embeddings, use the following two lines instead:
is_token_based = False
num_tx_tokens = dm.train[0].x_dict["tx"].shape[1] # Set the number of tokens to the number of genes
num_tx_tokens = dm.train[0].x_dict["tx"].shape[1] # Set the number of tokens to the number of genes


num_bd_features = dm.train[0].x_dict["bd"].shape[1]

# Initialize the Lightning model
ls = LitSegger(
is_token_based = is_token_based,
num_node_features = {"tx": num_tx_tokens, "bd": num_bd_features},
init_emb=8,
is_token_based=is_token_based,
num_node_features={"tx": num_tx_tokens, "bd": num_bd_features},
init_emb=8,
hidden_channels=64,
out_channels=16,
heads=4,
num_mid_layers=3,
aggr='sum',
learning_rate=1e-3
aggr="sum",
learning_rate=1e-3,
)

# Initialize the Lightning trainer
trainer = Trainer(
accelerator='cuda',
strategy='auto',
precision='16-mixed',
devices=2, # set higher number if more gpus are available
accelerator="cuda",
strategy="auto",
precision="16-mixed",
devices=2, # set higher number if more gpus are available
max_epochs=400,
default_root_dir=models_dir,
logger=CSVLogger(models_dir),
)


trainer.fit(
model=ls,
datamodule=dm
)
trainer.fit(model=ls, datamodule=dm)
10 changes: 5 additions & 5 deletions src/segger/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def try_import(module_name):
from datetime import timedelta


def filter_transcripts( #ONLY FOR XENIUM
def filter_transcripts( # ONLY FOR XENIUM
transcripts_df: pd.DataFrame,
min_qv: float = 20.0,
) -> pd.DataFrame:
Expand All @@ -65,14 +65,14 @@ def filter_transcripts( #ONLY FOR XENIUM
"DeprecatedCodeword_",
"UnassignedCodeword_",
)
transcripts_df['feature_name'] = transcripts_df['feature_name'].apply(

transcripts_df["feature_name"] = transcripts_df["feature_name"].apply(
lambda x: x.decode("utf-8") if isinstance(x, bytes) else x
)
mask_quality = transcripts_df['qv'] >= min_qv
mask_quality = transcripts_df["qv"] >= min_qv

# Apply the filter for unwanted codewords using Dask string functions
mask_codewords = ~transcripts_df['feature_name'].str.startswith(filter_codewords)
mask_codewords = ~transcripts_df["feature_name"].str.startswith(filter_codewords)

# Combine the filters and return the filtered Dask DataFrame
mask = mask_quality & mask_codewords
Expand Down
52 changes: 33 additions & 19 deletions src/segger/prediction/predict_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
from pathlib import Path
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from segger.data.utils import (
get_edge_index,
format_time,
create_anndata,
coo_to_dense_adj,
filter_transcripts
)
from segger.data.utils import get_edge_index, format_time, create_anndata, coo_to_dense_adj, filter_transcripts
from segger.training.train import LitSegger
from segger.training.segger_data_module import SeggerDataModule
from segger.prediction.boundary import generate_boundaries
Expand All @@ -36,7 +30,7 @@
from cupyx.scipy.sparse import coo_matrix
from torch.utils.dlpack import to_dlpack, from_dlpack

from dask.distributed import Client, LocalCluster
from dask.distributed import Client, LocalCluster, Future
import cupy as cp
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -286,6 +280,7 @@ def sparse_multiply(embeddings, edge_index, shape) -> coo_matrix:


def predict_batch(
client: Client,
lit_segger: torch.nn.Module,
batch: Batch,
score_cut: float,
Expand All @@ -295,12 +290,13 @@ def predict_batch(
edge_index_save_path: Union[str, Path] = None,
output_ddf_save_path: Union[str, Path] = None,
gpu_id: int = 0, # Added argument for GPU ID
):
) -> tuple[Future | None, Future | None]:
"""
Predict cell assignments for a batch of transcript data using a segmentation model.
Writes both the assignments and edge_index directly into Parquet files incrementally.

Args:
client (Client): The client to connect to and submit computation to a dask cluster.
lit_segger (torch.nn.Module): The lightning module wrapping the segmentation model.
batch (Batch): A batch of transcript and cell data.
score_cut (float): The threshold for assigning transcripts to cells based on similarity scores.
Expand All @@ -315,6 +311,8 @@ def predict_batch(
gpu_id (int, optional): The GPU ID to use for the computations. Defaults to 0.
"""

delayed_write_edge_index_future, delayed_write_output_ddf_future = None, None

def _get_id():
"""Generate a random Xenium-style ID."""
return "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx"
Expand Down Expand Up @@ -410,7 +408,7 @@ def _get_id():
delayed_write_edge_index = delayed(edge_index_ddf.to_parquet)(
edge_index_save_path, append=True, ignore_divisions=True
)
delayed_write_edge_index.persist() # Schedule writing
delayed_write_edge_index_future = client.persist(delayed_write_edge_index) # Schedule writing

assignments = {
"transcript_id": assignments["transcript_id"].astype("str"),
Expand All @@ -428,12 +426,14 @@ def _get_id():
delayed_write_output_ddf = delayed(batch_ddf.to_parquet)(
output_ddf_save_path, append=True, ignore_divisions=True
)
delayed_write_output_ddf.persist() # Schedule writing
delayed_write_output_ddf_future = client.persist(delayed_write_output_ddf) # Schedule writing

# Free memory after computation
cp.get_default_memory_pool().free_all_blocks() # Free CuPy memory
torch.cuda.empty_cache()

return delayed_write_edge_index_future, delayed_write_output_ddf_future


def segment(
model: LitSegger,
Expand Down Expand Up @@ -482,6 +482,7 @@ def segment(
None. Saves the result to disk in various formats and logs the parameter choices.
"""

client = Client()
start_time = time()

# Create a subdirectory with important parameter info (receptive field values)
Expand Down Expand Up @@ -511,6 +512,9 @@ def segment(
val_dataloader = dm.val_dataloader()
test_dataloader = dm.test_dataloader()

delayed_write_edge_index_futures = []
delayed_write_output_ddf_futures = []

# Loop through the data loaders (train, val, and test)
for loader_name, loader in zip(
["Train", "Validation", "Test"], [train_dataloader, val_dataloader, test_dataloader]
Expand All @@ -522,7 +526,8 @@ def segment(
for batch in tqdm(loader, desc=f"Processing {loader_name} batches"):
gpu_id = random.choice(gpu_ids)
# Call predict_batch for each batch
predict_batch(
delayed_write_edge_index_future, delayed_write_output_ddf_future = predict_batch(
client,
model,
batch,
score_cut,
Expand All @@ -534,23 +539,30 @@ def segment(
gpu_id=gpu_id,
)

if delayed_write_edge_index_future is not None:
delayed_write_edge_index_futures.append(delayed_write_edge_index_future)

if delayed_write_output_ddf_future is not None:
delayed_write_output_ddf_futures.append(delayed_write_output_ddf_future)

if verbose:
elapsed_time = time() - step_start_time
print(f"Batch processing completed in {elapsed_time:.2f} seconds.")

client.gather(delayed_write_output_ddf_futures)
assert os.path.exists(output_ddf_save_path)
seg_final_dd = pd.read_parquet(output_ddf_save_path)

step_start_time = time()
if verbose:
print(f"Applying max score selection logic...")
output_ddf_save_path = save_dir / "transcripts_df.parquet"



seg_final_dd = pd.read_parquet(output_ddf_save_path)
seg_final_filtered = seg_final_dd.sort_values(
"score", ascending=False
).drop_duplicates(subset="transcript_id", keep="first")

seg_final_filtered = seg_final_dd.sort_values("score", ascending=False).drop_duplicates(
subset="transcript_id", keep="first"
)

if verbose:
elapsed_time = time() - step_start_time
Expand All @@ -570,7 +582,7 @@ def segment(

# Outer merge to include all transcripts, even those without assigned cell ids
transcripts_df_filtered = transcripts_df.merge(seg_final_filtered, on="transcript_id", how="outer")

if verbose:
elapsed_time = time() - step_start_time
print(f"Merged segmentation results with transcripts in {elapsed_time:.2f} seconds.")
Expand All @@ -581,6 +593,8 @@ def segment(
if verbose:
print(f"Computing connected components for unassigned transcripts...")
# Load edge indices from saved Parquet
client.gather(delayed_write_edge_index_futures)
assert os.path.exists(edge_index_save_path)
edge_index_dd = pd.read_parquet(edge_index_save_path)

# Step 2: Get unique transcript_ids from edge_index_dd and their positional indices
Expand Down