Source code for django_typer.utils

"""
A collection of useful utilities.
"""

import inspect
import os
import sys
import typing as t
from datetime import timedelta
from functools import partial
from importlib.util import find_spec
from pathlib import Path
from threading import local
from types import MethodType, ModuleType

from django.db.models import Model
from django.db.models.query import QuerySet

from .completers.model import ModelObjectCompleter
from .config import traceback_config
from .parsers.model import ModelObjectParser, ReturnType

# DO NOT IMPORT ANYTHING FROM TYPER HERE - SEE patch.py

__all__ = [
    "detect_shell",
    "get_usage_script",
    "get_current_command",
    "with_typehint",
    "register_command_plugins",
    "called_from_module",
    "called_from_command_definition",
    "duration_iso_string",
    "parse_iso_duration",
    "model_parser_completer",
    "rich_installed",
]


rich_installed = find_spec("rich") is not None and os.environ.get(
    "TYPER_USE_RICH", "true"
).lower() not in {"0", "false"}


def detect_shell(max_depth: int = 10) -> t.Tuple[str, str]:
    """
    Detect the current shell.

    :raises ShellDetectionFailure: If the shell cannot be detected
    :return: A tuple of the shell name and the shell command
    """
    from shellingham import ShellDetectionFailure
    from shellingham import detect_shell as _detect_shell

    try:
        return _detect_shell(max_depth=max_depth)
    except ShellDetectionFailure:
        login_shell = os.environ.get("SHELL", "")
        if login_shell:
            return (os.path.basename(login_shell).lower(), login_shell)
        raise


def get_usage_script(script: t.Optional[str] = None) -> t.Union[Path, str]:
    """
    Return the script name if it is on the path or the absolute path to the script
    if it is not.

    :param script: The script name to check. If None the current script is used.
    :return: The script name or the relative path to the script from cwd.
    """
    import shutil

    cmd_pth = Path(script or sys.argv[0])
    on_path: t.Optional[t.Union[str, Path]] = shutil.which(cmd_pth.name)
    on_path = on_path and Path(on_path)
    if (
        on_path
        and on_path.is_absolute()
        and (on_path == cmd_pth.absolute() or not cmd_pth.is_file())
    ):
        return cmd_pth.name
    try:
        return cmd_pth.absolute().relative_to(Path(os.getcwd()))
    except ValueError:
        return cmd_pth.absolute()


_command_context = local()


