diff --git a/src/foiaghost/ghost.py b/src/foiaghost/ghost.py new file mode 100644 index 0000000..e167158 --- /dev/null +++ b/src/foiaghost/ghost.py @@ -0,0 +1,28 @@ +from scrapeghost import SchemaScraper + +schema = { + "public_records_email": "email", + "public_records_address": "str", + "public_records_phone": "555-555-5555", + "public_records_fax": "555-555-5555", + "public_records_web": "url", + "general_contact_phone": "555-555-5555", + "general_contact_address": "str", + "foia_guide": "url", + "public_reading_room": "url", + "agency_logo": "url", +} +extra_instructions = """ +The fields that begin with public_records should refer to contact information specific to FOIA/Public Information/Freedom of Information requests. +The fields that begin with general_contact should refer to contact information for the agency in general. +If a field is not found in the HTML, leave it as null in the JSON. +""" + +# create a scrapeghost +ghost = SchemaScraper( + schema=schema, + models=["gpt-3.5-turbo-16k"], + # extra_preprocessors=[], + extra_instructions=extra_instructions, + max_cost=10, +) diff --git a/src/foiaghost/models.py b/src/foiaghost/models.py index 61363e6..2d8cdce 100644 --- a/src/foiaghost/models.py +++ b/src/foiaghost/models.py @@ -13,3 +13,17 @@ class URL(BaseModel): class Int(BaseModel): int: int + + +class IdOnly(BaseModel): + pass + + +class JSON(BaseModel): + scraped_json: dict | list + + +class ScrapeghostResponse(BaseModel): + total_cost: float + api_time: float + data: dict diff --git a/src/foiaghost/pipeline.py b/src/foiaghost/pipeline.py index 22fbf35..9864574 100644 --- a/src/foiaghost/pipeline.py +++ b/src/foiaghost/pipeline.py @@ -1,14 +1,19 @@ -from ssl import SSLCertVerificationError, SSLError +import csv import httpx import tiktoken import lxml.html +import lxml.etree from lxml.etree import ParserError -from databeakers import Pipeline +from ssl import SSLCertVerificationError, SSLError +from databeakers.pipeline import Pipeline, EdgeType, ErrorType +from databeakers.beakers import TempBeaker from databeakers.http import HttpRequest, HttpResponse -from scrapeghost import SchemaScraper from scrapeghost.preprocessors import CleanHTML -from .models import Agency, URL, Int -import csv +from scrapeghost.errors import TooManyTokens +from openai.error import InvalidRequestError +from scrapeghost.errors import BadStop +from .models import Agency, URL, Int, IdOnly, ScrapeghostResponse +from .ghost import ghost class CSVSource: @@ -24,10 +29,10 @@ class CSVSource: def tiktoken_count(response): - if response["status_code"] != 200: + if response.status_code != 200: raise ValueError("response status code is not 200") - html = response["response_body"] + html = response.response_body # clean the html cleaner = CleanHTML() @@ -37,50 +42,69 @@ def tiktoken_count(response): html_again = lxml.html.tostring(doc, encoding="unicode") tokens = len(encoding.encode(html_again)) - response["tiktoken_count"] = tokens - - return response + return Int(int=tokens) # current thinking, beakers exist within a recipe recipe = Pipeline("foiaghost", "foiaghost.db") recipe.add_beaker("agency", Agency) -recipe.add_beaker("good_urls", URL) -recipe.add_transform("agency", "good_urls", lambda x: x["url"].startswith("http")) -recipe.add_beaker("responses", HttpResponse) -recipe.add_transform("good_urls", "responses", HttpRequest) -recipe.add_beaker("tiktoken_count", Int) -recipe.add_transform( - "responses", - "tiktoken_count", - tiktoken_count, - error_map={(ValueError, ParserError): "no_tiktoken_count"}, -) recipe.add_seed( "agencies", "agency", CSVSource("agencies.csv", Agency), ) - -# recipe.add_beaker("token_lt_8k", temp=True) -# recipe.add_beaker("token_gt_8k", temp=True) +recipe.add_beaker("good_urls", URL) +recipe.add_transform( + "agency", + "good_urls", + lambda a: a.url.startswith("http"), + edge_type=EdgeType.conditional, +) +recipe.add_beaker("responses", HttpResponse) +recipe.add_transform( + "good_urls", + "responses", + HttpRequest(), + error_map={ + ( + httpx.HTTPError, + SSLCertVerificationError, + SSLError, + ): "bad_requests" + }, +) -# recipe.add_transform( -# "good_urls", -# "responses", -# add_response, -# error_map={ -# ( -# httpx.HTTPError, -# SSLCertVerificationError, -# SSLError, -# ): "bad_requests" -# }, -# ) -# recipe.add_conditional( -# "with_tiktoken_count", -# lambda x: x["tiktoken_count"] < 8000, -# if_true="token_lt_8k", -# if_false="token_gt_8k", -# ) +class ProcessRecord: + def __init__(self, func, params_map): + self.func = func + self.params_map = params_map + + def __call__(self, record): + kwargs = {} + for param, (beaker_name, field_name) in self.params_map.items(): + kwargs[param] = getattr(record[beaker_name], field_name) + return self.func(**kwargs) + + def __repr__(self): + return f"ProcessRecord({self.func.__name__}, {self.params_map})" + + +def scrapeghost_response(record) -> ScrapeghostResponse: + sg = ghost.scrape(url_or_html=record["responses"].response_body) + return ScrapeghostResponse( + total_cost=sg.total_cost, api_time=sg.api_time, data=sg.data + ) + + +recipe.add_transform( + "responses", + "scrapeghost_response", + scrapeghost_response, + whole_record=True, + error_map={ + (TooManyTokens,): "scrapeghost_too_many_tokens", + (BadStop,): "scrapeghost_bad_stop", + (InvalidRequestError, ValueError, ParserError): "scrapeghost_invalid_request", + }, +)