Merge branch 'py3'

This commit is contained in:
James Turk 2012-03-13 11:28:03 -07:00
commit 91143a7da0
11 changed files with 279 additions and 201 deletions

View File

@ -3,7 +3,7 @@ from saucebrush.outputs import CSVOutput, DebugOutput
def merge_columns(datasource, mapping, merge_func): def merge_columns(datasource, mapping, merge_func):
for rowdata in datasource: for rowdata in datasource:
for to_col,from_cols in mapping.iteritems(): for to_col,from_cols in mapping.items():
values = [rowdata.pop(col, None) for col in from_cols] values = [rowdata.pop(col, None) for col in from_cols]
rowdata[to_col] = reduce(merge_func, values) rowdata[to_col] = reduce(merge_func, values)
yield rowdata yield rowdata

View File

@ -84,7 +84,7 @@ class FECSource(object):
@staticmethod @staticmethod
def get_form_type(rectype): def get_form_type(rectype):
for type_re, type in FECSource.FORM_MAPPING.iteritems(): for type_re, type in FECSource.FORM_MAPPING.items():
if type_re.match(rectype): if type_re.match(rectype):
return type return type

View File

@ -2,7 +2,7 @@
Saucebrush is a data loading & manipulation framework written in python. Saucebrush is a data loading & manipulation framework written in python.
""" """
import filters, emitters, sources, utils from . import filters, emitters, sources, utils
class SaucebrushError(Exception): class SaucebrushError(Exception):

View File

@ -2,7 +2,7 @@
Saucebrush Emitters are filters that instead of modifying the record, output Saucebrush Emitters are filters that instead of modifying the record, output
it in some manner. it in some manner.
""" """
from __future__ import unicode_literals
from saucebrush.filters import Filter from saucebrush.filters import Filter
class Emitter(Filter): class Emitter(Filter):
@ -50,7 +50,7 @@ class DebugEmitter(Emitter):
self._outfile = outfile self._outfile = outfile
def emit_record(self, record): def emit_record(self, record):
self._outfile.write(str(record) + '\n') self._outfile.write("{0}\n".format(record))
class CountEmitter(Emitter): class CountEmitter(Emitter):
@ -83,16 +83,16 @@ class CountEmitter(Emitter):
self._of = of self._of = of
self.count = 0 self.count = 0
def __str__(self): def format(self):
return self._format % {'count': self.count, 'of': self._of} return self._format % {'count': self.count, 'of': self._of}
def emit_record(self, record): def emit_record(self, record):
self.count += 1 self.count += 1
if self.count % self._every == 0: if self.count % self._every == 0:
self._outfile.write(str(self)) self._outfile.write(self.format())
def done(self): def done(self):
self._outfile.write(str(self)) self._outfile.write(self.format())
class CSVEmitter(Emitter): class CSVEmitter(Emitter):
@ -107,7 +107,8 @@ class CSVEmitter(Emitter):
import csv import csv
self._dictwriter = csv.DictWriter(csvfile, fieldnames) self._dictwriter = csv.DictWriter(csvfile, fieldnames)
# write header row # write header row
self._dictwriter.writerow(dict(zip(fieldnames, fieldnames))) header_row = dict(zip(fieldnames, fieldnames))
self._dictwriter.writerow(header_row)
def emit_record(self, record): def emit_record(self, record):
self._dictwriter.writerow(record) self._dictwriter.writerow(record)
@ -145,8 +146,8 @@ class SqliteEmitter(Emitter):
','.join(record.keys()), ','.join(record.keys()),
qmarks) qmarks)
try: try:
self._cursor.execute(insert, record.values()) self._cursor.execute(insert, list(record.values()))
except sqlite3.IntegrityError, ie: except sqlite3.IntegrityError as ie:
if not self._quiet: if not self._quiet:
raise ie raise ie
self.reject_record(record, ie.message) self.reject_record(record, ie.message)
@ -179,21 +180,25 @@ class SqlDumpEmitter(Emitter):
table_name, '`,`'.join(fieldnames)) table_name, '`,`'.join(fieldnames))
def quote(self, item): def quote(self, item):
if item is None: if item is None:
return "null" return "null"
elif isinstance(item, (unicode, str)):
try:
types = (basestring,)
except NameError:
types = (str,)
if isinstance(item, types):
item = item.replace("\\","\\\\").replace("'","\\'").replace(chr(0),'0') item = item.replace("\\","\\\\").replace("'","\\'").replace(chr(0),'0')
return "'%s'" % item return "'%s'" % item
else:
return "%s" % item return "%s" % item
def emit_record(self, record): def emit_record(self, record):
quoted_data = [self.quote(record[field]) for field in self._fieldnames] quoted_data = [self.quote(record[field]) for field in self._fieldnames]
self._outfile.write(self._insert_str % ','.join(quoted_data)) self._outfile.write(self._insert_str % ','.join(quoted_data))
def done(self):
self._outfile.close()
class DjangoModelEmitter(Emitter): class DjangoModelEmitter(Emitter):
""" Emitter that populates a table corresponding to a django model. """ Emitter that populates a table corresponding to a django model.