[docs] def get_current_command() -> t.Optional["TyperCommand"]: # type: ignore # noqa: F821 """ Returns the current typer command. This can be used as a way to access the current command object from anywhere if we are executing inside of one from higher on the stack. We primarily need this because certain monkey patches are required in typer code - namely for enabling/disabling color based on configured parameters. This function is thread safe. This is analogous to click's get_current_context but for command execution. :return: The current typer command or None if there is no active command. """ try: return t.cast("TyperCommand", _command_context.stack[-1]) # type: ignore # noqa: F821 except (AttributeError, IndexError): pass return None
T = t.TypeVar("T") # pylint: disable=C0103 def with_typehint(baseclass: t.Type[T]) -> t.Type[T]: """ Type hinting mixin inheritance is really annoying. The current canonical way is to use Protocols but this is prohibitive when the super classes already exist and are extensive. All we're trying to do is let our type checker know about super() methods etc - this is a simple way to do that. """ if t.TYPE_CHECKING: return baseclass # pragma: no cover return object # type: ignore _command_plugins: t.Dict[str, t.List[ModuleType]] = {}
[docs] def register_command_plugins( package: ModuleType, commands: t.Optional[t.List[str]] = None ): """ Register a command plugin for the given command within the given package. For example, use this in your AppConfig's ready() method: .. code-block:: python from django.apps import AppConfig from django_typer.utils import register_command_plugins class MyAppConfig(AppConfig): name = "myapp" def ready(self): from .management import plugins register_command_plugins(plugins) :param package: The package the command extension module resides in :param commands: The names of the commands/modules, if not provided, all modules in the package will be registered as plugins """ import pkgutil commands = commands or [ module[1].split(".")[-1] for module in pkgutil.iter_modules(package.__path__, f"{package.__name__}.") ] for command in commands: _command_plugins.setdefault(command, []) if package not in _command_plugins[command]: _command_plugins[command].append(package)
def _load_command_plugins(command: str) -> int: """ Load any plugins for the given command by loading the registered modules in registration order. :param command: The name of the command :return: The number of plugins loaded. """ plugins = _command_plugins.get(command, []) if plugins: import importlib for ext_pkg in reversed(plugins): try: importlib.import_module(f"{ext_pkg.__name__}.{command}") except (ImportError, ModuleNotFoundError) as err: raise ValueError( f"No extension module was found for command {command} in " f"{ext_pkg.__path__}." ) from err # we only want to do this once del _command_plugins[command] return len(plugins) def _check_call_frame(frame_name: str, look_back=1) -> bool: """ Returns True if the stack frame one frame above where this function has the given name. :param frame_name: The name of the frame to check for """ frame = inspect.currentframe() for _ in range(0, look_back + 1): if not frame: break frame = frame.f_back if frame: return frame.f_code.co_name == frame_name return False called_from_module = partial(_check_call_frame, "<module>") called_from_command_definition = partial(_check_call_frame, "Command")
[docs] def is_method( func_or_params: t.Optional[t.Union[t.Callable[..., t.Any], t.List[str]]], ) -> t.Optional[bool]: """ This logic is used to to determine if a function should be bound as a method or not. Right now django-typer will treat module scope functions as methods when binding to command classes if they have a first argument named 'self'. :param func: The function to check or a list of parameter names, or None :return: True if the function should be considered a method, False if not and None if undetermined. """ # ############## # Remove when python 3.9 support is dropped func_or_params = getattr(func_or_params, "__func__", func_or_params) ############## if func_or_params: params = ( list(inspect.signature(func_or_params).parameters) if callable(func_or_params) else func_or_params ) if params: return params[0] == "self" return isinstance(func_or_params, MethodType) return None
def accepts_var_kwargs(func: t.Callable[..., t.Any]) -> bool: """ Determines if the given function accepts variable keyword arguments. """ for param in reversed(list(inspect.signature(func).parameters.values())): return param.kind is inspect.Parameter.VAR_KEYWORD return False def accepted_kwargs( func: t.Callable[..., t.Any], kwargs: t.Dict[str, t.Any] ) -> t.Dict[str, t.Any]: """ Return the named keyword arguments that are accepted by the given function. """ if accepts_var_kwargs(func): return kwargs param_names = set(inspect.signature(func).parameters.keys()) return {k: v for k, v in kwargs.items() if k in param_names} def get_win_shell() -> str: """ The way installed python scripts are wrapped on Windows means shellingham will detect cmd.exe as the shell. This function will attempt to detect the correct shell, usually either powershell (<=v5) or pwsh (>=v6). :raises ShellDetectionFailure: If the shell cannot be detected :return: The name of the shell, either 'powershell' or 'pwsh' """ import json import shutil import subprocess # nosec B404 from shellingham import ShellDetectionFailure pwsh = shutil.which("pwsh") or shutil.which("powershell") if pwsh: try: ps_command = """ $parent = Get-CimInstance -Query "SELECT * FROM Win32_Process WHERE ProcessId = {pid}"; $parentPid = $parent.ParentProcessId; $parentInfo = Get-CimInstance -Query "SELECT * FROM Win32_Process WHERE ProcessId = $parentPid"; $parentInfo | Select-Object Name, ProcessId | ConvertTo-Json -Depth 1 """ pid = os.getpid() while True: result = subprocess.run( # nosec B603 [pwsh, "-NoProfile", "-Command", ps_command.format(pid=pid)], capture_output=True, text=True, ).stdout.strip() if not result: break process = json.loads(result) if "pwsh" in process.get("Name", ""): return "pwsh" elif "powershell" in process.get("Name", ""): return "powershell" pid = process["ProcessId"] except Exception as e: # pragma: no cover raise ShellDetectionFailure("Unable to detect windows shell") from e raise ShellDetectionFailure("Unable to detect windows shell")
[docs] def parse_iso_duration(duration: str) -> t.Tuple[timedelta, t.Optional[str]]: """ Progressively parse an ISO8601 duration type - can be a partial duration string. If it is a partial duration string with an ambiguous trailing number, the number will be returned as the second value of the tuple. .. note:: We use a subset of ISO8601, the supported markers are Y, M, W, D, H, M, S. Years are approximated as 365 days and months as 30 days. :return: A tuple of the parsed duration and the ambiguous trailing number """ import re original = duration duration = duration.upper() sign = -1 if duration.startswith("-") else 1 duration = duration.lstrip("-").lstrip("+").lstrip("P") ambiguous: t.Optional[str] = None class Incomplete(Exception): value: str def __init__(self, value: str): self.value = value def eat(markers: t.Sequence[str], interpret=lambda x: (int(x), True)) -> int: nonlocal duration if duration: match = re.match(r"(\d+)(.)?", duration) if match and match.group(2) in markers: duration = duration[match.end() :] return interpret(match.group(1))[0] if match and not match.group(2): duration = duration[match.end() :] value, ambig = interpret(match.group(1)) if not ambig: return value raise Incomplete(match.group(1)) return 0 years = 0 months = 0 weeks = 0 days = 0 hours = 0 minutes = 0 seconds = 0 microseconds = 0 # date portion: PnYnMnWnD try: years = eat(("Y",)) months = eat(("M",)) weeks = eat(("W",)) days = eat(("D",)) except Incomplete as incomplete: ambiguous = incomplete.value duration = duration.lstrip("T") try: hours = eat(("H",)) minutes = eat(("M",)) seconds = eat((".", "S")) microseconds = eat(("S",), lambda x: (int(f"{x:0<6}"), len(x) < 6)) except Incomplete as incomplete: ambiguous = incomplete.value if duration: # if the string was a valid full or partial duration all characters # should have been consumed raise ValueError(f"Invalid ISO 8601 duration format: {original}") total_days = years * 365 + months * 30 + weeks * 7 + days return sign * timedelta( days=total_days, hours=hours, minutes=minutes, seconds=seconds, microseconds=microseconds, ), ambiguous
[docs] def duration_iso_string(duration: timedelta) -> str: """ Return an ISO8601 duration string from a timedelta. This differs from the Django implementation in that zeros are elided. """ if not duration: return "PT0S" sign = "-" if duration < timedelta() else "" if sign: duration *= -1 days = duration.days hours, seconds = divmod(duration.seconds, 3600) minutes, seconds = divmod(seconds, 60) time_parts = [] day_str = "" if days: day_str = f"{abs(days)}D" if hours: time_parts.append(f"{abs(hours)}H") if minutes: time_parts.append(f"{abs(minutes)}M") if seconds or duration.microseconds: if duration.microseconds: time_parts.append(f"{abs(seconds)}.{abs(duration.microseconds):0>6}S") else: time_parts.append(f"{abs(seconds)}S") time_str = "" if time_parts: time_str = "T" + "".join(time_parts) return f"{sign}P{day_str}{time_str}"
[docs] def model_parser_completer( model_or_qry: t.Union[t.Type[Model], QuerySet], lookup_field: t.Optional[str] = None, case_insensitive: bool = False, help_field: t.Optional[str] = ModelObjectCompleter.help_field, query: t.Optional[ModelObjectCompleter.QueryBuilder] = None, limit: t.Optional[int] = ModelObjectCompleter.limit, distinct: bool = ModelObjectCompleter.distinct, on_error: t.Optional[ModelObjectParser.error_handler] = ModelObjectParser.on_error, order_by: t.Optional[t.Union[str, t.Sequence[str]]] = None, return_type: ReturnType = ModelObjectParser.return_type, ) -> t.Dict[str, t.Any]: """ A factory function that returns a dictionary that can be used to specify a parser and completer for a typer.Option or typer.Argument. This is a convenience function that can be used to specify the parser and completer for a model object in one go. .. code-block:: python def handle( self, obj: t.Annotated[ ModelClass, typer.Argument( **model_parser_completer(ModelClass, 'field_name'), help=_("Fetch objects by their field_names.") ), ] ): ... :param model_or_qry: the model class or QuerySet to use for lookup :param lookup_field: the field to use for lookup, by default the primary key :param case_insensitive: whether to perform case insensitive lookups and completions, default: False :param help_field: the field to use for help output in completion suggestions, by default no help will be provided :param query: a callable that will be used to build the query for completions, by default the query will be reasonably determined by the field type :param limit: the maximum number of completions to return, default: 50 :param distinct: whether to filter out already provided parameters in the completion suggestions, True by default :param on_error: a callable that will be called if the parser lookup fails to produce a matching object - by default a CommandError will be raised :param return_type: An enumeration switch to return either a model instance, queryset or model field value type. """ return { "parser": ModelObjectParser( model_or_qry if inspect.isclass(model_or_qry) else model_or_qry.model, # type: ignore lookup_field, case_insensitive=case_insensitive, on_error=on_error, return_type=return_type, ), "shell_complete": ModelObjectCompleter( model_or_qry, lookup_field, case_insensitive=case_insensitive, help_field=help_field, query=query, limit=limit, distinct=distinct, order_by=order_by, ), }
def install_traceback(tb_config: t.Optional[t.Dict[str, t.Any]] = None): from .config import use_rich_tracebacks if not use_rich_tracebacks(): return import rich from rich import traceback from typer import main as typer_main tb_config = tb_config or traceback_config() # install rich tracebacks if we've been configured to do so (default) no_color = "NO_COLOR" in os.environ force_color = "FORCE_COLOR" in os.environ traceback.install( console=tb_config.pop( "console", ( rich.console.Console( stderr=True, no_color=no_color, force_terminal=( False if no_color else force_color if force_color else None ), ) if no_color or force_color else None ), ), **{ param: value for param, value in tb_config.items() if param in set(inspect.signature(traceback.install).parameters.keys()) }, ) # typer installs its own exception hook and it falls back to the sys hook - # depending on when typer was imported it may have the original fallback system hook # or our installed rich one - we patch it here to make sure! typer_main._original_except_hook = sys.excepthook