Files
OpenBB/cli/openbb_cli/argparse_translator/argparse_translator.py
montezdesousa cc15a770da [Feature] - Rename terminal folder (#6349)
* rename terminal folder

* docstring

* Rename to openbb_cli

* rename constant

* rename .his file

* pyproject.toml

* Update pre-commit and remove docker-build.yml
2024-04-30 12:42:31 +00:00

531 lines
18 KiB
Python

import argparse
import inspect
from copy import deepcopy
from enum import Enum
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Union,
get_args,
get_origin,
get_type_hints,
)
from openbb_core.app.model.field import OpenBBField
from pydantic import BaseModel, model_validator
from typing_extensions import Annotated
SEP = "__"
class ArgparseActionType(Enum):
store = "store"
store_true = "store_true"
class CustomArgument(BaseModel):
name: str
type: Optional[Any]
dest: str
default: Any
required: bool
action: Literal["store_true", "store"]
help: str
nargs: Optional[Literal["+"]]
choices: Optional[Any]
@model_validator(mode="after") # type: ignore
@classmethod
def validate_action(cls, values: "CustomArgument"):
if values.type is bool and values.action != "store_true":
raise ValueError('If type is bool, action must be "store_true"')
return values
@model_validator(mode="after") # type: ignore
@classmethod
def remove_props_on_store_true(cls, values: "CustomArgument"):
if values.action == "store_true":
values.type = None
values.nargs = None
values.choices = None
return values
# override
def model_dump(self, **kwargs):
res = super().model_dump(**kwargs)
# Check if choices is present and if it's an empty tuple remove it
if "choices" in res and not res["choices"]:
del res["choices"]
return res
class CustomArgumentGroup(BaseModel):
name: str
arguments: List[CustomArgument]
class ReferenceToCustomArgumentsProcessor:
def __init__(self, reference: Dict[str, Dict]):
"""Initializes the ReferenceToCustomArgumentsProcessor."""
self.reference = reference
self.custom_groups: Dict[str, List[CustomArgumentGroup]] = {}
self.build_custom_groups()
@staticmethod
def _make_type_parsable(type_: str) -> type:
"""Make the type parsable by removing the annotations."""
if "Union" in type_ and "str" in type_:
return str
if "Union" in type_ and "int" in type_:
return int
if type_ in ["date", "datetime.time", "time"]:
return str
if any(x in type_ for x in ["gt=", "ge=", "lt=", "le="]):
if "Annotated" in type_:
type_ = type_.replace("Annotated[", "").replace("]", "")
type_ = type_.split(",")[0]
return eval(type_) # noqa: S307, E501 pylint: disable=eval-used
def _parse_type(self, type_: str) -> type:
"""Parse the type from the string representation."""
type_ = self._make_type_parsable(type_) # type: ignore
if get_origin(type_) is Literal:
type_ = type(get_args(type_)[0]) # type: ignore
return type_ # type: ignore
def _get_nargs(self, type_: type) -> Optional[Union[int, str]]:
"""Get the nargs for the given type."""
if get_origin(type_) is list:
return "+"
return None
def _get_choices(self, type_: str) -> Tuple:
"""Get the choices for the given type."""
type_ = self._make_type_parsable(type_) # type: ignore
type_origin = get_origin(type_)
choices = ()
if type_origin is Literal:
choices = get_args(type_)
# param_type = type(choices[0])
if type_origin is list:
type_ = get_args(type_)[0]
if get_origin(type_) is Literal:
choices = get_args(type_)
# param_type = type(choices[0])
if type_origin is Union and type(None) in get_args(type_):
# remove NoneType from the args
args = [arg for arg in get_args(type_) if arg != type(None)]
# if there is only one arg left, use it
if len(args) > 1:
raise ValueError("Union with NoneType should have only one type left")
type_ = args[0]
if get_origin(type_) is Literal:
choices = get_args(type_)
# param_type = type(choices[0])
return choices
def build_custom_groups(self):
"""Build the custom groups from the reference."""
for route, v in self.reference.items():
for provider, args in v["parameters"].items():
if provider == "standard":
continue
custom_arguments = []
for arg in args:
if arg.get("standard"):
continue
type_ = self._parse_type(arg["type"])
custom_arguments.append(
CustomArgument(
name=arg["name"],
type=type_,
dest=arg["name"],
default=arg["default"],
required=not (arg["optional"]),
action="store" if type_ != bool else "store_true",
help=arg["description"],
nargs=self._get_nargs(type_), # type: ignore
choices=self._get_choices(arg["type"]),
)
)
group = CustomArgumentGroup(name=provider, arguments=custom_arguments)
if route not in self.custom_groups:
self.custom_groups[route] = []
self.custom_groups[route].append(group)
class ArgparseTranslator:
def __init__(
self,
func: Callable,
custom_argument_groups: Optional[List[CustomArgumentGroup]] = None,
add_help: Optional[bool] = True,
):
"""
Initializes the ArgparseTranslator.
Args:
func (Callable): The function to translate into an argparse program.
add_help (Optional[bool], optional): Whether to add the help argument. Defaults to False.
"""
self.func = func
self.signature = inspect.signature(func)
self.type_hints = get_type_hints(func)
self.provider_parameters = []
self._parser = argparse.ArgumentParser(
prog=func.__name__,
description=self._build_description(func.__doc__), # type: ignore
formatter_class=argparse.RawTextHelpFormatter,
add_help=add_help if add_help else False,
)
self._required = self._parser.add_argument_group("required arguments")
if any(param in self.type_hints for param in self.signature.parameters):
self._generate_argparse_arguments(self.signature.parameters)
if custom_argument_groups:
for group in custom_argument_groups:
argparse_group = self._parser.add_argument_group(group.name)
for argument in group.arguments:
kwargs = argument.model_dump(exclude={"name"}, exclude_none=True)
# If the argument is already in use, we can't repeat it
if f"--{argument.name}" not in self._parser_arguments():
argparse_group.add_argument(f"--{argument.name}", **kwargs)
self.provider_parameters.append(argument.name)
def _parser_arguments(self) -> List[str]:
"""Get all the arguments from all groups currently defined on the parser."""
arguments_in_use: List[str] = []
# pylint: disable=protected-access
for action_group in self._parser._action_groups:
for action in action_group._group_actions:
arguments_in_use.extend(action.option_strings)
return arguments_in_use
@property
def parser(self) -> argparse.ArgumentParser:
return deepcopy(self._parser)
@staticmethod
def _build_description(func_doc: str) -> str:
"""Builds the description of the argparse program from the function docstring."""
patterns = ["openbb\n ======", "Parameters\n ----------"]
if func_doc:
for pattern in patterns:
if pattern in func_doc:
func_doc = func_doc[: func_doc.index(pattern)].strip()
break
return func_doc
@staticmethod
def _param_is_default(param: inspect.Parameter) -> bool:
"""Returns True if the parameter has a default value."""
return param.default != inspect.Parameter.empty
def _get_action_type(self, param: inspect.Parameter) -> str:
"""Returns the argparse action type for the given parameter."""
param_type = self.type_hints[param.name]
if param_type == bool:
return ArgparseActionType.store_true.value
return ArgparseActionType.store.value
def _get_type_and_choices(
self, param: inspect.Parameter
) -> Tuple[Type[Any], Tuple[Any, ...]]:
"""Returns the type and choices for the given parameter."""
param_type = self.type_hints[param.name]
type_origin = get_origin(param_type)
choices = ()
if type_origin is Literal:
choices = get_args(param_type)
param_type = type(choices[0]) # type: ignore
if type_origin is list: # TODO: dict should also go here
param_type = get_args(param_type)[0]
if get_origin(param_type) is Literal:
choices = get_args(param_type)
param_type = type(choices[0]) # type: ignore
if type_origin is Union:
union_args = get_args(param_type)
if str in union_args:
param_type = str
# check if it's an Optional, which would be a Union with NoneType
if type(None) in get_args(param_type):
# remove NoneType from the args
args = [arg for arg in get_args(param_type) if arg != type(None)]
# if there is only one arg left, use it
if len(args) > 1:
raise ValueError(
"Union with NoneType should have only one type left"
)
param_type = args[0]
if get_origin(param_type) is Literal:
choices = get_args(param_type)
param_type = type(choices[0]) # type: ignore
# if there are custom choices, override
choices = self._get_argument_custom_choices(param) or choices # type: ignore
return param_type, choices
@staticmethod
def _split_annotation(
base_annotation: Type[Any], custom_annotation_type: Type
) -> Tuple[Type[Any], List[Any]]:
"""Find the base annotation and the custom annotations, namely the OpenBBField."""
if get_origin(base_annotation) is not Annotated:
return base_annotation, []
base_annotation, *maybe_custom_annotations = get_args(base_annotation)
return base_annotation, [
annotation
for annotation in maybe_custom_annotations
if isinstance(annotation, custom_annotation_type)
]
@classmethod
def _get_argument_custom_help(cls, param: inspect.Parameter) -> Optional[str]:
"""Returns the help annotation for the given parameter."""
base_annotation = param.annotation
_, custom_annotations = cls._split_annotation(base_annotation, OpenBBField)
help_annotation = (
custom_annotations[0].description if custom_annotations else None
)
if not help_annotation:
# try to get it from the docstring
pass
return help_annotation
@classmethod
def _get_argument_custom_choices(cls, param: inspect.Parameter) -> Optional[str]:
"""Returns the help annotation for the given parameter."""
base_annotation = param.annotation
_, custom_annotations = cls._split_annotation(base_annotation, OpenBBField)
choices_annotation = (
custom_annotations[0].choices if custom_annotations else None
)
if not choices_annotation:
# try to get it from the docstring
pass
return choices_annotation
def _get_nargs(self, param: inspect.Parameter) -> Optional[str]:
"""Returns the nargs annotation for the given parameter."""
param_type = self.type_hints[param.name]
origin = get_origin(param_type)
if origin is list:
return "+"
if origin is Union and any(
get_origin(arg) is list for arg in get_args(param_type)
):
return "+"
return None
def _generate_argparse_arguments(self, parameters) -> None:
"""Generates the argparse arguments from the function parameters."""
for param in parameters.values():
# TODO : how to handle kwargs?
# it's possible to add unknown arguments when parsing as follows:
# args, unknown_args = parser.parse_known_args()
if param.name == "kwargs":
continue
param_type, choices = self._get_type_and_choices(param)
# if the param is a custom type, we need to flatten it
if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
# update type hints with the custom type fields
type_hints = get_type_hints(param_type)
# prefix the type hints keys with the param name
type_hints = {
f"{param.name}{SEP}{key}": value
for key, value in type_hints.items()
}
self.type_hints.update(type_hints)
# create a signature from the custom type
sig = inspect.signature(param_type)
# add help to the annotation
annotated_parameters: List[inspect.Parameter] = []
for child_param in sig.parameters.values():
new_child_param = child_param.replace(
name=f"{param.name}{SEP}{child_param.name}",
annotation=Annotated[
child_param.annotation,
OpenBBField(
description=param_type.model_json_schema()[
"properties"
][child_param.name].get("description", None)
),
],
kind=inspect.Parameter.KEYWORD_ONLY,
)
annotated_parameters.append(new_child_param)
# replacing with the annotated parameters
new_signature = inspect.Signature(
parameters=annotated_parameters,
return_annotation=sig.return_annotation,
)
self._generate_argparse_arguments(new_signature.parameters)
# the custom type itself should not be added as an argument
continue
required = not self._param_is_default(param)
kwargs = {
"type": param_type,
"dest": param.name,
"default": param.default,
"required": required,
"action": self._get_action_type(param),
"help": self._get_argument_custom_help(param),
"nargs": self._get_nargs(param),
}
if choices:
kwargs["choices"] = choices
if param_type == bool:
# store_true action does not accept the below kwargs
kwargs.pop("type")
kwargs.pop("nargs")
if required:
self._required.add_argument(
f"--{param.name}",
**kwargs,
)
else:
self._parser.add_argument(
f"--{param.name}",
**kwargs,
)
@staticmethod
def _unflatten_args(args: dict) -> Dict[str, Any]:
"""Unflatten the args that were flattened by the custom types."""
result: Dict[str, Any] = {}
for key, value in args.items():
if SEP in key:
parts = key.split(SEP)
nested_dict = result
for part in parts[:-1]:
if part not in nested_dict:
nested_dict[part] = {}
nested_dict = nested_dict[part]
nested_dict[parts[-1]] = value
else:
result[key] = value
return result
def _update_with_custom_types(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Update the kwargs with the custom types."""
# for each argument in the signature that is a custom type, we need to
# update the kwargs with the custom type kwargs
for param in self.signature.parameters.values():
# TODO : how to handle kwargs?
if param.name == "kwargs":
continue
param_type, _ = self._get_type_and_choices(param)
if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
custom_type_kwargs = kwargs[param.name]
kwargs[param.name] = param_type(**custom_type_kwargs)
return kwargs
def execute_func(
self,
parsed_args: Optional[argparse.Namespace] = None,
) -> Any:
"""
Executes the original function with the parsed arguments.
Args:
parsed_args (Optional[argparse.Namespace], optional): The parsed arguments. Defaults to None.
Returns:
Any: The return value of the original function.
"""
kwargs = self._unflatten_args(vars(parsed_args))
kwargs = self._update_with_custom_types(kwargs)
# remove kwargs that doesn't match the signature or provider parameters
kwargs = {
key: value
for key, value in kwargs.items()
if key in self.signature.parameters or key in self.provider_parameters
}
return self.func(**kwargs)
def parse_args_and_execute(self) -> Any:
"""
Parses the arguments and executes the original function.
Returns:
Any: The return value of the original function.
"""
parsed_args = self._parser.parse_args()
return self.execute_func(parsed_args)
def translate(self) -> Callable:
"""
Wraps the original function with an argparse program.
Returns:
Callable: The original function wrapped with an argparse program.
"""
def wrapper_func():
return self.parse_args_and_execute()
return wrapper_func