From d33483f1363561628d5b5d6129b6a6e95a25b71b Mon Sep 17 00:00:00 2001 From: James Turk Date: Tue, 11 Jul 2023 23:39:01 -0500 Subject: [PATCH] more typing sanity --- src/beakers/beakers.py | 56 +++++++++++---------------- src/beakers/cli.py | 16 ++++---- src/beakers/http.py | 13 +++---- src/beakers/recipe.py | 88 ++++++++++++++++++++++-------------------- 4 files changed, 83 insertions(+), 90 deletions(-) diff --git a/src/beakers/beakers.py b/src/beakers/beakers.py index fdc9174..8acbadd 100644 --- a/src/beakers/beakers.py +++ b/src/beakers/beakers.py @@ -2,77 +2,67 @@ import abc import json import sqlite3 import uuid +from pydantic import BaseModel +from typing import Iterable, Type, TYPE_CHECKING +if TYPE_CHECKING: + from .recipe import Recipe -class DataObject: - def __init__(self, id: str | None = None): - self._id = id if id else str(uuid.uuid4()) - self._data = {} - - def __getattr__(self, name): - return self._data[name] - - def __setattr__(self, name, value): - if name.startswith("_"): - super().__setattr__(name, value) - elif name not in self._data: - self._data[name] = value - else: - raise AttributeError(f"DataObject attribute {name} already exists") +PydanticModel = Type[BaseModel] class Beaker(abc.ABC): - def __init__(self, name: str, model: type, recipe: "Recipe"): + def __init__(self, name: str, model: PydanticModel, recipe: Recipe): self.name = name self.model = model self.recipe = recipe - def __repr__(self): + def __repr__(self) -> str: return f"Beaker({self.name}, {self.model.__name__})" @abc.abstractmethod - def items(self): + def items(self) -> Iterable[tuple[str, BaseModel]]: pass @abc.abstractmethod - def __len__(self): + def __len__(self) -> int: pass @abc.abstractmethod - def add_item(self, item: "T", id: str | None = None) -> None: + def add_item(self, item: BaseModel, id: str | None = None) -> None: pass @abc.abstractmethod - def reset(self): + def reset(self) -> None: pass - def add_items(self, items: list["T"]) -> None: + def add_items(self, items: Iterable[BaseModel]) -> None: for item in items: self.add_item(item) class TempBeaker(Beaker): - def __init__(self, name: str, model: type, recipe: "Recipe"): + def __init__(self, name: str, model: PydanticModel | None, recipe: Recipe): super().__init__(name, model, recipe) - self._items = [] + self._items: list[tuple[str, BaseModel]] = [] - def __len__(self): + def __len__(self) -> int: return len(self._items) - def add_item(self, item: "T", id=None) -> None: + def add_item(self, item: BaseModel, id: str | None = None) -> None: if id is None: id = str(uuid.uuid1()) self._items.append((id, item)) - def items(self): + def items(self) -> Iterable[tuple[str, BaseModel]]: yield from self._items - def reset(self): + def reset(self) -> None: self._items = [] class SqliteBeaker(Beaker): - def __init__(self, name: str, model: type, recipe: "Recipe"): + def __init__(self, name: str, model: PydanticModel, recipe: Recipe): super().__init__(name, model, recipe) # create table if it doesn't exist self.cursor = self.recipe.db.cursor() @@ -81,17 +71,17 @@ class SqliteBeaker(Beaker): f"CREATE TABLE IF NOT EXISTS {self.name} (uuid TEXT PRIMARY KEY, data JSON)" ) - def items(self): + def items(self) -> Iterable[tuple[str, BaseModel]]: self.cursor.execute(f"SELECT uuid, data FROM {self.name}") data = self.cursor.fetchall() for item in data: yield item["uuid"], self.model(**json.loads(item["data"])) - def __len__(self): + def __len__(self) -> int: self.cursor.execute(f"SELECT COUNT(*) FROM {self.name}") return self.cursor.fetchone()[0] - def add_item(self, item: "T", id: str | None = None) -> None: + def add_item(self, item: BaseModel, id: str | None = None) -> None: if id is None: id = str(uuid.uuid1()) print("UUID", id, item) @@ -101,6 +91,6 @@ class SqliteBeaker(Beaker): ) self.recipe.db.commit() - def reset(self): + def reset(self) -> None: self.cursor.execute(f"DELETE FROM {self.name}") self.recipe.db.commit() diff --git a/src/beakers/cli.py b/src/beakers/cli.py index 1b753e0..03aa68b 100644 --- a/src/beakers/cli.py +++ b/src/beakers/cli.py @@ -11,7 +11,7 @@ from beakers.beakers import SqliteBeaker app = typer.Typer() -def _load_recipe(dotted_path: str): +def _load_recipe(dotted_path: str) -> SimpleNamespace: sys.path.append(".") path, name = dotted_path.rsplit(".", 1) mod = importlib.import_module(path) @@ -22,7 +22,7 @@ def _load_recipe(dotted_path: str): def main( ctx: typer.Context, recipe: str = typer.Option(None, envvar="BEAKER_RECIPE"), -): +) -> None: if not recipe: typer.secho( "Missing recipe; pass --recipe or set env[BEAKER_RECIPE]", @@ -33,7 +33,7 @@ def main( @app.command() -def reset(ctx: typer.Context): +def reset(ctx: typer.Context) -> None: for beaker in ctx.obj.beakers.values(): if isinstance(beaker, SqliteBeaker): if bl := len(beaker): @@ -44,12 +44,12 @@ def reset(ctx: typer.Context): @app.command() -def show(ctx: typer.Context): +def show(ctx: typer.Context) -> None: ctx.obj.show() @app.command() -def graph(ctx: typer.Context): +def graph(ctx: typer.Context) -> None: pprint(ctx.obj.graph_data()) @@ -59,15 +59,15 @@ def run( input: Annotated[Optional[List[str]], typer.Option(...)] = None, start: Optional[str] = typer.Option(None), end: Optional[str] = typer.Option(None), -): +) -> None: if ctx.obj.seeds: typer.secho("Seeding beakers", fg=typer.colors.GREEN) ctx.obj.process_seeds() has_data = any(ctx.obj.beakers.values()) - if not has_data and not input: + if not input and not has_data: typer.secho("No data; pass --input to seed beaker(s)", fg=typer.colors.RED) raise typer.Exit(1) - for input_str in input: + for input_str in input: # type: ignore beaker, filename = input_str.split("=") ctx.obj.csv_to_beaker(filename, beaker) ctx.obj.run_once(start, end) diff --git a/src/beakers/http.py b/src/beakers/http.py index b313eb9..1f518b5 100644 --- a/src/beakers/http.py +++ b/src/beakers/http.py @@ -1,9 +1,9 @@ import httpx -import pydantic +from pydantic import BaseModel, Field import datetime -class HttpResponse(pydantic.BaseModel): +class HttpResponse(BaseModel): """ Beaker data type that represents an HTTP response. """ @@ -11,9 +11,7 @@ class HttpResponse(pydantic.BaseModel): url: str status_code: int response_body: str - retrieved_at: datetime.datetime = pydantic.Field( - default_factory=datetime.datetime.now - ) + retrieved_at: datetime.datetime = Field(default_factory=datetime.datetime.now) class HttpRequest: @@ -30,9 +28,8 @@ class HttpRequest: self.beaker = beaker self.field = field - async def __call__(self, item) -> HttpResponse: - bkr = getattr(item, self.beaker) - url = getattr(bkr, self.field) + async def __call__(self, item: BaseModel) -> HttpResponse: + url = getattr(item, self.field) async with httpx.AsyncClient() as client: response = await client.get(url) diff --git a/src/beakers/recipe.py b/src/beakers/recipe.py index a8247f4..a8779e9 100644 --- a/src/beakers/recipe.py +++ b/src/beakers/recipe.py @@ -8,7 +8,8 @@ import asyncio import networkx from collections import defaultdict, Counter from dataclasses import dataclass # TODO: pydantic? -from typing import Iterable +from pydantic import BaseModel +from typing import Iterable, Callable from structlog import get_logger from .beakers import Beaker, SqliteBeaker, TempBeaker @@ -24,10 +25,16 @@ def get_sha512(filename: str) -> str: @dataclass(frozen=True, eq=True) class Transform: name: str - transform_func: callable + transform_func: Callable error_map: dict[tuple, str] +class ErrorType(BaseModel): + item: BaseModel + exception: str + exc_type: str + + def if_cond_true(data_cond_tup: tuple[dict, bool]) -> dict | None: return data_cond_tup[0] if data_cond_tup[1] else None @@ -37,11 +44,11 @@ def if_cond_false(data_cond_tup: tuple[dict, bool]) -> dict | None: class Recipe: - def __init__(self, name, db_name="beakers.db"): + def __init__(self, name: str, db_name: str = "beakers.db"): self.name = name self.graph = networkx.DiGraph() - self.beakers = {} - self.seeds = defaultdict(list) + self.beakers: dict[str, Beaker] = {} + self.seeds: defaultdict[str, list[Iterable[BaseModel]]] = defaultdict(list) self.db = sqlite3.connect(db_name) cursor = self.db.cursor() cursor.execute( @@ -63,9 +70,9 @@ class Recipe: self, from_beaker: str, to_beaker: str, - transform_func: callable, + transform_func: Callable, *, - name=None, + name: str | None = None, error_map: dict[tuple, str] | None = None, ) -> None: if name is None: @@ -86,7 +93,7 @@ class Recipe: def add_conditional( self, from_beaker: str, - condition_func: callable, + condition_func: Callable, if_true: str, if_false: str = "", ) -> None: @@ -117,7 +124,7 @@ class Recipe: if_cond_false, ) - def add_seed(self, beaker_name: str, data: Iterable) -> None: + def add_seed(self, beaker_name: str, data: Iterable[BaseModel]) -> None: self.seeds[beaker_name].append(data) def process_seeds(self) -> None: @@ -126,30 +133,30 @@ class Recipe: for seed in seeds: self.beakers[beaker_name].add_items(seed) - def get_metadata(self, table_name) -> dict: - cursor = self.db.cursor() - cursor.execute( - "SELECT data FROM _metadata WHERE table_name = ?", - (table_name,), - ) - try: - data = cursor.fetchone()["data"] - log.debug("get_metadata", table_name=table_name, data=data) - return json.loads(data) - except TypeError: - log.debug("get_metadata", table_name=table_name, data={}) - return {} + # def get_metadata(self, table_name: str) -> dict: + # cursor = self.db.cursor() + # cursor.execute( + # "SELECT data FROM _metadata WHERE table_name = ?", + # (table_name,), + # ) + # try: + # data = cursor.fetchone()["data"] + # log.debug("get_metadata", table_name=table_name, data=data) + # return json.loads(data) + # except TypeError: + # log.debug("get_metadata", table_name=table_name, data={}) + # return {} - def save_metadata(self, table_name: str, data: dict) -> None: - data_json = json.dumps(data) - log.info("save_metadata", table_name=table_name, data=data_json) - # sqlite upsert - cursor = self.db.cursor() - cursor.execute( - "INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?", - (table_name, data_json, data_json), - ) - self.db.commit() + # def save_metadata(self, table_name: str, data: dict) -> None: + # data_json = json.dumps(data) + # log.info("save_metadata", table_name=table_name, data=data_json) + # # sqlite upsert + # cursor = self.db.cursor() + # cursor.execute( + # "INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?", + # (table_name, data_json, data_json), + # ) + # self.db.commit() def csv_to_beaker(self, filename: str, beaker_name: str) -> None: beaker = self.beakers[beaker_name] @@ -161,7 +168,7 @@ class Recipe: reader = csv.DictReader(file) added = 0 for row in reader: - beaker.add_item(row) + beaker.add_item(beaker.model(**row)) added += 1 lg.info("from_csv", case="empty", added=added) meta = self.get_metadata(beaker.name) @@ -177,9 +184,8 @@ class Recipe: else: # case 2: match lg.info("from_csv", case="match") - return beaker - def show(self): + def show(self) -> None: seed_count = Counter(self.seeds.keys()) typer.secho("Seeds", fg=typer.colors.GREEN) for beaker, count in seed_count.items(): @@ -204,7 +210,7 @@ class Recipe: else: typer.secho(f" {k.__name__} -> {v}", fg=typer.colors.RED) - def graph_data(self): + def graph_data(self) -> list[dict]: nodes = {} for node in networkx.topological_sort(self.graph): @@ -290,11 +296,11 @@ class Recipe: if isinstance(e, error_types): error_beaker = self.beakers[error_beaker_name] error_beaker.add_item( - { - "item": item, - "exception": str(e), - "exc_type": str(type(e)), - }, + ErrorType( + item=item, + exception=str(e), + exc_type=str(type(e)), + ), id, ) break