Merge branch 'py3'

This commit is contained in:
James Turk 2012-03-13 11:28:03 -07:00
commit 91143a7da0
11 changed files with 279 additions and 201 deletions

View File

@ -3,7 +3,7 @@ from saucebrush.outputs import CSVOutput, DebugOutput
def merge_columns(datasource, mapping, merge_func): def merge_columns(datasource, mapping, merge_func):
for rowdata in datasource: for rowdata in datasource:
for to_col,from_cols in mapping.iteritems(): for to_col,from_cols in mapping.items():
values = [rowdata.pop(col, None) for col in from_cols] values = [rowdata.pop(col, None) for col in from_cols]
rowdata[to_col] = reduce(merge_func, values) rowdata[to_col] = reduce(merge_func, values)
yield rowdata yield rowdata

View File

@ -84,7 +84,7 @@ class FECSource(object):
@staticmethod @staticmethod
def get_form_type(rectype): def get_form_type(rectype):
for type_re, type in FECSource.FORM_MAPPING.iteritems(): for type_re, type in FECSource.FORM_MAPPING.items():
if type_re.match(rectype): if type_re.match(rectype):
return type return type

View File

@ -2,7 +2,7 @@
Saucebrush is a data loading & manipulation framework written in python. Saucebrush is a data loading & manipulation framework written in python.
""" """
import filters, emitters, sources, utils from . import filters, emitters, sources, utils
class SaucebrushError(Exception): class SaucebrushError(Exception):

View File

@ -2,7 +2,7 @@
Saucebrush Emitters are filters that instead of modifying the record, output Saucebrush Emitters are filters that instead of modifying the record, output
it in some manner. it in some manner.
""" """
from __future__ import unicode_literals
from saucebrush.filters import Filter from saucebrush.filters import Filter
class Emitter(Filter): class Emitter(Filter):
@ -50,12 +50,12 @@ class DebugEmitter(Emitter):
self._outfile = outfile self._outfile = outfile
def emit_record(self, record): def emit_record(self, record):
self._outfile.write(str(record) + '\n') self._outfile.write("{0}\n".format(record))
class CountEmitter(Emitter): class CountEmitter(Emitter):
""" Emitter that writes the record count to a file-like object. """ Emitter that writes the record count to a file-like object.
CountEmitter() by default writes to stdout. CountEmitter() by default writes to stdout.
CountEmitter(outfile=open('text', 'w')) would print to a file name test. CountEmitter(outfile=open('text', 'w')) would print to a file name test.
CountEmitter(every=1000000) would write the count every 1,000,000 records. CountEmitter(every=1000000) would write the count every 1,000,000 records.
@ -63,36 +63,36 @@ class CountEmitter(Emitter):
""" """
def __init__(self, every=1000, of=None, outfile=None, format=None): def __init__(self, every=1000, of=None, outfile=None, format=None):
super(CountEmitter, self).__init__() super(CountEmitter, self).__init__()
if not outfile: if not outfile:
import sys import sys
self._outfile = sys.stdout self._outfile = sys.stdout
else: else:
self._outfile = outfile self._outfile = outfile
if format is None: if format is None:
if of is not None: if of is not None:
format = "%(count)s of %(of)s\n" format = "%(count)s of %(of)s\n"
else: else:
format = "%(count)s\n" format = "%(count)s\n"
self._format = format self._format = format
self._every = every self._every = every
self._of = of self._of = of
self.count = 0 self.count = 0
def __str__(self): def format(self):
return self._format % {'count': self.count, 'of': self._of} return self._format % {'count': self.count, 'of': self._of}
def emit_record(self, record): def emit_record(self, record):
self.count += 1 self.count += 1
if self.count % self._every == 0: if self.count % self._every == 0:
self._outfile.write(str(self)) self._outfile.write(self.format())
def done(self): def done(self):
self._outfile.write(str(self)) self._outfile.write(self.format())
class CSVEmitter(Emitter): class CSVEmitter(Emitter):
@ -107,7 +107,8 @@ class CSVEmitter(Emitter):
import csv import csv
self._dictwriter = csv.DictWriter(csvfile, fieldnames) self._dictwriter = csv.DictWriter(csvfile, fieldnames)
# write header row # write header row
self._dictwriter.writerow(dict(zip(fieldnames, fieldnames))) header_row = dict(zip(fieldnames, fieldnames))
self._dictwriter.writerow(header_row)
def emit_record(self, record): def emit_record(self, record):
self._dictwriter.writerow(record) self._dictwriter.writerow(record)
@ -145,8 +146,8 @@ class SqliteEmitter(Emitter):
','.join(record.keys()), ','.join(record.keys()),
qmarks) qmarks)
try: try:
self._cursor.execute(insert, record.values()) self._cursor.execute(insert, list(record.values()))
except sqlite3.IntegrityError, ie: except sqlite3.IntegrityError as ie:
if not self._quiet: if not self._quiet:
raise ie raise ie
self.reject_record(record, ie.message) self.reject_record(record, ie.message)
@ -179,21 +180,25 @@ class SqlDumpEmitter(Emitter):
table_name, '`,`'.join(fieldnames)) table_name, '`,`'.join(fieldnames))
def quote(self, item): def quote(self, item):
if item is None: if item is None:
return "null" return "null"
elif isinstance(item, (unicode, str)):
try:
types = (basestring,)
except NameError:
types = (str,)
if isinstance(item, types):
item = item.replace("\\","\\\\").replace("'","\\'").replace(chr(0),'0') item = item.replace("\\","\\\\").replace("'","\\'").replace(chr(0),'0')
return "'%s'" % item return "'%s'" % item
else:
return "%s" % item return "%s" % item
def emit_record(self, record): def emit_record(self, record):
quoted_data = [self.quote(record[field]) for field in self._fieldnames] quoted_data = [self.quote(record[field]) for field in self._fieldnames]
self._outfile.write(self._insert_str % ','.join(quoted_data)) self._outfile.write(self._insert_str % ','.join(quoted_data))
def done(self):
self._outfile.close()
class DjangoModelEmitter(Emitter): class DjangoModelEmitter(Emitter):
""" Emitter that populates a table corresponding to a django model. """ Emitter that populates a table corresponding to a django model.

