[Feature] openbb-cookiecutter: Enrich Input Arguments to Filter by Extension Type (#7432)

* enhance cookiecutter input to filter by extension type and only generate what was selected

* missed a file

* pyproject and lock
This commit is contained in:
Danglewood
2026-04-20 14:50:46 -07:00
committed by GitHub
parent 203f62e7d2
commit a14208947e
10 changed files with 262 additions and 80 deletions

View File

@@ -4,5 +4,6 @@
"project_name": "Super Quant",
"project_tag": "{{ cookiecutter.project_name.lower().replace(' ', '-') }}",
"package_name": "{{ cookiecutter.project_name.lower().replace(' ', '_') }}",
"extension_types": "router",
"_template": "{% now 'utc', '%Y%m%d%H%M%S' %}"
}

View File

@@ -6,9 +6,81 @@ import argparse
import sys
from cookiecutter.main import cookiecutter
from rich.prompt import Prompt
from . import get_template_path
VALID_EXTENSION_TYPES = [
"router",
"provider",
"obbject",
"on_command_output",
"charting",
"all",
]
def _parse_extension_types(value: str) -> list[str]:
types = [t.strip() for t in value.split(",") if t.strip()]
invalid = [t for t in types if t not in VALID_EXTENSION_TYPES]
if invalid:
raise ValueError(
f"Invalid extension type(s): {', '.join(invalid)}. "
f"Valid choices: {', '.join(VALID_EXTENSION_TYPES)}"
)
if not types:
raise ValueError("At least one extension type must be selected.")
return types
def _prompt_context(preset_extension_types: list[str] | None = None) -> dict:
context = {}
context["full_name"] = Prompt.ask(" full_name", default="Hello World")
context["email"] = Prompt.ask(" email", default="hello@world.com")
context["project_name"] = Prompt.ask(
" project_name", default="OpenBB Python Extension Template"
)
default_tag = context["project_name"].lower().replace(" ", "-").replace("_", "-")
context["project_tag"] = Prompt.ask(" project_tag", default=default_tag)
default_pkg = context["project_name"].lower().replace(" ", "_").replace("-", "_")
context["package_name"] = Prompt.ask(" package_name", default=default_pkg)
if preset_extension_types:
types = preset_extension_types
else:
while True:
raw = Prompt.ask(
" extension_types"
" - router | provider | obbject | on_command_output | charting | all",
default="router",
)
try:
types = _parse_extension_types(raw)
break
except ValueError as e:
print(f" Error: {e}")
context["extension_types"] = ",".join(types)
is_all = "all" in types
if is_all or "provider" in types:
context["provider_name"] = Prompt.ask(" provider_name", default="template")
else:
context["provider_name"] = "template"
if is_all or "router" in types or "charting" in types:
context["router_name"] = Prompt.ask(" router_name", default="template")
else:
context["router_name"] = "template"
if is_all or "obbject" in types or "on_command_output" in types:
context["obbject_name"] = Prompt.ask(" obbject_name", default="template")
else:
context["obbject_name"] = "template"
return context
def main(argv: list | None = None) -> int:
"""Run the OpenBB cookiecutter template.
@@ -42,10 +114,18 @@ def main(argv: list | None = None) -> int:
metavar="KEY=VALUE",
help="Extra context variables (can be used multiple times)",
)
parser.add_argument(
"-e",
"--extension-types",
nargs="+",
choices=VALID_EXTENSION_TYPES,
default=None,
help="Extension types to include (default: all). "
"Choices: router, provider, obbject, on_command_output, charting, all",
)
args = parser.parse_args(argv)
# Build extra context from arguments
extra_context = {}
if args.extra_context:
for item in args.extra_context:
@@ -55,14 +135,21 @@ def main(argv: list | None = None) -> int:
key, value = item.split("=", 1)
extra_context[key] = value
# Get the bundled template path
if args.no_input:
if args.extension_types:
extra_context["extension_types"] = ",".join(args.extension_types)
else:
preset_types = args.extension_types if args.extension_types else None
context = _prompt_context(preset_extension_types=preset_types)
extra_context.update(context)
template_path = get_template_path()
try:
cookiecutter(
str(template_path),
output_dir=args.output_dir,
no_input=args.no_input,
no_input=True,
overwrite_if_exists=args.overwrite_if_exists,
extra_context=extra_context if extra_context else None,
)

View File

@@ -4,6 +4,7 @@
"project_name": "OpenBB Python Extension Template",
"project_tag": "extension-template",
"package_name": "extension_template",
"extension_types": "all",
"provider_name": "template",
"router_name": "template",
"obbject_name": "template",

View File

@@ -1,6 +1,8 @@
"""OpenBB Platform Extension post-generation script."""
import os
import re
import shutil
import sys
MODULE_REGEX = r"^[_a-zA-Z][_a-zA-Z0-9]+$"
@@ -9,23 +11,90 @@ MODULE_NAME = "{{ cookiecutter.package_name }}"
PROVIDER_NAME = "{{ cookiecutter.provider_name }}" or ""
ROUTER_NAME = "{{ cookiecutter.router_name }}" or ""
OBBJECT_NAME = "{{ cookiecutter.obbject_name }}" or ""
EXTENSION_TYPES_RAW = "{{ cookiecutter.extension_types }}"
def parse_extension_types(raw: str) -> set[str]:
types = {t.strip().lower() for t in raw.split(",") if t.strip()}
if "all" in types:
return {"router", "provider", "obbject", "on_command_output", "charting"}
return types
def remove_path(path: str):
if os.path.isdir(path):
shutil.rmtree(path)
elif os.path.isfile(path):
os.remove(path)
EXTENSION_TYPES = parse_extension_types(EXTENSION_TYPES_RAW)
if not re.match(MODULE_REGEX, MODULE_NAME):
print(f"ERROR: {MODULE_NAME} is not a valid Python package name.")
sys.exit(1)
if PROVIDER_NAME and not re.match(MODULE_REGEX, PROVIDER_NAME):
has_router = "router" in EXTENSION_TYPES
has_charting = "charting" in EXTENSION_TYPES
has_provider = "provider" in EXTENSION_TYPES
has_obbject = "obbject" in EXTENSION_TYPES
has_on_command_output = "on_command_output" in EXTENSION_TYPES
if has_provider and PROVIDER_NAME and not re.match(MODULE_REGEX, PROVIDER_NAME):
print(f"ERROR: {PROVIDER_NAME} should be in lower snakecase.")
sys.exit(1)
if ROUTER_NAME and not re.match(MODULE_REGEX, ROUTER_NAME):
if (
(has_router or has_charting)
and ROUTER_NAME
and not re.match(MODULE_REGEX, ROUTER_NAME)
):
print(f"ERROR: {ROUTER_NAME} should be in lower snakecase.")
sys.exit(1)
if OBBJECT_NAME and not re.match(MODULE_REGEX, OBBJECT_NAME):
if (
(has_obbject or has_on_command_output)
and OBBJECT_NAME
and not re.match(MODULE_REGEX, OBBJECT_NAME)
):
print(f"ERROR: {OBBJECT_NAME} should be in lower snakecase.")
sys.exit(1)
routers_dir = os.path.join(MODULE_NAME, "routers")
providers_dir = os.path.join(MODULE_NAME, "providers")
obbject_dir = os.path.join(MODULE_NAME, "obbject")
if not has_router:
remove_path(os.path.join(routers_dir, ROUTER_NAME + ".py"))
remove_path(os.path.join(routers_dir, "depends.py"))
if not has_charting:
remove_path(os.path.join(routers_dir, ROUTER_NAME + "_views.py"))
if not has_router and not has_charting:
remove_path(routers_dir)
if not has_provider:
remove_path(providers_dir)
if not has_obbject and not has_on_command_output:
remove_path(obbject_dir)
if (has_obbject or has_on_command_output) and not has_router:
try:
from importlib.metadata import entry_points
eps = entry_points()
core_eps = (
eps.select(group="openbb_core_extension")
if hasattr(eps, "select")
else eps.get("openbb_core_extension", [])
)
if not list(core_eps):
print(
"\n WARNING: No 'openbb_core_extension' entry points found in the environment."
"\n The 'obbject' and 'on_command_output' extension types require at least"
"\n one router extension to be installed in order to function.\n"
)
except Exception:
pass

View File

@@ -15,12 +15,24 @@ With it you can:
## Getting Started
We recommend you check out the files in the following order:
{% set types = cookiecutter.extension_types.split(',') | map('trim') | list %}
{% if 'router' in types or 'all' in types %}
* `{{cookiecutter.package_name}}/routers/{{cookiecutter.router_name}}.py`
* `{{cookiecutter.package_name}}/prvoviders/{{cookiecutter.provider_name}}/models/example.py`
{% endif %}
{% if 'provider' in types or 'all' in types %}
* `{{cookiecutter.package_name}}/providers/{{cookiecutter.provider_name}}/models/example.py`
* `{{cookiecutter.package_name}}/providers/{{cookiecutter.provider_name}}/__init__.py`
{% endif %}
{% if 'obbject' in types or 'on_command_output' in types or 'all' in types %}
* `{{cookiecutter.package_name}}/obbject/{{cookiecutter.obbject_name}}/__init__.py`
{% endif %}
{% if 'charting' in types or 'all' in types %}
* `{{cookiecutter.package_name}}/routers/{{cookiecutter.router_name}}_views.py`
{% endif %}
{% if 'charting' in types and 'router' not in types and 'all' not in types %}
> **Note:** You selected charting without a router. The views file references `{{cookiecutter.router_name}}` naming conventions. You will need to pair this with an existing router extension that uses the same name.
{% endif %}
Check out the developer [documentation](https://docs.openbb.co/python/developer) for more information on getting started making OpenBB extensions.

View File

@@ -18,18 +18,26 @@ openbb-devtools = { version = "*" }
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
{% set types = cookiecutter.extension_types.split(',') | map('trim') | list %}
{% if 'router' in types or 'all' in types %}
[tool.poetry.plugins."openbb_core_extension"]
{{ cookiecutter.router_name }} = "{{ cookiecutter.package_name }}.routers.{{ cookiecutter.router_name }}:router"
{% endif %}
{% if 'charting' in types or 'all' in types %}
[tool.poetry.plugins."openbb_charting_extension"]
{{ cookiecutter.router_name }} = "{{ cookiecutter.package_name }}.routers.{{ cookiecutter.router_name }}_views:{{cookiecutter.router_name.replace('_', ' ').title().replace(' ', '').replace('"', '')}}Views"
{% endif %}
{% if 'provider' in types or 'all' in types %}
[tool.poetry.plugins."openbb_provider_extension"]
{{ cookiecutter.provider_name }} = "{{ cookiecutter.package_name }}.providers.{{ cookiecutter.provider_name }}:{{ cookiecutter.provider_name }}_provider"
{% endif %}
{% if 'obbject' in types or 'on_command_output' in types or 'all' in types %}
[tool.poetry.plugins."openbb_obbject_extension"]
{% if 'obbject' in types or 'all' in types %}
to_string = "{{ cookiecutter.package_name }}.obbject.{{ cookiecutter.obbject_name }}:ext"
{{ cookiecutter.obbject_name }} = "{{ cookiecutter.package_name }}.obbject.{{ cookiecutter.obbject_name }}:class_ext"
# Uncomment to use
# nonblocking_plugin = "{{ cookiecutter.package_name }}.obbject.{{ cookiecutter.obbject_name }}:nonblocking_plugin"
{% endif %}
{% if 'on_command_output' in types or 'all' in types %}
nonblocking_plugin = "{{ cookiecutter.package_name }}.obbject.{{ cookiecutter.obbject_name }}:nonblocking_plugin"
{% endif %}
{% endif %}

View File

@@ -1,35 +1,36 @@
"""{{ cookiecutter.package_name }} OBBject Extension - {{ cookiecutter.obbject_name }}"""
{% set types = cookiecutter.extension_types.split(',') | map('trim') | list %}
{% set has_obbject = 'obbject' in types or 'all' in types %}
{% set has_on_command_output = 'on_command_output' in types or 'all' in types %}
# pylint: disable=W0613,R0903
{% if has_on_command_output %}
import threading
import time
{% endif %}
from openbb_core.app.model.extension import Extension
from openbb_core.app.model.obbject import OBBject
{% if has_obbject %}
# Extensions are registered as OBBject accessors.
# It can be a class, or it can be a callable method.
ext = Extension(
name="to_string",
description="An OBBject extension that converts the results to a string representation.",
)
# If it is a function, no parameters will be accepted.
# The function will execute like a property method.
# The accessor is called when the namespace is entered.
@ext.obbject_accessor
def to_string(obbject, **kwargs) -> str:
"""OBBject accessor providing a "to_string" method."""
return obbject.model_dump_json(exclude_none=True, exclude_unset=True, include="results")
# We ignore this OpenBBWarning: Skipping '{{ cookiecutter.obbject_name }}', name already in user.
class_ext = Extension(
name="{{ cookiecutter.obbject_name }}",
description="An OBBject extension with namespace."
)
@class_ext.obbject_accessor
class OBBjectExtension:
"""OBBject Extension Template."""
@@ -41,44 +42,45 @@ class OBBjectExtension:
def hello_world(self, **kwargs):
"""Say hello from the OBBject extension."""
print(f"Hello from the OBBject instance! \n\n{repr(self._obbject)}") # noqa
{% endif %}
{% if has_on_command_output %}
## Non-blocking OBBject Extension Example
## Uncomment to use
#nonblocking_plugin = Extension(
# name="nonblocking_plugin",
# description="An on-command-output plugin simulating an extensive task performed in a separate thread.",
# on_command_output=True, # Must be set as True
# command_output_paths=["/{{cookiecutter.router_name}}/candles"],
# immutable=True, # Set to `True` for parallel processing.
# results_only=False, # Use this as a flag to return only the "results" portion of the OBBject.
#)
nonblocking_plugin = Extension(
name="nonblocking_plugin",
description="An on-command-output plugin simulating an extensive task performed in a separate thread.",
on_command_output=True,
command_output_paths=["/{{cookiecutter.router_name}}/candles"],
immutable=True,
results_only=False,
)
#def _expensive_operation_worker(serialized_obbject: dict):
# """Simulate a long-running task without blocking the caller."""
# working_copy = OBBject(**serialized_obbject)
# print("\nThis is the deserialized OBBject in the non-blocking thread.")
# print(working_copy.__repr__())
# for i in range(10):
# print(str(i) + " seconds remaining...")
# time.sleep(1)
# print("Expensive operation is now complete.")
def _expensive_operation_worker(serialized_obbject: dict):
"""Simulate a long-running task without blocking the caller."""
working_copy = OBBject(**serialized_obbject)
print("\nThis is the deserialized OBBject in the non-blocking thread.")
print(working_copy.__repr__())
for i in range(10):
print(str(i) + " seconds remaining...")
time.sleep(1)
print("Expensive operation is now complete.")
#@nonblocking_plugin.obbject_accessor
#def empty_plugin_function(obbject): # This can also be an async function.
# """Simulated on_command_output function that executes an expensive task
# in a non-blocking thread."""
# print(
# "Serializing the obbject and passing to a new thread.\n"
# f"Command executed: {obbject.extra['metadata']}\n"
# )
# print(
# "Simulating an expensive task that is non-blocking and allows the function to return."
# )
# threading.Thread(
# target=_expensive_operation_worker,
# args=(obbject.model_dump(),),
# name="empty-plugin-expensive-operation",
# daemon=False,
# ).start()
@nonblocking_plugin.obbject_accessor
def empty_plugin_function(obbject):
"""Simulated on_command_output function that executes an expensive task
in a non-blocking thread."""
print(
"Serializing the obbject and passing to a new thread.\n"
f"Command executed: {obbject.extra['metadata']}\n"
)
print(
"Simulating an expensive task that is non-blocking and allows the function to return."
)
threading.Thread(
target=_expensive_operation_worker,
args=(obbject.model_dump(),),
name="empty-plugin-expensive-operation",
daemon=False,
).start()
{% endif %}

