From 9211b27415c1de3b05cf8266cf4dcc99f3bc7f79 Mon Sep 17 00:00:00 2001 From: James Turk Date: Tue, 24 Mar 2015 16:22:32 -0400 Subject: [PATCH] refactor Unit --- bia/test_units.py | 55 ++++++++++++----------- bia/units.py | 112 +++++++++++++++++++++++++++------------------- 2 files changed, 93 insertions(+), 74 deletions(-) diff --git a/bia/test_units.py b/bia/test_units.py index 956dbaf..e54743e 100644 --- a/bia/test_units.py +++ b/bia/test_units.py @@ -1,58 +1,59 @@ -from .units import Mass +from .units import Unit +from .units import UnitValue as V + + +def test_unit_basics(): + assert Unit('kg') == Unit(['kg']) + assert Unit(['kg'], ['m']) == Unit(['kg'], ['m']) def test_basic_conversions(): - a = Mass(2, 'kg') + a = V(2, 'kg') kg = a.as_unit('kg') assert kg.scalar == 2 - assert kg.unit == 'kg' + assert kg.unit == Unit('kg') lb = a.as_unit('lb') assert abs(lb.scalar - 4.40925) < 0.0001 - assert lb.unit == 'lb' + assert lb.unit == Unit('lb') g = a.as_unit('g') assert abs(g.scalar - 2000) < 0.0001 - assert g.unit == 'g' + assert g.unit == Unit('g') assert g.as_unit('lb').scalar == a.as_unit('lb').scalar -def test_complex_conversion(): - a = Mass(2, ['kg', 'kg']) - b = a.as_unit('lb') - - def test_basic_cmp(): - assert Mass(1, 'kg') < Mass(2, 'kg') - assert Mass(1, 'kg') <= Mass(2, 'kg') - assert Mass(2, 'kg') <= Mass(2, 'kg') - assert Mass(2, 'kg') == Mass(2, 'kg') - assert Mass(2, 'kg') >= Mass(2, 'kg') - assert Mass(2, 'kg') > Mass(1, 'kg') - assert Mass(2, 'kg') >= Mass(1, 'kg') + assert V(1, 'kg') < V(2, 'kg') + assert V(1, 'kg') <= V(2, 'kg') + assert V(2, 'kg') <= V(2, 'kg') + assert V(2, 'kg') == V(2, 'kg') + assert V(2, 'kg') >= V(2, 'kg') + assert V(2, 'kg') > V(1, 'kg') + assert V(2, 'kg') >= V(1, 'kg') def test_conversion_cmp(): - assert Mass(1, 'kg') < Mass(100, 'lb') - assert Mass(1000000, 'g') > Mass(100, 'lb') + assert V(1, 'kg') < V(100, 'lb') + assert V(1000000, 'g') > V(100, 'lb') def test_addition(): - assert Mass(1, 'kg') + Mass(2, 'kg') == Mass(3, 'kg') - assert Mass(1, 'kg') + Mass(1, 'lb') > Mass(1.4, 'kg') + assert V(1, 'kg') + V(2, 'kg') == V(3, 'kg') + assert V(1, 'kg') + V(1, 'lb') > V(1.4, 'kg') def test_subtraction(): - assert Mass(2, 'kg') - Mass(1, 'kg') == Mass(1, 'kg') - assert Mass(1, 'kg') - Mass(1, 'lb') < Mass(0.55, 'kg') + assert V(2, 'kg') - V(1, 'kg') == V(1, 'kg') + assert V(1, 'kg') - V(1, 'lb') < V(0.55, 'kg') def test_multiplication(): - assert Mass(2, 'kg') * 2 == Mass(4, 'kg') - assert Mass(2, 'kg') * Mass(1, 'kg') == Mass(2, ['kg', 'kg']) + assert V(2, 'kg') * 2 == V(4, 'kg') + assert V(2, 'kg') * V(1, 'kg') == V(2, ['kg', 'kg']) def test_division(): - assert Mass(8, 'kg') / 2 == Mass(4, 'kg') - assert Mass(2, 'kg') / Mass(1, 'kg') == Mass(2, []) + assert V(8, 'kg') / 2 == V(4, 'kg') + assert V(2, 'kg') / V(1, 'kg') == V(2, []) diff --git a/bia/units.py b/bia/units.py index d1595f9..1815a5b 100644 --- a/bia/units.py +++ b/bia/units.py @@ -1,6 +1,16 @@ from collections import defaultdict +_UNITS = { + # mass + 'kg': {'base': 'kg', 'scale': 1.0}, + 'lb': {'base': 'kg', 'scale': 2.20462}, + 'g': {'base': 'kg', 'scale': 1000.0}, + + 'm': {'base': 'm', 'scale': 1.0}, +} + + class UnitError(ValueError): pass @@ -10,39 +20,66 @@ class ConversionError(ValueError): class Unit(object): - _mapping = {} - - def __init__(self, n, unit, denom=None): - self.scalar = float(n) - - if isinstance(unit, str): - self.unit_numerator = [unit] + def __init__(self, numerator, denominator=None): + if isinstance(numerator, str): + self.numerator = [numerator] else: - self.unit_numerator = unit + self.numerator = numerator + self.denominator = [] if denominator is None else denominator + for u in self.numerator + self.denominator: + if u not in _UNITS: + raise ValueError('invalid unit {}'.format(self)) - self.unit_denominator = [] if denom is None else denom + self._simplify() - for u in self.unit_numerator + self.unit_denominator: - if u not in self._mapping: - raise ValueError('invalid unit {} for {}'.format(unit, self.__class__.__name__)) + def __str__(self): + return '*'.join(self.numerator) - @property - def unit(self): - return '*'.join(self.unit_numerator) + def __repr__(self): + return 'Unit({!r}, {!r})'.format(self.numerator, self.denominator) + + def __eq__(self, other): + return self.numerator == other.numerator and self.denominator == other.denominator + + def __mul__(self, other): + return Unit(self.numerator + other.numerator, self.denominator + other.denominator) + + def __div__(self, other): + return Unit(self.numerator + other.denominator, self.denominator + other.numerator) + + def _simplify(self): + for u in list(self.denominator): + if u in self.numerator: + self.numerator.remove(u) + self.denominator.remove(u) + + def conversion_factor(self, other): + factor = 1 + for u in self.numerator: + factor /= _UNITS[u]['scale'] + for u in other.numerator: + factor *= _UNITS[u]['scale'] + return factor + + @staticmethod + def unit(u): + if isinstance(u, Unit): + return u + else: + return Unit(u) + + +class UnitValue(object): + def __init__(self, n, unit): + self.scalar = float(n) + self.unit = Unit.unit(unit) def as_unit(self, unit): + unit = Unit.unit(unit) if self.unit == unit: return self.__class__(self.scalar, self.unit) - try: - if self.unit == self._base: - factor = self._mapping[unit] - else: - factor = self._mapping[unit] / self._mapping[self.unit] - except KeyError: - raise ConversionError('cannot convert from {} to {}'.format(self.unit, unit)) - - return self.__class__(self.scalar * factor, unit) + return self.__class__(self.scalar * self.unit.conversion_factor(unit), unit) def __str__(self): return '{}{}'.format(self.scalar, self.unit) @@ -67,36 +104,17 @@ class Unit(object): return self.__class__(self.scalar - other.scalar, self.unit) def __mul__(self, other): - if isinstance(other, Unit): + if isinstance(other, UnitValue): if self.unit != other.unit: other = other.as_unit(self.unit) - return self.__class__(self.scalar * other.scalar, - self.unit_numerator + other.unit_numerator) + return self.__class__(self.scalar * other.scalar, self.unit * other.unit) else: return self.__class__(self.scalar * other, self.unit) def __div__(self, other): - if isinstance(other, Unit): + if isinstance(other, UnitValue): if self.unit != other.unit: other = other.as_unit(self.unit) - new_numerator = list(self.unit_numerator) - new_denominator = list(self.unit_denominator) - for u in other.unit_numerator: - if u in new_numerator: - new_numerator.remove(u) - else: - new_denominator.append(u) - - return self.__class__(self.scalar / other.scalar, new_numerator, new_denominator) + return self.__class__(self.scalar / other.scalar, self.unit / other.unit) else: return self.__class__(self.scalar / other, self.unit) - -class Mass(Unit): - _base = 'kg' - - # 1kg == - _mapping = { - 'kg': 1.0, - 'lb': 2.20462, - 'g': 1000.0, - }