This commit is contained in:
James Turk 2022-11-10 21:26:09 -06:00
parent 806b8873ec
commit a7e3fc63b3
12 changed files with 751 additions and 554 deletions

View File

@ -13,39 +13,39 @@ class OvercookedError(Exception):
""" """
Exception for trying to operate on a Recipe that has been finished. Exception for trying to operate on a Recipe that has been finished.
""" """
pass pass
class Recipe(object): class Recipe(object):
def __init__(self, *filter_args, **kwargs): def __init__(self, *filter_args, **kwargs):
self.finished = False self.finished = False
self.filters = [] self.filters = []
for filter in filter_args: for filter in filter_args:
if hasattr(filter, 'filters'): if hasattr(filter, "filters"):
self.filters.extend(filter.filters) self.filters.extend(filter.filters)
else: else:
self.filters.append(filter) self.filters.append(filter)
self.error_stream = kwargs.get('error_stream') self.error_stream = kwargs.get("error_stream")
if self.error_stream and not isinstance(self.error_stream, Recipe): if self.error_stream and not isinstance(self.error_stream, Recipe):
if isinstance(self.error_stream, filters.Filter): if isinstance(self.error_stream, filters.Filter):
self.error_stream = Recipe(self.error_stream) self.error_stream = Recipe(self.error_stream)
elif hasattr(self.error_stream, '__iter__'): elif hasattr(self.error_stream, "__iter__"):
self.error_stream = Recipe(*self.error_stream) self.error_stream = Recipe(*self.error_stream)
else: else:
raise SaucebrushError('error_stream must be either a filter' raise SaucebrushError(
' or an iterable of filters') "error_stream must be either a filter" " or an iterable of filters"
)
def reject_record(self, record, exception): def reject_record(self, record, exception):
if self.error_stream: if self.error_stream:
self.error_stream.run([{'record': record, self.error_stream.run([{"record": record, "exception": repr(exception)}])
'exception': repr(exception)}])
def run(self, source): def run(self, source):
if self.finished: if self.finished:
raise OvercookedError('run() called on finished recipe') raise OvercookedError("run() called on finished recipe")
# connect datapath # connect datapath
data = source data = source
@ -58,7 +58,7 @@ class Recipe(object):
def done(self): def done(self):
if self.finished: if self.finished:
raise OvercookedError('done() called on finished recipe') raise OvercookedError("done() called on finished recipe")
self.finished = True self.finished = True
@ -74,8 +74,7 @@ class Recipe(object):
def run_recipe(source, *filter_args, **kwargs): def run_recipe(source, *filter_args, **kwargs):
""" Process data, taking it from a source and applying any number of filters """Process data, taking it from a source and applying any number of filters"""
"""
r = Recipe(*filter_args, **kwargs) r = Recipe(*filter_args, **kwargs)
r.run(source) r.run(source)

View File

@ -2,9 +2,9 @@
Saucebrush Emitters are filters that instead of modifying the record, output Saucebrush Emitters are filters that instead of modifying the record, output
it in some manner. it in some manner.
""" """
from __future__ import unicode_literals
from saucebrush.filters import Filter from saucebrush.filters import Filter
class Emitter(Filter): class Emitter(Filter):
"""ABC for emitters """ABC for emitters
@ -15,6 +15,7 @@ class Emitter(Filter):
all records are processed (allowing database flushes, or printing of all records are processed (allowing database flushes, or printing of
aggregate data). aggregate data).
""" """
def process_record(self, record): def process_record(self, record):
self.emit_record(record) self.emit_record(record)
return record return record
@ -24,8 +25,9 @@ class Emitter(Filter):
Called with a single record, should "emit" the record unmodified. Called with a single record, should "emit" the record unmodified.
""" """
raise NotImplementedError('emit_record not defined in ' + raise NotImplementedError(
self.__class__.__name__) "emit_record not defined in " + self.__class__.__name__
)
def done(self): def done(self):
"""No-op Method to be overridden. """No-op Method to be overridden.
@ -41,10 +43,12 @@ class DebugEmitter(Emitter):
DebugEmitter() by default prints to stdout. DebugEmitter() by default prints to stdout.
DebugEmitter(open('test', 'w')) would print to a file named test DebugEmitter(open('test', 'w')) would print to a file named test
""" """
def __init__(self, outfile=None): def __init__(self, outfile=None):
super(DebugEmitter, self).__init__() super(DebugEmitter, self).__init__()
if not outfile: if not outfile:
import sys import sys
self._outfile = sys.stdout self._outfile = sys.stdout
else: else:
self._outfile = outfile self._outfile = outfile
@ -68,6 +72,7 @@ class CountEmitter(Emitter):
if not outfile: if not outfile:
import sys import sys
self._outfile = sys.stdout self._outfile = sys.stdout
else: else:
self._outfile = outfile self._outfile = outfile
@ -84,7 +89,7 @@ class CountEmitter(Emitter):
self.count = 0 self.count = 0
def format(self): def format(self):
return self._format % {'count': self.count, 'of': self._of} return self._format % {"count": self.count, "of": self._of}
def emit_record(self, record): def emit_record(self, record):
self.count += 1 self.count += 1
@ -105,6 +110,7 @@ class CSVEmitter(Emitter):
def __init__(self, csvfile, fieldnames): def __init__(self, csvfile, fieldnames):
super(CSVEmitter, self).__init__() super(CSVEmitter, self).__init__()
import csv import csv
self._dictwriter = csv.DictWriter(csvfile, fieldnames) self._dictwriter = csv.DictWriter(csvfile, fieldnames)
# write header row # write header row
header_row = dict(zip(fieldnames, fieldnames)) header_row = dict(zip(fieldnames, fieldnames))
@ -127,24 +133,31 @@ class SqliteEmitter(Emitter):
def __init__(self, dbname, table_name, fieldnames=None, replace=False, quiet=False): def __init__(self, dbname, table_name, fieldnames=None, replace=False, quiet=False):
super(SqliteEmitter, self).__init__() super(SqliteEmitter, self).__init__()
import sqlite3 import sqlite3
self._conn = sqlite3.connect(dbname) self._conn = sqlite3.connect(dbname)
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._table_name = table_name self._table_name = table_name
self._replace = replace self._replace = replace
self._quiet = quiet self._quiet = quiet
if fieldnames: if fieldnames:
create = "CREATE TABLE IF NOT EXISTS %s (%s)" % (table_name, create = "CREATE TABLE IF NOT EXISTS %s (%s)" % (
', '.join([' '.join((field, 'TEXT')) for field in fieldnames])) table_name,
", ".join([" ".join((field, "TEXT")) for field in fieldnames]),
)
self._cursor.execute(create) self._cursor.execute(create)
def emit_record(self, record): def emit_record(self, record):
import sqlite3 import sqlite3
# input should be escaped with ? if data isn't trusted # input should be escaped with ? if data isn't trusted
qmarks = ','.join(('?',) * len(record)) qmarks = ",".join(("?",) * len(record))
insert = 'INSERT OR REPLACE' if self._replace else 'INSERT' insert = "INSERT OR REPLACE" if self._replace else "INSERT"
insert = '%s INTO %s (%s) VALUES (%s)' % (insert, self._table_name, insert = "%s INTO %s (%s) VALUES (%s)" % (
','.join(record.keys()), insert,
qmarks) self._table_name,
",".join(record.keys()),
qmarks,
)
try: try:
self._cursor.execute(insert, list(record.values())) self._cursor.execute(insert, list(record.values()))
except sqlite3.IntegrityError as ie: except sqlite3.IntegrityError as ie:
@ -173,11 +186,14 @@ class SqlDumpEmitter(Emitter):
self._fieldnames = fieldnames self._fieldnames = fieldnames
if not outfile: if not outfile:
import sys import sys
self._outfile = sys.stderr self._outfile = sys.stderr
else: else:
self._outfile = outfile self._outfile = outfile
self._insert_str = "INSERT INTO `%s` (`%s`) VALUES (%%s);\n" % ( self._insert_str = "INSERT INTO `%s` (`%s`) VALUES (%%s);\n" % (
table_name, '`,`'.join(fieldnames)) table_name,
"`,`".join(fieldnames),
)
def quote(self, item): def quote(self, item):
@ -190,14 +206,14 @@ class SqlDumpEmitter(Emitter):
types = (str,) types = (str,)
if isinstance(item, types): if isinstance(item, types):
item = item.replace("\\","\\\\").replace("'","\\'").replace(chr(0),'0') item = item.replace("\\", "\\\\").replace("'", "\\'").replace(chr(0), "0")
return "'%s'" % item return "'%s'" % item
return "%s" % item return "%s" % item
def emit_record(self, record): def emit_record(self, record):
quoted_data = [self.quote(record[field]) for field in self._fieldnames] quoted_data = [self.quote(record[field]) for field in self._fieldnames]
self._outfile.write(self._insert_str % ','.join(quoted_data)) self._outfile.write(self._insert_str % ",".join(quoted_data))
class DjangoModelEmitter(Emitter): class DjangoModelEmitter(Emitter):
@ -210,9 +226,11 @@ class DjangoModelEmitter(Emitter):
records to addressbook.models.friend model using database settings records to addressbook.models.friend model using database settings
from settings.py. from settings.py.
""" """
def __init__(self, dj_settings, app_label, model_name): def __init__(self, dj_settings, app_label, model_name):
super(DjangoModelEmitter, self).__init__() super(DjangoModelEmitter, self).__init__()
from saucebrush.utils import get_django_model from saucebrush.utils import get_django_model
self._dbmodel = get_django_model(dj_settings, app_label, model_name) self._dbmodel = get_django_model(dj_settings, app_label, model_name)
if not self._dbmodel: if not self._dbmodel:
raise Exception("No such model: %s %s" % (app_label, model_name)) raise Exception("No such model: %s %s" % (app_label, model_name))
@ -228,13 +246,24 @@ class MongoDBEmitter(Emitter):
be inserted are required parameters. The host and port are optional, be inserted are required parameters. The host and port are optional,
defaulting to 'localhost' and 27017, repectively. defaulting to 'localhost' and 27017, repectively.
""" """
def __init__(self, database, collection, host='localhost', port=27017, drop_collection=False, conn=None):
def __init__(
self,
database,
collection,
host="localhost",
port=27017,
drop_collection=False,
conn=None,
):
super(MongoDBEmitter, self).__init__() super(MongoDBEmitter, self).__init__()
from pymongo.database import Database from pymongo.database import Database
if not isinstance(database, Database): if not isinstance(database, Database):
if not conn: if not conn:
from pymongo.connection import Connection from pymongo.connection import Connection
conn = Connection(host, port) conn = Connection(host, port)
db = conn[database] db = conn[database]
else: else:
@ -255,6 +284,7 @@ class LoggingEmitter(Emitter):
a format parameter. The resulting message will get logged a format parameter. The resulting message will get logged
at the provided level. at the provided level.
""" """
import logging import logging
def __init__(self, logger, msg_template, level=logging.DEBUG): def __init__(self, logger, msg_template, level=logging.DEBUG):