View File

@@ -14,7 +14,7 @@ from openbb_core.provider.standard_models.equity_historical import (
from pydantic import Field, field_validator
class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams(EquityHistoricalQueryParams):
class {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams(EquityHistoricalQueryParams):
"""Example provider query.
The standard model here comes with parameters for symbol, start_date, and end_date.
@@ -25,7 +25,7 @@ class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '
)
class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalData(EquityHistoricalData):
class {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalData(EquityHistoricalData):
"""Sample provider data.
The standard model has these fields,
@@ -51,10 +51,10 @@ class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '
return v if v else "Data validator replaced None."
class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalFetcher(
class {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalFetcher(
Fetcher[
{{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams,
list[{{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalData],
{{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams,
list[{{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalData],
]
):
"""Example Fetcher class.
@@ -63,18 +63,18 @@ class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '
"""
@staticmethod
def transform_query(params: dict[str, Any]) -> {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams:
def transform_query(params: dict[str, Any]) -> {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams:
"""Define example transform_query.
Here we can pre-process the query parameters and add any extra parameters that
will be used inside the extract_data method.
"""
return {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams(**params)
return {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams(**params)
# Note the use of async here. Make the Fetcher async with this small change.
@staticmethod
async def aextract_data(
query: {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams,
query: {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams,
credentials: dict[str, str] | None,
**kwargs: Any,
) -> list[dict]:
@@ -115,13 +115,13 @@ class {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '
@staticmethod
def transform_data(
query: {{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalQueryParams,
query: {{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalQueryParams,
data: list[dict],
**kwargs: Any
) -> list[{{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalData]:
) -> list[{{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalData]:
"""Define example transform_data.
Right now, we're converting the data to fit our desired format.
You can apply other transformations to it here.
"""
return [{{cookiecutter.provider_name.replace('_', ' ').capitalize().replace(' ', '')}}EquityHistoricalData.model_validate(d) for d in data]
return [{{cookiecutter.provider_name.replace('_', ' ').title().replace(' ', '')}}EquityHistoricalData.model_validate(d) for d in data]

View File

@@ -1,20 +1,22 @@
"""{{cookiecutter.router_name}} router command example."""
{% set types = cookiecutter.extension_types.split(',') | map('trim') | list %}
{% set has_provider = 'provider' in types or 'all' in types %}
# pylint: disable=unused-argument
{% if has_provider %}
from openbb_core.app.model.command_context import CommandContext
{% endif %}
from openbb_core.app.model.example import APIEx, PythonEx
from openbb_core.app.model.obbject import OBBject
{% if has_provider %}
from openbb_core.app.provider_interface import ExtraParams, ProviderChoices, StandardParams
from openbb_core.app.query import Query
{% endif %}
from openbb_core.app.router import Router
from pydantic import BaseModel
# Example dependency injection yielding a configured requests.Session object.
from {{cookiecutter.package_name}}.routers.depends import Session
# The prefix extension's prefix is determined by the `pyproject.toml` EntryPoint assignment.
# Assign a prefix only if this is a sub-router.
router = Router(prefix="")
@@ -52,6 +54,7 @@ async def post_example(
spread = ask - bid
return OBBject(results={"mid": mid, "spread": spread, "flag": flag})
{% if has_provider %}
@router.command(model="Example")
@@ -65,14 +68,13 @@ async def model_example(
return await OBBject.from_query(Query(**locals()))
# If you had another provider installed that mapped to this model - i.e, `openbb-fmp`
# they will be added to this endpoint.
@router.command(model="EquityHistorical")
async def candles(
cc: CommandContext,
provider_choices: ProviderChoices,
standard_params: StandardParams,
extra_params: ExtraParams,
) -> OBBject: # results type is inferred from Fetcher annotations.
) -> OBBject:
"""Example Data."""
return await OBBject.from_query(Query(**locals()))
{% endif %}

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "openbb-cookiecutter"
version = "0.5.0"
version = "0.6.0"
description = "Extensions template for the OpenBB Python Package."
license = "AGPL-3.0-only"
authors = ["OpenBB Team <hello@openbb.co>"]