more typing sanity

This commit is contained in:
James Turk 2023-07-11 23:39:01 -05:00
parent b49a1914b5
commit d33483f136
4 changed files with 83 additions and 90 deletions

View File

@ -2,77 +2,67 @@ import abc
import json
import sqlite3
import uuid
from pydantic import BaseModel
from typing import Iterable, Type, TYPE_CHECKING
if TYPE_CHECKING:
from .recipe import Recipe
class DataObject:
def __init__(self, id: str | None = None):
self._id = id if id else str(uuid.uuid4())
self._data = {}
def __getattr__(self, name):
return self._data[name]
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
elif name not in self._data:
self._data[name] = value
else:
raise AttributeError(f"DataObject attribute {name} already exists")
PydanticModel = Type[BaseModel]
class Beaker(abc.ABC):
def __init__(self, name: str, model: type, recipe: "Recipe"):
def __init__(self, name: str, model: PydanticModel, recipe: Recipe):
self.name = name
self.model = model
self.recipe = recipe
def __repr__(self):
def __repr__(self) -> str:
return f"Beaker({self.name}, {self.model.__name__})"
@abc.abstractmethod
def items(self):
def items(self) -> Iterable[tuple[str, BaseModel]]:
pass
@abc.abstractmethod
def __len__(self):
def __len__(self) -> int:
pass
@abc.abstractmethod
def add_item(self, item: "T", id: str | None = None) -> None:
def add_item(self, item: BaseModel, id: str | None = None) -> None:
pass
@abc.abstractmethod
def reset(self):
def reset(self) -> None:
pass
def add_items(self, items: list["T"]) -> None:
def add_items(self, items: Iterable[BaseModel]) -> None:
for item in items:
self.add_item(item)
class TempBeaker(Beaker):
def __init__(self, name: str, model: type, recipe: "Recipe"):
def __init__(self, name: str, model: PydanticModel | None, recipe: Recipe):
super().__init__(name, model, recipe)
self._items = []
self._items: list[tuple[str, BaseModel]] = []
def __len__(self):
def __len__(self) -> int:
return len(self._items)
def add_item(self, item: "T", id=None) -> None:
def add_item(self, item: BaseModel, id: str | None = None) -> None:
if id is None:
id = str(uuid.uuid1())
self._items.append((id, item))
def items(self):
def items(self) -> Iterable[tuple[str, BaseModel]]:
yield from self._items
def reset(self):
def reset(self) -> None:
self._items = []
class SqliteBeaker(Beaker):
def __init__(self, name: str, model: type, recipe: "Recipe"):
def __init__(self, name: str, model: PydanticModel, recipe: Recipe):
super().__init__(name, model, recipe)
# create table if it doesn't exist
self.cursor = self.recipe.db.cursor()
@ -81,17 +71,17 @@ class SqliteBeaker(Beaker):
f"CREATE TABLE IF NOT EXISTS {self.name} (uuid TEXT PRIMARY KEY, data JSON)"
)
def items(self):
def items(self) -> Iterable[tuple[str, BaseModel]]:
self.cursor.execute(f"SELECT uuid, data FROM {self.name}")
data = self.cursor.fetchall()
for item in data:
yield item["uuid"], self.model(**json.loads(item["data"]))
def __len__(self):
def __len__(self) -> int:
self.cursor.execute(f"SELECT COUNT(*) FROM {self.name}")
return self.cursor.fetchone()[0]
def add_item(self, item: "T", id: str | None = None) -> None:
def add_item(self, item: BaseModel, id: str | None = None) -> None:
if id is None:
id = str(uuid.uuid1())
print("UUID", id, item)
@ -101,6 +91,6 @@ class SqliteBeaker(Beaker):
)
self.recipe.db.commit()
def reset(self):
def reset(self) -> None:
self.cursor.execute(f"DELETE FROM {self.name}")
self.recipe.db.commit()

View File

