Skip to content

Commit d3a2a2b

Browse files
[REFACTOR] Continue refactor
1 parent 3b03851 commit d3a2a2b

File tree

10 files changed

+63
-637
lines changed

10 files changed

+63
-637
lines changed

src/flexrag/datasets/__init__.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +0,0 @@
1-
# datasets
2-
from .corpora import IterableCorpus, MappingCorpus
3-
from .dataset import ChainDataset, ConcatDataset, IterableDataset, MappingDataset
4-
from .hf_dataset import HFDataset, HFDatasetConfig
5-
from .qa_dataset import (
6-
QA_DATASETS,
7-
FlashQADataset,
8-
FlashQADatasetConfig,
9-
QADataset,
10-
QADatasetConfig,
11-
QAEvalData,
12-
)
13-
from .retrieval_datasets import (
14-
RETRIEVAL_DATASETS,
15-
IREvalData,
16-
MSMARCODataset,
17-
MSMARCODatasetConfig,
18-
MTEBDataset,
19-
MTEBDatasetConfig,
20-
MultiLongDocRetrievalDataset,
21-
MultiLongDocRetrievalDatasetConfig,
22-
RetrievalDatasetBase,
23-
)
24-
25-
__all__ = [
26-
"ChainDataset",
27-
"IterableDataset",
28-
"MappingDataset",
29-
"ConcatDataset",
30-
"HFDataset",
31-
"HFDatasetConfig",
32-
"IterableCorpus",
33-
"MappingCorpus",
34-
"QA_DATASETS",
35-
"FlashQADataset",
36-
"FlashQADatasetConfig",
37-
"QADataset",
38-
"QAEvalData",
39-
"QADatasetConfig",
40-
"MSMARCODatasetConfig",
41-
"MSMARCODataset",
42-
"MTEBDataset",
43-
"MTEBDatasetConfig",
44-
"RETRIEVAL_DATASETS",
45-
"IREvalData",
46-
"MultiLongDocRetrievalDataset",
47-
"MultiLongDocRetrievalDatasetConfig",
48-
"RetrievalDatasetBase",
49-
]

src/flexrag/datasets/benchmarks/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .multiple_choice import (
22
LongBenchV2Dataset,
33
LongBenchV2DatasetConfig,
4-
NovelQADataset,
54
NovelQAConfig,
5+
NovelQADataset,
66
QuALITYDataset,
77
QuALITYDatasetConfig,
88
)
@@ -30,6 +30,15 @@
3030
SQuADDataset,
3131
SQuADDatasetConfig,
3232
)
33+
from .retrieval import (
34+
MSMARCODataset,
35+
MSMARCODatasetConfig,
36+
MTEBDataset,
37+
MTEBDatasetConfig,
38+
MultiLongDocRetrievalDataset,
39+
MultiLongDocRetrievalDatasetConfig,
40+
RetrievalDatasetBase,
41+
)
3342
from .suites import KiltDataset, KiltDatasetConfig
3443

3544
__all__ = [
@@ -61,6 +70,13 @@
6170
"SimpleQADatasetConfig",
6271
"SQuADDataset",
6372
"SQuADDatasetConfig",
73+
"MSMARCODataset",
74+
"MSMARCODatasetConfig",
75+
"MTEBDataset",
76+
"MTEBDatasetConfig",
77+
"MultiLongDocRetrievalDataset",
78+
"MultiLongDocRetrievalDatasetConfig",
79+
"RetrievalDatasetBase",
6480
"KiltDataset",
6581
"KiltDatasetConfig",
6682
]

src/flexrag/datasets/retrieval_datasets/__init__.py renamed to src/flexrag/datasets/benchmarks/retrieval/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44
)
55
from .msmarco_dataset import MSMARCODataset, MSMARCODatasetConfig
66
from .mteb_dataset import MTEBDataset, MTEBDatasetConfig
7-
from .retrieval_dataset import RETRIEVAL_DATASETS, IREvalData, RetrievalDatasetBase
7+
from .retrieval_dataset_base import RetrievalDatasetBase
88

