Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 65 additions & 118 deletions test2text/pages/reports/report_by_req.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from itertools import groupby
import numpy as np
import streamlit as st
from sqlite_vec import serialize_float32

from test2text.services.utils.math_utils import round_distance
from test2text.services.repositories import (
requirements as requirements_repo,
test_cases as test_cases_repo,
annotations as annotations_repo,
)

SUMMARY_LENGTH = 100
LABELS_SUMMARY_LENGTH = 15
Expand All @@ -13,22 +16,8 @@ def make_a_report():
from test2text.services.db import get_db_client

with get_db_client() as db:
from test2text.services.embeddings.embed import embed_requirement
from test2text.services.utils import unpack_float32
from test2text.services.visualisation.visualize_vectors import (
minifold_vectors_2d,
plot_2_sets_in_one_2d,
minifold_vectors_3d,
plot_2_sets_in_one_3d,
)

st.header("Test2Text Report")

def write_annotations(current_annotations: set[tuple]):
st.write("id,", "Summary,", "Distance")
for anno_id, anno_summary, _, distance in current_annotations:
st.write(anno_id, anno_summary, round_distance(distance))

with st.container(border=True):
st.subheader("Filter requirements")
with st.expander("🔍 Filters"):
Expand All @@ -47,62 +36,30 @@ def write_annotations(current_annotations: set[tuple]):
)
st.info("Search using embeddings")

where_clauses = []
params = []

if filter_id.strip():
where_clauses.append("Requirements.id = ?")
params.append(filter_id.strip())

if filter_summary.strip():
where_clauses.append("Requirements.summary LIKE ?")
params.append(f"%{filter_summary.strip()}%")

distance_sql = ""
distance_order_sql = ""
query_embedding_bytes = None
if filter_embedding.strip():
query_embedding = embed_requirement(filter_embedding.strip())
query_embedding_bytes = serialize_float32(query_embedding)
distance_sql = ", vec_distance_L2(embedding, ?) AS distance"
distance_order_sql = "distance ASC, "

