update CountEmitter to take an optional of argument to display 'x of y'
This commit is contained in:
parent
42213ff106
commit
ddd81c96b1
@ -61,24 +61,37 @@ class CountEmitter(Emitter):
|
|||||||
CountEmitter(every=1000000) would write the count every 1,000,000 records.
|
CountEmitter(every=1000000) would write the count every 1,000,000 records.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, every=1000, outfile=None, format=None):
|
def __init__(self, every=1000, of=None, outfile=None, format=None):
|
||||||
|
|
||||||
super(CountEmitter, self).__init__()
|
super(CountEmitter, self).__init__()
|
||||||
|
|
||||||
if not outfile:
|
if not outfile:
|
||||||
import sys
|
import sys
|
||||||
self._outfile = sys.stdout
|
self._outfile = sys.stdout
|
||||||
else:
|
else:
|
||||||
self._outfile = outfile
|
self._outfile = outfile
|
||||||
self._format = "%s\n" if format is None else format
|
|
||||||
|
if format is None:
|
||||||
|
if of is not None:
|
||||||
|
format = "%(count)s of %(of)s\n"
|
||||||
|
else:
|
||||||
|
format = "%(count)s\n"
|
||||||
|
|
||||||
|
self._format = format
|
||||||
self._every = every
|
self._every = every
|
||||||
|
self._of = of
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self._format % {'count': self.count, 'of': self._of}
|
||||||
|
|
||||||
def emit_record(self, record):
|
def emit_record(self, record):
|
||||||
self.count += 1
|
self.count += 1
|
||||||
if self.count % self._every == 0:
|
if self.count % self._every == 0:
|
||||||
self._outfile.write(self._format % self.count)
|
self._outfile.write(str(self))
|
||||||
|
|
||||||
def done(self):
|
def done(self):
|
||||||
self._outfile.write(self._format % self.count)
|
self._outfile.write(str(self))
|
||||||
|
|
||||||
|
|
||||||
class CSVEmitter(Emitter):
|
class CSVEmitter(Emitter):
|
||||||
|
@ -22,13 +22,26 @@ class EmitterTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(self.output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
|
self.assertEquals(self.output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
|
||||||
|
|
||||||
def test_count_emitter(self):
|
def test_count_emitter(self):
|
||||||
ce = CountEmitter(every=10, outfile=self.output, format="%s records\n")
|
|
||||||
data = ce.attach([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22])
|
# values for test
|
||||||
for _ in data:
|
values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
|
||||||
pass
|
|
||||||
|
# test without of parameter
|
||||||
|
ce = CountEmitter(every=10, outfile=self.output, format="%(count)s records\n")
|
||||||
|
list(ce.attach(values))
|
||||||
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n')
|
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n')
|
||||||
ce.done()
|
ce.done()
|
||||||
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n22 records\n')
|
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n22 records\n')
|
||||||
|
|
||||||
|
# reset output
|
||||||
|
self.output.truncate(0)
|
||||||
|
|
||||||
|
# test with of parameter
|
||||||
|
ce = CountEmitter(every=10, outfile=self.output, of=len(values))
|
||||||
|
list(ce.attach(values))
|
||||||
|
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n')
|
||||||
|
ce.done()
|
||||||
|
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user