From c6e82fd648ecfd83687e84ac25db78aec44e1004 Mon Sep 17 00:00:00 2001 From: Nikolai Dorofeev Date: Sun, 21 Sep 2025 14:32:48 +0200 Subject: [PATCH 1/8] [Req report] refactor requirements filtering --- test2text/pages/reports/report_by_req.py | 322 ++++++++---------- test2text/services/db/client.py | 27 -- test2text/services/repositories/__init__.py | 0 .../repositories/requirements/__init__.py | 2 + .../requirements/fetch_filtered.py | 40 +++ 5 files changed, 186 insertions(+), 205 deletions(-) create mode 100644 test2text/services/repositories/__init__.py create mode 100644 test2text/services/repositories/requirements/__init__.py create mode 100644 test2text/services/repositories/requirements/fetch_filtered.py diff --git a/test2text/pages/reports/report_by_req.py b/test2text/pages/reports/report_by_req.py index 4880a12..01aaa4f 100644 --- a/test2text/pages/reports/report_by_req.py +++ b/test2text/pages/reports/report_by_req.py @@ -1,34 +1,150 @@ from itertools import groupby import numpy as np import streamlit as st -from sqlite_vec import serialize_float32 from test2text.services.utils.math_utils import round_distance +from test2text.services.repositories import requirements SUMMARY_LENGTH = 100 LABELS_SUMMARY_LENGTH = 15 +def display_found_details(data: list): + from test2text.services.utils import unpack_float32 + from test2text.services.visualisation.visualize_vectors import ( + minifold_vectors_2d, + plot_2_sets_in_one_2d, + minifold_vectors_3d, + plot_2_sets_in_one_3d, + ) + def write_annotations(current_annotations: set[tuple]): + st.write("id,", "Summary,", "Distance") + for anno_id, anno_summary, _, distance in current_annotations: + st.write(anno_id, anno_summary, round_distance(distance)) + + + for ( + req_id, + req_external_id, + req_summary, + req_embedding, + ), group in groupby(data, lambda x: x[0:4]): + st.divider() + with st.container(): + st.subheader(f" Inspect Requirement {req_external_id}") + st.write(req_summary) + current_test_cases = dict() + for ( + _, + _, + _, + _, + anno_id, + anno_summary, + anno_embedding, + distance, + case_id, + test_script, + test_case, + ) in group: + current_annotation = current_test_cases.get( + test_case, set() + ) + current_test_cases.update({test_case: current_annotation}) + current_test_cases[test_case].add( + (anno_id, anno_summary, anno_embedding, distance) + ) + + t_cs, anno, viz = st.columns(3) + with t_cs: + with st.container(border=True): + st.write("Test Cases") + st.info("Test cases of chosen Requirement") + st.radio( + "Test cases name", + current_test_cases.keys(), + key="radio_choice", + ) + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + if st.session_state["radio_choice"]: + with anno: + with st.container(border=True): + st.write("Annotations") + st.info( + "List of Annotations for chosen Test case" + ) + write_annotations( + current_annotations=current_test_cases[ + st.session_state["radio_choice"] + ] + ) + with viz: + with st.container(border=True): + st.write("Visualization") + select = st.selectbox( + "Choose type of visualization", ["2D", "3D"] + ) + anno_embeddings = [ + unpack_float32(anno_emb) + for _, _, anno_emb, _ in current_test_cases[ + st.session_state["radio_choice"] + ] + ] + anno_labels = [ + f"{anno_id}" + for anno_id, _, _, _ in current_test_cases[ + st.session_state["radio_choice"] + ] + ] + requirement_vectors = np.array( + [np.array(unpack_float32(req_embedding))] + ) + annotation_vectors = np.array(anno_embeddings) + if select == "2D": + plot_2_sets_in_one_2d( + minifold_vectors_2d( + requirement_vectors + ), + minifold_vectors_2d(annotation_vectors), + "Requirement", + "Annotations", + first_labels=[f"{req_external_id}"], + second_labels=anno_labels, + ) + else: + reqs_vectors_3d = minifold_vectors_3d( + requirement_vectors + ) + anno_vectors_3d = minifold_vectors_3d( + annotation_vectors + ) + plot_2_sets_in_one_3d( + reqs_vectors_3d, + anno_vectors_3d, + "Requirement", + "Annotations", + first_labels=[f"{req_external_id}"], + second_labels=anno_labels, + ) + def make_a_report(): from test2text.services.db import get_db_client - with get_db_client() as db: - from test2text.services.embeddings.embed import embed_requirement - from test2text.services.utils import unpack_float32 - from test2text.services.visualisation.visualize_vectors import ( - minifold_vectors_2d, - plot_2_sets_in_one_2d, - minifold_vectors_3d, - plot_2_sets_in_one_3d, - ) - + with (get_db_client() as db): st.header("Test2Text Report") - def write_annotations(current_annotations: set[tuple]): - st.write("id,", "Summary,", "Distance") - for anno_id, anno_summary, _, distance in current_annotations: - st.write(anno_id, anno_summary, round_distance(distance)) - with st.container(border=True): st.subheader("Filter requirements") with st.expander("🔍 Filters"): @@ -47,62 +163,26 @@ def write_annotations(current_annotations: set[tuple]): ) st.info("Search using embeddings") - where_clauses = [] - params = [] - - if filter_id.strip(): - where_clauses.append("Requirements.id = ?") - params.append(filter_id.strip()) - - if filter_summary.strip(): - where_clauses.append("Requirements.summary LIKE ?") - params.append(f"%{filter_summary.strip()}%") - - distance_sql = "" - distance_order_sql = "" - query_embedding_bytes = None - if filter_embedding.strip(): - query_embedding = embed_requirement(filter_embedding.strip()) - query_embedding_bytes = serialize_float32(query_embedding) - distance_sql = ", vec_distance_L2(embedding, ?) AS distance" - distance_order_sql = "distance ASC, " - with st.container(border=True): st.session_state.update({"req_form_submitting": True}) - data = db.get_ordered_values_from_requirements( - distance_sql, - where_clauses, - distance_order_sql, - params + [query_embedding_bytes] if distance_sql else params, - ) + data = requirements.fetch_filtered_requirements(db, + external_id=filter_id, + text_content=filter_summary, + smart_search_query=filter_embedding) - if distance_sql: - requirements_dict = { - f"{req_external_id} {summary[:SUMMARY_LENGTH]}... [smart search d={round_distance(distance)}]": req_id - for (req_id, req_external_id, summary, distance) in data - } - else: - requirements_dict = { - f"{req_external_id} {summary[:SUMMARY_LENGTH]}...": req_id - for (req_id, req_external_id, summary) in data - } + requirements_dict = { + f"{req_external_id} {summary[:SUMMARY_LENGTH]}...": req_id + for (req_id, req_external_id, summary) in data + } st.subheader("Choose 1 of filtered requirements") - option = st.selectbox( + selected_requirement = st.selectbox( "Choose a requirement to work with", requirements_dict.keys(), key="filter_req_id", ) - if option: - clause = "Requirements.id = ?" - if clause in where_clauses: - idx = where_clauses.index(clause) - params.insert(idx, requirements_dict[option]) - else: - where_clauses.append(clause) - params.append(requirements_dict[option]) - + if selected_requirement: st.subheader("Filter Test cases") with st.expander("🔍 Filters"): @@ -140,123 +220,9 @@ def write_annotations(current_annotations: set[tuple]): "There is no requested data to inspect.\n" "Please check filters, completeness of the data or upload new annotations and requirements." ) - return None - - for ( - req_id, - req_external_id, - req_summary, - req_embedding, - ), group in groupby(rows, lambda x: x[0:4]): - st.divider() - with st.container(): - st.subheader(f" Inspect Requirement {req_external_id}") - st.write(req_summary) - current_test_cases = dict() - for ( - _, - _, - _, - _, - anno_id, - anno_summary, - anno_embedding, - distance, - case_id, - test_script, - test_case, - ) in group: - current_annotation = current_test_cases.get( - test_case, set() - ) - current_test_cases.update({test_case: current_annotation}) - current_test_cases[test_case].add( - (anno_id, anno_summary, anno_embedding, distance) - ) - - t_cs, anno, viz = st.columns(3) - with t_cs: - with st.container(border=True): - st.write("Test Cases") - st.info("Test cases of chosen Requirement") - st.radio( - "Test cases name", - current_test_cases.keys(), - key="radio_choice", - ) - st.markdown( - """ - - """, - unsafe_allow_html=True, - ) + else: + display_found_details(rows) - if st.session_state["radio_choice"]: - with anno: - with st.container(border=True): - st.write("Annotations") - st.info( - "List of Annotations for chosen Test case" - ) - write_annotations( - current_annotations=current_test_cases[ - st.session_state["radio_choice"] - ] - ) - with viz: - with st.container(border=True): - st.write("Visualization") - select = st.selectbox( - "Choose type of visualization", ["2D", "3D"] - ) - anno_embeddings = [ - unpack_float32(anno_emb) - for _, _, anno_emb, _ in current_test_cases[ - st.session_state["radio_choice"] - ] - ] - anno_labels = [ - f"{anno_id}" - for anno_id, _, _, _ in current_test_cases[ - st.session_state["radio_choice"] - ] - ] - requirement_vectors = np.array( - [np.array(unpack_float32(req_embedding))] - ) - annotation_vectors = np.array(anno_embeddings) - if select == "2D": - plot_2_sets_in_one_2d( - minifold_vectors_2d( - requirement_vectors - ), - minifold_vectors_2d(annotation_vectors), - "Requirement", - "Annotations", - first_labels=[f"{req_external_id}"], - second_labels=anno_labels, - ) - else: - reqs_vectors_3d = minifold_vectors_3d( - requirement_vectors - ) - anno_vectors_3d = minifold_vectors_3d( - annotation_vectors - ) - plot_2_sets_in_one_3d( - reqs_vectors_3d, - anno_vectors_3d, - "Requirement", - "Annotations", - first_labels=[f"{req_external_id}"], - second_labels=anno_labels, - ) if __name__ == "__main__": diff --git a/test2text/services/db/client.py b/test2text/services/db/client.py index 3624e61..9ccd6c0 100644 --- a/test2text/services/db/client.py +++ b/test2text/services/db/client.py @@ -233,33 +233,6 @@ def join_all_tables_by_requirements( data = self.conn.execute(sql, params) return data.fetchall() - def get_ordered_values_from_requirements( - self, distance_sql="", where_clauses="", distance_order_sql="", params=None - ) -> list[tuple]: - """ - Extracted values from Requirements table based on the provided where clauses and specified parameters ordered by distance and id. - Return a list of tuples containing : - req_id, - req_external_id, - req_summary, - distance between annotation and requirement embeddings, - """ - where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - sql = f""" - SELECT - Requirements.id as req_id, - Requirements.external_id as req_external_id, - Requirements.summary as req_summary - {distance_sql} - FROM - Requirements - {where_sql} - ORDER BY - {distance_order_sql}Requirements.id - """ - data = self.conn.execute(sql, params) - return data.fetchall() - def get_ordered_values_from_test_cases( self, distance_sql="", where_clauses="", distance_order_sql="", params=None ) -> list[tuple]: diff --git a/test2text/services/repositories/__init__.py b/test2text/services/repositories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test2text/services/repositories/requirements/__init__.py b/test2text/services/repositories/requirements/__init__.py new file mode 100644 index 0000000..f9afd8f --- /dev/null +++ b/test2text/services/repositories/requirements/__init__.py @@ -0,0 +1,2 @@ +__all__ = ["fetch_filtered_requirements",] +from .fetch_filtered import fetch_filtered_requirements \ No newline at end of file diff --git a/test2text/services/repositories/requirements/fetch_filtered.py b/test2text/services/repositories/requirements/fetch_filtered.py new file mode 100644 index 0000000..efbb856 --- /dev/null +++ b/test2text/services/repositories/requirements/fetch_filtered.py @@ -0,0 +1,40 @@ +from typing import Optional + +from sqlite_vec import serialize_float32 + +from test2text.services.db import DbClient + + +def fetch_filtered_requirements(db: DbClient, + *_, + external_id: Optional[str] = None, + text_content: Optional[str] = None, + smart_search_query: Optional[str] = None) -> list[tuple[int, str, str]]: + sql = f""" + SELECT + Requirements.id as req_id, + Requirements.external_id as req_external_id, + Requirements.summary as req_summary + FROM + Requirements + """ + options = [] + if external_id or text_content or smart_search_query: + sql += " WHERE " + conditions = [] + if external_id: + conditions.append("Requirements.external_id LIKE ?") + options.append(f"%{external_id.strip()}%") + if text_content: + conditions.append("Requirements.summary LIKE ?") + options.append(f"%{text_content.strip()}%") + if smart_search_query: + from test2text.services.embeddings.embed import embed_requirement + embedding = embed_requirement(smart_search_query.strip()) + conditions.append("vec_distance(Requirements.embedding, ?) < 0.7") + options.append(serialize_float32(embedding)) + sql += " AND ".join(conditions) + sql += " ORDER BY Requirements.id ASC" + + + return db.conn.execute(sql, options).fetchall() From d72b7a9f96d744f20ffe4584c2829ede125071ec Mon Sep 17 00:00:00 2001 From: Nikolai Dorofeev Date: Sun, 21 Sep 2025 15:25:35 +0200 Subject: [PATCH 2/8] [Req report] Fix data fetching --- test2text/pages/reports/report_by_req.py | 13 ++--- test2text/services/db/client.py | 49 ------------------- .../requirements/fetch_filtered.py | 2 +- .../repositories/test_cases/__init__.py | 3 ++ .../test_cases/fetch_by_requirement.py | 31 ++++++++++++ 5 files changed, 39 insertions(+), 59 deletions(-) create mode 100644 test2text/services/repositories/test_cases/__init__.py create mode 100644 test2text/services/repositories/test_cases/fetch_by_requirement.py diff --git a/test2text/pages/reports/report_by_req.py b/test2text/pages/reports/report_by_req.py index 01aaa4f..ed5a811 100644 --- a/test2text/pages/reports/report_by_req.py +++ b/test2text/pages/reports/report_by_req.py @@ -4,6 +4,7 @@ from test2text.services.utils.math_utils import round_distance from test2text.services.repositories import requirements +from test2text.services.repositories import test_cases SUMMARY_LENGTH = 100 LABELS_SUMMARY_LENGTH = 15 @@ -171,7 +172,7 @@ def make_a_report(): smart_search_query=filter_embedding) requirements_dict = { - f"{req_external_id} {summary[:SUMMARY_LENGTH]}...": req_id + req_id: f"{req_external_id} {summary[:SUMMARY_LENGTH]}..." for (req_id, req_external_id, summary) in data } @@ -180,6 +181,7 @@ def make_a_report(): "Choose a requirement to work with", requirements_dict.keys(), key="filter_req_id", + format_func=lambda x: requirements_dict[x], ) if selected_requirement: @@ -206,14 +208,7 @@ def make_a_report(): ) st.info("Limit of selected test cases") - if filter_radius: - where_clauses.append("distance <= ?") - params.append(f"{filter_radius}") - - if filter_limit: - params.append(f"{filter_limit}") - - rows = db.join_all_tables_by_requirements(where_clauses, params) + rows = test_cases.fetch_test_cases_by_requirement(db, selected_requirement, filter_radius, filter_limit) if not rows: st.error( diff --git a/test2text/services/db/client.py b/test2text/services/db/client.py index 9ccd6c0..bdacf5b 100644 --- a/test2text/services/db/client.py +++ b/test2text/services/db/client.py @@ -184,55 +184,6 @@ def get_embeddings_from_annotations_to_requirements_table(self): """) return cursor.fetchall() - def join_all_tables_by_requirements( - self, where_clauses="", params=None - ) -> list[tuple]: - """ - Extract values from requirements with related annotations and their test cases based on the provided where clauses and parameters. - Return a list of tuples containing : - req_id, - req_external_id, - req_summary, - req_embedding, - anno_id, - anno_summary, - anno_embedding, - distance, - case_id, - test_script, - test_case - """ - where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - sql = f""" - SELECT - Requirements.id as req_id, - Requirements.external_id as req_external_id, - Requirements.summary as req_summary, - Requirements.embedding as req_embedding, - - Annotations.id as anno_id, - Annotations.summary as anno_summary, - Annotations.embedding as anno_embedding, - - AnnotationsToRequirements.cached_distance as distance, - - TestCases.id as case_id, - TestCases.test_script as test_script, - TestCases.test_case as test_case - FROM - Requirements - JOIN AnnotationsToRequirements ON Requirements.id = AnnotationsToRequirements.requirement_id - JOIN Annotations ON Annotations.id = AnnotationsToRequirements.annotation_id - JOIN CasesToAnnos ON Annotations.id = CasesToAnnos.annotation_id - JOIN TestCases ON TestCases.id = CasesToAnnos.case_id - {where_sql} - ORDER BY - Requirements.id, AnnotationsToRequirements.cached_distance, TestCases.id - LIMIT ? - """ - data = self.conn.execute(sql, params) - return data.fetchall() - def get_ordered_values_from_test_cases( self, distance_sql="", where_clauses="", distance_order_sql="", params=None ) -> list[tuple]: diff --git a/test2text/services/repositories/requirements/fetch_filtered.py b/test2text/services/repositories/requirements/fetch_filtered.py index efbb856..4363ef6 100644 --- a/test2text/services/repositories/requirements/fetch_filtered.py +++ b/test2text/services/repositories/requirements/fetch_filtered.py @@ -31,7 +31,7 @@ def fetch_filtered_requirements(db: DbClient, if smart_search_query: from test2text.services.embeddings.embed import embed_requirement embedding = embed_requirement(smart_search_query.strip()) - conditions.append("vec_distance(Requirements.embedding, ?) < 0.7") + conditions.append("vec_distance_L2(Requirements.embedding, ?) < 0.7") options.append(serialize_float32(embedding)) sql += " AND ".join(conditions) sql += " ORDER BY Requirements.id ASC" diff --git a/test2text/services/repositories/test_cases/__init__.py b/test2text/services/repositories/test_cases/__init__.py new file mode 100644 index 0000000..e3e8e4d --- /dev/null +++ b/test2text/services/repositories/test_cases/__init__.py @@ -0,0 +1,3 @@ +__all__=['fetch_test_cases_by_requirement'] + +from .fetch_by_requirement import fetch_test_cases_by_requirement \ No newline at end of file diff --git a/test2text/services/repositories/test_cases/fetch_by_requirement.py b/test2text/services/repositories/test_cases/fetch_by_requirement.py new file mode 100644 index 0000000..ebb42c9 --- /dev/null +++ b/test2text/services/repositories/test_cases/fetch_by_requirement.py @@ -0,0 +1,31 @@ +from test2text.services.db import DbClient + + +def fetch_test_cases_by_requirement(db: DbClient, requirement_id: int, radius: float, limit: int) -> list: + sql = f""" + SELECT + Requirements.id as req_id, + Requirements.external_id as req_external_id, + Requirements.summary as req_summary, + Requirements.embedding as req_embedding, + + Annotations.id as anno_id, + Annotations.summary as anno_summary, + Annotations.embedding as anno_embedding, + + vec_distance_L2(Requirements.embedding, Annotations.embedding) as distance, + + TestCases.id as case_id, + TestCases.test_script as test_script, + TestCases.test_case as test_case + FROM + Requirements + JOIN Annotations ON vec_distance_L2(Requirements.embedding, Annotations.embedding) <= ? + JOIN CasesToAnnos ON Annotations.id = CasesToAnnos.annotation_id + JOIN TestCases ON TestCases.id = CasesToAnnos.case_id + WHERE Requirements.id = ? + ORDER BY + Requirements.id, distance, TestCases.id + LIMIT ? + """ + return db.conn.execute(sql, (radius, requirement_id, limit)).fetchall() \ No newline at end of file From 9ea338f031d834c3d4caf26d7dca2e775ec0a0fb Mon Sep 17 00:00:00 2001 From: Nikolai Dorofeev Date: Sun, 21 Sep 2025 15:26:16 +0200 Subject: [PATCH 3/8] [Req report] Lintfix --- test2text/pages/reports/report_by_req.py | 62 +++++++++---------- .../repositories/requirements/__init__.py | 6 +- .../requirements/fetch_filtered.py | 14 +++-- .../repositories/test_cases/__init__.py | 4 +- .../test_cases/fetch_by_requirement.py | 6 +- 5 files changed, 48 insertions(+), 44 deletions(-) diff --git a/test2text/pages/reports/report_by_req.py b/test2text/pages/reports/report_by_req.py index ed5a811..0567385 100644 --- a/test2text/pages/reports/report_by_req.py +++ b/test2text/pages/reports/report_by_req.py @@ -9,6 +9,7 @@ SUMMARY_LENGTH = 100 LABELS_SUMMARY_LENGTH = 15 + def display_found_details(data: list): from test2text.services.utils import unpack_float32 from test2text.services.visualisation.visualize_vectors import ( @@ -17,17 +18,17 @@ def display_found_details(data: list): minifold_vectors_3d, plot_2_sets_in_one_3d, ) + def write_annotations(current_annotations: set[tuple]): st.write("id,", "Summary,", "Distance") for anno_id, anno_summary, _, distance in current_annotations: st.write(anno_id, anno_summary, round_distance(distance)) - for ( - req_id, - req_external_id, - req_summary, - req_embedding, + req_id, + req_external_id, + req_summary, + req_embedding, ), group in groupby(data, lambda x: x[0:4]): st.divider() with st.container(): @@ -35,21 +36,19 @@ def write_annotations(current_annotations: set[tuple]): st.write(req_summary) current_test_cases = dict() for ( - _, - _, - _, - _, - anno_id, - anno_summary, - anno_embedding, - distance, - case_id, - test_script, - test_case, + _, + _, + _, + _, + anno_id, + anno_summary, + anno_embedding, + distance, + case_id, + test_script, + test_case, ) in group: - current_annotation = current_test_cases.get( - test_case, set() - ) + current_annotation = current_test_cases.get(test_case, set()) current_test_cases.update({test_case: current_annotation}) current_test_cases[test_case].add( (anno_id, anno_summary, anno_embedding, distance) @@ -82,9 +81,7 @@ def write_annotations(current_annotations: set[tuple]): with anno: with st.container(border=True): st.write("Annotations") - st.info( - "List of Annotations for chosen Test case" - ) + st.info("List of Annotations for chosen Test case") write_annotations( current_annotations=current_test_cases[ st.session_state["radio_choice"] @@ -114,9 +111,7 @@ def write_annotations(current_annotations: set[tuple]): annotation_vectors = np.array(anno_embeddings) if select == "2D": plot_2_sets_in_one_2d( - minifold_vectors_2d( - requirement_vectors - ), + minifold_vectors_2d(requirement_vectors), minifold_vectors_2d(annotation_vectors), "Requirement", "Annotations", @@ -143,7 +138,7 @@ def write_annotations(current_annotations: set[tuple]): def make_a_report(): from test2text.services.db import get_db_client - with (get_db_client() as db): + with get_db_client() as db: st.header("Test2Text Report") with st.container(border=True): @@ -166,10 +161,12 @@ def make_a_report(): with st.container(border=True): st.session_state.update({"req_form_submitting": True}) - data = requirements.fetch_filtered_requirements(db, - external_id=filter_id, - text_content=filter_summary, - smart_search_query=filter_embedding) + data = requirements.fetch_filtered_requirements( + db, + external_id=filter_id, + text_content=filter_summary, + smart_search_query=filter_embedding, + ) requirements_dict = { req_id: f"{req_external_id} {summary[:SUMMARY_LENGTH]}..." @@ -208,7 +205,9 @@ def make_a_report(): ) st.info("Limit of selected test cases") - rows = test_cases.fetch_test_cases_by_requirement(db, selected_requirement, filter_radius, filter_limit) + rows = test_cases.fetch_test_cases_by_requirement( + db, selected_requirement, filter_radius, filter_limit + ) if not rows: st.error( @@ -219,6 +218,5 @@ def make_a_report(): display_found_details(rows) - if __name__ == "__main__": make_a_report() diff --git a/test2text/services/repositories/requirements/__init__.py b/test2text/services/repositories/requirements/__init__.py index f9afd8f..69c1e5b 100644 --- a/test2text/services/repositories/requirements/__init__.py +++ b/test2text/services/repositories/requirements/__init__.py @@ -1,2 +1,4 @@ -__all__ = ["fetch_filtered_requirements",] -from .fetch_filtered import fetch_filtered_requirements \ No newline at end of file +__all__ = [ + "fetch_filtered_requirements", +] +from .fetch_filtered import fetch_filtered_requirements diff --git a/test2text/services/repositories/requirements/fetch_filtered.py b/test2text/services/repositories/requirements/fetch_filtered.py index 4363ef6..cd9f9dd 100644 --- a/test2text/services/repositories/requirements/fetch_filtered.py +++ b/test2text/services/repositories/requirements/fetch_filtered.py @@ -5,11 +5,13 @@ from test2text.services.db import DbClient -def fetch_filtered_requirements(db: DbClient, - *_, - external_id: Optional[str] = None, - text_content: Optional[str] = None, - smart_search_query: Optional[str] = None) -> list[tuple[int, str, str]]: +def fetch_filtered_requirements( + db: DbClient, + *_, + external_id: Optional[str] = None, + text_content: Optional[str] = None, + smart_search_query: Optional[str] = None, +) -> list[tuple[int, str, str]]: sql = f""" SELECT Requirements.id as req_id, @@ -30,11 +32,11 @@ def fetch_filtered_requirements(db: DbClient, options.append(f"%{text_content.strip()}%") if smart_search_query: from test2text.services.embeddings.embed import embed_requirement + embedding = embed_requirement(smart_search_query.strip()) conditions.append("vec_distance_L2(Requirements.embedding, ?) < 0.7") options.append(serialize_float32(embedding)) sql += " AND ".join(conditions) sql += " ORDER BY Requirements.id ASC" - return db.conn.execute(sql, options).fetchall() diff --git a/test2text/services/repositories/test_cases/__init__.py b/test2text/services/repositories/test_cases/__init__.py index e3e8e4d..5075b63 100644 --- a/test2text/services/repositories/test_cases/__init__.py +++ b/test2text/services/repositories/test_cases/__init__.py @@ -1,3 +1,3 @@ -__all__=['fetch_test_cases_by_requirement'] +__all__ = ["fetch_test_cases_by_requirement"] -from .fetch_by_requirement import fetch_test_cases_by_requirement \ No newline at end of file +from .fetch_by_requirement import fetch_test_cases_by_requirement diff --git a/test2text/services/repositories/test_cases/fetch_by_requirement.py b/test2text/services/repositories/test_cases/fetch_by_requirement.py index ebb42c9..cc64244 100644 --- a/test2text/services/repositories/test_cases/fetch_by_requirement.py +++ b/test2text/services/repositories/test_cases/fetch_by_requirement.py @@ -1,7 +1,9 @@ from test2text.services.db import DbClient -def fetch_test_cases_by_requirement(db: DbClient, requirement_id: int, radius: float, limit: int) -> list: +def fetch_test_cases_by_requirement( + db: DbClient, requirement_id: int, radius: float, limit: int +) -> list: sql = f""" SELECT Requirements.id as req_id, @@ -28,4 +30,4 @@ def fetch_test_cases_by_requirement(db: DbClient, requirement_id: int, radius: f Requirements.id, distance, TestCases.id LIMIT ? """ - return db.conn.execute(sql, (radius, requirement_id, limit)).fetchall() \ No newline at end of file + return db.conn.execute(sql, (radius, requirement_id, limit)).fetchall() From ad874bfae9ae8ca5cca41e00acb3897e6aaa10e7 Mon Sep 17 00:00:00 2001 From: Nikolai Dorofeev Date: Sun, 21 Sep 2025 16:46:51 +0200 Subject: [PATCH 4/8] [Req report] Fix string template --- .../services/repositories/test_cases/fetch_by_requirement.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test2text/services/repositories/test_cases/fetch_by_requirement.py b/test2text/services/repositories/test_cases/fetch_by_requirement.py index cc64244..894438c 100644 --- a/test2text/services/repositories/test_cases/fetch_by_requirement.py +++ b/test2text/services/repositories/test_cases/fetch_by_requirement.py @@ -4,7 +4,7 @@ def fetch_test_cases_by_requirement( db: DbClient, requirement_id: int, radius: float, limit: int ) -> list: - sql = f""" + sql = """ SELECT Requirements.id as req_id, Requirements.external_id as req_external_id, From 96b7888227c2704e6de428e9a3938948421b22cb Mon Sep 17 00:00:00 2001 From: Nikolai Dorofeev Date: Sun, 21 Sep 2025 17:34:54 +0200 Subject: [PATCH 5/8] [Req report] Display all annotations per requirement --- test2text/pages/reports/report_by_req.py | 264 +++++++++--------- test2text/services/db/tables/requirements.py | 18 ++ test2text/services/db/tables/test_case.py | 20 ++ .../repositories/annotations/__init__.py | 5 + .../annotations/fetch_by_test_case.py | 20 ++ .../test_cases/fetch_by_requirement.py | 17 +- .../test_db/test_tables/test_requirements.py | 7 + tests/test_db/test_tables/test_test_cases.py | 9 + 8 files changed, 210 insertions(+), 150 deletions(-) create mode 100644 test2text/services/repositories/annotations/__init__.py create mode 100644 test2text/services/repositories/annotations/fetch_by_test_case.py diff --git a/test2text/pages/reports/report_by_req.py b/test2text/pages/reports/report_by_req.py index 0567385..bf26519 100644 --- a/test2text/pages/reports/report_by_req.py +++ b/test2text/pages/reports/report_by_req.py @@ -3,138 +3,16 @@ import streamlit as st from test2text.services.utils.math_utils import round_distance -from test2text.services.repositories import requirements -from test2text.services.repositories import test_cases +from test2text.services.repositories import ( + requirements as requirements_repo, + test_cases as test_cases_repo, + annotations as annotations_repo, +) SUMMARY_LENGTH = 100 LABELS_SUMMARY_LENGTH = 15 -def display_found_details(data: list): - from test2text.services.utils import unpack_float32 - from test2text.services.visualisation.visualize_vectors import ( - minifold_vectors_2d, - plot_2_sets_in_one_2d, - minifold_vectors_3d, - plot_2_sets_in_one_3d, - ) - - def write_annotations(current_annotations: set[tuple]): - st.write("id,", "Summary,", "Distance") - for anno_id, anno_summary, _, distance in current_annotations: - st.write(anno_id, anno_summary, round_distance(distance)) - - for ( - req_id, - req_external_id, - req_summary, - req_embedding, - ), group in groupby(data, lambda x: x[0:4]): - st.divider() - with st.container(): - st.subheader(f" Inspect Requirement {req_external_id}") - st.write(req_summary) - current_test_cases = dict() - for ( - _, - _, - _, - _, - anno_id, - anno_summary, - anno_embedding, - distance, - case_id, - test_script, - test_case, - ) in group: - current_annotation = current_test_cases.get(test_case, set()) - current_test_cases.update({test_case: current_annotation}) - current_test_cases[test_case].add( - (anno_id, anno_summary, anno_embedding, distance) - ) - - t_cs, anno, viz = st.columns(3) - with t_cs: - with st.container(border=True): - st.write("Test Cases") - st.info("Test cases of chosen Requirement") - st.radio( - "Test cases name", - current_test_cases.keys(), - key="radio_choice", - ) - st.markdown( - """ - - """, - unsafe_allow_html=True, - ) - - if st.session_state["radio_choice"]: - with anno: - with st.container(border=True): - st.write("Annotations") - st.info("List of Annotations for chosen Test case") - write_annotations( - current_annotations=current_test_cases[ - st.session_state["radio_choice"] - ] - ) - with viz: - with st.container(border=True): - st.write("Visualization") - select = st.selectbox( - "Choose type of visualization", ["2D", "3D"] - ) - anno_embeddings = [ - unpack_float32(anno_emb) - for _, _, anno_emb, _ in current_test_cases[ - st.session_state["radio_choice"] - ] - ] - anno_labels = [ - f"{anno_id}" - for anno_id, _, _, _ in current_test_cases[ - st.session_state["radio_choice"] - ] - ] - requirement_vectors = np.array( - [np.array(unpack_float32(req_embedding))] - ) - annotation_vectors = np.array(anno_embeddings) - if select == "2D": - plot_2_sets_in_one_2d( - minifold_vectors_2d(requirement_vectors), - minifold_vectors_2d(annotation_vectors), - "Requirement", - "Annotations", - first_labels=[f"{req_external_id}"], - second_labels=anno_labels, - ) - else: - reqs_vectors_3d = minifold_vectors_3d( - requirement_vectors - ) - anno_vectors_3d = minifold_vectors_3d( - annotation_vectors - ) - plot_2_sets_in_one_3d( - reqs_vectors_3d, - anno_vectors_3d, - "Requirement", - "Annotations", - first_labels=[f"{req_external_id}"], - second_labels=anno_labels, - ) - - def make_a_report(): from test2text.services.db import get_db_client @@ -161,27 +39,28 @@ def make_a_report(): with st.container(border=True): st.session_state.update({"req_form_submitting": True}) - data = requirements.fetch_filtered_requirements( + requirements = requirements_repo.fetch_filtered_requirements( db, external_id=filter_id, text_content=filter_summary, smart_search_query=filter_embedding, ) - requirements_dict = { - req_id: f"{req_external_id} {summary[:SUMMARY_LENGTH]}..." - for (req_id, req_external_id, summary) in data + requirements = { + req_id: (req_external_id, summary) + for (req_id, req_external_id, summary) in requirements } st.subheader("Choose 1 of filtered requirements") selected_requirement = st.selectbox( "Choose a requirement to work with", - requirements_dict.keys(), + requirements.keys(), key="filter_req_id", - format_func=lambda x: requirements_dict[x], + format_func=lambda x: f"{requirements[x][0]} {requirements[x][1][:SUMMARY_LENGTH]}...", ) if selected_requirement: + requirement = db.requirements.get_by_id_raw(selected_requirement) st.subheader("Filter Test cases") with st.expander("🔍 Filters"): @@ -205,17 +84,130 @@ def make_a_report(): ) st.info("Limit of selected test cases") - rows = test_cases.fetch_test_cases_by_requirement( + test_cases = test_cases_repo.fetch_test_cases_by_requirement( db, selected_requirement, filter_radius, filter_limit ) + test_cases = { + tc_id: (test_script, test_case) + for (tc_id, test_script, test_case) in test_cases + } - if not rows: + if not test_cases: st.error( "There is no requested data to inspect.\n" "Please check filters, completeness of the data or upload new annotations and requirements." ) else: - display_found_details(rows) + from test2text.services.utils import unpack_float32 + from test2text.services.visualisation.visualize_vectors import ( + minifold_vectors_2d, + plot_2_sets_in_one_2d, + minifold_vectors_3d, + plot_2_sets_in_one_3d, + ) + + st.divider() + with st.container(): + st.subheader( + f" Inspect Requirement {requirements[selected_requirement][0]}" + ) + st.write(requirements[selected_requirement][1]) + + t_cs, anno, viz = st.columns(3) + with t_cs: + with st.container(border=True): + st.write("Test Cases") + st.info("Test cases of chosen Requirement") + st.radio( + "Test cases name", + test_cases.keys(), + key="chosen_test_case", + format_func=lambda tc_id: test_cases[tc_id][1], + ) + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + if st.session_state["chosen_test_case"]: + test_case = db.test_cases.get_by_id_raw( + st.session_state["chosen_test_case"] + ) + annotations = annotations_repo.fetch_annotations_by_test_case_with_distance_to_requirement( + db, + st.session_state["chosen_test_case"], + requirement[3], # embedding + ) + with anno: + with st.container(border=True): + st.write("Annotations") + st.info( + "List of Annotations for chosen Test case" + ) + st.write("id,", "Summary,", "Distance") + for ( + anno_id, + anno_summary, + _, + distance, + ) in annotations: + st.write( + anno_id, + anno_summary, + round_distance(distance), + ) + with viz: + with st.container(border=True): + st.write("Visualization") + select = st.selectbox( + "Choose type of visualization", ["2D", "3D"] + ) + anno_embeddings = [ + unpack_float32(anno_emb) + for _, _, anno_emb, _ in annotations + ] + anno_labels = [ + f"{anno_id}" + for anno_id, _, _, _ in annotations + ] + requirement_vectors = np.array( + [np.array(unpack_float32(requirement[3]))] + ) + annotation_vectors = np.array(anno_embeddings) + if select == "2D": + plot_2_sets_in_one_2d( + minifold_vectors_2d( + requirement_vectors + ), + minifold_vectors_2d(annotation_vectors), + "Requirement", + "Annotations", + first_labels=[f"{requirement[1]}"], + second_labels=anno_labels, + ) + else: + reqs_vectors_3d = minifold_vectors_3d( + requirement_vectors + ) + anno_vectors_3d = minifold_vectors_3d( + annotation_vectors + ) + plot_2_sets_in_one_3d( + reqs_vectors_3d, + anno_vectors_3d, + "Requirement", + "Annotations", + first_labels=[f"{requirement[1]}"], + second_labels=anno_labels, + ) if __name__ == "__main__": diff --git a/test2text/services/db/tables/requirements.py b/test2text/services/db/tables/requirements.py index fb4cdd4..2012512 100644 --- a/test2text/services/db/tables/requirements.py +++ b/test2text/services/db/tables/requirements.py @@ -72,3 +72,21 @@ def count(self) -> int: """ cursor = self.connection.execute("SELECT COUNT(*) FROM Requirements") return cursor.fetchone()[0] + + def get_by_id_raw( + self, req_id: int + ) -> Optional[tuple[int, str, str, Optional[bytes]]]: + """ + Retrieves a requirement by its ID. + :param req_id: The ID of the requirement to retrieve. + :return: A tuple containing the requirement's ID, external ID, summary, and embedding, or None if not found. + """ + cursor = self.connection.execute( + """ + SELECT id, external_id, summary, embedding + FROM Requirements + WHERE id = ? + """, + (req_id,), + ) + return cursor.fetchone() diff --git a/test2text/services/db/tables/test_case.py b/test2text/services/db/tables/test_case.py index fd3c99c..2139f9f 100644 --- a/test2text/services/db/tables/test_case.py +++ b/test2text/services/db/tables/test_case.py @@ -93,3 +93,23 @@ def count(self) -> int: """ cursor = self.connection.execute("SELECT COUNT(*) FROM TestCases") return cursor.fetchone()[0] + + def get_by_id_raw( + self, case_id: int + ) -> Optional[tuple[int, str, str, Optional[bytes]]]: + """ + Fetches a test case by its ID. + :param case_id: The ID of the test case to fetch. + :return: A tuple containing the test case's ID, test script, test case, and embedding (if available), or None if not found. + """ + cursor = self.connection.execute( + """ + SELECT id, test_script, test_case, embedding + FROM TestCases + WHERE id = ? + """, + (case_id,), + ) + result = cursor.fetchone() + cursor.close() + return result diff --git a/test2text/services/repositories/annotations/__init__.py b/test2text/services/repositories/annotations/__init__.py new file mode 100644 index 0000000..47a4756 --- /dev/null +++ b/test2text/services/repositories/annotations/__init__.py @@ -0,0 +1,5 @@ +__all__ = ["fetch_annotations_by_test_case_with_distance_to_requirement"] + +from .fetch_by_test_case import ( + fetch_annotations_by_test_case_with_distance_to_requirement, +) diff --git a/test2text/services/repositories/annotations/fetch_by_test_case.py b/test2text/services/repositories/annotations/fetch_by_test_case.py new file mode 100644 index 0000000..bb59cf0 --- /dev/null +++ b/test2text/services/repositories/annotations/fetch_by_test_case.py @@ -0,0 +1,20 @@ +from test2text.services.db import DbClient + + +def fetch_annotations_by_test_case_with_distance_to_requirement( + db: DbClient, test_case_id: int, requirement_embedding: bytes +) -> list[tuple[int, str, bytes, float]]: + sql = """ + SELECT + Annotations.id as anno_id, + Annotations.summary as anno_summary, + Annotations.embedding as anno_embedding, + vec_distance_L2(?, Annotations.embedding) as distance + FROM + Annotations + JOIN CasesToAnnos ON Annotations.id = CasesToAnnos.annotation_id + WHERE CasesToAnnos.case_id = ? + ORDER BY + distance ASC + """ + return db.conn.execute(sql, (requirement_embedding, test_case_id)).fetchall() diff --git a/test2text/services/repositories/test_cases/fetch_by_requirement.py b/test2text/services/repositories/test_cases/fetch_by_requirement.py index 894438c..43d44cc 100644 --- a/test2text/services/repositories/test_cases/fetch_by_requirement.py +++ b/test2text/services/repositories/test_cases/fetch_by_requirement.py @@ -3,20 +3,9 @@ def fetch_test_cases_by_requirement( db: DbClient, requirement_id: int, radius: float, limit: int -) -> list: +) -> list[tuple[int, str, str]]: sql = """ - SELECT - Requirements.id as req_id, - Requirements.external_id as req_external_id, - Requirements.summary as req_summary, - Requirements.embedding as req_embedding, - - Annotations.id as anno_id, - Annotations.summary as anno_summary, - Annotations.embedding as anno_embedding, - - vec_distance_L2(Requirements.embedding, Annotations.embedding) as distance, - + SELECT DISTINCT TestCases.id as case_id, TestCases.test_script as test_script, TestCases.test_case as test_case @@ -27,7 +16,7 @@ def fetch_test_cases_by_requirement( JOIN TestCases ON TestCases.id = CasesToAnnos.case_id WHERE Requirements.id = ? ORDER BY - Requirements.id, distance, TestCases.id + Requirements.id, TestCases.id LIMIT ? """ return db.conn.execute(sql, (radius, requirement_id, limit)).fetchall() diff --git a/tests/test_db/test_tables/test_requirements.py b/tests/test_db/test_tables/test_requirements.py index 125e9ce..da42e8c 100644 --- a/tests/test_db/test_tables/test_requirements.py +++ b/tests/test_db/test_tables/test_requirements.py @@ -66,3 +66,10 @@ def test_count(self): self.db.requirements.insert("Test Requirement 8") count_after = self.db.requirements.count self.assertEqual(count_after, count_before + 1) + + def test_get_by_id_raw(self): + id1 = self.db.requirements.insert("Test Requirement 9") + requirement = self.db.requirements.get_by_id(id1) + self.assertIsNotNone(requirement) + self.assertEqual(requirement[0], id1) + self.assertEqual(requirement[1], "Test Requirement 9") diff --git a/tests/test_db/test_tables/test_test_cases.py b/tests/test_db/test_tables/test_test_cases.py index 9009dfa..4b92438 100644 --- a/tests/test_db/test_tables/test_test_cases.py +++ b/tests/test_db/test_tables/test_test_cases.py @@ -67,3 +67,12 @@ def test_count(self): self.db.test_cases.insert("Test Script 15", "Test Case 15") count_after = self.db.test_cases.count self.assertEqual(count_after, count_before + 1) + + def test_get_by_id_raw(self): + id1 = self.db.test_cases.insert("Test Script 16", "Test Case 16") + record = self.db.test_cases.get_by_id_raw(id1) + self.assertIsNotNone(record) + self.assertEqual(record[0], id1) + self.assertEqual(record[1], "Test Script 16") + self.assertEqual(record[2], "Test Case 16") + self.assertIsNone(record[3]) From 9b2a9e31f07440bd760bd7cf91a66c4489b2add7 Mon Sep 17 00:00:00 2001 From: Nikolai Dorofeev Date: Sun, 21 Sep 2025 17:37:45 +0200 Subject: [PATCH 6/8] [Req report] Fix tests --- tests/test_db/test_tables/test_requirements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_db/test_tables/test_requirements.py b/tests/test_db/test_tables/test_requirements.py index da42e8c..cc17489 100644 --- a/tests/test_db/test_tables/test_requirements.py +++ b/tests/test_db/test_tables/test_requirements.py @@ -69,7 +69,7 @@ def test_count(self): def test_get_by_id_raw(self): id1 = self.db.requirements.insert("Test Requirement 9") - requirement = self.db.requirements.get_by_id(id1) + requirement = self.db.requirements.get_by_id_raw(id1) self.assertIsNotNone(requirement) self.assertEqual(requirement[0], id1) - self.assertEqual(requirement[1], "Test Requirement 9") + self.assertEqual(requirement[2], "Test Requirement 9") From 0650d0f0358f6ff0310275356a8d30495b2a2904 Mon Sep 17 00:00:00 2001 From: Nikolai Dorofeev Date: Sun, 21 Sep 2025 17:40:37 +0200 Subject: [PATCH 7/8] [Req report] Lintfix --- test2text/pages/reports/report_by_req.py | 4 ---- .../services/repositories/requirements/fetch_filtered.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/test2text/pages/reports/report_by_req.py b/test2text/pages/reports/report_by_req.py index bf26519..4ece6e6 100644 --- a/test2text/pages/reports/report_by_req.py +++ b/test2text/pages/reports/report_by_req.py @@ -1,4 +1,3 @@ -from itertools import groupby import numpy as np import streamlit as st @@ -138,9 +137,6 @@ def make_a_report(): ) if st.session_state["chosen_test_case"]: - test_case = db.test_cases.get_by_id_raw( - st.session_state["chosen_test_case"] - ) annotations = annotations_repo.fetch_annotations_by_test_case_with_distance_to_requirement( db, st.session_state["chosen_test_case"], diff --git a/test2text/services/repositories/requirements/fetch_filtered.py b/test2text/services/repositories/requirements/fetch_filtered.py index cd9f9dd..c42084e 100644 --- a/test2text/services/repositories/requirements/fetch_filtered.py +++ b/test2text/services/repositories/requirements/fetch_filtered.py @@ -12,7 +12,7 @@ def fetch_filtered_requirements( text_content: Optional[str] = None, smart_search_query: Optional[str] = None, ) -> list[tuple[int, str, str]]: - sql = f""" + sql = """ SELECT Requirements.id as req_id, Requirements.external_id as req_external_id, From 80b23244d240697002e9d3a943bdc22a81ae837b Mon Sep 17 00:00:00 2001 From: Nikolai Dorofeev Date: Sun, 21 Sep 2025 19:45:34 +0200 Subject: [PATCH 8/8] [TC report] Fix TC report --- test2text/pages/reports/report_by_tc.py | 210 +++++++----------- test2text/services/db/client.py | 79 ------- .../repositories/annotations/__init__.py | 6 +- .../annotations/fetch_by_test_case.py | 16 ++ .../repositories/requirements/__init__.py | 4 + .../requirements/fetch_by_annotation.py | 22 ++ .../requirements/fetch_by_test_case.py | 24 ++ .../repositories/test_cases/__init__.py | 3 +- .../repositories/test_cases/fetch_filtered.py | 36 +++ 9 files changed, 194 insertions(+), 206 deletions(-) create mode 100644 test2text/services/repositories/requirements/fetch_by_annotation.py create mode 100644 test2text/services/repositories/requirements/fetch_by_test_case.py create mode 100644 test2text/services/repositories/test_cases/fetch_filtered.py diff --git a/test2text/pages/reports/report_by_tc.py b/test2text/pages/reports/report_by_tc.py index 94c5cb9..e0c3d81 100644 --- a/test2text/pages/reports/report_by_tc.py +++ b/test2text/pages/reports/report_by_tc.py @@ -1,9 +1,12 @@ -from itertools import groupby import numpy as np import streamlit as st -from sqlite_vec import serialize_float32 from test2text.services.utils.math_utils import round_distance +from test2text.services.repositories import ( + test_cases as tc_repo, + requirements as req_repo, + annotations as an_repo, +) SUMMARY_LENGTH = 100 @@ -13,7 +16,6 @@ def make_a_tc_report(): from test2text.services.db import get_db_client with get_db_client() as db: - from test2text.services.embeddings.embed import embed_requirement from test2text.services.utils import unpack_float32 from test2text.services.visualisation.visualize_vectors import ( minifold_vectors_2d, @@ -24,17 +26,6 @@ def make_a_tc_report(): st.header("Test2Text Report") - def write_requirements(current_requirements: set[tuple]): - st.write("External id,", "Summary,", "Distance") - for ( - _, - req_external_id, - req_summary, - _, - distance, - ) in current_requirements: - st.write(req_external_id, req_summary, round_distance(distance)) - with st.container(border=True): st.subheader("Filter test cases") with st.expander("🔍 Filters"): @@ -50,47 +41,25 @@ def write_requirements(current_requirements: set[tuple]): ) st.info("Search using embeddings") - where_clauses = [] - params = [] - - if filter_summary.strip(): - where_clauses.append("Testcases.test_case LIKE ?") - params.append(f"%{filter_summary.strip()}%") - - distance_sql = "" - distance_order_sql = "" - query_embedding_bytes = None - if filter_embedding.strip(): - query_embedding = embed_requirement(filter_embedding.strip()) - query_embedding_bytes = serialize_float32(query_embedding) - distance_sql = ", vec_distance_L2(embedding, ?) AS distance" - distance_order_sql = "distance ASC, " - with st.container(border=True): st.session_state.update({"tc_form_submitting": True}) - data = db.get_ordered_values_from_test_cases( - distance_sql, - where_clauses, - distance_order_sql, - params + [query_embedding_bytes] if distance_sql else params, + test_cases = tc_repo.fetch_filtered_test_cases( + db, text_content=filter_summary, smart_search_query=filter_embedding ) - if distance_sql: - tc_dict = { - f"{test_case} [smart search d={round_distance(distance)}]": tc_id - for (tc_id, _, test_case, distance) in data - } - else: - tc_dict = {test_case: tc_id for (tc_id, _, test_case) in data} + test_cases = { + tc_id: (test_script, test_case) + for tc_id, test_script, test_case in test_cases + } st.subheader("Choose ONE of filtered test cases") - option = st.selectbox( - "Choose a requirement to work with", tc_dict.keys(), key="filter_tc_id" + selected_test_case = st.selectbox( + "Choose a requirement to work with", + test_cases.keys(), + key="filter_tc_id", + format_func=lambda x: test_cases[x][1], ) - if option: - where_clauses.append("Testcases.id = ?") - params.append(tc_dict[option]) - + if selected_test_case: st.subheader("Filter Requirements") with st.expander("🔍 Filters"): @@ -114,81 +83,39 @@ def write_requirements(current_requirements: set[tuple]): ) st.info("Limit of selected requirements") - if filter_radius: - where_clauses.append("distance <= ?") - params.append(f"{filter_radius}") - - if filter_limit: - params.append(f"{filter_limit}") - - rows = db.join_all_tables_by_test_cases(where_clauses, params) + annotations = an_repo.fetch_annotations_by_test_case( + db, selected_test_case + ) + annotations_dict = { + anno_id: (anno_summary, anno_embedding) + for anno_id, anno_summary, anno_embedding in annotations + } - if not rows: + if not annotations_dict: st.error( "There is no requested data to inspect.\n" "Please check filters, completeness of the data or upload new annotations and requirements." ) - return None - - for (tc_id, test_script, test_case), group in groupby( - rows, lambda x: x[0:3] - ): + else: st.divider() with st.container(): - st.subheader(f"Inspect #{tc_id} Test case '{test_case}'") - st.write(f"From test script {test_script}") - current_annotations = dict() - for ( - _, - _, - _, - anno_id, - anno_summary, - anno_embedding, - distance, - req_id, - req_external_id, - req_summary, - req_embedding, - ) in group: - current_annotation = (anno_id, anno_summary, anno_embedding) - current_reqs = current_annotations.get( - current_annotation, set() - ) - current_annotations.update( - {current_annotation: current_reqs} - ) - current_annotations[current_annotation].add( - ( - req_id, - req_external_id, - req_summary, - req_embedding, - distance, - ) - ) + st.subheader( + f"Inspect #{selected_test_case} Test case '{test_cases[selected_test_case][1]}'" + ) + st.write( + f"From test script {test_cases[selected_test_case][0]}" + ) t_cs, anno, viz = st.columns(3) with t_cs: with st.container(border=True): st.write("Annotations") st.info("Annotations linked to chosen Test case") - reqs_by_anno = { - f"#{anno_id} {anno_summary}": ( - anno_id, - anno_summary, - anno_embedding, - ) - for ( - anno_id, - anno_summary, - anno_embedding, - ) in current_annotations.keys() - } - radio_choice = st.radio( + chosen_annotation = st.radio( "Annotation's id + summary", - reqs_by_anno.keys(), - key="radio_choice", + annotations_dict.keys(), + key="chosen_annotation", + format_func=lambda x: f"[{x}] {annotations_dict[x][0][:SUMMARY_LENGTH]}", ) st.markdown( """ @@ -203,18 +130,42 @@ def write_requirements(current_requirements: set[tuple]): unsafe_allow_html=True, ) - if radio_choice: + if chosen_annotation: + requirements = ( + req_repo.fetch_requirements_by_annotation( + db, + annotation_id=chosen_annotation, + radius=filter_radius, + limit=filter_limit, + ) + ) + reqs_dict = { + req_id: ( + req_external_id, + req_summary, + req_emb, + distance, + ) + for req_id, req_external_id, req_summary, req_emb, distance in requirements + } with anno: with st.container(border=True): st.write("Requirements") st.info( "Found Requirements for chosen annotation" ) - write_requirements( - current_annotations[ - reqs_by_anno[radio_choice] - ] - ) + st.write("External id,", "Summary,", "Distance") + for ( + req_external_id, + req_summary, + _, + distance, + ) in reqs_dict.values(): + st.write( + req_external_id, + req_summary, + round_distance(distance), + ) with viz: with st.container(border=True): st.write("Visualization") @@ -223,18 +174,27 @@ def write_requirements(current_requirements: set[tuple]): ) req_embeddings = [ unpack_float32(req_emb) - for _, _, _, req_emb, _ in current_annotations[ - reqs_by_anno[radio_choice] - ] + for _, _, req_emb, _ in reqs_dict.values() ] req_labels = [ - f"{ext_id}" - for _, ext_id, req_sum, _, _ in current_annotations[ - reqs_by_anno[radio_choice] - ] + req_ext_id or req_id + for req_id, ( + req_ext_id, + _, + _, + _, + ) in reqs_dict.items() ] annotation_vectors = np.array( - [np.array(unpack_float32(anno_embedding))] + [ + np.array( + unpack_float32( + annotations_dict[ + chosen_annotation + ][1] + ) + ) + ] ) requirement_vectors = np.array(req_embeddings) if select == "2D": @@ -245,7 +205,7 @@ def write_requirements(current_requirements: set[tuple]): ), first_title="Annotation", second_title="Requirements", - first_labels=radio_choice, + first_labels=chosen_annotation, second_labels=req_labels, ) else: @@ -260,7 +220,7 @@ def write_requirements(current_requirements: set[tuple]): reqs_vectors_3d, first_title="Annotation", second_title="Requirements", - first_labels=radio_choice, + first_labels=chosen_annotation, second_labels=req_labels, ) diff --git a/test2text/services/db/client.py b/test2text/services/db/client.py index bdacf5b..9b5f6c6 100644 --- a/test2text/services/db/client.py +++ b/test2text/services/db/client.py @@ -184,85 +184,6 @@ def get_embeddings_from_annotations_to_requirements_table(self): """) return cursor.fetchall() - def get_ordered_values_from_test_cases( - self, distance_sql="", where_clauses="", distance_order_sql="", params=None - ) -> list[tuple]: - """ - Extracted values from TestCases table based on the provided where clauses and specified parameters ordered by distance and id. - Return a list of tuples containing : - case_id, - test_script, - test_case, - distance between test case and typed by user text embeddings if it is specified, - """ - where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - sql = f""" - SELECT - TestCases.id as case_id, - TestCases.test_script as test_script, - TestCases.test_case as test_case - {distance_sql} - FROM - TestCases - {where_sql} - ORDER BY - {distance_order_sql}TestCases.id - """ - data = self.conn.execute(sql, params) - return data.fetchall() - - def join_all_tables_by_test_cases( - self, where_clauses="", params=None - ) -> list[tuple]: - """ - Join all tables related to test cases based on the provided where clauses and specified parameters. - Return a list of tuples containing : - case_id, - test_script, - test_case, - anno_id, - anno_summary, - anno_embedding, - distance between annotation and requirement embeddings, - req_id, - req_external_id, - req_summary, - req_embedding - """ - where_sql = "" - if where_clauses: - where_sql = f"WHERE {' AND '.join(where_clauses)}" - - sql = f""" - SELECT - TestCases.id as case_id, - TestCases.test_script as test_script, - TestCases.test_case as test_case, - - Annotations.id as anno_id, - Annotations.summary as anno_summary, - Annotations.embedding as anno_embedding, - - AnnotationsToRequirements.cached_distance as distance, - - Requirements.id as req_id, - Requirements.external_id as req_external_id, - Requirements.summary as req_summary, - Requirements.embedding as req_embedding - FROM - TestCases - JOIN CasesToAnnos ON TestCases.id = CasesToAnnos.case_id - JOIN Annotations ON Annotations.id = CasesToAnnos.annotation_id - JOIN AnnotationsToRequirements ON Annotations.id = AnnotationsToRequirements.annotation_id - JOIN Requirements ON Requirements.id = AnnotationsToRequirements.requirement_id - {where_sql} - ORDER BY - case_id, distance, req_id - LIMIT ? - """ - data = self.conn.execute(sql, params) - return data.fetchall() - def get_embeddings_by_id(self, id1: int, from_table: str) -> float: """ Returns the embedding of the specified id from the specified table. diff --git a/test2text/services/repositories/annotations/__init__.py b/test2text/services/repositories/annotations/__init__.py index 47a4756..596d1d9 100644 --- a/test2text/services/repositories/annotations/__init__.py +++ b/test2text/services/repositories/annotations/__init__.py @@ -1,5 +1,9 @@ -__all__ = ["fetch_annotations_by_test_case_with_distance_to_requirement"] +__all__ = [ + "fetch_annotations_by_test_case_with_distance_to_requirement", + "fetch_annotations_by_test_case", +] from .fetch_by_test_case import ( fetch_annotations_by_test_case_with_distance_to_requirement, + fetch_annotations_by_test_case, ) diff --git a/test2text/services/repositories/annotations/fetch_by_test_case.py b/test2text/services/repositories/annotations/fetch_by_test_case.py index bb59cf0..df139a9 100644 --- a/test2text/services/repositories/annotations/fetch_by_test_case.py +++ b/test2text/services/repositories/annotations/fetch_by_test_case.py @@ -1,6 +1,22 @@ from test2text.services.db import DbClient +def fetch_annotations_by_test_case( + db: DbClient, test_case_id: int +) -> list[tuple[int, str, bytes]]: + sql = """ + SELECT + Annotations.id as anno_id, + Annotations.summary as anno_summary, + Annotations.embedding as anno_embedding + FROM + Annotations + JOIN CasesToAnnos ON Annotations.id = CasesToAnnos.annotation_id + WHERE CasesToAnnos.case_id = ? + """ + return db.conn.execute(sql, (test_case_id,)).fetchall() + + def fetch_annotations_by_test_case_with_distance_to_requirement( db: DbClient, test_case_id: int, requirement_embedding: bytes ) -> list[tuple[int, str, bytes, float]]: diff --git a/test2text/services/repositories/requirements/__init__.py b/test2text/services/repositories/requirements/__init__.py index 69c1e5b..21160fe 100644 --- a/test2text/services/repositories/requirements/__init__.py +++ b/test2text/services/repositories/requirements/__init__.py @@ -1,4 +1,8 @@ __all__ = [ "fetch_filtered_requirements", + "fetch_requirements_by_test_case", + "fetch_requirements_by_annotation", ] from .fetch_filtered import fetch_filtered_requirements +from .fetch_by_test_case import fetch_requirements_by_test_case +from .fetch_by_annotation import fetch_requirements_by_annotation diff --git a/test2text/services/repositories/requirements/fetch_by_annotation.py b/test2text/services/repositories/requirements/fetch_by_annotation.py new file mode 100644 index 0000000..0b281b4 --- /dev/null +++ b/test2text/services/repositories/requirements/fetch_by_annotation.py @@ -0,0 +1,22 @@ +from test2text.services.db import DbClient + + +def fetch_requirements_by_annotation( + db: DbClient, *, annotation_id: int, radius: float, limit: int +) -> list[tuple[int, str, str, bytes, float]]: + sql = """ + SELECT DISTINCT + Requirements.id as req_id, + Requirements.external_id as req_external_id, + Requirements.summary as req_summary, + Requirements.embedding as req_embedding, + vec_distance_L2(Requirements.embedding, Annotations.embedding) as distance + FROM + Annotations + JOIN Requirements ON vec_distance_L2(Requirements.embedding, Annotations.embedding) <= ? + WHERE Annotations.id = ? + ORDER BY + distance + LIMIT ? + """ + return db.conn.execute(sql, (radius, annotation_id, limit)).fetchall() diff --git a/test2text/services/repositories/requirements/fetch_by_test_case.py b/test2text/services/repositories/requirements/fetch_by_test_case.py new file mode 100644 index 0000000..d695bb3 --- /dev/null +++ b/test2text/services/repositories/requirements/fetch_by_test_case.py @@ -0,0 +1,24 @@ +from test2text.services.db import DbClient + + +def fetch_requirements_by_test_case( + db: DbClient, *, test_case_id: int, radius: float, limit: int +) -> list[tuple[int, str, str]]: + sql = """ + SELECT DISTINCT + Requirements.id as req_id, + Requirements.external_id as req_external_id, + Requirements.summary as req_summary, + MIN(vec_distance_L2(Requirements.embedding, Annotations.embedding)) as min_distance + FROM + TestCases + JOIN CasesToAnnos ON TestCases.id = CasesToAnnos.case_id + JOIN Annotations ON Annotations.id = CasesToAnnos.annotation_id + JOIN Requirements ON vec_distance_L2(Requirements.embedding, Annotations.embedding) <= ? + WHERE TestCases.id = ? + GROUP BY Requirements.id + ORDER BY + min_distance + LIMIT ? + """ + return db.conn.execute(sql, (test_case_id, radius, limit)).fetchall() diff --git a/test2text/services/repositories/test_cases/__init__.py b/test2text/services/repositories/test_cases/__init__.py index 5075b63..2fa3d3a 100644 --- a/test2text/services/repositories/test_cases/__init__.py +++ b/test2text/services/repositories/test_cases/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["fetch_test_cases_by_requirement"] +__all__ = ["fetch_test_cases_by_requirement", "fetch_filtered_test_cases"] from .fetch_by_requirement import fetch_test_cases_by_requirement +from .fetch_filtered import fetch_filtered_test_cases diff --git a/test2text/services/repositories/test_cases/fetch_filtered.py b/test2text/services/repositories/test_cases/fetch_filtered.py new file mode 100644 index 0000000..d574814 --- /dev/null +++ b/test2text/services/repositories/test_cases/fetch_filtered.py @@ -0,0 +1,36 @@ +from typing import Optional + +from sqlite_vec import serialize_float32 + +from test2text.services.db import DbClient + + +def fetch_filtered_test_cases( + db: DbClient, + *_, + text_content: Optional[str] = None, + smart_search_query: Optional[str] = None, +) -> list[tuple[int, str, str]]: + sql = """ + SELECT + TestCases.id as case_id, + TestCases.test_script as test_script, + TestCases.test_case as test_case + FROM + TestCases + """ + params = [] + if text_content or smart_search_query: + sql += " WHERE " + conditions = [] + if text_content: + conditions.append("TestCases.test_case LIKE ?") + params.append(f"%{text_content.strip()}%") + if smart_search_query: + from test2text.services.embeddings.embed import embed_requirement + + embedding = embed_requirement(smart_search_query.strip()) + conditions.append("vec_distance_L2(TestCases.embedding, ?) < 0.7") + params.append(serialize_float32(embedding)) + sql += " AND ".join(conditions) + return db.conn.execute(sql, params).fetchall()