more typing sanity

This commit is contained in:
James Turk 2023-07-11 23:39:01 -05:00
parent b49a1914b5
commit d33483f136
4 changed files with 83 additions and 90 deletions

View File

@ -2,77 +2,67 @@ import abc
import json import json
import sqlite3 import sqlite3
import uuid import uuid
from pydantic import BaseModel
from typing import Iterable, Type, TYPE_CHECKING
if TYPE_CHECKING:
from .recipe import Recipe
class DataObject: PydanticModel = Type[BaseModel]
def __init__(self, id: str | None = None):
self._id = id if id else str(uuid.uuid4())
self._data = {}
def __getattr__(self, name):
return self._data[name]
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
elif name not in self._data:
self._data[name] = value
else:
raise AttributeError(f"DataObject attribute {name} already exists")
class Beaker(abc.ABC): class Beaker(abc.ABC):
def __init__(self, name: str, model: type, recipe: "Recipe"): def __init__(self, name: str, model: PydanticModel, recipe: Recipe):
self.name = name self.name = name
self.model = model self.model = model
self.recipe = recipe self.recipe = recipe
def __repr__(self): def __repr__(self) -> str:
return f"Beaker({self.name}, {self.model.__name__})" return f"Beaker({self.name}, {self.model.__name__})"
@abc.abstractmethod @abc.abstractmethod
def items(self): def items(self) -> Iterable[tuple[str, BaseModel]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def __len__(self): def __len__(self) -> int:
pass pass
@abc.abstractmethod @abc.abstractmethod
def add_item(self, item: "T", id: str | None = None) -> None: def add_item(self, item: BaseModel, id: str | None = None) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def reset(self): def reset(self) -> None:
pass pass
def add_items(self, items: list["T"]) -> None: def add_items(self, items: Iterable[BaseModel]) -> None:
for item in items: for item in items:
self.add_item(item) self.add_item(item)
class TempBeaker(Beaker): class TempBeaker(Beaker):
def __init__(self, name: str, model: type, recipe: "Recipe"): def __init__(self, name: str, model: PydanticModel | None, recipe: Recipe):
super().__init__(name, model, recipe) super().__init__(name, model, recipe)
self._items = [] self._items: list[tuple[str, BaseModel]] = []
def __len__(self): def __len__(self) -> int:
return len(self._items) return len(self._items)
def add_item(self, item: "T", id=None) -> None: def add_item(self, item: BaseModel, id: str | None = None) -> None:
if id is None: if id is None:
id = str(uuid.uuid1()) id = str(uuid.uuid1())
self._items.append((id, item)) self._items.append((id, item))
def items(self): def items(self) -> Iterable[tuple[str, BaseModel]]:
yield from self._items yield from self._items
def reset(self): def reset(self) -> None:
self._items = [] self._items = []
class SqliteBeaker(Beaker): class SqliteBeaker(Beaker):
def __init__(self, name: str, model: type, recipe: "Recipe"): def __init__(self, name: str, model: PydanticModel, recipe: Recipe):
super().__init__(name, model, recipe) super().__init__(name, model, recipe)
# create table if it doesn't exist # create table if it doesn't exist
self.cursor = self.recipe.db.cursor() self.cursor = self.recipe.db.cursor()
@ -81,17 +71,17 @@ class SqliteBeaker(Beaker):
f"CREATE TABLE IF NOT EXISTS {self.name} (uuid TEXT PRIMARY KEY, data JSON)" f"CREATE TABLE IF NOT EXISTS {self.name} (uuid TEXT PRIMARY KEY, data JSON)"
) )
def items(self): def items(self) -> Iterable[tuple[str, BaseModel]]:
self.cursor.execute(f"SELECT uuid, data FROM {self.name}") self.cursor.execute(f"SELECT uuid, data FROM {self.name}")
data = self.cursor.fetchall() data = self.cursor.fetchall()
for item in data: for item in data:
yield item["uuid"], self.model(**json.loads(item["data"])) yield item["uuid"], self.model(**json.loads(item["data"]))
def __len__(self): def __len__(self) -> int:
self.cursor.execute(f"SELECT COUNT(*) FROM {self.name}") self.cursor.execute(f"SELECT COUNT(*) FROM {self.name}")
return self.cursor.fetchone()[0] return self.cursor.fetchone()[0]
def add_item(self, item: "T", id: str | None = None) -> None: def add_item(self, item: BaseModel, id: str | None = None) -> None:
if id is None: if id is None:
id = str(uuid.uuid1()) id = str(uuid.uuid1())
print("UUID", id, item) print("UUID", id, item)
@ -101,6 +91,6 @@ class SqliteBeaker(Beaker):
) )
self.recipe.db.commit() self.recipe.db.commit()
def reset(self): def reset(self) -> None:
self.cursor.execute(f"DELETE FROM {self.name}") self.cursor.execute(f"DELETE FROM {self.name}")
self.recipe.db.commit() self.recipe.db.commit()

