diff --git a/saucebrush/sources.py b/saucebrush/sources.py index 91461f3..bebeeae 100644 --- a/saucebrush/sources.py +++ b/saucebrush/sources.py @@ -154,6 +154,12 @@ class MongoDBSource(object): for doc in self.collection.find(self.spec): yield dict(doc) +# dict_factory for sqlite source +def dict_factory(cursor, row): + d = { } + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d class SqliteSource(object): """ Source that reads from a sqlite database. @@ -164,35 +170,32 @@ class SqliteSource(object): """ def __init__(self, dbpath, query, args=None, conn_params=None): + + import sqlite3 + self._dbpath = dbpath self._query = query self._args = args or [] self._conn_params = conn_params or [] - - def _process_query(self): - - import sqlite3 - def dict_factory(cursor, row): - d = { } - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d - - conn = sqlite3.connect(self._dbpath) - conn.row_factory = dict_factory - + # setup connection + self._conn = sqlite3.connect(self._dbpath) + self._conn.row_factory = dict_factory if self._conn_params: for param, value in self._conn_params.iteritems(): - setattr(conn, param, value) + setattr(self._conn, param, value) - cursor = conn.cursor() + def _process_query(self): + + cursor = self._conn.cursor() for row in cursor.execute(self._query, self._args): yield row cursor.close() - conn.close() def __iter__(self): return self._process_query() + + def done(self): + self._conn.close()