test basics of division and multiplication

This commit is contained in:
James Turk 2015-03-24 00:32:24 -04:00
parent d6096b9ed2
commit a93cc51188
2 changed files with 57 additions and 6 deletions

View File

@ -18,6 +18,11 @@ def test_basic_conversions():
assert g.as_unit('lb').scalar == a.as_unit('lb').scalar assert g.as_unit('lb').scalar == a.as_unit('lb').scalar
def test_complex_conversion():
a = Mass(2, ['kg', 'kg'])
b = a.as_unit('lb')
def test_basic_cmp(): def test_basic_cmp():
assert Mass(1, 'kg') < Mass(2, 'kg') assert Mass(1, 'kg') < Mass(2, 'kg')
assert Mass(1, 'kg') <= Mass(2, 'kg') assert Mass(1, 'kg') <= Mass(2, 'kg')
@ -41,3 +46,13 @@ def test_addition():
def test_subtraction(): def test_subtraction():
assert Mass(2, 'kg') - Mass(1, 'kg') == Mass(1, 'kg') assert Mass(2, 'kg') - Mass(1, 'kg') == Mass(1, 'kg')
assert Mass(1, 'kg') - Mass(1, 'lb') < Mass(0.55, 'kg') assert Mass(1, 'kg') - Mass(1, 'lb') < Mass(0.55, 'kg')
def test_multiplication():
assert Mass(2, 'kg') * 2 == Mass(4, 'kg')
assert Mass(2, 'kg') * Mass(1, 'kg') == Mass(2, ['kg', 'kg'])
def test_division():
assert Mass(8, 'kg') / 2 == Mass(4, 'kg')
assert Mass(2, 'kg') / Mass(1, 'kg') == Mass(2, [])

View File

@ -12,11 +12,23 @@ class ConversionError(ValueError):
class Unit(object): class Unit(object):
_mapping = {} _mapping = {}
def __init__(self, n, unit): def __init__(self, n, unit, denom=None):
self.scalar = float(n) self.scalar = float(n)
if unit not in self._mapping:
raise ValueError('invalid unit {} for {}'.format(unit, self.__class__.__name__)) if isinstance(unit, str):
self.unit = unit self.unit_numerator = [unit]
else:
self.unit_numerator = unit
self.unit_denominator = [] if denom is None else denom
for u in self.unit_numerator + self.unit_denominator:
if u not in self._mapping:
raise ValueError('invalid unit {} for {}'.format(unit, self.__class__.__name__))
@property
def unit(self):
return '*'.join(self.unit_numerator)
def as_unit(self, unit): def as_unit(self, unit):
if self.unit == unit: if self.unit == unit:
@ -47,13 +59,37 @@ class Unit(object):
def __add__(self, other): def __add__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
other = other.as_unit(self.unit) other = other.as_unit(self.unit)
return self.__class__(self.scalar.__add__(other.scalar), self.unit) return self.__class__(self.scalar + other.scalar, self.unit)
def __sub__(self, other): def __sub__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
other = other.as_unit(self.unit) other = other.as_unit(self.unit)
return self.__class__(self.scalar.__sub__(other.scalar), self.unit) return self.__class__(self.scalar - other.scalar, self.unit)
def __mul__(self, other):
if isinstance(other, Unit):
if self.unit != other.unit:
other = other.as_unit(self.unit)
return self.__class__(self.scalar * other.scalar,
self.unit_numerator + other.unit_numerator)
else:
return self.__class__(self.scalar * other, self.unit)
def __div__(self, other):
if isinstance(other, Unit):
if self.unit != other.unit:
other = other.as_unit(self.unit)
new_numerator = list(self.unit_numerator)
new_denominator = list(self.unit_denominator)
for u in other.unit_numerator:
if u in new_numerator:
new_numerator.remove(u)
else:
new_denominator.append(u)
return self.__class__(self.scalar / other.scalar, new_numerator, new_denominator)
else:
return self.__class__(self.scalar / other, self.unit)
class Mass(Unit): class Mass(Unit):
_base = 'kg' _base = 'kg'