diff --git a/changelog.md b/changelog.md index 4c956ac9..7af0f851 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ Upcoming Release (TBD) ====================== +Features +-------- +* Add LLM support. + + Bug Fixes -------- * Improve missing ssh-extras message. @@ -9,6 +14,7 @@ Bug Fixes Internal -------- * Improve pull request template lint commands. +* Continue typehinting the non-test codebase. 1.37.1 (2025/07/28) diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py index cf2c03cc..1d22c095 100644 --- a/mycli/clibuffer.py +++ b/mycli/clibuffer.py @@ -1,13 +1,13 @@ -from typing import Callable +from __future__ import annotations from prompt_toolkit.application import get_app from prompt_toolkit.enums import DEFAULT_BUFFER -from prompt_toolkit.filters import Condition +from prompt_toolkit.filters import Condition, Filter from mycli.packages.special import iocommands -def cli_is_multiline(mycli) -> Callable: +def cli_is_multiline(mycli) -> Filter: @Condition def cond(): doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document diff --git a/mycli/config.py b/mycli/config.py index 07f57236..390373bd 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -8,7 +8,7 @@ from os.path import exists import struct import sys -from typing import IO, BinaryIO, Literal +from typing import IO, BinaryIO, Literal, TextIO from configobj import ConfigObj, ConfigObjError import pyaes @@ -25,7 +25,7 @@ def log(logger: logging.Logger, level: int, message: str) -> None: logger.log(level, message) -def read_config_file(f: str | TextIOWrapper, list_values: bool = True) -> ConfigObj | None: +def read_config_file(f: str | TextIO | TextIOWrapper, list_values: bool = True) -> ConfigObj | None: """Read a config file. *list_values* set to `True` is the default behavior of ConfigObj. @@ -52,7 +52,7 @@ def read_config_file(f: str | TextIOWrapper, list_values: bool = True) -> Config return config -def get_included_configs(config_file: str | TextIOWrapper) -> list[str]: +def get_included_configs(config_file: str | TextIOWrapper) -> list[str | TextIOWrapper]: """Get a list of configuration files that are included into config_path with !includedir directive. @@ -64,7 +64,7 @@ def get_included_configs(config_file: str | TextIOWrapper) -> list[str]: """ if not isinstance(config_file, str) or not os.path.isfile(config_file): return [] - included_configs = [] + included_configs: list[str | TextIOWrapper] = [] try: with open(config_file) as f: @@ -80,7 +80,7 @@ def get_included_configs(config_file: str | TextIOWrapper) -> list[str]: return included_configs -def read_config_files(files: list[str], list_values: bool = True) -> ConfigObj: +def read_config_files(files: list[str | TextIOWrapper], list_values: bool = True) -> ConfigObj: """Read and merge a list of config files.""" config = create_default_config(list_values=list_values) diff --git a/mycli/main.py b/mycli/main.py index 8627c8c5..1ef2f7ed 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,6 +1,7 @@ -# type: ignore +from __future__ import annotations from collections import defaultdict, namedtuple +from io import TextIOWrapper import logging import os import re @@ -8,6 +9,7 @@ import sys import threading import traceback +from typing import Any, Generator, Iterable, Literal try: from pwd import getpwuid @@ -24,22 +26,24 @@ from cli_helpers.utils import strip_ansi import click from prompt_toolkit.auto_suggest import AutoSuggestFromHistory -from prompt_toolkit.completion import DynamicCompleter +from prompt_toolkit.completion import Completion, DynamicCompleter from prompt_toolkit.document import Document from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode from prompt_toolkit.filters import HasFocus, IsDone -from prompt_toolkit.formatted_text import ANSI +from prompt_toolkit.formatted_text import ANSI, AnyFormattedText from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register +from prompt_toolkit.key_binding.key_processor import KeyPressEvent from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor from prompt_toolkit.lexers import PygmentsLexer from prompt_toolkit.shortcuts import CompleteStyle, PromptSession from pymysql import OperationalError +from pymysql.cursors import Cursor import sqlglot import sqlparse from mycli import __version__ from mycli.clibuffer import cli_is_multiline -from mycli.clistyle import style_factory, style_factory_output +from mycli.clistyle import style_factory, style_factory_output # type: ignore[attr-defined] from mycli.clitoolbar import create_toolbar_tokens_func from mycli.compat import WIN from mycli.completion_refresher import CompletionRefresher @@ -60,14 +64,15 @@ try: import paramiko except ImportError: - from mycli.packages.paramiko_stub import paramiko + from mycli.packages.paramiko_stub import paramiko # type: ignore[no-redef] -click.disable_unicode_literals_warning = True # Query tuples are used for maintaining history Query = namedtuple("Query", ["query", "successful", "mutating"]) SUPPORT_INFO = "Home: http://mycli.net\nBug tracker: https://github.com/dbcli/mycli/issues" +DEFAULT_WIDTH = 80 +DEFAULT_HEIGHT = 25 class PasswordFileError(Exception): @@ -81,7 +86,7 @@ class MyCli: defaults_suffix = None # In order of being loaded. Files lower in list override earlier ones. - cnf_files = [ + cnf_files: list[str | TextIOWrapper] = [ "/etc/my.cnf", "/etc/mysql/my.cnf", "/usr/local/etc/my.cnf", @@ -90,27 +95,31 @@ class MyCli: # check XDG_CONFIG_HOME exists and not an empty string xdg_config_home = os.environ.get("XDG_CONFIG_HOME", "~/.config") - system_config_files = ["/etc/myclirc", os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")] + system_config_files: list[str | TextIOWrapper] = [ + "/etc/myclirc", + os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc"), + ] pwd_config_file = os.path.join(os.getcwd(), ".myclirc") def __init__( self, - sqlexecute=None, - prompt=None, - logfile=None, - defaults_suffix=None, - defaults_file=None, - login_path=None, - auto_vertical_output=False, - warn=None, - myclirc="~/.myclirc", - ): + sqlexecute: SQLExecute | None = None, + prompt: str | None = None, + logfile: TextIOWrapper | Literal[False] | None = None, + defaults_suffix: str | None = None, + defaults_file: str | None = None, + login_path: str | None = None, + auto_vertical_output: bool = False, + warn: bool | None = None, + myclirc: str = "~/.myclirc", + ) -> None: self.sqlexecute = sqlexecute self.logfile = logfile self.defaults_suffix = defaults_suffix self.login_path = login_path - self.toolbar_error_message = None + self.toolbar_error_message: str | None = None + self.prompt_app: PromptSession | None = None # self.cnf_files is a class variable that stores the list of mysql # config files to read in at launch. @@ -120,7 +129,7 @@ def __init__( self.cnf_files = [defaults_file] # Load config. - config_files = self.system_config_files + [myclirc] + [self.pwd_config_file] + config_files: list[str | TextIOWrapper] = self.system_config_files + [myclirc] + [self.pwd_config_file] c = self.config = read_config_files(config_files) self.multi_line = c["main"].as_bool("multi_line") self.key_bindings = c["main"]["key_bindings"] @@ -169,7 +178,7 @@ def __init__( self.multiline_continuation_char = c["main"]["prompt_continuation"] keyword_casing = c["main"].get("keyword_casing", "auto") - self.query_history = [] + self.query_history: list[Query] = [] # Initialize completer. self.smart_completion = c["main"].as_bool("smart_completion") @@ -194,7 +203,7 @@ def __init__( self.prompt_app = None - def register_special_commands(self): + def register_special_commands(self) -> None: special.register_special_command(self.change_db, "use", "\\u", "Change to a new database.", aliases=["\\u"]) special.register_special_command( self.change_db, @@ -228,7 +237,7 @@ def register_special_commands(self): self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True ) - def change_table_format(self, arg, **_): + def change_table_format(self, arg: str, **_) -> Generator[tuple, None, None]: try: self.main_formatter.format_name = arg yield (None, None, None, "Changed table format to {}".format(arg)) @@ -238,7 +247,7 @@ def change_table_format(self, arg, **_): msg += "\n\t{}".format(table_type) yield (None, None, None, msg) - def change_redirect_format(self, arg, **_): + def change_redirect_format(self, arg: str, **_) -> Generator[tuple, None, None]: try: self.redirect_formatter.format_name = arg yield (None, None, None, "Changed redirect format to {}".format(arg)) @@ -248,21 +257,23 @@ def change_redirect_format(self, arg, **_): msg += "\n\t{}".format(table_type) yield (None, None, None, msg) - def change_db(self, arg, **_): + def change_db(self, arg: str, **_) -> Generator[tuple, None, None]: + if arg.startswith("`") and arg.endswith("`"): + arg = re.sub(r"^`(.*)`$", r"\1", arg) + arg = re.sub(r"``", r"`", arg) + if not arg: click.secho("No database selected", err=True, fg="red") return - if arg.startswith("`") and arg.endswith("`"): - arg = re.sub(r"^`(.*)`$", r"\1", arg) - arg = re.sub(r"``", r"`", arg) + assert isinstance(self.sqlexecute, SQLExecute) self.sqlexecute.change_db(arg) yield (None, None, None, 'You are now connected to database "%s" as user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) - def execute_from_file(self, arg, **_): + def execute_from_file(self, arg: str, **_) -> Iterable[tuple]: if not arg: - message = "Missing required argument, filename." + message = "Missing required argument: filename." return [(None, None, None, message)] try: with open(os.path.expanduser(arg)) as f: @@ -274,9 +285,10 @@ def execute_from_file(self, arg, **_): message = "Wise choice. Command execution stopped." return [(None, None, None, message)] + assert isinstance(self.sqlexecute, SQLExecute) return self.sqlexecute.run(query) - def change_prompt_format(self, arg, **_): + def change_prompt_format(self, arg: str, **_) -> list[tuple]: """ Change the prompt format. """ @@ -287,7 +299,7 @@ def change_prompt_format(self, arg, **_): self.prompt_format = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] - def initialize_logging(self): + def initialize_logging(self) -> None: log_file = os.path.expanduser(self.config["main"]["log_file"]) log_level = self.config["main"]["log_level"] @@ -302,7 +314,7 @@ def initialize_logging(self): # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. if log_level.upper() == "NONE": - handler = logging.NullHandler() + handler: logging.Handler = logging.NullHandler() log_level = "CRITICAL" elif dir_path_exists(log_file): handler = logging.FileHandler(log_file) @@ -323,7 +335,7 @@ def initialize_logging(self): root_logger.debug("Initializing mycli logging.") root_logger.debug("Log file %r.", log_file) - def read_my_cnf_files(self, files, keys): + def read_my_cnf_files(self, files: list[str | TextIOWrapper], keys: list[str]) -> dict[str, Any]: """ Reads a list of config files and merges them. The last one will win. :param files: list of files to read @@ -347,7 +359,7 @@ def read_my_cnf_files(self, files, keys): if self.defaults_suffix: sections.extend([sect + self.defaults_suffix for sect in sections]) - configuration = defaultdict(lambda: None) + configuration: dict[str, Any] = defaultdict(lambda: None) for key in keys: for section in cnf: if section not in sections or key not in cnf[section]: @@ -357,7 +369,7 @@ def read_my_cnf_files(self, files, keys): return configuration - def merge_ssl_with_cnf(self, ssl, cnf): + def merge_ssl_with_cnf(self, ssl: dict[str, Any], cnf: dict[str, Any]) -> dict[str, Any]: """Merge SSL configuration dict with cnf dict""" merged = {} @@ -382,23 +394,23 @@ def merge_ssl_with_cnf(self, ssl, cnf): def connect( self, - database="", - user="", - passwd="", - host="", - port="", - socket="", - charset="", - local_infile="", - ssl="", - ssh_user="", - ssh_host="", - ssh_port="", - ssh_password="", - ssh_key_filename="", - init_command="", - password_file="", - ): + database: str | None = "", + user: str | None = "", + passwd: str = "", + host: str | None = "", + port: str | int | None = "", + socket: str | None = "", + charset: str = "", + local_infile: str = "", + ssl: dict[str, Any] | None = {}, + ssh_user: str = "", + ssh_host: str = "", + ssh_port: str = "", + ssh_password: str = "", + ssh_key_filename: str = "", + init_command: str = "", + password_file: str = "", + ) -> None: cnf = { "database": None, "user": None, @@ -417,18 +429,18 @@ def connect( "ssl-verify-serer-cert": None, } - cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) + cnf = self.read_my_cnf_files(self.cnf_files, list(cnf.keys())) # Fall back to config values only if user did not specify a value. database = database or cnf["database"] user = user or cnf["user"] or os.getenv("USER") host = host or cnf["host"] port = port or cnf["port"] - ssl = ssl or {} + ssl_config: dict[str, Any] = ssl or {} - port = port and int(port) - if not port: - port = 3306 + int_port = port and int(port) + if not int_port: + int_port = 3306 if not host or host == "localhost": socket = socket or cnf["socket"] or cnf["default_socket"] or guess_socket_location() @@ -436,17 +448,18 @@ def connect( charset = charset or cnf["default-character-set"] or "utf8" # Favor whichever local_infile option is set. + use_local_infile = False for local_infile_option in (local_infile, cnf["local-infile"], cnf["loose-local-infile"], False): try: - local_infile = str_to_bool(local_infile_option) + use_local_infile = str_to_bool(local_infile_option or '') break except (TypeError, ValueError): pass - ssl = self.merge_ssl_with_cnf(ssl, cnf) + ssl_config_or_none: dict[str, Any] | None = self.merge_ssl_with_cnf(ssl_config, cnf) # prune lone check_hostname=False - if not any(v for v in ssl.values()): - ssl = None + if not any(v for v in ssl_config.values()): + ssl_config_or_none = None # if the passwd is not specified try to set it using the password_file option password_from_file = self.get_password_from_file(password_file) @@ -454,21 +467,21 @@ def connect( # Connect to the database. - def _connect(): + def _connect() -> None: try: self.sqlexecute = SQLExecute( database, user, passwd, host, - port, + int_port, socket, charset, - local_infile, - ssl, + use_local_infile, + ssl_config_or_none, ssh_user, ssh_host, - ssh_port, + int(ssh_port) if ssh_port else None, ssh_password, ssh_key_filename, init_command, @@ -484,14 +497,14 @@ def _connect(): user, new_passwd, host, - port, + int_port, socket, charset, - local_infile, - ssl, + use_local_infile, + ssl_config, ssh_user, ssh_host, - ssh_port, + int(ssh_port) if ssh_port else None, ssh_password, ssh_key_filename, init_command, @@ -540,22 +553,23 @@ def _connect(): self.echo(str(e), err=True, fg="red") sys.exit(1) - def get_password_from_file(self, password_file): - if password_file: - try: - with open(password_file) as fp: - password = fp.readline().strip() - return password - except FileNotFoundError: - raise PasswordFileError(f"Password file '{password_file}' not found") from None - except PermissionError: - raise PasswordFileError(f"Permission denied reading password file '{password_file}'") from None - except IsADirectoryError: - raise PasswordFileError(f"Path '{password_file}' is a directory, not a file") from None - except Exception as e: - raise PasswordFileError(f"Error reading password file '{password_file}': {str(e)}") from None + def get_password_from_file(self, password_file: str) -> str: + if not password_file: + return '' + try: + with open(password_file) as fp: + password = fp.readline().strip() + return password + except FileNotFoundError: + raise PasswordFileError(f"Password file '{password_file}' not found") from None + except PermissionError: + raise PasswordFileError(f"Permission denied reading password file '{password_file}'") from None + except IsADirectoryError: + raise PasswordFileError(f"Path '{password_file}' is a directory, not a file") from None + except Exception as e: + raise PasswordFileError(f"Error reading password file '{password_file}': {str(e)}") from None - def handle_editor_command(self, text): + def handle_editor_command(self, text: str) -> str: r"""Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: @@ -577,6 +591,7 @@ def handle_editor_command(self, text): raise RuntimeError(message) while True: try: + assert isinstance(self.prompt_app, PromptSession) text = self.prompt_app.prompt(default=sql) break except KeyboardInterrupt: @@ -585,7 +600,7 @@ def handle_editor_command(self, text): continue return text - def handle_clip_command(self, text): + def handle_clip_command(self, text: str) -> bool: r"""A clip command is any query that is prefixed or suffixed by a '\clip'. @@ -602,7 +617,7 @@ def handle_clip_command(self, text): return True return False - def handle_prettify_binding(self, text): + def handle_prettify_binding(self, text: str) -> str: try: statements = sqlglot.parse(text, read="mysql") except Exception: @@ -616,7 +631,7 @@ def handle_prettify_binding(self, text): pretty_text = pretty_text + ";" return pretty_text - def handle_unprettify_binding(self, text): + def handle_unprettify_binding(self, text: str) -> str: try: statements = sqlglot.parse(text, read="mysql") except Exception: @@ -630,9 +645,10 @@ def handle_unprettify_binding(self, text): unpretty_text = unpretty_text + ";" return unpretty_text - def run_cli(self): + def run_cli(self) -> None: iterations = 0 sqlexecute = self.sqlexecute + assert isinstance(sqlexecute, SQLExecute) logger = self.logger self.configure_pager() @@ -658,14 +674,14 @@ def run_cli(self): print(SUPPORT_INFO) print("Thanks to the contributor -", thanks_picker()) - def get_message(): + def get_message() -> ANSI: prompt = self.get_prompt(self.prompt_format) if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: prompt = self.get_prompt(self.default_prompt_splitln) prompt = prompt.replace("\\x1b", "\x1b") return ANSI(prompt) - def get_continuation(width, *_): + def get_continuation(width: int, _two: int, _three: int) -> AnyFormattedText: if self.multiline_continuation_char == "": continuation = "" elif self.multiline_continuation_char: @@ -675,7 +691,7 @@ def get_continuation(width, *_): continuation = " " return [("class:continuation", continuation)] - def show_suggestion_tip(): + def show_suggestion_tip() -> bool: return iterations < 2 # Keep track of whether or not the query is mutating. In case @@ -735,9 +751,10 @@ def output_res(res, start): mutating = mutating or is_mutating(status) return - def one_iteration(text=None): + def one_iteration(text: str | None = None) -> None: if text is None: try: + assert self.prompt_app is not None text = self.prompt_app.prompt() except KeyboardInterrupt: return @@ -763,8 +780,9 @@ def one_iteration(text=None): return # LLM command support while special.is_llm_command(text): + start = time() try: - start = time() + assert sqlexecute.conn is not None cur = sqlexecute.conn.cursor() context, sql, duration = special.handle_llm(text, cur) if context: @@ -772,23 +790,26 @@ def one_iteration(text=None): click.echo(context) click.echo("---") click.echo(f"Time: {duration:.2f} seconds") - text = self.prompt_app.prompt(default=sql) + text = self.prompt_app.prompt(default=sql or '') except KeyboardInterrupt: return except special.FinishIteration as e: - return output_res(e.results, start) if e.results else None + if e.results: + output_res(e.results, start) except RuntimeError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg="red") return - if not text.strip(): + text = text.strip() + + if not text: return if is_redirect_command(text): sql_part, command_part, file_operator_part, file_part = get_redirect_components(text) - text = sql_part + text = sql_part or '' try: special.set_redirect(command_part, file_operator_part, file_part) except (FileNotFoundError, OSError, RuntimeError) as e: @@ -831,7 +852,7 @@ def one_iteration(text=None): raise e except KeyboardInterrupt: # get last connection id - connection_id_to_kill = sqlexecute.connection_id + connection_id_to_kill = sqlexecute.connection_id or 0 # some mysql compatible databases may not implemente connection_id() if connection_id_to_kill > 0: logger.debug("connection id to kill: %r", connection_id_to_kill) @@ -857,9 +878,9 @@ def one_iteration(text=None): self.echo("Did not get a connection id, skip cancelling query", err=True, fg="red") except NotImplementedError: self.echo("Not Yet Implemented.", fg="yellow") - except OperationalError as e: - logger.debug("Exception: %r", e) - if e.args[0] in (2003, 2006, 2013): + except OperationalError as e1: + logger.debug("Exception: %r", e1) + if e1.args[0] in (2003, 2006, 2013): logger.debug("Attempting to reconnect.") self.echo("Reconnecting...", fg="yellow") try: @@ -867,23 +888,23 @@ def one_iteration(text=None): logger.debug("Reconnected successfully.") one_iteration(text) return # OK to just return, cuz the recursion call runs to the end. - except OperationalError as e: - logger.debug("Reconnect failed. e: %r", e) - self.echo(str(e), err=True, fg="red") + except OperationalError as e2: + logger.debug("Reconnect failed. e: %r", e2) + self.echo(str(e2), err=True, fg="red") # If reconnection failed, don't proceed further. return else: - logger.error("sql: %r, error: %r", text, e) + logger.error("sql: %r, error: %r", text, e1) logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") + self.echo(str(e1), err=True, fg="red") except Exception as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg="red") else: - if is_dropping_database(text, self.sqlexecute.dbname): - self.sqlexecute.dbname = None - self.sqlexecute.connect() + if is_dropping_database(text, sqlexecute.dbname): + sqlexecute.dbname = None + sqlexecute.connect() # Refresh the table names and column names if necessary. if need_completion_refresh(text): @@ -943,12 +964,12 @@ def one_iteration(text=None): if not self.less_chatty: self.echo("Goodbye!") - def log_output(self, output): + def log_output(self, output: str) -> None: """Log the output in the audit log, if it's enabled.""" - if self.logfile: + if isinstance(self.logfile, TextIOWrapper): click.echo(output, file=self.logfile) - def echo(self, s, **kwargs): + def echo(self, s: str, **kwargs) -> None: """Print a message to stdout. The message will be logged in the audit log, if enabled. @@ -959,11 +980,11 @@ def echo(self, s, **kwargs): self.log_output(s) click.secho(s, **kwargs) - def bell(self): + def bell(self) -> None: """Print a bell on the stderr.""" click.secho("\a", err=True, nl=False) - def get_output_margin(self, status=None): + def get_output_margin(self, status: str | None = None) -> int: """Get the output margin (number of rows for the prompt, footer and timing message.""" margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count("\n") + 1 @@ -974,7 +995,7 @@ def get_output_margin(self, status=None): return margin - def output(self, output, status=None): + def output(self, output: itertools.chain[str], status: str | None = None) -> None: """Output text to stdout or a pager command. The status text is not outputted to pager or files. @@ -985,7 +1006,13 @@ def output(self, output, status=None): """ if output: - size = self.prompt_app.output.get_size() + if self.prompt_app is not None: + size = self.prompt_app.output.get_size() + size_columns = size.columns + size_rows = size.rows + else: + size_columns = DEFAULT_WIDTH + size_rows = DEFAULT_HEIGHT margin = self.get_output_margin(status) @@ -1003,7 +1030,7 @@ def output(self, output, status=None): elif fits or output_via_pager: # buffering buf.append(line) - if len(line) > size.columns or i > (size.rows - margin): + if len(line) > size_columns or i > (size_rows - margin): fits = False if not self.explicit_pager and special.is_pager_enabled(): # doesn't fit, use pager @@ -1020,7 +1047,7 @@ def output(self, output, status=None): if buf: if output_via_pager: - def newlinewrapper(text): + def newlinewrapper(text: list[str]) -> Generator[str, None, None]: for line in text: yield line + "\n" @@ -1033,7 +1060,7 @@ def newlinewrapper(text): self.log_output(status) click.secho(status) - def configure_pager(self): + def configure_pager(self) -> None: # Provide sane defaults for less if they are empty. if not os.environ.get("LESS"): os.environ["LESS"] = "-RXF" @@ -1054,10 +1081,11 @@ def configure_pager(self): if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"): special.disable_pager() - def refresh_completions(self, reset=False): + def refresh_completions(self, reset: bool = False) -> list[tuple]: if reset: with self._completer_lock: self.completer.reset_completions() + assert self.sqlexecute is not None self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, @@ -1070,7 +1098,7 @@ def refresh_completions(self, reset=False): return [(None, None, None, "Auto-completion refresh started in the background.")] - def _on_completions_refreshed(self, new_completer): + def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None: """Swap the completer object in cli with the newly created completer.""" with self._completer_lock: self.completer = new_completer @@ -1080,12 +1108,15 @@ def _on_completions_refreshed(self, new_completer): # "Refreshing completions..." indicator self.prompt_app.app.invalidate() - def get_completions(self, text, cursor_positition): + def get_completions(self, text: str, cursor_positition: int) -> Iterable[Completion]: with self._completer_lock: return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None) - def get_prompt(self, string): + def get_prompt(self, string: str) -> str: sqlexecute = self.sqlexecute + assert sqlexecute is not None + assert sqlexecute.server_info is not None + assert sqlexecute.server_info.species is not None host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host now = datetime.now() string = string.replace("\\u", sqlexecute.user or "(none)") @@ -1104,8 +1135,9 @@ def get_prompt(self, string): string = string.replace("\\_", " ") return string - def run_query(self, query, new_line=True): + def run_query(self, query: str, new_line: bool = True) -> None: """Runs *query*.""" + assert self.sqlexecute is not None results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result @@ -1123,20 +1155,20 @@ def run_query(self, query, new_line=True): def format_output( self, - title, - cur, - headers, - expanded=False, - is_redirected=False, - max_width=None, - ): + title: str | None, + cur: Cursor | list[tuple] | None, + headers: list[str] | None, + expanded: bool = False, + is_redirected: bool = False, + max_width: int | None = None, + ) -> itertools.chain[str]: if is_redirected: use_formatter = self.redirect_formatter else: use_formatter = self.main_formatter expanded = expanded or use_formatter.format_name == "vertical" - output = [] + output: itertools.chain[str] = itertools.chain() output_kwargs = {"dialect": "unix", "disable_numparse": True, "preserve_whitespace": True, "style": self.output_style} @@ -1148,13 +1180,13 @@ def format_output( if cur: column_types = None - if hasattr(cur, "description"): + if isinstance(cur, Cursor): - def get_col_type(col): + def get_col_type(col) -> type: col_type = FIELD_TYPES.get(col[1], str) return col_type if type(col_type) is type else str - column_types = [get_col_type(col) for col in cur.description] + column_types = [get_col_type(tup) for tup in cur.description] if max_width is not None: cur = list(cur) @@ -1190,14 +1222,14 @@ def get_col_type(col): return output - def get_reserved_space(self): + def get_reserved_space(self) -> int: """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = 0.45 max_reserved_space = 8 _, height = shutil.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) - def get_last_query(self): + def get_last_query(self) -> str | None: """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None @@ -1547,7 +1579,7 @@ def cli( sys.exit(1) -def need_completion_refresh(queries): +def need_completion_refresh(queries: str) -> bool: """Determines if the completion needs a refresh by checking if the sql statement is an alter, create, drop or change db.""" for query in sqlparse.split(queries): @@ -1557,9 +1589,10 @@ def need_completion_refresh(queries): return True except Exception: return False + return False -def need_completion_reset(queries): +def need_completion_reset(queries: str) -> bool: """Determines if the statement is a database switch such as 'use' or '\\u'. When a database is changed the existing completions must be reset before we start the completion refresh for the new database. @@ -1571,9 +1604,10 @@ def need_completion_reset(queries): return True except Exception: return False + return False -def is_mutating(status): +def is_mutating(status: str | None) -> bool: """Determines if the statement is mutating based on the status.""" if not status: return False @@ -1582,14 +1616,14 @@ def is_mutating(status): return status.split(None, 1)[0].lower() in mutating -def is_select(status): +def is_select(status: str | None) -> bool: """Returns true if the first word in status is 'select'.""" if not status: return False return status.split(None, 1)[0].lower() == "select" -def thanks_picker(): +def thanks_picker() -> str: import mycli lines = (resources.read_text(mycli, "AUTHORS") + resources.read_text(mycli, "SPONSORS")).split("\n") @@ -1603,14 +1637,14 @@ def thanks_picker(): @prompt_register("edit-and-execute-command") -def edit_and_execute(event): +def edit_and_execute(event: KeyPressEvent) -> None: """Different from the prompt-toolkit default, we want to have a choice not to execute a query after editing, hence validate_and_handle=False.""" buff = event.current_buffer buff.open_in_editor(validate_and_handle=False) -def read_ssh_config(ssh_config_path): +def read_ssh_config(ssh_config_path: str): ssh_config = paramiko.config.SSHConfig() try: with open(ssh_config_path) as f: diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index 4516f8b5..aae7e790 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -274,13 +274,13 @@ def is_destructive(queries: str) -> bool: return False -def is_dropping_database(queries: list[str], dbname: str | None) -> bool: +def is_dropping_database(queries: str, dbname: str | None) -> bool: """Determine if the query is dropping a specific database.""" result = False if dbname is None: return False - def normalize_db_name(db): + def normalize_db_name(db: str) -> str: return db.lower().strip('`"') dbname = normalize_db_name(dbname) diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py index 737dc9df..1c432b55 100644 --- a/mycli/packages/special/__init__.py +++ b/mycli/packages/special/__init__.py @@ -1,19 +1,95 @@ from __future__ import annotations -from typing import Callable - -__all__: list[str] = [] - - -def export(defn: Callable): - """Decorator to explicitly mark functions that are exposed in a lib.""" - globals()[defn.__name__] = defn - __all__.append(defn.__name__) - return defn - - -from mycli.packages.special import ( - dbcommands, # noqa: E402 F401 - iocommands, # noqa: E402 F401 - llm, # noqa: E402 F401 +from mycli.packages.special.dbcommands import ( + list_databases, + list_tables, + status, ) +from mycli.packages.special.iocommands import ( + clip_command, + close_tee, + copy_query_to_clipboard, + disable_pager, + editor_command, + flush_pipe_once_if_written, + forced_horizontal, + get_clip_query, + get_current_delimiter, + get_editor_query, + get_filename, + is_expanded_output, + is_pager_enabled, + is_redirected, + is_timing_enabled, + open_external_editor, + set_delimiter, + set_expanded_output, + set_favorite_queries, + set_forced_horizontal_output, + set_pager, + set_pager_enabled, + set_redirect, + set_timing_enabled, + split_queries, + unset_once_if_written, + write_once, + write_pipe_once, + write_tee, +) +from mycli.packages.special.llm import ( + FinishIteration, + handle_llm, + is_llm_command, + sql_using_llm, +) +from mycli.packages.special.main import ( + CommandNotFound, + execute, + parse_special_command, + register_special_command, + special_command, +) + +__all__: list[str] = [ + 'CommandNotFound', + 'FinishIteration', + 'clip_command', + 'close_tee', + 'copy_query_to_clipboard', + 'disable_pager', + 'editor_command', + 'execute', + 'flush_pipe_once_if_written', + 'forced_horizontal', + 'get_clip_query', + 'get_current_delimiter', + 'get_editor_query', + 'get_filename', + 'handle_llm', + 'is_expanded_output', + 'is_llm_command', + 'is_pager_enabled', + 'is_redirected', + 'is_timing_enabled', + 'list_databases', + 'list_tables', + 'open_external_editor', + 'parse_special_command', + 'register_special_command', + 'set_delimiter', + 'set_expanded_output', + 'set_favorite_queries', + 'set_forced_horizontal_output', + 'set_pager', + 'set_pager_enabled', + 'set_redirect', + 'set_timing_enabled', + 'special_command', + 'split_queries', + 'sql_using_llm', + 'status', + 'unset_once_if_written', + 'write_once', + 'write_pipe_once', + 'write_tee', +] diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index 8a0cda99..6c9f8023 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -17,7 +17,6 @@ from mycli.compat import WIN from mycli.packages.prompt_utils import confirm_destructive_query -from mycli.packages.special import export from mycli.packages.special.delimitercommand import DelimiterCommand from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import ArgType, special_command @@ -40,30 +39,25 @@ favoritequeries = FavoriteQueries(ConfigObj()) -@export def set_favorite_queries(config): global favoritequeries favoritequeries = FavoriteQueries(config) -@export def set_timing_enabled(val: bool) -> None: global TIMING_ENABLED TIMING_ENABLED = val -@export def set_pager_enabled(val: bool) -> None: global PAGER_ENABLED PAGER_ENABLED = val -@export def is_pager_enabled() -> bool: return PAGER_ENABLED -@export @special_command( "pager", "\\P [command]", @@ -88,7 +82,6 @@ def set_pager(arg: str, **_) -> list[tuple]: return [(None, None, None, msg)] -@export @special_command("nopager", "\\n", "Disable pager, print to stdout.", arg_type=ArgType.NO_QUERY, aliases=["\\n"], case_sensitive=True) def disable_pager() -> list[tuple]: set_pager_enabled(False) @@ -104,29 +97,24 @@ def toggle_timing() -> list[tuple]: return [(None, None, None, message)] -@export def is_timing_enabled() -> bool: return TIMING_ENABLED -@export def set_expanded_output(val: bool) -> None: global use_expanded_output use_expanded_output = val -@export def is_expanded_output() -> bool: return use_expanded_output -@export def set_forced_horizontal_output(val: bool) -> None: global force_horizontal_output force_horizontal_output = val -@export def forced_horizontal() -> bool: return force_horizontal_output @@ -134,7 +122,6 @@ def forced_horizontal() -> bool: _logger = logging.getLogger(__name__) -@export def editor_command(command: str) -> bool: """ Is this an external editor command? @@ -145,7 +132,6 @@ def editor_command(command: str) -> bool: return command.strip().endswith("\\e") or command.strip().startswith("\\e") -@export def get_filename(sql: str) -> str | None: if sql.strip().startswith("\\e"): command, _, filename = sql.partition(" ") @@ -154,7 +140,6 @@ def get_filename(sql: str) -> str | None: return None -@export def get_editor_query(sql: str) -> str: """Get the query part of an editor command.""" sql = sql.strip() @@ -169,7 +154,6 @@ def get_editor_query(sql: str) -> str: return sql -@export def open_external_editor(filename: str | None = None, sql: str | None = None) -> tuple[str, str | None]: """Open external editor, wait for the user to type in their query, return the query. @@ -204,7 +188,6 @@ def open_external_editor(filename: str | None = None, sql: str | None = None) -> return (query, None) -@export def clip_command(command: str) -> bool: """Is this a clip command? @@ -216,7 +199,6 @@ def clip_command(command: str) -> bool: return command.strip().endswith("\\clip") or command.strip().startswith("\\clip") -@export def get_clip_query(sql: str) -> str: """Get the query part of a clip command.""" sql = sql.strip() @@ -230,7 +212,6 @@ def get_clip_query(sql: str) -> str: return sql -@export def copy_query_to_clipboard(sql: str | None = None) -> str | None: """Send query to the clipboard.""" @@ -245,7 +226,6 @@ def copy_query_to_clipboard(sql: str | None = None) -> str | None: return message -@export def set_redirect(command_part: str | None, file_operator_part: str | None, file_part: str | None) -> list[tuple]: if command_part: if file_part: @@ -405,7 +385,6 @@ def set_tee(arg: str, **_) -> list[tuple]: return [(None, None, None, "")] -@export def close_tee() -> None: global tee_file if tee_file: @@ -419,7 +398,6 @@ def no_tee(arg: str, **_) -> list[tuple]: return [(None, None, None, "")] -@export def write_tee(output: str) -> None: global tee_file if tee_file: @@ -441,12 +419,10 @@ def set_once(arg: str, **_) -> list[tuple]: return [(None, None, None, "")] -@export def is_redirected() -> bool: return bool(once_file or PIPE_ONCE['process']) -@export def write_once(output: str) -> None: global once_file, written_to_once_file if output and once_file: @@ -456,7 +432,6 @@ def write_once(output: str) -> None: written_to_once_file = True -@export def unset_once_if_written(post_redirect_command: str) -> None: """Unset the once file, if it has been written to.""" global once_file, written_to_once_file @@ -506,13 +481,11 @@ def set_pipe_once(arg: str, **_) -> list[tuple]: return [(None, None, None, "")] -@export def write_pipe_once(line: str) -> None: if line and PIPE_ONCE['process']: PIPE_ONCE['stdin'].append(line) -@export def flush_pipe_once_if_written(post_redirect_command: str) -> None: """Flush the pipe_once cmd, if lines have been written.""" if not PIPE_ONCE['process']: @@ -608,18 +581,15 @@ def watch_query(arg: str, **kwargs) -> Generator[tuple, None, None]: set_pager_enabled(old_pager_enabled) -@export @special_command("delimiter", None, "Change SQL delimiter.") def set_delimiter(arg: str, **_) -> list[tuple]: return delimiter_command.set(arg) -@export def get_current_delimiter() -> str: return delimiter_command.current -@export def split_queries(input_str: str) -> Generator[str, None, None]: for query in delimiter_command.queries_iter(input_str): yield query diff --git a/mycli/packages/special/llm.py b/mycli/packages/special/llm.py index 56dcfff1..4bce0980 100644 --- a/mycli/packages/special/llm.py +++ b/mycli/packages/special/llm.py @@ -13,7 +13,6 @@ import llm from llm.cli import cli -from mycli.packages.special import export from mycli.packages.special.main import Verbosity, parse_special_command log = logging.getLogger(__name__) @@ -91,7 +90,6 @@ def get_completions(tokens, tree=COMMAND_TREE): return list(tree.keys()) if tree else [] -@export class FinishIteration(Exception): def __init__(self, results=None): self.results = results @@ -161,7 +159,6 @@ def ensure_mycli_template(replace=False): return -@export def handle_llm(text, cur) -> Tuple[str, Optional[str], float]: _, verbosity, arg = parse_special_command(text) if not arg.strip(): @@ -217,13 +214,11 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str], float]: raise RuntimeError(e) -@export def is_llm_command(command) -> bool: cmd, _, _ = parse_special_command(command) return cmd in ("\\llm", "\\ai") -@export def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: if cur is None: raise RuntimeError("Connect to a database and try again.") diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py index 71e3269a..76b8677d 100644 --- a/mycli/packages/special/main.py +++ b/mycli/packages/special/main.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import namedtuple from enum import Enum import logging @@ -5,8 +7,6 @@ from pymysql.cursors import Cursor -from mycli.packages.special import export - logger = logging.getLogger(__name__) COMMANDS = {} @@ -31,7 +31,6 @@ class ArgType(Enum): RAW_QUERY = 2 -@export class CommandNotFound(Exception): pass @@ -42,7 +41,6 @@ class Verbosity(Enum): VERBOSE = "verbose" -@export def parse_special_command(sql: str) -> tuple[str, Verbosity, str]: command, _, arg = sql.partition(" ") verbosity = Verbosity.NORMAL @@ -54,10 +52,9 @@ def parse_special_command(sql: str) -> tuple[str, Verbosity, str]: return (command, verbosity, arg.strip()) -@export def special_command( command: str, - shortcut: str, + shortcut: str | None, description: str, arg_type: ArgType = ArgType.PARSED_QUERY, hidden: bool = False, @@ -80,11 +77,10 @@ def wrapper(wrapped): return wrapper -@export def register_special_command( handler: Callable, command: str, - shortcut: str, + shortcut: str | None, description: str, arg_type: ArgType = ArgType.PARSED_QUERY, hidden: bool = False, @@ -114,7 +110,6 @@ def register_special_command( ) -@export def execute(cur: Cursor, sql: str) -> list[tuple]: """Execute a special command and return the results. If the special command is not supported a CommandNotFound will be raised. diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index a884565a..04479ecb 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -1104,7 +1104,7 @@ def apply_case(kw: str) -> str: def get_completions( self, document: Document, - complete_event: CompleteEvent, + complete_event: CompleteEvent | None, smart_completion: bool | None = None, ) -> Iterable[Completion]: word_before_cursor = document.get_word_before_cursor(WORD=True) diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index a19ac53c..4562354f 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -5,9 +5,10 @@ import logging import re import ssl -from typing import Any, Generator +from typing import Any, Generator, Iterable import pymysql +from pymysql.connections import Connection from pymysql.constants import FIELD_TYPE from pymysql.converters import conversions, convert_date, convert_datetime, convert_timedelta, decoders from pymysql.cursors import Cursor @@ -112,7 +113,7 @@ def __init__( port: int | None, socket: str | None, charset: str | None, - local_infile: str | None, + local_infile: bool | None, ssl: dict[str, Any] | None, ssh_user: str | None, ssh_host: str | None, @@ -138,41 +139,42 @@ def __init__( self.ssh_password = ssh_password self.ssh_key_filename = ssh_key_filename self.init_command = init_command + self.conn: Connection | None = None self.connect() def connect( self, - database=None, - user=None, - password=None, - host=None, - port=None, - socket=None, - charset=None, - local_infile=None, - ssl=None, - ssh_host=None, - ssh_port=None, - ssh_user=None, - ssh_password=None, - ssh_key_filename=None, - init_command=None, + database: str | None = None, + user: str | None = None, + password: str | None = None, + host: str | None = None, + port: int | None = None, + socket: str | None = None, + charset: str | None = None, + local_infile: bool | None = None, + ssl: dict[str, Any] | None = None, + ssh_host: str | None = None, + ssh_port: int | None = None, + ssh_user: str | None = None, + ssh_password: str | None = None, + ssh_key_filename: str | None = None, + init_command: str | None = None, ): - db = database or self.dbname - user = user or self.user - password = password or self.password - host = host or self.host - port = port or self.port - socket = socket or self.socket - charset = charset or self.charset - local_infile = local_infile or self.local_infile - ssl = ssl or self.ssl - ssh_user = ssh_user or self.ssh_user - ssh_host = ssh_host or self.ssh_host - ssh_port = ssh_port or self.ssh_port - ssh_password = ssh_password or self.ssh_password - ssh_key_filename = ssh_key_filename or self.ssh_key_filename - init_command = init_command or self.init_command + db = database if database is not None else self.dbname + user = user if user is not None else self.user + password = password if password is not None else self.password + host = host if host is not None else self.host + port = port if port is not None else self.port + socket = socket if socket is not None else self.socket + charset = charset if charset is not None else self.charset + local_infile = local_infile if local_infile is not None else self.local_infile + ssl = ssl if ssl is not None else self.ssl + ssh_user = ssh_user if ssh_user is not None else self.ssh_user + ssh_host = ssh_host if ssh_host is not None else self.ssh_host + ssh_port = ssh_port if ssh_port is not None else self.ssh_port + ssh_password = ssh_password if ssh_password is not None else self.ssh_password + ssh_key_filename = ssh_key_filename if ssh_key_filename is not None else self.ssh_key_filename + init_command = init_command if init_command is not None else self.init_command _logger.debug( "Connection DB Params: \n" "\tdatabase: %r" @@ -228,21 +230,21 @@ def connect( conn = pymysql.connect( database=db, user=user, - password=password, + password=password or '', host=host, - port=port, + port=port or 0, unix_socket=socket, use_unicode=True, - charset=charset, + charset=charset or '', autocommit=True, client_flag=client_flag, local_infile=local_infile, conv=conv, - ssl=ssl_context, + ssl=ssl_context, # type: ignore[arg-type] program_name="mycli", defer_connect=defer_connect, init_command=init_command or None, - ) + ) # type: ignore[misc] if ssh_host: ##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel @@ -264,8 +266,11 @@ def connect( except Exception as e: raise e - if hasattr(self, "conn"): - self.conn.close() + if self.conn is not None: + try: + self.conn.close() + except pymysql.err.Error: + pass self.conn = conn # Update them after the connection is made to ensure that it was a # successful connection. @@ -280,7 +285,7 @@ def connect( self.init_command = init_command # retrieve connection id self.reset_connection_id() - self.server_info = ServerInfo.from_version_string(conn.server_version) + self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined] def run(self, statement: str) -> Generator[tuple, None, None]: """Execute the sql in the database and return the results. The results @@ -297,7 +302,7 @@ def run(self, statement: str) -> Generator[tuple, None, None]: # Unless it's saving a favorite query, in which case we # want to save them all together. if statement.startswith("\\fs"): - components = [statement] + components: Iterable[str] = [statement] else: components = iocommands.split_queries(statement) @@ -313,6 +318,7 @@ def run(self, statement: str) -> Generator[tuple, None, None]: iocommands.set_forced_horizontal_output(True) sql = sql[:-2].strip() + assert isinstance(self.conn, Connection) cur = self.conn.cursor() try: # Special command _logger.debug("Trying a dbspecial command. sql: %r", sql) @@ -350,6 +356,7 @@ def get_result(self, cursor: Cursor) -> tuple: def tables(self) -> Generator[tuple[str], None, None]: """Yields table names""" + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: _logger.debug("Tables Query. sql: %r", self.tables_query) cur.execute(self.tables_query) @@ -358,6 +365,7 @@ def tables(self) -> Generator[tuple[str], None, None]: def table_columns(self) -> Generator[tuple[str, str], None, None]: """Yields (table name, column name) pairs""" + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: _logger.debug("Columns Query. sql: %r", self.table_columns_query) cur.execute(self.table_columns_query % self.dbname) @@ -365,6 +373,7 @@ def table_columns(self) -> Generator[tuple[str, str], None, None]: yield row def databases(self) -> list[str]: + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: _logger.debug("Databases Query. sql: %r", self.databases_query) cur.execute(self.databases_query) @@ -373,6 +382,7 @@ def databases(self) -> list[str]: def functions(self) -> Generator[tuple[str, str], None, None]: """Yields tuples of (schema_name, function_name)""" + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: _logger.debug("Functions Query. sql: %r", self.functions_query) cur.execute(self.functions_query % self.dbname) @@ -380,6 +390,7 @@ def functions(self) -> Generator[tuple[str, str], None, None]: yield row def show_candidates(self) -> Generator[tuple, None, None]: + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: _logger.debug("Show Query. sql: %r", self.show_candidates_query) try: @@ -392,6 +403,7 @@ def show_candidates(self) -> Generator[tuple, None, None]: yield (row[0].split(None, 1)[-1],) def users(self) -> Generator[tuple, None, None]: + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: _logger.debug("Users Query. sql: %r", self.users_query) try: @@ -404,6 +416,7 @@ def users(self) -> Generator[tuple, None, None]: yield row def now(self) -> datetime.datetime: + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: _logger.debug("Now Query. sql: %r", self.now_query) cur.execute(self.now_query) @@ -432,6 +445,7 @@ def reset_connection_id(self) -> None: _logger.debug("Current connection id: %s", self.connection_id) def change_db(self, db: str) -> None: + assert isinstance(self.conn, Connection) self.conn.select_db(db) self.dbname = db