Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 38 additions & 35 deletions alphadia/fragcomp/fragcomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ def _get_fragment_overlap(
frag_mz_2: np.ndarray,
mass_tol_ppm: float = 10,
) -> int:
"""
Get the number of overlapping fragments between two spectra.
"""Get the number of overlapping fragments between two spectra.

Parameters
----------

frag_mz_1: np.ndarray
The m/z values of the first spectrum.

Expand All @@ -39,38 +37,41 @@ def _get_fragment_overlap(
-------
int
The number of overlapping fragments.

"""
frag_mz_1 = frag_mz_1.reshape(-1, 1)
frag_mz_2 = frag_mz_2.reshape(1, -1)
delta_mz = np.abs(frag_mz_1 - frag_mz_2)
ppm_delta_mz = delta_mz / frag_mz_1 * 1e6
frag_overlap = np.sum(ppm_delta_mz < mass_tol_ppm)
return frag_overlap
return np.sum(ppm_delta_mz < mass_tol_ppm)


@timsutils.pjit(cache=USE_NUMBA_CACHING)
def _compete_for_fragments(
thread_idx: int,
def _compete_for_fragments( # noqa: PLR0913 # Too many arguments
thread_idx: int, # pjit decorator changes the passed argument from an iterable to single index
precursor_start_idxs: np.ndarray,
precursor_stop_idxs: np.ndarray,
rt: np.ndarray,
valid: np.ndarray,
frag_start_idx: np.ndarray,
frag_stop_idx: np.ndarray,
fragment_mz: np.ndarray,
rt_tol_seconds: float = 3,
mass_tol_ppm: float = 15,
):
"""
Remove PSMs that share fragments with other PSMs.
rt_tol_seconds: float,
mass_tol_ppm: float,
Comment on lines +59 to +60
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed double defaults

) -> None:
"""Remove PSMs that share fragments with other PSMs.

The function is applied on a dia window basis.

The pjit decorator thread-parallelizes over the first argument index and additionally wraps with numba.njit(nogil=True).
Make sure to read and understand the pjit decorator, especially how it changes the type of the first argument.

Parameters
----------

thread_idx: int
The thread index. Each thread will handle one dia window.
The function will be wrapped in a pjit decorator and will be parallelized over this index.
The pjit decorator effectively changes the type of this argument to `np.ndarray` and thread-parallelizes
over it.

precursor_start_idxs: np.ndarray
Array of length n_windows. The start indices of the precursors in the PSM dataframe.
Expand All @@ -82,7 +83,7 @@ def _compete_for_fragments(
The retention times of the precursors.

valid: np.ndarray
Array of length n_psms. The validity of each PSM.
Array of length n_psms. The validity of each PSM. This is where the method output will be stored.

frag_start_idx: np.ndarray
Array of length n_psms. The start indices of the fragments in the fragment dataframe.
Expand All @@ -98,8 +99,12 @@ def _compete_for_fragments(

mass_tol_ppm: float
The mass tolerance in ppm.
"""

Returns
-------
None, but modifies the `valid` array in place.

"""
precursor_start_idx = precursor_start_idxs[thread_idx]
precursor_stop_idx = precursor_stop_idxs[thread_idx]

Expand Down Expand Up @@ -130,22 +135,22 @@ def _compete_for_fragments(
],
mass_tol_ppm=mass_tol_ppm,
)
if fragment_overlap >= 3:
if fragment_overlap >= 3: # noqa: PLR2004
valid_window[j] = False

valid[precursor_start_idx:precursor_stop_idx] = valid_window


class FragmentCompetition:
"""Fragment competition class to remove PSMs that share fragments with other PSMs."""

def __init__(
self, rt_tol_seconds: int = 3, mass_tol_ppm: int = 15, thread_count: int = 8
):
"""
Remove PSMs that share fragments with other PSMs.
"""Remove PSMs that share fragments with other PSMs.

Parameters
----------

rt_tol_seconds: int
The retention time tolerance in seconds.

Expand All @@ -161,13 +166,11 @@ def __init__(
self.thread_count = thread_count

@staticmethod
def _add_window_idx(psm_df: pd.DataFrame, cycle: np.ndarray):
"""
Add the window index to the PSM dataframe.
def _add_window_idx(psm_df: pd.DataFrame, cycle: np.ndarray) -> pd.DataFrame:
"""Add the window index to the PSM dataframe.

Parameters
----------

psm_df: pd.DataFrame
The PSM dataframe.

Expand All @@ -178,8 +181,8 @@ def _add_window_idx(psm_df: pd.DataFrame, cycle: np.ndarray):
-------
pd.DataFrame
The PSM dataframe with the window index.
"""

