Merge branch 'py3'
This commit is contained in:
commit
91143a7da0
@ -3,7 +3,7 @@ from saucebrush.outputs import CSVOutput, DebugOutput
|
|||||||
|
|
||||||
def merge_columns(datasource, mapping, merge_func):
|
def merge_columns(datasource, mapping, merge_func):
|
||||||
for rowdata in datasource:
|
for rowdata in datasource:
|
||||||
for to_col,from_cols in mapping.iteritems():
|
for to_col,from_cols in mapping.items():
|
||||||
values = [rowdata.pop(col, None) for col in from_cols]
|
values = [rowdata.pop(col, None) for col in from_cols]
|
||||||
rowdata[to_col] = reduce(merge_func, values)
|
rowdata[to_col] = reduce(merge_func, values)
|
||||||
yield rowdata
|
yield rowdata
|
||||||
|
@ -84,7 +84,7 @@ class FECSource(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_form_type(rectype):
|
def get_form_type(rectype):
|
||||||
for type_re, type in FECSource.FORM_MAPPING.iteritems():
|
for type_re, type in FECSource.FORM_MAPPING.items():
|
||||||
if type_re.match(rectype):
|
if type_re.match(rectype):
|
||||||
return type
|
return type
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Saucebrush is a data loading & manipulation framework written in python.
|
Saucebrush is a data loading & manipulation framework written in python.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import filters, emitters, sources, utils
|
from . import filters, emitters, sources, utils
|
||||||
|
|
||||||
|
|
||||||
class SaucebrushError(Exception):
|
class SaucebrushError(Exception):
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Saucebrush Emitters are filters that instead of modifying the record, output
|
Saucebrush Emitters are filters that instead of modifying the record, output
|
||||||
it in some manner.
|
it in some manner.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import unicode_literals
|
||||||
from saucebrush.filters import Filter
|
from saucebrush.filters import Filter
|
||||||
|
|
||||||
class Emitter(Filter):
|
class Emitter(Filter):
|
||||||
@ -50,12 +50,12 @@ class DebugEmitter(Emitter):
|
|||||||
self._outfile = outfile
|
self._outfile = outfile
|
||||||
|
|
||||||
def emit_record(self, record):
|
def emit_record(self, record):
|
||||||
self._outfile.write(str(record) + '\n')
|
self._outfile.write("{0}\n".format(record))
|
||||||
|
|
||||||
|
|
||||||
class CountEmitter(Emitter):
|
class CountEmitter(Emitter):
|
||||||
""" Emitter that writes the record count to a file-like object.
|
""" Emitter that writes the record count to a file-like object.
|
||||||
|
|
||||||
CountEmitter() by default writes to stdout.
|
CountEmitter() by default writes to stdout.
|
||||||
CountEmitter(outfile=open('text', 'w')) would print to a file name test.
|
CountEmitter(outfile=open('text', 'w')) would print to a file name test.
|
||||||
CountEmitter(every=1000000) would write the count every 1,000,000 records.
|
CountEmitter(every=1000000) would write the count every 1,000,000 records.
|
||||||
@ -63,36 +63,36 @@ class CountEmitter(Emitter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, every=1000, of=None, outfile=None, format=None):
|
def __init__(self, every=1000, of=None, outfile=None, format=None):
|
||||||
|
|
||||||
super(CountEmitter, self).__init__()
|
super(CountEmitter, self).__init__()
|
||||||
|
|
||||||
if not outfile:
|
if not outfile:
|
||||||
import sys
|
import sys
|
||||||
self._outfile = sys.stdout
|
self._outfile = sys.stdout
|
||||||
else:
|
else:
|
||||||
self._outfile = outfile
|
self._outfile = outfile
|
||||||
|
|
||||||
if format is None:
|
if format is None:
|
||||||
if of is not None:
|
if of is not None:
|
||||||
format = "%(count)s of %(of)s\n"
|
format = "%(count)s of %(of)s\n"
|
||||||
else:
|
else:
|
||||||
format = "%(count)s\n"
|
format = "%(count)s\n"
|
||||||
|
|
||||||
self._format = format
|
self._format = format
|
||||||
self._every = every
|
self._every = every
|
||||||
self._of = of
|
self._of = of
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
def __str__(self):
|
def format(self):
|
||||||
return self._format % {'count': self.count, 'of': self._of}
|
return self._format % {'count': self.count, 'of': self._of}
|
||||||
|
|
||||||
def emit_record(self, record):
|
def emit_record(self, record):
|
||||||
self.count += 1
|
self.count += 1
|
||||||
if self.count % self._every == 0:
|
if self.count % self._every == 0:
|
||||||
self._outfile.write(str(self))
|
self._outfile.write(self.format())
|
||||||
|
|
||||||
def done(self):
|
def done(self):
|
||||||
self._outfile.write(str(self))
|
self._outfile.write(self.format())
|
||||||
|
|
||||||
|
|
||||||
class CSVEmitter(Emitter):
|
class CSVEmitter(Emitter):
|
||||||
@ -107,7 +107,8 @@ class CSVEmitter(Emitter):
|
|||||||
import csv
|
import csv
|
||||||
self._dictwriter = csv.DictWriter(csvfile, fieldnames)
|
self._dictwriter = csv.DictWriter(csvfile, fieldnames)
|
||||||
# write header row
|
# write header row
|
||||||
self._dictwriter.writerow(dict(zip(fieldnames, fieldnames)))
|
header_row = dict(zip(fieldnames, fieldnames))
|
||||||
|
self._dictwriter.writerow(header_row)
|
||||||
|
|
||||||
def emit_record(self, record):
|
def emit_record(self, record):
|
||||||
self._dictwriter.writerow(record)
|
self._dictwriter.writerow(record)
|
||||||
@ -145,8 +146,8 @@ class SqliteEmitter(Emitter):
|
|||||||
','.join(record.keys()),
|
','.join(record.keys()),
|
||||||
qmarks)
|
qmarks)
|
||||||
try:
|
try:
|
||||||
self._cursor.execute(insert, record.values())
|
self._cursor.execute(insert, list(record.values()))
|
||||||
except sqlite3.IntegrityError, ie:
|
except sqlite3.IntegrityError as ie:
|
||||||
if not self._quiet:
|
if not self._quiet:
|
||||||
raise ie
|
raise ie
|
||||||
self.reject_record(record, ie.message)
|
self.reject_record(record, ie.message)
|
||||||
@ -179,21 +180,25 @@ class SqlDumpEmitter(Emitter):
|
|||||||
table_name, '`,`'.join(fieldnames))
|
table_name, '`,`'.join(fieldnames))
|
||||||
|
|
||||||
def quote(self, item):
|
def quote(self, item):
|
||||||
|
|
||||||
if item is None:
|
if item is None:
|
||||||
return "null"
|
return "null"
|
||||||
elif isinstance(item, (unicode, str)):
|
|
||||||
|
try:
|
||||||
|
types = (basestring,)
|
||||||
|
except NameError:
|
||||||
|
types = (str,)
|
||||||
|
|
||||||
|
if isinstance(item, types):
|
||||||
item = item.replace("\\","\\\\").replace("'","\\'").replace(chr(0),'0')
|
item = item.replace("\\","\\\\").replace("'","\\'").replace(chr(0),'0')
|
||||||
return "'%s'" % item
|
return "'%s'" % item
|
||||||
else:
|
|
||||||
return "%s" % item
|
return "%s" % item
|
||||||
|
|
||||||
def emit_record(self, record):
|
def emit_record(self, record):
|
||||||
quoted_data = [self.quote(record[field]) for field in self._fieldnames]
|
quoted_data = [self.quote(record[field]) for field in self._fieldnames]
|
||||||
self._outfile.write(self._insert_str % ','.join(quoted_data))
|
self._outfile.write(self._insert_str % ','.join(quoted_data))
|
||||||
|
|
||||||
def done(self):
|
|
||||||
self._outfile.close()
|
|
||||||
|
|
||||||
|
|
||||||
class DjangoModelEmitter(Emitter):
|
class DjangoModelEmitter(Emitter):
|
||||||
""" Emitter that populates a table corresponding to a django model.
|
""" Emitter that populates a table corresponding to a django model.
|
||||||
|
@ -217,17 +217,17 @@ class FieldModifier(FieldFilter):
|
|||||||
|
|
||||||
class FieldKeeper(Filter):
|
class FieldKeeper(Filter):
|
||||||
""" Filter that removes all but the given set of fields.
|
""" Filter that removes all but the given set of fields.
|
||||||
|
|
||||||
FieldKeeper(('spam', 'eggs')) removes all bu tthe spam and eggs
|
FieldKeeper(('spam', 'eggs')) removes all bu tthe spam and eggs
|
||||||
fields from every record filtered.
|
fields from every record filtered.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, keys):
|
def __init__(self, keys):
|
||||||
super(FieldKeeper, self).__init__()
|
super(FieldKeeper, self).__init__()
|
||||||
self._target_keys = utils.str_or_list(keys)
|
self._target_keys = utils.str_or_list(keys)
|
||||||
|
|
||||||
def process_record(self, record):
|
def process_record(self, record):
|
||||||
for key in record.keys():
|
for key in list(record.keys()):
|
||||||
if key not in self._target_keys:
|
if key not in self._target_keys:
|
||||||
del record[key]
|
del record[key]
|
||||||
return record
|
return record
|
||||||
@ -269,7 +269,7 @@ class FieldMerger(Filter):
|
|||||||
self._keep_fields = keep_fields
|
self._keep_fields = keep_fields
|
||||||
|
|
||||||
def process_record(self, record):
|
def process_record(self, record):
|
||||||
for to_col, from_cols in self._field_mapping.iteritems():
|
for to_col, from_cols in self._field_mapping.items():
|
||||||
if self._keep_fields:
|
if self._keep_fields:
|
||||||
values = [record.get(col, None) for col in from_cols]
|
values = [record.get(col, None) for col in from_cols]
|
||||||
else:
|
else:
|
||||||
@ -301,7 +301,11 @@ class FieldAdder(Filter):
|
|||||||
self._field_name = field_name
|
self._field_name = field_name
|
||||||
self._field_value = field_value
|
self._field_value = field_value
|
||||||
if hasattr(self._field_value, '__iter__'):
|
if hasattr(self._field_value, '__iter__'):
|
||||||
self._field_value = iter(self._field_value).next
|
value_iter = iter(self._field_value)
|
||||||
|
if hasattr(value_iter, "next"):
|
||||||
|
self._field_value = value_iter.next
|
||||||
|
else:
|
||||||
|
self._field_value = value_iter.__next__
|
||||||
self._replace = replace
|
self._replace = replace
|
||||||
|
|
||||||
def process_record(self, record):
|
def process_record(self, record):
|
||||||
@ -328,7 +332,7 @@ class FieldCopier(Filter):
|
|||||||
|
|
||||||
def process_record(self, record):
|
def process_record(self, record):
|
||||||
# mapping is dest:source
|
# mapping is dest:source
|
||||||
for dest, source in self._copy_mapping.iteritems():
|
for dest, source in self._copy_mapping.items():
|
||||||
record[dest] = record[source]
|
record[dest] = record[source]
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@ -343,7 +347,7 @@ class FieldRenamer(Filter):
|
|||||||
|
|
||||||
def process_record(self, record):
|
def process_record(self, record):
|
||||||
# mapping is dest:source
|
# mapping is dest:source
|
||||||
for dest, source in self._rename_mapping.iteritems():
|
for dest, source in self._rename_mapping.items():
|
||||||
record[dest] = record.pop(source)
|
record[dest] = record.pop(source)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@ -363,7 +367,7 @@ class Splitter(Filter):
|
|||||||
self._split_mapping = split_mapping
|
self._split_mapping = split_mapping
|
||||||
|
|
||||||
def process_record(self, record):
|
def process_record(self, record):
|
||||||
for key, filters in self._split_mapping.iteritems():
|
for key, filters in self._split_mapping.items():
|
||||||
|
|
||||||
# if the key doesn't exist -- move on to next key
|
# if the key doesn't exist -- move on to next key
|
||||||
try:
|
try:
|
||||||
@ -479,7 +483,7 @@ class UnicodeFilter(Filter):
|
|||||||
self._errors = errors
|
self._errors = errors
|
||||||
|
|
||||||
def process_record(self, record):
|
def process_record(self, record):
|
||||||
for key, value in record.iteritems():
|
for key, value in record.items():
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
record[key] = unicode(value, self._encoding, self._errors)
|
record[key] = unicode(value, self._encoding, self._errors)
|
||||||
elif isinstance(value, unicode):
|
elif isinstance(value, unicode):
|
||||||
@ -494,7 +498,7 @@ class StringFilter(Filter):
|
|||||||
self._errors = errors
|
self._errors = errors
|
||||||
|
|
||||||
def process_record(self, record):
|
def process_record(self, record):
|
||||||
for key, value in record.iteritems():
|
for key, value in record.items():
|
||||||
if isinstance(value, unicode):
|
if isinstance(value, unicode):
|
||||||
record[key] = value.encode(self._encoding, self._errors)
|
record[key] = value.encode(self._encoding, self._errors)
|
||||||
return record
|
return record
|
||||||
@ -584,7 +588,7 @@ class NameCleaner(Filter):
|
|||||||
# if there is a match, remove original name and add pieces
|
# if there is a match, remove original name and add pieces
|
||||||
if match:
|
if match:
|
||||||
record.pop(key)
|
record.pop(key)
|
||||||
for k,v in match.groupdict().iteritems():
|
for k,v in match.groupdict().items():
|
||||||
record[self._name_prefix + k] = v
|
record[self._name_prefix + k] = v
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -4,8 +4,9 @@
|
|||||||
All sources must implement the iterable interface and return python
|
All sources must implement the iterable interface and return python
|
||||||
dictionaries.
|
dictionaries.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import unicode_literals
|
||||||
import string
|
import string
|
||||||
|
|
||||||
from saucebrush import utils
|
from saucebrush import utils
|
||||||
|
|
||||||
class CSVSource(object):
|
class CSVSource(object):
|
||||||
@ -25,8 +26,8 @@ class CSVSource(object):
|
|||||||
def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs):
|
def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs):
|
||||||
import csv
|
import csv
|
||||||
self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs)
|
self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs)
|
||||||
for _ in xrange(skiprows):
|
for _ in range(skiprows):
|
||||||
self._dictreader.next()
|
next(self._dictreader)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self._dictreader
|
return self._dictreader
|
||||||
@ -59,13 +60,18 @@ class FixedWidthFileSource(object):
|
|||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def next(self):
|
def __next__(self):
|
||||||
line = self._fwfile.next()
|
line = next(self._fwfile)
|
||||||
record = {}
|
record = {}
|
||||||
for name, range_ in self._fields_dict.iteritems():
|
for name, range_ in self._fields_dict.items():
|
||||||
record[name] = line[range_[0]:range_[1]].rstrip(self._fillchars)
|
record[name] = line[range_[0]:range_[1]].rstrip(self._fillchars)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
""" Keep Python 2 next() method that defers to __next__().
|
||||||
|
"""
|
||||||
|
return self.__next__()
|
||||||
|
|
||||||
|
|
||||||
class HtmlTableSource(object):
|
class HtmlTableSource(object):
|
||||||
""" Saucebrush source for reading data from an HTML table.
|
""" Saucebrush source for reading data from an HTML table.
|
||||||
@ -86,26 +92,32 @@ class HtmlTableSource(object):
|
|||||||
def __init__(self, htmlfile, id_or_num, fieldnames=None, skiprows=0):
|
def __init__(self, htmlfile, id_or_num, fieldnames=None, skiprows=0):
|
||||||
|
|
||||||
# extract the table
|
# extract the table
|
||||||
from BeautifulSoup import BeautifulSoup
|
from lxml.html import parse
|
||||||
soup = BeautifulSoup(htmlfile.read())
|
doc = parse(htmlfile).getroot()
|
||||||
if isinstance(id_or_num, int):
|
if isinstance(id_or_num, int):
|
||||||
table = soup.findAll('table')[id_or_num]
|
table = doc.cssselect('table')[id_or_num]
|
||||||
elif isinstance(id_or_num, str):
|
else:
|
||||||
table = soup.find('table', id=id_or_num)
|
table = doc.cssselect('table#%s' % id_or_num)
|
||||||
|
|
||||||
|
table = table[0] # get the first table
|
||||||
|
|
||||||
# skip the necessary number of rows
|
# skip the necessary number of rows
|
||||||
self._rows = table.findAll('tr')[skiprows:]
|
self._rows = table.cssselect('tr')[skiprows:]
|
||||||
|
|
||||||
# determine the fieldnames
|
# determine the fieldnames
|
||||||
if not fieldnames:
|
if not fieldnames:
|
||||||
self._fieldnames = [td.string
|
self._fieldnames = [td.text_content()
|
||||||
for td in self._rows[0].findAll(('td','th'))]
|
for td in self._rows[0].cssselect('td, th')]
|
||||||
|
skiprows += 1
|
||||||
else:
|
else:
|
||||||
self._fieldnames = fieldnames
|
self._fieldnames = fieldnames
|
||||||
|
|
||||||
|
# skip the necessary number of rows
|
||||||
|
self._rows = table.cssselect('tr')[skiprows:]
|
||||||
|
|
||||||
def process_tr(self):
|
def process_tr(self):
|
||||||
for row in self._rows:
|
for row in self._rows:
|
||||||
strings = [utils.string_dig(td) for td in row.findAll('td')]
|
strings = [td.text_content() for td in row.cssselect('td')]
|
||||||
yield dict(zip(self._fieldnames, strings))
|
yield dict(zip(self._fieldnames, strings))
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@ -182,7 +194,7 @@ class SqliteSource(object):
|
|||||||
self._conn = sqlite3.connect(self._dbpath)
|
self._conn = sqlite3.connect(self._dbpath)
|
||||||
self._conn.row_factory = dict_factory
|
self._conn.row_factory = dict_factory
|
||||||
if self._conn_params:
|
if self._conn_params:
|
||||||
for param, value in self._conn_params.iteritems():
|
for param, value in self._conn_params.items():
|
||||||
setattr(self._conn, param, value)
|
setattr(self._conn, param, value)
|
||||||
|
|
||||||
def _process_query(self):
|
def _process_query(self):
|
||||||
@ -214,20 +226,20 @@ class FileSource(object):
|
|||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
# This method would be a lot cleaner with the proposed
|
# This method would be a lot cleaner with the proposed
|
||||||
# 'yield from' expression (PEP 380)
|
# 'yield from' expression (PEP 380)
|
||||||
if hasattr(self._input, '__read__'):
|
if hasattr(self._input, '__read__') or hasattr(self._input, 'read'):
|
||||||
for record in self._process_file(input):
|
for record in self._process_file(self._input):
|
||||||
yield record
|
yield record
|
||||||
elif isinstance(self._input, basestring):
|
elif isinstance(self._input, str):
|
||||||
with open(self._input) as f:
|
with open(self._input) as f:
|
||||||
for record in self._process_file(f):
|
for record in self._process_file(f):
|
||||||
yield record
|
yield record
|
||||||
elif hasattr(self._input, '__iter__'):
|
elif hasattr(self._input, '__iter__'):
|
||||||
for el in self._input:
|
for el in self._input:
|
||||||
if isinstance(el, basestring):
|
if isinstance(el, str):
|
||||||
with open(el) as f:
|
with open(el) as f:
|
||||||
for record in self._process_file(f):
|
for record in self._process_file(f):
|
||||||
yield record
|
yield record
|
||||||
elif hasattr(el, '__read__'):
|
elif hasattr(el, '__read__') or hasattr(el, 'read'):
|
||||||
for record in self._process_file(f):
|
for record in self._process_file(f):
|
||||||
yield record
|
yield record
|
||||||
|
|
||||||
@ -244,10 +256,11 @@ class JSONSource(FileSource):
|
|||||||
object.
|
object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _process_file(self, file):
|
def _process_file(self, f):
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
obj = json.load(file)
|
obj = json.load(f)
|
||||||
|
|
||||||
# If the top-level JSON object in the file is a list
|
# If the top-level JSON object in the file is a list
|
||||||
# then yield each element separately; otherwise, yield
|
# then yield each element separately; otherwise, yield
|
||||||
|
@ -5,7 +5,7 @@ import math
|
|||||||
|
|
||||||
def _average(values):
|
def _average(values):
|
||||||
""" Calculate the average of a list of values.
|
""" Calculate the average of a list of values.
|
||||||
|
|
||||||
:param values: an iterable of ints or floats to average
|
:param values: an iterable of ints or floats to average
|
||||||
"""
|
"""
|
||||||
value_count = len(values)
|
value_count = len(values)
|
||||||
@ -14,64 +14,64 @@ def _average(values):
|
|||||||
|
|
||||||
def _median(values):
|
def _median(values):
|
||||||
""" Calculate the median of a list of values.
|
""" Calculate the median of a list of values.
|
||||||
|
|
||||||
:param values: an iterable of ints or floats to calculate
|
:param values: an iterable of ints or floats to calculate
|
||||||
"""
|
"""
|
||||||
|
|
||||||
count = len(values)
|
count = len(values)
|
||||||
|
|
||||||
# bail early before sorting if 0 or 1 values in list
|
# bail early before sorting if 0 or 1 values in list
|
||||||
if count == 0:
|
if count == 0:
|
||||||
return None
|
return None
|
||||||
elif count == 1:
|
elif count == 1:
|
||||||
return values[0]
|
return values[0]
|
||||||
|
|
||||||
values = sorted(values)
|
values = sorted(values)
|
||||||
|
|
||||||
if count % 2 == 1:
|
if count % 2 == 1:
|
||||||
# odd number of items, return middle value
|
# odd number of items, return middle value
|
||||||
return float(values[count / 2])
|
return float(values[int(count / 2)])
|
||||||
else:
|
else:
|
||||||
# even number of items, return average of middle two items
|
# even number of items, return average of middle two items
|
||||||
mid = count / 2
|
mid = int(count / 2)
|
||||||
return sum(values[mid - 1:mid + 1]) / 2.0
|
return sum(values[mid - 1:mid + 1]) / 2.0
|
||||||
|
|
||||||
def _stddev(values, population=False):
|
def _stddev(values, population=False):
|
||||||
""" Calculate the standard deviation and variance of a list of values.
|
""" Calculate the standard deviation and variance of a list of values.
|
||||||
|
|
||||||
:param values: an iterable of ints or floats to calculate
|
:param values: an iterable of ints or floats to calculate
|
||||||
:param population: True if values represents entire population,
|
:param population: True if values represents entire population,
|
||||||
False if it is a sample of the population
|
False if it is a sample of the population
|
||||||
"""
|
"""
|
||||||
|
|
||||||
avg = _average(values)
|
avg = _average(values)
|
||||||
count = len(values) if population else len(values) - 1
|
count = len(values) if population else len(values) - 1
|
||||||
|
|
||||||
# square the difference between each value and the average
|
# square the difference between each value and the average
|
||||||
diffsq = ((i - avg) ** 2 for i in values)
|
diffsq = ((i - avg) ** 2 for i in values)
|
||||||
|
|
||||||
# the average of the squared differences
|
# the average of the squared differences
|
||||||
variance = sum(diffsq) / float(count)
|
variance = sum(diffsq) / float(count)
|
||||||
|
|
||||||
return (math.sqrt(variance), variance) # stddev is sqrt of variance
|
return (math.sqrt(variance), variance) # stddev is sqrt of variance
|
||||||
|
|
||||||
class StatsFilter(Filter):
|
class StatsFilter(Filter):
|
||||||
""" Base for all stats filters.
|
""" Base for all stats filters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, field, test=None):
|
def __init__(self, field, test=None):
|
||||||
self._field = field
|
self._field = field
|
||||||
self._test = test
|
self._test = test
|
||||||
|
|
||||||
def process_record(self, record):
|
def process_record(self, record):
|
||||||
if self._test is None or self._test(record):
|
if self._test is None or self._test(record):
|
||||||
self.process_field(record[self._field])
|
self.process_field(record[self._field])
|
||||||
return record
|
return record
|
||||||
|
|
||||||
def process_field(self, record):
|
def process_field(self, record):
|
||||||
raise NotImplementedError('process_field not defined in ' +
|
raise NotImplementedError('process_field not defined in ' +
|
||||||
self.__class__.__name__)
|
self.__class__.__name__)
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
raise NotImplementedError('value not defined in ' +
|
raise NotImplementedError('value not defined in ' +
|
||||||
self.__class__.__name__)
|
self.__class__.__name__)
|
||||||
@ -80,14 +80,14 @@ class Sum(StatsFilter):
|
|||||||
""" Calculate the sum of the values in a field. Field must contain either
|
""" Calculate the sum of the values in a field. Field must contain either
|
||||||
int or float values.
|
int or float values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, field, initial=0, **kwargs):
|
def __init__(self, field, initial=0, **kwargs):
|
||||||
super(Sum, self).__init__(field, **kwargs)
|
super(Sum, self).__init__(field, **kwargs)
|
||||||
self._value = initial
|
self._value = initial
|
||||||
|
|
||||||
def process_field(self, item):
|
def process_field(self, item):
|
||||||
self._value += item or 0
|
self._value += item or 0
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
return self._value
|
return self._value
|
||||||
|
|
||||||
@ -95,35 +95,35 @@ class Average(StatsFilter):
|
|||||||
""" Calculate the average (mean) of the values in a field. Field must
|
""" Calculate the average (mean) of the values in a field. Field must
|
||||||
contain either int or float values.
|
contain either int or float values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, field, initial=0, **kwargs):
|
def __init__(self, field, initial=0, **kwargs):
|
||||||
super(Average, self).__init__(field, **kwargs)
|
super(Average, self).__init__(field, **kwargs)
|
||||||
self._value = initial
|
self._value = initial
|
||||||
self._count = 0
|
self._count = 0
|
||||||
|
|
||||||
def process_field(self, item):
|
def process_field(self, item):
|
||||||
if item is not None:
|
if item is not None:
|
||||||
self._value += item
|
self._value += item
|
||||||
self._count += 1
|
self._count += 1
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
return self._value / float(self._count)
|
return self._value / float(self._count)
|
||||||
|
|
||||||
class Median(StatsFilter):
|
class Median(StatsFilter):
|
||||||
""" Calculate the median of the values in a field. Field must contain
|
""" Calculate the median of the values in a field. Field must contain
|
||||||
either int or float values.
|
either int or float values.
|
||||||
|
|
||||||
**This filter keeps a list of field values in memory.**
|
**This filter keeps a list of field values in memory.**
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, field, **kwargs):
|
def __init__(self, field, **kwargs):
|
||||||
super(Median, self).__init__(field, **kwargs)
|
super(Median, self).__init__(field, **kwargs)
|
||||||
self._values = []
|
self._values = []
|
||||||
|
|
||||||
def process_field(self, item):
|
def process_field(self, item):
|
||||||
if item is not None:
|
if item is not None:
|
||||||
self._values.append(item)
|
self._values.append(item)
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
return _median(self._values)
|
return _median(self._values)
|
||||||
|
|
||||||
@ -131,19 +131,19 @@ class MinMax(StatsFilter):
|
|||||||
""" Find the minimum and maximum values in a field. Field must contain
|
""" Find the minimum and maximum values in a field. Field must contain
|
||||||
either int or float values.
|
either int or float values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, field, **kwargs):
|
def __init__(self, field, **kwargs):
|
||||||
super(MinMax, self).__init__(field, **kwargs)
|
super(MinMax, self).__init__(field, **kwargs)
|
||||||
self._max = None
|
self._max = None
|
||||||
self._min = None
|
self._min = None
|
||||||
|
|
||||||
def process_field(self, item):
|
def process_field(self, item):
|
||||||
if item is not None:
|
if item is not None:
|
||||||
if self._max is None or item > self._max:
|
if self._max is None or item > self._max:
|
||||||
self._max = item
|
self._max = item
|
||||||
if self._min is None or item < self._min:
|
if self._min is None or item < self._min:
|
||||||
self._min = item
|
self._min = item
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
return (self._min, self._max)
|
return (self._min, self._max)
|
||||||
|
|
||||||
@ -156,24 +156,24 @@ class StandardDeviation(StatsFilter):
|
|||||||
|
|
||||||
**This filter keeps a list of field values in memory.**
|
**This filter keeps a list of field values in memory.**
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, field, **kwargs):
|
def __init__(self, field, **kwargs):
|
||||||
super(StandardDeviation, self).__init__(field, **kwargs)
|
super(StandardDeviation, self).__init__(field, **kwargs)
|
||||||
self._values = []
|
self._values = []
|
||||||
|
|
||||||
def process_field(self, item):
|
def process_field(self, item):
|
||||||
if item is not None:
|
if item is not None:
|
||||||
self._values.append(item)
|
self._values.append(item)
|
||||||
|
|
||||||
def average(self):
|
def average(self):
|
||||||
return _average(self._values)
|
return _average(self._values)
|
||||||
|
|
||||||
def median(self):
|
def median(self):
|
||||||
return _median(self._values)
|
return _median(self._values)
|
||||||
|
|
||||||
def value(self, population=False):
|
def value(self, population=False):
|
||||||
""" Return a tuple of (standard_deviation, variance).
|
""" Return a tuple of (standard_deviation, variance).
|
||||||
|
|
||||||
:param population: True if values represents entire population,
|
:param population: True if values represents entire population,
|
||||||
False if values is a sample. Default: False
|
False if values is a sample. Default: False
|
||||||
"""
|
"""
|
||||||
@ -185,34 +185,34 @@ class Histogram(StatsFilter):
|
|||||||
generates a basic and limited histogram useful for printing to the
|
generates a basic and limited histogram useful for printing to the
|
||||||
command line. The label_length attribute determines the padding and
|
command line. The label_length attribute determines the padding and
|
||||||
cut-off of the basic histogram labels.
|
cut-off of the basic histogram labels.
|
||||||
|
|
||||||
**This filters maintains a dict of unique field values in memory.**
|
**This filters maintains a dict of unique field values in memory.**
|
||||||
"""
|
"""
|
||||||
|
|
||||||
label_length = 6
|
label_length = 6
|
||||||
|
|
||||||
def __init__(self, field, **kwargs):
|
def __init__(self, field, **kwargs):
|
||||||
super(Histogram, self).__init__(field, **kwargs)
|
super(Histogram, self).__init__(field, **kwargs)
|
||||||
self._counter = collections.Counter()
|
self._counter = collections.Counter()
|
||||||
|
|
||||||
def process_field(self, item):
|
def process_field(self, item):
|
||||||
self._counter[self.prep_field(item)] += 1
|
self._counter[self.prep_field(item)] += 1
|
||||||
|
|
||||||
def prep_field(self, item):
|
def prep_field(self, item):
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
return self._counter.copy()
|
return self._counter.copy()
|
||||||
|
|
||||||
def in_order(self):
|
def in_order(self):
|
||||||
ordered = []
|
ordered = []
|
||||||
for key in sorted(self._counter.keys()):
|
for key in sorted(self._counter.keys()):
|
||||||
ordered.append((key, self._counter[key]))
|
ordered.append((key, self._counter[key]))
|
||||||
return ordered
|
return ordered
|
||||||
|
|
||||||
def most_common(self, n=None):
|
def most_common(self, n=None):
|
||||||
return self._counter.most_common(n)
|
return self._counter.most_common(n)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def as_string(self, occurences, label_length):
|
def as_string(self, occurences, label_length):
|
||||||
output = "\n"
|
output = "\n"
|
||||||
@ -220,6 +220,6 @@ class Histogram(StatsFilter):
|
|||||||
key_str = str(key).ljust(label_length)[:label_length]
|
key_str = str(key).ljust(label_length)[:label_length]
|
||||||
output += "%s %s\n" % (key_str, "*" * count)
|
output += "%s %s\n" % (key_str, "*" * count)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return Histogram.as_string(self.in_order(), label_length=self.label_length)
|
return Histogram.as_string(self.in_order(), label_length=self.label_length)
|
||||||
|
@ -1,47 +1,86 @@
|
|||||||
|
from __future__ import unicode_literals
|
||||||
|
from contextlib import closing
|
||||||
|
from io import StringIO
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from cStringIO import StringIO
|
|
||||||
from saucebrush.emitters import DebugEmitter, CSVEmitter, CountEmitter
|
from saucebrush.emitters import (
|
||||||
|
DebugEmitter, CSVEmitter, CountEmitter, SqliteEmitter, SqlDumpEmitter)
|
||||||
|
|
||||||
class EmitterTestCase(unittest.TestCase):
|
class EmitterTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.output = StringIO()
|
|
||||||
|
|
||||||
def test_debug_emitter(self):
|
def test_debug_emitter(self):
|
||||||
de = DebugEmitter(self.output)
|
with closing(StringIO()) as output:
|
||||||
data = de.attach([1,2,3])
|
de = DebugEmitter(output)
|
||||||
for _ in data:
|
list(de.attach([1,2,3]))
|
||||||
pass
|
self.assertEqual(output.getvalue(), '1\n2\n3\n')
|
||||||
self.assertEquals(self.output.getvalue(), '1\n2\n3\n')
|
|
||||||
|
|
||||||
def test_csv_emitter(self):
|
|
||||||
ce = CSVEmitter(self.output, ('x','y','z'))
|
|
||||||
data = ce.attach([{'x':1,'y':2,'z':3}, {'x':5, 'y':5, 'z':5}])
|
|
||||||
for _ in data:
|
|
||||||
pass
|
|
||||||
self.assertEquals(self.output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
|
|
||||||
|
|
||||||
def test_count_emitter(self):
|
def test_count_emitter(self):
|
||||||
|
|
||||||
# values for test
|
# values for test
|
||||||
values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
|
values = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
|
||||||
|
|
||||||
# test without of parameter
|
with closing(StringIO()) as output:
|
||||||
ce = CountEmitter(every=10, outfile=self.output, format="%(count)s records\n")
|
|
||||||
list(ce.attach(values))
|
# test without of parameter
|
||||||
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n')
|
ce = CountEmitter(every=10, outfile=output, format="%(count)s records\n")
|
||||||
ce.done()
|
list(ce.attach(values))
|
||||||
self.assertEquals(self.output.getvalue(), '10 records\n20 records\n22 records\n')
|
self.assertEqual(output.getvalue(), '10 records\n20 records\n')
|
||||||
|
ce.done()
|
||||||
# reset output
|
self.assertEqual(output.getvalue(), '10 records\n20 records\n22 records\n')
|
||||||
self.output.truncate(0)
|
|
||||||
|
with closing(StringIO()) as output:
|
||||||
# test with of parameter
|
|
||||||
ce = CountEmitter(every=10, outfile=self.output, of=len(values))
|
# test with of parameter
|
||||||
list(ce.attach(values))
|
ce = CountEmitter(every=10, outfile=output, of=len(values))
|
||||||
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n')
|
list(ce.attach(values))
|
||||||
ce.done()
|
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n')
|
||||||
self.assertEquals(self.output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n')
|
ce.done()
|
||||||
|
self.assertEqual(output.getvalue(), '10 of 22\n20 of 22\n22 of 22\n')
|
||||||
|
|
||||||
|
def test_csv_emitter(self):
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cStringIO # if Python 2.x then use old cStringIO
|
||||||
|
io = cStringIO.StringIO()
|
||||||
|
except:
|
||||||
|
io = StringIO() # if Python 3.x then use StringIO
|
||||||
|
|
||||||
|
with closing(io) as output:
|
||||||
|
ce = CSVEmitter(output, ('x','y','z'))
|
||||||
|
list(ce.attach([{'x':1, 'y':2, 'z':3}, {'x':5, 'y':5, 'z':5}]))
|
||||||
|
self.assertEqual(output.getvalue(), 'x,y,z\r\n1,2,3\r\n5,5,5\r\n')
|
||||||
|
|
||||||
|
def test_sqlite_emitter(self):
|
||||||
|
|
||||||
|
import sqlite3, tempfile
|
||||||
|
|
||||||
|
with closing(tempfile.NamedTemporaryFile(suffix='.db')) as f:
|
||||||
|
db_path = f.name
|
||||||
|
|
||||||
|
sle = SqliteEmitter(db_path, 'testtable', fieldnames=('a','b','c'))
|
||||||
|
list(sle.attach([{'a': '1', 'b': '2', 'c': '3'}]))
|
||||||
|
sle.done()
|
||||||
|
|
||||||
|
with closing(sqlite3.connect(db_path)) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("""SELECT a, b, c FROM testtable""")
|
||||||
|
results = cur.fetchall()
|
||||||
|
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
self.assertEqual(results, [('1', '2', '3')])
|
||||||
|
|
||||||
|
def test_sql_dump_emitter(self):
|
||||||
|
|
||||||
|
with closing(StringIO()) as bffr:
|
||||||
|
|
||||||
|
sde = SqlDumpEmitter(bffr, 'testtable', ('a', 'b'))
|
||||||
|
list(sde.attach([{'a': 1, 'b': '2'}]))
|
||||||
|
sde.done()
|
||||||
|
|
||||||
|
self.assertEqual(bffr.getvalue(), "INSERT INTO `testtable` (`a`,`b`) VALUES (1,'2');\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -57,41 +57,42 @@ class FilterTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def assert_filter_result(self, filter_obj, expected_data):
|
def assert_filter_result(self, filter_obj, expected_data):
|
||||||
result = filter_obj.attach(self._simple_data())
|
result = filter_obj.attach(self._simple_data())
|
||||||
self.assertEquals(list(result), expected_data)
|
self.assertEqual(list(result), expected_data)
|
||||||
|
|
||||||
def test_reject_record(self):
|
def test_reject_record(self):
|
||||||
recipe = DummyRecipe()
|
recipe = DummyRecipe()
|
||||||
f = Doubler()
|
f = Doubler()
|
||||||
result = f.attach([1,2,3], recipe=recipe)
|
result = f.attach([1,2,3], recipe=recipe)
|
||||||
result.next() # next has to be called for attach to take effect
|
# next has to be called for attach to take effect
|
||||||
|
next(result)
|
||||||
f.reject_record('bad', 'this one was bad')
|
f.reject_record('bad', 'this one was bad')
|
||||||
|
|
||||||
# ensure that the rejection propagated to the recipe
|
# ensure that the rejection propagated to the recipe
|
||||||
self.assertEquals('bad', recipe.rejected_record)
|
self.assertEqual('bad', recipe.rejected_record)
|
||||||
self.assertEquals('this one was bad', recipe.rejected_msg)
|
self.assertEqual('this one was bad', recipe.rejected_msg)
|
||||||
|
|
||||||
def test_simple_filter(self):
|
def test_simple_filter(self):
|
||||||
df = Doubler()
|
df = Doubler()
|
||||||
result = df.attach([1,2,3])
|
result = df.attach([1,2,3])
|
||||||
|
|
||||||
# ensure we got a generator that yields 2,4,6
|
# ensure we got a generator that yields 2,4,6
|
||||||
self.assertEquals(type(result), types.GeneratorType)
|
self.assertEqual(type(result), types.GeneratorType)
|
||||||
self.assertEquals(list(result), [2,4,6])
|
self.assertEqual(list(result), [2,4,6])
|
||||||
|
|
||||||
def test_simple_filter_return_none(self):
|
def test_simple_filter_return_none(self):
|
||||||
cf = OddRemover()
|
cf = OddRemover()
|
||||||
result = cf.attach(range(10))
|
result = cf.attach(range(10))
|
||||||
|
|
||||||
# ensure only even numbers remain
|
# ensure only even numbers remain
|
||||||
self.assertEquals(list(result), [0,2,4,6,8])
|
self.assertEqual(list(result), [0,2,4,6,8])
|
||||||
|
|
||||||
def test_simple_yield_filter(self):
|
def test_simple_yield_filter(self):
|
||||||
lf = ListFlattener()
|
lf = ListFlattener()
|
||||||
result = lf.attach([[1],[2,3],[4,5,6]])
|
result = lf.attach([[1],[2,3],[4,5,6]])
|
||||||
|
|
||||||
# ensure we got a generator that yields 1,2,3,4,5,6
|
# ensure we got a generator that yields 1,2,3,4,5,6
|
||||||
self.assertEquals(type(result), types.GeneratorType)
|
self.assertEqual(type(result), types.GeneratorType)
|
||||||
self.assertEquals(list(result), [1,2,3,4,5,6])
|
self.assertEqual(list(result), [1,2,3,4,5,6])
|
||||||
|
|
||||||
def test_simple_field_filter(self):
|
def test_simple_field_filter(self):
|
||||||
ff = FieldDoubler(['a', 'c'])
|
ff = FieldDoubler(['a', 'c'])
|
||||||
@ -107,7 +108,7 @@ class FilterTestCase(unittest.TestCase):
|
|||||||
result = cf.attach(range(10))
|
result = cf.attach(range(10))
|
||||||
|
|
||||||
# ensure only even numbers remain
|
# ensure only even numbers remain
|
||||||
self.assertEquals(list(result), [0,2,4,6,8])
|
self.assertEqual(list(result), [0,2,4,6,8])
|
||||||
|
|
||||||
### Tests for Subrecord
|
### Tests for Subrecord
|
||||||
|
|
||||||
@ -123,7 +124,7 @@ class FilterTestCase(unittest.TestCase):
|
|||||||
sf = SubrecordFilter('a', NonModifyingFieldDoubler('b'))
|
sf = SubrecordFilter('a', NonModifyingFieldDoubler('b'))
|
||||||
result = sf.attach(data)
|
result = sf.attach(data)
|
||||||
|
|
||||||
self.assertEquals(list(result), expected)
|
self.assertEqual(list(result), expected)
|
||||||
|
|
||||||
def test_subrecord_filter_deep(self):
|
def test_subrecord_filter_deep(self):
|
||||||
data = [{'a': {'d':[{'b': 2}, {'b': 4}]}},
|
data = [{'a': {'d':[{'b': 2}, {'b': 4}]}},
|
||||||
@ -137,7 +138,7 @@ class FilterTestCase(unittest.TestCase):
|
|||||||
sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b'))
|
sf = SubrecordFilter('a.d', NonModifyingFieldDoubler('b'))
|
||||||
result = sf.attach(data)
|
result = sf.attach(data)
|
||||||
|
|
||||||
self.assertEquals(list(result), expected)
|
self.assertEqual(list(result), expected)
|
||||||
|
|
||||||
def test_subrecord_filter_nonlist(self):
|
def test_subrecord_filter_nonlist(self):
|
||||||
data = [
|
data = [
|
||||||
@ -155,7 +156,7 @@ class FilterTestCase(unittest.TestCase):
|
|||||||
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
|
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
|
||||||
result = sf.attach(data)
|
result = sf.attach(data)
|
||||||
|
|
||||||
self.assertEquals(list(result), expected)
|
self.assertEqual(list(result), expected)
|
||||||
|
|
||||||
def test_subrecord_filter_list_in_path(self):
|
def test_subrecord_filter_list_in_path(self):
|
||||||
data = [
|
data = [
|
||||||
@ -173,7 +174,7 @@ class FilterTestCase(unittest.TestCase):
|
|||||||
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
|
sf = SubrecordFilter('a.b', NonModifyingFieldDoubler('c'))
|
||||||
result = sf.attach(data)
|
result = sf.attach(data)
|
||||||
|
|
||||||
self.assertEquals(list(result), expected)
|
self.assertEqual(list(result), expected)
|
||||||
|
|
||||||
def test_conditional_path(self):
|
def test_conditional_path(self):
|
||||||
|
|
||||||
@ -202,7 +203,7 @@ class FilterTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_field_keeper(self):
|
def test_field_keeper(self):
|
||||||
fk = FieldKeeper(['c'])
|
fk = FieldKeeper(['c'])
|
||||||
|
|
||||||
# check against expected results
|
# check against expected results
|
||||||
expected_data = [{'c':3}, {'c':5}, {'c':100}]
|
expected_data = [{'c':3}, {'c':5}, {'c':100}]
|
||||||
self.assert_filter_result(fk, expected_data)
|
self.assert_filter_result(fk, expected_data)
|
||||||
@ -295,7 +296,7 @@ class FilterTestCase(unittest.TestCase):
|
|||||||
expected_data = [{'a': 77}, {'a':33}]
|
expected_data = [{'a': 77}, {'a':33}]
|
||||||
result = u.attach(in_data)
|
result = u.attach(in_data)
|
||||||
|
|
||||||
self.assertEquals(list(result), expected_data)
|
self.assertEqual(list(result), expected_data)
|
||||||
|
|
||||||
# TODO: unicode & string filter tests
|
# TODO: unicode & string filter tests
|
||||||
|
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
|
from __future__ import unicode_literals
|
||||||
|
from io import BytesIO, StringIO
|
||||||
import unittest
|
import unittest
|
||||||
import cStringIO
|
|
||||||
from saucebrush.sources import CSVSource, FixedWidthFileSource
|
from saucebrush.sources import (
|
||||||
|
CSVSource, FixedWidthFileSource, HtmlTableSource, JSONSource)
|
||||||
|
|
||||||
class SourceTestCase(unittest.TestCase):
|
class SourceTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@ -9,14 +12,14 @@ class SourceTestCase(unittest.TestCase):
|
|||||||
1,2,3
|
1,2,3
|
||||||
5,5,5
|
5,5,5
|
||||||
1,10,100'''
|
1,10,100'''
|
||||||
return cStringIO.StringIO(data)
|
return StringIO(data)
|
||||||
|
|
||||||
def test_csv_source_basic(self):
|
def test_csv_source_basic(self):
|
||||||
source = CSVSource(self._get_csv())
|
source = CSVSource(self._get_csv())
|
||||||
expected_data = [{'a':'1', 'b':'2', 'c':'3'},
|
expected_data = [{'a':'1', 'b':'2', 'c':'3'},
|
||||||
{'a':'5', 'b':'5', 'c':'5'},
|
{'a':'5', 'b':'5', 'c':'5'},
|
||||||
{'a':'1', 'b':'10', 'c':'100'}]
|
{'a':'1', 'b':'10', 'c':'100'}]
|
||||||
self.assertEquals(list(source), expected_data)
|
self.assertEqual(list(source), expected_data)
|
||||||
|
|
||||||
def test_csv_source_fieldnames(self):
|
def test_csv_source_fieldnames(self):
|
||||||
source = CSVSource(self._get_csv(), ['x','y','z'])
|
source = CSVSource(self._get_csv(), ['x','y','z'])
|
||||||
@ -24,35 +27,59 @@ class SourceTestCase(unittest.TestCase):
|
|||||||
{'x':'1', 'y':'2', 'z':'3'},
|
{'x':'1', 'y':'2', 'z':'3'},
|
||||||
{'x':'5', 'y':'5', 'z':'5'},
|
{'x':'5', 'y':'5', 'z':'5'},
|
||||||
{'x':'1', 'y':'10', 'z':'100'}]
|
{'x':'1', 'y':'10', 'z':'100'}]
|
||||||
self.assertEquals(list(source), expected_data)
|
self.assertEqual(list(source), expected_data)
|
||||||
|
|
||||||
def test_csv_source_skiprows(self):
|
def test_csv_source_skiprows(self):
|
||||||
source = CSVSource(self._get_csv(), skiprows=1)
|
source = CSVSource(self._get_csv(), skiprows=1)
|
||||||
expected_data = [{'a':'5', 'b':'5', 'c':'5'},
|
expected_data = [{'a':'5', 'b':'5', 'c':'5'},
|
||||||
{'a':'1', 'b':'10', 'c':'100'}]
|
{'a':'1', 'b':'10', 'c':'100'}]
|
||||||
self.assertEquals(list(source), expected_data)
|
self.assertEqual(list(source), expected_data)
|
||||||
|
|
||||||
def test_fixed_width_source(self):
|
def test_fixed_width_source(self):
|
||||||
data = cStringIO.StringIO('JamesNovember 3 1986\nTim September151999')
|
data = StringIO('JamesNovember 3 1986\nTim September151999')
|
||||||
fields = (('name',5), ('month',9), ('day',2), ('year',4))
|
fields = (('name',5), ('month',9), ('day',2), ('year',4))
|
||||||
source = FixedWidthFileSource(data, fields)
|
source = FixedWidthFileSource(data, fields)
|
||||||
expected_data = [{'name':'James', 'month':'November', 'day':'3',
|
expected_data = [{'name':'James', 'month':'November', 'day':'3',
|
||||||
'year':'1986'},
|
'year':'1986'},
|
||||||
{'name':'Tim', 'month':'September', 'day':'15',
|
{'name':'Tim', 'month':'September', 'day':'15',
|
||||||
'year':'1999'}]
|
'year':'1999'}]
|
||||||
self.assertEquals(list(source), expected_data)
|
self.assertEqual(list(source), expected_data)
|
||||||
|
|
||||||
def test_fixed_width_source(self):
|
def test_json_source(self):
|
||||||
data = cStringIO.StringIO('JamesNovember.3.1986\nTim..September151999')
|
|
||||||
fields = (('name',5), ('month',9), ('day',2), ('year',4))
|
|
||||||
source = FixedWidthFileSource(data, fields, fillchars='.')
|
|
||||||
expected_data = [{'name':'James', 'month':'November', 'day':'3',
|
|
||||||
'year':'1986'},
|
|
||||||
{'name':'Tim', 'month':'September', 'day':'15',
|
|
||||||
'year':'1999'}]
|
|
||||||
self.assertEquals(list(source), expected_data)
|
|
||||||
|
|
||||||
|
content = StringIO("""[{"a": 1, "b": "2", "c": 3}]""")
|
||||||
|
|
||||||
|
js = JSONSource(content)
|
||||||
|
self.assertEqual(list(js), [{'a': 1, 'b': '2', 'c': 3}])
|
||||||
|
|
||||||
|
def test_html_table_source(self):
|
||||||
|
|
||||||
|
content = StringIO("""
|
||||||
|
<html>
|
||||||
|
<table id="thetable">
|
||||||
|
<tr>
|
||||||
|
<th>a</th>
|
||||||
|
<th>b</th>
|
||||||
|
<th>c</th>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>1</td>
|
||||||
|
<td>2</td>
|
||||||
|
<td>3</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</html>
|
||||||
|
""")
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
import lxml
|
||||||
|
|
||||||
|
hts = HtmlTableSource(content, 'thetable')
|
||||||
|
self.assertEqual(list(hts), [{'a': '1', 'b': '2', 'c': '3'}])
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
self.skipTest("lxml is not installed")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import urllib2
|
|
||||||
|
try:
|
||||||
|
from urllib.request import urlopen # attemp py3 first
|
||||||
|
except ImportError:
|
||||||
|
from urllib2 import urlopen # fallback to py2
|
||||||
|
|
||||||
"""
|
"""
|
||||||
General utilities used within saucebrush that may be useful elsewhere.
|
General utilities used within saucebrush that may be useful elsewhere.
|
||||||
"""
|
"""
|
||||||
@ -20,38 +25,23 @@ def get_django_model(dj_settings, app_label, model_name):
|
|||||||
from django.db.models import get_model
|
from django.db.models import get_model
|
||||||
return get_model(app_label, model_name)
|
return get_model(app_label, model_name)
|
||||||
|
|
||||||
|
|
||||||
def string_dig(element, separator=''):
|
|
||||||
"""
|
|
||||||
Dig into BeautifulSoup HTML elements looking for inner strings.
|
|
||||||
|
|
||||||
If element resembled: <p><b>test</b><em>test</em></p>
|
|
||||||
then string_dig(element, '~') would return test~test
|
|
||||||
"""
|
|
||||||
if element.string:
|
|
||||||
return element.string
|
|
||||||
else:
|
|
||||||
return separator.join([string_dig(child)
|
|
||||||
for child in element.findAll(True)])
|
|
||||||
|
|
||||||
|
|
||||||
def flatten(item, prefix='', separator='_', keys=None):
|
def flatten(item, prefix='', separator='_', keys=None):
|
||||||
"""
|
"""
|
||||||
Flatten nested dictionary into one with its keys concatenated together.
|
Flatten nested dictionary into one with its keys concatenated together.
|
||||||
|
|
||||||
>>> flatten({'a':1, 'b':{'c':2}, 'd':[{'e':{'r':7}}, {'e':5}],
|
>>> flatten({'a':1, 'b':{'c':2}, 'd':[{'e':{'r':7}}, {'e':5}],
|
||||||
'f':{'g':{'h':6}}})
|
'f':{'g':{'h':6}}})
|
||||||
{'a': 1, 'b_c': 2, 'd': [{'e_r': 7}, {'e': 5}], 'f_g_h': 6}
|
{'a': 1, 'b_c': 2, 'd': [{'e_r': 7}, {'e': 5}], 'f_g_h': 6}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# update dictionaries recursively
|
# update dictionaries recursively
|
||||||
|
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
# don't prepend a leading _
|
# don't prepend a leading _
|
||||||
if prefix != '':
|
if prefix != '':
|
||||||
prefix += separator
|
prefix += separator
|
||||||
retval = {}
|
retval = {}
|
||||||
for key, value in item.iteritems():
|
for key, value in item.items():
|
||||||
if (not keys) or (key in keys):
|
if (not keys) or (key in keys):
|
||||||
retval.update(flatten(value, prefix + key, separator, keys))
|
retval.update(flatten(value, prefix + key, separator, keys))
|
||||||
else:
|
else:
|
||||||
@ -60,9 +50,8 @@ def flatten(item, prefix='', separator='_', keys=None):
|
|||||||
#elif isinstance(item, (tuple, list)):
|
#elif isinstance(item, (tuple, list)):
|
||||||
# return {prefix: [flatten(i, prefix, separator, keys) for i in item]}
|
# return {prefix: [flatten(i, prefix, separator, keys) for i in item]}
|
||||||
else:
|
else:
|
||||||
print item, prefix
|
|
||||||
return {prefix: item}
|
return {prefix: item}
|
||||||
|
|
||||||
def str_or_list(obj):
|
def str_or_list(obj):
|
||||||
if isinstance(obj, str):
|
if isinstance(obj, str):
|
||||||
return [obj]
|
return [obj]
|
||||||
@ -76,7 +65,7 @@ def str_or_list(obj):
|
|||||||
class Files(object):
|
class Files(object):
|
||||||
""" Iterate over multiple files as a single file. Pass the paths of the
|
""" Iterate over multiple files as a single file. Pass the paths of the
|
||||||
files as arguments to the class constructor:
|
files as arguments to the class constructor:
|
||||||
|
|
||||||
for line in Files('/path/to/file/a', '/path/to/file/b'):
|
for line in Files('/path/to/file/a', '/path/to/file/b'):
|
||||||
pass
|
pass
|
||||||
"""
|
"""
|
||||||
@ -105,15 +94,15 @@ class Files(object):
|
|||||||
|
|
||||||
class RemoteFile(object):
|
class RemoteFile(object):
|
||||||
""" Stream data from a remote file.
|
""" Stream data from a remote file.
|
||||||
|
|
||||||
:param url: URL to remote file
|
:param url: URL to remote file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, url):
|
def __init__(self, url):
|
||||||
self._url = url
|
self._url = url
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
resp = urllib2.urlopen(self._url)
|
resp = urlopen(self._url)
|
||||||
for line in resp:
|
for line in resp:
|
||||||
yield line.rstrip()
|
yield line.rstrip()
|
||||||
resp.close()
|
resp.close()
|
Loading…
Reference in New Issue
Block a user