This commit is contained in:
James Turk 2023-07-11 12:41:19 -05:00
parent 11003ef872
commit 84c2f1a641
5 changed files with 37 additions and 15 deletions

View File

@ -2,7 +2,6 @@ import datetime
from pydantic import BaseModel
import lxml
from beakers import Recipe
from beakers.filters import ConditionalFilter
from beakers.http import HttpRequest
@ -65,7 +64,7 @@ other = [
ArticleURL(url="https://nytimes.com", source="nytimes"),
]
# recipe.add_seed(
# "article_url",
# npr_examples + other,
# )
recipe.add_seed(
"url",
npr_examples + other,
)

View File

@ -60,6 +60,9 @@ def run(
start: Optional[str] = typer.Option(None),
end: Optional[str] = typer.Option(None),
):
if ctx.obj.seeds:
typer.secho("Seeding beakers", fg=typer.colors.GREEN)
ctx.obj.process_seeds()
has_data = any(ctx.obj.beakers.values())
if not has_data and not input:
typer.secho("No data; pass --input to seed beaker(s)", fg=typer.colors.RED)

View File

@ -1,9 +0,0 @@
class ConditionalFilter:
def __init__(self, condition):
self.condition = condition
def __call__(self, item):
if self.condition(item):
return item
else:
return None

View File

@ -4,6 +4,10 @@ import datetime
class HttpResponse(pydantic.BaseModel):
"""
Beaker data type that represents an HTTP response.
"""
url: str
status_code: int
response_body: str
@ -13,7 +17,16 @@ class HttpResponse(pydantic.BaseModel):
class HttpRequest:
"""
Filter that converts from a beaker with a URL to a beaker with an HTTP response.
"""
def __init__(self, beaker: str, field: str):
"""
Args:
beaker: The name of the beaker that contains the URL.
field: The name of the field in the beaker that contains the URL.
"""
self.beaker = beaker
self.field = field

View File

@ -5,8 +5,10 @@ import inspect
import sqlite3
import hashlib
import asyncio
from dataclasses import dataclass
import networkx
from collections import defaultdict, Counter
from dataclasses import dataclass # TODO: pydantic?
from typing import Iterable
from structlog import get_logger
from .beakers import Beaker, SqliteBeaker, TempBeaker
@ -39,6 +41,7 @@ class Recipe:
self.name = name
self.graph = networkx.DiGraph()
self.beakers = {}
self.seeds = defaultdict(list)
self.db = sqlite3.connect(db_name)
cursor = self.db.cursor()
cursor.execute(
@ -114,6 +117,15 @@ class Recipe:
if_cond_false,
)
def add_seed(self, beaker_name: str, data: Iterable) -> None:
self.seeds[beaker_name].append(data)
def process_seeds(self) -> None:
log.info("process_seeds", recipe=self.name)
for beaker_name, seeds in self.seeds.items():
for seed in seeds:
self.beakers[beaker_name].add_item(seed)
def get_metadata(self, table_name) -> dict:
cursor = self.db.cursor()
cursor.execute(
@ -168,6 +180,10 @@ class Recipe:
return beaker
def show(self):
seed_count = Counter(self.seeds.keys())
typer.secho("Seeds", fg=typer.colors.GREEN)
for beaker, count in seed_count.items():
typer.secho(f" {beaker} ({count})", fg=typer.colors.GREEN)
graph_data = self.graph_data()
for node in graph_data:
if node["temp"]: