Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

define_unit override argument #239

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test_modified_define_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unyt

unyt.define_unit("cloudRadius", (1.0, "pc"))
print(unyt.cloudRadius/(2.0 * unyt.pc))

#from unyt import define_unit
#from unyt import pc
#define_unit("cloudRadius", (1.0, "pc"))
#from unyt import cloudRadius
#print(cloudRadius/(2.0*pc))

from importlib import reload
reload(unyt)
# after some processing, switching to a simulation of a larger cloud
unyt.define_unit("cloudRadius", (10.0, "pc"), allow_override=True)
print(unyt.cloudRadius/(2.0 * unyt.pc))

#from unyt import cloudRadius
#print(cloudRadius/(2.0*pc))

18 changes: 14 additions & 4 deletions unyt/tests/test_define_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,20 @@ def test_define_unit():
second = unyt_quantity(1.0, "s")
assert g == volt / second ** (0.5)

## allow_override test
define_unit("Foo", (1.0, "V/s"), allow_override=True)
h = unyt_quantity(1.0, "Foo")
assert h != g
assert h == volt / second

# Test custom registry
reg = UnitRegistry()
define_unit("Foo", (1, "m"), registry=reg)
define_unit("Baz", (1, "Foo**2"), registry=reg)
h = unyt_quantity(1, "Baz", registry=reg)
i = unyt_quantity(1, "m**2", registry=reg)
assert h == i

i = unyt_quantity(1, "Baz", registry=reg)
j = unyt_quantity(1, "m**2", registry=reg)
assert i == j
print("done!")

def test_define_unit_error():
from unyt import define_unit
Expand All @@ -41,3 +47,7 @@ def test_define_unit_error():
define_unit("foobar", 12)
with pytest.raises(RuntimeError):
define_unit("C", (1.0, "A*s"))

if __name__ == "__main__":
test_define_unit()
test_define_unit_error()
52 changes: 34 additions & 18 deletions unyt/unit_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,59 +1038,75 @@ def _validate_dimensions(dimensions):


def define_unit(
symbol, value, tex_repr=None, offset=None, prefixable=False, registry=None
symbol, value, tex_repr=None, offset=None, prefixable=False, registry=None, allow_override=False
):
"""
Define a new unit and add it to the specified unit registry.
Define a unit. Modifies the specified unit registry.

Parameters
----------
symbol : string
The symbol for the new unit.
The symbol for the unit.
value : tuple or :class:`unyt.array.unyt_quantity`
The definition of the new unit in terms of some other units. For
example, one would define a new "mph" unit with ``(1.0, "mile/hr")``
The definition of the unit in terms of some other units. For
example, one would define an "mph" unit with ``(1.0, "mile/hr")``
or with ``1.0*unyt.mile/unyt.hr``
tex_repr : string, optional
The LaTeX representation of the new unit. If one is not supplied, it
The LaTeX representation of the unit. If one is not supplied, it
will be generated automatically based on the symbol string.
offset : float, optional
The default offset for the unit. If not set, an offset of 0 is assumed.
prefixable : boolean, optional
Whether or not the new unit can use SI prefixes. Default: False
Whether or not the unit can use SI prefixes. Default: False
registry : :class:`unyt.unit_registry.UnitRegistry` or None
The unit registry to add the unit to. If None, then defaults to the
The unit registry to modify. If None, then defaults to the
global default unit registry. If registry is set to None then the
unit object will be added as an attribute to the top-level :mod:`unyt`
namespace to ease working with the newly defined unit. See the example
namespace to ease working with the unit. See the example
below.
allow_override : boolean, optional
Whether or not to allow an override to an existing unit in the given
unit registry.

Examples
--------
>>> from unyt import define_unit
>>> from unyt import day
>>> two_weeks = 14.0*day
>>> one_day = 1.0*day
>>> define_unit("two_weeks", two_weeks)
>>> from unyt import two_weeks
>>> print((3*two_weeks)/one_day)
42.0 dimensionless

>>> from unyt import define_unit
>>> from unyt import pc
>>> define_unit("cloudRadius", (1.0, "pc"))
>>> from unyt import cloudRadius
>>> print(cloudRadius/(2.0*pc))
0.5 dimensionless
>>> # after some processing, switching to a simulation of a larger cloud
>>> define_unit("cloudRadius", (10.0, "pc"), allow_override=True)
>>> from unyt import cloudRadius
>>> print(cloudRadius/1.0)
1.0 cloudRadius
"""
import unyt
from unyt.array import _iterable, unyt_quantity

if registry is None:
registry = default_unit_registry
if symbol in registry:
if symbol in registry and not allow_override:
raise RuntimeError(
"Unit symbol '%s' already exists in the provided " "registry" % symbol
)
if not isinstance(value, unyt_quantity):
if _iterable(value) and len(value) == 2:
value = unyt_quantity(value[0], value[1], registry=registry)
else:
raise RuntimeError(
'"value" needs to be a quantity or ' "(value, unit) tuple!"
)
if _iterable(value) and len(value) == 2:
print("1104 hi")
value = unyt_quantity(value[0], value[1], registry=registry)
else:
raise RuntimeError(
'"value" needs to be a quantity or ' "(value, unit) tuple!"
)
base_value = float(value.in_base(unit_system="mks"))
dimensions = value.units.dimensions
registry.add(
Expand All @@ -1104,6 +1120,6 @@ def define_unit(
if registry is default_unit_registry:
u = Unit(symbol, registry=registry)
setattr(unyt, symbol, u)

print("did the setattr wohoo!")

NULL_UNIT = Unit()