This commit is contained in:
James Turk 2022-11-10 21:26:09 -06:00
parent 806b8873ec
commit a7e3fc63b3
12 changed files with 751 additions and 554 deletions

View File

@ -13,39 +13,39 @@ class OvercookedError(Exception):
"""
Exception for trying to operate on a Recipe that has been finished.
"""
pass
class Recipe(object):
def __init__(self, *filter_args, **kwargs):
self.finished = False
self.filters = []
for filter in filter_args:
if hasattr(filter, 'filters'):
if hasattr(filter, "filters"):
self.filters.extend(filter.filters)
else:
self.filters.append(filter)
self.error_stream = kwargs.get('error_stream')
self.error_stream = kwargs.get("error_stream")
if self.error_stream and not isinstance(self.error_stream, Recipe):
if isinstance(self.error_stream, filters.Filter):
self.error_stream = Recipe(self.error_stream)
elif hasattr(self.error_stream, '__iter__'):
elif hasattr(self.error_stream, "__iter__"):
self.error_stream = Recipe(*self.error_stream)
else:
raise SaucebrushError('error_stream must be either a filter'
' or an iterable of filters')
raise SaucebrushError(
"error_stream must be either a filter" " or an iterable of filters"
)
def reject_record(self, record, exception):
if self.error_stream:
self.error_stream.run([{'record': record,
'exception': repr(exception)}])
self.error_stream.run([{"record": record, "exception": repr(exception)}])
def run(self, source):
if self.finished:
raise OvercookedError('run() called on finished recipe')
raise OvercookedError("run() called on finished recipe")
# connect datapath
data = source
@ -58,7 +58,7 @@ class Recipe(object):
def done(self):
if self.finished:
raise OvercookedError('done() called on finished recipe')
raise OvercookedError("done() called on finished recipe")
self.finished = True
@ -74,8 +74,7 @@ class Recipe(object):
def run_recipe(source, *filter_args, **kwargs):
""" Process data, taking it from a source and applying any number of filters
"""
"""Process data, taking it from a source and applying any number of filters"""
r = Recipe(*filter_args, **kwargs)
r.run(source)

View File

