Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions questionary/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict
from typing import NamedTuple
from typing import Sequence
from typing import Union

from questionary.constants import DEFAULT_KBI_MESSAGE
from questionary.question import Question
Expand Down Expand Up @@ -77,7 +78,9 @@ async def unsafe_ask_async(self, patch_stdout: bool = False) -> Dict[str, Any]:
}

def ask(
self, patch_stdout: bool = False, kbi_msg: str = DEFAULT_KBI_MESSAGE
self,
patch_stdout: bool = False,
kbi_msg: Union[str, None] = DEFAULT_KBI_MESSAGE,
) -> Dict[str, Any]:
"""Ask the questions synchronously and return user response.

Expand All @@ -93,11 +96,14 @@ def ask(
try:
return self.unsafe_ask(patch_stdout)
except KeyboardInterrupt:
print(kbi_msg)
if kbi_msg is not None:
print(kbi_msg)
return {}

async def ask_async(
self, patch_stdout: bool = False, kbi_msg: str = DEFAULT_KBI_MESSAGE
self,
patch_stdout: bool = False,
kbi_msg: Union[str, None] = DEFAULT_KBI_MESSAGE,
) -> Dict[str, Any]:
"""Ask the questions using asyncio and return user response.

Expand All @@ -113,5 +119,6 @@ async def ask_async(
try:
return await self.unsafe_ask_async(patch_stdout)
except KeyboardInterrupt:
print(kbi_msg)
if kbi_msg is not None:
print(kbi_msg)
return {}
17 changes: 10 additions & 7 deletions questionary/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def parse_question_config(
return None
except Exception as exception:
raise ValueError(
f"Problem in 'when' check of " f"{name} question: {exception}"
f"Problem in 'when' check of {name} question: {exception}"
) from exception
else:
raise ValueError("'when' needs to be function that accepts a dict argument")
Expand Down Expand Up @@ -124,8 +124,7 @@ def on_answer(answer):
answer = _filter(answer)
except Exception as exception:
raise ValueError(
f"Problem processing 'filter' of {name} "
f"question: {exception}"
f"Problem processing 'filter' of {name} question: {exception}"
) from exception
answers[name] = answer

Expand All @@ -137,7 +136,7 @@ async def prompt_async(
answers: Optional[Mapping[str, Any]] = None,
patch_stdout: bool = False,
true_color: bool = False,
kbi_msg: str = DEFAULT_KBI_MESSAGE,
kbi_msg: Union[str, None] = DEFAULT_KBI_MESSAGE,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prompt the user for input on all the questions using asyncio.
Expand Down Expand Up @@ -168,6 +167,7 @@ async def prompt_async(
are printing to stdout.

kbi_msg: The message to be printed on a keyboard interrupt.

true_color: Use true color output.

color_depth: Color depth to use. If ``true_color`` is set to true then this
Expand All @@ -189,7 +189,8 @@ async def prompt_async(
questions, answers, patch_stdout, true_color, **kwargs
)
except KeyboardInterrupt:
print(kbi_msg)
if kbi_msg is not None:
print(kbi_msg)
return {}


Expand All @@ -198,7 +199,7 @@ def prompt(
answers: Optional[Mapping[str, Any]] = None,
patch_stdout: bool = False,
true_color: bool = False,
kbi_msg: str = DEFAULT_KBI_MESSAGE,
kbi_msg: Union[str, None] = DEFAULT_KBI_MESSAGE,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prompt the user for input on all the questions.
Expand Down Expand Up @@ -229,6 +230,7 @@ def prompt(
are printing to stdout.

kbi_msg: The message to be printed on a keyboard interrupt.

true_color: Use true color output.

color_depth: Color depth to use. If ``true_color`` is set to true then this
Expand All @@ -248,7 +250,8 @@ def prompt(
try:
return unsafe_prompt(questions, answers, patch_stdout, true_color, **kwargs)
except KeyboardInterrupt:
print(kbi_msg)
if kbi_msg is not None:
print(kbi_msg)
return {}


Expand Down
15 changes: 11 additions & 4 deletions questionary/question.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from typing import Any
from typing import Union

import prompt_toolkit.patch_stdout
from prompt_toolkit import Application
Expand All @@ -24,7 +25,9 @@ def __init__(self, application: "Application[Any]") -> None:
self.default = None

async def ask_async(
self, patch_stdout: bool = False, kbi_msg: str = DEFAULT_KBI_MESSAGE
self,
patch_stdout: bool = False,
kbi_msg: Union[str, None] = DEFAULT_KBI_MESSAGE,
) -> Any:
"""Ask the question using asyncio and return user response.

Expand All @@ -42,11 +45,14 @@ async def ask_async(
sys.stdout.flush()
return await self.unsafe_ask_async(patch_stdout)
except KeyboardInterrupt:
print("{}".format(kbi_msg))
if kbi_msg is not None:
print(f"{kbi_msg}")
return None

def ask(
self, patch_stdout: bool = False, kbi_msg: str = DEFAULT_KBI_MESSAGE
self,
patch_stdout: bool = False,
kbi_msg: Union[str, None] = DEFAULT_KBI_MESSAGE,
) -> Any:
"""Ask the question synchronously and return user response.

Expand All @@ -63,7 +69,8 @@ def ask(
try:
return self.unsafe_ask(patch_stdout)
except KeyboardInterrupt:
print("{}".format(kbi_msg))
if kbi_msg is not None:
print(f"{kbi_msg}")
return None

def unsafe_ask(self, patch_stdout: bool = False) -> Any:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_form.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import asyncio
from unittest.mock import patch

from prompt_toolkit.output import DummyOutput
from pytest import fail

Expand Down Expand Up @@ -77,3 +80,39 @@ def run(inp):
fail("Keyboard Interrupt should be caught by `ask()`")

execute_with_input_pipe(run)


def test_no_keyboard_interrupt_message_in_ask() -> None:
"""
Test no message printed when `kbi_msg` is None in `ask()`.
"""

def run(inp):
inp.send_text(KeyInputs.CONTROLC)
f = example_form(inp)

with patch("builtins.print") as mock_print:
result = f.ask(kbi_msg=None)

mock_print.assert_not_called()
assert result == {}

execute_with_input_pipe(run)


def test_no_keyboard_interrupt_message_in_ask_async() -> None:
"""
Test no message printed when `kbi_msg` is None in `ask_async()`.
"""

def run(inp):
inp.send_text(KeyInputs.CONTROLC)
f = example_form(inp)

with patch("builtins.print") as mock_print:
result = asyncio.run(f.ask_async(kbi_msg=None))

mock_print.assert_not_called()
assert result == {}

execute_with_input_pipe(run)
54 changes: 54 additions & 0 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import asyncio
from unittest.mock import patch

import pytest

from questionary.prompt import PromptParameterException
from questionary.prompt import prompt
from questionary.prompt import prompt_async
from tests.utils import patched_prompt


Expand Down Expand Up @@ -76,3 +80,53 @@ def test_print_with_name():
questions = [{"name": "hello", "type": "print", "message": "Hello World"}]
result = patched_prompt(questions, "")
assert result == {"hello": None}


@patch("builtins.print")
@patch("questionary.prompt.unsafe_prompt", side_effect=KeyboardInterrupt)
def test_no_keyboard_interrupt_message_in_prompt(
mock_unsafe_prompt, mock_print
) -> None:
"""
Test no message printed when `kbi_msg` is None in `prompt()`.

Args:
mock_unsafe_prompt: A mock of the internal `unsafe_prompt()` call.
raises a KeyboardInterrupt.

mock_print: A mock of Python's builtin `print()` function.
"""
# Act
result = prompt(questions={}, kbi_msg=None)

# Verify internal functions were properly called
mock_unsafe_prompt.assert_called_once() # Raises KeyboardInterrupt
mock_print.assert_not_called()

# Verify result
assert result == {}


@patch("builtins.print")
@patch("questionary.prompt.unsafe_prompt_async", side_effect=KeyboardInterrupt)
def test_no_keyboard_interrupt_message_in_prompt_async(
mock_unsafe_prompt_async, mock_print
) -> None:
"""
Test no message printed when `kbi_msg` is None in `prompt_async()`.

Args:
mock_unsafe_prompt_async: A mock of the internal `unsafe_prompt_async()` call.
raises a KeyboardInterrupt.

mock_print: A mock of Python's builtin `print()` function.
"""
# Act
result = asyncio.run(prompt_async(questions={}, kbi_msg=None))

# Verify internal functions were properly called
mock_unsafe_prompt_async.assert_called_once() # Raises KeyboardInterrupt
mock_print.assert_not_called()

# Verify result
assert result == {}
37 changes: 37 additions & 0 deletions tests/test_question.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import platform
from unittest.mock import patch

import pytest
from prompt_toolkit.output import DummyOutput
Expand Down Expand Up @@ -94,3 +95,39 @@ def run(inp):
assert response == "Hello\nworld"

execute_with_input_pipe(run)


def test_no_keyboard_interrupt_message_in_ask():
"""
Test no message printed when `kbi_msg` is None in `ask()`.
"""

def run(inp):
inp.send_text(KeyInputs.CONTROLC)
question = text("Hello?", input=inp, output=DummyOutput())

with patch("builtins.print") as mock_print:
result = question.ask(kbi_msg=None)

mock_print.assert_not_called()
assert result is None

execute_with_input_pipe(run)


def test_no_keyboard_interrupt_message_in_ask_async():
"""
Test no message printed when `kbi_msg` is None in `ask_async()`.
"""

def run(inp):
inp.send_text(KeyInputs.CONTROLC)
question = text("Hello?", input=inp, output=DummyOutput())

with patch("builtins.print") as mock_print:
result = asyncio.run(question.ask_async(kbi_msg=None))

mock_print.assert_not_called()
assert result is None

execute_with_input_pipe(run)
Loading