str_or_list for FieldFilters

This commit is contained in:
James Turk 2008-11-20 15:37:45 +00:00
parent aff9f2295c
commit a585fe95f2
2 changed files with 24 additions and 14 deletions

View File

@ -64,7 +64,7 @@ class FieldFilter(Filter):
def __init__(self, keys): def __init__(self, keys):
super(FieldFilter, self).__init__() super(FieldFilter, self).__init__()
self._target_keys = keys self._target_keys = utils.str_or_list(keys)
def process_record(self, record): def process_record(self, record):
""" Calls process_field on all keys passed to __init__. """ """ Calls process_field on all keys passed to __init__. """
@ -141,7 +141,7 @@ class FieldRemover(Filter):
def __init__(self, keys): def __init__(self, keys):
super(FieldRemover, self).__init__() super(FieldRemover, self).__init__()
self._target_keys = keys self._target_keys = utils.str_or_list(keys)
def process_record(self, record): def process_record(self, record):
for key in self._target_keys: for key in self._target_keys:
@ -194,9 +194,6 @@ class FieldAdder(Filter):
def __init__(self, field_name, field_value): def __init__(self, field_name, field_value):
super(FieldAdder, self).__init__() super(FieldAdder, self).__init__()
self._field_name = field_name self._field_name = field_name
try:
self._field_value = iter(field_value).next
except TypeError:
self._field_value = field_value self._field_value = field_value
def process_record(self, record): def process_record(self, record):
@ -318,11 +315,11 @@ class Flattener(FieldFilter):
class DictFlattener(Filter): class DictFlattener(Filter):
def __init__(self, keys, separator='_'): def __init__(self, keys, separator='_'):
super(DictFlattener, self).__init__() super(DictFlattener, self).__init__()
self._keys = keys self._keys = utils.str_or_list(keys)
self._separator = separator self._separator = separator
def process_record(self, record): def process_record(self, record):
utils.flatten(record, keys=self._keys, separator=self._separator) return utils.flatten(record, keys=self._keys, separator=self._separator)
class Unique(ConditionalFilter): class Unique(ConditionalFilter):
@ -403,11 +400,13 @@ class NameCleaner(Filter):
(?:\s+(?P<suffix>JR\.?|II|III|IV))? (?:\s+(?P<suffix>JR\.?|II|III|IV))?
\s*$''', re.VERBOSE | re.IGNORECASE) \s*$''', re.VERBOSE | re.IGNORECASE)
def __init__(self, keys, name_formats=None): def __init__(self, keys, prefix='', formats=None, nomatch_name=None):
super(NameCleaner, self).__init__() super(NameCleaner, self).__init__()
self._keys = keys self._keys = utils.str_or_list(keys)
if name_formats: self._name_prefix = prefix
self._name_formats = name_formats self._nomatch_name = nomatch_name
if formats:
self._name_formats = formats
else: else:
self._name_formats = [self.FIRST_LAST, self.LAST_FIRST] self._name_formats = [self.FIRST_LAST, self.LAST_FIRST]
@ -424,8 +423,13 @@ class NameCleaner(Filter):
if match: if match:
record.pop(key) record.pop(key)
for k,v in match.groupdict().iteritems(): for k,v in match.groupdict().iteritems():
record[k] = v record[self._name_prefix + k] = v
break break
# can add else statement here to log non-names
# if there is no match, move name into nomatch_name
else:
if self._nomatch_name:
record.pop(key)
record[self._nomatch_name] = name
return record return record

View File

@ -61,6 +61,12 @@ def flatten(item, prefix='', separator='_', keys=None):
print item, prefix print item, prefix
return {prefix: item} return {prefix: item}
def str_or_list(obj):
if isinstance(obj, str):
return [obj]
else:
return obj
def dotted_key_lookup(dict_, dotted_key, default=KeyError, separator='.'): def dotted_key_lookup(dict_, dotted_key, default=KeyError, separator='.'):
""" """
Do a lookup within dict_ by the various elements of dotted_key. Do a lookup within dict_ by the various elements of dotted_key.