View File

@ -12,9 +12,10 @@ import re
import time import time
###################### ######################
## Abstract Filters ## # Abstract Filters #
###################### ######################
class Filter(object): class Filter(object):
"""ABC for filters that operate on records. """ABC for filters that operate on records.
@ -27,11 +28,12 @@ class Filter(object):
Called with a single record, should return modified record. Called with a single record, should return modified record.
""" """
raise NotImplementedError('process_record not defined in ' + raise NotImplementedError(
self.__class__.__name__) "process_record not defined in " + self.__class__.__name__
)
def reject_record(self, record, exception): def reject_record(self, record, exception):
recipe = getattr(self, '_recipe') recipe = getattr(self, "_recipe")
if recipe: if recipe:
recipe.reject_record(record, exception) recipe.reject_record(record, exception)
@ -91,11 +93,13 @@ class FieldFilter(Filter):
def process_field(self, item): def process_field(self, item):
"""Given a value, return the value that it should be replaced with.""" """Given a value, return the value that it should be replaced with."""
raise NotImplementedError('process_field not defined in ' + raise NotImplementedError(
self.__class__.__name__) "process_field not defined in " + self.__class__.__name__
)
def __unicode__(self): def __unicode__(self):
return '%s( %s )' % (self.__class__.__name__, str(self._target_keys)) return "%s( %s )" % (self.__class__.__name__, str(self._target_keys))
class ConditionalFilter(YieldFilter): class ConditionalFilter(YieldFilter):
"""ABC for filters that only pass through records meeting a condition. """ABC for filters that only pass through records meeting a condition.
@ -120,14 +124,17 @@ class ConditionalFilter(YieldFilter):
def test_record(self, record): def test_record(self, record):
"""Given a record, return True iff it should be passed on""" """Given a record, return True iff it should be passed on"""
raise NotImplementedError('test_record not defined in ' + raise NotImplementedError(
self.__class__.__name__) "test_record not defined in " + self.__class__.__name__
)
class ValidationError(Exception): class ValidationError(Exception):
def __init__(self, record): def __init__(self, record):
super(ValidationError, self).__init__(repr(record)) super(ValidationError, self).__init__(repr(record))
self.record = record self.record = record
def _dotted_get(d, path): def _dotted_get(d, path):
""" """
utility function for SubrecordFilter utility function for SubrecordFilter
@ -135,15 +142,16 @@ def _dotted_get(d, path):
dives into a complex nested dictionary with paths like a.b.c dives into a complex nested dictionary with paths like a.b.c
""" """
if path: if path:
key_pieces = path.split('.', 1) key_pieces = path.split(".", 1)
piece = d[key_pieces[0]] piece = d[key_pieces[0]]
if isinstance(piece, (tuple, list)): if isinstance(piece, (tuple, list)):
return [_dotted_get(i, '.'.join(key_pieces[1:])) for i in piece] return [_dotted_get(i, ".".join(key_pieces[1:])) for i in piece]
elif isinstance(piece, (dict)): elif isinstance(piece, (dict)):
return _dotted_get(piece, '.'.join(key_pieces[1:])) return _dotted_get(piece, ".".join(key_pieces[1:]))
else: else:
return d return d
class SubrecordFilter(Filter): class SubrecordFilter(Filter):
"""Filter that calls another filter on subrecord(s) of a record """Filter that calls another filter on subrecord(s) of a record
@ -152,8 +160,8 @@ class SubrecordFilter(Filter):
""" """
def __init__(self, field_path, filter_): def __init__(self, field_path, filter_):
if '.' in field_path: if "." in field_path:
self.field_path, self.key = field_path.rsplit('.', 1) self.field_path, self.key = field_path.rsplit(".", 1)
else: else:
self.field_path = None self.field_path = None
self.key = field_path self.key = field_path
@ -178,6 +186,7 @@ class SubrecordFilter(Filter):
self.process_subrecord(subrecord_parent) self.process_subrecord(subrecord_parent)
return record return record
class ConditionalPathFilter(Filter): class ConditionalPathFilter(Filter):
"""Filter that uses a predicate to split input among two filter paths.""" """Filter that uses a predicate to split input among two filter paths."""
@ -192,10 +201,12 @@ class ConditionalPathFilter(Filter):
else: else:
return self.false_filter.process_record(record) return self.false_filter.process_record(record)
##################### #####################
## Generic Filters ## ## Generic Filters ##
##################### #####################
class FieldModifier(FieldFilter): class FieldModifier(FieldFilter):
"""Filter that calls a given function on a given set of fields. """Filter that calls a given function on a given set of fields.
@ -211,8 +222,11 @@ class FieldModifier(FieldFilter):
return self._filter_func(item) return self._filter_func(item)
def __unicode__(self): def __unicode__(self):
return '%s( %s, %s )' % (self.__class__.__name__, return "%s( %s, %s )" % (
str(self._target_keys), str(self._filter_func)) self.__class__.__name__,
str(self._target_keys),
str(self._filter_func),
)
class FieldKeeper(Filter): class FieldKeeper(Filter):
@ -250,7 +264,7 @@ class FieldRemover(Filter):
return record return record
def __unicode__(self): def __unicode__(self):
return '%s( %s )' % (self.__class__.__name__, str(self._target_keys)) return "%s( %s )" % (self.__class__.__name__, str(self._target_keys))
class FieldMerger(Filter): class FieldMerger(Filter):
@ -278,9 +292,11 @@ class FieldMerger(Filter):
return record return record
def __unicode__(self): def __unicode__(self):
return '%s( %s, %s )' % (self.__class__.__name__, return "%s( %s, %s )" % (
self.__class__.__name__,
str(self._field_mapping), str(self._field_mapping),
str(self._merge_func)) str(self._merge_func),
)
class FieldAdder(Filter): class FieldAdder(Filter):
@ -300,7 +316,7 @@ class FieldAdder(Filter):
super(FieldAdder, self).__init__() super(FieldAdder, self).__init__()
self._field_name = field_name self._field_name = field_name
self._field_value = field_value self._field_value = field_value
if hasattr(self._field_value, '__iter__'): if hasattr(self._field_value, "__iter__"):
value_iter = iter(self._field_value) value_iter = iter(self._field_value)
if hasattr(value_iter, "next"): if hasattr(value_iter, "next"):
self._field_value = value_iter.next self._field_value = value_iter.next
@ -317,8 +333,12 @@ class FieldAdder(Filter):
return record return record
def __unicode__(self): def __unicode__(self):
return '%s( %s, %s )' % (self.__class__.__name__, self._field_name, return "%s( %s, %s )" % (
str(self._field_value)) self.__class__.__name__,
self._field_name,
str(self._field_value),
)
class FieldCopier(Filter): class FieldCopier(Filter):
"""Filter that copies one field to another. """Filter that copies one field to another.
@ -326,6 +346,7 @@ class FieldCopier(Filter):
Takes a dictionary mapping destination keys to source keys. Takes a dictionary mapping destination keys to source keys.
""" """
def __init__(self, copy_mapping): def __init__(self, copy_mapping):
super(FieldCopier, self).__init__() super(FieldCopier, self).__init__()
self._copy_mapping = copy_mapping self._copy_mapping = copy_mapping
@ -336,11 +357,13 @@ class FieldCopier(Filter):
record[dest] = record[source] record[dest] = record[source]
return record return record
class FieldRenamer(Filter): class FieldRenamer(Filter):
"""Filter that renames one field to another. """Filter that renames one field to another.
Takes a dictionary mapping destination keys to source keys. Takes a dictionary mapping destination keys to source keys.
""" """
def __init__(self, rename_mapping): def __init__(self, rename_mapping):
super(FieldRenamer, self).__init__() super(FieldRenamer, self).__init__()
self._rename_mapping = rename_mapping self._rename_mapping = rename_mapping
@ -351,6 +374,7 @@ class FieldRenamer(Filter):
record[dest] = record.pop(source) record[dest] = record.pop(source)
return record return record
class FieldNameModifier(Filter): class FieldNameModifier(Filter):
"""Filter that calls a given function on a given set of fields. """Filter that calls a given function on a given set of fields.
@ -368,6 +392,7 @@ class FieldNameModifier(Filter):
record[dest] = record.pop(source) record[dest] = record.pop(source)
return record return record
class Splitter(Filter): class Splitter(Filter):
"""Filter that splits nested data into different paths. """Filter that splits nested data into different paths.
@ -422,6 +447,7 @@ class Flattener(FieldFilter):
{'addresses': [{'state': 'NC', 'street': '146 shirley drive'}, {'addresses': [{'state': 'NC', 'street': '146 shirley drive'},
{'state': 'NY', 'street': '3000 Winton Rd'}]} {'state': 'NY', 'street': '3000 Winton Rd'}]}
""" """
def __init__(self, keys): def __init__(self, keys):
super(Flattener, self).__init__(keys) super(Flattener, self).__init__(keys)
@ -436,7 +462,7 @@ class Flattener(FieldFilter):
class DictFlattener(Filter): class DictFlattener(Filter):
def __init__(self, keys, separator='_'): def __init__(self, keys, separator="_"):
super(DictFlattener, self).__init__() super(DictFlattener, self).__init__()
self._keys = utils.str_or_list(keys) self._keys = utils.str_or_list(keys)
self._separator = separator self._separator = separator
@ -446,8 +472,7 @@ class DictFlattener(Filter):
class Unique(ConditionalFilter): class Unique(ConditionalFilter):
""" Filter that ensures that all records passing through are unique. """Filter that ensures that all records passing through are unique."""
"""
def __init__(self): def __init__(self):
super(Unique, self).__init__() super(Unique, self).__init__()
@ -461,6 +486,7 @@ class Unique(ConditionalFilter):
else: else:
return False return False
class UniqueValidator(Unique): class UniqueValidator(Unique):
validator = True validator = True
@ -472,7 +498,7 @@ class UniqueID(ConditionalFilter):
of a composite ID. of a composite ID.
""" """
def __init__(self, field='id', *args): def __init__(self, field="id", *args):
super(UniqueID, self).__init__() super(UniqueID, self).__init__()
self._seen = set() self._seen = set()
self._id_fields = [field] self._id_fields = [field]
@ -486,15 +512,15 @@ class UniqueID(ConditionalFilter):
else: else:
return False return False
class UniqueIDValidator(UniqueID): class UniqueIDValidator(UniqueID):
validator = True validator = True
class UnicodeFilter(Filter): class UnicodeFilter(Filter):
""" Convert all str elements in the record to Unicode. """Convert all str elements in the record to Unicode."""
"""
def __init__(self, encoding='utf-8', errors='ignore'): def __init__(self, encoding="utf-8", errors="ignore"):
super(UnicodeFilter, self).__init__() super(UnicodeFilter, self).__init__()
self._encoding = encoding self._encoding = encoding
self._errors = errors self._errors = errors
@ -507,9 +533,9 @@ class UnicodeFilter(Filter):
record[key] = value.decode(self._encoding, self._errors) record[key] = value.decode(self._encoding, self._errors)
return record return record
class StringFilter(Filter):
def __init__(self, encoding='utf-8', errors='ignore'): class StringFilter(Filter):
def __init__(self, encoding="utf-8", errors="ignore"):
super(StringFilter, self).__init__() super(StringFilter, self).__init__()
self._encoding = encoding self._encoding = encoding
self._errors = errors self._errors = errors
@ -525,6 +551,7 @@ class StringFilter(Filter):
## Commonly Used Filters ## ## Commonly Used Filters ##
########################### ###########################
class PhoneNumberCleaner(FieldFilter): class PhoneNumberCleaner(FieldFilter):
"""Filter that cleans phone numbers to match a given format. """Filter that cleans phone numbers to match a given format.
@ -534,10 +561,11 @@ class PhoneNumberCleaner(FieldFilter):
PhoneNumberCleaner( ('phone','fax'), number_format='%s%s%s-%s%s%s-%s%s%s%s') PhoneNumberCleaner( ('phone','fax'), number_format='%s%s%s-%s%s%s-%s%s%s%s')
would format the phone & fax columns to 555-123-4567 format. would format the phone & fax columns to 555-123-4567 format.
""" """
def __init__(self, keys, number_format='%s%s%s.%s%s%s.%s%s%s%s'):
def __init__(self, keys, number_format="%s%s%s.%s%s%s.%s%s%s%s"):
super(PhoneNumberCleaner, self).__init__(keys) super(PhoneNumberCleaner, self).__init__(keys)
self._number_format = number_format self._number_format = number_format
self._num_re = re.compile('\d') self._num_re = re.compile("\d")
def process_field(self, item): def process_field(self, item):
nums = self._num_re.findall(item) nums = self._num_re.findall(item)
@ -545,19 +573,21 @@ class PhoneNumberCleaner(FieldFilter):
item = self._number_format % tuple(nums) item = self._number_format % tuple(nums)
return item return item
class DateCleaner(FieldFilter): class DateCleaner(FieldFilter):
"""Filter that cleans dates to match a given format. """Filter that cleans dates to match a given format.
Takes a list of target keys and to and from formats in strftime format. Takes a list of target keys and to and from formats in strftime format.
""" """
def __init__(self, keys, from_format, to_format): def __init__(self, keys, from_format, to_format):
super(DateCleaner, self).__init__(keys) super(DateCleaner, self).__init__(keys)
self._from_format = from_format self._from_format = from_format
self._to_format = to_format self._to_format = to_format
def process_field(self, item): def process_field(self, item):
return time.strftime(self._to_format, return time.strftime(self._to_format, time.strptime(item, self._from_format))
time.strptime(item, self._from_format))
class NameCleaner(Filter): class NameCleaner(Filter):
"""Filter that splits names into a first, last, and middle name field. """Filter that splits names into a first, last, and middle name field.
@ -570,20 +600,26 @@ class NameCleaner(Filter):
""" """
# first middle? last suffix? # first middle? last suffix?
FIRST_LAST = re.compile('''^\s*(?:(?P<firstname>\w+)(?:\.?) FIRST_LAST = re.compile(
"""^\s*(?:(?P<firstname>\w+)(?:\.?)
\s+(?:(?P<middlename>\w+)\.?\s+)? \s+(?:(?P<middlename>\w+)\.?\s+)?
(?P<lastname>[A-Za-z'-]+)) (?P<lastname>[A-Za-z'-]+))
(?:\s+(?P<suffix>JR\.?|II|III|IV))? (?:\s+(?P<suffix>JR\.?|II|III|IV))?
\s*$''', re.VERBOSE | re.IGNORECASE) \s*$""",
re.VERBOSE | re.IGNORECASE,
)
# last, first middle? suffix? # last, first middle? suffix?
LAST_FIRST = re.compile('''^\s*(?:(?P<lastname>[A-Za-z'-]+), LAST_FIRST = re.compile(
"""^\s*(?:(?P<lastname>[A-Za-z'-]+),
\s+(?P<firstname>\w+)(?:\.?) \s+(?P<firstname>\w+)(?:\.?)
(?:\s+(?P<middlename>\w+)\.?)?) (?:\s+(?P<middlename>\w+)\.?)?)
(?:\s+(?P<suffix>JR\.?|II|III|IV))? (?:\s+(?P<suffix>JR\.?|II|III|IV))?
\s*$''', re.VERBOSE | re.IGNORECASE) \s*$""",
re.VERBOSE | re.IGNORECASE,
)
def __init__(self, keys, prefix='', formats=None, nomatch_name=None): def __init__(self, keys, prefix="", formats=None, nomatch_name=None):
super(NameCleaner, self).__init__() super(NameCleaner, self).__init__()
self._keys = utils.str_or_list(keys) self._keys = utils.str_or_list(keys)
self._name_prefix = prefix self._name_prefix = prefix

View File

@ -9,6 +9,7 @@ import string
from saucebrush import utils from saucebrush import utils
class CSVSource(object): class CSVSource(object):
"""Saucebrush source for reading from CSV files. """Saucebrush source for reading from CSV files.
@ -25,6 +26,7 @@ class CSVSource(object):
def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs): def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs):
import csv import csv
self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs) self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs)
for _ in range(skiprows): for _ in range(skiprows):
next(self._dictreader) next(self._dictreader)
@ -68,8 +70,7 @@ class FixedWidthFileSource(object):
return record return record
def next(self): def next(self):
""" Keep Python 2 next() method that defers to __next__(). """Keep Python 2 next() method that defers to __next__()."""
"""
return self.__next__() return self.__next__()
@ -93,31 +94,33 @@ class HtmlTableSource(object):
# extract the table # extract the table
from lxml.html import parse from lxml.html import parse
doc = parse(htmlfile).getroot() doc = parse(htmlfile).getroot()
if isinstance(id_or_num, int): if isinstance(id_or_num, int):
table = doc.cssselect('table')[id_or_num] table = doc.cssselect("table")[id_or_num]
else: else:
table = doc.cssselect('table#%s' % id_or_num) table = doc.cssselect("table#%s" % id_or_num)
table = table[0] # get the first table table = table[0] # get the first table
# skip the necessary number of rows # skip the necessary number of rows
self._rows = table.cssselect('tr')[skiprows:] self._rows = table.cssselect("tr")[skiprows:]
# determine the fieldnames # determine the fieldnames
if not fieldnames: if not fieldnames:
self._fieldnames = [td.text_content() self._fieldnames = [
for td in self._rows[0].cssselect('td, th')] td.text_content() for td in self._rows[0].cssselect("td, th")
]
skiprows += 1 skiprows += 1
else: else:
self._fieldnames = fieldnames self._fieldnames = fieldnames
# skip the necessary number of rows # skip the necessary number of rows
self._rows = table.cssselect('tr')[skiprows:] self._rows = table.cssselect("tr")[skiprows:]
def process_tr(self): def process_tr(self):
for row in self._rows: for row in self._rows:
strings = [td.text_content() for td in row.cssselect('td')] strings = [td.text_content() for td in row.cssselect("td")]
yield dict(zip(self._fieldnames, strings)) yield dict(zip(self._fieldnames, strings))
def __iter__(self): def __iter__(self):
@ -135,12 +138,12 @@ class DjangoModelSource(object):
friends from the friend model in the phonebook app described in friends from the friend model in the phonebook app described in
settings.py. settings.py.
""" """
def __init__(self, dj_settings, app_label, model_name): def __init__(self, dj_settings, app_label, model_name):
dbmodel = utils.get_django_model(dj_settings, app_label, model_name) dbmodel = utils.get_django_model(dj_settings, app_label, model_name)
# only get values defined in model (no extra fields from custom manager) # only get values defined in model (no extra fields from custom manager)
self._data = dbmodel.objects.values(*[f.name self._data = dbmodel.objects.values(*[f.name for f in dbmodel._meta.fields])
for f in dbmodel._meta.fields])
def __iter__(self): def __iter__(self):
return iter(self._data) return iter(self._data)
@ -152,9 +155,13 @@ class MongoDBSource(object):
The record dict is populated with records matching the spec The record dict is populated with records matching the spec
from the specified database and collection. from the specified database and collection.
""" """
def __init__(self, database, collection, spec=None, host='localhost', port=27017, conn=None):
def __init__(
self, database, collection, spec=None, host="localhost", port=27017, conn=None
):
if not conn: if not conn:
from pymongo.connection import Connection from pymongo.connection import Connection
conn = Connection(host, port) conn = Connection(host, port)
self.collection = conn[database][collection] self.collection = conn[database][collection]
self.spec = spec self.spec = spec
@ -166,6 +173,7 @@ class MongoDBSource(object):
for doc in self.collection.find(self.spec): for doc in self.collection.find(self.spec):
yield dict(doc) yield dict(doc)
# dict_factory for sqlite source # dict_factory for sqlite source
def dict_factory(cursor, row): def dict_factory(cursor, row):
d = {} d = {}
@ -173,6 +181,7 @@ def dict_factory(cursor, row):
d[col[0]] = row[idx] d[col[0]] = row[idx]
return d return d
class SqliteSource(object): class SqliteSource(object):
"""Source that reads from a sqlite database. """Source that reads from a sqlite database.
@ -226,26 +235,28 @@ class FileSource(object):
def __iter__(self): def __iter__(self):
# This method would be a lot cleaner with the proposed # This method would be a lot cleaner with the proposed
# 'yield from' expression (PEP 380) # 'yield from' expression (PEP 380)
if hasattr(self._input, '__read__') or hasattr(self._input, 'read'): if hasattr(self._input, "__read__") or hasattr(self._input, "read"):
for record in self._process_file(self._input): for record in self._process_file(self._input):
yield record yield record
elif isinstance(self._input, str): elif isinstance(self._input, str):
with open(self._input) as f: with open(self._input) as f:
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
elif hasattr(self._input, '__iter__'): elif hasattr(self._input, "__iter__"):
for el in self._input: for el in self._input:
if isinstance(el, str): if isinstance(el, str):
with open(el) as f: with open(el) as f:
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
elif hasattr(el, '__read__') or hasattr(el, 'read'): elif hasattr(el, "__read__") or hasattr(el, "read"):
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
def _process_file(self, file): def _process_file(self, file):
raise NotImplementedError('Descendants of FileSource should implement' raise NotImplementedError(
' a custom _process_file method.') "Descendants of FileSource should implement"
" a custom _process_file method."
)
class JSONSource(FileSource): class JSONSource(FileSource):
@ -271,6 +282,7 @@ class JSONSource(FileSource):
else: else:
yield obj yield obj
class XMLSource(FileSource): class XMLSource(FileSource):
"""Source for reading from XML files. Use with the same kind of caution """Source for reading from XML files. Use with the same kind of caution
that you use to approach anything written in XML. that you use to approach anything written in XML.
@ -281,14 +293,13 @@ class XMLSource(FileSource):
almost never going to be useful at the top level. almost never going to be useful at the top level.
""" """
def __init__(self, input, node_path=None, attr_prefix='ATTR_', def __init__(self, input, node_path=None, attr_prefix="ATTR_", postprocessor=None):
postprocessor=None):
super(XMLSource, self).__init__(input) super(XMLSource, self).__init__(input)
self.node_list = node_path.split('.') self.node_list = node_path.split(".")
self.attr_prefix = attr_prefix self.attr_prefix = attr_prefix
self.postprocessor = postprocessor self.postprocessor = postprocessor
def _process_file(self, f, attr_prefix='ATTR_'): def _process_file(self, f, attr_prefix="ATTR_"):
"""xmltodict can either return attributes of nodes as prefixed fields """xmltodict can either return attributes of nodes as prefixed fields
(prefixes to avoid key collisions), or ignore them altogether. (prefixes to avoid key collisions), or ignore them altogether.
@ -299,8 +310,9 @@ class XMLSource(FileSource):
import xmltodict import xmltodict
if self.postprocessor: if self.postprocessor:
obj = xmltodict.parse(f, attr_prefix=self.attr_prefix, obj = xmltodict.parse(
postprocessor=self.postprocessor) f, attr_prefix=self.attr_prefix, postprocessor=self.postprocessor
)
else: else:
obj = xmltodict.parse(f, attr_prefix=self.attr_prefix) obj = xmltodict.parse(f, attr_prefix=self.attr_prefix)

View File

@ -1,9 +1,9 @@
from saucebrush.filters import Filter from saucebrush.filters import Filter
from saucebrush.utils import FallbackCounter from saucebrush.utils import FallbackCounter
import collections import collections
import itertools
import math import math
def _average(values): def _average(values):
"""Calculate the average of a list of values. """Calculate the average of a list of values.
@ -13,6 +13,7 @@ def _average(values):
if len(values) > 0: if len(values) > 0:
return sum(values) / float(value_count) return sum(values) / float(value_count)
def _median(values): def _median(values):
"""Calculate the median of a list of values. """Calculate the median of a list of values.
@ -37,6 +38,7 @@ def _median(values):
mid = int(count / 2) mid = int(count / 2)
return sum(values[mid - 1 : mid + 1]) / 2.0 return sum(values[mid - 1 : mid + 1]) / 2.0
def _stddev(values, population=False): def _stddev(values, population=False):
"""Calculate the standard deviation and variance of a list of values. """Calculate the standard deviation and variance of a list of values.
@ -56,9 +58,9 @@ def _stddev(values, population=False):
return (math.sqrt(variance), variance) # stddev is sqrt of variance return (math.sqrt(variance), variance) # stddev is sqrt of variance
class StatsFilter(Filter): class StatsFilter(Filter):
""" Base for all stats filters. """Base for all stats filters."""
"""
def __init__(self, field, test=None): def __init__(self, field, test=None):
self._field = field self._field = field
@ -70,12 +72,13 @@ class StatsFilter(Filter):
return record return record
def process_field(self, record): def process_field(self, record):
raise NotImplementedError('process_field not defined in ' + raise NotImplementedError(
self.__class__.__name__) "process_field not defined in " + self.__class__.__name__
)
def value(self): def value(self):
raise NotImplementedError('value not defined in ' + raise NotImplementedError("value not defined in " + self.__class__.__name__)
self.__class__.__name__)
class Sum(StatsFilter): class Sum(StatsFilter):
"""Calculate the sum of the values in a field. Field must contain either """Calculate the sum of the values in a field. Field must contain either
@ -92,6 +95,7 @@ class Sum(StatsFilter):
def value(self): def value(self):
return self._value return self._value
class Average(StatsFilter): class Average(StatsFilter):
"""Calculate the average (mean) of the values in a field. Field must """Calculate the average (mean) of the values in a field. Field must
contain either int or float values. contain either int or float values.
@ -110,6 +114,7 @@ class Average(StatsFilter):
def value(self): def value(self):
return self._value / float(self._count) return self._value / float(self._count)
class Median(StatsFilter): class Median(StatsFilter):
"""Calculate the median of the values in a field. Field must contain """Calculate the median of the values in a field. Field must contain
either int or float values. either int or float values.
@ -128,6 +133,7 @@ class Median(StatsFilter):
def value(self): def value(self):
return _median(self._values) return _median(self._values)
class MinMax(StatsFilter): class MinMax(StatsFilter):
"""Find the minimum and maximum values in a field. Field must contain """Find the minimum and maximum values in a field. Field must contain
either int or float values. either int or float values.
@ -148,6 +154,7 @@ class MinMax(StatsFilter):
def value(self): def value(self):
return (self._min, self._max) return (self._min, self._max)
class StandardDeviation(StatsFilter): class StandardDeviation(StatsFilter):
"""Calculate the standard deviation of the values in a field. Calling """Calculate the standard deviation of the values in a field. Calling
value() will return a standard deviation for the sample. Pass value() will return a standard deviation for the sample. Pass
@ -180,6 +187,7 @@ class StandardDeviation(StatsFilter):
""" """
return _stddev(self._values, population) return _stddev(self._values, population)
class Histogram(StatsFilter): class Histogram(StatsFilter):
"""Generate a basic histogram of the specified field. The value() method """Generate a basic histogram of the specified field. The value() method
returns a dict of value to occurance count mappings. The __str__ method returns a dict of value to occurance count mappings. The __str__ method
@ -194,7 +202,7 @@ class Histogram(StatsFilter):
def __init__(self, field, **kwargs): def __init__(self, field, **kwargs):
super(Histogram, self).__init__(field, **kwargs) super(Histogram, self).__init__(field, **kwargs)
if hasattr(collections, 'Counter'): if hasattr(collections, "Counter"):
self._counter = collections.Counter() self._counter = collections.Counter()
else: else:
self._counter = FallbackCounter() self._counter = FallbackCounter()

View File

@ -11,5 +11,5 @@ emitter_suite = unittest.TestLoader().loadTestsFromTestCase(EmitterTestCase)
recipe_suite = unittest.TestLoader().loadTestsFromTestCase(RecipeTestCase) recipe_suite = unittest.TestLoader().loadTestsFromTestCase(RecipeTestCase)
stats_suite = unittest.TestLoader().loadTestsFromTestCase(StatsTestCase) stats_suite = unittest.TestLoader().loadTestsFromTestCase(StatsTestCase)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -5,61 +5,90 @@ import os
import unittest import unittest
from saucebrush.emitters import ( from saucebrush.emitters import (
DebugEmitter, CSVEmitter, CountEmitter, SqliteEmitter, SqlDumpEmitter) DebugEmitter,
CSVEmitter,
CountEmitter,
SqliteEmitter,
SqlDumpEmitter,
)
class EmitterTestCase(unittest.TestCase): class EmitterTestCase(unittest.TestCase):
def test_debug_emitter(self): def test_debug_emitter(self):
with closing(StringIO()) as output: with closing(StringIO()) as output:
de = DebugEmitter(output) de = DebugEmitter(output)
list(de.attach([1, 2, 3])) list(de.attach([1, 2, 3]))
self.assertEqual(output.getvalue(), '1\n2\n3\n') self.assertEqual(output.getvalue(), "1\n2\n3\n")
def test_count_emitter(self): def test_count_emitter(self):
# values for test # values for test
values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22] values = [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
]
with closing(StringIO()) as output: with closing(StringIO()) as output:
# test without of parameter # test without of parameter
ce = CountEmitter(every=10, outfile=output, format="%(count)s records\n") ce = CountEmitter(every=10, outfile=output, format="%(count)s records\n")
list(ce.attach(values)) list(ce.attach(values))
self.assertEqual(output.getvalue(), '10 records\n20 records\n') self.assertEqual(output.getvalue(), "10 records\n20 records\n")
ce.done() ce.done()
self.assertEqual(output.getvalue(), '10 records\n20 records\n22 records\n') self.assertEqual(output.getvalue(), "10 records\n20 records\n22 records\n")
with closing(StringIO()) as output: with closing(StringIO()) as output:
# test with of parameter # test with of parameter
ce = CountEmitter(every=10, outfile=output, of=len(values)) ce = CountEmitter(every=10, outfile=output, of=len(values))
list(ce.attach(values)) list(ce.attach(values))
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n') self.assertEqual(output.getvalue(), "10 of 22\n20 of 22\n")
ce.done() ce.done()
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n') self.assertEqual(output.getvalue(), "10 of 22\n20 of 22\n22 of 22\n")
def test_csv_emitter(self): def test_csv_emitter(self):
try: try:
import cStringIO # if Python 2.x then use old cStringIO import cStringIO # if Python 2.x then use old cStringIO
io = cStringIO.StringIO() io = cStringIO.StringIO()
except: except:
io = StringIO() # if Python 3.x then use StringIO io = StringIO() # if Python 3.x then use StringIO
with closing(io) as output: with closing(io) as output:
ce = CSVEmitter(output, ('x','y','z')) ce = CSVEmitter(output, ("x", "y", "z"))
list(ce.attach([{'x':1, 'y':2, 'z':3}, {'x':5, 'y':5, 'z':5}])) list(ce.attach([{"x": 1, "y": 2, "z": 3}, {"x": 5, "y": 5, "z": 5}]))
self.assertEqual(output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n') self.assertEqual(output.getvalue(), "x,y,z\r\n1,2,3\r\n5,5,5\r\n")
def test_sqlite_emitter(self): def test_sqlite_emitter(self):
import sqlite3, tempfile import sqlite3, tempfile
with closing(tempfile.NamedTemporaryFile(suffix='.db')) as f: with closing(tempfile.NamedTemporaryFile(suffix=".db")) as f:
db_path = f.name db_path = f.name
sle = SqliteEmitter(db_path, 'testtable', fieldnames=('a','b','c')) sle = SqliteEmitter(db_path, "testtable", fieldnames=("a", "b", "c"))
list(sle.attach([{'a': '1', 'b': '2', 'c': '3'}])) list(sle.attach([{"a": "1", "b": "2", "c": "3"}]))
sle.done() sle.done()
with closing(sqlite3.connect(db_path)) as conn: with closing(sqlite3.connect(db_path)) as conn:
@ -69,18 +98,20 @@ class EmitterTestCase(unittest.TestCase):
os.unlink(db_path) os.unlink(db_path)
self.assertEqual(results, [('1', '2', '3')]) self.assertEqual(results, [("1", "2", "3")])
def test_sql_dump_emitter(self): def test_sql_dump_emitter(self):
with closing(StringIO()) as bffr: with closing(StringIO()) as bffr:
sde = SqlDumpEmitter(bffr, 'testtable', ('a', 'b')) sde = SqlDumpEmitter(bffr, "testtable", ("a", "b"))
list(sde.attach([{'a': 1, 'b': '2'}])) list(sde.attach([{"a": 1, "b": "2"}]))
sde.done() sde.done()
self.assertEqual(bffr.getvalue(), "INSERT INTO `testtable` (`a`,`b`) VALUES (1,'2');\n") self.assertEqual(
bffr.getvalue(), "INSERT INTO `testtable` (`a`,`b`) VALUES (1,'2');\n"
)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,23 +1,38 @@
import unittest import unittest
import operator import operator
import types import types
from saucebrush.filters import (Filter, YieldFilter, FieldFilter, from saucebrush.filters import (
SubrecordFilter, ConditionalPathFilter, Filter,
ConditionalFilter, FieldModifier, FieldKeeper, YieldFilter,
FieldRemover, FieldMerger, FieldAdder, FieldFilter,
FieldCopier, FieldRenamer, Unique) SubrecordFilter,
ConditionalPathFilter,
ConditionalFilter,
FieldModifier,
FieldKeeper,
FieldRemover,
FieldMerger,
FieldAdder,
FieldCopier,
FieldRenamer,
Unique,
)
class DummyRecipe(object): class DummyRecipe(object):
rejected_record = None rejected_record = None
rejected_msg = None rejected_msg = None
def reject_record(self, record, msg): def reject_record(self, record, msg):
self.rejected_record = record self.rejected_record = record
self.rejected_msg = msg self.rejected_msg = msg
class Doubler(Filter): class Doubler(Filter):
def process_record(self, record): def process_record(self, record):
return record * 2 return record * 2
class OddRemover(Filter): class OddRemover(Filter):
def process_record(self, record): def process_record(self, record):
if record % 2 == 0: if record % 2 == 0:
@ -25,15 +40,18 @@ class OddRemover(Filter):
else: else:
return None # explicitly return None return None # explicitly return None
class ListFlattener(YieldFilter): class ListFlattener(YieldFilter):
def process_record(self, record): def process_record(self, record):
for item in record: for item in record:
yield item yield item
class FieldDoubler(FieldFilter): class FieldDoubler(FieldFilter):
def process_field(self, item): def process_field(self, item):
return item * 2 return item * 2
class NonModifyingFieldDoubler(Filter): class NonModifyingFieldDoubler(Filter):
def __init__(self, key): def __init__(self, key):
self.key = key self.key = key
@ -43,17 +61,20 @@ class NonModifyingFieldDoubler(Filter):
record[self.key] *= 2 record[self.key] *= 2
return record return record
class ConditionalOddRemover(ConditionalFilter): class ConditionalOddRemover(ConditionalFilter):
def test_record(self, record): def test_record(self, record):
# return True for even values # return True for even values
return record % 2 == 0 return record % 2 == 0
class FilterTestCase(unittest.TestCase):
class FilterTestCase(unittest.TestCase):
def _simple_data(self): def _simple_data(self):
return [{'a':1, 'b':2, 'c':3}, return [
{'a':5, 'b':5, 'c':5}, {"a": 1, "b": 2, "c": 3},
{'a':1, 'b':10, 'c':100}] {"a": 5, "b": 5, "c": 5},
{"a": 1, "b": 10, "c": 100},
]
def assert_filter_result(self, filter_obj, expected_data): def assert_filter_result(self, filter_obj, expected_data):
result = filter_obj.attach(self._simple_data()) result = filter_obj.attach(self._simple_data())
@ -65,11 +86,11 @@ class FilterTestCase(unittest.TestCase):
result = f.attach([1, 2, 3], recipe=recipe) result = f.attach([1, 2, 3], recipe=recipe)
# next has to be called for attach to take effect # next has to be called for attach to take effect
next(result) next(result)
f.reject_record('bad', 'this one was bad') f.reject_record("bad", "this one was bad")
# ensure that the rejection propagated to the recipe # ensure that the rejection propagated to the recipe
self.assertEqual('bad', recipe.rejected_record) self.assertEqual("bad", recipe.rejected_record)
self.assertEqual('this one was bad', recipe.rejected_msg) self.assertEqual("this one was bad", recipe.rejected_msg)
def test_simple_filter(self): def test_simple_filter(self):
df = Doubler() df = Doubler()
@ -95,12 +116,14 @@ class FilterTestCase(unittest.TestCase):
self.assertEqual(list(result), [1, 2, 3, 4, 5, 6]) self.assertEqual(list(result), [1, 2, 3, 4, 5, 6])
def test_simple_field_filter(self): def test_simple_field_filter(self):
ff = FieldDoubler(['a', 'c']) ff = FieldDoubler(["a", "c"])
# check against expected data # check against expected data
expected_data = [{'a':2, 'b':2, 'c':6}, expected_data = [
{'a':10, 'b':5, 'c':10}, {"a": 2, "b": 2, "c": 6},
{'a':2, 'b':10, 'c':200}] {"a": 10, "b": 5, "c": 10},
{"a": 2, "b": 10, "c": 200},
]
self.assert_filter_result(ff, expected_data) self.assert_filter_result(ff, expected_data)
def test_conditional_filter(self): def test_conditional_filter(self):
@ -113,79 +136,88 @@ class FilterTestCase(unittest.TestCase):
### Tests for Subrecord ### Tests for Subrecord
def test_subrecord_filter_list(self): def test_subrecord_filter_list(self):
data = [{'a': [{'b': 2}, {'b': 4}]}, data = [
{'a': [{'b': 5}]}, {"a": [{"b": 2}, {"b": 4}]},
{'a': [{'b': 8}, {'b':2}, {'b':1}]}] {"a": [{"b": 5}]},
{"a": [{"b": 8}, {"b": 2}, {"b": 1}]},
]
expected = [{'a': [{'b': 4}, {'b': 8}]}, expected = [
{'a': [{'b': 10}]}, {"a": [{"b": 4}, {"b": 8}]},
{'a': [{'b': 16}, {'b':4}, {'b':2}]}] {"a": [{"b": 10}]},
{"a": [{"b": 16}, {"b": 4}, {"b": 2}]},
]
sf = SubrecordFilter('a', NonModifyingFieldDoubler('b')) sf = SubrecordFilter("a", NonModifyingFieldDoubler("b"))
result = sf.attach(data) result = sf.attach(data)
self.assertEqual(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_deep(self): def test_subrecord_filter_deep(self):
data = [{'a': {'d':[{'b': 2}, {'b': 4}]}}, data = [
{'a': {'d':[{'b': 5}]}}, {"a": {"d": [{"b": 2}, {"b": 4}]}},
{'a': {'d':[{'b': 8}, {'b':2}, {'b':1}]}}] {"a": {"d": [{"b": 5}]}},
{"a": {"d": [{"b": 8}, {"b": 2}, {"b": 1}]}},
]
expected = [{'a': {'d':[{'b': 4}, {'b': 8}]}}, expected = [
{'a': {'d':[{'b': 10}]}}, {"a": {"d": [{"b": 4}, {"b": 8}]}},
{'a': {'d':[{'b': 16}, {'b':4}, {'b':2}]}}] {"a": {"d": [{"b": 10}]}},
{"a": {"d": [{"b": 16}, {"b": 4}, {"b": 2}]}},
]
sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b')) sf = SubrecordFilter("a.d", NonModifyingFieldDoubler("b"))
result = sf.attach(data) result = sf.attach(data)
self.assertEqual(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_nonlist(self): def test_subrecord_filter_nonlist(self):
data = [ data = [
{'a':{'b':{'c':1}}}, {"a": {"b": {"c": 1}}},
{'a':{'b':{'c':2}}}, {"a": {"b": {"c": 2}}},
{'a':{'b':{'c':3}}}, {"a": {"b": {"c": 3}}},
] ]
expected = [ expected = [
{'a':{'b':{'c':2}}}, {"a": {"b": {"c": 2}}},
{'a':{'b':{'c':4}}}, {"a": {"b": {"c": 4}}},
{'a':{'b':{'c':6}}}, {"a": {"b": {"c": 6}}},
] ]
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter("a.b", NonModifyingFieldDoubler("c"))
result = sf.attach(data) result = sf.attach(data)
self.assertEqual(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_list_in_path(self): def test_subrecord_filter_list_in_path(self):
data = [ data = [
{'a': [{'b': {'c': 5}}, {'b': {'c': 6}}]}, {"a": [{"b": {"c": 5}}, {"b": {"c": 6}}]},
{'a': [{'b': {'c': 1}}, {'b': {'c': 2}}, {'b': {'c': 3}}]}, {"a": [{"b": {"c": 1}}, {"b": {"c": 2}}, {"b": {"c": 3}}]},
{'a': [{'b': {'c': 2}} ]} {"a": [{"b": {"c": 2}}]},
] ]
expected = [ expected = [
{'a': [{'b': {'c': 10}}, {'b': {'c': 12}}]}, {"a": [{"b": {"c": 10}}, {"b": {"c": 12}}]},
{'a': [{'b': {'c': 2}}, {'b': {'c': 4}}, {'b': {'c': 6}}]}, {"a": [{"b": {"c": 2}}, {"b": {"c": 4}}, {"b": {"c": 6}}]},
{'a': [{'b': {'c': 4}} ]} {"a": [{"b": {"c": 4}}]},
] ]
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter("a.b", NonModifyingFieldDoubler("c"))
result = sf.attach(data) result = sf.attach(data)
self.assertEqual(list(result), expected) self.assertEqual(list(result), expected)
def test_conditional_path(self): def test_conditional_path(self):
predicate = lambda r: r['a'] == 1 predicate = lambda r: r["a"] == 1
# double b if a == 1, otherwise double c # double b if a == 1, otherwise double c
cpf = ConditionalPathFilter(predicate, FieldDoubler('b'), cpf = ConditionalPathFilter(predicate, FieldDoubler("b"), FieldDoubler("c"))
FieldDoubler('c')) expected_data = [
expected_data = [{'a':1, 'b':4, 'c':3}, {"a": 1, "b": 4, "c": 3},
{'a':5, 'b':5, 'c':10}, {"a": 5, "b": 5, "c": 10},
{'a':1, 'b':20, 'c':100}] {"a": 1, "b": 20, "c": 100},
]
self.assert_filter_result(cpf, expected_data) self.assert_filter_result(cpf, expected_data)
@ -193,112 +225,132 @@ class FilterTestCase(unittest.TestCase):
def test_field_modifier(self): def test_field_modifier(self):
# another version of FieldDoubler # another version of FieldDoubler
fm = FieldModifier(['a', 'c'], lambda x: x*2) fm = FieldModifier(["a", "c"], lambda x: x * 2)
# check against expected data # check against expected data
expected_data = [{'a':2, 'b':2, 'c':6}, expected_data = [
{'a':10, 'b':5, 'c':10}, {"a": 2, "b": 2, "c": 6},
{'a':2, 'b':10, 'c':200}] {"a": 10, "b": 5, "c": 10},
{"a": 2, "b": 10, "c": 200},
]
self.assert_filter_result(fm, expected_data) self.assert_filter_result(fm, expected_data)
def test_field_keeper(self): def test_field_keeper(self):
fk = FieldKeeper(['c']) fk = FieldKeeper(["c"])
# check against expected results # check against expected results
expected_data = [{'c':3}, {'c':5}, {'c':100}] expected_data = [{"c": 3}, {"c": 5}, {"c": 100}]
self.assert_filter_result(fk, expected_data) self.assert_filter_result(fk, expected_data)
def test_field_remover(self): def test_field_remover(self):
fr = FieldRemover(['a', 'b']) fr = FieldRemover(["a", "b"])
# check against expected results # check against expected results
expected_data = [{'c':3}, {'c':5}, {'c':100}] expected_data = [{"c": 3}, {"c": 5}, {"c": 100}]
self.assert_filter_result(fr, expected_data) self.assert_filter_result(fr, expected_data)
def test_field_merger(self): def test_field_merger(self):
fm = FieldMerger({'sum':('a','b','c')}, lambda x,y,z: x+y+z) fm = FieldMerger({"sum": ("a", "b", "c")}, lambda x, y, z: x + y + z)
# check against expected results # check against expected results
expected_data = [{'sum':6}, {'sum':15}, {'sum':111}] expected_data = [{"sum": 6}, {"sum": 15}, {"sum": 111}]
self.assert_filter_result(fm, expected_data) self.assert_filter_result(fm, expected_data)
def test_field_merger_keep_fields(self): def test_field_merger_keep_fields(self):
fm = FieldMerger({'sum':('a','b','c')}, lambda x,y,z: x+y+z, fm = FieldMerger(
keep_fields=True) {"sum": ("a", "b", "c")}, lambda x, y, z: x + y + z, keep_fields=True
)
# check against expected results # check against expected results
expected_data = [{'a':1, 'b':2, 'c':3, 'sum':6}, expected_data = [
{'a':5, 'b':5, 'c':5, 'sum':15}, {"a": 1, "b": 2, "c": 3, "sum": 6},
{'a':1, 'b':10, 'c':100, 'sum': 111}] {"a": 5, "b": 5, "c": 5, "sum": 15},
{"a": 1, "b": 10, "c": 100, "sum": 111},
]
self.assert_filter_result(fm, expected_data) self.assert_filter_result(fm, expected_data)
def test_field_adder_scalar(self): def test_field_adder_scalar(self):
fa = FieldAdder('x', 7) fa = FieldAdder("x", 7)
expected_data = [{'a':1, 'b':2, 'c':3, 'x':7}, expected_data = [
{'a':5, 'b':5, 'c':5, 'x':7}, {"a": 1, "b": 2, "c": 3, "x": 7},
{'a':1, 'b':10, 'c':100, 'x': 7}] {"a": 5, "b": 5, "c": 5, "x": 7},
{"a": 1, "b": 10, "c": 100, "x": 7},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_adder_callable(self): def test_field_adder_callable(self):
fa = FieldAdder('x', lambda: 7) fa = FieldAdder("x", lambda: 7)
expected_data = [{'a':1, 'b':2, 'c':3, 'x':7}, expected_data = [
{'a':5, 'b':5, 'c':5, 'x':7}, {"a": 1, "b": 2, "c": 3, "x": 7},
{'a':1, 'b':10, 'c':100, 'x': 7}] {"a": 5, "b": 5, "c": 5, "x": 7},
{"a": 1, "b": 10, "c": 100, "x": 7},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_adder_iterable(self): def test_field_adder_iterable(self):
fa = FieldAdder('x', [1,2,3]) fa = FieldAdder("x", [1, 2, 3])
expected_data = [{'a':1, 'b':2, 'c':3, 'x':1}, expected_data = [
{'a':5, 'b':5, 'c':5, 'x':2}, {"a": 1, "b": 2, "c": 3, "x": 1},
{'a':1, 'b':10, 'c':100, 'x': 3}] {"a": 5, "b": 5, "c": 5, "x": 2},
{"a": 1, "b": 10, "c": 100, "x": 3},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_adder_replace(self): def test_field_adder_replace(self):
fa = FieldAdder('b', lambda: 7) fa = FieldAdder("b", lambda: 7)
expected_data = [{'a':1, 'b':7, 'c':3}, expected_data = [
{'a':5, 'b':7, 'c':5}, {"a": 1, "b": 7, "c": 3},
{'a':1, 'b':7, 'c':100}] {"a": 5, "b": 7, "c": 5},
{"a": 1, "b": 7, "c": 100},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_adder_no_replace(self): def test_field_adder_no_replace(self):
fa = FieldAdder('b', lambda: 7, replace=False) fa = FieldAdder("b", lambda: 7, replace=False)
expected_data = [{'a':1, 'b':2, 'c':3}, expected_data = [
{'a':5, 'b':5, 'c':5}, {"a": 1, "b": 2, "c": 3},
{'a':1, 'b':10, 'c':100}] {"a": 5, "b": 5, "c": 5},
{"a": 1, "b": 10, "c": 100},
]
self.assert_filter_result(fa, expected_data) self.assert_filter_result(fa, expected_data)
def test_field_copier(self): def test_field_copier(self):
fc = FieldCopier({'a2':'a', 'b2':'b'}) fc = FieldCopier({"a2": "a", "b2": "b"})
expected_data = [{'a':1, 'b':2, 'c':3, 'a2':1, 'b2':2}, expected_data = [
{'a':5, 'b':5, 'c':5, 'a2':5, 'b2':5}, {"a": 1, "b": 2, "c": 3, "a2": 1, "b2": 2},
{'a':1, 'b':10, 'c':100, 'a2': 1, 'b2': 10}] {"a": 5, "b": 5, "c": 5, "a2": 5, "b2": 5},
{"a": 1, "b": 10, "c": 100, "a2": 1, "b2": 10},
]
self.assert_filter_result(fc, expected_data) self.assert_filter_result(fc, expected_data)
def test_field_renamer(self): def test_field_renamer(self):
fr = FieldRenamer({'x':'a', 'y':'b'}) fr = FieldRenamer({"x": "a", "y": "b"})
expected_data = [{'x':1, 'y':2, 'c':3}, expected_data = [
{'x':5, 'y':5, 'c':5}, {"x": 1, "y": 2, "c": 3},
{'x':1, 'y':10, 'c':100}] {"x": 5, "y": 5, "c": 5},
{"x": 1, "y": 10, "c": 100},
]
self.assert_filter_result(fr, expected_data) self.assert_filter_result(fr, expected_data)
# TODO: splitter & flattner tests? # TODO: splitter & flattner tests?
def test_unique_filter(self): def test_unique_filter(self):
u = Unique() u = Unique()
in_data = [{'a': 77}, {'a':33}, {'a': 77}] in_data = [{"a": 77}, {"a": 33}, {"a": 77}]
expected_data = [{'a': 77}, {'a':33}] expected_data = [{"a": 77}, {"a": 33}]
result = u.attach(in_data) result = u.attach(in_data)
self.assertEqual(list(result), expected_data) self.assertEqual(list(result), expected_data)
# TODO: unicode & string filter tests # TODO: unicode & string filter tests
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -22,11 +22,11 @@ class RecipeTestCase(unittest.TestCase):
def test_error_stream(self): def test_error_stream(self):
saver = Saver() saver = Saver()
recipe = Recipe(Raiser(), error_stream=saver) recipe = Recipe(Raiser(), error_stream=saver)
recipe.run([{'a': 1}, {'b': 2}]) recipe.run([{"a": 1}, {"b": 2}])
recipe.done() recipe.done()
self.assertEqual(saver.saved[0]['record'], {'a': 1}) self.assertEqual(saver.saved[0]["record"], {"a": 1})
self.assertEqual(saver.saved[1]['record'], {'b': 2}) self.assertEqual(saver.saved[1]["record"], {"b": 2})
# Must pass either a Recipe, a Filter or an iterable of Filters # Must pass either a Recipe, a Filter or an iterable of Filters
# as the error_stream argument # as the error_stream argument
@ -49,5 +49,5 @@ class RecipeTestCase(unittest.TestCase):
self.assertEqual(saver.saved, [1]) self.assertEqual(saver.saved, [1])
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -3,46 +3,56 @@ from io import BytesIO, StringIO
import unittest import unittest
from saucebrush.sources import ( from saucebrush.sources import (
CSVSource, FixedWidthFileSource, HtmlTableSource, JSONSource) CSVSource,
FixedWidthFileSource,
HtmlTableSource,
JSONSource,
)
class SourceTestCase(unittest.TestCase): class SourceTestCase(unittest.TestCase):
def _get_csv(self): def _get_csv(self):
data = '''a,b,c data = """a,b,c
1,2,3 1,2,3
5,5,5 5,5,5
1,10,100''' 1,10,100"""
return StringIO(data) return StringIO(data)
def test_csv_source_basic(self): def test_csv_source_basic(self):
source = CSVSource(self._get_csv()) source = CSVSource(self._get_csv())
expected_data = [{'a':'1', 'b':'2', 'c':'3'}, expected_data = [
{'a':'5', 'b':'5', 'c':'5'}, {"a": "1", "b": "2", "c": "3"},
{'a':'1', 'b':'10', 'c':'100'}] {"a": "5", "b": "5", "c": "5"},
{"a": "1", "b": "10", "c": "100"},
]
self.assertEqual(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_fieldnames(self): def test_csv_source_fieldnames(self):
source = CSVSource(self._get_csv(), ['x','y','z']) source = CSVSource(self._get_csv(), ["x", "y", "z"])
expected_data = [{'x':'a', 'y':'b', 'z':'c'}, expected_data = [
{'x':'1', 'y':'2', 'z':'3'}, {"x": "a", "y": "b", "z": "c"},
{'x':'5', 'y':'5', 'z':'5'}, {"x": "1", "y": "2", "z": "3"},
{'x':'1', 'y':'10', 'z':'100'}] {"x": "5", "y": "5", "z": "5"},
{"x": "1", "y": "10", "z": "100"},
]
self.assertEqual(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_skiprows(self): def test_csv_source_skiprows(self):
source = CSVSource(self._get_csv(), skiprows=1) source = CSVSource(self._get_csv(), skiprows=1)
expected_data = [{'a':'5', 'b':'5', 'c':'5'}, expected_data = [
{'a':'1', 'b':'10', 'c':'100'}] {"a": "5", "b": "5", "c": "5"},
{"a": "1", "b": "10", "c": "100"},
]
self.assertEqual(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_fixed_width_source(self): def test_fixed_width_source(self):
data = StringIO('JamesNovember 3 1986\nTim September151999') data = StringIO("JamesNovember 3 1986\nTim September151999")
fields = (('name',5), ('month',9), ('day',2), ('year',4)) fields = (("name", 5), ("month", 9), ("day", 2), ("year", 4))
source = FixedWidthFileSource(data, fields) source = FixedWidthFileSource(data, fields)
expected_data = [{'name':'James', 'month':'November', 'day':'3', expected_data = [
'year':'1986'}, {"name": "James", "month": "November", "day": "3", "year": "1986"},
{'name':'Tim', 'month':'September', 'day':'15', {"name": "Tim", "month": "September", "day": "15", "year": "1999"},
'year':'1999'}] ]
self.assertEqual(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_json_source(self): def test_json_source(self):
@ -50,11 +60,12 @@ class SourceTestCase(unittest.TestCase):
content = StringIO("""[{"a": 1, "b": "2", "c": 3}]""") content = StringIO("""[{"a": 1, "b": "2", "c": 3}]""")
js = JSONSource(content) js = JSONSource(content)
self.assertEqual(list(js), [{'a': 1, 'b': '2', 'c': 3}]) self.assertEqual(list(js), [{"a": 1, "b": "2", "c": 3}])
def test_html_table_source(self): def test_html_table_source(self):
content = StringIO(""" content = StringIO(
"""
<html> <html>
<table id="thetable"> <table id="thetable">
<tr> <tr>
@ -69,19 +80,21 @@ class SourceTestCase(unittest.TestCase):
</tr> </tr>
</table> </table>
</html> </html>
""") """
)
try: try:
import lxml import lxml
hts = HtmlTableSource(content, 'thetable') hts = HtmlTableSource(content, "thetable")
self.assertEqual(list(hts), [{'a': '1', 'b': '2', 'c': '3'}]) self.assertEqual(list(hts), [{"a": "1", "b": "2", "c": "3"}])
except ImportError: except ImportError:
# Python 2.6 doesn't have skipTest. We'll just suffer without it. # Python 2.6 doesn't have skipTest. We'll just suffer without it.
if hasattr(self, 'skipTest'): if hasattr(self, "skipTest"):
self.skipTest("lxml is not installed") self.skipTest("lxml is not installed")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,41 +1,43 @@
import unittest import unittest
from saucebrush.stats import Sum, Average, Median, MinMax, StandardDeviation, Histogram from saucebrush.stats import Sum, Average, Median, MinMax, StandardDeviation, Histogram
class StatsTestCase(unittest.TestCase):
class StatsTestCase(unittest.TestCase):
def _simple_data(self): def _simple_data(self):
return [{'a':1, 'b':2, 'c':3}, return [
{'a':5, 'b':5, 'c':5}, {"a": 1, "b": 2, "c": 3},
{'a':1, 'b':10, 'c':100}] {"a": 5, "b": 5, "c": 5},
{"a": 1, "b": 10, "c": 100},
]
def test_sum(self): def test_sum(self):
fltr = Sum('b') fltr = Sum("b")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), 17) self.assertEqual(fltr.value(), 17)
def test_average(self): def test_average(self):
fltr = Average('c') fltr = Average("c")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), 36.0) self.assertEqual(fltr.value(), 36.0)
def test_median(self): def test_median(self):
# odd number of values # odd number of values
fltr = Median('a') fltr = Median("a")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), 1) self.assertEqual(fltr.value(), 1)
# even number of values # even number of values
fltr = Median('a') fltr = Median("a")
list(fltr.attach(self._simple_data()[:2])) list(fltr.attach(self._simple_data()[:2]))
self.assertEqual(fltr.value(), 3) self.assertEqual(fltr.value(), 3)
def test_minmax(self): def test_minmax(self):
fltr = MinMax('b') fltr = MinMax("b")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), (2, 10)) self.assertEqual(fltr.value(), (2, 10))
def test_standard_deviation(self): def test_standard_deviation(self):
fltr = StandardDeviation('c') fltr = StandardDeviation("c")
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.average(), 36.0) self.assertEqual(fltr.average(), 36.0)
self.assertEqual(fltr.median(), 5) self.assertEqual(fltr.median(), 5)
@ -43,10 +45,11 @@ class StatsTestCase(unittest.TestCase):
self.assertEqual(fltr.value(True), (45.2621990922521, 2048.6666666666665)) self.assertEqual(fltr.value(True), (45.2621990922521, 2048.6666666666665))
def test_histogram(self): def test_histogram(self):
fltr = Histogram('a') fltr = Histogram("a")
fltr.label_length = 1 fltr.label_length = 1
list(fltr.attach(self._simple_data())) list(fltr.attach(self._simple_data()))
self.assertEqual(str(fltr), "\n1 **\n5 *\n") self.assertEqual(str(fltr), "\n1 **\n5 *\n")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -10,23 +10,29 @@ except ImportError:
General utilities used within saucebrush that may be useful elsewhere. General utilities used within saucebrush that may be useful elsewhere.
""" """
def get_django_model(dj_settings, app_label, model_name): def get_django_model(dj_settings, app_label, model_name):
""" """
Get a django model given a settings file, app label, and model name. Get a django model given a settings file, app label, and model name.
""" """
from django.conf import settings from django.conf import settings
if not settings.configured: if not settings.configured:
settings.configure(DATABASE_ENGINE=dj_settings.DATABASE_ENGINE, settings.configure(
DATABASE_ENGINE=dj_settings.DATABASE_ENGINE,
DATABASE_NAME=dj_settings.DATABASE_NAME, DATABASE_NAME=dj_settings.DATABASE_NAME,
DATABASE_USER=dj_settings.DATABASE_USER, DATABASE_USER=dj_settings.DATABASE_USER,
DATABASE_PASSWORD=dj_settings.DATABASE_PASSWORD, DATABASE_PASSWORD=dj_settings.DATABASE_PASSWORD,
DATABASE_HOST=dj_settings.DATABASE_HOST, DATABASE_HOST=dj_settings.DATABASE_HOST,
INSTALLED_APPS=dj_settings.INSTALLED_APPS) INSTALLED_APPS=dj_settings.INSTALLED_APPS,
)
from django.db.models import get_model from django.db.models import get_model
return get_model(app_label, model_name) return get_model(app_label, model_name)
def flatten(item, prefix='', separator='_', keys=None):
def flatten(item, prefix="", separator="_", keys=None):
""" """
Flatten nested dictionary into one with its keys concatenated together. Flatten nested dictionary into one with its keys concatenated together.
@ -39,7 +45,7 @@ def flatten(item, prefix='', separator='_', keys=None):
if isinstance(item, dict): if isinstance(item, dict):
# don't prepend a leading _ # don't prepend a leading _
if prefix != '': if prefix != "":
prefix += separator prefix += separator
retval = {} retval = {}
for key, value in item.items(): for key, value in item.items():
@ -53,16 +59,19 @@ def flatten(item, prefix='', separator='_', keys=None):
else: else:
return {prefix: item} return {prefix: item}
def str_or_list(obj): def str_or_list(obj):
if isinstance(obj, str): if isinstance(obj, str):
return [obj] return [obj]
else: else:
return obj return obj
# #
# utility classes # utility classes
# #
class FallbackCounter(collections.defaultdict): class FallbackCounter(collections.defaultdict):
"""Python 2.6 does not have collections.Counter. """Python 2.6 does not have collections.Counter.
This is class that does the basics of what we need from Counter. This is class that does the basics of what we need from Counter.
@ -73,14 +82,14 @@ class FallbackCounter(collections.defaultdict):
def most_common(n=None): def most_common(n=None):
l = sorted(self.items(), l = sorted(self.items(), cmp=lambda x, y: cmp(x[1], y[1]))
cmp=lambda x,y: cmp(x[1], y[1]))
if n is not None: if n is not None:
l = l[:n] l = l[:n]
return l return l
class Files(object): class Files(object):
"""Iterate over multiple files as a single file. Pass the paths of the """Iterate over multiple files as a single file. Pass the paths of the
files as arguments to the class constructor: files as arguments to the class constructor:
@ -111,6 +120,7 @@ class Files(object):
yield line yield line
f.close() f.close()
class RemoteFile(object): class RemoteFile(object):
"""Stream data from a remote file. """Stream data from a remote file.
@ -126,6 +136,7 @@ class RemoteFile(object):
yield line.rstrip() yield line.rstrip()
resp.close() resp.close()
class ZippedFiles(object): class ZippedFiles(object):
"""unpack a zipped collection of files on init. """unpack a zipped collection of files on init.
@ -137,10 +148,12 @@ class ZippedFiles(object):
if using a ZipFile object, make sure to set mode to 'a' or 'w' in order if using a ZipFile object, make sure to set mode to 'a' or 'w' in order
to use the add() function. to use the add() function.
""" """
def __init__(self, zippedfile): def __init__(self, zippedfile):
import zipfile import zipfile
if type(zippedfile) == str: if type(zippedfile) == str:
self._zipfile = zipfile.ZipFile(zippedfile,'a') self._zipfile = zipfile.ZipFile(zippedfile, "a")
else: else:
self._zipfile = zippedfile self._zipfile = zippedfile
self.paths = self._zipfile.namelist() self.paths = self._zipfile.namelist()