Files
OpenBB/tests/conftest.py
Colin Delahunty 141dd68e68 Removed prints (#5367)
* Removed prints

* FIxed typing

* FIxed typing

* Fixed typing

* Fixed typing

* Fixed typing

* Fixed typing

* Fixed typing

* Fix

* Reverted stuff
2023-08-23 18:40:31 +00:00

531 lines
16 KiB
Python

# IMPORTS STANDARD
import json
import os
import pathlib
from typing import Any, Dict, List, Optional, Type
import importlib_metadata
# IMPORTS THIRD-PARTY
import matplotlib
import pandas as pd
import pytest
import yfinance.utils
from _pytest.capture import MultiCapture, SysCapture
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import SubRequest
from _pytest.mark.structures import Mark
# IMPORTS INTERNAL
from openbb_terminal import (
config_terminal,
decorators,
helper_funcs,
)
from openbb_terminal.core.session.current_system import set_system_variable
# pylint: disable=redefined-outer-name
config_terminal.setup_i18n()
DISPLAY_LIMIT: int = 500
EXTENSIONS_ALLOWED: List[str] = ["csv", "json", "txt"]
EXTENSIONS_MATCHING: Dict[str, List[Type]] = {
"csv": [pd.DataFrame, pd.Series],
"json": [bool, dict, float, int, list, tuple],
"txt": [str],
}
set_system_variable("TEST_MODE", True)
set_system_variable("LOG_COLLECT", False)
class Record:
@staticmethod
def extract_string(data: Any, **kwargs) -> str:
if isinstance(data, tuple(EXTENSIONS_MATCHING["txt"])):
string_value = data
elif isinstance(data, tuple(EXTENSIONS_MATCHING["csv"])):
string_value = data.to_csv(
encoding="utf-8",
lineterminator="\n",
# date_format="%Y-%m-%d %H:%M:%S",
**kwargs,
)
elif isinstance(data, tuple(EXTENSIONS_MATCHING["json"])):
string_value = json.dumps(data, **kwargs)
else:
raise AttributeError(f"Unsupported type : {type(data)}")
return string_value.replace("\r\n", "\n")
@staticmethod
def load_string(path: str) -> Optional[str]:
if os.path.exists(path):
with open(
file=path,
encoding="utf-8",
newline="\n", # Windows: newline="\r\n" Which is BAD
) as f:
return f.read()
else:
return None
@property
def captured(self) -> str:
return self.__captured
@property
def strip(self) -> bool:
return self.__strip
@property
def record_changed(self) -> bool:
return (
self.__recorded is None
or (self.__strip and self.__recorded.strip() != self.__captured.strip())
or (not self.__strip and self.__recorded != self.__captured)
)
@property
def record_exists(self) -> bool:
return self.__recorded is not None
@property
def record_path(self) -> str:
return self.__record_path
@property
def recorded(self) -> Optional[str]:
return self.__recorded
def recorded_reload(self):
record_path = self.__record_path
self.__recorded = self.load_string(path=record_path)
def __init__(
self, captured: Any, record_path: str, strip: bool = False, **kwargs
) -> None:
self.__captured = self.extract_string(data=captured, **kwargs)
self.__record_path = record_path
self.__strip = strip
self.__recorded = self.load_string(path=record_path)
def persist(self):
record_path = self.__record_path
captured = self.__captured
record_dir_name = os.path.dirname(record_path)
# CREATE FOLDER
if not os.path.exists(record_dir_name):
pathlib.Path(record_dir_name).mkdir(parents=True, exist_ok=True)
# SAVE FILE
with open(
file=record_path,
mode="w",
encoding="utf-8",
newline="\n", # Windows: newline="\r\n" Which is BAD
) as f:
f.write(captured)
# RELOAD RECORDED CONTENT
self.recorded_reload()
class PathTemplate:
@staticmethod
def find_extension(data: Any):
for extension, type_list in EXTENSIONS_MATCHING.items():
if isinstance(data, tuple(type_list)):
return extension
raise Exception(f"No extension found for this type : {type(data)}")
def __init__(self, module_dir: str, module_name: str, test_name: str) -> None:
self.__module_dir = module_dir
self.__module_name = module_name
self.__test_name = test_name
def build_path_by_extension(self, extension: str, index: int = 0):
if extension not in EXTENSIONS_ALLOWED:
raise Exception(f"Unsupported extension : {extension}")
path = os.path.join(
self.__module_dir, extension, self.__module_name, self.__test_name
)
if index:
path += "_" + str(index)
path += "."
path += extension
return path
def build_path_by_data(self, data: Any, index: int = 0):
extension = self.find_extension(data=data)
return self.build_path_by_extension(extension=extension, index=index)
class Recorder:
@property
def display_limit(self) -> int:
return self.__display_limit
@display_limit.setter
def display_limit(self, display_limit: int):
self.__display_limit = display_limit
@property
def rewrite_expected(self) -> bool:
return self.__rewrite_expected
@rewrite_expected.setter
def rewrite_expected(self, rewrite_expected: bool):
self.__rewrite_expected = rewrite_expected
@property
def path_template(self) -> PathTemplate:
return self.__path_template
@property
def record_mode(self) -> str:
return self.__record_mode
@record_mode.setter
def record_mode(self, record_mode: str):
self.__record_mode = record_mode
def __init__(
self,
path_template: PathTemplate,
record_mode: str,
display_limit: int = DISPLAY_LIMIT,
rewrite_expected: bool = False,
) -> None:
self.__path_template = path_template
self.__record_mode = record_mode
self.__display_limit = display_limit
self.__rewrite_expected = rewrite_expected
self.__record_list: List[Record] = list()
def capture(self, captured: Any, strip: bool = False, **kwargs):
record_list = self.__record_list
record_path = self.__path_template.build_path_by_data(
data=captured, index=len(record_list)
)
record = Record(
captured=captured, record_path=record_path, strip=strip, **kwargs
)
self.__record_list.append(record)
def capture_list(self, captured_list: List[Any], strip: bool = False):
for captured in captured_list:
self.capture(captured=captured, strip=strip)
def assert_equal(self):
record_list = self.__record_list
for record in record_list:
if record.record_changed:
raise AssertionError(
"Change detected\n"
f"Record : {record.record_path}\n"
f"Expected : {record.recorded[:self.display_limit]}\n"
f"Actual : {record.captured[:self.display_limit]}\n"
)
def assert_in_list(self, in_list: List[str]):
record_list = self.__record_list
for record in record_list:
for string_value in in_list:
assert string_value in record.captured
def persist(self):
record_list = self.__record_list
record_mode = self.__record_mode
rewrite_expected = self.__rewrite_expected
for record in record_list:
if record_mode == "all":
save = True
elif record_mode == "new_episodes":
save = record.record_changed
elif record_mode == "none":
if not record.record_exists:
raise Exception("You are using `record-mode=none`.")
save = False
elif record_mode == "once":
save = not record.record_exists
elif record_mode == "rewrite":
save = True
else:
raise Exception(f"Unknown `record-mode` : {record_mode}")
if save or rewrite_expected:
record.persist()
def build_path_by_extension(
request: SubRequest, extension: str, create_folder: bool = False
) -> str:
# SETUP PATH TEMPLATE
module_dir = request.node.fspath.dirname
module_name = request.node.fspath.purebasename
test_name = request.node.name
path_template = PathTemplate(
module_dir=module_dir, module_name=module_name, test_name=test_name
)
# BUILD PATH
path = path_template.build_path_by_extension(extension)
# CREATE FOLDER
if create_folder:
dir_name = os.path.dirname(path)
if not os.path.exists(dir_name):
dir_name = os.path.dirname(path)
pathlib.Path(dir_name).mkdir(parents=True, exist_ok=True)
return path
def merge_markers_kwargs(markers: List[Mark]) -> Dict[str, Any]:
"""Merge all kwargs into a single dictionary."""
kwargs: Dict[str, Any] = dict()
for marker in reversed(markers):
kwargs.update(marker.kwargs)
return kwargs
def record_stdout_format_kwargs(
test_name: str, record_mode: str, record_stdout_markers: List[Mark]
) -> Dict[str, Any]:
kwargs = merge_markers_kwargs(record_stdout_markers)
formatted_fields = dict()
formatted_fields["assert_in_list"] = kwargs.get("assert_in_list", list())
formatted_fields["display_limit"] = kwargs.get("display_limit", DISPLAY_LIMIT)
formatted_fields["record_mode"] = kwargs.get("record_mode", record_mode)
formatted_fields["record_name"] = kwargs.get("record_name", test_name)
formatted_fields["save_record"] = kwargs.get("save_record", True)
formatted_fields["strip"] = kwargs.get("strip", True)
return formatted_fields
def pytest_addoption(parser: Parser):
parser.addoption(
"--forecast",
action="store_true",
help="To run tests with the marker : @pytest.mark.forecast",
)
parser.addoption(
"--optimization",
action="store_true",
help="To run tests with the marker : @pytest.mark.optimization",
)
parser.addoption(
"--session",
action="store_true",
help="To run tests with the marker : @pytest.mark.session",
)
parser.addoption(
"--rewrite-expected",
action="store_true",
help="To force `record_stdout` and `recorder` to rewrite all files.",
)
parser.addoption(
"--autodoc",
action="store_true",
default=False,
help="run auto documentation tests",
)
def brotli_check():
for item in importlib_metadata.packages_distributions():
if "brotli" in str(item).lower():
pytest.exit("Uninstall brotli and brotlipy before running tests")
def disable_rich():
def effect(df, *xargs, **kwargs): # pylint: disable=unused-argument
print(df.to_string()) # noqa: T201
helper_funcs.print_rich_table = effect
def disable_matplotlib():
# We add this to avoid multiple figures being opened
matplotlib.use("Agg")
def disable_check_api():
decorators.disable_check_api()
def enable_debug():
set_system_variable("DEBUG_MODE", True)
def pytest_configure(config: Config) -> None:
config.addinivalue_line("markers", "record_stdout: Mark the test as text record.")
brotli_check()
enable_debug()
disable_rich()
disable_check_api()
disable_matplotlib()
@pytest.fixture(scope="session") # type: ignore
def rewrite_expected(request: SubRequest) -> bool:
"""Force rewriting of all expected data by : `record_stdout` and `recorder`."""
return request.config.getoption("--rewrite-expected")
@pytest.fixture(autouse=True)
def mock_matplotlib(mocker):
mocker.patch("matplotlib.pyplot.show")
@pytest.fixture(autouse=True)
def mock_plotly(mocker):
mocker.patch("plotly.io.show")
# pylint: disable=protected-access
@pytest.fixture(autouse=True)
def mock_yfinance_tzcache(mocker):
mocker.patch.object(
target=yfinance.utils,
attribute="_TzCache",
new=yfinance.utils._TzCacheDummy,
)
@pytest.fixture
def default_csv_path(request: SubRequest) -> str:
return build_path_by_extension(request=request, extension="csv", create_folder=True)
@pytest.fixture
def default_txt_path(request: SubRequest) -> str:
return build_path_by_extension(request=request, extension="txt", create_folder=True)
@pytest.fixture
def default_json_path(request: SubRequest) -> str:
return build_path_by_extension(
request=request, extension="json", create_folder=True
)
@pytest.fixture
def record_stdout_markers(request: SubRequest) -> List[Mark]:
"""All markers applied to the certain test together with cassette names associated with each marker."""
return list(request.node.iter_markers(name="record_stdout"))
@pytest.fixture(autouse=True)
def record_stdout(
disable_recording: bool,
rewrite_expected: bool,
record_stdout_markers: List[Mark],
record_mode: str,
request: SubRequest,
):
marker = request.node.get_closest_marker("record_stdout")
if disable_recording:
yield None
elif marker:
# SETUP TEST DETAILS
module_dir = request.node.fspath.dirname
module_name = request.node.fspath.purebasename
test_name = request.node.name
# FORMAT MARKER'S KEYWORD ARGUMENTS
formatted_kwargs = record_stdout_format_kwargs(
test_name=test_name,
record_mode=record_mode,
record_stdout_markers=record_stdout_markers,
)
# SETUP RECORDER
path_template = PathTemplate(
module_dir=module_dir,
module_name=module_name,
test_name=formatted_kwargs["record_name"],
)
recorder = Recorder(
path_template=path_template,
record_mode=formatted_kwargs["record_mode"],
display_limit=formatted_kwargs["display_limit"],
rewrite_expected=rewrite_expected,
)
# CAPTURE STDOUT
capture = request.config.getoption("--capture")
if capture == "no":
global_capturing = MultiCapture(
in_=SysCapture(0), out=SysCapture(1), err=SysCapture(2)
)
global_capturing.start_capturing()
yield
recorder.capture(
captured=global_capturing.readouterr().out,
strip=formatted_kwargs["strip"],
)
global_capturing.stop_capturing()
else:
capsys = request.getfixturevalue("capsys")
yield
recorder.capture(
captured=capsys.readouterr().out, strip=formatted_kwargs["strip"]
)
# SAVE/CHECK RECORD
if formatted_kwargs["save_record"]:
recorder.persist()
recorder.assert_equal()
recorder.assert_in_list(in_list=formatted_kwargs["assert_in_list"])
else:
recorder.assert_in_list(in_list=formatted_kwargs["assert_in_list"])
else:
yield None
@pytest.fixture
def recorder(
disable_recording: bool,
rewrite_expected: bool,
record_mode: str,
request: SubRequest,
):
marker_record_stdout = request.node.get_closest_marker("record_stdout")
module_dir = request.node.fspath.dirname
module_name = request.node.fspath.purebasename
test_name = request.node.name
path_template = PathTemplate(
module_dir=module_dir, module_name=module_name, test_name=test_name
)
if disable_recording:
yield None
elif marker_record_stdout:
raise Exception(
"You can't combine both of these fixtures : `record_stdout marker`, `recorder`."
)
else:
recorder = Recorder(
path_template, record_mode, rewrite_expected=rewrite_expected
)
yield recorder
recorder.persist()
recorder.assert_equal()