99
__all__ = [
10+
"RetrievalDatasetBase",
1011
"MultiLongDocRetrievalDataset",
1112
"MultiLongDocRetrievalDatasetConfig",
12-
"RETRIEVAL_DATASETS",
13-
"IREvalData",
14-
"RetrievalDatasetBase",
13+
"MSMARCODataset",
14+
"MSMARCODatasetConfig",
1515
"MTEBDataset",
1616
"MTEBDatasetConfig",
17-
"MSMARCODatasetConfig",
18-
"MSMARCODataset",
1917
]

src/flexrag/datasets/retrieval_datasets/mldr_dataset.py renamed to src/flexrag/datasets/benchmarks/retrieval/mldr_dataset.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from collections import defaultdict
22
from pathlib import Path
3-
from typing import Mapping
3+
from typing import Annotated, Mapping, Optional
44

55
from huggingface_hub import snapshot_download
66

7-
from flexrag.common import FLEXRAG_CACHE_DIR, LOGGER_MANAGER, configure
7+
from flexrag.common import FLEXRAG_CACHE_DIR, LOGGER_MANAGER, Choices, configure
88
from flexrag.common.dataclasses import Context
99

10-
from ..reader import LineDelimitedReader
11-
from .retrieval_dataset import RETRIEVAL_DATASETS, RetrievalDatasetBase
10+
from ...core import DATASETS
11+
from ...reader import LineDelimitedReader
12+
from .retrieval_dataset_base import RetrievalDatasetBase
1213

1314
logger = LOGGER_MANAGER.get_logger("flexrag.datasets.mldr_dataset")
1415

