diff --git a/bia/test_units.py b/bia/test_units.py index 321bcbe..dcdb9d0 100644 --- a/bia/test_units.py +++ b/bia/test_units.py @@ -16,3 +16,23 @@ def test_basic_conversions(): assert g.unit == 'g' assert g.as_unit('lb').scalar == a.as_unit('lb').scalar + + +def test_basic_cmp(): + assert Mass(1, 'kg') < Mass(2, 'kg') + assert Mass(1, 'kg') <= Mass(2, 'kg') + assert Mass(2, 'kg') <= Mass(2, 'kg') + assert Mass(2, 'kg') == Mass(2, 'kg') + assert Mass(2, 'kg') >= Mass(2, 'kg') + assert Mass(2, 'kg') > Mass(1, 'kg') + assert Mass(2, 'kg') >= Mass(1, 'kg') + + +def test_conversion_cmp(): + assert Mass(1, 'kg') < Mass(100, 'lb') + assert Mass(10000000, 'g') > Mass(100, 'lb') + + +def test_add_sub(): + assert Mass(1, 'kg') + Mass(2, 'kg') == Mass(3, 'kg') + assert Mass(2, 'kg') - Mass(1, 'kg') == Mass(1, 'kg') diff --git a/bia/units.py b/bia/units.py index 77f395a..17f8328 100644 --- a/bia/units.py +++ b/bia/units.py @@ -38,8 +38,20 @@ class Unit(object): def __repr__(self): return 'U({}, {!r})'.format(self.scalar, self.unit) - def __cmp__(self): - pass + def __cmp__(self, other): + if self.unit != other.unit: + other = other.as_unit(self.unit) + return self.scalar - other.scalar + + def __add__(self, other): + if self.unit != other.unit: + other = other.as_unit(self.unit) + return self.__class__(self.scalar.__add__(other.scalar), self.unit) + + def __sub__(self, other): + if self.unit != other.unit: + other = other.as_unit(self.unit) + return self.__class__(self.scalar.__sub__(other.scalar), self.unit) class Mass(Unit):