This commit is contained in:
James Turk 2022-11-10 21:26:09 -06:00
parent 806b8873ec
commit a7e3fc63b3
12 changed files with 751 additions and 554 deletions

View File

@ -13,39 +13,39 @@ class OvercookedError(Exception):
""" """
Exception for trying to operate on a Recipe that has been finished. Exception for trying to operate on a Recipe that has been finished.
""" """
pass pass
class Recipe(object): class Recipe(object):
def __init__(self, *filter_args, **kwargs): def __init__(self, *filter_args, **kwargs):
self.finished = False self.finished = False
self.filters = [] self.filters = []
for filter in filter_args: for filter in filter_args:
if hasattr(filter, 'filters'): if hasattr(filter, "filters"):
self.filters.extend(filter.filters) self.filters.extend(filter.filters)
else: else:
self.filters.append(filter) self.filters.append(filter)
self.error_stream = kwargs.get('error_stream') self.error_stream = kwargs.get("error_stream")
if self.error_stream and not isinstance(self.error_stream, Recipe): if self.error_stream and not isinstance(self.error_stream, Recipe):
if isinstance(self.error_stream, filters.Filter): if isinstance(self.error_stream, filters.Filter):
self.error_stream = Recipe(self.error_stream) self.error_stream = Recipe(self.error_stream)
elif hasattr(self.error_stream, '__iter__'): elif hasattr(self.error_stream, "__iter__"):
self.error_stream = Recipe(*self.error_stream) self.error_stream = Recipe(*self.error_stream)
else: else:
raise SaucebrushError('error_stream must be either a filter' raise SaucebrushError(
' or an iterable of filters') "error_stream must be either a filter" " or an iterable of filters"
)
def reject_record(self, record, exception): def reject_record(self, record, exception):
if self.error_stream: if self.error_stream:
self.error_stream.run([{'record': record, self.error_stream.run([{"record": record, "exception": repr(exception)}])
'exception': repr(exception)}])
def run(self, source): def run(self, source):
if self.finished: if self.finished:
raise OvercookedError('run() called on finished recipe') raise OvercookedError("run() called on finished recipe")
# connect datapath # connect datapath
data = source data = source
@ -58,7 +58,7 @@ class Recipe(object):
def done(self): def done(self):
if self.finished: if self.finished:
raise OvercookedError('done() called on finished recipe') raise OvercookedError("done() called on finished recipe")
self.finished = True self.finished = True
@ -70,12 +70,11 @@ class Recipe(object):
try: try:
filter_.done() filter_.done()
except AttributeError: except AttributeError:
pass # don't care if there isn't a done method pass # don't care if there isn't a done method
def run_recipe(source, *filter_args, **kwargs): def run_recipe(source, *filter_args, **kwargs):
""" Process data, taking it from a source and applying any number of filters """Process data, taking it from a source and applying any number of filters"""
"""
r = Recipe(*filter_args, **kwargs) r = Recipe(*filter_args, **kwargs)
r.run(source) r.run(source)

View File

@ -2,49 +2,53 @@
Saucebrush Emitters are filters that instead of modifying the record, output Saucebrush Emitters are filters that instead of modifying the record, output
it in some manner. it in some manner.
""" """
from __future__ import unicode_literals
from saucebrush.filters import Filter from saucebrush.filters import Filter
class Emitter(Filter): class Emitter(Filter):
""" ABC for emitters """ABC for emitters
All derived emitters must provide an emit_record(self, record) that All derived emitters must provide an emit_record(self, record) that
takes a single record (python dictionary). takes a single record (python dictionary).
Emitters can optionally define a done() method that is called after Emitters can optionally define a done() method that is called after
all records are processed (allowing database flushes, or printing of all records are processed (allowing database flushes, or printing of
aggregate data). aggregate data).
""" """
def process_record(self, record): def process_record(self, record):
self.emit_record(record) self.emit_record(record)
return record return record
def emit_record(self, record): def emit_record(self, record):
""" Abstract method to be overridden. """Abstract method to be overridden.
Called with a single record, should "emit" the record unmodified. Called with a single record, should "emit" the record unmodified.
""" """
raise NotImplementedError('emit_record not defined in ' + raise NotImplementedError(
self.__class__.__name__) "emit_record not defined in " + self.__class__.__name__
)
def done(self): def done(self):
""" No-op Method to be overridden. """No-op Method to be overridden.
Called when all processing is complete Called when all processing is complete
""" """
pass pass
class DebugEmitter(Emitter): class DebugEmitter(Emitter):
""" Emitter that prints raw records to a file, useful for debugging. """Emitter that prints raw records to a file, useful for debugging.
DebugEmitter() by default prints to stdout. DebugEmitter() by default prints to stdout.
DebugEmitter(open('test', 'w')) would print to a file named test DebugEmitter(open('test', 'w')) would print to a file named test
""" """
def __init__(self, outfile=None): def __init__(self, outfile=None):
super(DebugEmitter, self).__init__() super(DebugEmitter, self).__init__()
if not outfile: if not outfile:
import sys import sys
self._outfile = sys.stdout self._outfile = sys.stdout
else: else:
self._outfile = outfile self._outfile = outfile
@ -54,12 +58,12 @@ class DebugEmitter(Emitter):
class CountEmitter(Emitter): class CountEmitter(Emitter):
""" Emitter that writes the record count to a file-like object. """Emitter that writes the record count to a file-like object.
CountEmitter() by default writes to stdout. CountEmitter() by default writes to stdout.
CountEmitter(outfile=open('text', 'w')) would print to a file name test. CountEmitter(outfile=open('text', 'w')) would print to a file name test.
CountEmitter(every=1000000) would write the count every 1,000,000 records. CountEmitter(every=1000000) would write the count every 1,000,000 records.
CountEmitter(every=100, of=2000) would write "<count> of 2000" every 100 records. CountEmitter(every=100, of=2000) would write "<count> of 2000" every 100 records.
""" """
def __init__(self, every=1000, of=None, outfile=None, format=None): def __init__(self, every=1000, of=None, outfile=None, format=None):
@ -68,6 +72,7 @@ class CountEmitter(Emitter):
if not outfile: if not outfile:
import sys import sys
self._outfile = sys.stdout self._outfile = sys.stdout
else: else:
self._outfile = outfile self._outfile = outfile
@ -84,7 +89,7 @@ class CountEmitter(Emitter):
self.count = 0 self.count = 0
def format(self): def format(self):
return self._format % {'count': self.count, 'of': self._of} return self._format % {"count": self.count, "of": self._of}
def emit_record(self, record): def emit_record(self, record):
self.count += 1 self.count += 1
@ -96,15 +101,16 @@ class CountEmitter(Emitter):
class CSVEmitter(Emitter): class CSVEmitter(Emitter):
""" Emitter that writes records to a CSV file. """Emitter that writes records to a CSV file.
CSVEmitter(open('output.csv','w'), ('id', 'name', 'phone')) writes all CSVEmitter(open('output.csv','w'), ('id', 'name', 'phone')) writes all
records to a csvfile with the columns in the order specified. records to a csvfile with the columns in the order specified.
""" """
def __init__(self, csvfile, fieldnames): def __init__(self, csvfile, fieldnames):
super(CSVEmitter, self).__init__() super(CSVEmitter, self).__init__()
import csv import csv
self._dictwriter = csv.DictWriter(csvfile, fieldnames) self._dictwriter = csv.DictWriter(csvfile, fieldnames)
# write header row # write header row
header_row = dict(zip(fieldnames, fieldnames)) header_row = dict(zip(fieldnames, fieldnames))
@ -115,36 +121,43 @@ class CSVEmitter(Emitter):
class SqliteEmitter(Emitter): class SqliteEmitter(Emitter):
""" Emitter that writes records to a SQLite database. """Emitter that writes records to a SQLite database.
SqliteEmitter('addressbook.db', 'friend') writes all records to the SqliteEmitter('addressbook.db', 'friend') writes all records to the
friends table in the SQLite database named addressbook.db friends table in the SQLite database named addressbook.db
(To have the emitter create the table, the fieldnames should be passed (To have the emitter create the table, the fieldnames should be passed
as a third parameter to SqliteEmitter.) as a third parameter to SqliteEmitter.)
""" """
def __init__(self, dbname, table_name, fieldnames=None, replace=False, quiet=False): def __init__(self, dbname, table_name, fieldnames=None, replace=False, quiet=False):
super(SqliteEmitter, self).__init__() super(SqliteEmitter, self).__init__()
import sqlite3 import sqlite3
self._conn = sqlite3.connect(dbname) self._conn = sqlite3.connect(dbname)
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._table_name = table_name self._table_name = table_name
self._replace = replace self._replace = replace
self._quiet = quiet self._quiet = quiet
if fieldnames: if fieldnames:
create = "CREATE TABLE IF NOT EXISTS %s (%s)" % (table_name, create = "CREATE TABLE IF NOT EXISTS %s (%s)" % (
', '.join([' '.join((field, 'TEXT')) for field in fieldnames])) table_name,
", ".join([" ".join((field, "TEXT")) for field in fieldnames]),
)
self._cursor.execute(create) self._cursor.execute(create)
def emit_record(self, record): def emit_record(self, record):
import sqlite3 import sqlite3
# input should be escaped with ? if data isn't trusted # input should be escaped with ? if data isn't trusted
qmarks = ','.join(('?',) * len(record)) qmarks = ",".join(("?",) * len(record))
insert = 'INSERT OR REPLACE' if self._replace else 'INSERT' insert = "INSERT OR REPLACE" if self._replace else "INSERT"
insert = '%s INTO %s (%s) VALUES (%s)' % (insert, self._table_name, insert = "%s INTO %s (%s) VALUES (%s)" % (
','.join(record.keys()), insert,
qmarks) self._table_name,
",".join(record.keys()),
qmarks,
)
try: try:
self._cursor.execute(insert, list(record.values())) self._cursor.execute(insert, list(record.values()))
except sqlite3.IntegrityError as ie: except sqlite3.IntegrityError as ie:
@ -158,14 +171,14 @@ class SqliteEmitter(Emitter):
class SqlDumpEmitter(Emitter): class SqlDumpEmitter(Emitter):
""" Emitter that writes SQL INSERT statements. """Emitter that writes SQL INSERT statements.
The output generated by the SqlDumpEmitter is intended to be used to The output generated by the SqlDumpEmitter is intended to be used to
populate a mySQL database. populate a mySQL database.
SqlDumpEmitter(open('addresses.sql', 'w'), 'friend', ('name', 'phone')) SqlDumpEmitter(open('addresses.sql', 'w'), 'friend', ('name', 'phone'))
writes statements to addresses.sql to insert the data writes statements to addresses.sql to insert the data
into the friends table. into the friends table.
""" """
def __init__(self, outfile, table_name, fieldnames): def __init__(self, outfile, table_name, fieldnames):
@ -173,11 +186,14 @@ class SqlDumpEmitter(Emitter):
self._fieldnames = fieldnames self._fieldnames = fieldnames
if not outfile: if not outfile:
import sys import sys
self._outfile = sys.stderr self._outfile = sys.stderr
else: else:
self._outfile = outfile self._outfile = outfile
self._insert_str = "INSERT INTO `%s` (`%s`) VALUES (%%s);\n" % ( self._insert_str = "INSERT INTO `%s` (`%s`) VALUES (%%s);\n" % (
table_name, '`,`'.join(fieldnames)) table_name,
"`,`".join(fieldnames),
)
def quote(self, item): def quote(self, item):
@ -190,29 +206,31 @@ class SqlDumpEmitter(Emitter):
types = (str,) types = (str,)
if isinstance(item, types): if isinstance(item, types):
item = item.replace("\\","\\\\").replace("'","\\'").replace(chr(0),'0') item = item.replace("\\", "\\\\").replace("'", "\\'").replace(chr(0), "0")
return "'%s'" % item return "'%s'" % item
return "%s" % item return "%s" % item
def emit_record(self, record): def emit_record(self, record):
quoted_data = [self.quote(record[field]) for field in self._fieldnames] quoted_data = [self.quote(record[field]) for field in self._fieldnames]
self._outfile.write(self._insert_str % ','.join(quoted_data)) self._outfile.write(self._insert_str % ",".join(quoted_data))
class DjangoModelEmitter(Emitter): class DjangoModelEmitter(Emitter):
""" Emitter that populates a table corresponding to a django model. """Emitter that populates a table corresponding to a django model.
Takes a django settings file, app label and model name and uses django Takes a django settings file, app label and model name and uses django
to insert the records into the appropriate table. to insert the records into the appropriate table.
DjangoModelEmitter('settings.py', 'addressbook', 'friend') writes DjangoModelEmitter('settings.py', 'addressbook', 'friend') writes
records to addressbook.models.friend model using database settings records to addressbook.models.friend model using database settings
from settings.py. from settings.py.
""" """
def __init__(self, dj_settings, app_label, model_name): def __init__(self, dj_settings, app_label, model_name):
super(DjangoModelEmitter, self).__init__() super(DjangoModelEmitter, self).__init__()
from saucebrush.utils import get_django_model from saucebrush.utils import get_django_model
self._dbmodel = get_django_model(dj_settings, app_label, model_name) self._dbmodel = get_django_model(dj_settings, app_label, model_name)
if not self._dbmodel: if not self._dbmodel:
raise Exception("No such model: %s %s" % (app_label, model_name)) raise Exception("No such model: %s %s" % (app_label, model_name))
@ -222,19 +240,30 @@ class DjangoModelEmitter(Emitter):
class MongoDBEmitter(Emitter): class MongoDBEmitter(Emitter):
""" Emitter that creates a document in a MongoDB datastore """Emitter that creates a document in a MongoDB datastore
The names of the database and collection in which the records will The names of the database and collection in which the records will
be inserted are required parameters. The host and port are optional, be inserted are required parameters. The host and port are optional,
defaulting to 'localhost' and 27017, repectively. defaulting to 'localhost' and 27017, repectively.
""" """
def __init__(self, database, collection, host='localhost', port=27017, drop_collection=False, conn=None):
def __init__(
self,
database,
collection,
host="localhost",
port=27017,
drop_collection=False,
conn=None,
):
super(MongoDBEmitter, self).__init__() super(MongoDBEmitter, self).__init__()
from pymongo.database import Database from pymongo.database import Database
if not isinstance(database, Database): if not isinstance(database, Database):
if not conn: if not conn:
from pymongo.connection import Connection from pymongo.connection import Connection
conn = Connection(host, port) conn = Connection(host, port)
db = conn[database] db = conn[database]
else: else:
@ -249,12 +278,13 @@ class MongoDBEmitter(Emitter):
class LoggingEmitter(Emitter): class LoggingEmitter(Emitter):
""" Emitter that logs to a Python logging.Logger instance. """Emitter that logs to a Python logging.Logger instance.
The msg_template will be passed the record being emitted as The msg_template will be passed the record being emitted as
a format parameter. The resulting message will get logged a format parameter. The resulting message will get logged
at the provided level. at the provided level.
""" """
import logging import logging
def __init__(self, logger, msg_template, level=logging.DEBUG): def __init__(self, logger, msg_template, level=logging.DEBUG):

