import basix
import numpy as onp
from jax_fem import logger
# def get_full_integration_poly_degree(ele_type, lag_order, dim):
# """Only works for weak forms of (grad_u, grad_v).
# TODO: Is this correct?
# Reference:
# https://zhuanlan.zhihu.com/p/521630645
# """
# if ele_type == 'hexahedron' or ele_type == 'quadrilateral':
# return 2 * (dim*lag_order - 1)
# if ele_type == 'tetrahedron' or ele_type == 'triangle':
# return 2 * (dim*(lag_order - 1) - 1)
[docs]
def get_elements(ele_type):
"""Obtain element information useful for `basix <https://github.com/FEniCS/basix>`_ to handle.
Note that mesh node ordering is important.
If the input mesh file is Gmsh .msh or Abaqus .inp, meshio would convert it to its own ordering.
Our experience shows that meshio ordering is the same as Abaqus.
For example, for a 10-node tetrahedron element, the ordering of meshio follows this
`instruction <https://web.mit.edu/calculix_v2.7/CalculiX/ccx_2.7/doc/ccx/node33.html>`_.
The troublesome thing is that basix has `a different ordering <https://defelement.com/elements/lagrange.html>`_.
The consequence is that we need to define the **re_order** variable to make sure the ordering is correct.
Parameters
----------
ele_type: str
:attr:`~jax_fem.fe.FiniteElement.ele_type`
Returns
-------
element_family : BasixObject
`ElementFamily <https://docs.fenicsproject.org/basix/main/python/_autosummary/basix.html#basix.ElementFamily>`_
basix_ele : BasixObject
For element: `CellType <https://docs.fenicsproject.org/basix/main/python/_autosummary/basix.html#basix.CellType>`_
basix_face_ele : BasixObject
For element face: `CellType <https://docs.fenicsproject.org/basix/main/python/_autosummary/basix.html#basix.CellType>`_
gauss_order : int
:attr:`~jax_fem.fe.FiniteElement.gauss_order`
degree : int
Element degree, used in basix
re_order : list
Specifieds node re-ordering transformation. For example, [0, 1, 3, 2, 4, 5, 7, 6] for HEX8 element.
"""
element_family = basix.ElementFamily.P
if ele_type == 'HEX8':
re_order = [0, 1, 3, 2, 4, 5, 7, 6]
basix_ele = basix.CellType.hexahedron
basix_face_ele = basix.CellType.quadrilateral
gauss_order = 2 # 2x2x2, TODO: is this full integration?
degree = 1
elif ele_type == 'HEX27':
print(f"Warning: 27-node hexahedron is rarely used in practice and not recommended.")
re_order = [0, 1, 3, 2, 4, 5, 7, 6, 8, 11, 13, 9, 16, 18, 19,
17, 10, 12, 15, 14, 22, 23, 21, 24, 20, 25, 26]
basix_ele = basix.CellType.hexahedron
basix_face_ele = basix.CellType.quadrilateral
gauss_order = 10 # 6x6x6, full integration
degree = 2
elif ele_type == 'HEX20':
re_order = [0, 1, 3, 2, 4, 5, 7, 6, 8, 11, 13, 9, 16, 18, 19, 17, 10, 12, 15, 14]
element_family = basix.ElementFamily.serendipity
basix_ele = basix.CellType.hexahedron
basix_face_ele = basix.CellType.quadrilateral
gauss_order = 2 # 6x6x6, full integration
degree = 2
elif ele_type == 'TET4':
re_order = [0, 1, 2, 3]
basix_ele = basix.CellType.tetrahedron
basix_face_ele = basix.CellType.triangle
gauss_order = 0 # 1, full integration
degree = 1
elif ele_type == 'TET10':
re_order = [0, 1, 2, 3, 9, 6, 8, 7, 5, 4]
basix_ele = basix.CellType.tetrahedron
basix_face_ele = basix.CellType.triangle
gauss_order = 2 # 4, full integration
degree = 2
# TODO: Check if this is correct.
elif ele_type == 'QUAD4':
re_order = [0, 1, 3, 2]
basix_ele = basix.CellType.quadrilateral
basix_face_ele = basix.CellType.interval
gauss_order = 2
degree = 1
elif ele_type == 'QUAD8':
re_order = [0, 1, 3, 2, 4, 6, 7, 5]
element_family = basix.ElementFamily.serendipity
basix_ele = basix.CellType.quadrilateral
basix_face_ele = basix.CellType.interval
gauss_order = 2
degree = 2
elif ele_type == 'TRI3':
re_order = [0, 1, 2]
basix_ele = basix.CellType.triangle
basix_face_ele = basix.CellType.interval
gauss_order = 0 # 1, full integration
degree = 1
elif ele_type == 'TRI6':
re_order = [0, 1, 2, 5, 3, 4]
basix_ele = basix.CellType.triangle
basix_face_ele = basix.CellType.interval
gauss_order = 2 # 3, full integration
degree = 2
else:
raise NotImplementedError
return element_family, basix_ele, basix_face_ele, gauss_order, degree, re_order
def reorder_inds(inds, re_order):
"""Apply re-ordering transformation for node indices.
"""
new_inds = []
for ind in inds.reshape(-1):
new_inds.append(onp.argwhere(re_order == ind))
new_inds = onp.array(new_inds).reshape(inds.shape)
return new_inds
[docs]
def get_shape_vals_and_grads(ele_type, gauss_order=None):
"""Use `basix <https://github.com/FEniCS/basix>`_ to get shape function values and gradients for elements.
Parameters
----------
ele_type : str
:attr:`~jax_fem.fe.FiniteElement.ele_type`
gauss_order : int
:attr:`~jax_fem.fe.FiniteElement.gauss_order`
Returns
-------
shape_values: NumpyArray
Shape is (num_quads, num_nodes), e.g, (8, 8) for HEX8 element.
shape_grads_ref: NumpyArray
Shape is (num_quads, num_nodes, dim), e.g, (8, 8, 3) for HEX8 element.
weights: NumpyArray
Shape is (num_quads,), e.g, (8,) for HEX8 element.
"""
element_family, basix_ele, basix_face_ele, gauss_order_default, degree, re_order = get_elements(ele_type)
if gauss_order is None:
gauss_order = gauss_order_default
quad_points, weights = basix.make_quadrature(basix_ele, gauss_order)
element = basix.create_element(element_family, basix_ele, degree)
vals_and_grads = element.tabulate(1, quad_points)[:, :, re_order, :]
shape_values = vals_and_grads[0, :, :, 0]
shape_grads_ref = onp.transpose(vals_and_grads[1:, :, :, 0], axes=(1, 2, 0))
logger.debug(f"ele_type = {ele_type}, quad_points.shape = (num_quads, dim) = {quad_points.shape}")
return shape_values, shape_grads_ref, weights
[docs]
def get_face_shape_vals_and_grads(ele_type, gauss_order=None):
"""Use `basix <https://github.com/FEniCS/basix>`_ to get shape function values and gradients for element faces.
Parameters
----------
ele_type : str
:attr:`~jax_fem.fe.FiniteElement.ele_type`
gauss_order : int
:attr:`~jax_fem.fe.FiniteElement.gauss_order`
Returns
-------
face_shape_vals: NumpyArray
Shape is (num_faces, num_face_quads, num_nodes), e.g, (6, 4, 8) for HEX8 element.
face_shape_grads_ref: NumpyArray
Shape is(num_faces, num_face_quads, num_nodes, dim), e.g, (6, 4, 3) for HEX8 element.
face_weights: NumpyArray
Shape is (num_faces, num_face_quads), e.g, (6, 4) for HEX8 element.
face_normals:NumpyArray
Shape is (num_faces, dim), e.g, (6, 3) for HEX8 element.
face_inds: NumpyArray
Shape is (num_faces, num_face_vertices), e.g, (6, 4) for HEX8 element.
"""
element_family, basix_ele, basix_face_ele, gauss_order_default, degree, re_order = get_elements(ele_type)
if gauss_order is None:
gauss_order = gauss_order_default
# TODO: Check if this is correct.
# We should provide freedom for seperate gauss_order for volume integral and surface integral
# Currently, they're using the same gauss_order!
points, weights = basix.make_quadrature(basix_face_ele, gauss_order)
map_degree = 1
lagrange_map = basix.create_element(basix.ElementFamily.P, basix_face_ele, map_degree)
values = lagrange_map.tabulate(0, points)[0, :, :, 0]
vertices = basix.geometry(basix_ele)
dim = len(vertices[0])
facets = basix.cell.sub_entity_connectivity(basix_ele)[dim - 1]
# Map face points
# Reference: https://docs.fenicsproject.org/basix/main/python/demo/demo_facet_integral.py.html
face_quad_points = []
face_inds = []
face_weights = []
for f, facet in enumerate(facets):
mapped_points = []
for i in range(len(points)):
vals = values[i]
mapped_point = onp.sum(vertices[facet[0]] * vals[:, None], axis=0)
mapped_points.append(mapped_point)
face_quad_points.append(mapped_points)
face_inds.append(facet[0])
jacobian = basix.cell.facet_jacobians(basix_ele)[f]
if dim == 2:
size_jacobian = onp.linalg.norm(jacobian)
else:
size_jacobian = onp.linalg.norm(onp.cross(jacobian[:, 0], jacobian[:, 1]))
face_weights.append(weights*size_jacobian)
face_quad_points = onp.stack(face_quad_points)
face_weights = onp.stack(face_weights)
face_normals = basix.cell.facet_outward_normals(basix_ele)
face_inds = onp.array(face_inds)
face_inds = reorder_inds(face_inds, re_order)
num_faces, num_face_quads, dim = face_quad_points.shape
element = basix.create_element(element_family, basix_ele, degree)
vals_and_grads = element.tabulate(1, face_quad_points.reshape(-1, dim))[:, :, re_order, :]
face_shape_vals = vals_and_grads[0, :, :, 0].reshape(num_faces, num_face_quads, -1)
face_shape_grads_ref = vals_and_grads[1:, :, :, 0].reshape(dim, num_faces, num_face_quads, -1)
face_shape_grads_ref = onp.transpose(face_shape_grads_ref, axes=(1, 2, 3, 0))
logger.debug(f"face_quad_points.shape = (num_faces, num_face_quads, dim) = {face_quad_points.shape}")
return face_shape_vals, face_shape_grads_ref, face_weights, face_normals, face_inds