diff --git a/fkreplace/__init__.py b/fkreplace/__init__.py index e38d0c4..e8ded47 100644 --- a/fkreplace/__init__.py +++ b/fkreplace/__init__.py @@ -1,10 +1,21 @@ -# thanks to https://djangosnippets.org/snippets/2283/ for inspiration on m2m -def migrate(from_obj, to_obj): +def merge(from_obj, to_obj): for related in from_obj._meta.get_all_related_objects(): accessor_name = related.get_accessor_name() varname = related.field.name - getattr(from_obj, accessor_name).all().update(**{varname: to_obj}) + field = getattr(from_obj, accessor_name) + if related.multiple: + field.all().update(**{varname: to_obj}) + elif related.one_to_one: + try: + getattr(to_obj, accessor_name) + except Exception as e: + # doesn't exist, safe to overwrite + setattr(field, varname, to_obj) + field.save() + else: + import pdb; pdb.set_trace() + raise Exception('unknown code path') for related_m2m in from_obj._meta.get_all_related_many_to_many_objects(): accessor_name = related.get_accessor_name() diff --git a/tests/tests.py b/tests/tests.py index 95e73fb..684954d 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,8 +1,39 @@ from django.test import TestCase -from .models import Person, Number, SSN +from tests.models import Person, Number, SSN, Group +from fkreplace import merge -def setUp(): - a = Person.objects.create(name='alf') - b = Person.objects.create(name='bee') - Number.objects.create(person=a, number='123') - Number.objects.create(person=a, number='1234') +class MergeTests(TestCase): + def setUp(self): + self.a = Person.objects.create(name='alf') + self.b = Person.objects.create(name='bee') + self.c = Person.objects.create(name='sea') + Number.objects.create(person=self.a, number='555-1111') + Number.objects.create(person=self.a, number='555-1112') + Number.objects.create(person=self.b, number='555-1113') + SSN.objects.create(person=self.a, number='1') + SSN.objects.create(person=self.b, number='2') + self.g = Group.objects.create(name='Team Awesome') + self.g.people.add(self.a) + self.g.people.add(self.b) + + def test_fk_simple(self): + merge(self.a, self.c) + # move FKs pointing at A to C + assert self.a.numbers.count() == 0 + assert self.c.numbers.count() == 2 + assert Number.objects.count() == 3 + + def test_fk_existing(self): + merge(self.a, self.b) + # everything now on b + assert self.a.numbers.count() == 0 + assert self.b.numbers.count() == 3 + assert Number.objects.count() == 3 + + def test_one2one_simple(self): + merge(self.a, self.c) + # move FKs pointing at A to C + #assert self.a.ssn is None + c = Person.objects.get(pk=self.c.pk) + c.ssn.number == 1 + assert SSN.objects.count() == 2