@@ -45,11 +46,28 @@ class MultiLongDocRetrievalDatasetConfig:
4546
"""
4647

4748
split: str
48-
lang: str = "en"
49-
data_path: str | None = None
50-
51-
52-
@RETRIEVAL_DATASETS("mldr", config_class=MultiLongDocRetrievalDatasetConfig)
49+
lang: Annotated[
50+
str,
51+
Choices(
52+
"ar",
53+
"de",
54+
"en",
55+
"es",
56+
"fr",
57+
"hi",
58+
"it",
59+
"ja",
60+
"ko",
61+
"pt",
62+
"ru",
63+
"th",
64+
"zh",
65+
),
66+
] = "en"
67+
data_path: Optional[str] = None
68+
69+
70+
@DATASETS("mldr", config_class=MultiLongDocRetrievalDatasetConfig)
5371
class MultiLongDocRetrievalDataset(RetrievalDatasetBase):
5472
def __init__(self, config: MultiLongDocRetrievalDatasetConfig) -> None:
5573
# prepare dataset path
@@ -112,16 +130,13 @@ def __init__(self, config: MultiLongDocRetrievalDatasetConfig) -> None:
112130
return
113131

114132
@property
115-
def _contexts(self) -> Mapping[str, Context]:
116-
"""Return a mapping from context_id to Context object."""
117-
return self._context_data
133+
def _qrels(self) -> Mapping[str, Mapping[str, float]]:
134+
return self._qrels_data
118135

119136
@property
120137
def _queries(self) -> Mapping[str, str]:
121-
"""Return a mapping from query_id to query text."""
122138
return self._queries_data
123139

124140
@property
125-
def _qrels(self) -> Mapping[str, set[str]]:
126-
"""Return a mapping from query_id to a set of relevant context_ids."""
127-
return self._qrels_data
141+
def _contexts(self) -> Mapping[str, Context]:
142+
return self._context_data

src/flexrag/datasets/retrieval_datasets/msmarco_dataset.py renamed to src/flexrag/datasets/benchmarks/retrieval/msmarco_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
)
1414
from flexrag.common.dataclasses import Context
1515

16-
from ..reader import LineDelimitedReader
17-
from .retrieval_dataset import RETRIEVAL_DATASETS, RetrievalDatasetBase
16+
from ...core import DATASETS
17+
from ...reader import LineDelimitedReader
18+
from .retrieval_dataset_base import RetrievalDatasetBase
1819

1920
logger = LOGGER_MANAGER.get_logger("flexrag.datasets.msmarco_dataset")
2021

@@ -62,7 +63,7 @@ class MSMARCODatasetConfig:
6263
load_corpus: bool = False
6364

6465

65-
@RETRIEVAL_DATASETS("msmarco", MSMARCODatasetConfig)
66+
@DATASETS("msmarco", MSMARCODatasetConfig)
6667
class MSMARCODataset(RetrievalDatasetBase):
6768
"""Dataset for loading MSMARCO Retrieval Dataset."""
6869

src/flexrag/datasets/retrieval_datasets/mteb_dataset.py renamed to src/flexrag/datasets/benchmarks/retrieval/mteb_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from flexrag.common import configure
77
from flexrag.common.dataclasses import Context
88

9-
from .retrieval_dataset import RETRIEVAL_DATASETS, RetrievalDatasetBase
9+
from ...core import DATASETS
10+
from .retrieval_dataset_base import RetrievalDatasetBase
1011

1112

1213
@configure
@@ -41,7 +42,7 @@ class MTEBDatasetConfig:
4142
load_corpus: bool = False
4243

4344

44-
@RETRIEVAL_DATASETS("mteb", MTEBDatasetConfig)
45+
@DATASETS("mteb", MTEBDatasetConfig)
4546
class MTEBDataset(RetrievalDatasetBase):
4647
"""Dataset for loading MTEB Retrieval Dataset."""
4748

src/flexrag/datasets/retrieval_datasets/retrieval_dataset.py renamed to src/flexrag/datasets/benchmarks/retrieval/retrieval_dataset_base.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,13 @@
11
from abc import abstractmethod
22
from collections.abc import Iterator, Mapping
3-
from dataclasses import field
43
from functools import cached_property
5-
from typing import Optional
64

7-
from flexrag.common import Register, data
85
from flexrag.common.dataclasses import Context
96

10-
from ..dataset import MappingDataset
7+
from ...core import IRSample, MappingDataset
118

129

13-
@data
14-
class IREvalData:
15-
"""The dataclass for Information Retrieval evaluation data.
16-
17-
:param question: The question for evaluation. Required.
18-
:type question: str
19-
:param question_id: The unique identifier for the question. Default: None.
20-
:type question_id: Optional[str]
21-
:param contexts: The contexts related to the question. Default: None.
22-
:type contexts: Optional[list[Context]]
23-
:param hard_negatives: The hard negatives related to the question. Default: None.
24-
:type hard_negatives: Optional[list[Context]]
25-
:param meta_data: The metadata of the evaluation data. Default: {}.
26-
:type meta_data: dict
27-
"""
28-
29-
question: str
30-
question_id: Optional[str] = None
31-
contexts: Optional[list[Context]] = None
32-
hard_negatives: Optional[list[Context]] = None
33-
meta_data: dict = field(default_factory=dict)
34-
35-
36-
class RetrievalDatasetBase(MappingDataset[IREvalData]):
10+
class RetrievalDatasetBase(MappingDataset[IRSample]):
3711
"""Base class for Information Retrieval (IR) datasets.
3812
3913
This class provides a unified interface for accessing IR datasets, which typically consist of:
@@ -162,7 +136,7 @@ def __len__(self) -> int:
162136
"""The number of queries in the qrels."""
163137
return len(self._qids)
164138

165-
def get_item(self, index: int) -> IREvalData:
139+
def get_item(self, index: int) -> IRSample:
166140
qid = self._qids[index]
167141
query = self._queries[qid]
168142
relevant_ctxs = []
@@ -180,12 +154,9 @@ def get_item(self, index: int) -> IREvalData:
180154
self._contexts.get(ctx_id, Context(context_id=ctx_id))
181155
for ctx_id in hard_negatives
182156
]
183-
return IREvalData(
157+
return IRSample(
184158
question=query,
185159
question_id=qid,
186160
contexts=rels,
187161
hard_negatives=negs,
188162
)
189-
190-
191-
RETRIEVAL_DATASETS = Register[RetrievalDatasetBase]("retrieval_dataset")

0 commit comments

Comments
 (0)