add CountEmitter and test
This commit is contained in:
parent
4c025493b1
commit
2123977895
@ -53,6 +53,34 @@ class DebugEmitter(Emitter):
|
||||
self._outfile.write(str(record) + '\n')
|
||||
|
||||
|
||||
class CountEmitter(Emitter):
|
||||
""" Emitter that writes the record count to a file-like object.
|
||||
|
||||
CountEmitter() by default writes to stdout.
|
||||
CountEmitter(outfile=open('text', 'w')) would print to a file name test.
|
||||
CountEmitter(every=1000000) would write the count every 1,000,000 records.
|
||||
"""
|
||||
|
||||
def __init__(self, every=1000, outfile=None, format=None):
|
||||
super(CountEmitter, self).__init__()
|
||||
if not outfile:
|
||||
import sys
|
||||
self._outfile = sys.stdout
|
||||
else:
|
||||
self._outfile = outfile
|
||||
self._format = "%s\n" if format is None else format
|
||||
self._every = every
|
||||
self.count = 0
|
||||
|
||||
def emit_record(self, record):
|
||||
self.count += 1
|
||||
if self.count % self._every == 0:
|
||||
self._outfile.write(self._format % self.count)
|
||||
|
||||
def done(self):
|
||||
self._outfile.write(self._format % self.count)
|
||||
|
||||
|
||||
class CSVEmitter(Emitter):
|
||||
""" Emitter that writes records to a CSV file.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from cStringIO import StringIO
|
||||
from saucebrush.emitters import DebugEmitter, CSVEmitter
|
||||
from saucebrush.emitters import DebugEmitter, CSVEmitter, CountEmitter
|
||||
|
||||
class EmitterTestCase(unittest.TestCase):
|
||||
|
||||
@ -20,6 +20,15 @@ class EmitterTestCase(unittest.TestCase):
|
||||
for _ in data:
|
||||
pass
|
||||
self.assertEquals(self.output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
|
||||
|
||||
def test_count_emitter(self):
|
||||
ce = CountEmitter(every=10, outfile=self.output, format="%s records\n")
|
||||
data = ce.attach([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22])
|
||||
for _ in data:
|
||||
pass
|
||||
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n')
|
||||
ce.done()
|
||||
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n22 records\n')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user