mirror of
https://github.com/OpenBB-finance/OpenBB.git
synced 2026-05-31 23:13:26 +08:00
* rename terminal folder * docstring * Rename to openbb_cli * rename constant * rename .his file * pyproject.toml * Update pre-commit and remove docker-build.yml
531 lines
18 KiB
Python
531 lines
18 KiB
Python
import argparse
|
|
import inspect
|
|
from copy import deepcopy
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
get_args,
|
|
get_origin,
|
|
get_type_hints,
|
|
)
|
|
|
|
from openbb_core.app.model.field import OpenBBField
|
|
from pydantic import BaseModel, model_validator
|
|
from typing_extensions import Annotated
|
|
|
|
SEP = "__"
|
|
|
|
|
|
class ArgparseActionType(Enum):
|
|
store = "store"
|
|
store_true = "store_true"
|
|
|
|
|
|
class CustomArgument(BaseModel):
|
|
|
|
name: str
|
|
type: Optional[Any]
|
|
dest: str
|
|
default: Any
|
|
required: bool
|
|
action: Literal["store_true", "store"]
|
|
help: str
|
|
nargs: Optional[Literal["+"]]
|
|
choices: Optional[Any]
|
|
|
|
@model_validator(mode="after") # type: ignore
|
|
@classmethod
|
|
def validate_action(cls, values: "CustomArgument"):
|
|
if values.type is bool and values.action != "store_true":
|
|
raise ValueError('If type is bool, action must be "store_true"')
|
|
return values
|
|
|
|
@model_validator(mode="after") # type: ignore
|
|
@classmethod
|
|
def remove_props_on_store_true(cls, values: "CustomArgument"):
|
|
if values.action == "store_true":
|
|
values.type = None
|
|
values.nargs = None
|
|
values.choices = None
|
|
return values
|
|
|
|
# override
|
|
def model_dump(self, **kwargs):
|
|
|
|
res = super().model_dump(**kwargs)
|
|
|
|
# Check if choices is present and if it's an empty tuple remove it
|
|
if "choices" in res and not res["choices"]:
|
|
del res["choices"]
|
|
|
|
return res
|
|
|
|
|
|
class CustomArgumentGroup(BaseModel):
|
|
name: str
|
|
arguments: List[CustomArgument]
|
|
|
|
|
|
class ReferenceToCustomArgumentsProcessor:
|
|
def __init__(self, reference: Dict[str, Dict]):
|
|
"""Initializes the ReferenceToCustomArgumentsProcessor."""
|
|
self.reference = reference
|
|
self.custom_groups: Dict[str, List[CustomArgumentGroup]] = {}
|
|
|
|
self.build_custom_groups()
|
|
|
|
@staticmethod
|
|
def _make_type_parsable(type_: str) -> type:
|
|
"""Make the type parsable by removing the annotations."""
|
|
if "Union" in type_ and "str" in type_:
|
|
return str
|
|
if "Union" in type_ and "int" in type_:
|
|
return int
|
|
if type_ in ["date", "datetime.time", "time"]:
|
|
return str
|
|
|
|
if any(x in type_ for x in ["gt=", "ge=", "lt=", "le="]):
|
|
if "Annotated" in type_:
|
|
type_ = type_.replace("Annotated[", "").replace("]", "")
|
|
type_ = type_.split(",")[0]
|
|
|
|
return eval(type_) # noqa: S307, E501 pylint: disable=eval-used
|
|
|
|
def _parse_type(self, type_: str) -> type:
|
|
"""Parse the type from the string representation."""
|
|
type_ = self._make_type_parsable(type_) # type: ignore
|
|
|
|
if get_origin(type_) is Literal:
|
|
type_ = type(get_args(type_)[0]) # type: ignore
|
|
|
|
return type_ # type: ignore
|
|
|
|
def _get_nargs(self, type_: type) -> Optional[Union[int, str]]:
|
|
"""Get the nargs for the given type."""
|
|
if get_origin(type_) is list:
|
|
return "+"
|
|
return None
|
|
|
|
def _get_choices(self, type_: str) -> Tuple:
|
|
"""Get the choices for the given type."""
|
|
type_ = self._make_type_parsable(type_) # type: ignore
|
|
type_origin = get_origin(type_)
|
|
|
|
choices = ()
|
|
|
|
if type_origin is Literal:
|
|
choices = get_args(type_)
|
|
# param_type = type(choices[0])
|
|
|
|
if type_origin is list:
|
|
type_ = get_args(type_)[0]
|
|
|
|
if get_origin(type_) is Literal:
|
|
choices = get_args(type_)
|
|
# param_type = type(choices[0])
|
|
|
|
if type_origin is Union and type(None) in get_args(type_):
|
|
# remove NoneType from the args
|
|
args = [arg for arg in get_args(type_) if arg != type(None)]
|
|
# if there is only one arg left, use it
|
|
if len(args) > 1:
|
|
raise ValueError("Union with NoneType should have only one type left")
|
|
type_ = args[0]
|
|
|
|
if get_origin(type_) is Literal:
|
|
choices = get_args(type_)
|
|
# param_type = type(choices[0])
|
|
|
|
return choices
|
|
|
|
def build_custom_groups(self):
|
|
"""Build the custom groups from the reference."""
|
|
for route, v in self.reference.items():
|
|
|
|
for provider, args in v["parameters"].items():
|
|
if provider == "standard":
|
|
continue
|
|
|
|
custom_arguments = []
|
|
for arg in args:
|
|
if arg.get("standard"):
|
|
continue
|
|
|
|
type_ = self._parse_type(arg["type"])
|
|
|
|
custom_arguments.append(
|
|
CustomArgument(
|
|
name=arg["name"],
|
|
type=type_,
|
|
dest=arg["name"],
|
|
default=arg["default"],
|
|
required=not (arg["optional"]),
|
|
action="store" if type_ != bool else "store_true",
|
|
help=arg["description"],
|
|
nargs=self._get_nargs(type_), # type: ignore
|
|
choices=self._get_choices(arg["type"]),
|
|
)
|
|
)
|
|
|
|
group = CustomArgumentGroup(name=provider, arguments=custom_arguments)
|
|
|
|
if route not in self.custom_groups:
|
|
self.custom_groups[route] = []
|
|
|
|
self.custom_groups[route].append(group)
|
|
|
|
|
|
class ArgparseTranslator:
|
|
def __init__(
|
|
self,
|
|
func: Callable,
|
|
custom_argument_groups: Optional[List[CustomArgumentGroup]] = None,
|
|
add_help: Optional[bool] = True,
|
|
):
|
|
"""
|
|
Initializes the ArgparseTranslator.
|
|
|
|
Args:
|
|
func (Callable): The function to translate into an argparse program.
|
|
add_help (Optional[bool], optional): Whether to add the help argument. Defaults to False.
|
|
"""
|
|
self.func = func
|
|
self.signature = inspect.signature(func)
|
|
self.type_hints = get_type_hints(func)
|
|
self.provider_parameters = []
|
|
|
|
self._parser = argparse.ArgumentParser(
|
|
prog=func.__name__,
|
|
description=self._build_description(func.__doc__), # type: ignore
|
|
formatter_class=argparse.RawTextHelpFormatter,
|
|
add_help=add_help if add_help else False,
|
|
)
|
|
self._required = self._parser.add_argument_group("required arguments")
|
|
|
|
if any(param in self.type_hints for param in self.signature.parameters):
|
|
self._generate_argparse_arguments(self.signature.parameters)
|
|
|
|
if custom_argument_groups:
|
|
for group in custom_argument_groups:
|
|
argparse_group = self._parser.add_argument_group(group.name)
|
|
for argument in group.arguments:
|
|
kwargs = argument.model_dump(exclude={"name"}, exclude_none=True)
|
|
|
|
# If the argument is already in use, we can't repeat it
|
|
if f"--{argument.name}" not in self._parser_arguments():
|
|
argparse_group.add_argument(f"--{argument.name}", **kwargs)
|
|
self.provider_parameters.append(argument.name)
|
|
|
|
def _parser_arguments(self) -> List[str]:
|
|
"""Get all the arguments from all groups currently defined on the parser."""
|
|
arguments_in_use: List[str] = []
|
|
|
|
# pylint: disable=protected-access
|
|
for action_group in self._parser._action_groups:
|
|
for action in action_group._group_actions:
|
|
arguments_in_use.extend(action.option_strings)
|
|
|
|
return arguments_in_use
|
|
|
|
@property
|
|
def parser(self) -> argparse.ArgumentParser:
|
|
return deepcopy(self._parser)
|
|
|
|
@staticmethod
|
|
def _build_description(func_doc: str) -> str:
|
|
"""Builds the description of the argparse program from the function docstring."""
|
|
|
|
patterns = ["openbb\n ======", "Parameters\n ----------"]
|
|
|
|
if func_doc:
|
|
for pattern in patterns:
|
|
if pattern in func_doc:
|
|
func_doc = func_doc[: func_doc.index(pattern)].strip()
|
|
break
|
|
|
|
return func_doc
|
|
|
|
@staticmethod
|
|
def _param_is_default(param: inspect.Parameter) -> bool:
|
|
"""Returns True if the parameter has a default value."""
|
|
return param.default != inspect.Parameter.empty
|
|
|
|
def _get_action_type(self, param: inspect.Parameter) -> str:
|
|
"""Returns the argparse action type for the given parameter."""
|
|
param_type = self.type_hints[param.name]
|
|
|
|
if param_type == bool:
|
|
return ArgparseActionType.store_true.value
|
|
return ArgparseActionType.store.value
|
|
|
|
def _get_type_and_choices(
|
|
self, param: inspect.Parameter
|
|
) -> Tuple[Type[Any], Tuple[Any, ...]]:
|
|
"""Returns the type and choices for the given parameter."""
|
|
param_type = self.type_hints[param.name]
|
|
type_origin = get_origin(param_type)
|
|
|
|
choices = ()
|
|
|
|
if type_origin is Literal:
|
|
choices = get_args(param_type)
|
|
param_type = type(choices[0]) # type: ignore
|
|
|
|
if type_origin is list: # TODO: dict should also go here
|
|
param_type = get_args(param_type)[0]
|
|
|
|
if get_origin(param_type) is Literal:
|
|
choices = get_args(param_type)
|
|
param_type = type(choices[0]) # type: ignore
|
|
|
|
if type_origin is Union:
|
|
union_args = get_args(param_type)
|
|
if str in union_args:
|
|
param_type = str
|
|
|
|
# check if it's an Optional, which would be a Union with NoneType
|
|
if type(None) in get_args(param_type):
|
|
# remove NoneType from the args
|
|
args = [arg for arg in get_args(param_type) if arg != type(None)]
|
|
# if there is only one arg left, use it
|
|
if len(args) > 1:
|
|
raise ValueError(
|
|
"Union with NoneType should have only one type left"
|
|
)
|
|
param_type = args[0]
|
|
|
|
if get_origin(param_type) is Literal:
|
|
choices = get_args(param_type)
|
|
param_type = type(choices[0]) # type: ignore
|
|
|
|
# if there are custom choices, override
|
|
choices = self._get_argument_custom_choices(param) or choices # type: ignore
|
|
|
|
return param_type, choices
|
|
|
|
@staticmethod
|
|
def _split_annotation(
|
|
base_annotation: Type[Any], custom_annotation_type: Type
|
|
) -> Tuple[Type[Any], List[Any]]:
|
|
"""Find the base annotation and the custom annotations, namely the OpenBBField."""
|
|
if get_origin(base_annotation) is not Annotated:
|
|
return base_annotation, []
|
|
base_annotation, *maybe_custom_annotations = get_args(base_annotation)
|
|
return base_annotation, [
|
|
annotation
|
|
for annotation in maybe_custom_annotations
|
|
if isinstance(annotation, custom_annotation_type)
|
|
]
|
|
|
|
@classmethod
|
|
def _get_argument_custom_help(cls, param: inspect.Parameter) -> Optional[str]:
|
|
"""Returns the help annotation for the given parameter."""
|
|
base_annotation = param.annotation
|
|
_, custom_annotations = cls._split_annotation(base_annotation, OpenBBField)
|
|
help_annotation = (
|
|
custom_annotations[0].description if custom_annotations else None
|
|
)
|
|
if not help_annotation:
|
|
# try to get it from the docstring
|
|
pass
|
|
return help_annotation
|
|
|
|
@classmethod
|
|
def _get_argument_custom_choices(cls, param: inspect.Parameter) -> Optional[str]:
|
|
"""Returns the help annotation for the given parameter."""
|
|
base_annotation = param.annotation
|
|
_, custom_annotations = cls._split_annotation(base_annotation, OpenBBField)
|
|
choices_annotation = (
|
|
custom_annotations[0].choices if custom_annotations else None
|
|
)
|
|
if not choices_annotation:
|
|
# try to get it from the docstring
|
|
pass
|
|
return choices_annotation
|
|
|
|
def _get_nargs(self, param: inspect.Parameter) -> Optional[str]:
|
|
"""Returns the nargs annotation for the given parameter."""
|
|
param_type = self.type_hints[param.name]
|
|
origin = get_origin(param_type)
|
|
|
|
if origin is list:
|
|
return "+"
|
|
|
|
if origin is Union and any(
|
|
get_origin(arg) is list for arg in get_args(param_type)
|
|
):
|
|
return "+"
|
|
|
|
return None
|
|
|
|
def _generate_argparse_arguments(self, parameters) -> None:
|
|
"""Generates the argparse arguments from the function parameters."""
|
|
for param in parameters.values():
|
|
# TODO : how to handle kwargs?
|
|
# it's possible to add unknown arguments when parsing as follows:
|
|
# args, unknown_args = parser.parse_known_args()
|
|
if param.name == "kwargs":
|
|
continue
|
|
|
|
param_type, choices = self._get_type_and_choices(param)
|
|
|
|
# if the param is a custom type, we need to flatten it
|
|
if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
|
|
# update type hints with the custom type fields
|
|
type_hints = get_type_hints(param_type)
|
|
# prefix the type hints keys with the param name
|
|
type_hints = {
|
|
f"{param.name}{SEP}{key}": value
|
|
for key, value in type_hints.items()
|
|
}
|
|
self.type_hints.update(type_hints)
|
|
# create a signature from the custom type
|
|
sig = inspect.signature(param_type)
|
|
|
|
# add help to the annotation
|
|
annotated_parameters: List[inspect.Parameter] = []
|
|
for child_param in sig.parameters.values():
|
|
new_child_param = child_param.replace(
|
|
name=f"{param.name}{SEP}{child_param.name}",
|
|
annotation=Annotated[
|
|
child_param.annotation,
|
|
OpenBBField(
|
|
description=param_type.model_json_schema()[
|
|
"properties"
|
|
][child_param.name].get("description", None)
|
|
),
|
|
],
|
|
kind=inspect.Parameter.KEYWORD_ONLY,
|
|
)
|
|
annotated_parameters.append(new_child_param)
|
|
|
|
# replacing with the annotated parameters
|
|
new_signature = inspect.Signature(
|
|
parameters=annotated_parameters,
|
|
return_annotation=sig.return_annotation,
|
|
)
|
|
self._generate_argparse_arguments(new_signature.parameters)
|
|
|
|
# the custom type itself should not be added as an argument
|
|
continue
|
|
|
|
required = not self._param_is_default(param)
|
|
|
|
kwargs = {
|
|
"type": param_type,
|
|
"dest": param.name,
|
|
"default": param.default,
|
|
"required": required,
|
|
"action": self._get_action_type(param),
|
|
"help": self._get_argument_custom_help(param),
|
|
"nargs": self._get_nargs(param),
|
|
}
|
|
|
|
if choices:
|
|
kwargs["choices"] = choices
|
|
|
|
if param_type == bool:
|
|
# store_true action does not accept the below kwargs
|
|
kwargs.pop("type")
|
|
kwargs.pop("nargs")
|
|
|
|
if required:
|
|
self._required.add_argument(
|
|
f"--{param.name}",
|
|
**kwargs,
|
|
)
|
|
else:
|
|
self._parser.add_argument(
|
|
f"--{param.name}",
|
|
**kwargs,
|
|
)
|
|
|
|
@staticmethod
|
|
def _unflatten_args(args: dict) -> Dict[str, Any]:
|
|
"""Unflatten the args that were flattened by the custom types."""
|
|
result: Dict[str, Any] = {}
|
|
for key, value in args.items():
|
|
if SEP in key:
|
|
parts = key.split(SEP)
|
|
nested_dict = result
|
|
for part in parts[:-1]:
|
|
if part not in nested_dict:
|
|
nested_dict[part] = {}
|
|
nested_dict = nested_dict[part]
|
|
nested_dict[parts[-1]] = value
|
|
else:
|
|
result[key] = value
|
|
return result
|
|
|
|
def _update_with_custom_types(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Update the kwargs with the custom types."""
|
|
# for each argument in the signature that is a custom type, we need to
|
|
# update the kwargs with the custom type kwargs
|
|
for param in self.signature.parameters.values():
|
|
# TODO : how to handle kwargs?
|
|
if param.name == "kwargs":
|
|
continue
|
|
param_type, _ = self._get_type_and_choices(param)
|
|
if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
|
|
custom_type_kwargs = kwargs[param.name]
|
|
kwargs[param.name] = param_type(**custom_type_kwargs)
|
|
|
|
return kwargs
|
|
|
|
def execute_func(
|
|
self,
|
|
parsed_args: Optional[argparse.Namespace] = None,
|
|
) -> Any:
|
|
"""
|
|
Executes the original function with the parsed arguments.
|
|
|
|
Args:
|
|
parsed_args (Optional[argparse.Namespace], optional): The parsed arguments. Defaults to None.
|
|
|
|
Returns:
|
|
Any: The return value of the original function.
|
|
|
|
"""
|
|
kwargs = self._unflatten_args(vars(parsed_args))
|
|
kwargs = self._update_with_custom_types(kwargs)
|
|
|
|
# remove kwargs that doesn't match the signature or provider parameters
|
|
kwargs = {
|
|
key: value
|
|
for key, value in kwargs.items()
|
|
if key in self.signature.parameters or key in self.provider_parameters
|
|
}
|
|
|
|
return self.func(**kwargs)
|
|
|
|
def parse_args_and_execute(self) -> Any:
|
|
"""
|
|
Parses the arguments and executes the original function.
|
|
|
|
Returns:
|
|
Any: The return value of the original function.
|
|
"""
|
|
parsed_args = self._parser.parse_args()
|
|
return self.execute_func(parsed_args)
|
|
|
|
def translate(self) -> Callable:
|
|
"""
|
|
Wraps the original function with an argparse program.
|
|
|
|
Returns:
|
|
Callable: The original function wrapped with an argparse program.
|
|
"""
|
|
|
|
def wrapper_func():
|
|
return self.parse_args_and_execute()
|
|
|
|
return wrapper_func
|