diff --git a/oyster/core.py b/oyster/core.py index 194c99a..bb9ce69 100644 --- a/oyster/core.py +++ b/oyster/core.py @@ -83,7 +83,7 @@ class Kernel(object): self.doc_classes[doc_class] = properties - def track_url(self, url, doc_class, **kwargs): + def track_url(self, url, doc_class, id=None, **kwargs): """ Add a URL to the set of tracked URLs, accessible via a given filename. @@ -99,8 +99,11 @@ class Kernel(object): # if data is already tracked and this is just a duplicate call # return the original object if tracked: + # only check id if id was passed in + id_matches = (tracked['_id'] == id) if id else True if (tracked['metadata'] == kwargs and - tracked['doc_class'] == doc_class): + tracked['doc_class'] == doc_class and + id_matches): return tracked['_id'] else: self.log('track', url=url, error='tracking conflict') @@ -108,9 +111,13 @@ class Kernel(object): 'metadata' % url) self.log('track', url=url) - return self.db.tracked.insert(dict(url=url, doc_class=doc_class, - _random=random.randint(0, sys.maxint), - versions=[], metadata=kwargs)) + + newdoc = dict(url=url, doc_class=doc_class, + _random=random.randint(0, sys.maxint), + versions=[], metadata=kwargs) + if id: + newdoc['_id'] = id + return self.db.tracked.insert(newdoc) def md5_versioning(self, olddata, newdata): diff --git a/oyster/tests/test_kernel.py b/oyster/tests/test_kernel.py index 7be0642..8d9c28b 100644 --- a/oyster/tests/test_kernel.py +++ b/oyster/tests/test_kernel.py @@ -78,12 +78,24 @@ class KernelTests(TestCase): id2 = self.kernel.track_url('http://example.com', 'default', pi=3) assert id1 == id2 - # can't track same URL twice with different metadata + # test setting id + out = self.kernel.track_url('http://example.com/2', 'default', + 'fixed-id') + assert out == 'fixed-id' + + # can't track same URL twice with different id + assert_raises(ValueError, self.kernel.track_url, 'http://example.com', + 'default', 'hard-coded-id') + # logged error + assert self.kernel.db.logs.find_one({'error': 'tracking conflict'}) + + # ... with different metadata assert_raises(ValueError, self.kernel.track_url, 'http://example.com', 'default') # logged error assert self.kernel.db.logs.find_one({'error': 'tracking conflict'}) + # ... different doc class assert_raises(ValueError, self.kernel.track_url, 'http://example.com', 'special-doc-class', pi=3) # logged error