diff --git a/src/beakers/recipe.py b/src/beakers/recipe.py index c8e8819..eb45b7c 100644 --- a/src/beakers/recipe.py +++ b/src/beakers/recipe.py @@ -7,6 +7,8 @@ from dataclasses import dataclass import networkx from structlog import get_logger +from .beakers import Beaker, SqliteBeaker, TempBeaker + log = get_logger() """ @@ -41,54 +43,11 @@ pours are edges. The recipe is the graph. """ -# class Beaker: -# def __init__(self, table_name: str, recipe): -# self.table_name = table_name -# self.recipe = recipe - -# # create table if it doesn't exist -# self.recipe.cursor.execute( -# f"CREATE TABLE IF NOT EXISTS {self.table_name} (id INTEGER PRIMARY KEY, data JSON, from_table TEXT NULL, from_id INTEGER NULL)" -# ) - -# def __repr__(self): -# return f"Beaker({self.table_name})" - -# def items(self): -# self.recipe.cursor.execute(f"SELECT id, data FROM {self.table_name}") -# data = self.recipe.cursor.fetchall() -# for item in data: -# yield item["id"], json.loads(item["data"]) - -# def __len__(self): -# self.recipe.cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") -# return self.recipe.cursor.fetchone()[0] - -# def add_item(self, item: dict, from_table=None, from_id=None) -> None: -# self.recipe.cursor.execute( -# f"INSERT INTO {self.table_name} (data) VALUES (?)", (json.dumps(item),) -# ) -# self.recipe.cursor.commit() - - def get_sha512(filename: str) -> str: with open(filename, "rb") as file: return hashlib.sha512(file.read()).hexdigest() -@dataclass(eq=True, frozen=True) -class Beaker: - """ - A beaker is a node in the graph. - - They can correspond to tables in the database, - or they can be temporary. - """ - - name: str - temporary: bool = False - - @dataclass(eq=True, frozen=True) class Pour: """ @@ -130,17 +89,17 @@ class Recipe: return f"Recipe({self.name})" def add_beaker(self, name: str, temp: bool = False) -> Beaker: - beaker = Beaker(name, temporary=temp) - self.graph.add_node(beaker) - self.beakers[name] = beaker - return beaker + self.graph.add_node(name) + if temp: + self.beakers[name] = TempBeaker(name, self) + else: + self.beakers[name] = SqliteBeaker(name, self) + return self.beakers[name] def add_pour( self, from_beaker: str, to_beaker: str, transform_func: callable ) -> None: - self.graph.add_edge( - self.beakers[from_beaker], self.beakers[to_beaker], transform=transform_func - ) + self.graph.add_edge(from_beaker, to_beaker, transform=transform_func) def add_conditional( self, @@ -151,7 +110,7 @@ class Recipe: ) -> None: # first add a transform to evaluate the conditional cond_name = f"cond-{from_beaker}-{condition_func.__name__}" - cond = self.add_beaker(cond_name, temp=True) + self.add_beaker(cond_name, temp=True) self.add_pour( from_beaker, cond_name, @@ -160,13 +119,13 @@ class Recipe: # then add two filtered paths that remove the condition result self.graph.add_edge( - cond, - self.beakers[if_true], + cond_name, + if_true, filter_func=lambda data, condition: data if condition else None, ) self.graph.add_edge( - cond, - self.beakers[if_false], + cond_name, + if_false, filter_func=lambda data, condition: data if not condition else None, ) @@ -206,11 +165,11 @@ class Recipe: beaker.add_item(row) added += 1 lg.info("from_csv", case="empty", added=added) - meta = self.get_metadata(beaker.table_name) + meta = self.get_metadata(beaker.name) meta["sha512"] = get_sha512(filename) - self.save_metadata(beaker.table_name, meta) + self.save_metadata(beaker.name, meta) else: - old_sha = self.get_metadata(beaker).get("sha512") + old_sha = self.get_metadata(beaker.name).get("sha512") new_sha = get_sha512(filename) if old_sha != new_sha: # case 3: mismatch @@ -226,12 +185,7 @@ class Recipe: Solve the DAG by topological sort. """ for node in networkx.topological_sort(self.graph): - if isinstance(node, Beaker): - print(node) - elif isinstance(node, Conditional): - print(node) - else: - raise Exception("unknown node type") + print(node) def run_linearly(self): log.info("recipe", recipe=self) @@ -247,4 +201,4 @@ class Recipe: for id, item in pour.from_beaker.items(): log.info("pour_item", id=id, item=item) transformed = loop.run_until_complete(pour.transform(item)) - pour.to_beaker.add_item(transformed, pour.from_beaker.table_name, id) + pour.to_beaker.add_item(transformed, pour.from_beaker.name, id) diff --git a/src/example.py b/src/example.py index 41f5c4d..ec4ce74 100644 --- a/src/example.py +++ b/src/example.py @@ -19,7 +19,6 @@ recipe.add_beaker("agencies") recipe.add_beaker("responses") recipe.add_beaker("good_urls", temp=True) recipe.add_beaker("missing_urls", temp=True) -# recipe.csv_to_beaker("agencies.csv", "agencies") recipe.add_conditional( "agencies", lambda x: x["url"].startswith("http"), @@ -28,4 +27,5 @@ recipe.add_conditional( ) recipe.add_pour("good_urls", "responses", add_response) +recipe.csv_to_beaker("agencies.csv", "agencies") recipe.solve_dag()