View File

@ -227,7 +227,7 @@ class FieldKeeper(Filter):
self._target_keys = utils.str_or_list(keys) self._target_keys = utils.str_or_list(keys)
def process_record(self, record): def process_record(self, record):
for key in record.keys(): for key in list(record.keys()):
if key not in self._target_keys: if key not in self._target_keys:
del record[key] del record[key]
return record return record
@ -269,7 +269,7 @@ class FieldMerger(Filter):
self._keep_fields = keep_fields self._keep_fields = keep_fields
def process_record(self, record): def process_record(self, record):
for to_col, from_cols in self._field_mapping.iteritems(): for to_col, from_cols in self._field_mapping.items():
if self._keep_fields: if self._keep_fields:
values = [record.get(col, None) for col in from_cols] values = [record.get(col, None) for col in from_cols]
else: else:
@ -301,7 +301,11 @@ class FieldAdder(Filter):
self._field_name = field_name self._field_name = field_name
self._field_value = field_value self._field_value = field_value
if hasattr(self._field_value, '__iter__'): if hasattr(self._field_value, '__iter__'):
self._field_value = iter(self._field_value).next value_iter = iter(self._field_value)
if hasattr(value_iter, "next"):
self._field_value = value_iter.next
else:
self._field_value = value_iter.__next__
self._replace = replace self._replace = replace
def process_record(self, record): def process_record(self, record):
@ -328,7 +332,7 @@ class FieldCopier(Filter):
def process_record(self, record): def process_record(self, record):
# mapping is dest:source # mapping is dest:source
for dest, source in self._copy_mapping.iteritems(): for dest, source in self._copy_mapping.items():
record[dest] = record[source] record[dest] = record[source]
return record return record
@ -343,7 +347,7 @@ class FieldRenamer(Filter):
def process_record(self, record): def process_record(self, record):
# mapping is dest:source # mapping is dest:source
for dest, source in self._rename_mapping.iteritems(): for dest, source in self._rename_mapping.items():
record[dest] = record.pop(source) record[dest] = record.pop(source)
return record return record
@ -363,7 +367,7 @@ class Splitter(Filter):
self._split_mapping = split_mapping self._split_mapping = split_mapping
def process_record(self, record): def process_record(self, record):
for key, filters in self._split_mapping.iteritems(): for key, filters in self._split_mapping.items():
# if the key doesn't exist -- move on to next key # if the key doesn't exist -- move on to next key
try: try:
@ -479,7 +483,7 @@ class UnicodeFilter(Filter):
self._errors = errors self._errors = errors
def process_record(self, record): def process_record(self, record):
for key, value in record.iteritems(): for key, value in record.items():
if isinstance(value, str): if isinstance(value, str):
record[key] = unicode(value, self._encoding, self._errors) record[key] = unicode(value, self._encoding, self._errors)
elif isinstance(value, unicode): elif isinstance(value, unicode):
@ -494,7 +498,7 @@ class StringFilter(Filter):
self._errors = errors self._errors = errors
def process_record(self, record): def process_record(self, record):
for key, value in record.iteritems(): for key, value in record.items():
if isinstance(value, unicode): if isinstance(value, unicode):
record[key] = value.encode(self._encoding, self._errors) record[key] = value.encode(self._encoding, self._errors)
return record return record
@ -584,7 +588,7 @@ class NameCleaner(Filter):
# if there is a match, remove original name and add pieces # if there is a match, remove original name and add pieces
if match: if match:
record.pop(key) record.pop(key)
for k,v in match.groupdict().iteritems(): for k,v in match.groupdict().items():
record[self._name_prefix + k] = v record[self._name_prefix + k] = v
break break

