add replace option to sqliteemitter and add better kwargs to dictreader on csvsource

This commit is contained in:
Jeremy Carbaugh 2009-09-01 17:45:05 -04:00
parent 1d05d434b9
commit dcda3db140
2 changed files with 14 additions and 5 deletions

View File

@ -85,24 +85,33 @@ class SqliteEmitter(Emitter):
as a third parameter to SqliteEmitter.) as a third parameter to SqliteEmitter.)
""" """
def __init__(self, dbname, table_name, fieldnames=None): def __init__(self, dbname, table_name, fieldnames=None, replace=False, quiet=False):
super(SqliteEmitter, self).__init__() super(SqliteEmitter, self).__init__()
import sqlite3 import sqlite3
self._conn = sqlite3.connect(dbname) self._conn = sqlite3.connect(dbname)
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._table_name = table_name self._table_name = table_name
self._replace = replace
self._quiet = quiet
if fieldnames: if fieldnames:
create = "CREATE TABLE IF NOT EXISTS %s (%s)" % (table_name, create = "CREATE TABLE IF NOT EXISTS %s (%s)" % (table_name,
', '.join([' '.join((field, 'TEXT')) for field in fieldnames])) ', '.join([' '.join((field, 'TEXT')) for field in fieldnames]))
self._cursor.execute(create) self._cursor.execute(create)
def emit_record(self, record): def emit_record(self, record):
import sqlite3
# input should be escaped with ? if data isn't trusted # input should be escaped with ? if data isn't trusted
qmarks = ','.join(('?',) * len(record)) qmarks = ','.join(('?',) * len(record))
insert = 'INSERT INTO %s (%s) VALUES (%s)' % (self._table_name, insert = 'INSERT OR REPLACE' if self._replace else 'INSERT'
insert = '%s INTO %s (%s) VALUES (%s)' % (insert, self._table_name,
','.join(record.keys()), ','.join(record.keys()),
qmarks) qmarks)
self._cursor.execute(insert, record.values()) try:
self._cursor.execute(insert, record.values())
except sqlite3.IntegrityError, ie:
if not self._quiet:
raise ie
self.reject_record(record, ie.message)
def done(self): def done(self):
self._conn.commit() self._conn.commit()

View File

@ -22,9 +22,9 @@ class CSVSource(object):
ignoring the first row (presumed to be column names). ignoring the first row (presumed to be column names).
""" """
def __init__(self, csvfile, fieldnames=None, skiprows=0, delimiter=','): def __init__(self, csvfile, fieldnames=None, skiprows=0, **kwargs):
import csv import csv
self._dictreader = csv.DictReader(csvfile, fieldnames, delimiter=delimiter) self._dictreader = csv.DictReader(csvfile, fieldnames, **kwargs)
for _ in xrange(skiprows): for _ in xrange(skiprows):
self._dictreader.next() self._dictreader.next()