259 lines
8.6 KiB
Python
259 lines
8.6 KiB
Python
import csv
|
|
import json
|
|
import typer
|
|
import inspect
|
|
import sqlite3
|
|
import hashlib
|
|
import asyncio
|
|
from dataclasses import dataclass
|
|
import networkx
|
|
from structlog import get_logger
|
|
|
|
from .beakers import Beaker, SqliteBeaker, TempBeaker
|
|
|
|
log = get_logger()
|
|
|
|
|
|
def get_sha512(filename: str) -> str:
|
|
with open(filename, "rb") as file:
|
|
return hashlib.sha512(file.read()).hexdigest()
|
|
|
|
|
|
@dataclass(frozen=True, eq=True)
|
|
class Transform:
|
|
name: str
|
|
transform_func: callable
|
|
error_map: dict[tuple, str]
|
|
|
|
|
|
def if_cond_true(data_cond_tup: tuple[dict, bool]) -> dict | None:
|
|
return data_cond_tup[0] if data_cond_tup[1] else None
|
|
|
|
|
|
def if_cond_false(data_cond_tup: tuple[dict, bool]) -> dict | None:
|
|
return data_cond_tup[0] if not data_cond_tup[1] else None
|
|
|
|
|
|
class Recipe:
|
|
def __init__(self, name, db_name="beakers.db"):
|
|
self.name = name
|
|
self.graph = networkx.DiGraph()
|
|
self.beakers = {}
|
|
self.db = sqlite3.connect(db_name)
|
|
cursor = self.db.cursor()
|
|
cursor.execute(
|
|
"CREATE TABLE IF NOT EXISTS _metadata (table_name TEXT PRIMARY KEY, data JSON)"
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Recipe({self.name})"
|
|
|
|
def add_beaker(self, name: str, temp: bool = False) -> Beaker:
|
|
self.graph.add_node(name)
|
|
if temp:
|
|
self.beakers[name] = TempBeaker(name, self)
|
|
else:
|
|
self.beakers[name] = SqliteBeaker(name, self)
|
|
return self.beakers[name]
|
|
|
|
def add_transform(
|
|
self,
|
|
from_beaker: str,
|
|
to_beaker: str,
|
|
transform_func: callable,
|
|
*,
|
|
name=None,
|
|
error_map: dict[tuple, str] | None = None,
|
|
) -> None:
|
|
if name is None:
|
|
name = transform_func.__name__
|
|
if name == "<lambda>":
|
|
name = "λ"
|
|
transform = Transform(
|
|
name=name,
|
|
transform_func=transform_func,
|
|
error_map=error_map or {},
|
|
)
|
|
self.graph.add_edge(
|
|
from_beaker,
|
|
to_beaker,
|
|
transform=transform,
|
|
)
|
|
|
|
def add_conditional(
|
|
self,
|
|
from_beaker: str,
|
|
condition_func: callable,
|
|
if_true: str,
|
|
if_false: str,
|
|
) -> None:
|
|
# first add a transform to evaluate the conditional
|
|
if condition_func.__name__ == "<lambda>":
|
|
cond_name = f"cond-{from_beaker}"
|
|
else:
|
|
cond_name = f"cond-{from_beaker}-{condition_func.__name__}"
|
|
self.add_beaker(cond_name, temp=True)
|
|
self.add_transform(
|
|
from_beaker,
|
|
cond_name,
|
|
lambda data: (data, condition_func(data)),
|
|
name=cond_name,
|
|
)
|
|
|
|
# then add two filtered paths that remove the condition result
|
|
self.add_transform(
|
|
cond_name,
|
|
if_true,
|
|
if_cond_true,
|
|
)
|
|
self.add_transform(
|
|
cond_name,
|
|
if_false,
|
|
if_cond_false,
|
|
)
|
|
|
|
def get_metadata(self, table_name) -> dict:
|
|
cursor = self.db.cursor()
|
|
cursor.execute(
|
|
"SELECT data FROM _metadata WHERE table_name = ?",
|
|
(table_name,),
|
|
)
|
|
try:
|
|
data = cursor.fetchone()["data"]
|
|
log.debug("get_metadata", table_name=table_name, data=data)
|
|
return json.loads(data)
|
|
except TypeError:
|
|
log.debug("get_metadata", table_name=table_name, data={})
|
|
return {}
|
|
|
|
def save_metadata(self, table_name: str, data: dict) -> None:
|
|
data_json = json.dumps(data)
|
|
log.info("save_metadata", table_name=table_name, data=data_json)
|
|
# sqlite upsert
|
|
cursor = self.db.cursor()
|
|
cursor.execute(
|
|
"INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?",
|
|
(table_name, data_json, data_json),
|
|
)
|
|
self.db.commit()
|
|
|
|
def csv_to_beaker(self, filename: str, beaker_name: str) -> None:
|
|
beaker = self.beakers[beaker_name]
|
|
lg = log.bind(beaker=beaker, filename=filename)
|
|
# three cases: empty, match, mismatch
|
|
# case 1: empty
|
|
if len(beaker) == 0:
|
|
with open(filename, "r") as file:
|
|
reader = csv.DictReader(file)
|
|
added = 0
|
|
for row in reader:
|
|
beaker.add_item(row)
|
|
added += 1
|
|
lg.info("from_csv", case="empty", added=added)
|
|
meta = self.get_metadata(beaker.name)
|
|
meta["sha512"] = get_sha512(filename)
|
|
self.save_metadata(beaker.name, meta)
|
|
else:
|
|
old_sha = self.get_metadata(beaker.name).get("sha512")
|
|
new_sha = get_sha512(filename)
|
|
if old_sha != new_sha:
|
|
# case 3: mismatch
|
|
lg.info("from_csv", case="mismatch", old_sha=old_sha, new_sha=new_sha)
|
|
raise Exception("sha512 mismatch")
|
|
else:
|
|
# case 2: match
|
|
lg.info("from_csv", case="match")
|
|
return beaker
|
|
|
|
def show(self):
|
|
for node in networkx.topological_sort(self.graph):
|
|
beaker = self.beakers[node]
|
|
temp = isinstance(beaker, TempBeaker)
|
|
if temp:
|
|
typer.secho(node, fg=typer.colors.CYAN)
|
|
else:
|
|
lb = len(beaker)
|
|
typer.secho(
|
|
f"{node} ({lb})",
|
|
fg=typer.colors.GREEN if lb else typer.colors.YELLOW,
|
|
)
|
|
for from_b, to_b, edge in self.graph.out_edges(node, data=True):
|
|
name = edge["transform"].name
|
|
print(f" {from_b} -({name})-> {to_b}")
|
|
for k, v in edge["transform"].error_map.items():
|
|
typer.secho(
|
|
f" {' '.join(c.__name__ for c in k)} -> {v}",
|
|
fg=typer.colors.RED,
|
|
)
|
|
|
|
def run_once(
|
|
self, start_beaker: str | None = None, end_beaker: str | None = None
|
|
) -> None:
|
|
log.info("run_once", recipe=self)
|
|
loop = asyncio.get_event_loop()
|
|
|
|
started = False if start_beaker else True
|
|
|
|
# go through each node in forward order, pushing data
|
|
for node in networkx.topological_sort(self.graph):
|
|
# only process nodes between start and end
|
|
if not started:
|
|
if node == start_beaker:
|
|
started = True
|
|
log.info("partial run start", node=node)
|
|
else:
|
|
log.info("partial run skip", node=node, waiting_for=start_beaker)
|
|
continue
|
|
if end_beaker and node == end_beaker:
|
|
log.info("partial run end", node=node)
|
|
break
|
|
|
|
# get outbound edges
|
|
edges = self.graph.out_edges(node, data=True)
|
|
|
|
for from_b, to_b, edge in edges:
|
|
transform = edge["transform"]
|
|
|
|
from_beaker = self.beakers[from_b]
|
|
to_beaker = self.beakers[to_b]
|
|
log.info(
|
|
"transform",
|
|
from_b=from_b,
|
|
to_b=to_b,
|
|
items=len(from_beaker),
|
|
transform=edge["transform"].name,
|
|
)
|
|
|
|
# convert coroutine to function
|
|
if inspect.iscoroutinefunction(transform.transform_func):
|
|
t_func = lambda x: loop.run_until_complete(
|
|
transform.transform_func(x)
|
|
)
|
|
else:
|
|
t_func = transform.transform_func
|
|
|
|
for id, item in from_beaker.items():
|
|
try:
|
|
transformed = t_func(item)
|
|
if transformed:
|
|
to_beaker.add_item(transformed, id)
|
|
except Exception as e:
|
|
for (
|
|
error_types,
|
|
error_beaker_name,
|
|
) in transform.error_map.items():
|
|
if isinstance(e, error_types):
|
|
error_beaker = self.beakers[error_beaker_name]
|
|
error_beaker.add_item(
|
|
{
|
|
"item": item,
|
|
"exception": str(e),
|
|
"exc_type": str(type(e)),
|
|
},
|
|
id,
|
|
)
|
|
break
|
|
else:
|
|
# no error handler, re-raise
|
|
raise
|