import csv import json import typer import inspect import sqlite3 import hashlib import asyncio from dataclasses import dataclass import networkx from structlog import get_logger from .beakers import Beaker, SqliteBeaker, TempBeaker log = get_logger() def get_sha512(filename: str) -> str: with open(filename, "rb") as file: return hashlib.sha512(file.read()).hexdigest() @dataclass(frozen=True, eq=True) class Transform: name: str transform_func: callable error_map: dict[tuple, str] def if_cond_true(data_cond_tup: tuple[dict, bool]) -> dict | None: return data_cond_tup[0] if data_cond_tup[1] else None def if_cond_false(data_cond_tup: tuple[dict, bool]) -> dict | None: return data_cond_tup[0] if not data_cond_tup[1] else None class Recipe: def __init__(self, name, db_name="beakers.db"): self.name = name self.graph = networkx.DiGraph() self.beakers = {} self.db = sqlite3.connect(db_name) cursor = self.db.cursor() cursor.execute( "CREATE TABLE IF NOT EXISTS _metadata (table_name TEXT PRIMARY KEY, data JSON)" ) def __repr__(self) -> str: return f"Recipe({self.name})" def add_beaker(self, name: str, temp: bool = False) -> Beaker: self.graph.add_node(name) if temp: self.beakers[name] = TempBeaker(name, self) else: self.beakers[name] = SqliteBeaker(name, self) return self.beakers[name] def add_transform( self, from_beaker: str, to_beaker: str, transform_func: callable, *, name=None, error_map: dict[tuple, str] | None = None, ) -> None: if name is None: name = transform_func.__name__ if name == "": name = "λ" transform = Transform( name=name, transform_func=transform_func, error_map=error_map or {}, ) self.graph.add_edge( from_beaker, to_beaker, transform=transform, ) def add_conditional( self, from_beaker: str, condition_func: callable, if_true: str, if_false: str, ) -> None: # first add a transform to evaluate the conditional if condition_func.__name__ == "": cond_name = f"cond-{from_beaker}" else: cond_name = f"cond-{from_beaker}-{condition_func.__name__}" self.add_beaker(cond_name, temp=True) self.add_transform( from_beaker, cond_name, lambda data: (data, condition_func(data)), name=cond_name, ) # then add two filtered paths that remove the condition result self.add_transform( cond_name, if_true, if_cond_true, ) self.add_transform( cond_name, if_false, if_cond_false, ) def get_metadata(self, table_name) -> dict: cursor = self.db.cursor() cursor.execute( "SELECT data FROM _metadata WHERE table_name = ?", (table_name,), ) try: data = cursor.fetchone()["data"] log.debug("get_metadata", table_name=table_name, data=data) return json.loads(data) except TypeError: log.debug("get_metadata", table_name=table_name, data={}) return {} def save_metadata(self, table_name: str, data: dict) -> None: data_json = json.dumps(data) log.info("save_metadata", table_name=table_name, data=data_json) # sqlite upsert cursor = self.db.cursor() cursor.execute( "INSERT INTO _metadata (table_name, data) VALUES (?, ?) ON CONFLICT(table_name) DO UPDATE SET data = ?", (table_name, data_json, data_json), ) self.db.commit() def csv_to_beaker(self, filename: str, beaker_name: str) -> None: beaker = self.beakers[beaker_name] lg = log.bind(beaker=beaker, filename=filename) # three cases: empty, match, mismatch # case 1: empty if len(beaker) == 0: with open(filename, "r") as file: reader = csv.DictReader(file) added = 0 for row in reader: beaker.add_item(row) added += 1 lg.info("from_csv", case="empty", added=added) meta = self.get_metadata(beaker.name) meta["sha512"] = get_sha512(filename) self.save_metadata(beaker.name, meta) else: old_sha = self.get_metadata(beaker.name).get("sha512") new_sha = get_sha512(filename) if old_sha != new_sha: # case 3: mismatch lg.info("from_csv", case="mismatch", old_sha=old_sha, new_sha=new_sha) raise Exception("sha512 mismatch") else: # case 2: match lg.info("from_csv", case="match") return beaker def show(self): for node in networkx.topological_sort(self.graph): beaker = self.beakers[node] temp = isinstance(beaker, TempBeaker) if temp: typer.secho(node, fg=typer.colors.CYAN) else: lb = len(beaker) typer.secho( f"{node} ({lb})", fg=typer.colors.GREEN if lb else typer.colors.YELLOW, ) for from_b, to_b, edge in self.graph.out_edges(node, data=True): name = edge["transform"].name print(f" {from_b} -({name})-> {to_b}") for k, v in edge["transform"].error_map.items(): typer.secho( f" {' '.join(c.__name__ for c in k)} -> {v}", fg=typer.colors.RED, ) def run_once( self, start_beaker: str | None = None, end_beaker: str | None = None ) -> None: log.info("run_once", recipe=self) loop = asyncio.get_event_loop() started = False if start_beaker else True # go through each node in forward order, pushing data for node in networkx.topological_sort(self.graph): # only process nodes between start and end if not started: if node == start_beaker: started = True log.info("partial run start", node=node) else: log.info("partial run skip", node=node, waiting_for=start_beaker) continue if end_beaker and node == end_beaker: log.info("partial run end", node=node) break # get outbound edges edges = self.graph.out_edges(node, data=True) for from_b, to_b, edge in edges: transform = edge["transform"] from_beaker = self.beakers[from_b] to_beaker = self.beakers[to_b] log.info( "transform", from_b=from_b, to_b=to_b, items=len(from_beaker), transform=edge["transform"].name, ) # convert coroutine to function if inspect.iscoroutinefunction(transform.transform_func): t_func = lambda x: loop.run_until_complete( transform.transform_func(x) ) else: t_func = transform.transform_func for id, item in from_beaker.items(): try: transformed = t_func(item) if transformed: to_beaker.add_item(transformed, id) except Exception as e: for ( error_types, error_beaker_name, ) in transform.error_map.items(): if isinstance(e, error_types): error_beaker = self.beakers[error_beaker_name] error_beaker.add_item( { "item": item, "exception": str(e), "exc_type": str(type(e)), }, id, ) break else: # no error handler, re-raise raise