@ -2,9 +2,9 @@
Saucebrush Emitters are filters that instead of modifying the record, output
it in some manner.
"""
from __future__ import unicode_literals
from saucebrush.filters import Filter
class Emitter(Filter):
"""ABC for emitters
@ -15,6 +15,7 @@ class Emitter(Filter):
all records are processed (allowing database flushes, or printing of
aggregate data).
"""
def process_record(self, record):
self.emit_record(record)
return record
@ -24,8 +25,9 @@ class Emitter(Filter):
Called with a single record, should "emit" the record unmodified.
"""
raise NotImplementedError('emit_record not defined in ' +
self.__class__.__name__)
raise NotImplementedError(
"emit_record not defined in " + self.__class__.__name__
)
def done(self):
"""No-op Method to be overridden.
@ -41,10 +43,12 @@ class DebugEmitter(Emitter):
DebugEmitter() by default prints to stdout.
DebugEmitter(open('test', 'w')) would print to a file named test
"""
def __init__(self, outfile=None):
super(DebugEmitter, self).__init__()
if not outfile:
import sys
self._outfile = sys.stdout
else:
self._outfile = outfile
@ -68,6 +72,7 @@ class CountEmitter(Emitter):
if not outfile:
import sys
self._outfile = sys.stdout
else:
self._outfile = outfile
@ -84,7 +89,7 @@ class CountEmitter(Emitter):
self.count = 0
def format(self):
return self._format % {'count': self.count, 'of': self._of}
return self._format % {"count": self.count, "of": self._of}
def emit_record(self, record):
self.count += 1
@ -105,6 +110,7 @@ class CSVEmitter(Emitter):
def __init__(self, csvfile, fieldnames):
super(CSVEmitter, self).__init__()
import csv
self._dictwriter = csv.DictWriter(csvfile, fieldnames)
# write header row
header_row = dict(zip(fieldnames, fieldnames))
@ -127,24 +133,31 @@ class SqliteEmitter(Emitter):
def __init__(self, dbname, table_name, fieldnames=None, replace=False, quiet=False):
super(SqliteEmitter, self).__init__()
import sqlite3
self._conn = sqlite3.connect(dbname)
self._cursor = self._conn.cursor()
self._table_name = table_name
self._replace = replace
self._quiet = quiet
if fieldnames:
create = "CREATE TABLE IF NOT EXISTS %s (%s)" % (table_name,
', '.join([' '.join((field, 'TEXT')) for field in fieldnames]))
create = "CREATE TABLE IF NOT EXISTS %s (%s)" % (
table_name,
", ".join([" ".join((field, "TEXT")) for field in fieldnames]),
)
self._cursor.execute(create)
def emit_record(self, record):
import sqlite3
# input should be escaped with ? if data isn't trusted
qmarks = ','.join(('?',) * len(record))
insert = 'INSERT OR REPLACE' if self._replace else 'INSERT'
insert = '%s INTO %s (%s) VALUES (%s)' % (insert, self._table_name,
','.join(record.keys()),
qmarks)
qmarks = ",".join(("?",) * len(record))
insert = "INSERT OR REPLACE" if self._replace else "INSERT"
insert = "%s INTO %s (%s) VALUES (%s)" % (
insert,
self._table_name,
",".join(record.keys()),
qmarks,
)
try:
self._cursor.execute(insert, list(record.values()))
except sqlite3.IntegrityError as ie:
@ -173,11 +186,14 @@ class SqlDumpEmitter(Emitter):
self._fieldnames = fieldnames
if not outfile:
import sys
self._outfile = sys.stderr
else:
self._outfile = outfile
self._insert_str = "INSERT INTO `%s` (`%s`) VALUES (%%s);\n" % (
table_name, '`,`'.join(fieldnames))
table_name,
"`,`".join(fieldnames),
)
def quote(self, item):
@ -190,14 +206,14 @@ class SqlDumpEmitter(Emitter):
types = (str,)
if isinstance(item, types):
item = item.replace("\\","\\\\").replace("'","\\'").replace(chr(0),'0')
item = item.replace("\\", "\\\\").replace("'", "\\'").replace(chr(0), "0")
return "'%s'" % item
return "%s" % item
def emit_record(self, record):
quoted_data = [self.quote(record[field]) for field in self._fieldnames]
self._outfile.write(self._insert_str % ','.join(quoted_data))
self._outfile.write(self._insert_str % ",".join(quoted_data))
class DjangoModelEmitter(Emitter):
@ -210,9 +226,11 @@ class DjangoModelEmitter(Emitter):
records to addressbook.models.friend model using database settings
from settings.py.
"""
def __init__(self, dj_settings, app_label, model_name):
super(DjangoModelEmitter, self).__init__()
from saucebrush.utils import get_django_model
self._dbmodel = get_django_model(dj_settings, app_label, model_name)
if not self._dbmodel:
raise Exception("No such model: %s %s" % (app_label, model_name))
@ -228,13 +246,24 @@ class MongoDBEmitter(Emitter):
be inserted are required parameters. The host and port are optional,
defaulting to 'localhost' and 27017, repectively.
"""
def __init__(self, database, collection, host='localhost', port=27017, drop_collection=False, conn=None):
def __init__(
self,
database,
collection,
host="localhost",
port=27017,
drop_collection=False,
conn=None,
):
super(MongoDBEmitter, self).__init__()
from pymongo.database import Database
if not isinstance(database, Database):
if not conn:
from pymongo.connection import Connection
conn = Connection(host, port)
db = conn[database]
else:
@ -255,6 +284,7 @@ class LoggingEmitter(Emitter):
a format parameter. The resulting message will get logged
at the provided level.
"""
import logging
def __init__(self, logger, msg_template, level=logging.DEBUG):

View File

@ -12,9 +12,10 @@ import re
import time
######################
## Abstract Filters ##
# Abstract Filters #
######################
class Filter(object):
"""ABC for filters that operate on records.
@ -27,11 +28,12 @@ class Filter(object):
Called with a single record, should return modified record.
"""
raise NotImplementedError('process_record not defined in ' +
self.__class__.__name__)
raise NotImplementedError(
"process_record not defined in " + self.__class__.__name__
)
def reject_record(self, record, exception):
recipe = getattr(self, '_recipe')
recipe = getattr(self, "_recipe")
if recipe:
recipe.reject_record(record, exception)
@ -91,11 +93,13 @@ class FieldFilter(Filter):
def process_field(self, item):
"""Given a value, return the value that it should be replaced with."""
raise NotImplementedError('process_field not defined in ' +
self.__class__.__name__)
raise NotImplementedError(
"process_field not defined in " + self.__class__.__name__
)
def __unicode__(self):
return '%s( %s )' % (self.__class__.__name__, str(self._target_keys))
return "%s( %s )" % (self.__class__.__name__, str(self._target_keys))
class ConditionalFilter(YieldFilter):
"""ABC for filters that only pass through records meeting a condition.
@ -120,14 +124,17 @@ class ConditionalFilter(YieldFilter):
def test_record(self, record):
"""Given a record, return True iff it should be passed on"""
raise NotImplementedError('test_record not defined in ' +
self.__class__.__name__)
raise NotImplementedError(
"test_record not defined in " + self.__class__.__name__
)
class ValidationError(Exception):
def __init__(self, record):
super(ValidationError, self).__init__(repr(record))
self.record = record
def _dotted_get(d, path):
"""
utility function for SubrecordFilter
@ -135,15 +142,16 @@ def _dotted_get(d, path):
dives into a complex nested dictionary with paths like a.b.c
"""
if path:
key_pieces = path.split('.', 1)
key_pieces = path.split(".", 1)
piece = d[key_pieces[0]]
if isinstance(piece, (tuple, list)):
return [_dotted_get(i, '.'.join(key_pieces[1:])) for i in piece]
return [_dotted_get(i, ".".join(key_pieces[1:])) for i in piece]
elif isinstance(piece, (dict)):
return _dotted_get(piece, '.'.join(key_pieces[1:]))
return _dotted_get(piece, ".".join(key_pieces[1:]))
else:
return d
class SubrecordFilter(Filter):
"""Filter that calls another filter on subrecord(s) of a record
@ -152,8 +160,8 @@ class SubrecordFilter(Filter):
"""
def __init__(self, field_path, filter_):
if '.' in field_path:
self.field_path, self.key = field_path.rsplit('.', 1)
if "." in field_path:
self.field_path, self.key = field_path.rsplit(".", 1)
else:
self.field_path = None
self.key = field_path
@ -178,6 +186,7 @@ class SubrecordFilter(Filter):
self.process_subrecord(subrecord_parent)
return record
class ConditionalPathFilter(Filter):
"""Filter that uses a predicate to split input among two filter paths."""
@ -192,10 +201,12 @@ class ConditionalPathFilter(Filter):
else:
return self.false_filter.process_record(record)
#####################
## Generic Filters ##
#####################
class FieldModifier(FieldFilter):
"""Filter that calls a given function on a given set of fields.
@ -211,8 +222,11 @@ class FieldModifier(FieldFilter):
return self._filter_func(item)
def __unicode__(self):
return '%s( %s, %s )' % (self.__class__.__name__,
str(self._target_keys), str(self._filter_func))
return "%s( %s, %s )" % (
self.__class__.__name__,
str(self._target_keys),
str(self._filter_func),
)
class FieldKeeper(Filter):
@ -250,7 +264,7 @@ class FieldRemover(Filter):
return record
def __unicode__(self):
return '%s( %s )' % (self.__class__.__name__, str(self._target_keys))
return "%s( %s )" % (self.__class__.__name__, str(self._target_keys))
class FieldMerger(Filter):
@ -278,9 +292,11 @@ class FieldMerger(Filter):
return record
def __unicode__(self):
return '%s( %s, %s )' % (self.__class__.__name__,
return "%s( %s, %s )" % (
self.__class__.__name__,
str(self._field_mapping),
str(self._merge_func))
str(self._merge_func),
)
class FieldAdder(Filter):
@ -300,7 +316,7 @@ class FieldAdder(Filter):
super(FieldAdder, self).__init__()
self._field_name = field_name
self._field_value = field_value
if hasattr(self._field_value, '__iter__'):
if hasattr(self._field_value, "__iter__"):
value_iter = iter(self._field_value)
if hasattr(value_iter, "next"):
self._field_value = value_iter.next
@ -317,8 +333,12 @@ class FieldAdder(Filter):
return record
def __unicode__(self):
return '%s( %s, %s )' % (self.__class__.__name__, self._field_name,
str(self._field_value))
return "%s( %s, %s )" % (
self.__class__.__name__,
self._field_name,
str(self._field_value),
)
class FieldCopier(Filter):
"""Filter that copies one field to another.
@ -326,6 +346,7 @@ class FieldCopier(Filter):
Takes a dictionary mapping destination keys to source keys.
"""
def __init__(self, copy_mapping):
super(FieldCopier, self).__init__()
self._copy_mapping = copy_mapping
@ -336,11 +357,13 @@ class FieldCopier(Filter):
record[dest] = record[source]
return record
class FieldRenamer(Filter):
"""Filter that renames one field to another.
Takes a dictionary mapping destination keys to source keys.
"""
def __init__(self, rename_mapping):
super(FieldRenamer, self).__init__()
self._rename_mapping = rename_mapping
@ -351,6 +374,7 @@ class FieldRenamer(Filter):
record[dest] = record.pop(source)
return record
class FieldNameModifier(Filter):
"""Filter that calls a given function on a given set of fields.
@ -368,6 +392,7 @@ class FieldNameModifier(Filter):
record[dest] = record.pop(source)
return record
class Splitter(Filter):
"""Filter that splits nested data into different paths.
@ -422,6 +447,7 @@ class Flattener(FieldFilter):
{'addresses': [{'state': 'NC', 'street': '146 shirley drive'},
{'state': 'NY', 'street': '3000 Winton Rd'}]}
"""
def __init__(self, keys):
super(Flattener, self).__init__(keys)
@ -436,7 +462,7 @@ class Flattener(FieldFilter):
class DictFlattener(Filter):
def __init__(self, keys, separator='_'):
def __init__(self, keys, separator="_"):
super(DictFlattener, self).__init__()
self._keys = utils.str_or_list(keys)
self._separator = separator
@ -446,8 +472,7 @@ class DictFlattener(Filter):
class Unique(ConditionalFilter):
""" Filter that ensures that all records passing through are unique.
"""
"""Filter that ensures that all records passing through are unique."""
def __init__(self):
super(Unique, self).__init__()
@ -461,6 +486,7 @@ class Unique(ConditionalFilter):
else:
return False
class UniqueValidator(Unique):
validator = True
@ -472,7 +498,7 @@ class UniqueID(ConditionalFilter):
of a composite ID.
"""
def __init__(self, field='id', *args):
def __init__(self, field="id", *args):
super(UniqueID, self).__init__()
self._seen = set()
self._id_fields = [field]
@ -486,15 +512,15 @@ class UniqueID(ConditionalFilter):
else:
return False
class UniqueIDValidator(UniqueID):
validator = True
class UnicodeFilter(Filter):
""" Convert all str elements in the record to Unicode.
"""
"""Convert all str elements in the record to Unicode."""
def __init__(self, encoding='utf-8', errors='ignore'):
def __init__(self, encoding="utf-8", errors="ignore"):
super(UnicodeFilter, self).__init__()
self._encoding = encoding
self._errors = errors
@ -507,9 +533,9 @@ class UnicodeFilter(Filter):
record[key] = value.decode(self._encoding, self._errors)
return record
class StringFilter(Filter):
def __init__(self, encoding='utf-8', errors='ignore'):
class StringFilter(Filter):
def __init__(self, encoding="utf-8", errors="ignore"):
super(StringFilter, self).__init__()
self._encoding = encoding
self._errors = errors
@ -525,6 +551,7 @@ class StringFilter(Filter):
## Commonly Used Filters ##
###########################
class PhoneNumberCleaner(FieldFilter):
"""Filter that cleans phone numbers to match a given format.
@ -534,10 +561,11 @@ class PhoneNumberCleaner(FieldFilter):
PhoneNumberCleaner( ('phone','fax'), number_format='%s%s%s-%s%s%s-%s%s%s%s')
would format the phone & fax columns to 555-123-4567 format.
"""
def __init__(self, keys, number_format='%s%s%s.%s%s%s.%s%s%s%s'):
def __init__(self, keys, number_format="%s%s%s.%s%s%s.%s%s%s%s"):
super(PhoneNumberCleaner, self).__init__(keys)
self._number_format = number_format
self._num_re = re.compile('\d')
self._num_re = re.compile("\d")
def process_field(self, item):
nums = self._num_re.findall(item)
@ -545,19 +573,21 @@ class PhoneNumberCleaner(FieldFilter):
item = self._number_format % tuple(nums)
return item
class DateCleaner(FieldFilter):
"""Filter that cleans dates to match a given format.
Takes a list of target keys and to and from formats in strftime format.
"""
def __init__(self, keys, from_format, to_format):
super(DateCleaner, self).__init__(keys)
self._from_format = from_format
self._to_format = to_format
def process_field(self, item):
return time.strftime(self._to_format,
time.strptime(item, self._from_format))
return time.strftime(self._to_format, time.strptime(item, self._from_format))
class NameCleaner(Filter):
"""Filter that splits names into a first, last, and middle name field.
@ -570,20 +600,26 @@ class NameCleaner(Filter):
"""
# first middle? last suffix?
FIRST_LAST = re.compile('''^\s*(?:(?P<firstname>\w+)(?:\.?)
FIRST_LAST = re.compile(
"""^\s*(?:(?P<firstname>\w+)(?:\.?)
\s+(?:(?P<middlename>\w+)\.?\s+)?
(?P<lastname>[A-Za-z'-]+))
(?:\s+(?P<suffix>JR\.?|II|III|IV))?
\s*$''', re.VERBOSE | re.IGNORECASE)
\s*$""",
re.VERBOSE | re.IGNORECASE,
)
# last, first middle? suffix?
LAST_FIRST = re.compile('''^\s*(?:(?P<lastname>[A-Za-z'-]+),
LAST_FIRST = re.compile(
"""^\s*(?:(?P<lastname>[A-Za-z'-]+),
\s+(?P<firstname>\w+)(?:\.?)
(?:\s+(?P<middlename>\w+)\.?)?)
(?:\s+(?P<suffix>JR\.?|II|III|IV))?
\s*$''', re.VERBOSE | re.IGNORECASE)
\s*$""",
re.VERBOSE | re.IGNORECASE,
)
def __init__(self, keys, prefix='', formats=None, nomatch_name=None):
def __init__(self, keys, prefix="", formats=None, nomatch_name=None):
super(NameCleaner, self).__init__()
self._keys = utils.str_or_list(keys)
self._name_prefix = prefix

View File

@ -9,6 +9,7 @@ import string
from saucebrush import utils
class CSVSource(object):
"""Saucebrush source for reading from CSV files.
@ -25,6 +26,7 @@ class CSVSource(object):
def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs):
import csv
self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs)
for _ in range(skiprows):
next(self._dictreader)
@ -68,8 +70,7 @@ class FixedWidthFileSource(object):
return record
def next(self):
""" Keep Python 2 next() method that defers to __next__().
"""
"""Keep Python 2 next() method that defers to __next__()."""
return self.__next__()
@ -93,31 +94,33 @@ class HtmlTableSource(object):
# extract the table
from lxml.html import parse
doc = parse(htmlfile).getroot()
if isinstance(id_or_num, int):
table = doc.cssselect('table')[id_or_num]
table = doc.cssselect("table")[id_or_num]
else:
table = doc.cssselect('table#%s' % id_or_num)
table = doc.cssselect("table#%s" % id_or_num)
table = table[0] # get the first table
# skip the necessary number of rows
self._rows = table.cssselect('tr')[skiprows:]
self._rows = table.cssselect("tr")[skiprows:]
# determine the fieldnames
if not fieldnames:
self._fieldnames = [td.text_content()
for td in self._rows[0].cssselect('td, th')]
self._fieldnames = [
td.text_content() for td in self._rows[0].cssselect("td, th")
]
skiprows += 1
else:
self._fieldnames = fieldnames
# skip the necessary number of rows
self._rows = table.cssselect('tr')[skiprows:]
self._rows = table.cssselect("tr")[skiprows:]
def process_tr(self):
for row in self._rows:
strings = [td.text_content() for td in row.cssselect('td')]
strings = [td.text_content() for td in row.cssselect("td")]
yield dict(zip(self._fieldnames, strings))
def __iter__(self):
@ -135,12 +138,12 @@ class DjangoModelSource(object):
friends from the friend model in the phonebook app described in
settings.py.
"""
def __init__(self, dj_settings, app_label, model_name):
dbmodel = utils.get_django_model(dj_settings, app_label, model_name)
# only get values defined in model (no extra fields from custom manager)
self._data = dbmodel.objects.values(*[f.name
for f in dbmodel._meta.fields])
self._data = dbmodel.objects.values(*[f.name for f in dbmodel._meta.fields])
def __iter__(self):
return iter(self._data)
@ -152,9 +155,13 @@ class MongoDBSource(object):
The record dict is populated with records matching the spec
from the specified database and collection.
"""
def __init__(self, database, collection, spec=None, host='localhost', port=27017, conn=None):
def __init__(
self, database, collection, spec=None, host="localhost", port=27017, conn=None
):
if not conn:
from pymongo.connection import Connection
conn = Connection(host, port)
self.collection = conn[database][collection]
self.spec = spec
@ -166,6 +173,7 @@ class MongoDBSource(object):
for doc in self.collection.find(self.spec):
yield dict(doc)
# dict_factory for sqlite source
def dict_factory(cursor, row):
d = {}
@ -173,6 +181,7 @@ def dict_factory(cursor, row):
d[col[0]] = row[idx]
return d
class SqliteSource(object):
"""Source that reads from a sqlite database.
@ -226,26 +235,28 @@ class FileSource(object):
def __iter__(self):
# This method would be a lot cleaner with the proposed
# 'yield from' expression (PEP 380)
if hasattr(self._input, '__read__') or hasattr(self._input, 'read'):
if hasattr(self._input, "__read__") or hasattr(self._input, "read"):
for record in self._process_file(self._input):
yield record
elif isinstance(self._input, str):
with open(self._input) as f:
for record in self._process_file(f):
yield record
elif hasattr(self._input, '__iter__'):
elif hasattr(self._input, "__iter__"):
for el in self._input:
if isinstance(el, str):
with open(el) as f:
for record in self._process_file(f):
yield record
elif hasattr(el, '__read__') or hasattr(el, 'read'):
elif hasattr(el, "__read__") or hasattr(el, "read"):
for record in self._process_file(f):
yield record
def _process_file(self, file):
raise NotImplementedError('Descendants of FileSource should implement'
' a custom _process_file method.')
raise NotImplementedError(
"Descendants of FileSource should implement"
" a custom _process_file method."
)
class JSONSource(FileSource):
@ -271,6 +282,7 @@ class JSONSource(FileSource):
else:
yield obj
class XMLSource(FileSource):
"""Source for reading from XML files. Use with the same kind of caution
that you use to approach anything written in XML.
@ -281,14 +293,13 @@ class XMLSource(FileSource):
almost never going to be useful at the top level.
"""
def __init__(self, input, node_path=None, attr_prefix='ATTR_',
postprocessor=None):
def __init__(self, input, node_path=None, attr_prefix="ATTR_", postprocessor=None):
super(XMLSource, self).__init__(input)
self.node_list = node_path.split('.')
self.node_list = node_path.split(".")
self.attr_prefix = attr_prefix
self.postprocessor = postprocessor
def _process_file(self, f, attr_prefix='ATTR_'):
def _process_file(self, f, attr_prefix="ATTR_"):
"""xmltodict can either return attributes of nodes as prefixed fields
(prefixes to avoid key collisions), or ignore them altogether.
@ -299,8 +310,9 @@ class XMLSource(FileSource):
import xmltodict
if self.postprocessor:
obj = xmltodict.parse(f, attr_prefix=self.attr_prefix,
postprocessor=self.postprocessor)
obj = xmltodict.parse(
f, attr_prefix=self.attr_prefix, postprocessor=self.postprocessor
)
else:
obj = xmltodict.parse(f, attr_prefix=self.attr_prefix)

View File

@ -1,9 +1,9 @@
from saucebrush.filters import Filter
from saucebrush.utils import FallbackCounter
import collections
import itertools
import math
def _average(values):
"""Calculate the average of a list of values.
@ -13,6 +13,7 @@ def _average(values):
if len(values) > 0:
return sum(values) / float(value_count)
def _median(values):
"""Calculate the median of a list of values.
@ -37,6 +38,7 @@ def _median(values):
mid = int(count / 2)
return sum(values[mid - 1 : mid + 1]) / 2.0
def _stddev(values, population=False):
"""Calculate the standard deviation and variance of a list of values.
@ -56,9 +58,9 @@ def _stddev(values, population=False):
return (math.sqrt(variance), variance) # stddev is sqrt of variance
class StatsFilter(Filter):
""" Base for all stats filters.
"""
"""Base for all stats filters."""
def __init__(self, field, test=None):
self._field = field
@ -70,12 +72,13 @@ class StatsFilter(Filter):
return record
def process_field(self, record):
raise NotImplementedError('process_field not defined in ' +
self.__class__.__name__)
raise NotImplementedError(
"process_field not defined in " + self.__class__.__name__
)
def value(self):
raise NotImplementedError('value not defined in ' +
self.__class__.__name__)
raise NotImplementedError("value not defined in " + self.__class__.__name__)
class Sum(StatsFilter):
"""Calculate the sum of the values in a field. Field must contain either
@ -92,6 +95,7 @@ class Sum(StatsFilter):
def value(self):
return self._value
class Average(StatsFilter):
"""Calculate the average (mean) of the values in a field. Field must
contain either int or float values.
@ -110,6 +114,7 @@ class Average(StatsFilter):
def value(self):
return self._value / float(self._count)
class Median(StatsFilter):
"""Calculate the median of the values in a field. Field must contain
either int or float values.
@ -128,6 +133,7 @@ class Median(StatsFilter):
def value(self):
return _median(self._values)
class MinMax(StatsFilter):
"""Find the minimum and maximum values in a field. Field must contain
either int or float values.
@ -148,6 +154,7 @@ class MinMax(StatsFilter):
def value(self):
return (self._min, self._max)
class StandardDeviation(StatsFilter):
"""Calculate the standard deviation of the values in a field. Calling
value() will return a standard deviation for the sample. Pass
@ -180,6 +187,7 @@ class StandardDeviation(StatsFilter):
"""
return _stddev(self._values, population)
class Histogram(StatsFilter):
"""Generate a basic histogram of the specified field. The value() method
returns a dict of value to occurance count mappings. The __str__ method
@ -194,7 +202,7 @@ class Histogram(StatsFilter):
def __init__(self, field, **kwargs):
super(Histogram, self).__init__(field, **kwargs)
if hasattr(collections, 'Counter'):
if hasattr(collections, "Counter"):
self._counter = collections.Counter()
else:
self._counter = FallbackCounter()

View File

@ -11,5 +11,5 @@ emitter_suite = unittest.TestLoader().loadTestsFromTestCase(EmitterTestCase)
recipe_suite = unittest.TestLoader().loadTestsFromTestCase(RecipeTestCase)
stats_suite = unittest.TestLoader().loadTestsFromTestCase(StatsTestCase)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -5,61 +5,90 @@ import os
import unittest
from saucebrush.emitters import (
DebugEmitter, CSVEmitter, CountEmitter, SqliteEmitter, SqlDumpEmitter)
DebugEmitter,
CSVEmitter,
CountEmitter,
SqliteEmitter,
SqlDumpEmitter,
)
class EmitterTestCase(unittest.TestCase):
def test_debug_emitter(self):
with closing(StringIO()) as output:
de = DebugEmitter(output)
list(de.attach([1, 2, 3]))
self.assertEqual(output.getvalue(), '1\n2\n3\n')
self.assertEqual(output.getvalue(), "1\n2\n3\n")
def test_count_emitter(self):
# values for test
values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
values = [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
]
with closing(StringIO()) as output:
# test without of parameter
ce = CountEmitter(every=10, outfile=output, format="%(count)s records\n")
list(ce.attach(values))
self.assertEqual(output.getvalue(), '10 records\n20 records\n')
self.assertEqual(output.getvalue(), "10 records\n20 records\n")
ce.done()
self.assertEqual(output.getvalue(), '10 records\n20 records\n22 records\n')
self.assertEqual(output.getvalue(), "10 records\n20 records\n22 records\n")
with closing(StringIO()) as output:
# test with of parameter
ce = CountEmitter(every=10, outfile=output, of=len(values))
list(ce.attach(values))
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n')
self.assertEqual(output.getvalue(), "10 of 22\n20 of 22\n")
ce.done()
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n')
self.assertEqual(output.getvalue(), "10 of 22\n20 of 22\n22 of 22\n")
def test_csv_emitter(self):
try:
import cStringIO # if Python 2.x then use old cStringIO
io = cStringIO.StringIO()
except:
io = StringIO() # if Python 3.x then use StringIO
with closing(io) as output:
ce = CSVEmitter(output, ('x','y','z'))
list(ce.attach([{'x':1, 'y':2, 'z':3}, {'x':5, 'y':5, 'z':5}]))
self.assertEqual(output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
ce = CSVEmitter(output, ("x", "y", "z"))
list(ce.attach([{"x": 1, "y": 2, "z": 3}, {"x": 5, "y": 5, "z": 5}]))
self.assertEqual(output.getvalue(), "x,y,z\r\n1,2,3\r\n5,5,5\r\n")
def test_sqlite_emitter(self):
import sqlite3, tempfile
with closing(tempfile.NamedTemporaryFile(suffix='.db')) as f:
with closing(tempfile.NamedTemporaryFile(suffix=".db")) as f:
db_path = f.name
sle = SqliteEmitter(db_path, 'testtable', fieldnames=('a','b','c'))
list(sle.attach([{'a': '1', 'b': '2', 'c': '3'}]))
sle = SqliteEmitter(db_path, "testtable", fieldnames=("a", "b", "c"))
list(sle.attach([{"a": "1", "b": "2", "c": "3"}]))
sle.done()
with closing(sqlite3.connect(db_path)) as conn:
@ -69,18 +98,20 @@ class EmitterTestCase(unittest.TestCase):
os.unlink(db_path)
self.assertEqual(results, [('1', '2', '3')])
self.assertEqual(results, [("1", "2", "3")])
def test_sql_dump_emitter(self):
with closing(StringIO()) as bffr:
sde = SqlDumpEmitter(bffr, 'testtable', ('a', 'b'))
list(sde.attach([{'a': 1, 'b': '2'}]))
sde = SqlDumpEmitter(bffr, "testtable", ("a", "b"))
list(sde.attach([{"a": 1, "b": "2"}]))
sde.done()
self.assertEqual(bffr.getvalue(), "INSERT INTO `testtable` (`a`,`b`) VALUES (1,'2');\n")
self.assertEqual(
bffr.getvalue(), "INSERT INTO `testtable` (`a`,`b`) VALUES (1,'2');\n"
)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,23 +1,38 @@
import unittest
import operator
import types
from saucebrush.filters import (Filter, YieldFilter, FieldFilter,
SubrecordFilter, ConditionalPathFilter,
ConditionalFilter, FieldModifier, FieldKeeper,
FieldRemover, FieldMerger, FieldAdder,
FieldCopier, FieldRenamer, Unique)
from saucebrush.filters import (
Filter,
YieldFilter,
FieldFilter,
SubrecordFilter,
ConditionalPathFilter,
ConditionalFilter,
FieldModifier,
FieldKeeper,
FieldRemover,
FieldMerger,
FieldAdder,
FieldCopier,
FieldRenamer,
Unique,
)
class DummyRecipe(object):
rejected_record = None
rejected_msg = None
def reject_record(self, record, msg):
self.rejected_record = record
self.rejected_msg = msg
class Doubler(Filter):
def process_record(self, record):
return record * 2
class OddRemover(Filter):
def process_record(self, record):
if record % 2 == 0:
@ -25,15 +40,18 @@ class OddRemover(Filter):
else:
return None # explicitly return None
class ListFlattener(YieldFilter):
def process_record(self, record):
for item in record:
yield item
class FieldDoubler(FieldFilter):
def process_field(self, item):
return item * 2
class NonModifyingFieldDoubler(Filter):
def __init__(self, key):
self.key = key
@ -43,17 +61,20 @@ class NonModifyingFieldDoubler(Filter):
record[self.key] *= 2
return record
class ConditionalOddRemover(ConditionalFilter):
def test_record(self, record):
# return True for even values
return record % 2 == 0
class FilterTestCase(unittest.TestCase):
class FilterTestCase(unittest.TestCase):
def _simple_data(self):
return [{'a':1, 'b':2, 'c':3},
{'a':5, 'b':5, 'c':5},
{'a':1, 'b':10, 'c':100}]
return [
{"a": 1, "b": 2, "c": 3},
{"a": 5, "b": 5, "c": 5},
{"a": 1, "b": 10, "c": 100},
]
def assert_filter_result(self, filter_obj, expected_data):
result = filter_obj.attach(self._simple_data())
@ -65,11 +86,11 @@ class FilterTestCase(unittest.TestCase):
result = f.attach([1, 2, 3], recipe=recipe)
# next has to be called for attach to take effect
next(result)
f.reject_record('bad', 'this one was bad')
f.reject_record("bad", "this one was bad")
# ensure that the rejection propagated to the recipe
self.assertEqual('bad', recipe.rejected_record)
self.assertEqual('this one was bad', recipe.rejected_msg)
self.assertEqual("bad", recipe.rejected_record)
self.assertEqual("this one was bad", recipe.rejected_msg)
def test_simple_filter(self):
df = Doubler()
@ -95,12 +116,14 @@ class FilterTestCase(unittest.TestCase):
self.assertEqual(list(result), [1, 2, 3, 4, 5, 6])
def test_simple_field_filter(self):
ff = FieldDoubler(['a', 'c'])
ff = FieldDoubler(["a", "c"])
# check against expected data
expected_data = [{'a':2, 'b':2, 'c':6},
{'a':10, 'b':5, 'c':10},
{'a':2, 'b':10, 'c':200}]
expected_data = [
{"a": 2, "b": 2, "c": 6},
{"a": 10, "b": 5, "c": 10},
{"a": 2, "b": 10, "c": 200},
]
self.assert_filter_result(ff, expected_data)
def test_conditional_filter(self):
@ -113,79 +136,88 @@ class FilterTestCase(unittest.TestCase):
### Tests for Subrecord
def test_subrecord_filter_list(self):
data = [{'a': [{'b': 2}, {'b': 4}]},
{'a': [{'b': 5}]},
{'a': [{'b': 8}, {'b':2}, {'b':1}]}]
data = [
{"a": [{"b": 2}, {"b": 4}]},
{"a": [{"b": 5}]},
{"a": [{"b": 8}, {"b": 2}, {"b": 1}]},
]
expected = [{'a': [{'b': 4}, {'b': 8}]},
{'a': [{'b': 10}]},
{'a': [{'b': 16}, {'b':4}, {'b':2}]}]
expected = [
{"a": [{"b": 4}, {"b": 8}]},
{"a": [{"b": 10}]},
{"a": [{"b": 16}, {"b": 4}, {"b": 2}]},
]
sf = SubrecordFilter('a', NonModifyingFieldDoubler('b'))
sf = SubrecordFilter("a", NonModifyingFieldDoubler("b"))
result = sf.attach(data)
self.assertEqual(list(result), expected)
def test_subrecord_filter_deep(self):
data = [{'a': {'d':[{'b': 2}, {'b': 4}]}},
{'a': {'d':[{'b': 5}]}},
{'a': {'d':[{'b': 8}, {'b':2}, {'b':1}]}}]
data = [
{"a": {"d": [{"b": 2}, {"b": 4}]}},
{"a": {"d": [{"b": 5}]}},
{"a": {"d": [{"b": 8}, {"b": 2}, {"b": 1}]}},
]
expected = [{'a': {'d':[{'b': 4}, {'b': 8}]}},
{'a': {'d':[{'b': 10}]}},
{'a': {'d':[{'b': 16}, {'b':4}, {'b':2}]}}]
expected = [
{"a": {"d": [{"b": 4}, {"b": 8}]}},
{"a": {"d": [{"b": 10}]}},
{"a": {"d": [{"b": 16}, {"b": 4}, {"b": 2}]}},
]
sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b'))
sf = SubrecordFilter("a.d", NonModifyingFieldDoubler("b"))
result = sf.attach(data)
self.assertEqual(list(result), expected)
def test_subrecord_filter_nonlist(self):
data = [
{'a':{'b':{'c':1}}},
{'a':{'b':{'c':2}}},
{'a':{'b':{'c':3}}},
{"a": {"b": {"c": 1}}},
{"a": {"b": {"c": 2}}},
{"a": {"b": {"c": 3}}},
]
expected = [
{'a':{'b':{'c':2}}},
{'a':{'b':{'c':4}}},
{'a':{'b':{'c':6}}},
{"a": {"b": {"c": 2}}},
{"a": {"b": {"c": 4}}},
{"a": {"b": {"c": 6}}},
]
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
sf = SubrecordFilter("a.b", NonModifyingFieldDoubler("c"))
result = sf.attach(data)
self.assertEqual(list(result), expected)
def test_subrecord_filter_list_in_path(self):
data = [
{'a': [{'b': {'c': 5}}, {'b': {'c': 6}}]},
{'a': [{'b': {'c': 1}}, {'b': {'c': 2}}, {'b': {'c': 3}}]},
{'a': [{'b': {'c': 2}} ]}
{"a": [{"b": {"c": 5}}, {"b": {"c": 6}}]},
{"a": [{"b": {"c": 1}}, {"b": {"c": 2}}, {"b": {"c": 3}}]},
{"a": [{"b": {"c": 2}}]},
]
expected = [
{'a': [{'b': {'c': 10}}, {'b': {'c': 12}}]},
{'a': [{'b': {'c': 2}}, {'b': {'c': 4}}, {'b': {'c': 6}}]},
{'a': [{'b': {'c': 4}} ]}
{"a": [{"b": {"c": 10}}, {"b": {"c": 12}}]},
{"a": [{"b": {"c": 2}}, {"b": {"c": 4}}, {"b": {"c": 6}}]},
{"a": [{"b": {"c": 4}}]},
]
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
sf = SubrecordFilter("a.b", NonModifyingFieldDoubler("c"))
result = sf.attach(data)
self.assertEqual(list(result), expected)
def test_conditional_path(self):
predicate = lambda r: r['a'] == 1
predicate = lambda r: r["a"] == 1
# double b if a == 1, otherwise double c
cpf = ConditionalPathFilter(predicate, FieldDoubler('b'),
FieldDoubler('c'))
expected_data = [{'a':1, 'b':4, 'c':3},
{'a':5, 'b':5, 'c':10},
{'a':1, 'b':20, 'c':100}]
cpf = ConditionalPathFilter(predicate, FieldDoubler("b"), FieldDoubler("c"))
expected_data = [
{"a": 1, "b": 4, "c": 3},
{"a": 5, "b": 5, "c": 10},
{"a": 1, "b": 20, "c": 100},
]
self.assert_filter_result(cpf, expected_data)
@ -193,112 +225,132 @@ class FilterTestCase(unittest.TestCase):
def test_field_modifier(self):
# another version of FieldDoubler
fm = FieldModifier(['a', 'c'], lambda x: x*2)
fm = FieldModifier(["a", "c"], lambda x: x * 2)
# check against expected data
expected_data = [{'a':2, 'b':2, 'c':6},
{'a':10, 'b':5, 'c':10},
{'a':2, 'b':10, 'c':200}]
expected_data = [
{"a": 2, "b": 2, "c": 6},
{"a": 10, "b": 5, "c": 10},
{"a": 2, "b": 10, "c": 200},
]
self.assert_filter_result(fm, expected_data)
def test_field_keeper(self):
fk = FieldKeeper(['c'])
fk = FieldKeeper(["c"])
# check against expected results
expected_data = [{'c':3}, {'c':5}, {'c':100}]
expected_data = [{"c": 3}, {"c": 5}, {"c": 100}]
self.assert_filter_result(fk, expected_data)
def test_field_remover(self):
fr = FieldRemover(['a', 'b'])
fr = FieldRemover(["a", "b"])
# check against expected results
expected_data = [{'c':3}, {'c':5}, {'c':100}]
expected_data = [{"c": 3}, {"c": 5}, {"c": 100}]
self.assert_filter_result(fr, expected_data)
def test_field_merger(self):
fm = FieldMerger({'sum':('a','b','c')}, lambda x,y,z: x+y+z)
fm = FieldMerger({"sum": ("a", "b", "c")}, lambda x, y, z: x + y + z)
# check against expected results
expected_data = [{'sum':6}, {'sum':15}, {'sum':111}]
expected_data = [{"sum": 6}, {"sum": 15}, {"sum": 111}]
self.assert_filter_result(fm, expected_data)
def test_field_merger_keep_fields(self):
fm = FieldMerger({'sum':('a','b','c')}, lambda x,y,z: x+y+z,
keep_fields=True)
fm = FieldMerger(
{"sum": ("a", "b", "c")}, lambda x, y, z: x + y + z, keep_fields=True
)
# check against expected results
expected_data = [{'a':1, 'b':2, 'c':3, 'sum':6},
{'a':5, 'b':5, 'c':5, 'sum':15},
{'a':1, 'b':10, 'c':100, 'sum': 111}]
expected_data = [
{"a": 1, "b": 2, "c": 3, "sum": 6},
{"a": 5, "b": 5, "c": 5, "sum": 15},
{"a": 1, "b": 10, "c": 100, "sum": 111},
]
self.assert_filter_result(fm, expected_data)
def test_field_adder_scalar(self):
fa = FieldAdder('x', 7)
fa = FieldAdder("x", 7)
expected_data = [{'a':1, 'b':2, 'c':3, 'x':7},
{'a':5, 'b':5, 'c':5, 'x':7},
{'a':1, 'b':10, 'c':100, 'x': 7}]
expected_data = [
{"a": 1, "b": 2, "c": 3, "x": 7},
{"a": 5, "b": 5, "c": 5, "x": 7},
{"a": 1, "b": 10, "c": 100, "x": 7},
]
self.assert_filter_result(fa, expected_data)
def test_field_adder_callable(self):
fa = FieldAdder('x', lambda: 7)
fa = FieldAdder("x", lambda: 7)
expected_data = [{'a':1, 'b':2, 'c':3, 'x':7},
{'a':5, 'b':5, 'c':5, 'x':7},
{'a':1, 'b':10, 'c':100, 'x': 7}]
expected_data = [
{"a": 1, "b": 2, "c": 3, "x": 7},
{"a": 5, "b": 5, "c": 5, "x": 7},
{"a": 1, "b": 10, "c": 100, "x": 7},
]
self.assert_filter_result(fa, expected_data)
def test_field_adder_iterable(self):
fa = FieldAdder('x', [1,2,3])
fa = FieldAdder("x", [1, 2, 3])
expected_data = [{'a':1, 'b':2, 'c':3, 'x':1},
{'a':5, 'b':5, 'c':5, 'x':2},
{'a':1, 'b':10, 'c':100, 'x': 3}]
expected_data = [
{"a": 1, "b": 2, "c": 3, "x": 1},
{"a": 5, "b": 5, "c": 5, "x": 2},
{"a": 1, "b": 10, "c": 100, "x": 3},
]
self.assert_filter_result(fa, expected_data)
def test_field_adder_replace(self):
fa = FieldAdder('b', lambda: 7)
fa = FieldAdder("b", lambda: 7)
expected_data = [{'a':1, 'b':7, 'c':3},
{'a':5, 'b':7, 'c':5},
{'a':1, 'b':7, 'c':100}]
expected_data = [
{"a": 1, "b": 7, "c": 3},
{"a": 5, "b": 7, "c": 5},
{"a": 1, "b": 7, "c": 100},
]
self.assert_filter_result(fa, expected_data)
def test_field_adder_no_replace(self):
fa = FieldAdder('b', lambda: 7, replace=False)
fa = FieldAdder("b", lambda: 7, replace=False)
expected_data = [{'a':1, 'b':2, 'c':3},
{'a':5, 'b':5, 'c':5},
{'a':1, 'b':10, 'c':100}]
expected_data = [
{"a": 1, "b": 2, "c": 3},
{"a": 5, "b": 5, "c": 5},
{"a": 1, "b": 10, "c": 100},
]
self.assert_filter_result(fa, expected_data)
def test_field_copier(self):
fc = FieldCopier({'a2':'a', 'b2':'b'})
fc = FieldCopier({"a2": "a", "b2": "b"})
expected_data = [{'a':1, 'b':2, 'c':3, 'a2':1, 'b2':2},
{'a':5, 'b':5, 'c':5, 'a2':5, 'b2':5},
{'a':1, 'b':10, 'c':100, 'a2': 1, 'b2': 10}]
expected_data = [
{"a": 1, "b": 2, "c": 3, "a2": 1, "b2": 2},
{"a": 5, "b": 5, "c": 5, "a2": 5, "b2": 5},
{"a": 1, "b": 10, "c": 100, "a2": 1, "b2": 10},
]
self.assert_filter_result(fc, expected_data)
def test_field_renamer(self):
fr = FieldRenamer({'x':'a', 'y':'b'})
fr = FieldRenamer({"x": "a", "y": "b"})
expected_data = [{'x':1, 'y':2, 'c':3},
{'x':5, 'y':5, 'c':5},
{'x':1, 'y':10, 'c':100}]
expected_data = [
{"x": 1, "y": 2, "c": 3},
{"x": 5, "y": 5, "c": 5},
{"x": 1, "y": 10, "c": 100},
]
self.assert_filter_result(fr, expected_data)
# TODO: splitter & flattner tests?
def test_unique_filter(self):
u = Unique()
in_data = [{'a': 77}, {'a':33}, {'a': 77}]
expected_data = [{'a': 77}, {'a':33}]
in_data = [{"a": 77}, {"a": 33}, {"a": 77}]
expected_data = [{"a": 77}, {"a": 33}]
result = u.attach(in_data)
self.assertEqual(list(result), expected_data)
# TODO: unicode & string filter tests
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -22,11 +22,11 @@ class RecipeTestCase(unittest.TestCase):
def test_error_stream(self):
saver = Saver()
recipe = Recipe(Raiser(), error_stream=saver)
recipe.run([{'a': 1}, {'b': 2}])
recipe.run([{"a": 1}, {"b": 2}])
recipe.done()
self.assertEqual(saver.saved[0]['record'], {'a': 1})
self.assertEqual(saver.saved[1]['record'], {'b': 2})
self.assertEqual(saver.saved[0]["record"], {"a": 1})
self.assertEqual(saver.saved[1]["record"], {"b": 2})
# Must pass either a Recipe, a Filter or an iterable of Filters
# as the error_stream argument
@ -49,5 +49,5 @@ class RecipeTestCase(unittest.TestCase):
self.assertEqual(saver.saved, [1])
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -3,46 +3,56 @@ from io import BytesIO, StringIO
import unittest
from saucebrush.sources import (
CSVSource, FixedWidthFileSource, HtmlTableSource, JSONSource)
CSVSource,
FixedWidthFileSource,
HtmlTableSource,
JSONSource,
)
class SourceTestCase(unittest.TestCase):
def _get_csv(self):
data = '''a,b,c
data = """a,b,c
1,2,3
5,5,5
1,10,100'''
1,10,100"""
return StringIO(data)
def test_csv_source_basic(self):
source = CSVSource(self._get_csv())
expected_data = [{'a':'1', 'b':'2', 'c':'3'},
{'a':'5', 'b':'5', 'c':'5'},
{'a':'1', 'b':'10', 'c':'100'}]
expected_data = [
{"a": "1", "b": "2", "c": "3"},
{"a": "5", "b": "5", "c": "5"},
{"a": "1", "b": "10", "c": "100"},
]
self.assertEqual(list(source), expected_data)
def test_csv_source_fieldnames(self):
source = CSVSource(self._get_csv(), ['x','y','z'])
expected_data = [{'x':'a', 'y':'b', 'z':'c'},
{'x':'1', 'y':'2', 'z':'3'},
{'x':'5', 'y':'5', 'z':'5'},
{'x':'1', 'y':'10', 'z':'100'}]
source = CSVSource(self._get_csv(), ["x", "y", "z"])
expected_data = [
{"x": "a", "y": "b", "z": "c"},
{"x": "1", "y": "2", "z": "3"},
{"x": "5", "y": "5", "z": "5"},
{"x": "1", "y": "10", "z": "100"},
]
self.assertEqual(list(source), expected_data)
def test_csv_source_skiprows(self):
source = CSVSource(self._get_csv(), skiprows=1)
expected_data = [{'a':'5', 'b':'5', 'c':'5'},
{'a':'1', 'b':'10', 'c':'100'}]
expected_data = [
{"a": "5", "b": "5", "c": "5"},
{"a": "1", "b": "10", "c": "100"},
]
self.assertEqual(list(source), expected_data)
def test_fixed_width_source(self):
data = StringIO('JamesNovember 3 1986\nTim September151999')
fields = (('name',5), ('month',9), ('day',2), ('year',4))
data = StringIO("JamesNovember 3 1986\nTim September151999")
fields = (("name", 5), ("month", 9), ("day", 2), ("year", 4))
source = FixedWidthFileSource(data, fields)
expected_data = [{'name':'James', 'month':'November', 'day':'3',
'year':'1986'},
{'name':'Tim', 'month':'September', 'day':'15',
'year':'1999'}]
expected_data = [
{"name": "James", "month": "November", "day": "3", "year": "1986"},
{"name": "Tim", "month": "September", "day": "15", "year": "1999"},
]
self.assertEqual(list(source), expected_data)
def test_json_source(self):
@ -50,11 +60,12 @@ class SourceTestCase(unittest.TestCase):
content = StringIO("""[{"a": 1, "b": "2", "c": 3}]""")
js = JSONSource(content)
self.assertEqual(list(js), [{'a': 1, 'b': '2', 'c': 3}])
self.assertEqual(list(js), [{"a": 1, "b": "2", "c": 3}])
def test_html_table_source(self):
content = StringIO("""
content = StringIO(
"""
<html>
<table id="thetable">
<tr>
@ -69,19 +80,21 @@ class SourceTestCase(unittest.TestCase):
</tr>
</table>
</html>
""")
"""
)
try:
import lxml
hts = HtmlTableSource(content, 'thetable')
self.assertEqual(list(hts), [{'a': '1', 'b': '2', 'c': '3'}])
hts = HtmlTableSource(content, "thetable")
self.assertEqual(list(hts), [{"a": "1", "b": "2", "c": "3"}])
except ImportError:
# Python 2.6 doesn't have skipTest. We'll just suffer without it.
if hasattr(self, 'skipTest'):
if hasattr(self, "skipTest"):
self.skipTest("lxml is not installed")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,41 +1,43 @@
import unittest
from saucebrush.stats import Sum, Average, Median, MinMax, StandardDeviation, Histogram
class StatsTestCase(unittest.TestCase):
class StatsTestCase(unittest.TestCase):
def _simple_data(self):
return [{'a':1, 'b':2, 'c':3},
{'a':5, 'b':5, 'c':5},
{'a':1, 'b':10, 'c':100}]
return [
{"a": 1, "b": 2, "c": 3},
{"a": 5, "b": 5, "c": 5},
{"a": 1, "b": 10, "c": 100},
]
def test_sum(self):
fltr = Sum('b')
fltr = Sum("b")
list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), 17)
def test_average(self):
fltr = Average('c')
fltr = Average("c")
list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), 36.0)
def test_median(self):
# odd number of values
fltr = Median('a')
fltr = Median("a")
list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), 1)
# even number of values
fltr = Median('a')
fltr = Median("a")
list(fltr.attach(self._simple_data()[:2]))
self.assertEqual(fltr.value(), 3)
def test_minmax(self):
fltr = MinMax('b')
fltr = MinMax("b")
list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.value(), (2, 10))
def test_standard_deviation(self):
fltr = StandardDeviation('c')
fltr = StandardDeviation("c")
list(fltr.attach(self._simple_data()))
self.assertEqual(fltr.average(), 36.0)
self.assertEqual(fltr.median(), 5)
@ -43,10 +45,11 @@ class StatsTestCase(unittest.TestCase):
self.assertEqual(fltr.value(True), (45.2621990922521, 2048.6666666666665))
def test_histogram(self):
fltr = Histogram('a')
fltr = Histogram("a")
fltr.label_length = 1
list(fltr.attach(self._simple_data()))
self.assertEqual(str(fltr), "\n1 **\n5 *\n")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -10,23 +10,29 @@ except ImportError:
General utilities used within saucebrush that may be useful elsewhere.
"""
def get_django_model(dj_settings, app_label, model_name):
"""
Get a django model given a settings file, app label, and model name.
"""
from django.conf import settings
if not settings.configured:
settings.configure(DATABASE_ENGINE=dj_settings.DATABASE_ENGINE,
settings.configure(
DATABASE_ENGINE=dj_settings.DATABASE_ENGINE,
DATABASE_NAME=dj_settings.DATABASE_NAME,
DATABASE_USER=dj_settings.DATABASE_USER,
DATABASE_PASSWORD=dj_settings.DATABASE_PASSWORD,
DATABASE_HOST=dj_settings.DATABASE_HOST,
INSTALLED_APPS=dj_settings.INSTALLED_APPS)
INSTALLED_APPS=dj_settings.INSTALLED_APPS,
)
from django.db.models import get_model
return get_model(app_label, model_name)
def flatten(item, prefix='', separator='_', keys=None):
def flatten(item, prefix="", separator="_", keys=None):
"""
Flatten nested dictionary into one with its keys concatenated together.
@ -39,7 +45,7 @@ def flatten(item, prefix='', separator='_', keys=None):
if isinstance(item, dict):
# don't prepend a leading _
if prefix != '':
if prefix != "":
prefix += separator
retval = {}
for key, value in item.items():
@ -53,16 +59,19 @@ def flatten(item, prefix='', separator='_', keys=None):
else:
return {prefix: item}
def str_or_list(obj):
if isinstance(obj, str):
return [obj]
else:
return obj
#
# utility classes
#
class FallbackCounter(collections.defaultdict):
"""Python 2.6 does not have collections.Counter.
This is class that does the basics of what we need from Counter.
@ -73,14 +82,14 @@ class FallbackCounter(collections.defaultdict):
def most_common(n=None):
l = sorted(self.items(),
cmp=lambda x,y: cmp(x[1], y[1]))
l = sorted(self.items(), cmp=lambda x, y: cmp(x[1], y[1]))
if n is not None:
l = l[:n]
return l
class Files(object):
"""Iterate over multiple files as a single file. Pass the paths of the
files as arguments to the class constructor:
@ -111,6 +120,7 @@ class Files(object):
yield line
f.close()
class RemoteFile(object):
"""Stream data from a remote file.
@ -126,6 +136,7 @@ class RemoteFile(object):
yield line.rstrip()
resp.close()
class ZippedFiles(object):
"""unpack a zipped collection of files on init.
@ -137,10 +148,12 @@ class ZippedFiles(object):
if using a ZipFile object, make sure to set mode to 'a' or 'w' in order
to use the add() function.
"""
def __init__(self, zippedfile):
import zipfile
if type(zippedfile) == str:
self._zipfile = zipfile.ZipFile(zippedfile,'a')
self._zipfile = zipfile.ZipFile(zippedfile, "a")
else:
self._zipfile = zippedfile
self.paths = self._zipfile.namelist()