From 0998fe1e6a05496653b85fa66fd559b3de6de447 Mon Sep 17 00:00:00 2001 From: James Turk Date: Tue, 11 Jul 2023 23:53:18 -0500 Subject: [PATCH] getting close on types --- src/beakers/beakers.py | 8 +++--- src/beakers/cli.py | 2 +- src/beakers/recipe.py | 65 +++++++++++++++++++++++------------------- 3 files changed, 40 insertions(+), 35 deletions(-) diff --git a/src/beakers/beakers.py b/src/beakers/beakers.py index 8acbadd..07ae9e2 100644 --- a/src/beakers/beakers.py +++ b/src/beakers/beakers.py @@ -12,7 +12,7 @@ PydanticModel = Type[BaseModel] class Beaker(abc.ABC): - def __init__(self, name: str, model: PydanticModel, recipe: Recipe): + def __init__(self, name: str, model: PydanticModel, recipe: "Recipe"): self.name = name self.model = model self.recipe = recipe @@ -42,7 +42,7 @@ class Beaker(abc.ABC): class TempBeaker(Beaker): - def __init__(self, name: str, model: PydanticModel | None, recipe: Recipe): + def __init__(self, name: str, model: PydanticModel, recipe: "Recipe"): super().__init__(name, model, recipe) self._items: list[tuple[str, BaseModel]] = [] @@ -62,11 +62,11 @@ class TempBeaker(Beaker): class SqliteBeaker(Beaker): - def __init__(self, name: str, model: PydanticModel, recipe: Recipe): + def __init__(self, name: str, model: PydanticModel, recipe: "Recipe"): super().__init__(name, model, recipe) # create table if it doesn't exist self.cursor = self.recipe.db.cursor() - self.cursor.row_factory = sqlite3.Row + self.cursor.row_factory = sqlite3.Row # type: ignore self.cursor.execute( f"CREATE TABLE IF NOT EXISTS {self.name} (uuid TEXT PRIMARY KEY, data JSON)" ) diff --git a/src/beakers/cli.py b/src/beakers/cli.py index 03aa68b..94bdde1 100644 --- a/src/beakers/cli.py +++ b/src/beakers/cli.py @@ -67,7 +67,7 @@ def run( if not input and not has_data: typer.secho("No data; pass --input to seed beaker(s)", fg=typer.colors.RED) raise typer.Exit(1) - for input_str in input: # type: ignore + for input_str in input or []: beaker, filename = input_str.split("=") ctx.obj.csv_to_beaker(filename, beaker) ctx.obj.run_once(start, end) diff --git a/src/beakers/recipe.py b/src/beakers/recipe.py index a8779e9..cef4232 100644 --- a/src/beakers/recipe.py +++ b/src/beakers/recipe.py @@ -5,11 +5,10 @@ import inspect import sqlite3 import hashlib import asyncio -import networkx +import networkx # type: ignore from collections import defaultdict, Counter -from dataclasses import dataclass # TODO: pydantic? -from pydantic import BaseModel -from typing import Iterable, Callable +from typing import Iterable, Callable, Type +from pydantic import BaseModel, ConfigDict from structlog import get_logger from .beakers import Beaker, SqliteBeaker, TempBeaker @@ -22,8 +21,9 @@ def get_sha512(filename: str) -> str: return hashlib.sha512(file.read()).hexdigest() -@dataclass(frozen=True, eq=True) -class Transform: +class Transform(BaseModel): + model_config = ConfigDict(frozen=True) + name: str transform_func: Callable error_map: dict[tuple, str] @@ -58,7 +58,12 @@ class Recipe: def __repr__(self) -> str: return f"Recipe({self.name})" - def add_beaker(self, name: str, datatype: type | None) -> Beaker: + def add_beaker( + self, + name: str, + datatype: Type[BaseModel], + beaker_type: Type[Beaker] = SqliteBeaker, + ) -> Beaker: self.graph.add_node(name, datatype=datatype) if datatype is None: self.beakers[name] = TempBeaker(name, datatype, self) @@ -133,30 +138,30 @@ class Recipe: for seed in seeds: self.beakers[beaker_name].add_items(seed) - # def get_metadata(self, table_name: str) -> dict: - # cursor = self.db.cursor() - # cursor.execute( - # "SELECT data FROM _metadata WHERE table_name = ?", - # (table_name,), - # ) - # try: - # data = cursor.fetchone()["data"] - # log.debug("get_metadata", table_name=table_name, data=data) - # return json.loads(data) - # except TypeError: - # log.debug("get_metadata", table_name=table_name, data={}) - # return {} + def get_metadata(self, table_name: str) -> dict: + cursor = self.db.cursor() + cursor.execute( + "SELECT data FROM _metadata WHERE table_name = ?", + (table_name,), + ) + try: + data = cursor.fetchone()["data"] + log.debug("get_metadata", table_name=table_name, data=data) + return json.loads(data) + except TypeError: + log.debug("get_metadata", table_name=table_name, data={}) + return {} - # def save_metadata(self, table_name: str, data: dict) -> None: - # data_json = json.dumps(data) - # log.info("save_metadata", table_name=table_name, data=data_json) - # # sqlite upsert - # cursor = self.db.cursor() - # cursor.execute( - # "INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?", - # (table_name, data_json, data_json), - # ) - # self.db.commit() + def save_metadata(self, table_name: str, data: dict) -> None: + data_json = json.dumps(data) + log.info("save_metadata", table_name=table_name, data=data_json) + # sqlite upsert + cursor = self.db.cursor() + cursor.execute( + "INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?", + (table_name, data_json, data_json), + ) + self.db.commit() def csv_to_beaker(self, filename: str, beaker_name: str) -> None: beaker = self.beakers[beaker_name]