View File

@ -217,17 +217,17 @@ class FieldModifier(FieldFilter):
class FieldKeeper(Filter): class FieldKeeper(Filter):
""" Filter that removes all but the given set of fields. """ Filter that removes all but the given set of fields.
FieldKeeper(('spam', 'eggs')) removes all bu tthe spam and eggs FieldKeeper(('spam', 'eggs')) removes all bu tthe spam and eggs
fields from every record filtered. fields from every record filtered.
""" """
def __init__(self, keys): def __init__(self, keys):
super(FieldKeeper, self).__init__() super(FieldKeeper, self).__init__()
self._target_keys = utils.str_or_list(keys) self._target_keys = utils.str_or_list(keys)
def process_record(self, record): def process_record(self, record):
for key in record.keys(): for key in list(record.keys()):
if key not in self._target_keys: if key not in self._target_keys:
del record[key] del record[key]
return record return record
@ -269,7 +269,7 @@ class FieldMerger(Filter):
self._keep_fields = keep_fields self._keep_fields = keep_fields
def process_record(self, record): def process_record(self, record):
for to_col, from_cols in self._field_mapping.iteritems(): for to_col, from_cols in self._field_mapping.items():
if self._keep_fields: if self._keep_fields:
values = [record.get(col, None) for col in from_cols] values = [record.get(col, None) for col in from_cols]
else: else:
@ -301,7 +301,11 @@ class FieldAdder(Filter):
self._field_name = field_name self._field_name = field_name
self._field_value = field_value self._field_value = field_value
if hasattr(self._field_value, '__iter__'): if hasattr(self._field_value, '__iter__'):
self._field_value = iter(self._field_value).next value_iter = iter(self._field_value)
if hasattr(value_iter, "next"):
self._field_value = value_iter.next
else:
self._field_value = value_iter.__next__
self._replace = replace self._replace = replace
def process_record(self, record): def process_record(self, record):
@ -328,7 +332,7 @@ class FieldCopier(Filter):
def process_record(self, record): def process_record(self, record):
# mapping is dest:source # mapping is dest:source
for dest, source in self._copy_mapping.iteritems(): for dest, source in self._copy_mapping.items():
record[dest] = record[source] record[dest] = record[source]
return record return record
@ -343,7 +347,7 @@ class FieldRenamer(Filter):
def process_record(self, record): def process_record(self, record):
# mapping is dest:source # mapping is dest:source
for dest, source in self._rename_mapping.iteritems(): for dest, source in self._rename_mapping.items():
record[dest] = record.pop(source) record[dest] = record.pop(source)
return record return record
@ -363,7 +367,7 @@ class Splitter(Filter):
self._split_mapping = split_mapping self._split_mapping = split_mapping
def process_record(self, record): def process_record(self, record):
for key, filters in self._split_mapping.iteritems(): for key, filters in self._split_mapping.items():
# if the key doesn't exist -- move on to next key # if the key doesn't exist -- move on to next key
try: try:
@ -479,7 +483,7 @@ class UnicodeFilter(Filter):
self._errors = errors self._errors = errors
def process_record(self, record): def process_record(self, record):
for key, value in record.iteritems(): for key, value in record.items():
if isinstance(value, str): if isinstance(value, str):
record[key] = unicode(value, self._encoding, self._errors) record[key] = unicode(value, self._encoding, self._errors)
elif isinstance(value, unicode): elif isinstance(value, unicode):
@ -494,7 +498,7 @@ class StringFilter(Filter):
self._errors = errors self._errors = errors
def process_record(self, record): def process_record(self, record):
for key, value in record.iteritems(): for key, value in record.items():
if isinstance(value, unicode): if isinstance(value, unicode):
record[key] = value.encode(self._encoding, self._errors) record[key] = value.encode(self._encoding, self._errors)
return record return record
@ -584,7 +588,7 @@ class NameCleaner(Filter):
# if there is a match, remove original name and add pieces # if there is a match, remove original name and add pieces
if match: if match:
record.pop(key) record.pop(key)
for k,v in match.groupdict().iteritems(): for k,v in match.groupdict().items():
record[self._name_prefix + k] = v record[self._name_prefix + k] = v
break break

