Skip to content

Commit 401fb38

Browse files
committed
Optionally detect parameter collisions - fix #2566
1 parent 62d1672 commit 401fb38

File tree

3 files changed

+148
-6
lines changed

3 files changed

+148
-6
lines changed

doc/configuration.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,16 @@ check_complete_on_run
356356
missing.
357357
Defaults to false.
358358

359+
prevent_parameter_collision
360+
In complex pipelines especially when tasks are inherited, it can happen that
361+
different tasks define parameters with the same name. Luigi would normally use
362+
the same value for both parameter instances, which might not be desired.
363+
When set to ``true``, luigi will check for parameter collisions and refuse to
364+
run if a parameter is defined multiple times. Optionally, an allow-list of
365+
parameters called ``collisions_to_ignore`` can be passed to ``inherits/requires``,
366+
to ignore when checking for duplicate parameters.
367+
Defaults to false.
368+
359369

360370
[elasticsearch]
361371
---------------

luigi/util.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,10 @@ class TaskB(luigi.Task):
219219

220220
import datetime
221221
import logging
222+
from configparser import NoOptionError, NoSectionError
222223

223-
from luigi import task
224-
from luigi import parameter
224+
from luigi import parameter, task
225+
from luigi.configuration import get_config
225226

226227

227228
logger = logging.getLogger('luigi-interface')
@@ -277,18 +278,36 @@ def requires(self):
277278
def run(self):
278279
print self.n # this will be defined
279280
# ...
281+
282+
inherits/requires decorator optionally takes an argument called
283+
`collisions_to_ignore` with an iterable of parameters that are
284+
allowed to overwrite parameters in upstream tasks.
285+
In complex pipelines, it can happen that different tasks define parameters
286+
with the same name.
287+
If `prevent-parameter-collision` in the `[worker]` section of the config
288+
is true, luigi will raise an exception in case of parameter conflicts -
289+
unless the parameter is explicitly allowed in `collisions_to_ignore`.
280290
"""
281291

282-
def __init__(self, *tasks_to_inherit, **kw_tasks_to_inherit):
292+
def __init__(
293+
self,
294+
*tasks_to_inherit,
295+
collisions_to_ignore=(),
296+
**kw_tasks_to_inherit,
297+
):
283298
super(inherits, self).__init__()
284299
if not tasks_to_inherit and not kw_tasks_to_inherit:
285300
raise TypeError("tasks_to_inherit or kw_tasks_to_inherit must contain at least one task")
286301
if tasks_to_inherit and kw_tasks_to_inherit:
287302
raise TypeError("Only one of tasks_to_inherit or kw_tasks_to_inherit may be present")
288303
self.tasks_to_inherit = tasks_to_inherit
289304
self.kw_tasks_to_inherit = kw_tasks_to_inherit
305+
self.collisions_to_ignore = collisions_to_ignore
290306

291307
def __call__(self, task_that_inherits):
308+
# Check for parameter collisions and raise an exception if found
309+
self._check_for_parameter_collisions(task_that_inherits)
310+
292311
# Get all parameter objects from each of the underlying tasks
293312
task_iterator = self.tasks_to_inherit or self.kw_tasks_to_inherit.values()
294313
for task_to_inherit in task_iterator:
@@ -323,6 +342,63 @@ def clone_parents(_self, **kwargs):
323342

324343
return task_that_inherits
325344

345+
def _check_for_parameter_collisions(self, task_that_inherits):
346+
"""
347+
Check that the parameters from the tasks_to_inherit don't
348+
silently mask each other or by parameters from the inheriting
349+
task.
350+
351+
An exception will be raised immediately when the first parameter
352+
collision is encountered.
353+
354+
Collisions can be ignored by passing `collisions_to_ignore` with
355+
an interable of allowed parameters to `inherits/requires`.
356+
"""
357+
# only check for parameter collisions when enabled in config
358+
config = get_config()
359+
try:
360+
if config.getboolean("worker", "prevent_parameter_collision") is not True:
361+
return
362+
except (NoSectionError, NoOptionError, KeyError):
363+
return
364+
365+
error_msg = (
366+
'Parameter "{param}" in "{task}" is duplicated in "{another_task}" '
367+
"(or an ancestor). Either rename one of the parameters or include "
368+
'"{param}" in `collisions_to_ignore`.'
369+
)
370+
371+
for task_to_inherit in self.tasks_to_inherit:
372+
for param_name, _ in task_to_inherit.get_params():
373+
# Check that the parameters from the inheriting task don't mask any
374+
# parameters from the inherited tasks.
375+
if (
376+
hasattr(task_that_inherits, param_name)
377+
and param_name not in self.collisions_to_ignore
378+
):
379+
raise ValueError(
380+
error_msg.format(
381+
param=param_name,
382+
task=task_that_inherits.task_family,
383+
another_task=task_to_inherit.task_family,
384+
)
385+
)
386+
# Check that the parameters from an inherited task don't mask the
387+
# parameters from another inherited task.
388+
for another_task_to_inherit in self.tasks_to_inherit:
389+
if (
390+
hasattr(another_task_to_inherit, param_name)
391+
and another_task_to_inherit is not task_to_inherit
392+
and param_name not in self.collisions_to_ignore
393+
):
394+
raise ValueError(
395+
error_msg.format(
396+
param=param_name,
397+
task=task_to_inherit.task_family,
398+
another_task=another_task_to_inherit.task_family,
399+
)
400+
)
401+
326402

327403
class requires:
328404
"""
@@ -332,14 +408,21 @@ class requires:
332408
333409
"""
334410

335-
def __init__(self, *tasks_to_require, **kw_tasks_to_require):
411+
def __init__(
412+
self, *tasks_to_require, collisions_to_ignore=(), **kw_tasks_to_require
413+
):
336414
super(requires, self).__init__()
337415

338416
self.tasks_to_require = tasks_to_require
339417
self.kw_tasks_to_require = kw_tasks_to_require
418+
self.collisions_to_ignore = collisions_to_ignore
340419

341420
def __call__(self, task_that_requires):
342-
task_that_requires = inherits(*self.tasks_to_require, **self.kw_tasks_to_require)(task_that_requires)
421+
task_that_requires = inherits(
422+
*self.tasks_to_require,
423+
collisions_to_ignore=self.collisions_to_ignore,
424+
**self.kw_tasks_to_require,
425+
)(task_that_requires)
343426

344427
# Modify task_that_requires by adding requires method.
345428
# If only one task is required, this single task is returned.
@@ -387,7 +470,7 @@ def run(_self):
387470

388471

389472
def delegates(task_that_delegates):
390-
""" Lets a task call methods on subtask(s).
473+
"""Lets a task call methods on subtask(s).
391474
392475
The way this works is that the subtask is run as a part of the task, but
393476
the task itself doesn't have to care about the requirements of the subtasks.

test/parameter_collision_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
3+
import luigi
4+
from luigi.util import requires
5+
6+
from helpers import with_config
7+
8+
9+
class A(luigi.Task):
10+
num = luigi.IntParameter()
11+
12+
13+
class B(luigi.Task):
14+
num = luigi.IntParameter()
15+
16+
17+
class ParameterCollisionDetectionTest(unittest.TestCase):
18+
@with_config({"worker": {"prevent_parameter_collision": "true"}})
19+
def test_parameter_collision_with_inherited_task(self):
20+
with self.assertRaises(ValueError):
21+
22+
@requires(A)
23+
class T(luigi.Task):
24+
num = luigi.IntParameter()
25+
26+
@with_config({"worker": {"prevent_parameter_collision": "true"}})
27+
def test_parameter_collision_in_inheriting_tasks(self):
28+
with self.assertRaises(ValueError):
29+
30+
@requires(A, B)
31+
class T(luigi.Task):
32+
pass
33+
34+
def test_no_parameter_collision_when_disabled_in_config(self):
35+
@requires(A, B)
36+
class T(luigi.Task):
37+
pass
38+
39+
@with_config({"worker": {"prevent_parameter_collision": "true"}})
40+
def test_parameter_collision_with_inherited_task_ignored_by_allowlist(self):
41+
@requires(A, collisions_to_ignore=["num"])
42+
class T(luigi.Task):
43+
num = luigi.IntParameter()
44+
45+
@with_config({"worker": {"prevent_parameter_collision": "true"}})
46+
def test_parameter_collision_in_inheriting_tasks_ignored_by_allowlist(self):
47+
@requires(A, B, collisions_to_ignore=["num"])
48+
class T(luigi.Task):
49+
pass

0 commit comments

Comments
 (0)