mirror of
https://github.com/OpenBB-finance/OpenBB.git
synced 2026-05-06 22:12:12 +08:00
[Feature] openbb-cookiecutter: Enrich Input Arguments to Filter by Extension Type (#7432)
* enhance cookiecutter input to filter by extension type and only generate what was selected * missed a file * pyproject and lock
This commit is contained in:
1
cookiecutter/cookiecutter.json
vendored
1
cookiecutter/cookiecutter.json
vendored
@@ -4,5 +4,6 @@
|
||||
"project_name": "Super Quant",
|
||||
"project_tag": "{{ cookiecutter.project_name.lower().replace(' ', '-') }}",
|
||||
"package_name": "{{ cookiecutter.project_name.lower().replace(' ', '_') }}",
|
||||
"extension_types": "router",
|
||||
"_template": "{% now 'utc', '%Y%m%d%H%M%S' %}"
|
||||
}
|
||||
|
||||
@@ -6,9 +6,81 @@ import argparse
|
||||
import sys
|
||||
|
||||
from cookiecutter.main import cookiecutter
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from . import get_template_path
|
||||
|
||||
VALID_EXTENSION_TYPES = [
|
||||
"router",
|
||||
"provider",
|
||||
"obbject",
|
||||
"on_command_output",
|
||||
"charting",
|
||||
"all",
|
||||
]
|
||||
|
||||
|
||||
def _parse_extension_types(value: str) -> list[str]:
|
||||
types = [t.strip() for t in value.split(",") if t.strip()]
|
||||
invalid = [t for t in types if t not in VALID_EXTENSION_TYPES]
|
||||
if invalid:
|
||||
raise ValueError(
|
||||
f"Invalid extension type(s): {', '.join(invalid)}. "
|
||||
f"Valid choices: {', '.join(VALID_EXTENSION_TYPES)}"
|
||||
)
|
||||
if not types:
|
||||
raise ValueError("At least one extension type must be selected.")
|
||||
return types
|
||||
|
||||
|
||||
def _prompt_context(preset_extension_types: list[str] | None = None) -> dict:
|
||||
context = {}
|
||||
|
||||
context["full_name"] = Prompt.ask(" full_name", default="Hello World")
|
||||
context["email"] = Prompt.ask(" email", default="hello@world.com")
|
||||
context["project_name"] = Prompt.ask(
|
||||
" project_name", default="OpenBB Python Extension Template"
|
||||
)
|
||||
default_tag = context["project_name"].lower().replace(" ", "-").replace("_", "-")
|
||||
context["project_tag"] = Prompt.ask(" project_tag", default=default_tag)
|
||||
default_pkg = context["project_name"].lower().replace(" ", "_").replace("-", "_")
|
||||
context["package_name"] = Prompt.ask(" package_name", default=default_pkg)
|
||||
|
||||
if preset_extension_types:
|
||||
types = preset_extension_types
|
||||
else:
|
||||
while True:
|
||||
raw = Prompt.ask(
|
||||
" extension_types"
|
||||
" - router | provider | obbject | on_command_output | charting | all",
|
||||
default="router",
|
||||
)
|
||||
try:
|
||||
types = _parse_extension_types(raw)
|
||||
break
|
||||
except ValueError as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
context["extension_types"] = ",".join(types)
|
||||
is_all = "all" in types
|
||||
|
||||
if is_all or "provider" in types:
|
||||
context["provider_name"] = Prompt.ask(" provider_name", default="template")
|
||||
else:
|
||||
context["provider_name"] = "template"
|
||||
|
||||
if is_all or "router" in types or "charting" in types:
|
||||
context["router_name"] = Prompt.ask(" router_name", default="template")
|
||||
else:
|
||||
context["router_name"] = "template"
|
||||
|
||||
if is_all or "obbject" in types or "on_command_output" in types:
|
||||
context["obbject_name"] = Prompt.ask(" obbject_name", default="template")
|
||||
else:
|
||||
context["obbject_name"] = "template"
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def main(argv: list | None = None) -> int:
|
||||
"""Run the OpenBB cookiecutter template.
|
||||
@@ -42,10 +114,18 @@ def main(argv: list | None = None) -> int:
|
||||
metavar="KEY=VALUE",
|
||||
help="Extra context variables (can be used multiple times)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--extension-types",
|
||||
nargs="+",
|
||||
choices=VALID_EXTENSION_TYPES,
|
||||
default=None,
|
||||
help="Extension types to include (default: all). "
|
||||
"Choices: router, provider, obbject, on_command_output, charting, all",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# Build extra context from arguments
|
||||
extra_context = {}
|
||||
if args.extra_context:
|
||||
for item in args.extra_context:
|
||||
@@ -55,14 +135,21 @@ def main(argv: list | None = None) -> int:
|
||||
key, value = item.split("=", 1)
|
||||
extra_context[key] = value
|
||||
|
||||
# Get the bundled template path
|
||||
if args.no_input:
|
||||
if args.extension_types:
|
||||
extra_context["extension_types"] = ",".join(args.extension_types)
|
||||
else:
|
||||
preset_types = args.extension_types if args.extension_types else None
|
||||
context = _prompt_context(preset_extension_types=preset_types)
|
||||
extra_context.update(context)
|
||||
|
||||
template_path = get_template_path()
|
||||
|
||||
try:
|
||||
cookiecutter(
|
||||
str(template_path),
|
||||
output_dir=args.output_dir,
|
||||
no_input=args.no_input,
|
||||
no_input=True,
|
||||
overwrite_if_exists=args.overwrite_if_exists,
|
||||
extra_context=extra_context if extra_context else None,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"project_name": "OpenBB Python Extension Template",
|
||||
"project_tag": "extension-template",
|
||||
"package_name": "extension_template",
|
||||
"extension_types": "all",
|
||||
"provider_name": "template",
|
||||
"router_name": "template",
|
||||
"obbject_name": "template",
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""OpenBB Platform Extension post-generation script."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
MODULE_REGEX = r"^[_a-zA-Z][_a-zA-Z0-9]+$"
|
||||
@@ -9,23 +11,90 @@ MODULE_NAME = "{{ cookiecutter.package_name }}"
|
||||
PROVIDER_NAME = "{{ cookiecutter.provider_name }}" or ""
|
||||
ROUTER_NAME = "{{ cookiecutter.router_name }}" or ""
|
||||
OBBJECT_NAME = "{{ cookiecutter.obbject_name }}" or ""
|
||||
EXTENSION_TYPES_RAW = "{{ cookiecutter.extension_types }}"
|
||||
|
||||
|
||||
def parse_extension_types(raw: str) -> set[str]:
|
||||
types = {t.strip().lower() for t in raw.split(",") if t.strip()}
|
||||
if "all" in types:
|
||||
return {"router", "provider", "obbject", "on_command_output", "charting"}
|
||||
return types
|
||||
|
||||
|
||||
def remove_path(path: str):
|
||||
if os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
elif os.path.isfile(path):
|
||||
os.remove(path)
|
||||
|
||||
|
||||
EXTENSION_TYPES = parse_extension_types(EXTENSION_TYPES_RAW)
|
||||
|
||||
if not re.match(MODULE_REGEX, MODULE_NAME):
|
||||
print(f"ERROR: {MODULE_NAME} is not a valid Python package name.")
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
if PROVIDER_NAME and not re.match(MODULE_REGEX, PROVIDER_NAME):
|
||||
has_router = "router" in EXTENSION_TYPES
|
||||
has_charting = "charting" in EXTENSION_TYPES
|
||||
has_provider = "provider" in EXTENSION_TYPES
|
||||
has_obbject = "obbject" in EXTENSION_TYPES
|
||||
has_on_command_output = "on_command_output" in EXTENSION_TYPES
|
||||
|
||||
if has_provider and PROVIDER_NAME and not re.match(MODULE_REGEX, PROVIDER_NAME):
|
||||
print(f"ERROR: {PROVIDER_NAME} should be in lower snakecase.")
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
if ROUTER_NAME and not re.match(MODULE_REGEX, ROUTER_NAME):
|
||||
if (
|
||||
(has_router or has_charting)
|
||||
and ROUTER_NAME
|
||||
and not re.match(MODULE_REGEX, ROUTER_NAME)
|
||||
):
|
||||
print(f"ERROR: {ROUTER_NAME} should be in lower snakecase.")
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
if OBBJECT_NAME and not re.match(MODULE_REGEX, OBBJECT_NAME):
|
||||
if (
|
||||
(has_obbject or has_on_command_output)
|
||||
and OBBJECT_NAME
|
||||
and not re.match(MODULE_REGEX, OBBJECT_NAME)
|
||||
):
|
||||
print(f"ERROR: {OBBJECT_NAME} should be in lower snakecase.")
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
routers_dir = os.path.join(MODULE_NAME, "routers")
|
||||
providers_dir = os.path.join(MODULE_NAME, "providers")
|
||||
obbject_dir = os.path.join(MODULE_NAME, "obbject")
|
||||
|
||||
if not has_router:
|
||||
remove_path(os.path.join(routers_dir, ROUTER_NAME + ".py"))
|
||||
remove_path(os.path.join(routers_dir, "depends.py"))
|
||||
|
||||
if not has_charting:
|
||||
remove_path(os.path.join(routers_dir, ROUTER_NAME + "_views.py"))
|
||||
|
||||
if not has_router and not has_charting:
|
||||
remove_path(routers_dir)
|
||||
|
||||
if not has_provider:
|
||||
remove_path(providers_dir)
|
||||
|
||||
if not has_obbject and not has_on_command_output:
|
||||
remove_path(obbject_dir)
|
||||
|
||||
if (has_obbject or has_on_command_output) and not has_router:
|
||||
try:
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
eps = entry_points()
|
||||
core_eps = (
|
||||
eps.select(group="openbb_core_extension")
|
||||
if hasattr(eps, "select")
|
||||
else eps.get("openbb_core_extension", [])
|
||||
)
|
||||
if not list(core_eps):
|
||||
print(
|
||||
"\n WARNING: No 'openbb_core_extension' entry points found in the environment."
|
||||
"\n The 'obbject' and 'on_command_output' extension types require at least"
|
||||
"\n one router extension to be installed in order to function.\n"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -15,12 +15,24 @@ With it you can:
|
||||
## Getting Started
|
||||
|
||||
We recommend you check out the files in the following order:
|
||||
|
||||
{% set types = cookiecutter.extension_types.split(',') | map('trim') | list %}
|
||||
{% if 'router' in types or 'all' in types %}
|
||||
* `{{cookiecutter.package_name}}/routers/{{cookiecutter.router_name}}.py`
|
||||
* `{{cookiecutter.package_name}}/prvoviders/{{cookiecutter.provider_name}}/models/example.py`
|
||||
{% endif %}
|
||||
{% if 'provider' in types or 'all' in types %}
|
||||
* `{{cookiecutter.package_name}}/providers/{{cookiecutter.provider_name}}/models/example.py`
|
||||
* `{{cookiecutter.package_name}}/providers/{{cookiecutter.provider_name}}/__init__.py`
|
||||
{% endif %}
|
||||
{% if 'obbject' in types or 'on_command_output' in types or 'all' in types %}
|
||||
* `{{cookiecutter.package_name}}/obbject/{{cookiecutter.obbject_name}}/__init__.py`
|
||||
{% endif %}
|
||||
{% if 'charting' in types or 'all' in types %}
|
||||
* `{{cookiecutter.package_name}}/routers/{{cookiecutter.router_name}}_views.py`
|
||||
{% endif %}
|
||||
{% if 'charting' in types and 'router' not in types and 'all' not in types %}
|
||||
|
||||
> **Note:** You selected charting without a router. The views file references `{{cookiecutter.router_name}}` naming conventions. You will need to pair this with an existing router extension that uses the same name.
|
||||
{% endif %}
|
||||
|
||||
Check out the developer [documentation](https://docs.openbb.co/python/developer) for more information on getting started making OpenBB extensions.
|
||||
|
||||
|
||||
@@ -18,18 +18,26 @@ openbb-devtools = { version = "*" }
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
{% set types = cookiecutter.extension_types.split(',') | map('trim') | list %}
|
||||
{% if 'router' in types or 'all' in types %}
|
||||
[tool.poetry.plugins."openbb_core_extension"]
|
||||
{{ cookiecutter.router_name }} = "{{ cookiecutter.package_name }}.routers.{{ cookiecutter.router_name }}:router"
|
||||
|
||||
{% endif %}
|
||||
{% if 'charting' in types or 'all' in types %}
|
||||
[tool.poetry.plugins."openbb_charting_extension"]
|
||||
{{ cookiecutter.router_name }} = "{{ cookiecutter.package_name }}.routers.{{ cookiecutter.router_name }}_views:{{cookiecutter.router_name.replace('_', ' ').title().replace(' ', '').replace('"', '')}}Views"
|
||||
|
||||
{% endif %}
|
||||
{% if 'provider' in types or 'all' in types %}
|
||||
[tool.poetry.plugins."openbb_provider_extension"]
|
||||
{{ cookiecutter.provider_name }} = "{{ cookiecutter.package_name }}.providers.{{ cookiecutter.provider_name }}:{{ cookiecutter.provider_name }}_provider"
|
||||
|
||||
{% endif %}
|
||||
{% if 'obbject' in types or 'on_command_output' in types or 'all' in types %}
|
||||
[tool.poetry.plugins."openbb_obbject_extension"]
|
||||
{% if 'obbject' in types or 'all' in types %}
|
||||
to_string = "{{ cookiecutter.package_name }}.obbject.{{ cookiecutter.obbject_name }}:ext"
|
||||
{{ cookiecutter.obbject_name }} = "{{ cookiecutter.package_name }}.obbject.{{ cookiecutter.obbject_name }}:class_ext"
|
||||
# Uncomment to use
|
||||
# nonblocking_plugin = "{{ cookiecutter.package_name }}.obbject.{{ cookiecutter.obbject_name }}:nonblocking_plugin"
|
||||
{% endif %}
|
||||
{% if 'on_command_output' in types or 'all' in types %}
|
||||
nonblocking_plugin = "{{ cookiecutter.package_name }}.obbject.{{ cookiecutter.obbject_name }}:nonblocking_plugin"
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
@@ -1,35 +1,36 @@
|
||||
"""{{ cookiecutter.package_name }} OBBject Extension - {{ cookiecutter.obbject_name }}"""
|
||||
|
||||
{% set types = cookiecutter.extension_types.split(',') | map('trim') | list %}
|
||||
{% set has_obbject = 'obbject' in types or 'all' in types %}
|
||||
{% set has_on_command_output = 'on_command_output' in types or 'all' in types %}
|
||||
# pylint: disable=W0613,R0903
|
||||
{% if has_on_command_output %}
|
||||
|
||||
import threading
|
||||
import time
|
||||
{% endif %}
|
||||
|
||||
from openbb_core.app.model.extension import Extension
|
||||
from openbb_core.app.model.obbject import OBBject
|
||||
{% if has_obbject %}
|
||||
|
||||
# Extensions are registered as OBBject accessors.
|
||||
# It can be a class, or it can be a callable method.
|
||||
ext = Extension(
|
||||
name="to_string",
|
||||
description="An OBBject extension that converts the results to a string representation.",
|
||||
)
|
||||
|
||||
# If it is a function, no parameters will be accepted.
|
||||
# The function will execute like a property method.
|
||||
# The accessor is called when the namespace is entered.
|
||||
|
||||
@ext.obbject_accessor
|
||||
def to_string(obbject, **kwargs) -> str:
|
||||
"""OBBject accessor providing a "to_string" method."""
|
||||
return obbject.model_dump_json(exclude_none=True, exclude_unset=True, include="results")
|
||||
|
||||
# We ignore this OpenBBWarning: Skipping '{{ cookiecutter.obbject_name }}', name already in user.
|
||||
|
||||
class_ext = Extension(
|
||||
name="{{ cookiecutter.obbject_name }}",
|
||||
description="An OBBject extension with namespace."
|
||||
)
|
||||
|
||||
|
||||
@class_ext.obbject_accessor
|
||||
class OBBjectExtension:
|
||||
"""OBBject Extension Template."""
|
||||
@@ -41,44 +42,45 @@ class OBBjectExtension:
|
||||
def hello_world(self, **kwargs):
|
||||
"""Say hello from the OBBject extension."""
|
||||
print(f"Hello from the OBBject instance! \n\n{repr(self._obbject)}") # noqa
|
||||
{% endif %}
|
||||
{% if has_on_command_output %}
|
||||
|
||||
## Non-blocking OBBject Extension Example
|
||||
## Uncomment to use
|
||||
#nonblocking_plugin = Extension(
|
||||
# name="nonblocking_plugin",
|
||||
# description="An on-command-output plugin simulating an extensive task performed in a separate thread.",
|
||||
# on_command_output=True, # Must be set as True
|
||||
# command_output_paths=["/{{cookiecutter.router_name}}/candles"],
|
||||
# immutable=True, # Set to `True` for parallel processing.
|
||||
# results_only=False, # Use this as a flag to return only the "results" portion of the OBBject.
|
||||
#)
|
||||
nonblocking_plugin = Extension(
|
||||
name="nonblocking_plugin",
|
||||
description="An on-command-output plugin simulating an extensive task performed in a separate thread.",
|
||||
on_command_output=True,
|
||||
command_output_paths=["/{{cookiecutter.router_name}}/candles"],
|
||||
immutable=True,
|
||||
results_only=False,
|
||||
)
|
||||
|
||||
|
||||
#def _expensive_operation_worker(serialized_obbject: dict):
|
||||
# """Simulate a long-running task without blocking the caller."""
|
||||
# working_copy = OBBject(**serialized_obbject)
|
||||
# print("\nThis is the deserialized OBBject in the non-blocking thread.")
|
||||
# print(working_copy.__repr__())
|
||||
# for i in range(10):
|
||||
# print(str(i) + " seconds remaining...")
|
||||
# time.sleep(1)
|
||||
# print("Expensive operation is now complete.")
|
||||
def _expensive_operation_worker(serialized_obbject: dict):
|
||||
"""Simulate a long-running task without blocking the caller."""
|
||||
working_copy = OBBject(**serialized_obbject)
|
||||
print("\nThis is the deserialized OBBject in the non-blocking thread.")
|
||||
print(working_copy.__repr__())
|
||||
for i in range(10):
|
||||
print(str(i) + " seconds remaining...")
|
||||
time.sleep(1)
|
||||
print("Expensive operation is now complete.")
|
||||
|
||||
|
||||
#@nonblocking_plugin.obbject_accessor
|
||||
#def empty_plugin_function(obbject): # This can also be an async function.
|
||||
# """Simulated on_command_output function that executes an expensive task
|
||||
# in a non-blocking thread."""
|
||||
# print(
|
||||
# "Serializing the obbject and passing to a new thread.\n"
|
||||
# f"Command executed: {obbject.extra['metadata']}\n"
|
||||
# )
|
||||
# print(
|
||||
# "Simulating an expensive task that is non-blocking and allows the function to return."
|
||||
# )
|
||||
# threading.Thread(
|
||||
# target=_expensive_operation_worker,
|
||||
# args=(obbject.model_dump(),),
|
||||
# name="empty-plugin-expensive-operation",
|
||||
# daemon=False,
|
||||
# ).start()
|
||||
@nonblocking_plugin.obbject_accessor
|
||||
def empty_plugin_function(obbject):
|
||||
"""Simulated on_command_output function that executes an expensive task
|
||||
in a non-blocking thread."""
|
||||
print(
|
||||
"Serializing the obbject and passing to a new thread.\n"
|
||||
f"Command executed: {obbject.extra['metadata']}\n"
|
||||
)
|
||||
print(
|
||||
"Simulating an expensive task that is non-blocking and allows the function to return."
|
||||
)
|
||||
threading.Thread(
|
||||
target=_expensive_operation_worker,
|
||||
args=(obbject.model_dump(),),
|
||||
name="empty-plugin-expensive-operation",
|
||||
daemon=False,
|
||||
).start()
|
||||
{% endif %}
|
||||
|
||||
@@ -14,7 +14,7 @@ from openbb_core.provider.standard_models.equity_historical import (
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
|
||||
class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams(EquityHistoricalQueryParams):
|
||||
class {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams(EquityHistoricalQueryParams):
|
||||
"""Example provider query.
|
||||
|
||||
The standard model here comes with parameters for symbol, start_date, and end_date.
|
||||
@@ -25,7 +25,7 @@ class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '
|
||||
)
|
||||
|
||||
|
||||
class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalData(EquityHistoricalData):
|
||||
class {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalData(EquityHistoricalData):
|
||||
"""Sample provider data.
|
||||
|
||||
The standard model has these fields,
|
||||
@@ -51,10 +51,10 @@ class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '
|
||||
return v if v else "Data validator replaced None."
|
||||
|
||||
|
||||
class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalFetcher(
|
||||
class {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalFetcher(
|
||||
Fetcher[
|
||||
{{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams,
|
||||
list[{{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalData],
|
||||
{{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams,
|
||||
list[{{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalData],
|
||||
]
|
||||
):
|
||||
"""Example Fetcher class.
|
||||
@@ -63,18 +63,18 @@ class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def transform_query(params: dict[str, Any]) -> {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams:
|
||||
def transform_query(params: dict[str, Any]) -> {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams:
|
||||
"""Define example transform_query.
|
||||
|
||||
Here we can pre-process the query parameters and add any extra parameters that
|
||||
will be used inside the extract_data method.
|
||||
"""
|
||||
return {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams(**params)
|
||||
return {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams(**params)
|
||||
|
||||
# Note the use of async here. Make the Fetcher async with this small change.
|
||||
@staticmethod
|
||||
async def aextract_data(
|
||||
query: {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams,
|
||||
query: {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams,
|
||||
credentials: dict[str, str] | None,
|
||||
**kwargs: Any,
|
||||
) -> list[dict]:
|
||||
@@ -115,13 +115,13 @@ class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '
|
||||
|
||||
@staticmethod
|
||||
def transform_data(
|
||||
query: {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams,
|
||||
query: {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams,
|
||||
data: list[dict],
|
||||
**kwargs: Any
|
||||
) -> list[{{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalData]:
|
||||
) -> list[{{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalData]:
|
||||
"""Define example transform_data.
|
||||
|
||||
Right now, we're converting the data to fit our desired format.
|
||||
You can apply other transformations to it here.
|
||||
"""
|
||||
return [{{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalData.model_validate(d) for d in data]
|
||||
return [{{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalData.model_validate(d) for d in data]
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
"""{{cookiecutter.router_name}} router command example."""
|
||||
|
||||
{% set types = cookiecutter.extension_types.split(',') | map('trim') | list %}
|
||||
{% set has_provider = 'provider' in types or 'all' in types %}
|
||||
# pylint: disable=unused-argument
|
||||
{% if has_provider %}
|
||||
|
||||
from openbb_core.app.model.command_context import CommandContext
|
||||
{% endif %}
|
||||
from openbb_core.app.model.example import APIEx, PythonEx
|
||||
from openbb_core.app.model.obbject import OBBject
|
||||
{% if has_provider %}
|
||||
from openbb_core.app.provider_interface import ExtraParams, ProviderChoices, StandardParams
|
||||
from openbb_core.app.query import Query
|
||||
{% endif %}
|
||||
from openbb_core.app.router import Router
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Example dependency injection yielding a configured requests.Session object.
|
||||
from {{cookiecutter.package_name}}.routers.depends import Session
|
||||
|
||||
# The prefix extension's prefix is determined by the `pyproject.toml` EntryPoint assignment.
|
||||
# Assign a prefix only if this is a sub-router.
|
||||
router = Router(prefix="")
|
||||
|
||||
|
||||
@@ -52,6 +54,7 @@ async def post_example(
|
||||
spread = ask - bid
|
||||
|
||||
return OBBject(results={"mid": mid, "spread": spread, "flag": flag})
|
||||
{% if has_provider %}
|
||||
|
||||
|
||||
@router.command(model="Example")
|
||||
@@ -65,14 +68,13 @@ async def model_example(
|
||||
return await OBBject.from_query(Query(**locals()))
|
||||
|
||||
|
||||
# If you had another provider installed that mapped to this model - i.e, `openbb-fmp`
|
||||
# they will be added to this endpoint.
|
||||
@router.command(model="EquityHistorical")
|
||||
async def candles(
|
||||
cc: CommandContext,
|
||||
provider_choices: ProviderChoices,
|
||||
standard_params: StandardParams,
|
||||
extra_params: ExtraParams,
|
||||
) -> OBBject: # results type is inferred from Fetcher annotations.
|
||||
) -> OBBject:
|
||||
"""Example Data."""
|
||||
return await OBBject.from_query(Query(**locals()))
|
||||
{% endif %}
|
||||
|
||||
2
cookiecutter/pyproject.toml
vendored
2
cookiecutter/pyproject.toml
vendored
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "openbb-cookiecutter"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
description = "Extensions template for the OpenBB Python Package."
|
||||
license = "AGPL-3.0-only"
|
||||
authors = ["OpenBB Team <hello@openbb.co>"]
|
||||
|
||||
Reference in New Issue
Block a user