whole bunch of changes to support new unicode-based strings in Py3

This commit is contained in:
Jeremy Carbaugh 2012-03-11 23:14:39 -07:00
parent 46c8ab3ae3
commit d5b56b931b
4 changed files with 77 additions and 67 deletions

View File

@ -2,7 +2,7 @@
Saucebrush Emitters are filters that instead of modifying the record, output Saucebrush Emitters are filters that instead of modifying the record, output
it in some manner. it in some manner.
""" """
from __future__ import unicode_literals
from saucebrush.filters import Filter from saucebrush.filters import Filter
class Emitter(Filter): class Emitter(Filter):
@ -50,7 +50,7 @@ class DebugEmitter(Emitter):
self._outfile = outfile self._outfile = outfile
def emit_record(self, record): def emit_record(self, record):
self._outfile.write(str(record) + '\n') self._outfile.write("{0}\n".format(record))
class CountEmitter(Emitter): class CountEmitter(Emitter):
@ -83,16 +83,16 @@ class CountEmitter(Emitter):
self._of = of self._of = of
self.count = 0 self.count = 0
def __str__(self): def format(self):
return self._format % {'count': self.count, 'of': self._of} return self._format % {'count': self.count, 'of': self._of}
def emit_record(self, record): def emit_record(self, record):
self.count += 1 self.count += 1
if self.count % self._every == 0: if self.count % self._every == 0:
self._outfile.write(str(self)) self._outfile.write(self.format())
def done(self): def done(self):
self._outfile.write(str(self)) self._outfile.write(self.format())
class CSVEmitter(Emitter): class CSVEmitter(Emitter):
@ -107,7 +107,9 @@ class CSVEmitter(Emitter):
import csv import csv
self._dictwriter = csv.DictWriter(csvfile, fieldnames) self._dictwriter = csv.DictWriter(csvfile, fieldnames)
# write header row # write header row
self._dictwriter.writerow(dict(zip(fieldnames, fieldnames))) header_row = dict(zip(fieldnames, fieldnames))
print(header_row)
self._dictwriter.writerow(header_row)
def emit_record(self, record): def emit_record(self, record):
self._dictwriter.writerow(record) self._dictwriter.writerow(record)

View File