@ -11,7 +11,7 @@ from beakers.beakers import SqliteBeaker
app = typer.Typer()
def _load_recipe(dotted_path: str):
def _load_recipe(dotted_path: str) -> SimpleNamespace:
sys.path.append(".")
path, name = dotted_path.rsplit(".", 1)
mod = importlib.import_module(path)
@ -22,7 +22,7 @@ def _load_recipe(dotted_path: str):
def main(
ctx: typer.Context,
recipe: str = typer.Option(None, envvar="BEAKER_RECIPE"),
):
) -> None:
if not recipe:
typer.secho(
"Missing recipe; pass --recipe or set env[BEAKER_RECIPE]",
@ -33,7 +33,7 @@ def main(
@app.command()
def reset(ctx: typer.Context):
def reset(ctx: typer.Context) -> None:
for beaker in ctx.obj.beakers.values():
if isinstance(beaker, SqliteBeaker):
if bl := len(beaker):
@ -44,12 +44,12 @@ def reset(ctx: typer.Context):
@app.command()
def show(ctx: typer.Context):
def show(ctx: typer.Context) -> None:
ctx.obj.show()
@app.command()
def graph(ctx: typer.Context):
def graph(ctx: typer.Context) -> None:
pprint(ctx.obj.graph_data())
@ -59,15 +59,15 @@ def run(
input: Annotated[Optional[List[str]], typer.Option(...)] = None,
start: Optional[str] = typer.Option(None),
end: Optional[str] = typer.Option(None),
):
) -> None:
if ctx.obj.seeds:
typer.secho("Seeding beakers", fg=typer.colors.GREEN)
ctx.obj.process_seeds()
has_data = any(ctx.obj.beakers.values())
if not has_data and not input:
if not input and not has_data:
typer.secho("No data; pass --input to seed beaker(s)", fg=typer.colors.RED)
raise typer.Exit(1)
for input_str in input:
for input_str in input: # type: ignore
beaker, filename = input_str.split("=")
ctx.obj.csv_to_beaker(filename, beaker)
ctx.obj.run_once(start, end)

View File

@ -1,9 +1,9 @@
import httpx
import pydantic
from pydantic import BaseModel, Field
import datetime
class HttpResponse(pydantic.BaseModel):
class HttpResponse(BaseModel):
"""
Beaker data type that represents an HTTP response.
"""
@ -11,9 +11,7 @@ class HttpResponse(pydantic.BaseModel):
url: str
status_code: int
response_body: str
retrieved_at: datetime.datetime = pydantic.Field(
default_factory=datetime.datetime.now
)
retrieved_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
class HttpRequest:
@ -30,9 +28,8 @@ class HttpRequest:
self.beaker = beaker
self.field = field
async def __call__(self, item) -> HttpResponse:
bkr = getattr(item, self.beaker)
url = getattr(bkr, self.field)
async def __call__(self, item: BaseModel) -> HttpResponse:
url = getattr(item, self.field)
async with httpx.AsyncClient() as client:
response = await client.get(url)

View File

@ -8,7 +8,8 @@ import asyncio
import networkx
from collections import defaultdict, Counter
from dataclasses import dataclass # TODO: pydantic?
from typing import Iterable
from pydantic import BaseModel
from typing import Iterable, Callable
from structlog import get_logger
from .beakers import Beaker, SqliteBeaker, TempBeaker
@ -24,10 +25,16 @@ def get_sha512(filename: str) -> str:
@dataclass(frozen=True, eq=True)
class Transform:
name: str
transform_func: callable
transform_func: Callable
error_map: dict[tuple, str]
class ErrorType(BaseModel):
item: BaseModel
exception: str
exc_type: str
def if_cond_true(data_cond_tup: tuple[dict, bool]) -> dict | None:
return data_cond_tup[0] if data_cond_tup[1] else None
@ -37,11 +44,11 @@ def if_cond_false(data_cond_tup: tuple[dict, bool]) -> dict | None:
class Recipe:
def __init__(self, name, db_name="beakers.db"):
def __init__(self, name: str, db_name: str = "beakers.db"):
self.name = name
self.graph = networkx.DiGraph()
self.beakers = {}
self.seeds = defaultdict(list)
self.beakers: dict[str, Beaker] = {}
self.seeds: defaultdict[str, list[Iterable[BaseModel]]] = defaultdict(list)
self.db = sqlite3.connect(db_name)
cursor = self.db.cursor()
cursor.execute(
@ -63,9 +70,9 @@ class Recipe:
self,
from_beaker: str,
to_beaker: str,
transform_func: callable,
transform_func: Callable,
*,
name=None,
name: str | None = None,
error_map: dict[tuple, str] | None = None,
) -> None:
if name is None:
@ -86,7 +93,7 @@ class Recipe:
def add_conditional(
self,
from_beaker: str,
condition_func: callable,
condition_func: Callable,
if_true: str,
if_false: str = "",
) -> None:
@ -117,7 +124,7 @@ class Recipe:
if_cond_false,
)
def add_seed(self, beaker_name: str, data: Iterable) -> None:
def add_seed(self, beaker_name: str, data: Iterable[BaseModel]) -> None:
self.seeds[beaker_name].append(data)
def process_seeds(self) -> None:
@ -126,30 +133,30 @@ class Recipe:
for seed in seeds:
self.beakers[beaker_name].add_items(seed)
def get_metadata(self, table_name) -> dict:
cursor = self.db.cursor()
cursor.execute(
"SELECT data FROM _metadata WHERE table_name = ?",
(table_name,),
)
try:
data = cursor.fetchone()["data"]
log.debug("get_metadata", table_name=table_name, data=data)
return json.loads(data)
except TypeError:
log.debug("get_metadata", table_name=table_name, data={})
return {}
# def get_metadata(self, table_name: str) -> dict:
# cursor = self.db.cursor()
# cursor.execute(
# "SELECT data FROM _metadata WHERE table_name = ?",
# (table_name,),
# )
# try:
# data = cursor.fetchone()["data"]
# log.debug("get_metadata", table_name=table_name, data=data)
# return json.loads(data)
# except TypeError:
# log.debug("get_metadata", table_name=table_name, data={})
# return {}
def save_metadata(self, table_name: str, data: dict) -> None:
data_json = json.dumps(data)
log.info("save_metadata", table_name=table_name, data=data_json)
# sqlite upsert
cursor = self.db.cursor()
cursor.execute(
"INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?",
(table_name, data_json, data_json),
)
self.db.commit()
# def save_metadata(self, table_name: str, data: dict) -> None:
# data_json = json.dumps(data)
# log.info("save_metadata", table_name=table_name, data=data_json)
# # sqlite upsert
# cursor = self.db.cursor()
# cursor.execute(
# "INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?",
# (table_name, data_json, data_json),
# )
# self.db.commit()
def csv_to_beaker(self, filename: str, beaker_name: str) -> None:
beaker = self.beakers[beaker_name]
@ -161,7 +168,7 @@ class Recipe:
reader = csv.DictReader(file)
added = 0
for row in reader:
beaker.add_item(row)
beaker.add_item(beaker.model(**row))
added += 1
lg.info("from_csv", case="empty", added=added)
meta = self.get_metadata(beaker.name)
@ -177,9 +184,8 @@ class Recipe:
else:
# case 2: match
lg.info("from_csv", case="match")
return beaker
def show(self):
def show(self) -> None:
seed_count = Counter(self.seeds.keys())
typer.secho("Seeds", fg=typer.colors.GREEN)
for beaker, count in seed_count.items():
@ -204,7 +210,7 @@ class Recipe:
else:
typer.secho(f" {k.__name__} -> {v}", fg=typer.colors.RED)
def graph_data(self):
def graph_data(self) -> list[dict]:
nodes = {}
for node in networkx.topological_sort(self.graph):
@ -290,11 +296,11 @@ class Recipe:
if isinstance(e, error_types):
error_beaker = self.beakers[error_beaker_name]
error_beaker.add_item(
{
"item": item,
"exception": str(e),
"exc_type": str(type(e)),
},
ErrorType(
item=item,
exception=str(e),
exc_type=str(type(e)),
),
id,
)
break