convert to beakers 0.1
This commit is contained in:
		
							parent
							
								
									b357a6d1d5
								
							
						
					
					
						commit
						c510a4d5cc
					
				
					 7 changed files with 183 additions and 236 deletions
				
			
		| 
						 | 
					@ -1,71 +0,0 @@
 | 
				
			||||||
import datetime
 | 
					 | 
				
			||||||
from pydantic import BaseModel
 | 
					 | 
				
			||||||
import lxml
 | 
					 | 
				
			||||||
from beakers import Recipe
 | 
					 | 
				
			||||||
from beakers.http import HttpRequest
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ArticleURL(BaseModel):
 | 
					 | 
				
			||||||
    url: str
 | 
					 | 
				
			||||||
    source: str
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class HttpResponse(BaseModel):
 | 
					 | 
				
			||||||
    url: str
 | 
					 | 
				
			||||||
    status: int
 | 
					 | 
				
			||||||
    content: str
 | 
					 | 
				
			||||||
    retrieved_at: datetime.datetime
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Article(BaseModel):
 | 
					 | 
				
			||||||
    title: str
 | 
					 | 
				
			||||||
    text: str
 | 
					 | 
				
			||||||
    image_urls: list[str]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_npr(item) -> bool:
 | 
					 | 
				
			||||||
    return item.url.source == "npr"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def extract_npr_article(item) -> Article:
 | 
					 | 
				
			||||||
    doc = lxml.html.fromstring(item.response.content)
 | 
					 | 
				
			||||||
    title = doc.cssselect(".story-title")[0].text()
 | 
					 | 
				
			||||||
    text = doc.cssselect(".paragraphs-container").text()
 | 
					 | 
				
			||||||
    return Article(
 | 
					 | 
				
			||||||
        title=title,
 | 
					 | 
				
			||||||
        text=text,
 | 
					 | 
				
			||||||
        image_urls=[],
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
recipe = Recipe("newsface", "newsface.db")
 | 
					 | 
				
			||||||
recipe.add_beaker("url", ArticleURL)
 | 
					 | 
				
			||||||
recipe.add_beaker("response", HttpResponse)
 | 
					 | 
				
			||||||
recipe.add_beaker("article", Article)
 | 
					 | 
				
			||||||
recipe.add_transform("url", "response", HttpRequest)
 | 
					 | 
				
			||||||
recipe.add_conditional(
 | 
					 | 
				
			||||||
    "response",
 | 
					 | 
				
			||||||
    is_npr,
 | 
					 | 
				
			||||||
    "npr_article",
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
recipe.add_transform(
 | 
					 | 
				
			||||||
    "npr_article",
 | 
					 | 
				
			||||||
    "article",
 | 
					 | 
				
			||||||
    extract_npr_article,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
recipe.add_transform("archived_article")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
npr_examples = [
 | 
					 | 
				
			||||||
    ArticleURL(url="https://text.npr.org/1186770075", source="npr"),
 | 
					 | 
				
			||||||
    ArticleURL(url="https://text.npr.org/1186525577", source="npr"),
 | 
					 | 
				
			||||||
    ArticleURL(url="https://text.npr.org/1185780577", source="npr"),
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
other = [
 | 
					 | 
				
			||||||
    ArticleURL(url="https://nytimes.com", source="nytimes"),
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
recipe.add_seed(
 | 
					 | 
				
			||||||
    "url",
 | 
					 | 
				
			||||||
    npr_examples + other,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,83 +0,0 @@
 | 
				
			||||||
from ssl import SSLCertVerificationError, SSLError
 | 
					 | 
				
			||||||
import httpx
 | 
					 | 
				
			||||||
import tiktoken
 | 
					 | 
				
			||||||
import lxml.html
 | 
					 | 
				
			||||||
from lxml.etree import ParserError
 | 
					 | 
				
			||||||
from beakers.beakers import Beaker
 | 
					 | 
				
			||||||
from beakers.recipe import Recipe
 | 
					 | 
				
			||||||
from scrapeghost import SchemaScraper
 | 
					 | 
				
			||||||
from scrapeghost.preprocessors import CleanHTML
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def add_response(obj_with_url):
 | 
					 | 
				
			||||||
    url = obj_with_url["url"]
 | 
					 | 
				
			||||||
    async with httpx.AsyncClient() as client:
 | 
					 | 
				
			||||||
        response = await client.get(url)
 | 
					 | 
				
			||||||
    return {
 | 
					 | 
				
			||||||
        "url": url,
 | 
					 | 
				
			||||||
        "status_code": response.status_code,
 | 
					 | 
				
			||||||
        "response_body": response.text,
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def tiktoken_count(response):
 | 
					 | 
				
			||||||
    if response["status_code"] != 200:
 | 
					 | 
				
			||||||
        raise ValueError("response status code is not 200")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    html = response["response_body"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # clean the html
 | 
					 | 
				
			||||||
    cleaner = CleanHTML()
 | 
					 | 
				
			||||||
    encoding = tiktoken.get_encoding("cl100k_base")
 | 
					 | 
				
			||||||
    doc = lxml.html.fromstring(html)
 | 
					 | 
				
			||||||
    (doc,) = cleaner(doc)  # returns a 1-item list
 | 
					 | 
				
			||||||
    html_again = lxml.html.tostring(doc, encoding="unicode")
 | 
					 | 
				
			||||||
    tokens = len(encoding.encode(html_again))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    response["tiktoken_count"] = tokens
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# current thinking, beakers exist within a recipe
 | 
					 | 
				
			||||||
recipe = Recipe("fetch urls", "url_example.db")
 | 
					 | 
				
			||||||
recipe.add_beaker("agencies")
 | 
					 | 
				
			||||||
recipe.add_beaker("responses")
 | 
					 | 
				
			||||||
recipe.add_beaker("bad_requests")
 | 
					 | 
				
			||||||
recipe.add_beaker("good_urls", temp=True)
 | 
					 | 
				
			||||||
recipe.add_beaker("missing_urls", temp=True)
 | 
					 | 
				
			||||||
recipe.add_beaker("with_tiktoken_count")
 | 
					 | 
				
			||||||
recipe.add_beaker("no_tiktoken_count", temp=True)
 | 
					 | 
				
			||||||
recipe.add_beaker("token_lt_8k", temp=True)
 | 
					 | 
				
			||||||
recipe.add_beaker("token_gt_8k", temp=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
recipe.add_conditional(
 | 
					 | 
				
			||||||
    "agencies",
 | 
					 | 
				
			||||||
    lambda x: x["url"].startswith("http"),
 | 
					 | 
				
			||||||
    if_true="good_urls",
 | 
					 | 
				
			||||||
    if_false="missing_urls",
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
recipe.add_transform(
 | 
					 | 
				
			||||||
    "good_urls",
 | 
					 | 
				
			||||||
    "responses",
 | 
					 | 
				
			||||||
    add_response,
 | 
					 | 
				
			||||||
    error_map={
 | 
					 | 
				
			||||||
        (
 | 
					 | 
				
			||||||
            httpx.HTTPError,
 | 
					 | 
				
			||||||
            SSLCertVerificationError,
 | 
					 | 
				
			||||||
            SSLError,
 | 
					 | 
				
			||||||
        ): "bad_requests"
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
recipe.add_transform(
 | 
					 | 
				
			||||||
    "responses",
 | 
					 | 
				
			||||||
    "with_tiktoken_count",
 | 
					 | 
				
			||||||
    tiktoken_count,
 | 
					 | 
				
			||||||
    error_map={(ValueError, ParserError): "no_tiktoken_count"},
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
recipe.add_conditional(
 | 
					 | 
				
			||||||
    "with_tiktoken_count",
 | 
					 | 
				
			||||||
    lambda x: x["tiktoken_count"] < 8000,
 | 
					 | 
				
			||||||
    if_true="token_lt_8k",
 | 
					 | 
				
			||||||
    if_false="token_gt_8k",
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
							
								
								
									
										82
									
								
								foiaghost.py
									
									
									
									
									
								
							
							
						
						
									
										82
									
								
								foiaghost.py
									
									
									
									
									
								
							| 
						 | 
					@ -1,82 +0,0 @@
 | 
				
			||||||
schema = {
 | 
					 | 
				
			||||||
    "public_records_email": "email",
 | 
					 | 
				
			||||||
    "public_records_address": "str",
 | 
					 | 
				
			||||||
    "public_records_phone": "555-555-5555",
 | 
					 | 
				
			||||||
    "public_records_fax": "555-555-5555",
 | 
					 | 
				
			||||||
    "public_records_web": "url",
 | 
					 | 
				
			||||||
    "general_contact_phone": "555-555-5555",
 | 
					 | 
				
			||||||
    "general_contact_address": "str",
 | 
					 | 
				
			||||||
    "foia_guide": "url",
 | 
					 | 
				
			||||||
    "public_reading_room": "url",
 | 
					 | 
				
			||||||
    "agency_logo": "url",
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
extra_instructions = """
 | 
					 | 
				
			||||||
The fields that begin with public_records should refer to contact information specific to FOIA/Public Information/Freedom of Information requests.
 | 
					 | 
				
			||||||
The fields that begin with general_contact should refer to contact information for the agency in general.
 | 
					 | 
				
			||||||
If a field is not found in the HTML, leave it as null in the JSON.
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# create a scraper w/ a sqlite cache
 | 
					 | 
				
			||||||
scraper = Scraper(requests_per_minute=600)
 | 
					 | 
				
			||||||
scraper.cache_storage = SQLiteCache("cache.sqlite")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# create a scrapeghost
 | 
					 | 
				
			||||||
ghost = SchemaScraper(
 | 
					 | 
				
			||||||
    schema=schema,
 | 
					 | 
				
			||||||
    extra_preprocessors=[],
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
agencies = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def fetch_urls(urls):
 | 
					 | 
				
			||||||
    async with httpx.AsyncClient() as client:
 | 
					 | 
				
			||||||
        tasks = [client.get(url) for url in urls]
 | 
					 | 
				
			||||||
        responses = await asyncio.gather(*tasks)
 | 
					 | 
				
			||||||
        return responses
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def worker(queue, batch_size):
 | 
					 | 
				
			||||||
    with open("results.csv", "w") as outf:
 | 
					 | 
				
			||||||
        out = csv.DictWriter(
 | 
					 | 
				
			||||||
            outf, fieldnames=["id", "url", "status"] + list(schema.keys())
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            urls = []
 | 
					 | 
				
			||||||
            for _ in range(batch_size):
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    url = await queue.get()
 | 
					 | 
				
			||||||
                    urls.append(url)
 | 
					 | 
				
			||||||
                except asyncio.QueueEmpty:
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
            if len(urls) > 0:
 | 
					 | 
				
			||||||
                responses = await fetch_urls(urls, batch_size)
 | 
					 | 
				
			||||||
                async yield responses
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def main():
 | 
					 | 
				
			||||||
    batch_size = 5
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    with open("agencies.csv", "r") as inf, 
 | 
					 | 
				
			||||||
        agencies = csv.DictReader(inf)
 | 
					 | 
				
			||||||
        # grouper -> https://docs.python.org/3/library/itertools.html#itertools-recipes
 | 
					 | 
				
			||||||
                    except Exception as e:
 | 
					 | 
				
			||||||
                        print(e)
 | 
					 | 
				
			||||||
                        out.writerow(
 | 
					 | 
				
			||||||
                            {
 | 
					 | 
				
			||||||
                                "id": agency["id"],
 | 
					 | 
				
			||||||
                                "url": agency["url"],
 | 
					 | 
				
			||||||
                                "status": "ERROR",
 | 
					 | 
				
			||||||
                            }
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
                    result = ghost.scrape(page.text)
 | 
					 | 
				
			||||||
                    out.writerow(
 | 
					 | 
				
			||||||
                        result
 | 
					 | 
				
			||||||
                        + {"id": agency["id"], "url": agency["url"], "status": "OK"}
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    main()
 | 
					 | 
				
			||||||
							
								
								
									
										82
									
								
								old.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								old.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,82 @@
 | 
				
			||||||
 | 
					# schema = {
 | 
				
			||||||
 | 
					#     "public_records_email": "email",
 | 
				
			||||||
 | 
					#     "public_records_address": "str",
 | 
				
			||||||
 | 
					#     "public_records_phone": "555-555-5555",
 | 
				
			||||||
 | 
					#     "public_records_fax": "555-555-5555",
 | 
				
			||||||
 | 
					#     "public_records_web": "url",
 | 
				
			||||||
 | 
					#     "general_contact_phone": "555-555-5555",
 | 
				
			||||||
 | 
					#     "general_contact_address": "str",
 | 
				
			||||||
 | 
					#     "foia_guide": "url",
 | 
				
			||||||
 | 
					#     "public_reading_room": "url",
 | 
				
			||||||
 | 
					#     "agency_logo": "url",
 | 
				
			||||||
 | 
					# }
 | 
				
			||||||
 | 
					# extra_instructions = """
 | 
				
			||||||
 | 
					# The fields that begin with public_records should refer to contact information specific to FOIA/Public Information/Freedom of Information requests.
 | 
				
			||||||
 | 
					# The fields that begin with general_contact should refer to contact information for the agency in general.
 | 
				
			||||||
 | 
					# If a field is not found in the HTML, leave it as null in the JSON.
 | 
				
			||||||
 | 
					# """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# # create a scraper w/ a sqlite cache
 | 
				
			||||||
 | 
					# scraper = Scraper(requests_per_minute=600)
 | 
				
			||||||
 | 
					# scraper.cache_storage = SQLiteCache("cache.sqlite")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# # create a scrapeghost
 | 
				
			||||||
 | 
					# ghost = SchemaScraper(
 | 
				
			||||||
 | 
					#     schema=schema,
 | 
				
			||||||
 | 
					#     extra_preprocessors=[],
 | 
				
			||||||
 | 
					# )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# agencies = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# async def fetch_urls(urls):
 | 
				
			||||||
 | 
					#     async with httpx.AsyncClient() as client:
 | 
				
			||||||
 | 
					#         tasks = [client.get(url) for url in urls]
 | 
				
			||||||
 | 
					#         responses = await asyncio.gather(*tasks)
 | 
				
			||||||
 | 
					#         return responses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# async def worker(queue, batch_size):
 | 
				
			||||||
 | 
					#     with open("results.csv", "w") as outf:
 | 
				
			||||||
 | 
					#         out = csv.DictWriter(
 | 
				
			||||||
 | 
					#             outf, fieldnames=["id", "url", "status"] + list(schema.keys())
 | 
				
			||||||
 | 
					#         )
 | 
				
			||||||
 | 
					#         while True:
 | 
				
			||||||
 | 
					#             urls = []
 | 
				
			||||||
 | 
					#             for _ in range(batch_size):
 | 
				
			||||||
 | 
					#                 try:
 | 
				
			||||||
 | 
					#                     url = await queue.get()
 | 
				
			||||||
 | 
					#                     urls.append(url)
 | 
				
			||||||
 | 
					#                 except asyncio.QueueEmpty:
 | 
				
			||||||
 | 
					#                     break
 | 
				
			||||||
 | 
					#             if len(urls) > 0:
 | 
				
			||||||
 | 
					#                 responses = await fetch_urls(urls, batch_size)
 | 
				
			||||||
 | 
					#                 async yield responses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# async def main():
 | 
				
			||||||
 | 
					#     batch_size = 5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     with open("agencies.csv", "r") as inf,
 | 
				
			||||||
 | 
					#         agencies = csv.DictReader(inf)
 | 
				
			||||||
 | 
					#         # grouper -> https://docs.python.org/3/library/itertools.html#itertools-recipes
 | 
				
			||||||
 | 
					#                     except Exception as e:
 | 
				
			||||||
 | 
					#                         print(e)
 | 
				
			||||||
 | 
					#                         out.writerow(
 | 
				
			||||||
 | 
					#                             {
 | 
				
			||||||
 | 
					#                                 "id": agency["id"],
 | 
				
			||||||
 | 
					#                                 "url": agency["url"],
 | 
				
			||||||
 | 
					#                                 "status": "ERROR",
 | 
				
			||||||
 | 
					#                             }
 | 
				
			||||||
 | 
					#                         )
 | 
				
			||||||
 | 
					#                         continue
 | 
				
			||||||
 | 
					#                     result = ghost.scrape(page.text)
 | 
				
			||||||
 | 
					#                     out.writerow(
 | 
				
			||||||
 | 
					#                         result
 | 
				
			||||||
 | 
					#                         + {"id": agency["id"], "url": agency["url"], "status": "OK"}
 | 
				
			||||||
 | 
					#                     )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# if __name__ == "__main__":
 | 
				
			||||||
 | 
					#     main()
 | 
				
			||||||
							
								
								
									
										0
									
								
								src/foiaghost/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/foiaghost/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										15
									
								
								src/foiaghost/models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								src/foiaghost/models.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,15 @@
 | 
				
			||||||
 | 
					from pydantic import BaseModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Agency(BaseModel):
 | 
				
			||||||
 | 
					    id: str
 | 
				
			||||||
 | 
					    url: str
 | 
				
			||||||
 | 
					    name: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class URL(BaseModel):
 | 
				
			||||||
 | 
					    url: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Int(BaseModel):
 | 
				
			||||||
 | 
					    int: int
 | 
				
			||||||
							
								
								
									
										86
									
								
								src/foiaghost/pipeline.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								src/foiaghost/pipeline.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,86 @@
 | 
				
			||||||
 | 
					from ssl import SSLCertVerificationError, SSLError
 | 
				
			||||||
 | 
					import httpx
 | 
				
			||||||
 | 
					import tiktoken
 | 
				
			||||||
 | 
					import lxml.html
 | 
				
			||||||
 | 
					from lxml.etree import ParserError
 | 
				
			||||||
 | 
					from databeakers import Pipeline
 | 
				
			||||||
 | 
					from databeakers.http import HttpRequest, HttpResponse
 | 
				
			||||||
 | 
					from scrapeghost import SchemaScraper
 | 
				
			||||||
 | 
					from scrapeghost.preprocessors import CleanHTML
 | 
				
			||||||
 | 
					from .models import Agency, URL, Int
 | 
				
			||||||
 | 
					import csv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CSVSource:
 | 
				
			||||||
 | 
					    def __init__(self, filename, datatype):
 | 
				
			||||||
 | 
					        self.filename = filename
 | 
				
			||||||
 | 
					        self.datatype = datatype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self):
 | 
				
			||||||
 | 
					        with open(self.filename) as inf:
 | 
				
			||||||
 | 
					            reader = csv.DictReader(inf)
 | 
				
			||||||
 | 
					            for line in reader:
 | 
				
			||||||
 | 
					                yield self.datatype(**line)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def tiktoken_count(response):
 | 
				
			||||||
 | 
					    if response["status_code"] != 200:
 | 
				
			||||||
 | 
					        raise ValueError("response status code is not 200")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    html = response["response_body"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # clean the html
 | 
				
			||||||
 | 
					    cleaner = CleanHTML()
 | 
				
			||||||
 | 
					    encoding = tiktoken.get_encoding("cl100k_base")
 | 
				
			||||||
 | 
					    doc = lxml.html.fromstring(html)
 | 
				
			||||||
 | 
					    (doc,) = cleaner(doc)  # returns a 1-item list
 | 
				
			||||||
 | 
					    html_again = lxml.html.tostring(doc, encoding="unicode")
 | 
				
			||||||
 | 
					    tokens = len(encoding.encode(html_again))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response["tiktoken_count"] = tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# current thinking, beakers exist within a recipe
 | 
				
			||||||
 | 
					recipe = Pipeline("foiaghost", "foiaghost.db")
 | 
				
			||||||
 | 
					recipe.add_beaker("agency", Agency)
 | 
				
			||||||
 | 
					recipe.add_beaker("good_urls", URL)
 | 
				
			||||||
 | 
					recipe.add_transform("agency", "good_urls", lambda x: x["url"].startswith("http"))
 | 
				
			||||||
 | 
					recipe.add_beaker("responses", HttpResponse)
 | 
				
			||||||
 | 
					recipe.add_transform("good_urls", "responses", HttpRequest)
 | 
				
			||||||
 | 
					recipe.add_beaker("tiktoken_count", Int)
 | 
				
			||||||
 | 
					recipe.add_transform(
 | 
				
			||||||
 | 
					    "responses",
 | 
				
			||||||
 | 
					    "tiktoken_count",
 | 
				
			||||||
 | 
					    tiktoken_count,
 | 
				
			||||||
 | 
					    error_map={(ValueError, ParserError): "no_tiktoken_count"},
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					recipe.add_seed(
 | 
				
			||||||
 | 
					    "agencies",
 | 
				
			||||||
 | 
					    "agency",
 | 
				
			||||||
 | 
					    CSVSource("agencies.csv", Agency),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# recipe.add_beaker("token_lt_8k", temp=True)
 | 
				
			||||||
 | 
					# recipe.add_beaker("token_gt_8k", temp=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# recipe.add_transform(
 | 
				
			||||||
 | 
					#     "good_urls",
 | 
				
			||||||
 | 
					#     "responses",
 | 
				
			||||||
 | 
					#     add_response,
 | 
				
			||||||
 | 
					#     error_map={
 | 
				
			||||||
 | 
					#         (
 | 
				
			||||||
 | 
					#             httpx.HTTPError,
 | 
				
			||||||
 | 
					#             SSLCertVerificationError,
 | 
				
			||||||
 | 
					#             SSLError,
 | 
				
			||||||
 | 
					#         ): "bad_requests"
 | 
				
			||||||
 | 
					#     },
 | 
				
			||||||
 | 
					# )
 | 
				
			||||||
 | 
					# recipe.add_conditional(
 | 
				
			||||||
 | 
					#     "with_tiktoken_count",
 | 
				
			||||||
 | 
					#     lambda x: x["tiktoken_count"] < 8000,
 | 
				
			||||||
 | 
					#     if_true="token_lt_8k",
 | 
				
			||||||
 | 
					#     if_false="token_gt_8k",
 | 
				
			||||||
 | 
					# )
 | 
				
			||||||
		Loading…
	
		Reference in a new issue