"""
if "window_idx" in psm_df.columns:
logger.warning("Window index already present in PSM dataframe. Skipping.")
return psm_df
Expand All @@ -195,21 +198,21 @@ def _add_window_idx(psm_df: pd.DataFrame, cycle: np.ndarray):
return psm_df

@staticmethod
def _get_thread_plan_df(psm_df: pd.DataFrame):
"""
Expects a dataframe sorted by window idxs and qvals.
def _get_thread_plan_df(psm_df: pd.DataFrame) -> pd.DataFrame:
"""Expects a dataframe sorted by window idxs and qvals.

Returns a dataframe with start and stop indices of the threads.

Parameters
----------

psm_df: pd.DataFrame
The PSM dataframe.

Returns
-------
pd.DataFrame
The thread plan dataframe.

"""
psm_df["_thread_idx"] = np.arange(len(psm_df))
index_df = psm_df.groupby("window_idx", as_index=False).agg(
Expand All @@ -224,12 +227,10 @@ def _get_thread_plan_df(psm_df: pd.DataFrame):
def __call__(
self, psm_df: pd.DataFrame, frag_df: pd.DataFrame, cycle: np.ndarray
) -> pd.DataFrame:
"""
Remove PSMs that share fragments with other PSMs.
"""Remove PSMs that share fragments with other PSMs.

Parameters
----------

psm_df: pd.DataFrame
The PSM dataframe.

Expand All @@ -238,13 +239,13 @@ def __call__(

cycle: np.ndarray
DIA cycle as provided by alphatims.

Returns
-------

pd.DataFrame
The PSM dataframe with the valid column.
"""

"""
psm_df["_candidate_idx"] = utils.candidate_hash(
psm_df["precursor_idx"].values, psm_df["rank"].values
)
Expand All @@ -265,14 +266,16 @@ def __call__(
thread_plan_df = self._get_thread_plan_df(psm_df)

_compete_for_fragments(
np.arange(len(thread_plan_df)),
np.arange(len(thread_plan_df)), # type: ignore # noqa: PGH003 # function is wrapped by pjit -> will be turned into single index and passed to the method
thread_plan_df["start_idx"].values,
thread_plan_df["stop_idx"].values,
psm_df["rt_observed"].values,
valid,
psm_df["_frag_start_idx"].values,
psm_df["_frag_stop_idx"].values,
frag_df["mz_observed"].values,
self.rt_tol_seconds,
self.mass_tol_ppm,
)

psm_df["valid"] = valid
Expand Down
13 changes: 8 additions & 5 deletions alphadia/fragcomp/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Utility methods for fragment competition."""

import logging

import numpy as np
Expand All @@ -6,14 +8,15 @@
logger = logging.getLogger(__name__)


def add_frag_start_stop_idx(psm_df: pd.DataFrame, frag_df: pd.DataFrame):
"""
The fragment dataframe is indexed by the precursor index.
def add_frag_start_stop_idx(
psm_df: pd.DataFrame, frag_df: pd.DataFrame
) -> pd.DataFrame:
"""The fragment dataframe is indexed by the precursor index.

This function adds the start and stop indices of the fragments to the PSM dataframe.

Parameters
----------

psm_df: pd.DataFrame
The PSM dataframe.

Expand All @@ -24,8 +27,8 @@ def add_frag_start_stop_idx(psm_df: pd.DataFrame, frag_df: pd.DataFrame):
-------
pd.DataFrame
The PSM dataframe with the start and stop indices of the fragments.
"""

"""
if "_frag_start_idx" in psm_df.columns and "_frag_stop_idx" in psm_df.columns:
logger.warning(
"Fragment start and stop indices already present in PSM dataframe. Skipping."
Expand Down
1 change: 0 additions & 1 deletion ruff-lint-strict.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ exclude = [
"alphadia/cli.py",
"alphadia/exceptions.py",
"alphadia/plexscoring/features/**.py",
"alphadia/fragcomp/**.py",
"alphadia/grouping.py",
"alphadia/outputaccumulator.py",
"alphadia/outputtransform/**.py",
Expand Down
18 changes: 17 additions & 1 deletion tests/unit_tests/test_fragcomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def test_compete_for_fragments():
frag_start_idx,
frag_stop_idx,
fragment_mz,
3,
15,
)

assert np.all(valid == np.array([True, True, False, True, False, True]))
Expand Down Expand Up @@ -77,7 +79,21 @@ def test_fragment_competition():
}
)

# when
fragment_competition = FragmentCompetition()
psm_df = fragment_competition(psm_df, frag_df, cycle)

assert len(psm_df) == 4
pd.testing.assert_frame_equal(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit nicer than just checking for the length

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"a bit nicer" in the sense "actually a test"? :-D

psm_df.reset_index(drop=True),
pd.DataFrame(
{
"precursor_idx": np.array([0, 1, 3, 5], dtype=np.uint32),
"rt_observed": np.array([10.0, 20.0, 10.0, 20]),
"valid": np.array([True] * 4),
"mz_observed": np.array([100, 100, 200, 200]),
"proba": np.array([0.1, 0.2, 0.4, 0.6]),
"rank": np.array([0, 0, 0, 0], dtype=np.uint8),
"_candidate_idx": np.array([0, 1, 3, 5]),
}
),
)
Loading