@ -1,47 +1,53 @@
from __future__ import unicode_literals
from contextlib import closing
from io import BytesIO, StringIO
import unittest import unittest
from cStringIO import StringIO
from saucebrush.emitters import DebugEmitter, CSVEmitter, CountEmitter from saucebrush.emitters import DebugEmitter, CSVEmitter, CountEmitter
class EmitterTestCase(unittest.TestCase): class EmitterTestCase(unittest.TestCase):
def setUp(self):
self.output = StringIO()
def test_debug_emitter(self): def test_debug_emitter(self):
de = DebugEmitter(self.output) with closing(StringIO()) as output:
data = de.attach([1,2,3]) de = DebugEmitter(output)
for _ in data: list(de.attach([1,2,3]))
pass self.assertEqual(output.getvalue(), '1\n2\n3\n')
self.assertEquals(self.output.getvalue(), '1\n2\n3\n')
def test_csv_emitter(self): def test_csv_emitter(self):
ce = CSVEmitter(self.output, ('x','y','z'))
data = ce.attach([{'x':1,'y':2,'z':3}, {'x':5, 'y':5, 'z':5}]) try:
for _ in data: import cStringIO # if Python 2.x then use BytesIO
pass io = BytesIO()
self.assertEquals(self.output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n') except:
io = StringIO() # if Python 3.x then use StringIO
with closing(io) as output:
ce = CSVEmitter(output, ('x','y','z'))
list(ce.attach([{'x':1, 'y':2, 'z':3}, {'x':5, 'y':5, 'z':5}]))
self.assertEqual(output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
def test_count_emitter(self): def test_count_emitter(self):
# values for test # values for test
values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22] values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
# test without of parameter with closing(StringIO()) as output:
ce = CountEmitter(every=10, outfile=self.output, format="%(count)s records\n")
list(ce.attach(values))
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n')
ce.done()
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n22 records\n')
# reset output # test without of parameter
self.output.truncate(0) ce = CountEmitter(every=10, outfile=output, format="%(count)s records\n")
list(ce.attach(values))
self.assertEqual(output.getvalue(), '10 records\n20 records\n')
ce.done()
self.assertEqual(output.getvalue(), '10 records\n20 records\n22 records\n')
# test with of parameter with closing(StringIO()) as output:
ce = CountEmitter(every=10, outfile=self.output, of=len(values))
list(ce.attach(values)) # test with of parameter
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n') ce = CountEmitter(every=10, outfile=output, of=len(values))
ce.done() list(ce.attach(values))
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n') self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n')
ce.done()
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -57,7 +57,7 @@ class FilterTestCase(unittest.TestCase):
def assert_filter_result(self, filter_obj, expected_data): def assert_filter_result(self, filter_obj, expected_data):
result = filter_obj.attach(self._simple_data()) result = filter_obj.attach(self._simple_data())
self.assertEquals(list(result), expected_data) self.assertEqual(list(result), expected_data)
def test_reject_record(self): def test_reject_record(self):
recipe = DummyRecipe() recipe = DummyRecipe()
@ -68,31 +68,31 @@ class FilterTestCase(unittest.TestCase):
f.reject_record('bad', 'this one was bad') f.reject_record('bad', 'this one was bad')
# ensure that the rejection propagated to the recipe # ensure that the rejection propagated to the recipe
self.assertEquals('bad', recipe.rejected_record) self.assertEqual('bad', recipe.rejected_record)
self.assertEquals('this one was bad', recipe.rejected_msg) self.assertEqual('this one was bad', recipe.rejected_msg)
def test_simple_filter(self): def test_simple_filter(self):
df = Doubler() df = Doubler()
result = df.attach([1,2,3]) result = df.attach([1,2,3])
# ensure we got a generator that yields 2,4,6 # ensure we got a generator that yields 2,4,6
self.assertEquals(type(result), types.GeneratorType) self.assertEqual(type(result), types.GeneratorType)
self.assertEquals(list(result), [2,4,6]) self.assertEqual(list(result), [2,4,6])
def test_simple_filter_return_none(self): def test_simple_filter_return_none(self):
cf = OddRemover() cf = OddRemover()
result = cf.attach(range(10)) result = cf.attach(range(10))
# ensure only even numbers remain # ensure only even numbers remain
self.assertEquals(list(result), [0,2,4,6,8]) self.assertEqual(list(result), [0,2,4,6,8])
def test_simple_yield_filter(self): def test_simple_yield_filter(self):
lf = ListFlattener() lf = ListFlattener()
result = lf.attach([[1],[2,3],[4,5,6]]) result = lf.attach([[1],[2,3],[4,5,6]])
# ensure we got a generator that yields 1,2,3,4,5,6 # ensure we got a generator that yields 1,2,3,4,5,6
self.assertEquals(type(result), types.GeneratorType) self.assertEqual(type(result), types.GeneratorType)
self.assertEquals(list(result), [1,2,3,4,5,6]) self.assertEqual(list(result), [1,2,3,4,5,6])
def test_simple_field_filter(self): def test_simple_field_filter(self):
ff = FieldDoubler(['a', 'c']) ff = FieldDoubler(['a', 'c'])
@ -108,7 +108,7 @@ class FilterTestCase(unittest.TestCase):
result = cf.attach(range(10)) result = cf.attach(range(10))
# ensure only even numbers remain # ensure only even numbers remain
self.assertEquals(list(result), [0,2,4,6,8]) self.assertEqual(list(result), [0,2,4,6,8])
### Tests for Subrecord ### Tests for Subrecord
@ -124,7 +124,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a', NonModifyingFieldDoubler('b')) sf = SubrecordFilter('a', NonModifyingFieldDoubler('b'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_deep(self): def test_subrecord_filter_deep(self):
data = [{'a': {'d':[{'b': 2}, {'b': 4}]}}, data = [{'a': {'d':[{'b': 2}, {'b': 4}]}},
@ -138,7 +138,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b')) sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_nonlist(self): def test_subrecord_filter_nonlist(self):
data = [ data = [
@ -156,7 +156,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_subrecord_filter_list_in_path(self): def test_subrecord_filter_list_in_path(self):
data = [ data = [
@ -174,7 +174,7 @@ class FilterTestCase(unittest.TestCase):
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c')) sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
result = sf.attach(data) result = sf.attach(data)
self.assertEquals(list(result), expected) self.assertEqual(list(result), expected)
def test_conditional_path(self): def test_conditional_path(self):
@ -296,7 +296,7 @@ class FilterTestCase(unittest.TestCase):
expected_data = [{'a': 77}, {'a':33}] expected_data = [{'a': 77}, {'a':33}]
result = u.attach(in_data) result = u.attach(in_data)
self.assertEquals(list(result), expected_data) self.assertEqual(list(result), expected_data)
# TODO: unicode & string filter tests # TODO: unicode & string filter tests

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals
from io import BytesIO, StringIO
import unittest import unittest
import cStringIO
from saucebrush.sources import CSVSource, FixedWidthFileSource from saucebrush.sources import CSVSource, FixedWidthFileSource
class SourceTestCase(unittest.TestCase): class SourceTestCase(unittest.TestCase):
@ -9,14 +11,14 @@ class SourceTestCase(unittest.TestCase):
1,2,3 1,2,3
5,5,5 5,5,5
1,10,100''' 1,10,100'''
return cStringIO.StringIO(data) return StringIO(data)
def test_csv_source_basic(self): def test_csv_source_basic(self):
source = CSVSource(self._get_csv()) source = CSVSource(self._get_csv())
expected_data = [{'a':'1', 'b':'2', 'c':'3'}, expected_data = [{'a':'1', 'b':'2', 'c':'3'},
{'a':'5', 'b':'5', 'c':'5'}, {'a':'5', 'b':'5', 'c':'5'},
{'a':'1', 'b':'10', 'c':'100'}] {'a':'1', 'b':'10', 'c':'100'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_fieldnames(self): def test_csv_source_fieldnames(self):
source = CSVSource(self._get_csv(), ['x','y','z']) source = CSVSource(self._get_csv(), ['x','y','z'])
@ -24,23 +26,23 @@ class SourceTestCase(unittest.TestCase):
{'x':'1', 'y':'2', 'z':'3'}, {'x':'1', 'y':'2', 'z':'3'},
{'x':'5', 'y':'5', 'z':'5'}, {'x':'5', 'y':'5', 'z':'5'},
{'x':'1', 'y':'10', 'z':'100'}] {'x':'1', 'y':'10', 'z':'100'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_csv_source_skiprows(self): def test_csv_source_skiprows(self):
source = CSVSource(self._get_csv(), skiprows=1) source = CSVSource(self._get_csv(), skiprows=1)
expected_data = [{'a':'5', 'b':'5', 'c':'5'}, expected_data = [{'a':'5', 'b':'5', 'c':'5'},
{'a':'1', 'b':'10', 'c':'100'}] {'a':'1', 'b':'10', 'c':'100'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
def test_fixed_width_source(self): def test_fixed_width_source(self):
data = cStringIO.StringIO('JamesNovember 3 1986\nTim September151999') data = StringIO('JamesNovember 3 1986\nTim September151999')
fields = (('name',5), ('month',9), ('day',2), ('year',4)) fields = (('name',5), ('month',9), ('day',2), ('year',4))
source = FixedWidthFileSource(data, fields) source = FixedWidthFileSource(data, fields)
expected_data = [{'name':'James', 'month':'November', 'day':'3', expected_data = [{'name':'James', 'month':'November', 'day':'3',
'year':'1986'}, 'year':'1986'},
{'name':'Tim', 'month':'September', 'day':'15', {'name':'Tim', 'month':'September', 'day':'15',
'year':'1999'}] 'year':'1999'}]
self.assertEquals(list(source), expected_data) self.assertEqual(list(source), expected_data)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()