Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20260129-174421.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Include metric filters in lineage graphs
time: 2026-01-29T17:44:21.092837-05:00
custom:
Author: will.tremml
Issue: "12411"
14 changes: 14 additions & 0 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,10 @@ class Manifest(MacroMethods, dbtClassMixin):
default=None,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
_custom_granularity_names: Optional[Set[str]] = field(
default=None,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)

def __pre_serialize__(self, context: Optional[Dict] = None):
# serialization won't work with anything except an empty source_patches because
Expand Down Expand Up @@ -1017,6 +1021,16 @@ def build_flat_graph(self):
},
}

def get_custom_granularity_names(self) -> Set[str]:
if self._custom_granularity_names is None:
names: Set[str] = set()
for node in self.nodes.values():
if isinstance(node, ModelNode) and node.time_spine:
for custom_granularity in node.time_spine.custom_granularities:
names.add(custom_granularity.name)
self._custom_granularity_names = names
return self._custom_granularity_names

def build_disabled_by_file_id(self):
disabled_by_file_id = {}
for node_list in self.disabled.values():
Expand Down
121 changes: 120 additions & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass, field
from datetime import date, datetime, timezone
from itertools import chain
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union

import msgpack
from jinja2.nodes import Call
Expand All @@ -32,6 +32,8 @@
NodeRelation,
NodeVersion,
)
from dbt.artifacts.resources.v1.metric import MetricInputMeasure
from dbt.artifacts.resources.v1.semantic_layer_components import WhereFilterIntersection
from dbt.artifacts.resources.types import BatchSize
from dbt.artifacts.schemas.base import Writable
from dbt.clients.jinja import MacroStack, get_rendered
Expand Down Expand Up @@ -85,6 +87,7 @@
PartialParsingErrorProcessingFile,
PartialParsingNotEnabled,
PartialParsingSkipParsing,
SemanticValidationFailure,
SpacesInResourceNameDeprecation,
StateCheckVarsHash,
UnableToPartialParse,
Expand Down Expand Up @@ -135,6 +138,7 @@
from dbt_common.events.types import Note
from dbt_common.exceptions.base import DbtValidationError
from dbt_common.helper_types import PathSet
from dbt_semantic_interfaces.call_parameter_sets import ParseJinjaObjectException
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.type_enums import MetricType

