Skip to content

Commit 2fb7d0b

Browse files
committed
minor fix
1 parent 55e0f1b commit 2fb7d0b

File tree

6 files changed

+16
-19
lines changed

6 files changed

+16
-19
lines changed

examples/example.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -954,9 +954,9 @@
954954
],
955955
"metadata": {
956956
"kernelspec": {
957-
"display_name": "nuggetizer",
957+
"display_name": "Python [conda env:base] *",
958958
"language": "python",
959-
"name": "python3"
959+
"name": "conda-base-py"
960960
},
961961
"language_info": {
962962
"codemirror_mode": {
@@ -968,7 +968,7 @@
968968
"name": "python",
969969
"nbconvert_exporter": "python",
970970
"pygments_lexer": "ipython3",
971-
"version": "3.10.16"
971+
"version": "3.10.13"
972972
}
973973
},
974974
"nbformat": 4,

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pytest-subtests
33
pytest-cov
44
pytest-json-report
55
ruff
6+
nbformat

src/open_nuggetizer/measure/_measures.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from ir_measures import measures
22

3-
43
class _AllScore(measures.Measure):
54
__name__ = 'AllScore'
65
NAME = __name__
@@ -11,7 +10,6 @@ class _AllScore(measures.Measure):
1110
'strict': measures.ParamInfo(dtype=bool, default=False, desc='Exclude nuggets partially supported in measure'),
1211
}
1312

14-
1513
AllScore = _AllScore()
1614
measures.register(AllScore)
1715

@@ -48,9 +46,9 @@ class _WeightedScore(measures.Measure):
4846

4947

5048
# debug printing measures registry
51-
def print_measures_registry():
52-
print("Registered measures:")
53-
for measure in measures.registry:
54-
print(measure)
55-
print("Registered measures with details:")
56-
print_measures_registry()
49+
#def print_measures_registry():
50+
# print("Registered measures:")
51+
# for measure in measures.registry:
52+
# print(measure)
53+
# print("Registered measures with details:")
54+
#print_measures_registry()

src/open_nuggetizer/measure/_provider.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,18 @@ def __init__(self, nuggetizer_instance, measures, qrels, invocations):
1515
def _unweighted(self, nuggets, partial_rel, strict, partial_weight):
1616
full_support = [n for n in nuggets if n[1] > partial_rel]
1717
partial_support = [n for n in nuggets if 0 < n[1] <= partial_rel]
18-
1918
if not len(full_support) > 0:
2019
return 0.0
2120

2221
value = len(full_support)
2322
if not strict:
2423
value += partial_weight * len(partial_support)
25-
24+
2625
return value / len(nuggets)
2726

2827
def _weighted(self, nuggets, rel, partial_rel, strict, partial_weight):
29-
vital_nuggets = [n for n in nuggets if n[2] > rel]
30-
okay_nuggets = [n for n in nuggets if 0 < n[2] <= rel]
28+
vital_nuggets = [n for n in nuggets if n[1] > rel]
29+
okay_nuggets = [n for n in nuggets if 0 < n[1] <= rel]
3130

3231
vital_score = self._unweighted(vital_nuggets, partial_rel, strict, partial_weight)
3332
okay_score = self._unweighted(okay_nuggets, partial_rel, strict, partial_weight)
@@ -79,8 +78,6 @@ def __init__(self, nuggetizer_instance):
7978
self.nuggetizer_instance = nuggetizer_instance
8079

8180
def supports(self, measure) -> bool:
82-
print(f"Measure: {measure.NAME}")
83-
print(f"Supported measures: {self.SUPPORTED_MEASURES}")
8481
measure.validate_params()
8582
for supported_measure in self.SUPPORTED_MEASURES:
8683
if measure.NAME == supported_measure.NAME:

src/open_nuggetizer/measure/_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def predict_type(self):
8282

8383
def as_dict_of_dict(self):
8484
t, err = self.predict_type()
85+
print(t, err)
8586
if t == 'dict_of_dict':
8687
return self.qrels
8788
else:

src/open_nuggetizer/nuggetizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def __post_init__(self):
473473

474474
if self.mode == NuggetAssignMode.SUPPORT_GRADE_2:
475475
self.mapping = {
476-
"support": 1,
476+
"support": 2,
477477
"not_support": 0,
478478
}
479479
else:
@@ -514,7 +514,7 @@ def transform_by_query(self, inp: Iterable[dict]) -> Iterable[dict]:
514514
output = self.nuggetizer.generate(prompt)[0].text
515515
assignments.extend(self.prompt.answer_extraction(output))
516516
assignments = [self.mapping.get(x.lower(), 0) for x in assignments]
517-
517+
print("Assignments:", assignments)
518518
return [
519519
{
520520
"qid": qid,

0 commit comments

Comments
 (0)