diff --git a/test_modified_define_unit.py b/test_modified_define_unit.py new file mode 100644 index 00000000..759ec5eb --- /dev/null +++ b/test_modified_define_unit.py @@ -0,0 +1,20 @@ +import unyt + +unyt.define_unit("cloudRadius", (1.0, "pc")) +print(unyt.cloudRadius/(2.0 * unyt.pc)) + +#from unyt import define_unit +#from unyt import pc +#define_unit("cloudRadius", (1.0, "pc")) +#from unyt import cloudRadius +#print(cloudRadius/(2.0*pc)) + +from importlib import reload +reload(unyt) +# after some processing, switching to a simulation of a larger cloud +unyt.define_unit("cloudRadius", (10.0, "pc"), allow_override=True) +print(unyt.cloudRadius/(2.0 * unyt.pc)) + +#from unyt import cloudRadius +#print(cloudRadius/(2.0*pc)) + diff --git a/unyt/tests/test_define_unit.py b/unyt/tests/test_define_unit.py index 23898b1c..14ddcbe2 100644 --- a/unyt/tests/test_define_unit.py +++ b/unyt/tests/test_define_unit.py @@ -23,14 +23,20 @@ def test_define_unit(): second = unyt_quantity(1.0, "s") assert g == volt / second ** (0.5) + ## allow_override test + define_unit("Foo", (1.0, "V/s"), allow_override=True) + h = unyt_quantity(1.0, "Foo") + assert h != g + assert h == volt / second + # Test custom registry reg = UnitRegistry() define_unit("Foo", (1, "m"), registry=reg) define_unit("Baz", (1, "Foo**2"), registry=reg) - h = unyt_quantity(1, "Baz", registry=reg) - i = unyt_quantity(1, "m**2", registry=reg) - assert h == i - + i = unyt_quantity(1, "Baz", registry=reg) + j = unyt_quantity(1, "m**2", registry=reg) + assert i == j + print("done!") def test_define_unit_error(): from unyt import define_unit @@ -41,3 +47,7 @@ def test_define_unit_error(): define_unit("foobar", 12) with pytest.raises(RuntimeError): define_unit("C", (1.0, "A*s")) + +if __name__ == "__main__": + test_define_unit() + test_define_unit_error() diff --git a/unyt/unit_object.py b/unyt/unit_object.py index 37a0eb09..3662d349 100644 --- a/unyt/unit_object.py +++ b/unyt/unit_object.py @@ -1038,35 +1038,39 @@ def _validate_dimensions(dimensions): def define_unit( - symbol, value, tex_repr=None, offset=None, prefixable=False, registry=None + symbol, value, tex_repr=None, offset=None, prefixable=False, registry=None, allow_override=False ): """ - Define a new unit and add it to the specified unit registry. + Define a unit. Modifies the specified unit registry. Parameters ---------- symbol : string - The symbol for the new unit. + The symbol for the unit. value : tuple or :class:`unyt.array.unyt_quantity` - The definition of the new unit in terms of some other units. For - example, one would define a new "mph" unit with ``(1.0, "mile/hr")`` + The definition of the unit in terms of some other units. For + example, one would define an "mph" unit with ``(1.0, "mile/hr")`` or with ``1.0*unyt.mile/unyt.hr`` tex_repr : string, optional - The LaTeX representation of the new unit. If one is not supplied, it + The LaTeX representation of the unit. If one is not supplied, it will be generated automatically based on the symbol string. offset : float, optional The default offset for the unit. If not set, an offset of 0 is assumed. prefixable : boolean, optional - Whether or not the new unit can use SI prefixes. Default: False + Whether or not the unit can use SI prefixes. Default: False registry : :class:`unyt.unit_registry.UnitRegistry` or None - The unit registry to add the unit to. If None, then defaults to the + The unit registry to modify. If None, then defaults to the global default unit registry. If registry is set to None then the unit object will be added as an attribute to the top-level :mod:`unyt` - namespace to ease working with the newly defined unit. See the example + namespace to ease working with the unit. See the example below. + allow_override : boolean, optional + Whether or not to allow an override to an existing unit in the given + unit registry. Examples -------- + >>> from unyt import define_unit >>> from unyt import day >>> two_weeks = 14.0*day >>> one_day = 1.0*day @@ -1074,23 +1078,35 @@ def define_unit( >>> from unyt import two_weeks >>> print((3*two_weeks)/one_day) 42.0 dimensionless + + >>> from unyt import define_unit + >>> from unyt import pc + >>> define_unit("cloudRadius", (1.0, "pc")) + >>> from unyt import cloudRadius + >>> print(cloudRadius/(2.0*pc)) + 0.5 dimensionless + >>> # after some processing, switching to a simulation of a larger cloud + >>> define_unit("cloudRadius", (10.0, "pc"), allow_override=True) + >>> from unyt import cloudRadius + >>> print(cloudRadius/1.0) + 1.0 cloudRadius """ import unyt from unyt.array import _iterable, unyt_quantity if registry is None: registry = default_unit_registry - if symbol in registry: + if symbol in registry and not allow_override: raise RuntimeError( "Unit symbol '%s' already exists in the provided " "registry" % symbol ) - if not isinstance(value, unyt_quantity): - if _iterable(value) and len(value) == 2: - value = unyt_quantity(value[0], value[1], registry=registry) - else: - raise RuntimeError( - '"value" needs to be a quantity or ' "(value, unit) tuple!" - ) + if _iterable(value) and len(value) == 2: + print("1104 hi") + value = unyt_quantity(value[0], value[1], registry=registry) + else: + raise RuntimeError( + '"value" needs to be a quantity or ' "(value, unit) tuple!" + ) base_value = float(value.in_base(unit_system="mks")) dimensions = value.units.dimensions registry.add( @@ -1104,6 +1120,6 @@ def define_unit( if registry is default_unit_registry: u = Unit(symbol, registry=registry) setattr(unyt, symbol, u) - + print("did the setattr wohoo!") NULL_UNIT = Unit()