more typing sanity
This commit is contained in:
parent
b49a1914b5
commit
d33483f136
@ -2,77 +2,67 @@ import abc
|
||||
import json
|
||||
import sqlite3
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from typing import Iterable, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .recipe import Recipe
|
||||
|
||||
class DataObject:
|
||||
def __init__(self, id: str | None = None):
|
||||
self._id = id if id else str(uuid.uuid4())
|
||||
self._data = {}
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self._data[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
elif name not in self._data:
|
||||
self._data[name] = value
|
||||
else:
|
||||
raise AttributeError(f"DataObject attribute {name} already exists")
|
||||
PydanticModel = Type[BaseModel]
|
||||
|
||||
|
||||
class Beaker(abc.ABC):
|
||||
def __init__(self, name: str, model: type, recipe: "Recipe"):
|
||||
def __init__(self, name: str, model: PydanticModel, recipe: Recipe):
|
||||
self.name = name
|
||||
self.model = model
|
||||
self.recipe = recipe
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Beaker({self.name}, {self.model.__name__})"
|
||||
|
||||
@abc.abstractmethod
|
||||
def items(self):
|
||||
def items(self) -> Iterable[tuple[str, BaseModel]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_item(self, item: "T", id: str | None = None) -> None:
|
||||
def add_item(self, item: BaseModel, id: str | None = None) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def add_items(self, items: list["T"]) -> None:
|
||||
def add_items(self, items: Iterable[BaseModel]) -> None:
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
|
||||
|
||||
class TempBeaker(Beaker):
|
||||
def __init__(self, name: str, model: type, recipe: "Recipe"):
|
||||
def __init__(self, name: str, model: PydanticModel | None, recipe: Recipe):
|
||||
super().__init__(name, model, recipe)
|
||||
self._items = []
|
||||
self._items: list[tuple[str, BaseModel]] = []
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self._items)
|
||||
|
||||
def add_item(self, item: "T", id=None) -> None:
|
||||
def add_item(self, item: BaseModel, id: str | None = None) -> None:
|
||||
if id is None:
|
||||
id = str(uuid.uuid1())
|
||||
self._items.append((id, item))
|
||||
|
||||
def items(self):
|
||||
def items(self) -> Iterable[tuple[str, BaseModel]]:
|
||||
yield from self._items
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self._items = []
|
||||
|
||||
|
||||
class SqliteBeaker(Beaker):
|
||||
def __init__(self, name: str, model: type, recipe: "Recipe"):
|
||||
def __init__(self, name: str, model: PydanticModel, recipe: Recipe):
|
||||
super().__init__(name, model, recipe)
|
||||
# create table if it doesn't exist
|
||||
self.cursor = self.recipe.db.cursor()
|
||||
@ -81,17 +71,17 @@ class SqliteBeaker(Beaker):
|
||||
f"CREATE TABLE IF NOT EXISTS {self.name} (uuid TEXT PRIMARY KEY, data JSON)"
|
||||
)
|
||||
|
||||
def items(self):
|
||||
def items(self) -> Iterable[tuple[str, BaseModel]]:
|
||||
self.cursor.execute(f"SELECT uuid, data FROM {self.name}")
|
||||
data = self.cursor.fetchall()
|
||||
for item in data:
|
||||
yield item["uuid"], self.model(**json.loads(item["data"]))
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
self.cursor.execute(f"SELECT COUNT(*) FROM {self.name}")
|
||||
return self.cursor.fetchone()[0]
|
||||
|
||||
def add_item(self, item: "T", id: str | None = None) -> None:
|
||||
def add_item(self, item: BaseModel, id: str | None = None) -> None:
|
||||
if id is None:
|
||||
id = str(uuid.uuid1())
|
||||
print("UUID", id, item)
|
||||
@ -101,6 +91,6 @@ class SqliteBeaker(Beaker):
|
||||
)
|
||||
self.recipe.db.commit()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.cursor.execute(f"DELETE FROM {self.name}")
|
||||
self.recipe.db.commit()
|
||||
|
@ -11,7 +11,7 @@ from beakers.beakers import SqliteBeaker
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def _load_recipe(dotted_path: str):
|
||||
def _load_recipe(dotted_path: str) -> SimpleNamespace:
|
||||
sys.path.append(".")
|
||||
path, name = dotted_path.rsplit(".", 1)
|
||||
mod = importlib.import_module(path)
|
||||
@ -22,7 +22,7 @@ def _load_recipe(dotted_path: str):
|
||||
def main(
|
||||
ctx: typer.Context,
|
||||
recipe: str = typer.Option(None, envvar="BEAKER_RECIPE"),
|
||||
):
|
||||
) -> None:
|
||||
if not recipe:
|
||||
typer.secho(
|
||||
"Missing recipe; pass --recipe or set env[BEAKER_RECIPE]",
|
||||
@ -33,7 +33,7 @@ def main(
|
||||
|
||||
|
||||
@app.command()
|
||||
def reset(ctx: typer.Context):
|
||||
def reset(ctx: typer.Context) -> None:
|
||||
for beaker in ctx.obj.beakers.values():
|
||||
if isinstance(beaker, SqliteBeaker):
|
||||
if bl := len(beaker):
|
||||
@ -44,12 +44,12 @@ def reset(ctx: typer.Context):
|
||||
|
||||
|
||||
@app.command()
|
||||
def show(ctx: typer.Context):
|
||||
def show(ctx: typer.Context) -> None:
|
||||
ctx.obj.show()
|
||||
|
||||
|
||||
@app.command()
|
||||
def graph(ctx: typer.Context):
|
||||
def graph(ctx: typer.Context) -> None:
|
||||
pprint(ctx.obj.graph_data())
|
||||
|
||||
|
||||
@ -59,15 +59,15 @@ def run(
|
||||
input: Annotated[Optional[List[str]], typer.Option(...)] = None,
|
||||
start: Optional[str] = typer.Option(None),
|
||||
end: Optional[str] = typer.Option(None),
|
||||
):
|
||||
) -> None:
|
||||
if ctx.obj.seeds:
|
||||
typer.secho("Seeding beakers", fg=typer.colors.GREEN)
|
||||
ctx.obj.process_seeds()
|
||||
has_data = any(ctx.obj.beakers.values())
|
||||
if not has_data and not input:
|
||||
if not input and not has_data:
|
||||
typer.secho("No data; pass --input to seed beaker(s)", fg=typer.colors.RED)
|
||||
raise typer.Exit(1)
|
||||
for input_str in input:
|
||||
for input_str in input: # type: ignore
|
||||
beaker, filename = input_str.split("=")
|
||||
ctx.obj.csv_to_beaker(filename, beaker)
|
||||
ctx.obj.run_once(start, end)
|
||||
|
@ -1,9 +1,9 @@
|
||||
import httpx
|
||||
import pydantic
|
||||
from pydantic import BaseModel, Field
|
||||
import datetime
|
||||
|
||||
|
||||
class HttpResponse(pydantic.BaseModel):
|
||||
class HttpResponse(BaseModel):
|
||||
"""
|
||||
Beaker data type that represents an HTTP response.
|
||||
"""
|
||||
@ -11,9 +11,7 @@ class HttpResponse(pydantic.BaseModel):
|
||||
url: str
|
||||
status_code: int
|
||||
response_body: str
|
||||
retrieved_at: datetime.datetime = pydantic.Field(
|
||||
default_factory=datetime.datetime.now
|
||||
)
|
||||
retrieved_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||
|
||||
|
||||
class HttpRequest:
|
||||
@ -30,9 +28,8 @@ class HttpRequest:
|
||||
self.beaker = beaker
|
||||
self.field = field
|
||||
|
||||
async def __call__(self, item) -> HttpResponse:
|
||||
bkr = getattr(item, self.beaker)
|
||||
url = getattr(bkr, self.field)
|
||||
async def __call__(self, item: BaseModel) -> HttpResponse:
|
||||
url = getattr(item, self.field)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
|
@ -8,7 +8,8 @@ import asyncio
|
||||
import networkx
|
||||
from collections import defaultdict, Counter
|
||||
from dataclasses import dataclass # TODO: pydantic?
|
||||
from typing import Iterable
|
||||
from pydantic import BaseModel
|
||||
from typing import Iterable, Callable
|
||||
from structlog import get_logger
|
||||
|
||||
from .beakers import Beaker, SqliteBeaker, TempBeaker
|
||||
@ -24,10 +25,16 @@ def get_sha512(filename: str) -> str:
|
||||
@dataclass(frozen=True, eq=True)
|
||||
class Transform:
|
||||
name: str
|
||||
transform_func: callable
|
||||
transform_func: Callable
|
||||
error_map: dict[tuple, str]
|
||||
|
||||
|
||||
class ErrorType(BaseModel):
|
||||
item: BaseModel
|
||||
exception: str
|
||||
exc_type: str
|
||||
|
||||
|
||||
def if_cond_true(data_cond_tup: tuple[dict, bool]) -> dict | None:
|
||||
return data_cond_tup[0] if data_cond_tup[1] else None
|
||||
|
||||
@ -37,11 +44,11 @@ def if_cond_false(data_cond_tup: tuple[dict, bool]) -> dict | None:
|
||||
|
||||
|
||||
class Recipe:
|
||||
def __init__(self, name, db_name="beakers.db"):
|
||||
def __init__(self, name: str, db_name: str = "beakers.db"):
|
||||
self.name = name
|
||||
self.graph = networkx.DiGraph()
|
||||
self.beakers = {}
|
||||
self.seeds = defaultdict(list)
|
||||
self.beakers: dict[str, Beaker] = {}
|
||||
self.seeds: defaultdict[str, list[Iterable[BaseModel]]] = defaultdict(list)
|
||||
self.db = sqlite3.connect(db_name)
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
@ -63,9 +70,9 @@ class Recipe:
|
||||
self,
|
||||
from_beaker: str,
|
||||
to_beaker: str,
|
||||
transform_func: callable,
|
||||
transform_func: Callable,
|
||||
*,
|
||||
name=None,
|
||||
name: str | None = None,
|
||||
error_map: dict[tuple, str] | None = None,
|
||||
) -> None:
|
||||
if name is None:
|
||||
@ -86,7 +93,7 @@ class Recipe:
|
||||
def add_conditional(
|
||||
self,
|
||||
from_beaker: str,
|
||||
condition_func: callable,
|
||||
condition_func: Callable,
|
||||
if_true: str,
|
||||
if_false: str = "",
|
||||
) -> None:
|
||||
@ -117,7 +124,7 @@ class Recipe:
|
||||
if_cond_false,
|
||||
)
|
||||
|
||||
def add_seed(self, beaker_name: str, data: Iterable) -> None:
|
||||
def add_seed(self, beaker_name: str, data: Iterable[BaseModel]) -> None:
|
||||
self.seeds[beaker_name].append(data)
|
||||
|
||||
def process_seeds(self) -> None:
|
||||
@ -126,30 +133,30 @@ class Recipe:
|
||||
for seed in seeds:
|
||||
self.beakers[beaker_name].add_items(seed)
|
||||
|
||||
def get_metadata(self, table_name) -> dict:
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"SELECT data FROM _metadata WHERE table_name = ?",
|
||||
(table_name,),
|
||||
)
|
||||
try:
|
||||
data = cursor.fetchone()["data"]
|
||||
log.debug("get_metadata", table_name=table_name, data=data)
|
||||
return json.loads(data)
|
||||
except TypeError:
|
||||
log.debug("get_metadata", table_name=table_name, data={})
|
||||
return {}
|
||||
# def get_metadata(self, table_name: str) -> dict:
|
||||
# cursor = self.db.cursor()
|
||||
# cursor.execute(
|
||||
# "SELECT data FROM _metadata WHERE table_name = ?",
|
||||
# (table_name,),
|
||||
# )
|
||||
# try:
|
||||
# data = cursor.fetchone()["data"]
|
||||
# log.debug("get_metadata", table_name=table_name, data=data)
|
||||
# return json.loads(data)
|
||||
# except TypeError:
|
||||
# log.debug("get_metadata", table_name=table_name, data={})
|
||||
# return {}
|
||||
|
||||
def save_metadata(self, table_name: str, data: dict) -> None:
|
||||
data_json = json.dumps(data)
|
||||
log.info("save_metadata", table_name=table_name, data=data_json)
|
||||
# sqlite upsert
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?",
|
||||
(table_name, data_json, data_json),
|
||||
)
|
||||
self.db.commit()
|
||||
# def save_metadata(self, table_name: str, data: dict) -> None:
|
||||
# data_json = json.dumps(data)
|
||||
# log.info("save_metadata", table_name=table_name, data=data_json)
|
||||
# # sqlite upsert
|
||||
# cursor = self.db.cursor()
|
||||
# cursor.execute(
|
||||
# "INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?",
|
||||
# (table_name, data_json, data_json),
|
||||
# )
|
||||
# self.db.commit()
|
||||
|
||||
def csv_to_beaker(self, filename: str, beaker_name: str) -> None:
|
||||
beaker = self.beakers[beaker_name]
|
||||
@ -161,7 +168,7 @@ class Recipe:
|
||||
reader = csv.DictReader(file)
|
||||
added = 0
|
||||
for row in reader:
|
||||
beaker.add_item(row)
|
||||
beaker.add_item(beaker.model(**row))
|
||||
added += 1
|
||||
lg.info("from_csv", case="empty", added=added)
|
||||
meta = self.get_metadata(beaker.name)
|
||||
@ -177,9 +184,8 @@ class Recipe:
|
||||
else:
|
||||
# case 2: match
|
||||
lg.info("from_csv", case="match")
|
||||
return beaker
|
||||
|
||||
def show(self):
|
||||
def show(self) -> None:
|
||||
seed_count = Counter(self.seeds.keys())
|
||||
typer.secho("Seeds", fg=typer.colors.GREEN)
|
||||
for beaker, count in seed_count.items():
|
||||
@ -204,7 +210,7 @@ class Recipe:
|
||||
else:
|
||||
typer.secho(f" {k.__name__} -> {v}", fg=typer.colors.RED)
|
||||
|
||||
def graph_data(self):
|
||||
def graph_data(self) -> list[dict]:
|
||||
nodes = {}
|
||||
|
||||
for node in networkx.topological_sort(self.graph):
|
||||
@ -290,11 +296,11 @@ class Recipe:
|
||||
if isinstance(e, error_types):
|
||||
error_beaker = self.beakers[error_beaker_name]
|
||||
error_beaker.add_item(
|
||||
{
|
||||
"item": item,
|
||||
"exception": str(e),
|
||||
"exc_type": str(type(e)),
|
||||
},
|
||||
ErrorType(
|
||||
item=item,
|
||||
exception=str(e),
|
||||
exc_type=str(type(e)),
|
||||
),
|
||||
id,
|
||||
)
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user