View File

@ -12,26 +12,28 @@ import re
import time import time
###################### ######################
## Abstract Filters ## # Abstract Filters #
###################### ######################
class Filter(object):
""" ABC for filters that operate on records.
All derived filters must provide a process_record(self, record) that class Filter(object):
takes a single record (python dictionary) and returns a result. """ABC for filters that operate on records.
All derived filters must provide a process_record(self, record) that
takes a single record (python dictionary) and returns a result.
""" """
def process_record(self, record): def process_record(self, record):
""" Abstract method to be overridden. """Abstract method to be overridden.
Called with a single record, should return modified record. Called with a single record, should return modified record.
""" """
raise NotImplementedError('process_record not defined in ' + raise NotImplementedError(
self.__class__.__name__) "process_record not defined in " + self.__class__.__name__
)
def reject_record(self, record, exception): def reject_record(self, record, exception):
recipe = getattr(self, '_recipe') recipe = getattr(self, "_recipe")
if recipe: if recipe:
recipe.reject_record(record, exception) recipe.reject_record(record, exception)
@ -47,11 +49,11 @@ class Filter(object):
class YieldFilter(Filter): class YieldFilter(Filter):
""" ABC for defining filters where process_record yields. """ABC for defining filters where process_record yields.
If process_record cannot return exactly one result for every record If process_record cannot return exactly one result for every record
it is passed, it should yield back as many records as needed and the it is passed, it should yield back as many records as needed and the
filter must derive from YieldFilter. filter must derive from YieldFilter.
""" """
def attach(self, source, recipe=None): def attach(self, source, recipe=None):
@ -65,11 +67,11 @@ class YieldFilter(Filter):
class FieldFilter(Filter): class FieldFilter(Filter):
""" ABC for filters that do a single operation on individual fields. """ABC for filters that do a single operation on individual fields.
All derived filters must provide a process_field(self, item) that All derived filters must provide a process_field(self, item) that
returns a modified item. process_field is called on one or more keys returns a modified item. process_field is called on one or more keys
passed into __init__. passed into __init__.
""" """
def __init__(self, keys): def __init__(self, keys):
@ -77,7 +79,7 @@ class FieldFilter(Filter):
self._target_keys = utils.str_or_list(keys) self._target_keys = utils.str_or_list(keys)
def process_record(self, record): def process_record(self, record):
""" Calls process_field on all keys passed to __init__. """ """Calls process_field on all keys passed to __init__."""
for key in self._target_keys: for key in self._target_keys:
try: try:
@ -89,29 +91,31 @@ class FieldFilter(Filter):
return record return record
def process_field(self, item): def process_field(self, item):
""" Given a value, return the value that it should be replaced with. """ """Given a value, return the value that it should be replaced with."""
raise NotImplementedError('process_field not defined in ' + raise NotImplementedError(
self.__class__.__name__) "process_field not defined in " + self.__class__.__name__
)
def __unicode__(self): def __unicode__(self):
return '%s( %s )' % (self.__class__.__name__, str(self._target_keys)) return "%s( %s )" % (self.__class__.__name__, str(self._target_keys))
class ConditionalFilter(YieldFilter): class ConditionalFilter(YieldFilter):
""" ABC for filters that only pass through records meeting a condition. """ABC for filters that only pass through records meeting a condition.
All derived filters must provide a test_record(self, record) that All derived filters must provide a test_record(self, record) that
returns True or False -- True indicating that the record should be returns True or False -- True indicating that the record should be
passed through, and False preventing pass through. passed through, and False preventing pass through.
If validator is True then raises a ValidationError instead of If validator is True then raises a ValidationError instead of
silently dropping records that fail test_record. silently dropping records that fail test_record.
""" """
validator = False validator = False
def process_record(self, record): def process_record(self, record):
""" Yields all records for which self.test_record is true """ """Yields all records for which self.test_record is true"""
if self.test_record(record): if self.test_record(record):
yield record yield record
@ -119,41 +123,45 @@ class ConditionalFilter(YieldFilter):
raise ValidationError(record) raise ValidationError(record)
def test_record(self, record): def test_record(self, record):
""" Given a record, return True iff it should be passed on """ """Given a record, return True iff it should be passed on"""
raise NotImplementedError('test_record not defined in ' + raise NotImplementedError(
self.__class__.__name__) "test_record not defined in " + self.__class__.__name__
)
class ValidationError(Exception): class ValidationError(Exception):
def __init__(self, record): def __init__(self, record):
super(ValidationError, self).__init__(repr(record)) super(ValidationError, self).__init__(repr(record))
self.record = record self.record = record
def _dotted_get(d, path): def _dotted_get(d, path):
""" """
utility function for SubrecordFilter utility function for SubrecordFilter
dives into a complex nested dictionary with paths like a.b.c dives into a complex nested dictionary with paths like a.b.c
""" """
if path: if path:
key_pieces = path.split('.', 1) key_pieces = path.split(".", 1)
piece = d[key_pieces[0]] piece = d[key_pieces[0]]
if isinstance(piece, (tuple, list)): if isinstance(piece, (tuple, list)):
return [_dotted_get(i, '.'.join(key_pieces[1:])) for i in piece] return [_dotted_get(i, ".".join(key_pieces[1:])) for i in piece]
elif isinstance(piece, (dict)): elif isinstance(piece, (dict)):
return _dotted_get(piece, '.'.join(key_pieces[1:])) return _dotted_get(piece, ".".join(key_pieces[1:]))
else: else:
return d return d
class SubrecordFilter(Filter):
""" Filter that calls another filter on subrecord(s) of a record
Takes a dotted path (eg. a.b.c) and instantiated filter and runs that class SubrecordFilter(Filter):
filter on all subrecords found at the path. """Filter that calls another filter on subrecord(s) of a record
Takes a dotted path (eg. a.b.c) and instantiated filter and runs that
filter on all subrecords found at the path.
""" """
def __init__(self, field_path, filter_): def __init__(self, field_path, filter_):
if '.' in field_path: if "." in field_path:
self.field_path, self.key = field_path.rsplit('.', 1) self.field_path, self.key = field_path.rsplit(".", 1)
else: else:
self.field_path = None self.field_path = None
self.key = field_path self.key = field_path
@ -178,8 +186,9 @@ class SubrecordFilter(Filter):
self.process_subrecord(subrecord_parent) self.process_subrecord(subrecord_parent)
return record return record
class ConditionalPathFilter(Filter): class ConditionalPathFilter(Filter):
""" Filter that uses a predicate to split input among two filter paths. """ """Filter that uses a predicate to split input among two filter paths."""
def __init__(self, predicate_func, true_filter, false_filter): def __init__(self, predicate_func, true_filter, false_filter):
self.predicate_func = predicate_func self.predicate_func = predicate_func
@ -192,15 +201,17 @@ class ConditionalPathFilter(Filter):
else: else:
return self.false_filter.process_record(record) return self.false_filter.process_record(record)
##################### #####################
## Generic Filters ## ## Generic Filters ##
##################### #####################
class FieldModifier(FieldFilter):
""" Filter that calls a given function on a given set of fields.
FieldModifier(('spam','eggs'), abs) to call the abs method on the spam class FieldModifier(FieldFilter):
and eggs fields in each record filtered. """Filter that calls a given function on a given set of fields.
FieldModifier(('spam','eggs'), abs) to call the abs method on the spam
and eggs fields in each record filtered.
""" """
def __init__(self, keys, func): def __init__(self, keys, func):
@ -211,15 +222,18 @@ class FieldModifier(FieldFilter):
return self._filter_func(item) return self._filter_func(item)
def __unicode__(self): def __unicode__(self):
return '%s( %s, %s )' % (self.__class__.__name__, return "%s( %s, %s )" % (
str(self._target_keys), str(self._filter_func)) self.__class__.__name__,
str(self._target_keys),
str(self._filter_func),
)
class FieldKeeper(Filter): class FieldKeeper(Filter):
""" Filter that removes all but the given set of fields. """Filter that removes all but the given set of fields.
FieldKeeper(('spam', 'eggs')) removes all bu tthe spam and eggs FieldKeeper(('spam', 'eggs')) removes all bu tthe spam and eggs
fields from every record filtered. fields from every record filtered.
""" """
def __init__(self, keys): def __init__(self, keys):
@ -234,10 +248,10 @@ class FieldKeeper(Filter):
class FieldRemover(Filter): class FieldRemover(Filter):
""" Filter that removes a given set of fields. """Filter that removes a given set of fields.
FieldRemover(('spam', 'eggs')) removes the spam and eggs fields from FieldRemover(('spam', 'eggs')) removes the spam and eggs fields from
every record filtered. every record filtered.
""" """
def __init__(self, keys): def __init__(self, keys):
@ -250,16 +264,16 @@ class FieldRemover(Filter):
return record return record
def __unicode__(self): def __unicode__(self):
return '%s( %s )' % (self.__class__.__name__, str(self._target_keys)) return "%s( %s )" % (self.__class__.__name__, str(self._target_keys))
class FieldMerger(Filter): class FieldMerger(Filter):
""" Filter that merges a given set of fields using a supplied merge_func. """Filter that merges a given set of fields using a supplied merge_func.
Takes a mapping (dictionary of new_column:(from_col1,from_col2)) Takes a mapping (dictionary of new_column:(from_col1,from_col2))
FieldMerger({"bacon": ("spam", "eggs")}, operator.add) creates a new FieldMerger({"bacon": ("spam", "eggs")}, operator.add) creates a new
column bacon that is the result of spam+eggs column bacon that is the result of spam+eggs
""" """
def __init__(self, mapping, merge_func, keep_fields=False): def __init__(self, mapping, merge_func, keep_fields=False):
@ -278,29 +292,31 @@ class FieldMerger(Filter):
return record return record
def __unicode__(self): def __unicode__(self):
return '%s( %s, %s )' % (self.__class__.__name__, return "%s( %s, %s )" % (
str(self._field_mapping), self.__class__.__name__,
str(self._merge_func)) str(self._field_mapping),
str(self._merge_func),
)
class FieldAdder(Filter): class FieldAdder(Filter):
""" Filter that adds a new field. """Filter that adds a new field.
Takes a name for the new field and a value, field_value can be an Takes a name for the new field and a value, field_value can be an
iterable, a function, or a static value. iterable, a function, or a static value.
from itertools import count from itertools import count
FieldAdder('id', count) FieldAdder('id', count)
would yield a new column named id that uses the itertools count iterable would yield a new column named id that uses the itertools count iterable
to create sequentially numbered ids. to create sequentially numbered ids.
""" """
def __init__(self, field_name, field_value, replace=True): def __init__(self, field_name, field_value, replace=True):
super(FieldAdder, self).__init__() super(FieldAdder, self).__init__()
self._field_name = field_name self._field_name = field_name
self._field_value = field_value self._field_value = field_value
if hasattr(self._field_value, '__iter__'): if hasattr(self._field_value, "__iter__"):
value_iter = iter(self._field_value) value_iter = iter(self._field_value)
if hasattr(value_iter, "next"): if hasattr(value_iter, "next"):
self._field_value = value_iter.next self._field_value = value_iter.next
@ -317,15 +333,20 @@ class FieldAdder(Filter):
return record return record
def __unicode__(self): def __unicode__(self):
return '%s( %s, %s )' % (self.__class__.__name__, self._field_name, return "%s( %s, %s )" % (
str(self._field_value)) self.__class__.__name__,
self._field_name,
str(self._field_value),
)
class FieldCopier(Filter): class FieldCopier(Filter):
""" Filter that copies one field to another. """Filter that copies one field to another.
Takes a dictionary mapping destination keys to source keys. Takes a dictionary mapping destination keys to source keys.
""" """
def __init__(self, copy_mapping): def __init__(self, copy_mapping):
super(FieldCopier, self).__init__() super(FieldCopier, self).__init__()
self._copy_mapping = copy_mapping self._copy_mapping = copy_mapping
@ -336,11 +357,13 @@ class FieldCopier(Filter):
record[dest] = record[source] record[dest] = record[source]
return record return record
class FieldRenamer(Filter):
""" Filter that renames one field to another.
Takes a dictionary mapping destination keys to source keys. class FieldRenamer(Filter):
"""Filter that renames one field to another.
Takes a dictionary mapping destination keys to source keys.
""" """
def __init__(self, rename_mapping): def __init__(self, rename_mapping):
super(FieldRenamer, self).__init__() super(FieldRenamer, self).__init__()
self._rename_mapping = rename_mapping self._rename_mapping = rename_mapping
@ -351,11 +374,12 @@ class FieldRenamer(Filter):
record[dest] = record.pop(source) record[dest] = record.pop(source)
return record return record
class FieldNameModifier(Filter):
""" Filter that calls a given function on a given set of fields.
FieldNameModifier(('spam','eggs'), abs) to call the abs method on the spam class FieldNameModifier(Filter):
and eggs field names in each record filtered. """Filter that calls a given function on a given set of fields.
FieldNameModifier(('spam','eggs'), abs) to call the abs method on the spam
and eggs field names in each record filtered.
""" """
def __init__(self, func): def __init__(self, func):
@ -368,15 +392,16 @@ class FieldNameModifier(Filter):
record[dest] = record.pop(source) record[dest] = record.pop(source)
return record return record
class Splitter(Filter): class Splitter(Filter):
""" Filter that splits nested data into different paths. """Filter that splits nested data into different paths.
Takes a dictionary of keys and a series of filters to run against the Takes a dictionary of keys and a series of filters to run against the
associated dictionaries. associated dictionaries.
{'person': {'firstname': 'James', 'lastname': 'Turk'}, {'person': {'firstname': 'James', 'lastname': 'Turk'},
'phones': [{'phone': '222-222-2222'}, {'phone': '335-333-3321'}] 'phones': [{'phone': '222-222-2222'}, {'phone': '335-333-3321'}]
} }
""" """
def __init__(self, split_mapping): def __init__(self, split_mapping):
@ -409,19 +434,20 @@ class Splitter(Filter):
class Flattener(FieldFilter): class Flattener(FieldFilter):
""" Collapse a set of similar dictionaries into a list. """Collapse a set of similar dictionaries into a list.
Takes a dictionary of keys and flattens the key names: Takes a dictionary of keys and flattens the key names:
addresses = [{'addresses': [{'address': {'state':'NC', 'street':'146 shirley drive'}}, addresses = [{'addresses': [{'address': {'state':'NC', 'street':'146 shirley drive'}},
{'address': {'state':'NY', 'street':'3000 Winton Rd'}}]}] {'address': {'state':'NY', 'street':'3000 Winton Rd'}}]}]
flattener = Flattener(['addresses']) flattener = Flattener(['addresses'])
would yield: would yield:
{'addresses': [{'state': 'NC', 'street': '146 shirley drive'}, {'addresses': [{'state': 'NC', 'street': '146 shirley drive'},
{'state': 'NY', 'street': '3000 Winton Rd'}]} {'state': 'NY', 'street': '3000 Winton Rd'}]}
""" """
def __init__(self, keys): def __init__(self, keys):
super(Flattener, self).__init__(keys) super(Flattener, self).__init__(keys)
@ -436,7 +462,7 @@ class Flattener(FieldFilter):
class DictFlattener(Filter): class DictFlattener(Filter):
def __init__(self, keys, separator='_'): def __init__(self, keys, separator="_"):
super(DictFlattener, self).__init__() super(DictFlattener, self).__init__()
self._keys = utils.str_or_list(keys) self._keys = utils.str_or_list(keys)
self._separator = separator self._separator = separator
@ -446,8 +472,7 @@ class DictFlattener(Filter):
class Unique(ConditionalFilter): class Unique(ConditionalFilter):
""" Filter that ensures that all records passing through are unique. """Filter that ensures that all records passing through are unique."""
"""
def __init__(self): def __init__(self):
super(Unique, self).__init__() super(Unique, self).__init__()
@ -461,18 +486,19 @@ class Unique(ConditionalFilter):
else: else:
return False return False
class UniqueValidator(Unique): class UniqueValidator(Unique):
validator = True validator = True
class UniqueID(ConditionalFilter): class UniqueID(ConditionalFilter):
""" Filter that ensures that all records through have a unique ID. """Filter that ensures that all records through have a unique ID.
Takes the name of an ID field, or multiple field names in the case Takes the name of an ID field, or multiple field names in the case
of a composite ID. of a composite ID.
""" """
def __init__(self, field='id', *args): def __init__(self, field="id", *args):
super(UniqueID, self).__init__() super(UniqueID, self).__init__()
self._seen = set() self._seen = set()
self._id_fields = [field] self._id_fields = [field]
@ -486,15 +512,15 @@ class UniqueID(ConditionalFilter):
else: else:
return False return False
class UniqueIDValidator(UniqueID): class UniqueIDValidator(UniqueID):
validator = True validator = True
class UnicodeFilter(Filter): class UnicodeFilter(Filter):
""" Convert all str elements in the record to Unicode. """Convert all str elements in the record to Unicode."""
"""
def __init__(self, encoding='utf-8', errors='ignore'): def __init__(self, encoding="utf-8", errors="ignore"):
super(UnicodeFilter, self).__init__() super(UnicodeFilter, self).__init__()
self._encoding = encoding self._encoding = encoding
self._errors = errors self._errors = errors
@ -507,9 +533,9 @@ class UnicodeFilter(Filter):
record[key] = value.decode(self._encoding, self._errors) record[key] = value.decode(self._encoding, self._errors)
return record return record
class StringFilter(Filter):
def __init__(self, encoding='utf-8', errors='ignore'): class StringFilter(Filter):
def __init__(self, encoding="utf-8", errors="ignore"):
super(StringFilter, self).__init__() super(StringFilter, self).__init__()
self._encoding = encoding self._encoding = encoding
self._errors = errors self._errors = errors
@ -525,19 +551,21 @@ class StringFilter(Filter):
## Commonly Used Filters ## ## Commonly Used Filters ##
########################### ###########################
class PhoneNumberCleaner(FieldFilter): class PhoneNumberCleaner(FieldFilter):
""" Filter that cleans phone numbers to match a given format. """Filter that cleans phone numbers to match a given format.
Takes a list of target keys and an optional phone # format that has Takes a list of target keys and an optional phone # format that has
10 %s placeholders. 10 %s placeholders.
PhoneNumberCleaner( ('phone','fax'), number_format='%s%s%s-%s%s%s-%s%s%s%s') PhoneNumberCleaner( ('phone','fax'), number_format='%s%s%s-%s%s%s-%s%s%s%s')
would format the phone & fax columns to 555-123-4567 format. would format the phone & fax columns to 555-123-4567 format.
""" """
def __init__(self, keys, number_format='%s%s%s.%s%s%s.%s%s%s%s'):
def __init__(self, keys, number_format="%s%s%s.%s%s%s.%s%s%s%s"):
super(PhoneNumberCleaner, self).__init__(keys) super(PhoneNumberCleaner, self).__init__(keys)
self._number_format = number_format self._number_format = number_format
self._num_re = re.compile('\d') self._num_re = re.compile("\d")
def process_field(self, item): def process_field(self, item):
nums = self._num_re.findall(item) nums = self._num_re.findall(item)
@ -545,45 +573,53 @@ class PhoneNumberCleaner(FieldFilter):
item = self._number_format % tuple(nums) item = self._number_format % tuple(nums)
return item return item
class DateCleaner(FieldFilter):
""" Filter that cleans dates to match a given format.
Takes a list of target keys and to and from formats in strftime format. class DateCleaner(FieldFilter):
"""Filter that cleans dates to match a given format.
Takes a list of target keys and to and from formats in strftime format.
""" """
def __init__(self, keys, from_format, to_format): def __init__(self, keys, from_format, to_format):
super(DateCleaner, self).__init__(keys) super(DateCleaner, self).__init__(keys)
self._from_format = from_format self._from_format = from_format
self._to_format = to_format self._to_format = to_format
def process_field(self, item): def process_field(self, item):
return time.strftime(self._to_format, return time.strftime(self._to_format, time.strptime(item, self._from_format))
time.strptime(item, self._from_format))
class NameCleaner(Filter): class NameCleaner(Filter):
""" Filter that splits names into a first, last, and middle name field. """Filter that splits names into a first, last, and middle name field.
Takes a list of target keys. Takes a list of target keys.
NameCleaner( ('name', ), nomatch_name='raw_name') NameCleaner( ('name', ), nomatch_name='raw_name')
would attempt to split 'name' into firstname, middlename, lastname, would attempt to split 'name' into firstname, middlename, lastname,
and suffix columns, and if it did not fit would place it in raw_name and suffix columns, and if it did not fit would place it in raw_name
""" """
# first middle? last suffix? # first middle? last suffix?
FIRST_LAST = re.compile('''^\s*(?:(?P<firstname>\w+)(?:\.?) FIRST_LAST = re.compile(
"""^\s*(?:(?P<firstname>\w+)(?:\.?)
\s+(?:(?P<middlename>\w+)\.?\s+)? \s+(?:(?P<middlename>\w+)\.?\s+)?
(?P<lastname>[A-Za-z'-]+)) (?P<lastname>[A-Za-z'-]+))
(?:\s+(?P<suffix>JR\.?|II|III|IV))? (?:\s+(?P<suffix>JR\.?|II|III|IV))?
\s*$''', re.VERBOSE | re.IGNORECASE) \s*$""",
re.VERBOSE | re.IGNORECASE,
)
# last, first middle? suffix? # last, first middle? suffix?
LAST_FIRST = re.compile('''^\s*(?:(?P<lastname>[A-Za-z'-]+), LAST_FIRST = re.compile(
"""^\s*(?:(?P<lastname>[A-Za-z'-]+),
\s+(?P<firstname>\w+)(?:\.?) \s+(?P<firstname>\w+)(?:\.?)
(?:\s+(?P<middlename>\w+)\.?)?) (?:\s+(?P<middlename>\w+)\.?)?)
(?:\s+(?P<suffix>JR\.?|II|III|IV))? (?:\s+(?P<suffix>JR\.?|II|III|IV))?
\s*$''', re.VERBOSE | re.IGNORECASE) \s*$""",
re.VERBOSE | re.IGNORECASE,
)
def __init__(self, keys, prefix='', formats=None, nomatch_name=None): def __init__(self, keys, prefix="", formats=None, nomatch_name=None):
super(NameCleaner, self).__init__() super(NameCleaner, self).__init__()
self._keys = utils.str_or_list(keys) self._keys = utils.str_or_list(keys)
self._name_prefix = prefix self._name_prefix = prefix
@ -605,7 +641,7 @@ class NameCleaner(Filter):
# if there is a match, remove original name and add pieces # if there is a match, remove original name and add pieces
if match: if match:
record.pop(key) record.pop(key)
for k,v in match.groupdict().items(): for k, v in match.groupdict().items():
record[self._name_prefix + k] = v record[self._name_prefix + k] = v
break break

View File

@ -9,22 +9,24 @@ import string
from saucebrush import utils from saucebrush import utils
class CSVSource(object): class CSVSource(object):
""" Saucebrush source for reading from CSV files. """Saucebrush source for reading from CSV files.
Takes an open csvfile, an optional set of fieldnames and optional number Takes an open csvfile, an optional set of fieldnames and optional number
of rows to skip. of rows to skip.
CSVSource(open('test.csv')) will read a csvfile, using the first row as CSVSource(open('test.csv')) will read a csvfile, using the first row as
the field names. the field names.
CSVSource(open('test.csv'), ('name', 'phone', 'address'), 1) will read CSVSource(open('test.csv'), ('name', 'phone', 'address'), 1) will read
in a CSV file and treat the three columns as name, phone, and address, in a CSV file and treat the three columns as name, phone, and address,
ignoring the first row (presumed to be column names). ignoring the first row (presumed to be column names).
""" """
def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs): def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs):
import csv import csv
self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs) self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs)
for _ in range(skiprows): for _ in range(skiprows):
next(self._dictreader) next(self._dictreader)
@ -34,16 +36,16 @@ class CSVSource(object):
class FixedWidthFileSource(object): class FixedWidthFileSource(object):
""" Saucebrush source for reading from fixed width field files. """Saucebrush source for reading from fixed width field files.
FixedWidthFileSource expects an open fixed width file and a tuple FixedWidthFileSource expects an open fixed width file and a tuple
of fields with their lengths. There is also an optional fillchars of fields with their lengths. There is also an optional fillchars
command that is the filler characters to strip from the end of each command that is the filler characters to strip from the end of each
field. (defaults to whitespace) field. (defaults to whitespace)
FixedWidthFileSource(open('testfile'), (('name',30), ('phone',12))) FixedWidthFileSource(open('testfile'), (('name',30), ('phone',12)))
will read in a fixed width file where the first 30 characters of each will read in a fixed width file where the first 30 characters of each
line are part of a name and the characters 31-42 are a phone number. line are part of a name and the characters 31-42 are a phone number.
""" """
def __init__(self, fwfile, fields, fillchars=string.whitespace): def __init__(self, fwfile, fields, fillchars=string.whitespace):
@ -64,60 +66,61 @@ class FixedWidthFileSource(object):
line = next(self._fwfile) line = next(self._fwfile)
record = {} record = {}
for name, range_ in self._fields_dict.items(): for name, range_ in self._fields_dict.items():
record[name] = line[range_[0]:range_[1]].rstrip(self._fillchars) record[name] = line[range_[0] : range_[1]].rstrip(self._fillchars)
return record return record
def next(self): def next(self):
""" Keep Python 2 next() method that defers to __next__(). """Keep Python 2 next() method that defers to __next__()."""
"""
return self.__next__() return self.__next__()
class HtmlTableSource(object): class HtmlTableSource(object):
""" Saucebrush source for reading data from an HTML table. """Saucebrush source for reading data from an HTML table.
HtmlTableSource expects an open html file, the id of the table or a HtmlTableSource expects an open html file, the id of the table or a
number indicating which table on the page to use, an optional fieldnames number indicating which table on the page to use, an optional fieldnames
tuple, and an optional number of rows to skip. tuple, and an optional number of rows to skip.
HtmlTableSource(open('test.html'), 0) opens the first HTML table and HtmlTableSource(open('test.html'), 0) opens the first HTML table and
uses the first row as the names of the columns. uses the first row as the names of the columns.
HtmlTableSource(open('test.html'), 'people', ('name','phone'), 1) opens HtmlTableSource(open('test.html'), 'people', ('name','phone'), 1) opens
the HTML table with an id of 'people' and names the two columns the HTML table with an id of 'people' and names the two columns
name and phone, skipping the first row where alternate names are name and phone, skipping the first row where alternate names are
stored. stored.
""" """
def __init__(self, htmlfile, id_or_num, fieldnames=None, skiprows=0): def __init__(self, htmlfile, id_or_num, fieldnames=None, skiprows=0):
# extract the table # extract the table
from lxml.html import parse from lxml.html import parse
doc = parse(htmlfile).getroot() doc = parse(htmlfile).getroot()
if isinstance(id_or_num, int): if isinstance(id_or_num, int):
table = doc.cssselect('table')[id_or_num] table = doc.cssselect("table")[id_or_num]
else: else:
table = doc.cssselect('table#%s' % id_or_num) table = doc.cssselect("table#%s" % id_or_num)
table = table[0] # get the first table table = table[0] # get the first table
# skip the necessary number of rows # skip the necessary number of rows
self._rows = table.cssselect('tr')[skiprows:] self._rows = table.cssselect("tr")[skiprows:]
# determine the fieldnames # determine the fieldnames
if not fieldnames: if not fieldnames:
self._fieldnames = [td.text_content() self._fieldnames = [
for td in self._rows[0].cssselect('td, th')] td.text_content() for td in self._rows[0].cssselect("td, th")
]
skiprows += 1 skiprows += 1
else: else:
self._fieldnames = fieldnames self._fieldnames = fieldnames
# skip the necessary number of rows # skip the necessary number of rows
self._rows = table.cssselect('tr')[skiprows:] self._rows = table.cssselect("tr")[skiprows:]
def process_tr(self): def process_tr(self):
for row in self._rows: for row in self._rows:
strings = [td.text_content() for td in row.cssselect('td')] strings = [td.text_content() for td in row.cssselect("td")]
yield dict(zip(self._fieldnames, strings)) yield dict(zip(self._fieldnames, strings))
def __iter__(self): def __iter__(self):
@ -125,36 +128,40 @@ class HtmlTableSource(object):
class DjangoModelSource(object): class DjangoModelSource(object):
""" Saucebrush source for reading data from django models. """Saucebrush source for reading data from django models.
DjangoModelSource expects a django settings file, app label, and model DjangoModelSource expects a django settings file, app label, and model
name. The resulting records contain all columns in the table for the name. The resulting records contain all columns in the table for the
specified model. specified model.
DjangoModelSource('settings.py', 'phonebook', 'friend') would read all DjangoModelSource('settings.py', 'phonebook', 'friend') would read all
friends from the friend model in the phonebook app described in friends from the friend model in the phonebook app described in
settings.py. settings.py.
""" """
def __init__(self, dj_settings, app_label, model_name): def __init__(self, dj_settings, app_label, model_name):
dbmodel = utils.get_django_model(dj_settings, app_label, model_name) dbmodel = utils.get_django_model(dj_settings, app_label, model_name)
# only get values defined in model (no extra fields from custom manager) # only get values defined in model (no extra fields from custom manager)
self._data = dbmodel.objects.values(*[f.name self._data = dbmodel.objects.values(*[f.name for f in dbmodel._meta.fields])
for f in dbmodel._meta.fields])
def __iter__(self): def __iter__(self):
return iter(self._data) return iter(self._data)
class MongoDBSource(object): class MongoDBSource(object):
""" Source for reading from a MongoDB database. """Source for reading from a MongoDB database.
The record dict is populated with records matching the spec The record dict is populated with records matching the spec
from the specified database and collection. from the specified database and collection.
""" """
def __init__(self, database, collection, spec=None, host='localhost', port=27017, conn=None):
def __init__(
self, database, collection, spec=None, host="localhost", port=27017, conn=None
):
if not conn: if not conn:
from pymongo.connection import Connection from pymongo.connection import Connection
conn = Connection(host, port) conn = Connection(host, port)
self.collection = conn[database][collection] self.collection = conn[database][collection]
self.spec = spec self.spec = spec
@ -166,19 +173,21 @@ class MongoDBSource(object):
for doc in self.collection.find(self.spec): for doc in self.collection.find(self.spec):
yield dict(doc) yield dict(doc)
# dict_factory for sqlite source # dict_factory for sqlite source
def dict_factory(cursor, row): def dict_factory(cursor, row):
d = { } d = {}
for idx, col in enumerate(cursor.description): for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx] d[col[0]] = row[idx]
return d return d
class SqliteSource(object):
""" Source that reads from a sqlite database.
The record dict is populated with the results from the class SqliteSource(object):
query argument. If given, args will be passed to the query """Source that reads from a sqlite database.
when executed.
The record dict is populated with the results from the
query argument. If given, args will be passed to the query
when executed.
""" """
def __init__(self, dbpath, query, args=None, conn_params=None): def __init__(self, dbpath, query, args=None, conn_params=None):
@ -214,10 +223,10 @@ class SqliteSource(object):
class FileSource(object): class FileSource(object):
""" Base class for sources which read from one or more files. """Base class for sources which read from one or more files.
Takes as input a file-like, a file path, a list of file-likes, Takes as input a file-like, a file path, a list of file-likes,
or a list of file paths. or a list of file paths.
""" """
def __init__(self, input): def __init__(self, input):
@ -226,34 +235,36 @@ class FileSource(object):
def __iter__(self): def __iter__(self):
# This method would be a lot cleaner with the proposed # This method would be a lot cleaner with the proposed
# 'yield from' expression (PEP 380) # 'yield from' expression (PEP 380)
if hasattr(self._input, '__read__') or hasattr(self._input, 'read'): if hasattr(self._input, "__read__") or hasattr(self._input, "read"):
for record in self._process_file(self._input): for record in self._process_file(self._input):
yield record yield record
elif isinstance(self._input, str): elif isinstance(self._input, str):
with open(self._input) as f: with open(self._input) as f:
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
elif hasattr(self._input, '__iter__'): elif hasattr(self._input, "__iter__"):
for el in self._input: for el in self._input:
if isinstance(el, str): if isinstance(el, str):
with open(el) as f: with open(el) as f:
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
elif hasattr(el, '__read__') or hasattr(el, 'read'): elif hasattr(el, "__read__") or hasattr(el, "read"):
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
def _process_file(self, file): def _process_file(self, file):
raise NotImplementedError('Descendants of FileSource should implement' raise NotImplementedError(
' a custom _process_file method.') "Descendants of FileSource should implement"
" a custom _process_file method."
)
class JSONSource(FileSource): class JSONSource(FileSource):
""" Source for reading from JSON files. """Source for reading from JSON files.
When processing JSON files, if the top-level object is a list, will When processing JSON files, if the top-level object is a list, will
yield each member separately. Otherwise, yields the top-level yield each member separately. Otherwise, yields the top-level
object. object.
""" """
def _process_file(self, f): def _process_file(self, f):
@ -271,36 +282,37 @@ class JSONSource(FileSource):
else: else:
yield obj yield obj
class XMLSource(FileSource):
""" Source for reading from XML files. Use with the same kind of caution
that you use to approach anything written in XML.
When processing XML files, if the top-level object is a list, will class XMLSource(FileSource):
yield each member separately, unless the dotted path to a list is """Source for reading from XML files. Use with the same kind of caution
included. you can also do this with a SubrecordFilter, but XML is that you use to approach anything written in XML.
almost never going to be useful at the top level.
When processing XML files, if the top-level object is a list, will
yield each member separately, unless the dotted path to a list is
included. you can also do this with a SubrecordFilter, but XML is
almost never going to be useful at the top level.
""" """
def __init__(self, input, node_path=None, attr_prefix='ATTR_', def __init__(self, input, node_path=None, attr_prefix="ATTR_", postprocessor=None):
postprocessor=None):
super(XMLSource, self).__init__(input) super(XMLSource, self).__init__(input)
self.node_list = node_path.split('.') self.node_list = node_path.split(".")
self.attr_prefix = attr_prefix self.attr_prefix = attr_prefix
self.postprocessor = postprocessor self.postprocessor = postprocessor
def _process_file(self, f, attr_prefix='ATTR_'): def _process_file(self, f, attr_prefix="ATTR_"):
"""xmltodict can either return attributes of nodes as prefixed fields """xmltodict can either return attributes of nodes as prefixed fields
(prefixes to avoid key collisions), or ignore them altogether. (prefixes to avoid key collisions), or ignore them altogether.
set attr prefix to whatever you want. Setting it to False ignores set attr prefix to whatever you want. Setting it to False ignores
attributes. attributes.
""" """
import xmltodict import xmltodict
if self.postprocessor: if self.postprocessor:
obj = xmltodict.parse(f, attr_prefix=self.attr_prefix, obj = xmltodict.parse(
postprocessor=self.postprocessor) f, attr_prefix=self.attr_prefix, postprocessor=self.postprocessor
)
else: else:
obj = xmltodict.parse(f, attr_prefix=self.attr_prefix) obj = xmltodict.parse(f, attr_prefix=self.attr_prefix)
@ -308,7 +320,7 @@ class XMLSource(FileSource):
if self.node_list: if self.node_list:
for node in self.node_list: for node in self.node_list:
obj = obj[node] obj = obj[node]
# If the top-level XML object in the file is a list # If the top-level XML object in the file is a list
# then yield each element separately; otherwise, yield # then yield each element separately; otherwise, yield

View File

@ -1,22 +1,23 @@
from saucebrush.filters import Filter from saucebrush.filters import Filter
from saucebrush.utils import FallbackCounter from saucebrush.utils import FallbackCounter
import collections import collections
import itertools
import math import math
def _average(values):
""" Calculate the average of a list of values.
:param values: an iterable of ints or floats to average def _average(values):
"""Calculate the average of a list of values.
:param values: an iterable of ints or floats to average
""" """
value_count = len(values) value_count = len(values)
if len(values) > 0: if len(values) > 0:
return sum(values) / float(value_count) return sum(values) / float(value_count)
def _median(values):
""" Calculate the median of a list of values.
:param values: an iterable of ints or floats to calculate def _median(values):
"""Calculate the median of a list of values.
:param values: an iterable of ints or floats to calculate
""" """
count = len(values) count = len(values)
@ -35,14 +36,15 @@ def _median(values):
else: else:
# even number of items, return average of middle two items # even number of items, return average of middle two items
mid = int(count / 2) mid = int(count / 2)
return sum(values[mid - 1:mid + 1]) / 2.0 return sum(values[mid - 1 : mid + 1]) / 2.0
def _stddev(values, population=False): def _stddev(values, population=False):
""" Calculate the standard deviation and variance of a list of values. """Calculate the standard deviation and variance of a list of values.
:param values: an iterable of ints or floats to calculate :param values: an iterable of ints or floats to calculate
:param population: True if values represents entire population, :param population: True if values represents entire population,
False if it is a sample of the population False if it is a sample of the population
""" """
avg = _average(values) avg = _average(values)
@ -54,11 +56,11 @@ def _stddev(values, population=False):
# the average of the squared differences # the average of the squared differences
variance = sum(diffsq) / float(count) variance = sum(diffsq) / float(count)
return (math.sqrt(variance), variance) # stddev is sqrt of variance return (math.sqrt(variance), variance) # stddev is sqrt of variance
class StatsFilter(Filter): class StatsFilter(Filter):
""" Base for all stats filters. """Base for all stats filters."""
"""
def __init__(self, field, test=None): def __init__(self, field, test=None):
self._field = field self._field = field
@ -70,16 +72,17 @@ class StatsFilter(Filter):
return record return record
def process_field(self, record): def process_field(self, record):
raise NotImplementedError('process_field not defined in ' + raise NotImplementedError(
self.__class__.__name__) "process_field not defined in " + self.__class__.__name__
)
def value(self): def value(self):
raise NotImplementedError('value not defined in ' + raise NotImplementedError("value not defined in " + self.__class__.__name__)
self.__class__.__name__)
class Sum(StatsFilter): class Sum(StatsFilter):
""" Calculate the sum of the values in a field. Field must contain either """Calculate the sum of the values in a field. Field must contain either
int or float values. int or float values.
""" """
def __init__(self, field, initial=0, **kwargs): def __init__(self, field, initial=0, **kwargs):
@ -92,9 +95,10 @@ class Sum(StatsFilter):
def value(self): def value(self):
return self._value return self._value
class Average(StatsFilter): class Average(StatsFilter):
""" Calculate the average (mean) of the values in a field. Field must """Calculate the average (mean) of the values in a field. Field must
contain either int or float values. contain either int or float values.
""" """
def __init__(self, field, initial=0, **kwargs): def __init__(self, field, initial=0, **kwargs):
@ -110,11 +114,12 @@ class Average(StatsFilter):
def value(self): def value(self):
return self._value / float(self._count) return self._value / float(self._count)
class Median(StatsFilter):
""" Calculate the median of the values in a field. Field must contain
either int or float values.
**This filter keeps a list of field values in memory.** class Median(StatsFilter):
"""Calculate the median of the values in a field. Field must contain
either int or float values.
**This filter keeps a list of field values in memory.**
""" """
def __init__(self, field, **kwargs): def __init__(self, field, **kwargs):
@ -128,9 +133,10 @@ class Median(StatsFilter):
def value(self): def value(self):
return _median(self._values) return _median(self._values)
class MinMax(StatsFilter): class MinMax(StatsFilter):
""" Find the minimum and maximum values in a field. Field must contain """Find the minimum and maximum values in a field. Field must contain
either int or float values. either int or float values.
""" """
def __init__(self, field, **kwargs): def __init__(self, field, **kwargs):
@ -148,14 +154,15 @@ class MinMax(StatsFilter):
def value(self): def value(self):
return (self._min, self._max) return (self._min, self._max)
class StandardDeviation(StatsFilter):
""" Calculate the standard deviation of the values in a field. Calling
value() will return a standard deviation for the sample. Pass
population=True to value() for the standard deviation of the
population. Convenience methods are provided for average() and
median(). Field must contain either int or float values.
**This filter keeps a list of field values in memory.** class StandardDeviation(StatsFilter):
"""Calculate the standard deviation of the values in a field. Calling
value() will return a standard deviation for the sample. Pass
population=True to value() for the standard deviation of the
population. Convenience methods are provided for average() and
median(). Field must contain either int or float values.
**This filter keeps a list of field values in memory.**
""" """
def __init__(self, field, **kwargs): def __init__(self, field, **kwargs):
@ -173,28 +180,29 @@ class StandardDeviation(StatsFilter):
return _median(self._values) return _median(self._values)
def value(self, population=False): def value(self, population=False):
""" Return a tuple of (standard_deviation, variance). """Return a tuple of (standard_deviation, variance).
:param population: True if values represents entire population, :param population: True if values represents entire population,
False if values is a sample. Default: False False if values is a sample. Default: False
""" """
return _stddev(self._values, population) return _stddev(self._values, population)
class Histogram(StatsFilter):
""" Generate a basic histogram of the specified field. The value() method
returns a dict of value to occurance count mappings. The __str__ method
generates a basic and limited histogram useful for printing to the
command line. The label_length attribute determines the padding and
cut-off of the basic histogram labels.
**This filters maintains a dict of unique field values in memory.** class Histogram(StatsFilter):
"""Generate a basic histogram of the specified field. The value() method
returns a dict of value to occurance count mappings. The __str__ method
generates a basic and limited histogram useful for printing to the
command line. The label_length attribute determines the padding and
cut-off of the basic histogram labels.
**This filters maintains a dict of unique field values in memory.**
""" """
label_length = 6 label_length = 6
def __init__(self, field, **kwargs): def __init__(self, field, **kwargs):
super(Histogram, self).__init__(field, **kwargs) super(Histogram, self).__init__(field, **kwargs)
if hasattr(collections, 'Counter'): if hasattr(collections, "Counter"):
self._counter = collections.Counter() self._counter = collections.Counter()
else: else:
self._counter = FallbackCounter() self._counter = FallbackCounter()

View File

@ -11,5 +11,5 @@ emitter_suite = unittest.TestLoader().loadTestsFromTestCase(EmitterTestCase)
recipe_suite = unittest.TestLoader().loadTestsFromTestCase(RecipeTestCase) recipe_suite = unittest.TestLoader().loadTestsFromTestCase(RecipeTestCase)
stats_suite = unittest.TestLoader().loadTestsFromTestCase(StatsTestCase) stats_suite = unittest.TestLoader().loadTestsFromTestCase(StatsTestCase)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -5,61 +5,90 @@ import os
import unittest import unittest
from saucebrush.emitters import ( from saucebrush.emitters import (
DebugEmitter, CSVEmitter, CountEmitter, SqliteEmitter, SqlDumpEmitter) DebugEmitter,
CSVEmitter,
CountEmitter,
SqliteEmitter,
SqlDumpEmitter,
)
class EmitterTestCase(unittest.TestCase): class EmitterTestCase(unittest.TestCase):
def test_debug_emitter(self): def test_debug_emitter(self):
with closing(StringIO()) as output: with closing(StringIO()) as output:
de = DebugEmitter(output) de = DebugEmitter(output)
list(de.attach([1,2,3])) list(de.attach([1, 2, 3]))
self.assertEqual(output.getvalue(), '1\n2\n3\n') self.assertEqual(output.getvalue(), "1\n2\n3\n")
def test_count_emitter(self): def test_count_emitter(self):
# values for test # values for test
values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22] values = [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
]
with closing(StringIO()) as output: with closing(StringIO()) as output:
# test without of parameter # test without of parameter
ce = CountEmitter(every=10, outfile=output, format="%(count)s records\n") ce = CountEmitter(every=10, outfile=output, format="%(count)s records\n")
list(ce.attach(values)) list(ce.attach(values))
self.assertEqual(output.getvalue(), '10 records\n20 records\n') self.assertEqual(output.getvalue(), "10 records\n20 records\n")
ce.done() ce.done()
self.assertEqual(output.getvalue(), '10 records\n20 records\n22 records\n') self.assertEqual(output.getvalue(), "10 records\n20 records\n22 records\n")
with closing(StringIO()) as output: with closing(StringIO()) as output:
# test with of parameter # test with of parameter
ce = CountEmitter(every=10, outfile=output, of=len(values)) ce = CountEmitter(every=10, outfile=output, of=len(values))
list(ce.attach(values)) list(ce.attach(values))
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n') self.assertEqual(output.getvalue(), "10 of 22\n20 of 22\n")
ce.done() ce.done()
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n') self.assertEqual(output.getvalue(), "10 of 22\n20 of 22\n22 of 22\n")
def test_csv_emitter(self): def test_csv_emitter(self):
try: try:
import cStringIO # if Python 2.x then use old cStringIO import cStringIO # if Python 2.x then use old cStringIO
io = cStringIO.StringIO() io = cStringIO.StringIO()
except: except:
io = StringIO() # if Python 3.x then use StringIO io = StringIO() # if Python 3.x then use StringIO
with closing(io) as output: with closing(io) as output:
ce = CSVEmitter(output, ('x','y','z')) ce = CSVEmitter(output, ("x", "y", "z"))
list(ce.attach([{'x':1, 'y':2, 'z':3}, {'x':5, 'y':5, 'z':5}])) list(ce.attach([{"x": 1, "y": 2, "z": 3}, {"x": 5, "y": 5, "z": 5}]))
self.assertEqual(output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n') self.assertEqual(output.getvalue(), "x,y,z\r\n1,2,3\r\n5,5,5\r\n")
def test_sqlite_emitter(self): def test_sqlite_emitter(self):
import sqlite3, tempfile import sqlite3, tempfile
with closing(tempfile.NamedTemporaryFile(suffix='.db')) as f: with closing(tempfile.NamedTemporaryFile(suffix=".db")) as f:
db_path = f.name db_path = f.name
sle = SqliteEmitter(db_path, 'testtable', fieldnames=('a','b','c')) sle = SqliteEmitter(db_path, "testtable", fieldnames=("a", "b", "c"))
list(sle.attach([{'a': '1', 'b': '2', 'c': '3'}])) list(sle.attach([{"a": "1", "b": "2", "c": "3"}]))
sle.done() sle.done()
with closing(sqlite3.connect(db_path)) as conn: with closing(sqlite3.connect(db_path)) as conn:
@ -69,18 +98,20 @@ class EmitterTestCase(unittest.TestCase):
os.unlink(db_path) os.unlink(db_path)
self.assertEqual(results, [('1', '2', '3')]) self.assertEqual(results, [("1", "2", "3")])
def test_sql_dump_emitter(self): def test_sql_dump_emitter(self):
with closing(StringIO()) as bffr: with closing(StringIO()) as bffr:
sde = SqlDumpEmitter(bffr, 'testtable', ('a', 'b')) sde = SqlDumpEmitter(bffr, "testtable", ("a", "b"))
list(sde.attach([{'a': 1, 'b': '2'}])) list(sde.attach([{"a": 1, "b": "2"}]))
sde.done() sde.done()
self.assertEqual(bffr.getvalue(), "INSERT INTO `testtable` (`a`,`b`) VALUES (1,'2');\n") self.assertEqual(
bffr.getvalue(), "INSERT INTO `testtable` (`a`,`b`) VALUES (1,'2');\n"
)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,38 +1,56 @@
import unittest import unittest
import operator import operator
import types import types
from saucebrush.filters import (Filter, YieldFilter, FieldFilter, from saucebrush.filters import (
SubrecordFilter, ConditionalPathFilter, Filter,
ConditionalFilter, FieldModifier, FieldKeeper, YieldFilter,
FieldRemover, FieldMerger, FieldAdder, FieldFilter,
FieldCopier, FieldRenamer, Unique) SubrecordFilter,
ConditionalPathFilter,
ConditionalFilter,
FieldModifier,
FieldKeeper,
FieldRemover,
FieldMerger,
FieldAdder,
FieldCopier,
FieldRenamer,
Unique,
)
class DummyRecipe(object): class DummyRecipe(object):
rejected_record = None rejected_record = None
rejected_msg = None rejected_msg = None
def reject_record(self, record, msg): def reject_record(self, record, msg):
self.rejected_record = record self.rejected_record = record
self.rejected_msg = msg self.rejected_msg = msg
class Doubler(Filter): class Doubler(Filter):
def process_record(self, record): def process_record(self, record):
return record*2 return record * 2
class OddRemover(Filter): class OddRemover(Filter):
def process_record(self, record): def process_record(self, record):
if record % 2 == 0: if record % 2 == 0:
return record return record
else: else:
return None # explicitly return None return None # explicitly return None
class ListFlattener(YieldFilter): class ListFlattener(YieldFilter):
def process_record(self, record): def process_record(self, record):
for item in record: for item in record:
yield item yield item
class FieldDoubler(FieldFilter): class FieldDoubler(FieldFilter):
def process_field(self, item): def process_field(self, item):
return item*2 return item * 2
class NonModifyingFieldDoubler(Filter): class NonModifyingFieldDoubler(Filter):
def __init__(self, key): def __init__(self, key):
@ -43,17 +61,20 @@ class NonModifyingFieldDoubler(Filter):
record[self.key] *= 2 record[self.key] *= 2
return record return record
class ConditionalOddRemover(ConditionalFilter): class ConditionalOddRemover(ConditionalFilter):
def test_record(self, record): def test_record(self, record):
# return True for even values # return True for even values
return record % 2 == 0 return record % 2 == 0
class FilterTestCase(unittest.TestCase):
class FilterTestCase(unittest.TestCase):
def _simple_data(self): def _simple_data(self):
return [{'a':1, 'b':2, 'c':3}, return [
{'a':5, 'b':5, 'c':5}, {"a": 1, "b": 2, "c": 3},
{'a':1, 'b':10, 'c':100}] {"a": 5, "b": 5, "c": 5},
{"a": 1, "b": 10, "c": 100},
]
def assert_filter_result(self, filter_obj, expected_data): def assert_filter_result(self, filter_obj, expected_data):
result = filter_obj.attach(self._simple_data()) result = filter_obj.attach(self._simple_data())
@ -62,45 +83,47 @@ class FilterTestCase(unittest.TestCase):
def test_reject_record(self): def test_reject_record(self):
recipe = DummyRecipe() recipe = DummyRecipe()
f = Doubler() f = Doubler()
result = f.attach([1,2,3], recipe=recipe) result = f.attach([1, 2, 3], recipe=recipe)
# next has to be called for attach to take effect # next has to be called for attach to take effect
next(result) next(result)
f.reject_record('bad', 'this one was bad') f.reject_record("bad", "this one was bad")
# ensure that the rejection propagated to the recipe # ensure that the rejection propagated to the recipe
self.assertEqual('bad', recipe.rejected_record) self.assertEqual("bad", recipe.rejected_record)
self.assertEqual('this one was bad', recipe.rejected_msg) self.assertEqual("this one was bad", recipe.rejected_msg)
def test_simple_filter(self): def test_simple_filter(self):
df = Doubler() df = Doubler()
result = df.attach([1,2,3]) result = df.attach([1, 2, 3])
# ensure we got a generator that yields 2,4,6 # ensure we got a generator that yields 2,4,6
self.assertEqual(type(result), types.GeneratorType) self.assertEqual(type(result), types.GeneratorType)
self.assertEqual(list(result), [2,4,6]) self.assertEqual(list(result), [2, 4, 6])
def test_simple_filter_return_none(self): def test_simple_filter_return_none(self):
cf = OddRemover() cf = OddRemover()
result = cf.attach(range(10)) result = cf.attach(range(10))
# ensure only even numbers remain # ensure only even numbers remain
self.assertEqual(list(result), [0,2,4,6,8]) self.assertEqual(list(result), [0, 2, 4, 6, 8])
def test_simple_yield_filter(self): def test_simple_yield_filter(self):
lf = ListFlattener() lf = ListFlattener()
result = lf.attach([[1],[2,3],[4,5,6]]) result = lf.attach([[1], [2, 3], [4, 5, 6]])
# ensure we got a generator that yields 1,2,3,4,5,6 # ensure we got a generator that yields 1,2,3,4,5,6
self.assertEqual(type(result), types.GeneratorType) self.assertEqual(type(result), types.GeneratorType)
self.assertEqual(list(result), [1,2,3,4,5,6]) self.assertEqual(list(result), [1, 2, 3, 4, 5, 6])
def test_simple_field_filter(self): def test_simple_field_filter(self):
ff = FieldDoubler(['a', 'c']) ff = FieldDoubler(["a", "c"])
# check against expected data # check against expected data
expected_data = [{'a':2, 'b':2, 'c':6}, expected_data = [
{'a':10, 'b':5, 'c':10}, {"a": 2, "b": 2, "c": 6},
{'a':2, 'b':10, 'c':200}] {"a": 10, "b": 5, "c": 10},
{"a": 2, "b": 10, "c": 200},
]
self.assert_filter_result(ff, expected_data) self.assert_filter_result(ff, expected_data)
def test_conditional_filter(self): def test_conditional_filter(self):
@ -108,84 +131,93 @@ class FilterTestCase(unittest.TestCase):
result = cf.attach(range(10)) result = cf.attach(range(10))
# ensure only even numbers remain # ensure only even numbers remain
self.assertEqual(list(result), [0,2,4,6,8]) self.assertEqual(list(result), [0, 2, 4, 6, 8])
### Tests for Subrecord ### Tests for Subrecord
def test_subrecord_filter_list(self): def test_subrecord_filter_list(self):
data = [{'a': [{'b': 2}, {'b': 4}]}, data = [
{'a': [{'b': 5}]}, {"a": [{"b": 2}, {"b": 4}]},
{'a': [{'b': 8}, {'b':2}, {'b':1}]}] {"a": [{"b": 5}]},
{"a": [{"b": 8}, {"b": 2}, {"b": 1}]},
]
expected = [{'a': [{'b': 4}, {'b': 8}]}, expected = [
{'a': [{'b': 10}]}, {"a": [{"b": 4}, {"b": 8}]},
{'a': [{'b': 16}, {'b':4}, {'b':2}]}] {"a": [{"b": 10}]},
{"a": [{"b": 16}, {"b": 4}, {"b": 2}]},
]
sf = SubrecordFilter('a', NonModifyingFieldDoubler('b')) sf = SubrecordFilter("a", NonModifyingFieldDoubler("b"))
result = sf.attach(data) result = sf.attach(data)
self.assertEqual(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_deep(self): def test_subrecord_filter_deep(self):
data = [{'a': {'d':[{'b': 2}, {'b': 4}]}}, data = [
{'a': {'d':[{'b': 5}]}}, {"a": {"d": [{"b": 2}, {"b": 4}]}},
{'a': {'d':[{'b': 8}, {'b':2}, {'b':1}]}}] {"a": {"d": [{"b": 5}]}},
{"a": {"d": [{"b": 8}, {"b": 2}, {"b": 1}]}},
]
expected = [{'a': {'d':[{'b': 4}, {'b': 8}]}}, expected = [
{'a': {'d':[{'b': 10}]}}, {"a": {"d": [{"b": 4}, {"b": 8}]}},
{'a': {'d':[{'b': 16}, {'b':4}, {'b':2}]}}] {"a": {"d": [{"b": 10}]}},
{"a": {"d": [{"b": 16}, {"b": 4}, {"b": 2}]}},
]
sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b')) sf = SubrecordFilter("a.d", NonModifyingFieldDoubler("b"))
result = sf.attach(data) result = sf.attach(data)
self.assertEqual(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_nonlist(self): def test_subrecord_filter_nonlist(self):
data = [ data = [
{'a':{'b':{'c':1}}}, {"a": {"b": {"c": 1}}},
{'a':{'b':{'c':2}}}, {"a": {"b": {"c": 2}}},
{'a':{'b':{'c':3}}}, {"a": {"b": {"c": 3}}},
] ]
expected = [ expected = [
{'a':{'b':{'c':2}}}, {"a": {"b": {"c": 2}}},
{'a':{'b':{'c':4}}}, {"a": {"b": {"c": 4}}},
{'a':{'b':{'c':6}}}, {"a": {"b": {"c": 6}}},
] ]
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter("a.b", NonModifyingFieldDoubler("c"))
result = sf.attach(data) result = sf.attach(data)
self.assertEqual(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_list_in_path(self): def test_subrecord_filter_list_in_path(self):
data = [ data = [
{'a': [{'b': {'c': 5}}, {'b': {'c': 6}}]}, {"a": [{"b": {"c": 5}}, {"b": {"c": 6}}]},
{'a': [{'b': {'c': 1}}, {'b': {'c': 2}}, {'b': {'c': 3}}]}, {"a": [{"b": {"c": 1}}, {"b": {"c": 2}}, {"b": {"c": 3}}]},
{'a': [{'b': {'c': 2}} ]} {"a": [{"b": {"c": 2}}]},
] ]
expected = [ expected = [
{'a': [{'b': {'c': 10}}, {'b': {'c': 12}}]}, {"a": [{"b": {"c": 10}}, {"b": {"c": 12}}]},
{'a': [{'b': {'c': 2}}, {'b': {'c': 4}}, {'b': {'c': 6}}]}, {"a": [{"b": {"c": 2}}, {"b": {"c": 4}}, {"b": {"c": 6}}]},
{'a': [{'b': {'c': 4}} ]} {"a": [{"b": {"c": 4}}]},
] ]
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter("a.b", NonModifyingFieldDoubler("c"))
result = sf.attach(data) result = sf.attach(data)
self.assertEqual(list(result), expected) self.assertEqual(list(result), expected)
def test_conditional_path(self): def test_conditional_path(self):
predicate = lambda r: r['a'] == 1 predicate = lambda r: r["a"] == 1
# double b if a == 1, otherwise double c # double b if a == 1, otherwise double c
cpf = ConditionalPathFilter(predicate, FieldDoubler('b'), cpf = ConditionalPathFilter(predicate, FieldDoubler("b"), FieldDoubler("c"))
FieldDoubler('c')) expected_data = [
expected_data = [{'a':1, 'b':4, 'c':3}, {"a": 1, "b": 4, "c": 3},
{'a':5, 'b':5, 'c':10}, {"a": 5, "b": 5, "c": 10},
{'a':1, 'b':20, 'c':100}] {"a": 1, "b": 20, "c": 100},
]
self.assert_filter_result(cpf, expected_data) self.assert_filter_result(cpf, expected_data)
@ -193,112 +225,132 @@ class FilterTestCase(unittest.TestCase):
def test_field_modifier(self): def test_field_modifier(self):
# another version of FieldDoubler # another version of FieldDoubler
fm = FieldModifier(['a', 'c'], lambda x: x*2) fm = FieldModifier(["a", "c"], lambda x: x * 2)
# check against expected data # check against expected data
expected_data = [{'a':2, 'b':2, 'c':6}, expected_data = [
{'a':10, 'b':5, 'c':10}, {"a": 2, "b": 2, "c": 6},
{'a':2, 'b':10, 'c':200}] {"a": 10, "b": 5, "c": 10},
{"a": 2, "b": 10, "c": 200},
]
self.assert_filter_result(fm, expected_data) self.assert_filter_result(fm, expected_data)
def test_field_keeper(self): def test_field_keeper(self):
fk = FieldKeeper(['c']) fk = FieldKeeper(["c"])
# check against expected results # check against expected results
expected_data = [{'c':3}, {'c':5}, {'c':100}] expected_data = [{"c": 3}, {"c": 5}, {"c": 100}]
self.assert_filter_result(fk, expected_data) self.assert_filter_result(fk, expected_data)
def test_field_remover(self): def test_field_remover(self):
fr = FieldRemover(['a', 'b']) fr = FieldRemover(["a", "b"])
# check against expected results # check against expected results
expected_data = [{'c':3}, {'c':5}, {'c':100}] expected_data = [{"c": 3}, {"c": 5}, {"c": 100}]
self.assert_filter_result(fr, expected_data) self.assert_filter_result(fr, expected_data)
def test_field_merger(self): def test_field_merger(self):
fm = FieldMerger({'sum':('a','b','c')}, lambda x,y,z: x+y+z) fm = FieldMerger({"sum": ("a", "b", "c")}, lambda x, y, z: x + y + z)
# check against expected results # check against expected results
expected_data = [{'sum':6}, {'sum':15}, {'sum':111}] expected_data = [{"sum": 6}, {"sum": 15}, {"sum": 111}]
self.assert_filter_result(fm, expected_data) self.assert_filter_result(fm, expected_data)
def test_field_merger_keep_fields(self): def test_field_merger_keep_fields(self):
fm = FieldMerger({'sum':('a','b','c')}, lambda x,y,z: x+y+z, fm = FieldMerger(
keep_fields=True) {"sum": ("a", "b", "c")}, lambda x, y, z: x + y + z, keep_fields=True
)
# check against expected results # check against expected results
expected_data = [{'a':1, 'b':2, 'c':3, 'sum':6}, expected_data = [
{'a':5, 'b':5, 'c':5, 'sum':15}, {"a": 1, "b": 2, "c": 3, "sum": 6},
{'a':1, 'b':10, 'c':100, 'sum': 111}] {"a": 5, "b": 5, "c": 5, "sum": 15},
{"a": 1, "b": 10, "c": 100, "sum": 111},
]
self.assert_filter_result(fm, expected_data) self.assert_filter_result(fm, expected_data)
def test_field_adder_scalar(self): def test_field_adder_scalar(self):
fa = FieldAdder('x', 7) fa = FieldAdder("x", 7)
expected_data = [{'a':1, 'b':2, 'c':3, 'x':7}, expected_data = [
{'a':5, 'b':5, 'c':5, 'x':7}, {"a": 1, "b": 2, "c": 3, "x": 7},
{'a':1, 'b':10, 'c':100, 'x': 7}] {"a": 5, "b": 5, "c": 5, "x": 7},
{"a": 1, "b": 10, "c": 100, "x": 7},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_adder_callable(self): def test_field_adder_callable(self):
fa = FieldAdder('x', lambda: 7) fa = FieldAdder("x", lambda: 7)
expected_data = [{'a':1, 'b':2, 'c':3, 'x':7}, expected_data = [
{'a':5, 'b':5, 'c':5, 'x':7}, {"a": 1, "b": 2, "c": 3, "x": 7},
{'a':1, 'b':10, 'c':100, 'x': 7}] {"a": 5, "b": 5, "c": 5, "x": 7},
{"a": 1, "b": 10, "c": 100, "x": 7},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_adder_iterable(self): def test_field_adder_iterable(self):
fa = FieldAdder('x', [1,2,3]) fa = FieldAdder("x", [1, 2, 3])
expected_data = [{'a':1, 'b':2, 'c':3, 'x':1}, expected_data = [
{'a':5, 'b':5, 'c':5, 'x':2}, {"a": 1, "b": 2, "c": 3, "x": 1},
{'a':1, 'b':10, 'c':100, 'x': 3}] {"a": 5, "b": 5, "c": 5, "x": 2},
{"a": 1, "b": 10, "c": 100, "x": 3},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_adder_replace(self): def test_field_adder_replace(self):
fa = FieldAdder('b', lambda: 7) fa = FieldAdder("b", lambda: 7)
expected_data = [{'a':1, 'b':7, 'c':3}, expected_data = [
{'a':5, 'b':7, 'c':5}, {"a": 1, "b": 7, "c": 3},
{'a':1, 'b':7, 'c':100}] {"a": 5, "b": 7, "c": 5},
{"a": 1, "b": 7, "c": 100},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_adder_no_replace(self): def test_field_adder_no_replace(self):
fa = FieldAdder('b', lambda: 7, replace=False) fa = FieldAdder("b", lambda: 7, replace=False)
expected_data = [{'a':1, 'b':2, 'c':3}, expected_data = [
{'a':5, 'b':5, 'c':5}, {"a": 1, "b": 2, "c": 3},
{'a':1, 'b':10, 'c':100}] {"a": 5, "b": 5, "c": 5},
{"a": 1, "b": 10, "c": 100},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_copier(self): def test_field_copier(self):
fc = FieldCopier({'a2':'a', 'b2':'b'}) fc = FieldCopier({"a2": "a", "b2": "b"})
expected_data = [{'a':1, 'b':2, 'c':3, 'a2':1, 'b2':2}, expected_data = [
{'a':5, 'b':5, 'c':5, 'a2':5, 'b2':5}, {"a": 1, "b": 2, "c": 3, "a2": 1, "b2": 2},
{'a':1, 'b':10, 'c':100, 'a2': 1, 'b2': 10}] {"a": 5, "b": 5, "c": 5, "a2": 5, "b2": 5},
{"a": 1, "b": 10, "c": 100, "a2": 1, "b2": 10},
]
self.assert_filter_result(fc, expected_data) self.assert_filter_result(fc, expected_data)
def test_field_renamer(self): def test_field_renamer(self):
fr = FieldRenamer({'x':'a', 'y':'b'}) fr = FieldRenamer({"x": "a", "y": "b"})
expected_data = [{'x':1, 'y':2, 'c':3}, expected_data = [
{'x':5, 'y':5, 'c':5}, {"x": 1, "y": 2, "c": 3},
{'x':1, 'y':10, 'c':100}] {"x": 5, "y": 5, "c": 5},
{"x": 1, "y": 10, "c": 100},
]
self.assert_filter_result(fr, expected_data) self.assert_filter_result(fr, expected_data)
# TODO: splitter & flattner tests? # TODO: splitter & flattner tests?
def test_unique_filter(self): def test_unique_filter(self):
u = Unique() u = Unique()
in_data = [{'a': 77}, {'a':33}, {'a': 77}] in_data = [{"a": 77}, {"a": 33}, {"a": 77}]
expected_data = [{'a': 77}, {'a':33}] expected_data = [{"a": 77}, {"a": 33}]
result = u.attach(in_data) result = u.attach(in_data)
self.assertEqual(list(result), expected_data) self.assertEqual(list(result), expected_data)
# TODO: unicode & string filter tests # TODO: unicode & string filter tests
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -22,11 +22,11 @@ class RecipeTestCase(unittest.TestCase):
def test_error_stream(self): def test_error_stream(self):
saver = Saver() saver = Saver()
recipe = Recipe(Raiser(), error_stream=saver) recipe = Recipe(Raiser(), error_stream=saver)
recipe.run([{'a': 1}, {'b': 2}]) recipe.run([{"a": 1}, {"b": 2}])
recipe.done() recipe.done()
self.assertEqual(saver.saved[0]['record'], {'a': 1}) self.assertEqual(saver.saved[0]["record"], {"a": 1})
self.assertEqual(saver.saved[1]['record'], {'b': 2}) self.assertEqual(saver.saved[1]["record"], {"b": 2})
# Must pass either a Recipe, a Filter or an iterable of Filters # Must pass either a Recipe, a Filter or an iterable of Filters
# as the error_stream argument # as the error_stream argument
@ -49,5 +49,5 @@ class RecipeTestCase(unittest.TestCase):
self.assertEqual(saver.saved, [1]) self.assertEqual(saver.saved, [1])
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -3,46 +3,56 @@ from io import BytesIO, StringIO
import unittest import unittest
from saucebrush.sources import ( from saucebrush.sources import (
CSVSource, FixedWidthFileSource, HtmlTableSource, JSONSource) CSVSource,
FixedWidthFileSource,
HtmlTableSource,
JSONSource,
)
class SourceTestCase(unittest.TestCase): class SourceTestCase(unittest.TestCase):
def _get_csv(self): def _get_csv(self):
data = '''a,b,c data = """a,b,c
1,2,3 1,2,3
5,5,5 5,5,5
1,10,100''' 1,10,100"""
return StringIO(data) return StringIO(data)
def test_csv_source_basic(self): def test_csv_source_basic(self):
source = CSVSource(self._get_csv()) source = CSVSource(self._get_csv())
expected_data = [{'a':'1', 'b':'2', 'c':'3'}, expected_data = [
{'a':'5', 'b':'5', 'c':'5'}, {"a": "1", "b": "2", "c": "3"},
{'a':'1', 'b':'10', 'c':'100'}] {"a": "5", "b": "5", "c": "5"},
{"a": "1", "b": "10", "c": "100"},
]
self.assertEqual(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_fieldnames(self): def test_csv_source_fieldnames(self):
source = CSVSource(self._get_csv(), ['x','y','z']) source = CSVSource(self._get_csv(), ["x", "y", "z"])
expected_data = [{'x':'a', 'y':'b', 'z':'c'}, expected_data = [
{'x':'1', 'y':'2', 'z':'3'}, {"x": "a", "y": "b", "z": "c"},
{'x':'5', 'y':'5', 'z':'5'}, {"x": "1", "y": "2", "z": "3"},
{'x':'1', 'y':'10', 'z':'100'}] {"x": "5", "y": "5", "z": "5"},
{"x": "1", "y": "10", "z": "100"},
]
self.assertEqual(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_skiprows(self): def test_csv_source_skiprows(self):
source = CSVSource(self._get_csv(), skiprows=1) source = CSVSource(self._get_csv(), skiprows=1)
expected_data = [{'a':'5', 'b':'5', 'c':'5'}, expected_data = [
{'a':'1', 'b':'10', 'c':'100'}] {"a": "5", "b": "5", "c": "5"},
{"a": "1", "b": "10", "c": "100"},
]
self.assertEqual(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_fixed_width_source(self): def test_fixed_width_source(self):
data = StringIO('JamesNovember 3 1986\nTim September151999') data = StringIO("JamesNovember 3 1986\nTim September151999")
fields = (('name',5), ('month',9), ('day',2), ('year',4)) fields = (("name", 5), ("month", 9), ("day", 2), ("year", 4))
source = FixedWidthFileSource(data, fields) source = FixedWidthFileSource(data, fields)
expected_data = [{'name':'James', 'month':'November', 'day':'3', expected_data = [
'year':'1986'}, {"name": "James", "month": "November", "day": "3", "year": "1986"},
{'name':'Tim', 'month':'September', 'day':'15', {"name": "Tim", "month": "September", "day": "15", "year": "1999"},
'year':'1999'}] ]
self.assertEqual(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_json_source(self): def test_json_source(self):
@ -50,11 +60,12 @@ class SourceTestCase(unittest.TestCase):
content = StringIO("""[{"a": 1, "b": "2", "c": 3}]""") content = StringIO("""[{"a": 1, "b": "2", "c": 3}]""")
js = JSONSource(content) js = JSONSource(content)
self.assertEqual(list(js), [{'a': 1, 'b': '2', 'c': 3}]) self.assertEqual(list(js), [{"a": 1, "b": "2", "c": 3}])
def test_html_table_source(self): def test_html_table_source(self):
content = StringIO(""" content = StringIO(
"""
<html> <html>
<table id="thetable"> <table id="thetable">
<tr> <tr>
@ -69,19 +80,21 @@ class SourceTestCase(unittest.TestCase):
</tr> </tr>
</table> </table>
</html> </html>
""") """
)
try: try:
import lxml import lxml
hts = HtmlTableSource(content, 'thetable') hts = HtmlTableSource(content, "thetable")
self.assertEqual(list(hts), [{'a': '1', 'b': '2', 'c': '3'}]) self.assertEqual(list(hts), [{"a": "1", "b": "2", "c": "3"}])
except ImportError: except ImportError:
# Python 2.6 doesn't have skipTest. We'll just suffer without it. # Python 2.6 doesn't have skipTest. We'll just suffer without it.
if hasattr(self, 'skipTest'): if hasattr(self, "skipTest"):
self.skipTest("lxml is not installed") self.skipTest("lxml is not installed")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,52 +1,55 @@
import unittest import unittest
from saucebrush.stats import Sum, Average, Median, MinMax, StandardDeviation, Histogram from saucebrush.stats import Sum, Average, Median, MinMax, StandardDeviation, Histogram
class StatsTestCase(unittest.TestCase):
class StatsTestCase(unittest.TestCase):
def _simple_data(self): def _simple_data(self):
return [{'a':1, 'b':2, 'c':3}, return [
{'a':5, 'b':5, 'c':5}, {"a": 1, "b": 2, "c": 3},
{'a':1, 'b':10, 'c':100}] {"a": 5, "b": 5, "c": 5},
{"a": 1, "b": 10, "c": 100},
]
def test_sum(self): def test_sum(self):
fltr = Sum('b') fltr = Sum("b")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), 17) self.assertEqual(fltr.value(), 17)
def test_average(self): def test_average(self):
fltr = Average('c') fltr = Average("c")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), 36.0) self.assertEqual(fltr.value(), 36.0)
def test_median(self): def test_median(self):
# odd number of values # odd number of values
fltr = Median('a') fltr = Median("a")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), 1) self.assertEqual(fltr.value(), 1)
# even number of values # even number of values
fltr = Median('a') fltr = Median("a")
list(fltr.attach(self._simple_data()[:2])) list(fltr.attach(self._simple_data()[:2]))
self.assertEqual(fltr.value(), 3) self.assertEqual(fltr.value(), 3)
def test_minmax(self): def test_minmax(self):
fltr = MinMax('b') fltr = MinMax("b")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), (2, 10)) self.assertEqual(fltr.value(), (2, 10))
def test_standard_deviation(self): def test_standard_deviation(self):
fltr = StandardDeviation('c') fltr = StandardDeviation("c")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.average(), 36.0) self.assertEqual(fltr.average(), 36.0)
self.assertEqual(fltr.median(), 5) self.assertEqual(fltr.median(), 5)
self.assertEqual(fltr.value(), (55.4346462061408, 3073.0)) self.assertEqual(fltr.value(), (55.4346462061408, 3073.0))
self.assertEqual(fltr.value(True), (45.2621990922521, 2048.6666666666665)) self.assertEqual(fltr.value(True), (45.2621990922521, 2048.6666666666665))
def test_histogram(self): def test_histogram(self):
fltr = Histogram('a') fltr = Histogram("a")
fltr.label_length = 1 fltr.label_length = 1
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(str(fltr), "\n1 **\n5 *\n") self.assertEqual(str(fltr), "\n1 **\n5 *\n")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -4,42 +4,48 @@ import os
try: try:
from urllib.request import urlopen # attemp py3 first from urllib.request import urlopen # attemp py3 first
except ImportError: except ImportError:
from urllib2 import urlopen # fallback to py2 from urllib2 import urlopen # fallback to py2
""" """
General utilities used within saucebrush that may be useful elsewhere. General utilities used within saucebrush that may be useful elsewhere.
""" """
def get_django_model(dj_settings, app_label, model_name): def get_django_model(dj_settings, app_label, model_name):
""" """
Get a django model given a settings file, app label, and model name. Get a django model given a settings file, app label, and model name.
""" """
from django.conf import settings from django.conf import settings
if not settings.configured: if not settings.configured:
settings.configure(DATABASE_ENGINE=dj_settings.DATABASE_ENGINE, settings.configure(
DATABASE_NAME=dj_settings.DATABASE_NAME, DATABASE_ENGINE=dj_settings.DATABASE_ENGINE,
DATABASE_USER=dj_settings.DATABASE_USER, DATABASE_NAME=dj_settings.DATABASE_NAME,
DATABASE_PASSWORD=dj_settings.DATABASE_PASSWORD, DATABASE_USER=dj_settings.DATABASE_USER,
DATABASE_HOST=dj_settings.DATABASE_HOST, DATABASE_PASSWORD=dj_settings.DATABASE_PASSWORD,
INSTALLED_APPS=dj_settings.INSTALLED_APPS) DATABASE_HOST=dj_settings.DATABASE_HOST,
INSTALLED_APPS=dj_settings.INSTALLED_APPS,
)
from django.db.models import get_model from django.db.models import get_model
return get_model(app_label, model_name) return get_model(app_label, model_name)
def flatten(item, prefix='', separator='_', keys=None):
"""
Flatten nested dictionary into one with its keys concatenated together.
>>> flatten({'a':1, 'b':{'c':2}, 'd':[{'e':{'r':7}}, {'e':5}], def flatten(item, prefix="", separator="_", keys=None):
'f':{'g':{'h':6}}}) """
{'a': 1, 'b_c': 2, 'd': [{'e_r': 7}, {'e': 5}], 'f_g_h': 6} Flatten nested dictionary into one with its keys concatenated together.
>>> flatten({'a':1, 'b':{'c':2}, 'd':[{'e':{'r':7}}, {'e':5}],
'f':{'g':{'h':6}}})
{'a': 1, 'b_c': 2, 'd': [{'e_r': 7}, {'e': 5}], 'f_g_h': 6}
""" """
# update dictionaries recursively # update dictionaries recursively
if isinstance(item, dict): if isinstance(item, dict):
# don't prepend a leading _ # don't prepend a leading _
if prefix != '': if prefix != "":
prefix += separator prefix += separator
retval = {} retval = {}
for key, value in item.items(): for key, value in item.items():
@ -48,24 +54,27 @@ def flatten(item, prefix='', separator='_', keys=None):
else: else:
retval[prefix + key] = value retval[prefix + key] = value
return retval return retval
#elif isinstance(item, (tuple, list)): # elif isinstance(item, (tuple, list)):
# return {prefix: [flatten(i, prefix, separator, keys) for i in item]} # return {prefix: [flatten(i, prefix, separator, keys) for i in item]}
else: else:
return {prefix: item} return {prefix: item}
def str_or_list(obj): def str_or_list(obj):
if isinstance(obj, str): if isinstance(obj, str):
return [obj] return [obj]
else: else:
return obj return obj
# #
# utility classes # utility classes
# #
class FallbackCounter(collections.defaultdict): class FallbackCounter(collections.defaultdict):
""" Python 2.6 does not have collections.Counter. """Python 2.6 does not have collections.Counter.
This is class that does the basics of what we need from Counter. This is class that does the basics of what we need from Counter.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -73,20 +82,20 @@ class FallbackCounter(collections.defaultdict):
def most_common(n=None): def most_common(n=None):
l = sorted(self.items(), l = sorted(self.items(), cmp=lambda x, y: cmp(x[1], y[1]))
cmp=lambda x,y: cmp(x[1], y[1]))
if n is not None: if n is not None:
l = l[:n] l = l[:n]
return l return l
class Files(object):
""" Iterate over multiple files as a single file. Pass the paths of the
files as arguments to the class constructor:
for line in Files('/path/to/file/a', '/path/to/file/b'): class Files(object):
pass """Iterate over multiple files as a single file. Pass the paths of the
files as arguments to the class constructor:
for line in Files('/path/to/file/a', '/path/to/file/b'):
pass
""" """
def __init__(self, *args): def __init__(self, *args):
@ -111,10 +120,11 @@ class Files(object):
yield line yield line
f.close() f.close()
class RemoteFile(object):
""" Stream data from a remote file.
:param url: URL to remote file class RemoteFile(object):
"""Stream data from a remote file.
:param url: URL to remote file
""" """
def __init__(self, url): def __init__(self, url):
@ -126,21 +136,24 @@ class RemoteFile(object):
yield line.rstrip() yield line.rstrip()
resp.close() resp.close()
class ZippedFiles(object): class ZippedFiles(object):
""" unpack a zipped collection of files on init. """unpack a zipped collection of files on init.
Takes a string with file location or zipfile.ZipFile object Takes a string with file location or zipfile.ZipFile object
Best to wrap this in a Files() object, if the goal is to have a Best to wrap this in a Files() object, if the goal is to have a
linereader, as this only returns filelike objects. linereader, as this only returns filelike objects.
if using a ZipFile object, make sure to set mode to 'a' or 'w' in order if using a ZipFile object, make sure to set mode to 'a' or 'w' in order
to use the add() function. to use the add() function.
""" """
def __init__(self, zippedfile): def __init__(self, zippedfile):
import zipfile import zipfile
if type(zippedfile) == str: if type(zippedfile) == str:
self._zipfile = zipfile.ZipFile(zippedfile,'a') self._zipfile = zipfile.ZipFile(zippedfile, "a")
else: else:
self._zipfile = zippedfile self._zipfile = zippedfile
self.paths = self._zipfile.namelist() self.paths = self._zipfile.namelist()
@ -152,10 +165,10 @@ class ZippedFiles(object):
def add(self, path, dirname=None, arcname=None): def add(self, path, dirname=None, arcname=None):
path_base = os.path.basename(path) path_base = os.path.basename(path)
if dirname: if dirname:
arcname = os.path.join(dirname,path_base) arcname = os.path.join(dirname, path_base)
if not arcname: if not arcname:
arcname = path_base arcname = path_base
self._zipfile.write(path,arcname) self._zipfile.write(path, arcname)
self.paths.append(path) self.paths.append(path)
def filereader(self): def filereader(self):