diff --git a/bia/test_units.py b/bia/test_units.py index dcdb9d0..bdc203a 100644 --- a/bia/test_units.py +++ b/bia/test_units.py @@ -30,9 +30,14 @@ def test_basic_cmp(): def test_conversion_cmp(): assert Mass(1, 'kg') < Mass(100, 'lb') - assert Mass(10000000, 'g') > Mass(100, 'lb') + assert Mass(1000000, 'g') > Mass(100, 'lb') -def test_add_sub(): +def test_addition(): assert Mass(1, 'kg') + Mass(2, 'kg') == Mass(3, 'kg') + assert Mass(1, 'kg') + Mass(1, 'lb') > Mass(1.4, 'kg') + + +def test_subtraction(): assert Mass(2, 'kg') - Mass(1, 'kg') == Mass(1, 'kg') + assert Mass(1, 'kg') - Mass(1, 'lb') < Mass(0.55, 'kg') diff --git a/bia/units.py b/bia/units.py index 17f8328..2285c2e 100644 --- a/bia/units.py +++ b/bia/units.py @@ -13,7 +13,7 @@ class Unit(object): _mapping = {} def __init__(self, n, unit): - self.scalar = n + self.scalar = float(n) if unit not in self._mapping: raise ValueError('invalid unit {} for {}'.format(unit, self.__class__.__name__)) self.unit = unit @@ -41,7 +41,8 @@ class Unit(object): def __cmp__(self, other): if self.unit != other.unit: other = other.as_unit(self.unit) - return self.scalar - other.scalar + # cmp() removed in Python 3, recommended to replace with this + return (self.scalar > other.scalar) - (self.scalar < other.scalar) def __add__(self, other): if self.unit != other.unit: