refactor Unit

This commit is contained in:
James Turk 2015-03-24 16:22:32 -04:00
parent a93cc51188
commit 9211b27415
2 changed files with 93 additions and 74 deletions

View File

@ -1,58 +1,59 @@
from .units import Mass from .units import Unit
from .units import UnitValue as V
def test_unit_basics():
assert Unit('kg') == Unit(['kg'])
assert Unit(['kg'], ['m']) == Unit(['kg'], ['m'])
def test_basic_conversions(): def test_basic_conversions():
a = Mass(2, 'kg') a = V(2, 'kg')
kg = a.as_unit('kg') kg = a.as_unit('kg')
assert kg.scalar == 2 assert kg.scalar == 2
assert kg.unit == 'kg' assert kg.unit == Unit('kg')
lb = a.as_unit('lb') lb = a.as_unit('lb')
assert abs(lb.scalar - 4.40925) < 0.0001 assert abs(lb.scalar - 4.40925) < 0.0001
assert lb.unit == 'lb' assert lb.unit == Unit('lb')
g = a.as_unit('g') g = a.as_unit('g')
assert abs(g.scalar - 2000) < 0.0001 assert abs(g.scalar - 2000) < 0.0001
assert g.unit == 'g' assert g.unit == Unit('g')
assert g.as_unit('lb').scalar == a.as_unit('lb').scalar assert g.as_unit('lb').scalar == a.as_unit('lb').scalar
def test_complex_conversion():
a = Mass(2, ['kg', 'kg'])
b = a.as_unit('lb')
def test_basic_cmp(): def test_basic_cmp():
assert Mass(1, 'kg') < Mass(2, 'kg') assert V(1, 'kg') < V(2, 'kg')
assert Mass(1, 'kg') <= Mass(2, 'kg') assert V(1, 'kg') <= V(2, 'kg')
assert Mass(2, 'kg') <= Mass(2, 'kg') assert V(2, 'kg') <= V(2, 'kg')
assert Mass(2, 'kg') == Mass(2, 'kg') assert V(2, 'kg') == V(2, 'kg')
assert Mass(2, 'kg') >= Mass(2, 'kg') assert V(2, 'kg') >= V(2, 'kg')
assert Mass(2, 'kg') > Mass(1, 'kg') assert V(2, 'kg') > V(1, 'kg')
assert Mass(2, 'kg') >= Mass(1, 'kg') assert V(2, 'kg') >= V(1, 'kg')
def test_conversion_cmp(): def test_conversion_cmp():
assert Mass(1, 'kg') < Mass(100, 'lb') assert V(1, 'kg') < V(100, 'lb')
assert Mass(1000000, 'g') > Mass(100, 'lb') assert V(1000000, 'g') > V(100, 'lb')
def test_addition(): def test_addition():
assert Mass(1, 'kg') + Mass(2, 'kg') == Mass(3, 'kg') assert V(1, 'kg') + V(2, 'kg') == V(3, 'kg')
assert Mass(1, 'kg') + Mass(1, 'lb') > Mass(1.4, 'kg') assert V(1, 'kg') + V(1, 'lb') > V(1.4, 'kg')
def test_subtraction(): def test_subtraction():
assert Mass(2, 'kg') - Mass(1, 'kg') == Mass(1, 'kg') assert V(2, 'kg') - V(1, 'kg') == V(1, 'kg')
assert Mass(1, 'kg') - Mass(1, 'lb') < Mass(0.55, 'kg') assert V(1, 'kg') - V(1, 'lb') < V(0.55, 'kg')
def test_multiplication(): def test_multiplication():
assert Mass(2, 'kg') * 2 == Mass(4, 'kg') assert V(2, 'kg') * 2 == V(4, 'kg')
assert Mass(2, 'kg') * Mass(1, 'kg') == Mass(2, ['kg', 'kg']) assert V(2, 'kg') * V(1, 'kg') == V(2, ['kg', 'kg'])
def test_division(): def test_division():
assert Mass(8, 'kg') / 2 == Mass(4, 'kg') assert V(8, 'kg') / 2 == V(4, 'kg')
assert Mass(2, 'kg') / Mass(1, 'kg') == Mass(2, []) assert V(2, 'kg') / V(1, 'kg') == V(2, [])

View File

