diff --git a/saucebrush/emitters.py b/saucebrush/emitters.py index aa3665c..a66ac34 100644 --- a/saucebrush/emitters.py +++ b/saucebrush/emitters.py @@ -53,6 +53,34 @@ class DebugEmitter(Emitter): self._outfile.write(str(record) + '\n') +class CountEmitter(Emitter): + """ Emitter that writes the record count to a file-like object. + + CountEmitter() by default writes to stdout. + CountEmitter(outfile=open('text', 'w')) would print to a file name test. + CountEmitter(every=1000000) would write the count every 1,000,000 records. + """ + + def __init__(self, every=1000, outfile=None, format=None): + super(CountEmitter, self).__init__() + if not outfile: + import sys + self._outfile = sys.stdout + else: + self._outfile = outfile + self._format = "%s\n" if format is None else format + self._every = every + self.count = 0 + + def emit_record(self, record): + self.count += 1 + if self.count % self._every == 0: + self._outfile.write(self._format % self.count) + + def done(self): + self._outfile.write(self._format % self.count) + + class CSVEmitter(Emitter): """ Emitter that writes records to a CSV file. diff --git a/saucebrush/tests/emitters.py b/saucebrush/tests/emitters.py index 9d6496d..1c8b9d8 100644 --- a/saucebrush/tests/emitters.py +++ b/saucebrush/tests/emitters.py @@ -1,6 +1,6 @@ import unittest from cStringIO import StringIO -from saucebrush.emitters import DebugEmitter, CSVEmitter +from saucebrush.emitters import DebugEmitter, CSVEmitter, CountEmitter class EmitterTestCase(unittest.TestCase): @@ -20,6 +20,15 @@ class EmitterTestCase(unittest.TestCase): for _ in data: pass self.assertEquals(self.output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n') + + def test_count_emitter(self): + ce = CountEmitter(every=10, outfile=self.output, format="%s records\n") + data = ce.attach([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]) + for _ in data: + pass + self.assertEquals(self.output.getvalue(), '10 records\n20 records\n') + ce.done() + self.assertEquals(self.output.getvalue(), '10 records\n20 records\n22 records\n') if __name__ == '__main__': unittest.main()