Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def build_labels_graph(self, connections_path):
labels_graph.add_nodes_from(self.valid_labels)

# Main
for line in util.read_txt(connections_path):
for line in util.read_txt(connections_path).splitlines():
ids = line.split(",")
id_1 = util.get_segment_id(ids[0])
id_2 = util.get_segment_id(ids[1])
Expand Down
112 changes: 111 additions & 1 deletion src/segmentation_skeleton_metrics/data_handling/swc_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
these attributes in the same order.
"""

from botocore import UNSIGNED
from botocore.client import Config
from collections import deque
from concurrent.futures import (
as_completed,
Expand All @@ -29,6 +31,7 @@
from tqdm import tqdm
from zipfile import ZipFile

import boto3
import numpy as np
import os

Expand Down Expand Up @@ -433,6 +436,33 @@ def read_from_s3(self, s3_path):
"""
Reads and parses SWC files from an S3 directory.

Parameters
----------
s3_path : str
Path to a directory in an S3 bucket containing SWC files or ZIPs
of SWC files to be read.

Returns
-------
swc_dicts : Dequeue[dict]
Dictionaries whose keys and values are the attribute names and
values from an SWC file.
"""
# List filenames
bucket_name, prefix = util.parse_cloud_path(s3_path)
swc_paths = util.list_s3_filenames(bucket_name, prefix, ".swc")
zip_paths = util.list_s3_filenames(bucket_name, prefix, ".zip")

# Call reader
if len(swc_paths) > 0:
return self.read_from_s3_swcs(bucket_name, swc_paths)
if len(zip_paths) > 0:
return self.read_from_s3_zips(bucket_name, zip_paths)

def read_from_s3_swcs(self, bucket_name, swc_paths):
"""
Reads and parses SWC files from an S3 directory.

Parameters
----------
s3_path : str
Expand All @@ -452,13 +482,93 @@ def read_from_s3(self, s3_path):
# Parse SWC files
swc_dicts = deque()
for path in swc_paths:
content = util.read_txt_from_s3(bucket_name, path).splitlines()
content = util.read_txt(bucket_name, path).splitlines()
filename = os.path.basename(path)
result = self.parse(content, filename)
if result:
swc_dicts.append(result)
return swc_dicts

def read_from_s3_zips(self, bucket_name, zip_paths):
"""
Reads SWC files stored in a list of ZIP archives stored in an S3
bucket.

Parameters
----------
bucket_name : str
Name of bucket containing SWC files.
zip_paths : str
Path to ZIP archive containing SWC files to be read.

Returns
-------
swc_dicts : Dequeue[dict]
Dictionaries whose keys and values are the attribute names and
values from an SWC file.
"""
with ProcessPoolExecutor() as executor:
# Submit processes
processes = list()
for zip_path in zip_paths:
processes.append(
executor.submit(
self.read_from_s3_zip, bucket_name, zip_path
)
)

# Store results
pbar = tqdm(total=len(processes), desc="Read SWCs")
swc_dicts = deque()
for process in as_completed(processes):
result = process.result()
if result:
swc_dicts.extend(result)
return swc_dicts

def read_from_s3_zip(self, bucket_name, path):
"""
Reads SWC files stored in a ZIP archive downloaded from an S3
bucket.

Parameters
----------
bucket_name : str
Name of bucket containing SWC files.
path : str
Path to ZIP archive containing SWC files to be read.

Returns
-------
swc_dicts : Dequeue[dict]
Dictionaries whose keys and values are the attribute names and
values from an SWC file.
"""
# Initialize cloud reader
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
zip_obj = s3.get_object(Bucket=bucket_name, Key=path)
zip_content = zip_obj["Body"].read()

# Parse ZIP
swc_dicts = deque()
with ZipFile(BytesIO(zip_content), "r") as zip_file:
with ThreadPoolExecutor() as executor:
# Assign threads for reading files
threads = [
executor.submit(
self.read_from_zipped_file, zip_file, filename
)
for filename in zip_file.namelist()
if self.confirm_read(filename)
]

# Collect results
for thread in as_completed(threads):
result = thread.result()
if result:
swc_dicts.append(result)
return swc_dicts

def confirm_read(self, filename):
"""
Checks whether the swc_id corresponding to the given filename is
Expand Down
7 changes: 6 additions & 1 deletion src/segmentation_skeleton_metrics/skeleton_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,12 @@ def __call__(self, gt_graphs, fragment_graphs, merge_sites):
DataFrame where the indices are the dictionary keys and values are
stored under a column called "self.name".
"""
pbar = self.get_pbar(len(merge_sites.index))
# Check if merge sites is non-empty
if len(merge_sites) == 0:
return _

# Compute metric
pbar = self.get_pbar(len(merge_sites))
pair_to_length = dict()
for i in merge_sites.index:
# Extract site info
Expand Down
16 changes: 11 additions & 5 deletions src/segmentation_skeleton_metrics/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,13 @@ def read_txt(path):
List[str]
Lines from the txt file.
"""
with open(path, "r") as f:
return f.read().splitlines()
if is_s3_path(path):
return read_txt_from_s3(path)
elif is_gcs_path(path):
return read_txt_from_gcs(path)
else:
with open(path, "r") as f:
return f.read().splitlines()


def update_txt(path, text, verbose=True):
Expand Down Expand Up @@ -393,7 +398,7 @@ def is_s3_path(path):
return path.startswith("s3://")


def list_s3_paths(bucket_name, prefix, extension=""):
def list_s3_filenames(bucket_name, prefix, extension=""):
"""
Lists all object keys in a public S3 bucket under a given prefix,
optionally filters by file extension.
Expand Down Expand Up @@ -427,7 +432,7 @@ def list_s3_paths(bucket_name, prefix, extension=""):
return filenames


def read_txt_from_s3(bucket_name, path):
def read_txt_from_s3(path):
"""
Reads a txt file stored in an S3 bucket.

Expand All @@ -443,6 +448,7 @@ def read_txt_from_s3(bucket_name, path):
str
Contents of txt file.
"""
bucket_name, path = parse_cloud_path(path)
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
obj = s3.get_object(Bucket=bucket_name, Key=path)
return obj['Body'].read().decode('utf-8')
Expand Down Expand Up @@ -558,7 +564,7 @@ def load_valid_labels(path):
Segment IDs that can be assigned to nodes.
"""
valid_labels = set()
for label_str in read_txt(path):
for label_str in read_txt(path).splitlines():
valid_labels.add(int(label_str.split(".")[0]))
return valid_labels

Expand Down
Loading