with st.container(border=True):
st.session_state.update({"req_form_submitting": True})
data = db.get_ordered_values_from_requirements(
distance_sql,
where_clauses,
distance_order_sql,
params + [query_embedding_bytes] if distance_sql else params,
requirements = requirements_repo.fetch_filtered_requirements(
db,
external_id=filter_id,
text_content=filter_summary,
smart_search_query=filter_embedding,
)

if distance_sql:
requirements_dict = {
f"{req_external_id} {summary[:SUMMARY_LENGTH]}... [smart search d={round_distance(distance)}]": req_id
for (req_id, req_external_id, summary, distance) in data
}
else:
requirements_dict = {
f"{req_external_id} {summary[:SUMMARY_LENGTH]}...": req_id
for (req_id, req_external_id, summary) in data
}
requirements = {
req_id: (req_external_id, summary)
for (req_id, req_external_id, summary) in requirements
}

st.subheader("Choose 1 of filtered requirements")
option = st.selectbox(
selected_requirement = st.selectbox(
"Choose a requirement to work with",
requirements_dict.keys(),
requirements.keys(),
key="filter_req_id",
format_func=lambda x: f"{requirements[x][0]} {requirements[x][1][:SUMMARY_LENGTH]}...",
)

if option:
clause = "Requirements.id = ?"
if clause in where_clauses:
idx = where_clauses.index(clause)
params.insert(idx, requirements_dict[option])
else:
where_clauses.append(clause)
params.append(requirements_dict[option])

if selected_requirement:
requirement = db.requirements.get_by_id_raw(selected_requirement)
st.subheader("Filter Test cases")

with st.expander("🔍 Filters"):
Expand All @@ -126,53 +83,34 @@ def write_annotations(current_annotations: set[tuple]):
)
st.info("Limit of selected test cases")

if filter_radius:
where_clauses.append("distance <= ?")
params.append(f"{filter_radius}")

if filter_limit:
params.append(f"{filter_limit}")

rows = db.join_all_tables_by_requirements(where_clauses, params)
test_cases = test_cases_repo.fetch_test_cases_by_requirement(
db, selected_requirement, filter_radius, filter_limit
)
test_cases = {
tc_id: (test_script, test_case)
for (tc_id, test_script, test_case) in test_cases
}

if not rows:
if not test_cases:
st.error(
"There is no requested data to inspect.\n"
"Please check filters, completeness of the data or upload new annotations and requirements."
)
return None
else:
from test2text.services.utils import unpack_float32
from test2text.services.visualisation.visualize_vectors import (
minifold_vectors_2d,
plot_2_sets_in_one_2d,
minifold_vectors_3d,
plot_2_sets_in_one_3d,
)

for (
req_id,
req_external_id,
req_summary,
req_embedding,
), group in groupby(rows, lambda x: x[0:4]):
st.divider()
with st.container():
st.subheader(f" Inspect Requirement {req_external_id}")
st.write(req_summary)
current_test_cases = dict()
for (
_,
_,
_,
_,
anno_id,
anno_summary,
anno_embedding,
distance,
case_id,
test_script,
test_case,
) in group:
current_annotation = current_test_cases.get(
test_case, set()
)
current_test_cases.update({test_case: current_annotation})
current_test_cases[test_case].add(
(anno_id, anno_summary, anno_embedding, distance)
)
st.subheader(
f" Inspect Requirement {requirements[selected_requirement][0]}"
)
st.write(requirements[selected_requirement][1])

t_cs, anno, viz = st.columns(3)
with t_cs:
Expand All @@ -181,8 +119,9 @@ def write_annotations(current_annotations: set[tuple]):
st.info("Test cases of chosen Requirement")
st.radio(
"Test cases name",
current_test_cases.keys(),
key="radio_choice",
test_cases.keys(),
key="chosen_test_case",
format_func=lambda tc_id: test_cases[tc_id][1],
)
st.markdown(
"""
Expand All @@ -197,18 +136,30 @@ def write_annotations(current_annotations: set[tuple]):
unsafe_allow_html=True,
)

if st.session_state["radio_choice"]:
if st.session_state["chosen_test_case"]:
annotations = annotations_repo.fetch_annotations_by_test_case_with_distance_to_requirement(
db,
st.session_state["chosen_test_case"],
requirement[3], # embedding
)
with anno:
with st.container(border=True):
st.write("Annotations")
st.info(
"List of Annotations for chosen Test case"
)
write_annotations(
current_annotations=current_test_cases[
st.session_state["radio_choice"]
]
)
st.write("id,", "Summary,", "Distance")
for (
anno_id,
anno_summary,
_,
distance,
) in annotations:
st.write(
anno_id,
anno_summary,
round_distance(distance),
)
with viz:
with st.container(border=True):
st.write("Visualization")
Expand All @@ -217,18 +168,14 @@ def write_annotations(current_annotations: set[tuple]):
)
anno_embeddings = [
unpack_float32(anno_emb)
for _, _, anno_emb, _ in current_test_cases[
st.session_state["radio_choice"]
]
for _, _, anno_emb, _ in annotations
]
anno_labels = [
f"{anno_id}"
for anno_id, _, _, _ in current_test_cases[
st.session_state["radio_choice"]
]
for anno_id, _, _, _ in annotations
]
requirement_vectors = np.array(
[np.array(unpack_float32(req_embedding))]
[np.array(unpack_float32(requirement[3]))]
)
annotation_vectors = np.array(anno_embeddings)
if select == "2D":
Expand All @@ -239,7 +186,7 @@ def write_annotations(current_annotations: set[tuple]):
minifold_vectors_2d(annotation_vectors),
"Requirement",
"Annotations",
first_labels=[f"{req_external_id}"],
first_labels=[f"{requirement[1]}"],
second_labels=anno_labels,
)
else:
Expand All @@ -254,7 +201,7 @@ def write_annotations(current_annotations: set[tuple]):
anno_vectors_3d,
"Requirement",
"Annotations",
first_labels=[f"{req_external_id}"],
first_labels=[f"{requirement[1]}"],
second_labels=anno_labels,
)

Expand Down
Loading