View File

@ -4,8 +4,9 @@
All sources must implement the iterable interface and return python All sources must implement the iterable interface and return python
dictionaries. dictionaries.
""" """
from __future__ import unicode_literals
import string import string
from saucebrush import utils from saucebrush import utils
class CSVSource(object): class CSVSource(object):
@ -25,8 +26,8 @@ class CSVSource(object):
def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs): def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs):
import csv import csv
self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs) self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs)
for _ in xrange(skiprows): for _ in range(skiprows):
self._dictreader.next() next(self._dictreader)
def __iter__(self): def __iter__(self):
return self._dictreader return self._dictreader
@ -59,13 +60,18 @@ class FixedWidthFileSource(object):
def __iter__(self): def __iter__(self):
return self return self
def next(self): def __next__(self):
line = self._fwfile.next() line = next(self._fwfile)
record = {} record = {}
for name, range_ in self._fields_dict.iteritems(): for name, range_ in self._fields_dict.items():
record[name] = line[range_[0]:range_[1]].rstrip(self._fillchars) record[name] = line[range_[0]:range_[1]].rstrip(self._fillchars)
return record return record
def next(self):
""" Keep Python 2 next() method that defers to __next__().
"""
return self.__next__()
class HtmlTableSource(object): class HtmlTableSource(object):
""" Saucebrush source for reading data from an HTML table. """ Saucebrush source for reading data from an HTML table.
@ -86,26 +92,32 @@ class HtmlTableSource(object):
def __init__(self, htmlfile, id_or_num, fieldnames=None, skiprows=0): def __init__(self, htmlfile, id_or_num, fieldnames=None, skiprows=0):
# extract the table # extract the table
from BeautifulSoup import BeautifulSoup from lxml.html import parse
soup = BeautifulSoup(htmlfile.read()) doc = parse(htmlfile).getroot()
if isinstance(id_or_num, int): if isinstance(id_or_num, int):
table = soup.findAll('table')[id_or_num] table = doc.cssselect('table')[id_or_num]
elif isinstance(id_or_num, str): else:
table = soup.find('table', id=id_or_num) table = doc.cssselect('table#%s' % id_or_num)
table = table[0] # get the first table
# skip the necessary number of rows # skip the necessary number of rows
self._rows = table.findAll('tr')[skiprows:] self._rows = table.cssselect('tr')[skiprows:]
# determine the fieldnames # determine the fieldnames
if not fieldnames: if not fieldnames:
self._fieldnames = [td.string self._fieldnames = [td.text_content()
for td in self._rows[0].findAll(('td','th'))] for td in self._rows[0].cssselect('td, th')]
skiprows += 1
else: else:
self._fieldnames = fieldnames self._fieldnames = fieldnames
# skip the necessary number of rows
self._rows = table.cssselect('tr')[skiprows:]
def process_tr(self): def process_tr(self):
for row in self._rows: for row in self._rows:
strings = [utils.string_dig(td) for td in row.findAll('td')] strings = [td.text_content() for td in row.cssselect('td')]
yield dict(zip(self._fieldnames, strings)) yield dict(zip(self._fieldnames, strings))
def __iter__(self): def __iter__(self):
@ -182,7 +194,7 @@ class SqliteSource(object):
self._conn = sqlite3.connect(self._dbpath) self._conn = sqlite3.connect(self._dbpath)
self._conn.row_factory = dict_factory self._conn.row_factory = dict_factory
if self._conn_params: if self._conn_params:
for param, value in self._conn_params.iteritems(): for param, value in self._conn_params.items():
setattr(self._conn, param, value) setattr(self._conn, param, value)
def _process_query(self): def _process_query(self):
@ -214,20 +226,20 @@ class FileSource(object):
def __iter__(self): def __iter__(self):
# This method would be a lot cleaner with the proposed # This method would be a lot cleaner with the proposed
# 'yield from' expression (PEP 380) # 'yield from' expression (PEP 380)
if hasattr(self._input, '__read__'): if hasattr(self._input, '__read__') or hasattr(self._input, 'read'):
for record in self._process_file(input): for record in self._process_file(self._input):
yield record yield record
elif isinstance(self._input, basestring): elif isinstance(self._input, str):
with open(self._input) as f: with open(self._input) as f:
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
elif hasattr(self._input, '__iter__'): elif hasattr(self._input, '__iter__'):
for el in self._input: for el in self._input:
if isinstance(el, basestring): if isinstance(el, str):
with open(el) as f: with open(el) as f:
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
elif hasattr(el, '__read__'): elif hasattr(el, '__read__') or hasattr(el, 'read'):
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
@ -244,10 +256,11 @@ class JSONSource(FileSource):
object. object.
""" """
def _process_file(self, file): def _process_file(self, f):
import json import json
obj = json.load(file) obj = json.load(f)
# If the top-level JSON object in the file is a list # If the top-level JSON object in the file is a list
# then yield each element separately; otherwise, yield # then yield each element separately; otherwise, yield