View File

@ -11,7 +11,7 @@ from beakers.beakers import SqliteBeaker
app = typer.Typer() app = typer.Typer()
def _load_recipe(dotted_path: str): def _load_recipe(dotted_path: str) -> SimpleNamespace:
sys.path.append(".") sys.path.append(".")
path, name = dotted_path.rsplit(".", 1) path, name = dotted_path.rsplit(".", 1)
mod = importlib.import_module(path) mod = importlib.import_module(path)
@ -22,7 +22,7 @@ def _load_recipe(dotted_path: str):
def main( def main(
ctx: typer.Context, ctx: typer.Context,
recipe: str = typer.Option(None, envvar="BEAKER_RECIPE"), recipe: str = typer.Option(None, envvar="BEAKER_RECIPE"),
): ) -> None:
if not recipe: if not recipe:
typer.secho( typer.secho(
"Missing recipe; pass --recipe or set env[BEAKER_RECIPE]", "Missing recipe; pass --recipe or set env[BEAKER_RECIPE]",
@ -33,7 +33,7 @@ def main(
@app.command() @app.command()
def reset(ctx: typer.Context): def reset(ctx: typer.Context) -> None:
for beaker in ctx.obj.beakers.values(): for beaker in ctx.obj.beakers.values():
if isinstance(beaker, SqliteBeaker): if isinstance(beaker, SqliteBeaker):
if bl := len(beaker): if bl := len(beaker):
@ -44,12 +44,12 @@ def reset(ctx: typer.Context):
@app.command() @app.command()
def show(ctx: typer.Context): def show(ctx: typer.Context) -> None:
ctx.obj.show() ctx.obj.show()
@app.command() @app.command()
def graph(ctx: typer.Context): def graph(ctx: typer.Context) -> None:
pprint(ctx.obj.graph_data()) pprint(ctx.obj.graph_data())
@ -59,15 +59,15 @@ def run(
input: Annotated[Optional[List[str]], typer.Option(...)] = None, input: Annotated[Optional[List[str]], typer.Option(...)] = None,
start: Optional[str] = typer.Option(None), start: Optional[str] = typer.Option(None),
end: Optional[str] = typer.Option(None), end: Optional[str] = typer.Option(None),
): ) -> None:
if ctx.obj.seeds: if ctx.obj.seeds:
typer.secho("Seeding beakers", fg=typer.colors.GREEN) typer.secho("Seeding beakers", fg=typer.colors.GREEN)
ctx.obj.process_seeds() ctx.obj.process_seeds()
has_data = any(ctx.obj.beakers.values()) has_data = any(ctx.obj.beakers.values())
if not has_data and not input: if not input and not has_data:
typer.secho("No data; pass --input to seed beaker(s)", fg=typer.colors.RED) typer.secho("No data; pass --input to seed beaker(s)", fg=typer.colors.RED)
raise typer.Exit(1) raise typer.Exit(1)
for input_str in input: for input_str in input: # type: ignore
beaker, filename = input_str.split("=") beaker, filename = input_str.split("=")
ctx.obj.csv_to_beaker(filename, beaker) ctx.obj.csv_to_beaker(filename, beaker)
ctx.obj.run_once(start, end) ctx.obj.run_once(start, end)

View File

@ -1,9 +1,9 @@
import httpx import httpx
import pydantic from pydantic import BaseModel, Field
import datetime import datetime
class HttpResponse(pydantic.BaseModel): class HttpResponse(BaseModel):
""" """
Beaker data type that represents an HTTP response. Beaker data type that represents an HTTP response.
""" """
@ -11,9 +11,7 @@ class HttpResponse(pydantic.BaseModel):
url: str url: str
status_code: int status_code: int
response_body: str response_body: str
retrieved_at: datetime.datetime = pydantic.Field( retrieved_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
default_factory=datetime.datetime.now
)
class HttpRequest: class HttpRequest:
@ -30,9 +28,8 @@ class HttpRequest:
self.beaker = beaker self.beaker = beaker
self.field = field self.field = field
async def __call__(self, item) -> HttpResponse: async def __call__(self, item: BaseModel) -> HttpResponse:
bkr = getattr(item, self.beaker) url = getattr(item, self.field)
url = getattr(bkr, self.field)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(url) response = await client.get(url)

View File

@ -8,7 +8,8 @@ import asyncio
import networkx import networkx
from collections import defaultdict, Counter from collections import defaultdict, Counter
from dataclasses import dataclass # TODO: pydantic? from dataclasses import dataclass # TODO: pydantic?
from typing import Iterable from pydantic import BaseModel
from typing import Iterable, Callable
from structlog import get_logger from structlog import get_logger
from .beakers import Beaker, SqliteBeaker, TempBeaker from .beakers import Beaker, SqliteBeaker, TempBeaker
@ -24,10 +25,16 @@ def get_sha512(filename: str) -> str:
@dataclass(frozen=True, eq=True) @dataclass(frozen=True, eq=True)
class Transform: class Transform:
name: str name: str
transform_func: callable transform_func: Callable
error_map: dict[tuple, str] error_map: dict[tuple, str]
class ErrorType(BaseModel):
item: BaseModel
exception: str
exc_type: str
def if_cond_true(data_cond_tup: tuple[dict, bool]) -> dict | None: def if_cond_true(data_cond_tup: tuple[dict, bool]) -> dict | None:
return data_cond_tup[0] if data_cond_tup[1] else None return data_cond_tup[0] if data_cond_tup[1] else None
@ -37,11 +44,11 @@ def if_cond_false(data_cond_tup: tuple[dict, bool]) -> dict | None:
class Recipe: class Recipe:
def __init__(self, name, db_name="beakers.db"): def __init__(self, name: str, db_name: str = "beakers.db"):
self.name = name self.name = name
self.graph = networkx.DiGraph() self.graph = networkx.DiGraph()
self.beakers = {} self.beakers: dict[str, Beaker] = {}
self.seeds = defaultdict(list) self.seeds: defaultdict[str, list[Iterable[BaseModel]]] = defaultdict(list)
self.db = sqlite3.connect(db_name) self.db = sqlite3.connect(db_name)
cursor = self.db.cursor() cursor = self.db.cursor()
cursor.execute( cursor.execute(
@ -63,9 +70,9 @@ class Recipe:
self, self,
from_beaker: str, from_beaker: str,
to_beaker: str, to_beaker: str,
transform_func: callable, transform_func: Callable,
*, *,
name=None, name: str | None = None,
error_map: dict[tuple, str] | None = None, error_map: dict[tuple, str] | None = None,
) -> None: ) -> None:
if name is None: if name is None:
@ -86,7 +93,7 @@ class Recipe:
def add_conditional( def add_conditional(
self, self,
from_beaker: str, from_beaker: str,
condition_func: callable, condition_func: Callable,
if_true: str, if_true: str,
if_false: str = "", if_false: str = "",
) -> None: ) -> None:
@ -117,7 +124,7 @@ class Recipe:
if_cond_false, if_cond_false,
) )
def add_seed(self, beaker_name: str, data: Iterable) -> None: def add_seed(self, beaker_name: str, data: Iterable[BaseModel]) -> None:
self.seeds[beaker_name].append(data) self.seeds[beaker_name].append(data)
def process_seeds(self) -> None: def process_seeds(self) -> None:
@ -126,30 +133,30 @@ class Recipe:
for seed in seeds: for seed in seeds:
self.beakers[beaker_name].add_items(seed) self.beakers[beaker_name].add_items(seed)
def get_metadata(self, table_name) -> dict: # def get_metadata(self, table_name: str) -> dict:
cursor = self.db.cursor() # cursor = self.db.cursor()
cursor.execute( # cursor.execute(
"SELECT data FROM _metadata WHERE table_name = ?", # "SELECT data FROM _metadata WHERE table_name = ?",
(table_name,), # (table_name,),
) # )
try: # try:
data = cursor.fetchone()["data"] # data = cursor.fetchone()["data"]
log.debug("get_metadata", table_name=table_name, data=data) # log.debug("get_metadata", table_name=table_name, data=data)
return json.loads(data) # return json.loads(data)
except TypeError: # except TypeError:
log.debug("get_metadata", table_name=table_name, data={}) # log.debug("get_metadata", table_name=table_name, data={})
return {} # return {}
def save_metadata(self, table_name: str, data: dict) -> None: # def save_metadata(self, table_name: str, data: dict) -> None:
data_json = json.dumps(data) # data_json = json.dumps(data)
log.info("save_metadata", table_name=table_name, data=data_json) # log.info("save_metadata", table_name=table_name, data=data_json)
# sqlite upsert # # sqlite upsert
cursor = self.db.cursor() # cursor = self.db.cursor()
cursor.execute( # cursor.execute(
"INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?", # "INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?",
(table_name, data_json, data_json), # (table_name, data_json, data_json),
) # )
self.db.commit() # self.db.commit()
def csv_to_beaker(self, filename: str, beaker_name: str) -> None: def csv_to_beaker(self, filename: str, beaker_name: str) -> None:
beaker = self.beakers[beaker_name] beaker = self.beakers[beaker_name]
@ -161,7 +168,7 @@ class Recipe:
reader = csv.DictReader(file) reader = csv.DictReader(file)
added = 0 added = 0
for row in reader: for row in reader:
beaker.add_item(row) beaker.add_item(beaker.model(**row))
added += 1 added += 1
lg.info("from_csv", case="empty", added=added) lg.info("from_csv", case="empty", added=added)
meta = self.get_metadata(beaker.name) meta = self.get_metadata(beaker.name)
@ -177,9 +184,8 @@ class Recipe:
else: else:
# case 2: match # case 2: match
lg.info("from_csv", case="match") lg.info("from_csv", case="match")
return beaker
def show(self): def show(self) -> None:
seed_count = Counter(self.seeds.keys()) seed_count = Counter(self.seeds.keys())
typer.secho("Seeds", fg=typer.colors.GREEN) typer.secho("Seeds", fg=typer.colors.GREEN)
for beaker, count in seed_count.items(): for beaker, count in seed_count.items():
@ -204,7 +210,7 @@ class Recipe:
else: else:
typer.secho(f" {k.__name__} -> {v}", fg=typer.colors.RED) typer.secho(f" {k.__name__} -> {v}", fg=typer.colors.RED)
def graph_data(self): def graph_data(self) -> list[dict]:
nodes = {} nodes = {}
for node in networkx.topological_sort(self.graph): for node in networkx.topological_sort(self.graph):
@ -290,11 +296,11 @@ class Recipe:
if isinstance(e, error_types): if isinstance(e, error_types):
error_beaker = self.beakers[error_beaker_name] error_beaker = self.beakers[error_beaker_name]
error_beaker.add_item( error_beaker.add_item(
{ ErrorType(
"item": item, item=item,
"exception": str(e), exception=str(e),
"exc_type": str(type(e)), exc_type=str(type(e)),
}, ),
id, id,
) )
break break