Expand Down Expand Up @@ -2029,6 +2033,84 @@ def _process_multiple_metric_inputs(
metric.depends_on.add_node(target_metric.unique_id)


def _maybe_append_where_filter(
where_filters: List[WhereFilterIntersection],
filter_value: Optional[WhereFilterIntersection],
) -> None:
if filter_value:
where_filters.append(filter_value)


def _maybe_append_metric_input_filter(
where_filters: List[WhereFilterIntersection],
metric_input: Optional[MetricInput],
) -> None:
if metric_input is not None and metric_input.filter is not None:
where_filters.append(metric_input.filter)


def _maybe_append_metric_input_measure_filter(
where_filters: List[WhereFilterIntersection],
metric_input_measure: Optional[MetricInputMeasure],
) -> None:
if metric_input_measure is not None and metric_input_measure.filter is not None:
where_filters.append(metric_input_measure.filter)


def _collect_metric_where_filters(metric: Metric) -> List[WhereFilterIntersection]:
where_filters: List[WhereFilterIntersection] = []
_maybe_append_where_filter(where_filters, metric.filter)
_maybe_append_metric_input_measure_filter(where_filters, metric.type_params.measure)
for input_measure in metric.type_params.input_measures:
_maybe_append_metric_input_measure_filter(where_filters, input_measure)
_maybe_append_metric_input_filter(where_filters, metric.type_params.numerator)
_maybe_append_metric_input_filter(where_filters, metric.type_params.denominator)
cumulative_type_params = metric.type_params.cumulative_type_params
if cumulative_type_params is not None:
_maybe_append_metric_input_filter(where_filters, cumulative_type_params.metric)
conversion_type_params = metric.type_params.conversion_type_params
if conversion_type_params is not None:
_maybe_append_metric_input_filter(where_filters, conversion_type_params.base_metric)
_maybe_append_metric_input_filter(where_filters, conversion_type_params.conversion_metric)
for input_metric in metric.input_metrics:
_maybe_append_metric_input_filter(where_filters, input_metric)
return where_filters


def _metric_dependency_names_from_filters(
manifest: Manifest,
where_filters: Sequence[WhereFilterIntersection],
node: Union[Metric, SavedQuery],
) -> Set[str]:
metric_names: Set[str] = set()
if not where_filters:
return metric_names
custom_granularity_names = manifest.get_custom_granularity_names()
for intersection in where_filters:
for filter_clause in intersection.where_filters:
try:
parameter_sets = filter_clause.call_parameter_sets(
custom_granularity_names=custom_granularity_names
)
except ParseJinjaObjectException as exc:
fire_event(
SemanticValidationFailure(
msg=f"Unable to parse semantic filter on {node.unique_id}: {exc}"
),
EventLevel.WARN,
)
continue
for metric_call in parameter_sets.metric_call_parameter_sets:
metric_names.add(metric_call.metric_reference.element_name)
return metric_names


def _metric_inputs_from_filters(manifest: Manifest, metric: Metric) -> List[MetricInput]:
where_filters = _collect_metric_where_filters(metric)
metric_names = _metric_dependency_names_from_filters(manifest, where_filters, metric)
return [MetricInput(name=name) for name in sorted(metric_names)]


def _process_metric_node(
manifest: Manifest,
current_project: str,
Expand Down Expand Up @@ -2177,6 +2259,15 @@ def _process_metric_node(
else:
assert_values_exhausted(metric.type)

filter_metric_inputs = _metric_inputs_from_filters(manifest, metric)
if filter_metric_inputs:
_process_multiple_metric_inputs(
manifest=manifest,
current_project=current_project,
metric=metric,
metric_inputs=filter_metric_inputs,
)


def _process_metrics_for_node(
manifest: Manifest,
Expand Down Expand Up @@ -2231,6 +2322,34 @@ def _process_metrics_for_node(

node.depends_on.add_node(target_metric_id)

if isinstance(node, SavedQuery) and node.query_params.where is not None:
referenced_metric_names = _metric_dependency_names_from_filters(
manifest, [node.query_params.where], node
)
existing_metric_names = set(node.metrics)
for metric_name in sorted(referenced_metric_names):
if metric_name in existing_metric_names:
# Already processed above via explicit metrics list.
continue
target_metric = manifest.resolve_metric(
metric_name,
None,
current_project,
node.package_name,
)

if target_metric is None or isinstance(target_metric, Disabled):
node.config.enabled = False
invalid_target_fail_unless_test(
node=node,
target_name=metric_name,
target_kind="metric",
disabled=(isinstance(target_metric, Disabled)),
)
continue

node.depends_on.add_node(target_metric.unique_id)


def remove_dependent_project_references(manifest, external_node_unique_id):
for child_id in manifest.child_map[external_node_unique_id]:
Expand Down
25 changes: 25 additions & 0 deletions tests/functional/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,31 @@ def test_filter_parsing(
assert filters6[0].where_sql_template == "{{ Dimension('id__loves_dbt') }} is true"


class TestMetricFilterLineage:
@pytest.fixture(scope="class")
def models(self):
return {
"basic_metrics.yml": basic_metrics_yml,
"metricflow_time_spine.sql": metricflow_time_spine_sql,
"semantic_model_people.yml": semantic_model_people_yml,
"people.sql": models_people_sql,
}

def test_metric_filter_lineage(self, project):
runner = dbtRunner()
result = runner.invoke(["parse"])
assert result.success
manifest = get_manifest(project.project_root)
manifest.build_parent_and_child_maps()

tenured = manifest.metrics["metric.test.tenured_people"]
assert "metric.test.collective_tenure" in tenured.depends_on.nodes
assert (
"metric.test.tenured_people"
in manifest.child_map["metric.test.collective_tenure"]
)


class TestDuplicateInputMeasures:
@pytest.fixture(scope="class")
def models(self):
Expand Down
10 changes: 8 additions & 2 deletions tests/functional/saved_queries/test_saved_query_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def test_semantic_model_parsing(self, project):
assert len(saved_query.query_params.metrics) == 1
assert len(saved_query.query_params.group_by) == 1
assert len(saved_query.query_params.where.where_filters) == 3
assert len(saved_query.depends_on.nodes) == 1
assert set(saved_query.depends_on.nodes) == {
"metric.test.simple_metric",
"metric.test.txn_revenue",
}

assert len(saved_query.query_params.order_by) == 2
assert saved_query.query_params.limit is not None
Expand Down Expand Up @@ -163,7 +166,10 @@ def test_semantic_model_parsing_with_default_schema(self, project, other_schema)
assert len(saved_query.query_params.metrics) == 1
assert len(saved_query.query_params.group_by) == 1
assert len(saved_query.query_params.where.where_filters) == 3
assert len(saved_query.depends_on.nodes) == 1
assert set(saved_query.depends_on.nodes) == {
"metric.test.simple_metric",
"metric.test.txn_revenue",
}
assert saved_query.description == "My SavedQuery Description"
assert len(saved_query.exports) == 1
assert saved_query.exports[0].name == "my_export"
Expand Down