View File

@ -5,7 +5,7 @@ import math
def _average(values): def _average(values):
""" Calculate the average of a list of values. """ Calculate the average of a list of values.
:param values: an iterable of ints or floats to average :param values: an iterable of ints or floats to average
""" """
value_count = len(values) value_count = len(values)
@ -14,64 +14,64 @@ def _average(values):
def _median(values): def _median(values):
""" Calculate the median of a list of values. """ Calculate the median of a list of values.
:param values: an iterable of ints or floats to calculate :param values: an iterable of ints or floats to calculate
""" """
count = len(values) count = len(values)
# bail early before sorting if 0 or 1 values in list # bail early before sorting if 0 or 1 values in list
if count == 0: if count == 0:
return None return None
elif count == 1: elif count == 1:
return values[0] return values[0]
values = sorted(values) values = sorted(values)
if count % 2 == 1: if count % 2 == 1:
# odd number of items, return middle value # odd number of items, return middle value
return float(values[count / 2]) return float(values[int(count / 2)])
else: else:
# even number of items, return average of middle two items # even number of items, return average of middle two items
mid = count / 2 mid = int(count / 2)
return sum(values[mid - 1:mid + 1]) / 2.0 return sum(values[mid - 1:mid + 1]) / 2.0
def _stddev(values, population=False): def _stddev(values, population=False):
""" Calculate the standard deviation and variance of a list of values. """ Calculate the standard deviation and variance of a list of values.
:param values: an iterable of ints or floats to calculate :param values: an iterable of ints or floats to calculate
:param population: True if values represents entire population, :param population: True if values represents entire population,
False if it is a sample of the population False if it is a sample of the population
""" """
avg = _average(values) avg = _average(values)
count = len(values) if population else len(values) - 1 count = len(values) if population else len(values) - 1
# square the difference between each value and the average # square the difference between each value and the average
diffsq = ((i - avg) ** 2 for i in values) diffsq = ((i - avg) ** 2 for i in values)
# the average of the squared differences # the average of the squared differences
variance = sum(diffsq) / float(count) variance = sum(diffsq) / float(count)
return (math.sqrt(variance), variance) # stddev is sqrt of variance return (math.sqrt(variance), variance) # stddev is sqrt of variance
class StatsFilter(Filter): class StatsFilter(Filter):
""" Base for all stats filters. """ Base for all stats filters.
""" """
def __init__(self, field, test=None): def __init__(self, field, test=None):
self._field = field self._field = field
self._test = test self._test = test
def process_record(self, record): def process_record(self, record):
if self._test is None or self._test(record): if self._test is None or self._test(record):
self.process_field(record[self._field]) self.process_field(record[self._field])
return record return record
def process_field(self, record): def process_field(self, record):
raise NotImplementedError('process_field not defined in ' + raise NotImplementedError('process_field not defined in ' +
self.__class__.__name__) self.__class__.__name__)
def value(self): def value(self):
raise NotImplementedError('value not defined in ' + raise NotImplementedError('value not defined in ' +
self.__class__.__name__) self.__class__.__name__)
@ -80,14 +80,14 @@ class Sum(StatsFilter):
""" Calculate the sum of the values in a field. Field must contain either """ Calculate the sum of the values in a field. Field must contain either
int or float values. int or float values.
""" """
def __init__(self, field, initial=0, **kwargs): def __init__(self, field, initial=0, **kwargs):
super(Sum, self).__init__(field, **kwargs) super(Sum, self).__init__(field, **kwargs)
self._value = initial self._value = initial
def process_field(self, item): def process_field(self, item):
self._value += item or 0 self._value += item or 0
def value(self): def value(self):
return self._value return self._value
@ -95,35 +95,35 @@ class Average(StatsFilter):
""" Calculate the average (mean) of the values in a field. Field must """ Calculate the average (mean) of the values in a field. Field must
contain either int or float values. contain either int or float values.
""" """
def __init__(self, field, initial=0, **kwargs): def __init__(self, field, initial=0, **kwargs):
super(Average, self).__init__(field, **kwargs) super(Average, self).__init__(field, **kwargs)
self._value = initial self._value = initial
self._count = 0 self._count = 0
def process_field(self, item): def process_field(self, item):
if item is not None: if item is not None:
self._value += item self._value += item
self._count += 1 self._count += 1
def value(self): def value(self):
return self._value / float(self._count) return self._value / float(self._count)
class Median(StatsFilter): class Median(StatsFilter):
""" Calculate the median of the values in a field. Field must contain """ Calculate the median of the values in a field. Field must contain
either int or float values. either int or float values.
**This filter keeps a list of field values in memory.** **This filter keeps a list of field values in memory.**
""" """
def __init__(self, field, **kwargs): def __init__(self, field, **kwargs):
super(Median, self).__init__(field, **kwargs) super(Median, self).__init__(field, **kwargs)
self._values = [] self._values = []
def process_field(self, item): def process_field(self, item):
if item is not None: if item is not None:
self._values.append(item) self._values.append(item)
def value(self): def value(self):
return _median(self._values) return _median(self._values)
@ -131,19 +131,19 @@ class MinMax(StatsFilter):
""" Find the minimum and maximum values in a field. Field must contain """ Find the minimum and maximum values in a field. Field must contain
either int or float values. either int or float values.
""" """
def __init__(self, field, **kwargs): def __init__(self, field, **kwargs):
super(MinMax, self).__init__(field, **kwargs) super(MinMax, self).__init__(field, **kwargs)
self._max = None self._max = None
self._min = None self._min = None
def process_field(self, item): def process_field(self, item):
if item is not None: if item is not None:
if self._max is None or item > self._max: if self._max is None or item > self._max:
self._max = item self._max = item
if self._min is None or item < self._min: if self._min is None or item < self._min:
self._min = item self._min = item
def value(self): def value(self):
return (self._min, self._max) return (self._min, self._max)
@ -156,24 +156,24 @@ class StandardDeviation(StatsFilter):
**This filter keeps a list of field values in memory.** **This filter keeps a list of field values in memory.**
""" """
def __init__(self, field, **kwargs): def __init__(self, field, **kwargs):
super(StandardDeviation, self).__init__(field, **kwargs) super(StandardDeviation, self).__init__(field, **kwargs)
self._values = [] self._values = []
def process_field(self, item): def process_field(self, item):
if item is not None: if item is not None:
self._values.append(item) self._values.append(item)
def average(self): def average(self):
return _average(self._values) return _average(self._values)
def median(self): def median(self):
return _median(self._values) return _median(self._values)
def value(self, population=False): def value(self, population=False):
""" Return a tuple of (standard_deviation, variance). """ Return a tuple of (standard_deviation, variance).
:param population: True if values represents entire population, :param population: True if values represents entire population,
False if values is a sample. Default: False False if values is a sample. Default: False
""" """
@ -185,34 +185,34 @@ class Histogram(StatsFilter):
generates a basic and limited histogram useful for printing to the generates a basic and limited histogram useful for printing to the
command line. The label_length attribute determines the padding and command line. The label_length attribute determines the padding and
cut-off of the basic histogram labels. cut-off of the basic histogram labels.
**This filters maintains a dict of unique field values in memory.** **This filters maintains a dict of unique field values in memory.**
""" """
label_length = 6 label_length = 6
def __init__(self, field, **kwargs): def __init__(self, field, **kwargs):
super(Histogram, self).__init__(field, **kwargs) super(Histogram, self).__init__(field, **kwargs)
self._counter = collections.Counter() self._counter = collections.Counter()
def process_field(self, item): def process_field(self, item):
self._counter[self.prep_field(item)] += 1 self._counter[self.prep_field(item)] += 1
def prep_field(self, item): def prep_field(self, item):
return item return item
def value(self): def value(self):
return self._counter.copy() return self._counter.copy()
def in_order(self): def in_order(self):
ordered = [] ordered = []
for key in sorted(self._counter.keys()): for key in sorted(self._counter.keys()):
ordered.append((key, self._counter[key])) ordered.append((key, self._counter[key]))
return ordered return ordered
def most_common(self, n=None): def most_common(self, n=None):
return self._counter.most_common(n) return self._counter.most_common(n)
@classmethod @classmethod
def as_string(self, occurences, label_length): def as_string(self, occurences, label_length):
output = "\n" output = "\n"
@ -220,6 +220,6 @@ class Histogram(StatsFilter):
key_str = str(key).ljust(label_length)[:label_length] key_str = str(key).ljust(label_length)[:label_length]
output += "%s %s\n" % (key_str, "*" * count) output += "%s %s\n" % (key_str, "*" * count)
return output return output
def __str__(self): def __str__(self):
return Histogram.as_string(self.in_order(), label_length=self.label_length) return Histogram.as_string(self.in_order(), label_length=self.label_length)