View File

@ -4,8 +4,9 @@
All sources must implement the iterable interface and return python All sources must implement the iterable interface and return python
dictionaries. dictionaries.
""" """
from __future__ import unicode_literals
import string import string
from saucebrush import utils from saucebrush import utils
class CSVSource(object): class CSVSource(object):
@ -25,8 +26,8 @@ class CSVSource(object):
def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs): def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs):
import csv import csv
self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs) self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs)
for _ in xrange(skiprows): for _ in range(skiprows):
self._dictreader.next() next(self._dictreader)
def __iter__(self): def __iter__(self):
return self._dictreader return self._dictreader
@ -59,13 +60,18 @@ class FixedWidthFileSource(object):
def __iter__(self): def __iter__(self):
return self return self
def next(self): def __next__(self):
line = self._fwfile.next() line = next(self._fwfile)
record = {} record = {}
for name, range_ in self._fields_dict.iteritems(): for name, range_ in self._fields_dict.items():
record[name] = line[range_[0]:range_[1]].rstrip(self._fillchars) record[name] = line[range_[0]:range_[1]].rstrip(self._fillchars)
return record return record
def next(self):
""" Keep Python 2 next() method that defers to __next__().
"""
return self.__next__()
class HtmlTableSource(object): class HtmlTableSource(object):
""" Saucebrush source for reading data from an HTML table. """ Saucebrush source for reading data from an HTML table.
@ -86,26 +92,32 @@ class HtmlTableSource(object):
def __init__(self, htmlfile, id_or_num, fieldnames=None, skiprows=0): def __init__(self, htmlfile, id_or_num, fieldnames=None, skiprows=0):
# extract the table # extract the table
from BeautifulSoup import BeautifulSoup from lxml.html import parse
soup = BeautifulSoup(htmlfile.read()) doc = parse(htmlfile).getroot()
if isinstance(id_or_num, int): if isinstance(id_or_num, int):
table = soup.findAll('table')[id_or_num] table = doc.cssselect('table')[id_or_num]
elif isinstance(id_or_num, str): else:
table = soup.find('table', id=id_or_num) table = doc.cssselect('table#%s' % id_or_num)
table = table[0] # get the first table
# skip the necessary number of rows # skip the necessary number of rows
self._rows = table.findAll('tr')[skiprows:] self._rows = table.cssselect('tr')[skiprows:]
# determine the fieldnames # determine the fieldnames
if not fieldnames: if not fieldnames:
self._fieldnames = [td.string self._fieldnames = [td.text_content()
for td in self._rows[0].findAll(('td','th'))] for td in self._rows[0].cssselect('td, th')]
skiprows += 1
else: else:
self._fieldnames = fieldnames self._fieldnames = fieldnames
# skip the necessary number of rows
self._rows = table.cssselect('tr')[skiprows:]
def process_tr(self): def process_tr(self):
for row in self._rows: for row in self._rows:
strings = [utils.string_dig(td) for td in row.findAll('td')] strings = [td.text_content() for td in row.cssselect('td')]
yield dict(zip(self._fieldnames, strings)) yield dict(zip(self._fieldnames, strings))
def __iter__(self): def __iter__(self):
@ -182,7 +194,7 @@ class SqliteSource(object):
self._conn = sqlite3.connect(self._dbpath) self._conn = sqlite3.connect(self._dbpath)
self._conn.row_factory = dict_factory self._conn.row_factory = dict_factory
if self._conn_params: if self._conn_params:
for param, value in self._conn_params.iteritems(): for param, value in self._conn_params.items():
setattr(self._conn, param, value) setattr(self._conn, param, value)
def _process_query(self): def _process_query(self):
@ -214,20 +226,20 @@ class FileSource(object):
def __iter__(self): def __iter__(self):
# This method would be a lot cleaner with the proposed # This method would be a lot cleaner with the proposed
# 'yield from' expression (PEP 380) # 'yield from' expression (PEP 380)
if hasattr(self._input, '__read__'): if hasattr(self._input, '__read__') or hasattr(self._input, 'read'):
for record in self._process_file(input): for record in self._process_file(self._input):
yield record yield record
elif isinstance(self._input, basestring): elif isinstance(self._input, str):
with open(self._input) as f: with open(self._input) as f:
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
elif hasattr(self._input, '__iter__'): elif hasattr(self._input, '__iter__'):
for el in self._input: for el in self._input:
if isinstance(el, basestring): if isinstance(el, str):
with open(el) as f: with open(el) as f:
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
elif hasattr(el, '__read__'): elif hasattr(el, '__read__') or hasattr(el, 'read'):
for record in self._process_file(f): for record in self._process_file(f):
yield record yield record
@ -244,10 +256,11 @@ class JSONSource(FileSource):
object. object.
""" """
def _process_file(self, file): def _process_file(self, f):
import json import json
obj = json.load(file) obj = json.load(f)
# If the top-level JSON object in the file is a list # If the top-level JSON object in the file is a list
# then yield each element separately; otherwise, yield # then yield each element separately; otherwise, yield

