diff --git a/.gitignore b/.gitignore index 0d20b64..ba9db1a 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *.pyc +dist/ diff --git a/src/whsk/__init__.py b/src/whsk/__init__.py index fdf61dc..3fd687b 100644 --- a/src/whsk/__init__.py +++ b/src/whsk/__init__.py @@ -41,11 +41,29 @@ def parse_headers(headers: list[str]) -> dict: return header_dict -def make_request(url, headers, postdata): +def make_request(url, *, headers, user_agent, postdata): + """ + Helper function for redundant code between methods. + + Will take all parameters related to request and return + a 2-tuple of the raw response and the parsed response. + + The second parameter may be a: + - JSON dict + - lxml.html.HtmlElement + - lxml.etree.Element + """ header_dict = parse_headers(headers) - resp = httpx.request("GET", url, headers=header_dict, data=postdata) - # if resp.headers["content-type"] == "text/html": - root = lxml.html.fromstring(resp.text) + method = "GET" + if postdata: + method = "POST" + resp = httpx.request(method, url, headers=header_dict, data=postdata) + if resp.headers["content-type"] == "application/json": + root = resp.json() + elif "xml" in resp.headers["content-type"]: + root = lxml.etree.fromstring(resp.content) + else: + root = lxml.html.fromstring(resp.text) return resp, root @@ -95,9 +113,18 @@ def query( xpath: Annotated[str, opt["xpath"]] = "", ): """Run a one-off query against the URL""" - resp, root = make_request(url, headers, postdata) + resp, root = make_request( + url, headers=headers, user_agent=user_agent, postdata=postdata + ) + if not isinstance(root, lxml.html.HtmlElement): + typer.secho(f"Expecting HTML response, got:\n{root}", fg="red") + raise typer.Exit(1) selector, selected = parse_selectors(root, css, xpath) + if selector is None: + typer.secho("Must provide either --css or --xpath to query", fg="red") + raise typer.Exit(1) + for s in selected: print(s) @@ -113,7 +140,9 @@ def shell( ): """Launch an interactive Python shell for scraping""" - resp, root = make_request(url, headers, postdata) + resp, root = make_request( + url, headers=headers, user_agent=user_agent, postdata=postdata + ) selector, selected = parse_selectors(root, css, xpath) console = Console() @@ -125,7 +154,13 @@ def shell( ) table.add_row("[green]url[/green]", url) table.add_row("[green]resp[/green]", str(resp)) - table.add_row("[green]root[/green]", "lxml.html.Element") + # if this list of parsed types expands there's probably a better way + if isinstance(root, lxml.html.HtmlElement): + table.add_row("[green]root[/green]", "lxml.html.HtmlElement") + elif isinstance(root, lxml.etree._Element): + table.add_row("[green]root[/green]", "lxml.etree.Element (XML)") + elif isinstance(root, dict): + table.add_row("[green]root[/green]", "dict (JSON)") if selector: table.add_row("[green]selector[/green]", selector) table.add_row("[green]selected[/green]", f"{len(selected)} elements")