Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wire up Q/DQ elements in the FInAT interface #71

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-git.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
git+https://github.com/coneoproject/COFFEE#egg=COFFEE
git+https://github.com/firedrakeproject/ufl.git#egg=ufl
git+https://github.com/firedrakeproject/fiat.git#egg=fiat
git+https://github.com/FInAT/FInAT.git#egg=finat
git+https://github.com/FInAT/FInAT.git@tensor-product#egg=finat
29 changes: 29 additions & 0 deletions tests/test_create_finat_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,35 @@ def test_triangle_vector(ufl_element, ufl_vector_element):
assert scalar == vector.base_element


@pytest.fixture(params=["CG", "DG"])
def tensor_name(request):
return request.param


@pytest.fixture(params=[ufl.interval, ufl.triangle,
ufl.quadrilateral],
ids=lambda x: x.cellname())
def ufl_A(request, tensor_name):
return ufl.FiniteElement(tensor_name, request.param, 1)


@pytest.fixture
def ufl_B(tensor_name):
return ufl.FiniteElement(tensor_name, ufl.interval, 1)


def test_tensor_prod_simple(ufl_A, ufl_B):
tensor_ufl = ufl.TensorProductElement(ufl_A, ufl_B)

tensor = f.create_element(tensor_ufl)
A = f.create_element(ufl_A)
B = f.create_element(ufl_B)

assert isinstance(tensor, finat.TensorProductElement)

assert tensor.factors == (A, B)


def test_cache_hit(ufl_element):
A = f.create_element(ufl_element)
B = f.create_element(ufl_element)
Expand Down
7 changes: 1 addition & 6 deletions tsfc/coffee.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,6 @@ def statement_evaluate(leaf, parameters):
return coffee.Block(ops, open_scope=False)
elif isinstance(expr, gem.Constant):
assert parameters.declare[leaf]
# Take all axes except the last one
axes = tuple(range(len(expr.array.shape) - 1))
nz_indices, = expr.array.any(axis=axes).nonzero()
nz_bounds = tuple([(i, 0)] for i in expr.array.shape[:-1])
nz_bounds += ([(max(nz_indices) - min(nz_indices) + 1, min(nz_indices))],)
table = numpy.array(expr.array)
# FFC uses one less digits for rounding than for printing
epsilon = eval("1e-%d" % (parameters.precision - 1))
Expand All @@ -195,7 +190,7 @@ def statement_evaluate(leaf, parameters):
table[abs(table + 1.0) < epsilon] = -1.0
table[abs(table - 0.5) < epsilon] = 0.5
table[abs(table + 0.5) < epsilon] = -0.5
init = coffee.SparseArrayInit(table, parameters.precision, nz_bounds)
init = coffee.ArrayInit(table, parameters.precision)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We assume this basically doesn't matter because the big zero blocks are gone in this implementation anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. No zero blocks from mixed or vector. There are still zeros if many basis functions are zero at the quadrature points, like here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually what I have on the link is something I might be able to optimise away.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this has more compact tables, but code actually looks worse. I think it is better to just leave it like above for now.

return coffee.Decl(SCALAR_TYPE,
_decl_symbol(expr, parameters),
init,
Expand Down
32 changes: 28 additions & 4 deletions tsfc/finatinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import finat

import ufl
from ufl.algorithms.elementtransformations import reconstruct_element

from tsfc.ufl_utils import spanning_degree

Expand All @@ -48,6 +49,9 @@
"Nedelec 2nd kind H(curl)": finat.NedelecSecondKind,
"Raviart-Thomas": finat.RaviartThomas,
"Regge": finat.Regge,
# These require special treatment below
"DQ": None,
"Q": None,
}
"""A :class:`.dict` mapping UFL element family names to their
FIAT-equivalent constructors. If the value is ``None``, the UFL
Expand Down Expand Up @@ -90,11 +94,20 @@ def convert(element):
@convert.register(ufl.FiniteElement)
def convert_finiteelement(element):
cell = as_fiat_cell(element.cell())
lmbda = supported_elements.get(element.family())
if lmbda:
return lmbda(cell, element.degree())
else:
if element.family() not in supported_elements:
return fiat_compat(element)
lmbda = supported_elements.get(element.family())
if lmbda is None:
if element.cell().cellname() != "quadrilateral":
raise ValueError("%s is supported, but handled incorrectly" %
element.family())
# Handle quadrilateral short names like RTCF and RTCE.
element = reconstruct_element(element,
element.family(),
quad_opc,
element.degree())
return finat.QuadrilateralElement(create_element(element))
return lmbda(cell, element.degree())


# MixedElement case
Expand All @@ -117,6 +130,17 @@ def convert_tensorelement(element):
return finat.TensorFiniteElement(scalar_element, element.reference_value_shape())


# TensorProductElement case
@convert.register(ufl.TensorProductElement)
def convert_tensorproductelement(element):
cell = element.cell()
if type(cell) is not ufl.TensorProductCell:
raise ValueError("TensorProductElement not on TensorProductCell?")
return finat.TensorProductElement([create_element(elem)
for elem in element.sub_elements()])


quad_opc = ufl.TensorProductCell(ufl.Cell("interval"), ufl.Cell("interval"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you could take this opportunity to rename this variable as quad_tpc?

_cache = weakref.WeakKeyDictionary()


Expand Down