View File

@ -30,10 +30,10 @@ def _median(values):
if count % 2 == 1: if count % 2 == 1:
# odd number of items, return middle value # odd number of items, return middle value
return float(values[count / 2]) return float(values[int(count / 2)])
else: else:
# even number of items, return average of middle two items # even number of items, return average of middle two items
mid = count / 2 mid = int(count / 2)
return sum(values[mid - 1:mid + 1]) / 2.0 return sum(values[mid - 1:mid + 1]) / 2.0
def _stddev(values, population=False): def _stddev(values, population=False):

View File

@ -1,47 +1,86 @@
from __future__ import unicode_literals
from contextlib import closing
from io import StringIO
import os
import unittest import unittest
from cStringIO import StringIO
from saucebrush.emitters import DebugEmitter, CSVEmitter, CountEmitter from saucebrush.emitters import (
DebugEmitter, CSVEmitter, CountEmitter, SqliteEmitter, SqlDumpEmitter)
class EmitterTestCase(unittest.TestCase): class EmitterTestCase(unittest.TestCase):
def setUp(self):
self.output = StringIO()
def test_debug_emitter(self): def test_debug_emitter(self):
de = DebugEmitter(self.output) with closing(StringIO()) as output:
data = de.attach([1,2,3]) de = DebugEmitter(output)
for _ in data: list(de.attach([1,2,3]))
pass self.assertEqual(output.getvalue(), '1\n2\n3\n')
self.assertEquals(self.output.getvalue(), '1\n2\n3\n')
def test_csv_emitter(self):
ce = CSVEmitter(self.output, ('x','y','z'))
data = ce.attach([{'x':1,'y':2,'z':3}, {'x':5, 'y':5, 'z':5}])
for _ in data:
pass
self.assertEquals(self.output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
def test_count_emitter(self): def test_count_emitter(self):
# values for test # values for test
values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22] values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
# test without of parameter with closing(StringIO()) as output:
ce = CountEmitter(every=10, outfile=self.output, format="%(count)s records\n")
list(ce.attach(values))
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n')
ce.done()
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n22 records\n')
# reset output # test without of parameter
self.output.truncate(0) ce = CountEmitter(every=10, outfile=output, format="%(count)s records\n")
list(ce.attach(values))
self.assertEqual(output.getvalue(), '10 records\n20 records\n')
ce.done()
self.assertEqual(output.getvalue(), '10 records\n20 records\n22 records\n')
with closing(StringIO()) as output:
# test with of parameter
ce = CountEmitter(every=10, outfile=output, of=len(values))
list(ce.attach(values))
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n')
ce.done()
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n')
def test_csv_emitter(self):
try:
import cStringIO # if Python 2.x then use old cStringIO
io = cStringIO.StringIO()
except:
io = StringIO() # if Python 3.x then use StringIO
with closing(io) as output:
ce = CSVEmitter(output, ('x','y','z'))
list(ce.attach([{'x':1, 'y':2, 'z':3}, {'x':5, 'y':5, 'z':5}]))
self.assertEqual(output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
def test_sqlite_emitter(self):
import sqlite3, tempfile
with closing(tempfile.NamedTemporaryFile(suffix='.db')) as f:
db_path = f.name
sle = SqliteEmitter(db_path, 'testtable', fieldnames=('a','b','c'))
list(sle.attach([{'a': '1', 'b': '2', 'c': '3'}]))
sle.done()
with closing(sqlite3.connect(db_path)) as conn:
cur = conn.cursor()
cur.execute("""SELECT a, b, c FROM testtable""")
results = cur.fetchall()
os.unlink(db_path)
self.assertEqual(results, [('1', '2', '3')])
def test_sql_dump_emitter(self):
with closing(StringIO()) as bffr:
sde = SqlDumpEmitter(bffr, 'testtable', ('a', 'b'))
list(sde.attach([{'a': 1, 'b': '2'}]))
sde.done()
self.assertEqual(bffr.getvalue(), "INSERT INTO `testtable` (`a`,`b`) VALUES (1,'2');\n")
# test with of parameter
ce = CountEmitter(every=10, outfile=self.output, of=len(values))
list(ce.attach(values))
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n')
ce.done()
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -57,41 +57,42 @@ class FilterTestCase(unittest.TestCase):
def assert_filter_result(self, filter_obj, expected_data): def assert_filter_result(self, filter_obj, expected_data):
result = filter_obj.attach(self._simple_data()) result = filter_obj.attach(self._simple_data())
self.assertEquals(list(result), expected_data) self.assertEqual(list(result), expected_data)
def test_reject_record(self): def test_reject_record(self):
recipe = DummyRecipe() recipe = DummyRecipe()
f = Doubler() f = Doubler()
result = f.attach([1,2,3], recipe=recipe) result = f.attach([1,2,3], recipe=recipe)
result.next() # next has to be called for attach to take effect # next has to be called for attach to take effect
next(result)
f.reject_record('bad', 'this one was bad') f.reject_record('bad', 'this one was bad')
# ensure that the rejection propagated to the recipe # ensure that the rejection propagated to the recipe
self.assertEquals('bad', recipe.rejected_record) self.assertEqual('bad', recipe.rejected_record)
self.assertEquals('this one was bad', recipe.rejected_msg) self.assertEqual('this one was bad', recipe.rejected_msg)
def test_simple_filter(self): def test_simple_filter(self):
df = Doubler() df = Doubler()
result = df.attach([1,2,3]) result = df.attach([1,2,3])
# ensure we got a generator that yields 2,4,6 # ensure we got a generator that yields 2,4,6
self.assertEquals(type(result), types.GeneratorType) self.assertEqual(type(result), types.GeneratorType)
self.assertEquals(list(result), [2,4,6]) self.assertEqual(list(result), [2,4,6])
def test_simple_filter_return_none(self): def test_simple_filter_return_none(self):
cf = OddRemover() cf = OddRemover()
result = cf.attach(range(10)) result = cf.attach(range(10))
# ensure only even numbers remain # ensure only even numbers remain
self.assertEquals(list(result), [0,2,4,6,8]) self.assertEqual(list(result), [0,2,4,6,8])
def test_simple_yield_filter(self): def test_simple_yield_filter(self):
lf = ListFlattener() lf = ListFlattener()
result = lf.attach([[1],[2,3],[4,5,6]]) result = lf.attach([[1],[2,3],[4,5,6]])
# ensure we got a generator that yields 1,2,3,4,5,6 # ensure we got a generator that yields 1,2,3,4,5,6
self.assertEquals(type(result), types.GeneratorType) self.assertEqual(type(result), types.GeneratorType)
self.assertEquals(list(result), [1,2,3,4,5,6]) self.assertEqual(list(result), [1,2,3,4,5,6])
def test_simple_field_filter(self): def test_simple_field_filter(self):
ff = FieldDoubler(['a', 'c']) ff = FieldDoubler(['a', 'c'])
@ -107,7 +108,7 @@ class FilterTestCase(unittest.TestCase):
result = cf.attach(range(10)) result = cf.attach(range(10))
# ensure only even numbers remain # ensure only even numbers remain
self.assertEquals(list(result), [0,2,4,6,8]) self.assertEqual(list(result), [0,2,4,6,8])
### Tests for Subrecord ### Tests for Subrecord
@ -123,7 +124,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a', NonModifyingFieldDoubler('b')) sf = SubrecordFilter('a', NonModifyingFieldDoubler('b'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_deep(self): def test_subrecord_filter_deep(self):
data = [{'a': {'d':[{'b': 2}, {'b': 4}]}}, data = [{'a': {'d':[{'b': 2}, {'b': 4}]}},
@ -137,7 +138,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b')) sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_nonlist(self): def test_subrecord_filter_nonlist(self):
data = [ data = [
@ -155,7 +156,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_list_in_path(self): def test_subrecord_filter_list_in_path(self):
data = [ data = [
@ -173,7 +174,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_conditional_path(self): def test_conditional_path(self):
@ -295,7 +296,7 @@ class FilterTestCase(unittest.TestCase):
expected_data = [{'a': 77}, {'a':33}] expected_data = [{'a': 77}, {'a':33}]
result = u.attach(in_data) result = u.attach(in_data)
self.assertEquals(list(result), expected_data) self.assertEqual(list(result), expected_data)
# TODO: unicode & string filter tests # TODO: unicode & string filter tests

View File

@ -1,6 +1,9 @@
from __future__ import unicode_literals
from io import BytesIO, StringIO
import unittest import unittest
import cStringIO
from saucebrush.sources import CSVSource, FixedWidthFileSource from saucebrush.sources import (
CSVSource, FixedWidthFileSource, HtmlTableSource, JSONSource)
class SourceTestCase(unittest.TestCase): class SourceTestCase(unittest.TestCase):
@ -9,14 +12,14 @@ class SourceTestCase(unittest.TestCase):
1,2,3 1,2,3
5,5,5 5,5,5
1,10,100''' 1,10,100'''
return cStringIO.StringIO(data) return StringIO(data)
def test_csv_source_basic(self): def test_csv_source_basic(self):
source = CSVSource(self._get_csv()) source = CSVSource(self._get_csv())
expected_data = [{'a':'1', 'b':'2', 'c':'3'}, expected_data = [{'a':'1', 'b':'2', 'c':'3'},
{'a':'5', 'b':'5', 'c':'5'}, {'a':'5', 'b':'5', 'c':'5'},
{'a':'1', 'b':'10', 'c':'100'}] {'a':'1', 'b':'10', 'c':'100'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_fieldnames(self): def test_csv_source_fieldnames(self):
source = CSVSource(self._get_csv(), ['x','y','z']) source = CSVSource(self._get_csv(), ['x','y','z'])
@ -24,35 +27,59 @@ class SourceTestCase(unittest.TestCase):
{'x':'1', 'y':'2', 'z':'3'}, {'x':'1', 'y':'2', 'z':'3'},
{'x':'5', 'y':'5', 'z':'5'}, {'x':'5', 'y':'5', 'z':'5'},
{'x':'1', 'y':'10', 'z':'100'}] {'x':'1', 'y':'10', 'z':'100'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_skiprows(self): def test_csv_source_skiprows(self):
source = CSVSource(self._get_csv(), skiprows=1) source = CSVSource(self._get_csv(), skiprows=1)
expected_data = [{'a':'5', 'b':'5', 'c':'5'}, expected_data = [{'a':'5', 'b':'5', 'c':'5'},
{'a':'1', 'b':'10', 'c':'100'}] {'a':'1', 'b':'10', 'c':'100'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_fixed_width_source(self): def test_fixed_width_source(self):
data = cStringIO.StringIO('JamesNovember 3 1986\nTim September151999') data = StringIO('JamesNovember 3 1986\nTim September151999')
fields = (('name',5), ('month',9), ('day',2), ('year',4)) fields = (('name',5), ('month',9), ('day',2), ('year',4))
source = FixedWidthFileSource(data, fields) source = FixedWidthFileSource(data, fields)
expected_data = [{'name':'James', 'month':'November', 'day':'3', expected_data = [{'name':'James', 'month':'November', 'day':'3',
'year':'1986'}, 'year':'1986'},
{'name':'Tim', 'month':'September', 'day':'15', {'name':'Tim', 'month':'September', 'day':'15',
'year':'1999'}] 'year':'1999'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_fixed_width_source(self): def test_json_source(self):
data = cStringIO.StringIO('JamesNovember.3.1986\nTim..September151999')
fields = (('name',5), ('month',9), ('day',2), ('year',4))
source = FixedWidthFileSource(data, fields, fillchars='.')
expected_data = [{'name':'James', 'month':'November', 'day':'3',
'year':'1986'},
{'name':'Tim', 'month':'September', 'day':'15',
'year':'1999'}]
self.assertEquals(list(source), expected_data)
content = StringIO("""[{"a": 1, "b": "2", "c": 3}]""")
js = JSONSource(content)
self.assertEqual(list(js), [{'a': 1, 'b': '2', 'c': 3}])
def test_html_table_source(self):
content = StringIO("""
<html>
<table id="thetable">
<tr>
<th>a</th>
<th>b</th>
<th>c</th>
</tr>
<tr>
<td>1</td>
<td>2</td>
<td>3</td>
</tr>
</table>
</html>
""")
try:
import lxml
hts = HtmlTableSource(content, 'thetable')
self.assertEqual(list(hts), [{'a': '1', 'b': '2', 'c': '3'}])
except ImportError:
self.skipTest("lxml is not installed")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,5 +1,10 @@
import os import os
import urllib2
try:
from urllib.request import urlopen # attemp py3 first
except ImportError:
from urllib2 import urlopen # fallback to py2
""" """
General utilities used within saucebrush that may be useful elsewhere. General utilities used within saucebrush that may be useful elsewhere.
""" """
@ -20,21 +25,6 @@ def get_django_model(dj_settings, app_label, model_name):
from django.db.models import get_model from django.db.models import get_model
return get_model(app_label, model_name) return get_model(app_label, model_name)
def string_dig(element, separator=''):
"""
Dig into BeautifulSoup HTML elements looking for inner strings.
If element resembled: <p><b>test</b><em>test</em></p>
then string_dig(element, '~') would return test~test
"""
if element.string:
return element.string
else:
return separator.join([string_dig(child)
for child in element.findAll(True)])
def flatten(item, prefix='', separator='_', keys=None): def flatten(item, prefix='', separator='_', keys=None):
""" """
Flatten nested dictionary into one with its keys concatenated together. Flatten nested dictionary into one with its keys concatenated together.
@ -51,7 +41,7 @@ def flatten(item, prefix='', separator='_', keys=None):
if prefix != '': if prefix != '':
prefix += separator prefix += separator
retval = {} retval = {}
for key, value in item.iteritems(): for key, value in item.items():
if (not keys) or (key in keys): if (not keys) or (key in keys):
retval.update(flatten(value, prefix + key, separator, keys)) retval.update(flatten(value, prefix + key, separator, keys))
else: else:
@ -60,7 +50,6 @@ def flatten(item, prefix='', separator='_', keys=None):
#elif isinstance(item, (tuple, list)): #elif isinstance(item, (tuple, list)):
# return {prefix: [flatten(i, prefix, separator, keys) for i in item]} # return {prefix: [flatten(i, prefix, separator, keys) for i in item]}
else: else:
print item, prefix
return {prefix: item} return {prefix: item}
def str_or_list(obj): def str_or_list(obj):
@ -113,7 +102,7 @@ class RemoteFile(object):
self._url = url self._url = url
def __iter__(self): def __iter__(self):
resp = urllib2.urlopen(self._url) resp = urlopen(self._url)
for line in resp: for line in resp:
yield line.rstrip() yield line.rstrip()
resp.close() resp.close()