diff --git a/fkreplace/__init__.py b/fkreplace/__init__.py index e8ded47..5007f91 100644 --- a/fkreplace/__init__.py +++ b/fkreplace/__init__.py @@ -17,9 +17,12 @@ def merge(from_obj, to_obj): import pdb; pdb.set_trace() raise Exception('unknown code path') - for related_m2m in from_obj._meta.get_all_related_many_to_many_objects(): + for related in from_obj._meta.get_all_related_many_to_many_objects(): accessor_name = related.get_accessor_name() - varname = related.field.name - - #for obj in getattr(from_obj, varname) - + if accessor_name: + varname = related.field.name + field = getattr(from_obj, accessor_name) + if related.many_to_many: + for f in field.all(): + getattr(f, varname).remove(from_obj) + getattr(f, varname).add(to_obj) diff --git a/tests/models.py b/tests/models.py index 1aa860c..4d589f0 100644 --- a/tests/models.py +++ b/tests/models.py @@ -18,4 +18,4 @@ class SSN(models.Model): class Group(models.Model): name = models.CharField(max_length=100) - people = models.ManyToManyField(Person) + people = models.ManyToManyField(Person, related_name='groups') diff --git a/tests/tests.py b/tests/tests.py index fdd428a..17183a8 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -37,3 +37,27 @@ class MergeTests(TestCase): c = Person.objects.get(pk=self.c.pk) c.ssn.number == 1 assert SSN.objects.count() == 2 + + # TODO: test one2one when there's a conflict + + def test_many2many_simple(self): + merge(self.a, self.c) + # A's membership in G has been moved to C + assert self.g.people.all().count() == 2 + assert self.a.groups.all().count() == 0 + assert self.c.groups.all().count() == 1 + + def test_many2many_redundant(self): + merge(self.a, self.b) + # A's membership in G is redundant with B's + assert self.g.people.all().count() == 1 + assert self.a.groups.all().count() == 0 + assert self.b.groups.all().count() == 1 + + def test_many2many_self(self): + self.a.friends.add(self.b) + merge(self.a, self.c) + import pdb; pdb.set_trace() + assert self.a.friends.all().count() == 0 + assert self.b.friends.get().name == self.c + assert self.c.friends.get().name == self.b