big update in august
This commit is contained in:
parent
c510a4d5cc
commit
30f0fce57a
28
src/foiaghost/ghost.py
Normal file
28
src/foiaghost/ghost.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from scrapeghost import SchemaScraper
|
||||||
|
|
||||||
|
schema = {
|
||||||
|
"public_records_email": "email",
|
||||||
|
"public_records_address": "str",
|
||||||
|
"public_records_phone": "555-555-5555",
|
||||||
|
"public_records_fax": "555-555-5555",
|
||||||
|
"public_records_web": "url",
|
||||||
|
"general_contact_phone": "555-555-5555",
|
||||||
|
"general_contact_address": "str",
|
||||||
|
"foia_guide": "url",
|
||||||
|
"public_reading_room": "url",
|
||||||
|
"agency_logo": "url",
|
||||||
|
}
|
||||||
|
extra_instructions = """
|
||||||
|
The fields that begin with public_records should refer to contact information specific to FOIA/Public Information/Freedom of Information requests.
|
||||||
|
The fields that begin with general_contact should refer to contact information for the agency in general.
|
||||||
|
If a field is not found in the HTML, leave it as null in the JSON.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# create a scrapeghost
|
||||||
|
ghost = SchemaScraper(
|
||||||
|
schema=schema,
|
||||||
|
models=["gpt-3.5-turbo-16k"],
|
||||||
|
# extra_preprocessors=[],
|
||||||
|
extra_instructions=extra_instructions,
|
||||||
|
max_cost=10,
|
||||||
|
)
|
@ -13,3 +13,17 @@ class URL(BaseModel):
|
|||||||
|
|
||||||
class Int(BaseModel):
|
class Int(BaseModel):
|
||||||
int: int
|
int: int
|
||||||
|
|
||||||
|
|
||||||
|
class IdOnly(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JSON(BaseModel):
|
||||||
|
scraped_json: dict | list
|
||||||
|
|
||||||
|
|
||||||
|
class ScrapeghostResponse(BaseModel):
|
||||||
|
total_cost: float
|
||||||
|
api_time: float
|
||||||
|
data: dict
|
||||||
|
@ -1,14 +1,19 @@
|
|||||||
from ssl import SSLCertVerificationError, SSLError
|
import csv
|
||||||
import httpx
|
import httpx
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import lxml.html
|
import lxml.html
|
||||||
|
import lxml.etree
|
||||||
from lxml.etree import ParserError
|
from lxml.etree import ParserError
|
||||||
from databeakers import Pipeline
|
from ssl import SSLCertVerificationError, SSLError
|
||||||
|
from databeakers.pipeline import Pipeline, EdgeType, ErrorType
|
||||||
|
from databeakers.beakers import TempBeaker
|
||||||
from databeakers.http import HttpRequest, HttpResponse
|
from databeakers.http import HttpRequest, HttpResponse
|
||||||
from scrapeghost import SchemaScraper
|
|
||||||
from scrapeghost.preprocessors import CleanHTML
|
from scrapeghost.preprocessors import CleanHTML
|
||||||
from .models import Agency, URL, Int
|
from scrapeghost.errors import TooManyTokens
|
||||||
import csv
|
from openai.error import InvalidRequestError
|
||||||
|
from scrapeghost.errors import BadStop
|
||||||
|
from .models import Agency, URL, Int, IdOnly, ScrapeghostResponse
|
||||||
|
from .ghost import ghost
|
||||||
|
|
||||||
|
|
||||||
class CSVSource:
|
class CSVSource:
|
||||||
@ -24,10 +29,10 @@ class CSVSource:
|
|||||||
|
|
||||||
|
|
||||||
def tiktoken_count(response):
|
def tiktoken_count(response):
|
||||||
if response["status_code"] != 200:
|
if response.status_code != 200:
|
||||||
raise ValueError("response status code is not 200")
|
raise ValueError("response status code is not 200")
|
||||||
|
|
||||||
html = response["response_body"]
|
html = response.response_body
|
||||||
|
|
||||||
# clean the html
|
# clean the html
|
||||||
cleaner = CleanHTML()
|
cleaner = CleanHTML()
|
||||||
@ -37,50 +42,69 @@ def tiktoken_count(response):
|
|||||||
html_again = lxml.html.tostring(doc, encoding="unicode")
|
html_again = lxml.html.tostring(doc, encoding="unicode")
|
||||||
tokens = len(encoding.encode(html_again))
|
tokens = len(encoding.encode(html_again))
|
||||||
|
|
||||||
response["tiktoken_count"] = tokens
|
return Int(int=tokens)
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
# current thinking, beakers exist within a recipe
|
# current thinking, beakers exist within a recipe
|
||||||
recipe = Pipeline("foiaghost", "foiaghost.db")
|
recipe = Pipeline("foiaghost", "foiaghost.db")
|
||||||
recipe.add_beaker("agency", Agency)
|
recipe.add_beaker("agency", Agency)
|
||||||
recipe.add_beaker("good_urls", URL)
|
|
||||||
recipe.add_transform("agency", "good_urls", lambda x: x["url"].startswith("http"))
|
|
||||||
recipe.add_beaker("responses", HttpResponse)
|
|
||||||
recipe.add_transform("good_urls", "responses", HttpRequest)
|
|
||||||
recipe.add_beaker("tiktoken_count", Int)
|
|
||||||
recipe.add_transform(
|
|
||||||
"responses",
|
|
||||||
"tiktoken_count",
|
|
||||||
tiktoken_count,
|
|
||||||
error_map={(ValueError, ParserError): "no_tiktoken_count"},
|
|
||||||
)
|
|
||||||
recipe.add_seed(
|
recipe.add_seed(
|
||||||
"agencies",
|
"agencies",
|
||||||
"agency",
|
"agency",
|
||||||
CSVSource("agencies.csv", Agency),
|
CSVSource("agencies.csv", Agency),
|
||||||
)
|
)
|
||||||
|
recipe.add_beaker("good_urls", URL)
|
||||||
# recipe.add_beaker("token_lt_8k", temp=True)
|
recipe.add_transform(
|
||||||
# recipe.add_beaker("token_gt_8k", temp=True)
|
"agency",
|
||||||
|
"good_urls",
|
||||||
|
lambda a: a.url.startswith("http"),
|
||||||
|
edge_type=EdgeType.conditional,
|
||||||
|
)
|
||||||
|
recipe.add_beaker("responses", HttpResponse)
|
||||||
|
recipe.add_transform(
|
||||||
|
"good_urls",
|
||||||
|
"responses",
|
||||||
|
HttpRequest(),
|
||||||
|
error_map={
|
||||||
|
(
|
||||||
|
httpx.HTTPError,
|
||||||
|
SSLCertVerificationError,
|
||||||
|
SSLError,
|
||||||
|
): "bad_requests"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# recipe.add_transform(
|
class ProcessRecord:
|
||||||
# "good_urls",
|
def __init__(self, func, params_map):
|
||||||
# "responses",
|
self.func = func
|
||||||
# add_response,
|
self.params_map = params_map
|
||||||
# error_map={
|
|
||||||
# (
|
def __call__(self, record):
|
||||||
# httpx.HTTPError,
|
kwargs = {}
|
||||||
# SSLCertVerificationError,
|
for param, (beaker_name, field_name) in self.params_map.items():
|
||||||
# SSLError,
|
kwargs[param] = getattr(record[beaker_name], field_name)
|
||||||
# ): "bad_requests"
|
return self.func(**kwargs)
|
||||||
# },
|
|
||||||
# )
|
def __repr__(self):
|
||||||
# recipe.add_conditional(
|
return f"ProcessRecord({self.func.__name__}, {self.params_map})"
|
||||||
# "with_tiktoken_count",
|
|
||||||
# lambda x: x["tiktoken_count"] < 8000,
|
|
||||||
# if_true="token_lt_8k",
|
def scrapeghost_response(record) -> ScrapeghostResponse:
|
||||||
# if_false="token_gt_8k",
|
sg = ghost.scrape(url_or_html=record["responses"].response_body)
|
||||||
# )
|
return ScrapeghostResponse(
|
||||||
|
total_cost=sg.total_cost, api_time=sg.api_time, data=sg.data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
recipe.add_transform(
|
||||||
|
"responses",
|
||||||
|
"scrapeghost_response",
|
||||||
|
scrapeghost_response,
|
||||||
|
whole_record=True,
|
||||||
|
error_map={
|
||||||
|
(TooManyTokens,): "scrapeghost_too_many_tokens",
|
||||||
|
(BadStop,): "scrapeghost_bad_stop",
|
||||||
|
(InvalidRequestError, ValueError, ParserError): "scrapeghost_invalid_request",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user