@ -1,6 +1,16 @@
from collections import defaultdict from collections import defaultdict
_UNITS = {
# mass
'kg': {'base': 'kg', 'scale': 1.0},
'lb': {'base': 'kg', 'scale': 2.20462},
'g': {'base': 'kg', 'scale': 1000.0},
'm': {'base': 'm', 'scale': 1.0},
}
class UnitError(ValueError): class UnitError(ValueError):
pass pass
@ -10,39 +20,66 @@ class ConversionError(ValueError):
class Unit(object): class Unit(object):
_mapping = {} def __init__(self, numerator, denominator=None):
if isinstance(numerator, str):
def __init__(self, n, unit, denom=None): self.numerator = [numerator]
self.scalar = float(n)
if isinstance(unit, str):
self.unit_numerator = [unit]
else: else:
self.unit_numerator = unit self.numerator = numerator
self.denominator = [] if denominator is None else denominator
for u in self.numerator + self.denominator:
if u not in _UNITS:
raise ValueError('invalid unit {}'.format(self))
self.unit_denominator = [] if denom is None else denom self._simplify()
for u in self.unit_numerator + self.unit_denominator: def __str__(self):
if u not in self._mapping: return '*'.join(self.numerator)
raise ValueError('invalid unit {} for {}'.format(unit, self.__class__.__name__))
@property def __repr__(self):
def unit(self): return 'Unit({!r}, {!r})'.format(self.numerator, self.denominator)
return '*'.join(self.unit_numerator)
def __eq__(self, other):
return self.numerator == other.numerator and self.denominator == other.denominator
def __mul__(self, other):
return Unit(self.numerator + other.numerator, self.denominator + other.denominator)
def __div__(self, other):
return Unit(self.numerator + other.denominator, self.denominator + other.numerator)
def _simplify(self):
for u in list(self.denominator):
if u in self.numerator:
self.numerator.remove(u)
self.denominator.remove(u)
def conversion_factor(self, other):
factor = 1
for u in self.numerator:
factor /= _UNITS[u]['scale']
for u in other.numerator:
factor *= _UNITS[u]['scale']
return factor
@staticmethod
def unit(u):
if isinstance(u, Unit):
return u
else:
return Unit(u)
class UnitValue(object):
def __init__(self, n, unit):
self.scalar = float(n)
self.unit = Unit.unit(unit)
def as_unit(self, unit): def as_unit(self, unit):
unit = Unit.unit(unit)
if self.unit == unit: if self.unit == unit:
return self.__class__(self.scalar, self.unit) return self.__class__(self.scalar, self.unit)
try: return self.__class__(self.scalar * self.unit.conversion_factor(unit), unit)
if self.unit == self._base:
factor = self._mapping[unit]
else:
factor = self._mapping[unit] / self._mapping[self.unit]
except KeyError:
raise ConversionError('cannot convert from {} to {}'.format(self.unit, unit))
return self.__class__(self.scalar * factor, unit)
def __str__(self): def __str__(self):
return '{}{}'.format(self.scalar, self.unit) return '{}{}'.format(self.scalar, self.unit)
@ -67,36 +104,17 @@ class Unit(object):
return self.__class__(self.scalar - other.scalar, self.unit) return self.__class__(self.scalar - other.scalar, self.unit)
def __mul__(self, other): def __mul__(self, other):
if isinstance(other, Unit): if isinstance(other, UnitValue):
if self.unit != other.unit: if self.unit != other.unit:
other = other.as_unit(self.unit) other = other.as_unit(self.unit)
return self.__class__(self.scalar * other.scalar, return self.__class__(self.scalar * other.scalar, self.unit * other.unit)
self.unit_numerator + other.unit_numerator)
else: else:
return self.__class__(self.scalar * other, self.unit) return self.__class__(self.scalar * other, self.unit)
def __div__(self, other): def __div__(self, other):
if isinstance(other, Unit): if isinstance(other, UnitValue):
if self.unit != other.unit: if self.unit != other.unit:
other = other.as_unit(self.unit) other = other.as_unit(self.unit)
new_numerator = list(self.unit_numerator) return self.__class__(self.scalar / other.scalar, self.unit / other.unit)
new_denominator = list(self.unit_denominator)
for u in other.unit_numerator:
if u in new_numerator:
new_numerator.remove(u)
else:
new_denominator.append(u)
return self.__class__(self.scalar / other.scalar, new_numerator, new_denominator)
else: else:
return self.__class__(self.scalar / other, self.unit) return self.__class__(self.scalar / other, self.unit)
class Mass(Unit):
_base = 'kg'
# 1kg ==
_mapping = {
'kg': 1.0,
'lb': 2.20462,
'g': 1000.0,
}