diff --git a/lifting/importers.py b/lifting/importers.py index f373803..0bd66b3 100644 --- a/lifting/importers.py +++ b/lifting/importers.py @@ -7,7 +7,7 @@ def _clean_name(name): return name.lower() -def import_fitnotes_db(filename): +def import_fitnotes_db(filename, user): # exercise names to db ids exercises = {} for e in Exercise.objects.all(): @@ -37,4 +37,4 @@ def import_fitnotes_db(filename): exercise_id = exercise_id_mapping[fnid] Set.objects.create(exercise_id=exercise_id, date=date, weight_kg=weight_kg, reps=reps, - source='fitnotes') + source='fitnotes', user=user) diff --git a/lifting/migrations/0002_set_user.py b/lifting/migrations/0002_set_user.py new file mode 100644 index 0000000..d9ca992 --- /dev/null +++ b/lifting/migrations/0002_set_user.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import models, migrations +from django.conf import settings + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('lifting', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='set', + name='user', + field=models.ForeignKey(to=settings.AUTH_USER_MODEL, related_name='sets', default=None), + preserve_default=False, + ), + ] diff --git a/lifting/models.py b/lifting/models.py index 76aca9a..9dae07a 100644 --- a/lifting/models.py +++ b/lifting/models.py @@ -1,4 +1,5 @@ from django.db import models +from django.contrib.auth.models import User from django.contrib.postgres.fields import ArrayField SET_TYPES = ( @@ -15,6 +16,7 @@ class Exercise(models.Model): class Set(models.Model): + user = models.ForeignKey(User, related_name='sets') date = models.DateField() exercise = models.ForeignKey(Exercise, related_name='sets') weight_kg = models.DecimalField(max_digits=7, decimal_places=3) diff --git a/lifting/tests.py b/lifting/tests.py index 17f811a..3f45138 100644 --- a/lifting/tests.py +++ b/lifting/tests.py @@ -1,4 +1,5 @@ from django.test import TestCase +from django.contrib.auth.models import User from lifting.models import Exercise, Set from lifting.importers import import_fitnotes_db @@ -17,9 +18,12 @@ class TestFitnotesImport(TestCase): # squat 2 @ 185 # squat 5 @ 225 + def setUp(self): + self.user = User.objects.create_user('default', 'default@example.com', 'default') + def test_basic_import(self): # ensure that the data comes in - import_fitnotes_db('lifting/testdata/example.fitnotes') + import_fitnotes_db('lifting/testdata/example.fitnotes', self.user) assert Exercise.objects.count() == 2 bp = Exercise.objects.get(names__contains=["flat barbell bench press"]) @@ -28,25 +32,24 @@ class TestFitnotesImport(TestCase): def test_double_import(self): # two identical dbs, should be idempotent - import_fitnotes_db('lifting/testdata/example.fitnotes') - import_fitnotes_db('lifting/testdata/example.fitnotes') + import_fitnotes_db('lifting/testdata/example.fitnotes', self.user) + import_fitnotes_db('lifting/testdata/example.fitnotes', self.user) assert Exercise.objects.count() == 2 assert Set.objects.count() == 9 def test_import_with_other_data(self): Exercise.objects.create(names=['incline bench press']) e = Exercise.objects.create(names=['flat barbell bench press']) - Set.objects.create(exercise=e, weight_kg=100, reps=10, date='2014-01-01') - import_fitnotes_db('lifting/testdata/example.fitnotes') + Set.objects.create(exercise=e, weight_kg=100, reps=10, date='2014-01-01', user=self.user) + import_fitnotes_db('lifting/testdata/example.fitnotes', self.user) assert Exercise.objects.count() == 3 assert Set.objects.count() == 10 - def test_bad_import(self): # good db then bad db, should fail without screwing up existing data - import_fitnotes_db('lifting/testdata/example.fitnotes') + import_fitnotes_db('lifting/testdata/example.fitnotes', self.user) with self.assertRaises(Exception): # baddata.fitnotes has all exercise ids set to 9999 - import_fitnotes_db('lifting/testdata/baddata.fitnotes') + import_fitnotes_db('lifting/testdata/baddata.fitnotes', self.user) assert Exercise.objects.count() == 2 assert Set.objects.count() == 9