Skip to content

Commit ecc1ada

Browse files
authored
fix: target directories retrieve (#227)
* perf: reuse query embeddings in hierarchical retriever * fix(retrieve): honor target_directories filtering
1 parent a062199 commit ecc1ada

File tree

2 files changed

+86
-2
lines changed

2 files changed

+86
-2
lines changed

openviking/retrieve/hierarchical_retriever.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,22 @@ async def retrieve(
9898

9999
collection = self._type_to_collection(query.context_type)
100100

101+
target_dirs = [d for d in (query.target_directories or []) if d]
102+
101103
# Create context_type filter
102104
type_filter = {"op": "must", "field": "context_type", "conds": [query.context_type.value]}
103105

104106
# Merge all filters
105107
filters_to_merge = [type_filter]
108+
if target_dirs:
109+
target_filter = {
110+
"op": "or",
111+
"conds": [
112+
{"op": "prefix", "field": "uri", "prefix": target_dir}
113+
for target_dir in target_dirs
114+
],
115+
}
116+
filters_to_merge.append(target_filter)
106117
if metadata_filter:
107118
filters_to_merge.append(metadata_filter)
108119

@@ -124,8 +135,11 @@ async def retrieve(
124135
query_vector = result.dense_vector
125136
sparse_query_vector = result.sparse_vector
126137

127-
# Step 1: Determine starting directories based on context_type
128-
root_uris = self._get_root_uris_for_type(query.context_type)
138+
# Step 1: Determine starting directories based on target_directories or context_type
139+
if target_dirs:
140+
root_uris = target_dirs
141+
else:
142+
root_uris = self._get_root_uris_for_type(query.context_type)
129143

130144
# Step 2: Global vector search to supplement starting points
131145
global_results = await self._global_vector_search(
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Hierarchical retriever target_directories tests."""
5+
6+
import pytest
7+
8+
from openviking.retrieve.hierarchical_retriever import HierarchicalRetriever
9+
from openviking_cli.retrieve.types import ContextType, TypedQuery
10+
11+
12+
class DummyStorage:
13+
"""Minimal storage stub to capture search filters."""
14+
15+
def __init__(self) -> None:
16+
self.search_calls = []
17+
18+
async def collection_exists(self, _name: str) -> bool:
19+
return True
20+
21+
async def search(
22+
self,
23+
collection: str,
24+
query_vector=None,
25+
sparse_query_vector=None,
26+
filter=None,
27+
limit: int = 10,
28+
offset: int = 0,
29+
output_fields=None,
30+
with_vector: bool = False,
31+
):
32+
self.search_calls.append(
33+
{
34+
"collection": collection,
35+
"filter": filter,
36+
"limit": limit,
37+
"offset": offset,
38+
}
39+
)
40+
return []
41+
42+
43+
def _contains_prefix_filter(obj, prefix: str) -> bool:
44+
if isinstance(obj, dict):
45+
if obj.get("op") == "prefix" and obj.get("field") == "uri" and obj.get("prefix") == prefix:
46+
return True
47+
return any(_contains_prefix_filter(v, prefix) for v in obj.values())
48+
if isinstance(obj, list):
49+
return any(_contains_prefix_filter(v, prefix) for v in obj)
50+
return False
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_retrieve_honors_target_directories_prefix_filter():
55+
target_uri = "viking://resources/foo"
56+
storage = DummyStorage()
57+
retriever = HierarchicalRetriever(storage=storage, embedder=None, rerank_config=None)
58+
59+
query = TypedQuery(
60+
query="test",
61+
context_type=ContextType.RESOURCE,
62+
intent="",
63+
target_directories=[target_uri],
64+
)
65+
66+
result = await retriever.retrieve(query, limit=3)
67+
68+
assert result.searched_directories == [target_uri]
69+
assert storage.search_calls
70+
assert _contains_prefix_filter(storage.search_calls[0]["filter"], target_uri)

0 commit comments

Comments
 (0)