View File

@ -1,47 +1,86 @@
from __future__ import unicode_literals
from contextlib import closing
from io import StringIO
import os
import unittest import unittest
from cStringIO import StringIO
from saucebrush.emitters import DebugEmitter, CSVEmitter, CountEmitter from saucebrush.emitters import (
DebugEmitter, CSVEmitter, CountEmitter, SqliteEmitter, SqlDumpEmitter)
class EmitterTestCase(unittest.TestCase): class EmitterTestCase(unittest.TestCase):
def setUp(self):
self.output = StringIO()
def test_debug_emitter(self): def test_debug_emitter(self):
de = DebugEmitter(self.output) with closing(StringIO()) as output:
data = de.attach([1,2,3]) de = DebugEmitter(output)
for _ in data: list(de.attach([1,2,3]))
pass self.assertEqual(output.getvalue(), '1\n2\n3\n')
self.assertEquals(self.output.getvalue(), '1\n2\n3\n')
def test_csv_emitter(self):
ce = CSVEmitter(self.output, ('x','y','z'))
data = ce.attach([{'x':1,'y':2,'z':3}, {'x':5, 'y':5, 'z':5}])
for _ in data:
pass
self.assertEquals(self.output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
def test_count_emitter(self): def test_count_emitter(self):
# values for test # values for test
values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22] values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
# test without of parameter with closing(StringIO()) as output:
ce = CountEmitter(every=10, outfile=self.output, format="%(count)s records\n")
list(ce.attach(values)) # test without of parameter
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n') ce = CountEmitter(every=10, outfile=output, format="%(count)s records\n")
ce.done() list(ce.attach(values))
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n22 records\n') self.assertEqual(output.getvalue(), '10 records\n20 records\n')
ce.done()
# reset output self.assertEqual(output.getvalue(), '10 records\n20 records\n22 records\n')
self.output.truncate(0)
with closing(StringIO()) as output:
# test with of parameter
ce = CountEmitter(every=10, outfile=self.output, of=len(values)) # test with of parameter
list(ce.attach(values)) ce = CountEmitter(every=10, outfile=output, of=len(values))
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n') list(ce.attach(values))
ce.done() self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n')
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n') ce.done()
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n')
def test_csv_emitter(self):
try:
import cStringIO # if Python 2.x then use old cStringIO
io = cStringIO.StringIO()
except:
io = StringIO() # if Python 3.x then use StringIO
with closing(io) as output:
ce = CSVEmitter(output, ('x','y','z'))
list(ce.attach([{'x':1, 'y':2, 'z':3}, {'x':5, 'y':5, 'z':5}]))
self.assertEqual(output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
def test_sqlite_emitter(self):
import sqlite3, tempfile
with closing(tempfile.NamedTemporaryFile(suffix='.db')) as f:
db_path = f.name
sle = SqliteEmitter(db_path, 'testtable', fieldnames=('a','b','c'))
list(sle.attach([{'a': '1', 'b': '2', 'c': '3'}]))
sle.done()
with closing(sqlite3.connect(db_path)) as conn:
cur = conn.cursor()
cur.execute("""SELECT a, b, c FROM testtable""")
results = cur.fetchall()
os.unlink(db_path)
self.assertEqual(results, [('1', '2', '3')])
def test_sql_dump_emitter(self):
with closing(StringIO()) as bffr:
sde = SqlDumpEmitter(bffr, 'testtable', ('a', 'b'))
list(sde.attach([{'a': 1, 'b': '2'}]))
sde.done()
self.assertEqual(bffr.getvalue(), "INSERT INTO `testtable` (`a`,`b`) VALUES (1,'2');\n")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -57,41 +57,42 @@ class FilterTestCase(unittest.TestCase):
def assert_filter_result(self, filter_obj, expected_data): def assert_filter_result(self, filter_obj, expected_data):
result = filter_obj.attach(self._simple_data()) result = filter_obj.attach(self._simple_data())
self.assertEquals(list(result), expected_data) self.assertEqual(list(result), expected_data)
def test_reject_record(self): def test_reject_record(self):
recipe = DummyRecipe() recipe = DummyRecipe()
f = Doubler() f = Doubler()
result = f.attach([1,2,3], recipe=recipe) result = f.attach([1,2,3], recipe=recipe)
result.next() # next has to be called for attach to take effect # next has to be called for attach to take effect
next(result)
f.reject_record('bad', 'this one was bad') f.reject_record('bad', 'this one was bad')
# ensure that the rejection propagated to the recipe # ensure that the rejection propagated to the recipe
self.assertEquals('bad', recipe.rejected_record) self.assertEqual('bad', recipe.rejected_record)
self.assertEquals('this one was bad', recipe.rejected_msg) self.assertEqual('this one was bad', recipe.rejected_msg)
def test_simple_filter(self): def test_simple_filter(self):
df = Doubler() df = Doubler()
result = df.attach([1,2,3]) result = df.attach([1,2,3])
# ensure we got a generator that yields 2,4,6 # ensure we got a generator that yields 2,4,6
self.assertEquals(type(result), types.GeneratorType) self.assertEqual(type(result), types.GeneratorType)
self.assertEquals(list(result), [2,4,6]) self.assertEqual(list(result), [2,4,6])
def test_simple_filter_return_none(self): def test_simple_filter_return_none(self):
cf = OddRemover() cf = OddRemover()
result = cf.attach(range(10)) result = cf.attach(range(10))
# ensure only even numbers remain # ensure only even numbers remain
self.assertEquals(list(result), [0,2,4,6,8]) self.assertEqual(list(result), [0,2,4,6,8])
def test_simple_yield_filter(self): def test_simple_yield_filter(self):
lf = ListFlattener() lf = ListFlattener()
result = lf.attach([[1],[2,3],[4,5,6]]) result = lf.attach([[1],[2,3],[4,5,6]])
# ensure we got a generator that yields 1,2,3,4,5,6 # ensure we got a generator that yields 1,2,3,4,5,6
self.assertEquals(type(result), types.GeneratorType) self.assertEqual(type(result), types.GeneratorType)
self.assertEquals(list(result), [1,2,3,4,5,6]) self.assertEqual(list(result), [1,2,3,4,5,6])
def test_simple_field_filter(self): def test_simple_field_filter(self):
ff = FieldDoubler(['a', 'c']) ff = FieldDoubler(['a', 'c'])
@ -107,7 +108,7 @@ class FilterTestCase(unittest.TestCase):
result = cf.attach(range(10)) result = cf.attach(range(10))
# ensure only even numbers remain # ensure only even numbers remain
self.assertEquals(list(result), [0,2,4,6,8]) self.assertEqual(list(result), [0,2,4,6,8])
### Tests for Subrecord ### Tests for Subrecord
@ -123,7 +124,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a', NonModifyingFieldDoubler('b')) sf = SubrecordFilter('a', NonModifyingFieldDoubler('b'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_deep(self): def test_subrecord_filter_deep(self):
data = [{'a': {'d':[{'b': 2}, {'b': 4}]}}, data = [{'a': {'d':[{'b': 2}, {'b': 4}]}},
@ -137,7 +138,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b')) sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_nonlist(self): def test_subrecord_filter_nonlist(self):
data = [ data = [
@ -155,7 +156,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_list_in_path(self): def test_subrecord_filter_list_in_path(self):
data = [ data = [
@ -173,7 +174,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_conditional_path(self): def test_conditional_path(self):
@ -202,7 +203,7 @@ class FilterTestCase(unittest.TestCase):
def test_field_keeper(self): def test_field_keeper(self):
fk = FieldKeeper(['c']) fk = FieldKeeper(['c'])
# check against expected results # check against expected results
expected_data = [{'c':3}, {'c':5}, {'c':100}] expected_data = [{'c':3}, {'c':5}, {'c':100}]
self.assert_filter_result(fk, expected_data) self.assert_filter_result(fk, expected_data)
@ -295,7 +296,7 @@ class FilterTestCase(unittest.TestCase):
expected_data = [{'a': 77}, {'a':33}] expected_data = [{'a': 77}, {'a':33}]
result = u.attach(in_data) result = u.attach(in_data)
self.assertEquals(list(result), expected_data) self.assertEqual(list(result), expected_data)
# TODO: unicode & string filter tests # TODO: unicode & string filter tests

View File

@ -1,6 +1,9 @@
from __future__ import unicode_literals
from io import BytesIO, StringIO
import unittest import unittest
import cStringIO
from saucebrush.sources import CSVSource, FixedWidthFileSource from saucebrush.sources import (
CSVSource, FixedWidthFileSource, HtmlTableSource, JSONSource)
class SourceTestCase(unittest.TestCase): class SourceTestCase(unittest.TestCase):
@ -9,14 +12,14 @@ class SourceTestCase(unittest.TestCase):
1,2,3 1,2,3
5,5,5 5,5,5
1,10,100''' 1,10,100'''
return cStringIO.StringIO(data) return StringIO(data)
def test_csv_source_basic(self): def test_csv_source_basic(self):
source = CSVSource(self._get_csv()) source = CSVSource(self._get_csv())
expected_data = [{'a':'1', 'b':'2', 'c':'3'}, expected_data = [{'a':'1', 'b':'2', 'c':'3'},
{'a':'5', 'b':'5', 'c':'5'}, {'a':'5', 'b':'5', 'c':'5'},
{'a':'1', 'b':'10', 'c':'100'}] {'a':'1', 'b':'10', 'c':'100'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_fieldnames(self): def test_csv_source_fieldnames(self):
source = CSVSource(self._get_csv(), ['x','y','z']) source = CSVSource(self._get_csv(), ['x','y','z'])
@ -24,35 +27,59 @@ class SourceTestCase(unittest.TestCase):
{'x':'1', 'y':'2', 'z':'3'}, {'x':'1', 'y':'2', 'z':'3'},
{'x':'5', 'y':'5', 'z':'5'}, {'x':'5', 'y':'5', 'z':'5'},
{'x':'1', 'y':'10', 'z':'100'}] {'x':'1', 'y':'10', 'z':'100'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_skiprows(self): def test_csv_source_skiprows(self):
source = CSVSource(self._get_csv(), skiprows=1) source = CSVSource(self._get_csv(), skiprows=1)
expected_data = [{'a':'5', 'b':'5', 'c':'5'}, expected_data = [{'a':'5', 'b':'5', 'c':'5'},
{'a':'1', 'b':'10', 'c':'100'}] {'a':'1', 'b':'10', 'c':'100'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_fixed_width_source(self): def test_fixed_width_source(self):
data = cStringIO.StringIO('JamesNovember 3 1986\nTim September151999') data = StringIO('JamesNovember 3 1986\nTim September151999')
fields = (('name',5), ('month',9), ('day',2), ('year',4)) fields = (('name',5), ('month',9), ('day',2), ('year',4))
source = FixedWidthFileSource(data, fields) source = FixedWidthFileSource(data, fields)
expected_data = [{'name':'James', 'month':'November', 'day':'3', expected_data = [{'name':'James', 'month':'November', 'day':'3',
'year':'1986'}, 'year':'1986'},
{'name':'Tim', 'month':'September', 'day':'15', {'name':'Tim', 'month':'September', 'day':'15',
'year':'1999'}] 'year':'1999'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_fixed_width_source(self): def test_json_source(self):
data = cStringIO.StringIO('JamesNovember.3.1986\nTim..September151999')
fields = (('name',5), ('month',9), ('day',2), ('year',4))
source = FixedWidthFileSource(data, fields, fillchars='.')
expected_data = [{'name':'James', 'month':'November', 'day':'3',
'year':'1986'},
{'name':'Tim', 'month':'September', 'day':'15',
'year':'1999'}]
self.assertEquals(list(source), expected_data)
content = StringIO("""[{"a": 1, "b": "2", "c": 3}]""")
js = JSONSource(content)
self.assertEqual(list(js), [{'a': 1, 'b': '2', 'c': 3}])
def test_html_table_source(self):
content = StringIO("""
<html>
<table id="thetable">
<tr>
<th>a</th>
<th>b</th>
<th>c</th>
</tr>
<tr>
<td>1</td>
<td>2</td>
<td>3</td>
</tr>
</table>
</html>
""")
try:
import lxml
hts = HtmlTableSource(content, 'thetable')
self.assertEqual(list(hts), [{'a': '1', 'b': '2', 'c': '3'}])
except ImportError:
self.skipTest("lxml is not installed")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,5 +1,10 @@
import os import os
import urllib2
try:
from urllib.request import urlopen # attemp py3 first
except ImportError:
from urllib2 import urlopen # fallback to py2
""" """
General utilities used within saucebrush that may be useful elsewhere. General utilities used within saucebrush that may be useful elsewhere.
""" """
@ -20,38 +25,23 @@ def get_django_model(dj_settings, app_label, model_name):
from django.db.models import get_model from django.db.models import get_model
return get_model(app_label, model_name) return get_model(app_label, model_name)
def string_dig(element, separator=''):
"""
Dig into BeautifulSoup HTML elements looking for inner strings.
If element resembled: <p><b>test</b><em>test</em></p>
then string_dig(element, '~') would return test~test
"""
if element.string:
return element.string
else:
return separator.join([string_dig(child)
for child in element.findAll(True)])
def flatten(item, prefix='', separator='_', keys=None): def flatten(item, prefix='', separator='_', keys=None):
""" """
Flatten nested dictionary into one with its keys concatenated together. Flatten nested dictionary into one with its keys concatenated together.
>>> flatten({'a':1, 'b':{'c':2}, 'd':[{'e':{'r':7}}, {'e':5}], >>> flatten({'a':1, 'b':{'c':2}, 'd':[{'e':{'r':7}}, {'e':5}],
'f':{'g':{'h':6}}}) 'f':{'g':{'h':6}}})
{'a': 1, 'b_c': 2, 'd': [{'e_r': 7}, {'e': 5}], 'f_g_h': 6} {'a': 1, 'b_c': 2, 'd': [{'e_r': 7}, {'e': 5}], 'f_g_h': 6}
""" """
# update dictionaries recursively # update dictionaries recursively
if isinstance(item, dict): if isinstance(item, dict):
# don't prepend a leading _ # don't prepend a leading _
if prefix != '': if prefix != '':
prefix += separator prefix += separator
retval = {} retval = {}
for key, value in item.iteritems(): for key, value in item.items():
if (not keys) or (key in keys): if (not keys) or (key in keys):
retval.update(flatten(value, prefix + key, separator, keys)) retval.update(flatten(value, prefix + key, separator, keys))
else: else:
@ -60,9 +50,8 @@ def flatten(item, prefix='', separator='_', keys=None):
#elif isinstance(item, (tuple, list)): #elif isinstance(item, (tuple, list)):
# return {prefix: [flatten(i, prefix, separator, keys) for i in item]} # return {prefix: [flatten(i, prefix, separator, keys) for i in item]}
else: else:
print item, prefix
return {prefix: item} return {prefix: item}
def str_or_list(obj): def str_or_list(obj):
if isinstance(obj, str): if isinstance(obj, str):
return [obj] return [obj]
@ -76,7 +65,7 @@ def str_or_list(obj):
class Files(object): class Files(object):
""" Iterate over multiple files as a single file. Pass the paths of the """ Iterate over multiple files as a single file. Pass the paths of the
files as arguments to the class constructor: files as arguments to the class constructor:
for line in Files('/path/to/file/a', '/path/to/file/b'): for line in Files('/path/to/file/a', '/path/to/file/b'):
pass pass
""" """
@ -105,15 +94,15 @@ class Files(object):
class RemoteFile(object): class RemoteFile(object):
""" Stream data from a remote file. """ Stream data from a remote file.
:param url: URL to remote file :param url: URL to remote file
""" """
def __init__(self, url): def __init__(self, url):
self._url = url self._url = url
def __iter__(self): def __iter__(self):
resp = urllib2.urlopen(self._url) resp = urlopen(self._url)
for line in resp: for line in resp:
yield line.rstrip() yield line.rstrip()
resp.close() resp.close()