diff --git a/saucebrush/sources.py b/saucebrush/sources.py index 9ebe532..92844a7 100644 --- a/saucebrush/sources.py +++ b/saucebrush/sources.py @@ -226,20 +226,20 @@ class FileSource(object): def __iter__(self): # This method would be a lot cleaner with the proposed # 'yield from' expression (PEP 380) - if hasattr(self._input, '__read__'): - for record in self._process_file(input): + if hasattr(self._input, '__read__') or hasattr(self._input, 'read'): + for record in self._process_file(self._input): yield record - elif isinstance(self._input, basestring): + elif isinstance(self._input, str): with open(self._input) as f: for record in self._process_file(f): yield record elif hasattr(self._input, '__iter__'): for el in self._input: - if isinstance(el, basestring): + if isinstance(el, str): with open(el) as f: for record in self._process_file(f): yield record - elif hasattr(el, '__read__'): + elif hasattr(el, '__read__') or hasattr(el, 'read'): for record in self._process_file(f): yield record @@ -256,10 +256,11 @@ class JSONSource(FileSource): object. """ - def _process_file(self, file): + def _process_file(self, f): + import json - obj = json.load(file) + obj = json.load(f) # If the top-level JSON object in the file is a list # then yield each element separately; otherwise, yield diff --git a/saucebrush/tests/sources.py b/saucebrush/tests/sources.py index 1db434c..83c645f 100644 --- a/saucebrush/tests/sources.py +++ b/saucebrush/tests/sources.py @@ -2,7 +2,8 @@ from __future__ import unicode_literals from io import BytesIO, StringIO import unittest -from saucebrush.sources import CSVSource, FixedWidthFileSource, HtmlTableSource +from saucebrush.sources import ( + CSVSource, FixedWidthFileSource, HtmlTableSource, JSONSource) class SourceTestCase(unittest.TestCase): @@ -44,6 +45,13 @@ class SourceTestCase(unittest.TestCase): 'year':'1999'}] self.assertEqual(list(source), expected_data) + def test_json_source(self): + + content = StringIO("""[{"a": 1, "b": "2", "c": 3}]""") + + js = JSONSource(content) + self.assertEqual(list(js), [{'a': 1, 'b': '2', 'c': 3}]) + def test_html_table_source(self): content = StringIO("""