import basix
import numpy as onp
from jax_fem import logger
# def get_full_integration_poly_degree(ele_type, lag_order, dim):
# """Only works for weak forms of (grad_u, grad_v).
# TODO: Is this correct?
# Reference:
# https://zhuanlan.zhihu.com/p/521630645
# """
# if ele_type == 'hexahedron' or ele_type == 'quadrilateral':
# return 2 * (dim*lag_order - 1)
# if ele_type == 'tetrahedron' or ele_type == 'triangle':
# return 2 * (dim*(lag_order - 1) - 1)
[docs]
def get_elements(ele_type):
"""Mesh node ordering is important.
If the input mesh file is Gmsh .msh or Abaqus .inp, meshio would convert it to
its own ordering. My experience shows that meshio ordering is the same as Abaqus.
For example, for a 10-node tetrahedron element, the ordering of meshio is the following
https://web.mit.edu/calculix_v2.7/CalculiX/ccx_2.7/doc/ccx/node33.html
The troublesome thing is that basix has a different ordering. As shown below
https://defelement.com/elements/lagrange.html
The consequence is that we need to define this "re_order" variable to make sure the
ordering is correct.
"""
element_family = basix.ElementFamily.P
if ele_type == 'HEX8':
re_order = [0, 1, 3, 2, 4, 5, 7, 6]
basix_ele = basix.CellType.hexahedron
basix_face_ele = basix.CellType.quadrilateral
gauss_order = 2 # 2x2x2, TODO: is this full integration?
degree = 1
elif ele_type == 'HEX27':
print(f"Warning: 27-node hexahedron is rarely used in practice and not recommended.")
re_order = [0, 1, 3, 2, 4, 5, 7, 6, 8, 11, 13, 9, 16, 18, 19,
17, 10, 12, 15, 14, 22, 23, 21, 24, 20, 25, 26]
basix_ele = basix.CellType.hexahedron
basix_face_ele = basix.CellType.quadrilateral
gauss_order = 10 # 6x6x6, full integration
degree = 2
elif ele_type == 'HEX20':
re_order = [0, 1, 3, 2, 4, 5, 7, 6, 8, 11, 13, 9, 16, 18, 19, 17, 10, 12, 15, 14]
element_family = basix.ElementFamily.serendipity
basix_ele = basix.CellType.hexahedron
basix_face_ele = basix.CellType.quadrilateral
gauss_order = 2 # 6x6x6, full integration
degree = 2
elif ele_type == 'TET4':
re_order = [0, 1, 2, 3]
basix_ele = basix.CellType.tetrahedron
basix_face_ele = basix.CellType.triangle
gauss_order = 0 # 1, full integration
degree = 1
elif ele_type == 'TET10':
re_order = [0, 1, 2, 3, 9, 6, 8, 7, 5, 4]
basix_ele = basix.CellType.tetrahedron
basix_face_ele = basix.CellType.triangle
gauss_order = 2 # 4, full integration
degree = 2
# TODO: Check if this is correct.
elif ele_type == 'QUAD4':
re_order = [0, 1, 3, 2]
basix_ele = basix.CellType.quadrilateral
basix_face_ele = basix.CellType.interval
gauss_order = 2
degree = 1
elif ele_type == 'QUAD8':
re_order = [0, 1, 3, 2, 4, 6, 7, 5]
element_family = basix.ElementFamily.serendipity
basix_ele = basix.CellType.quadrilateral
basix_face_ele = basix.CellType.interval
gauss_order = 2
degree = 2
elif ele_type == 'TRI3':
re_order = [0, 1, 2]
basix_ele = basix.CellType.triangle
basix_face_ele = basix.CellType.interval
gauss_order = 0 # 1, full integration
degree = 1
elif ele_type == 'TRI6':
re_order = [0, 1, 2, 5, 3, 4]
basix_ele = basix.CellType.triangle
basix_face_ele = basix.CellType.interval
gauss_order = 2 # 3, full integration
degree = 2
else:
raise NotImplementedError
return element_family, basix_ele, basix_face_ele, gauss_order, degree, re_order
[docs]
def reorder_inds(inds, re_order):
new_inds = []
for ind in inds.reshape(-1):
new_inds.append(onp.argwhere(re_order == ind))
new_inds = onp.array(new_inds).reshape(inds.shape)
return new_inds
[docs]
def get_shape_vals_and_grads(ele_type, gauss_order=None):
"""TODO: Add comments
Returns
-------
shape_values: ndarray
(8, 8) = (num_quads, num_nodes)
shape_grads_ref: ndarray
(8, 8, 3) = (num_quads, num_nodes, dim)
weights: ndarray
(8,) = (num_quads,)
"""
element_family, basix_ele, basix_face_ele, gauss_order_default, degree, re_order = get_elements(ele_type)
if gauss_order is None:
gauss_order = gauss_order_default
quad_points, weights = basix.make_quadrature(basix_ele, gauss_order)
element = basix.create_element(element_family, basix_ele, degree)
vals_and_grads = element.tabulate(1, quad_points)[:, :, re_order, :]
shape_values = vals_and_grads[0, :, :, 0]
shape_grads_ref = onp.transpose(vals_and_grads[1:, :, :, 0], axes=(1, 2, 0))
logger.debug(f"ele_type = {ele_type}, quad_points.shape = (num_quads, dim) = {quad_points.shape}")
return shape_values, shape_grads_ref, weights
[docs]
def get_face_shape_vals_and_grads(ele_type, gauss_order=None):
"""TODO: Add comments
Returns
-------
face_shape_vals: ndarray
(6, 4, 8) = (num_faces, num_face_quads, num_nodes)
face_shape_grads_ref: ndarray
(6, 4, 3) = (num_faces, num_face_quads, num_nodes, dim)
face_weights: ndarray
(6, 4) = (num_faces, num_face_quads)
face_normals:ndarray
(6, 3) = (num_faces, dim)
face_inds: ndarray
(6, 4) = (num_faces, num_face_vertices)
"""
element_family, basix_ele, basix_face_ele, gauss_order_default, degree, re_order = get_elements(ele_type)
if gauss_order is None:
gauss_order = gauss_order_default
# TODO: Check if this is correct.
# We should provide freedom for seperate gauss_order for volume integral and surface integral
# Currently, they're using the same gauss_order!
points, weights = basix.make_quadrature(basix_face_ele, gauss_order)
map_degree = 1
lagrange_map = basix.create_element(basix.ElementFamily.P, basix_face_ele, map_degree)
values = lagrange_map.tabulate(0, points)[0, :, :, 0]
vertices = basix.geometry(basix_ele)
dim = len(vertices[0])
facets = basix.cell.sub_entity_connectivity(basix_ele)[dim - 1]
# Map face points
# Reference: https://docs.fenicsproject.org/basix/main/python/demo/demo_facet_integral.py.html
face_quad_points = []
face_inds = []
face_weights = []
for f, facet in enumerate(facets):
mapped_points = []
for i in range(len(points)):
vals = values[i]
mapped_point = onp.sum(vertices[facet[0]] * vals[:, None], axis=0)
mapped_points.append(mapped_point)
face_quad_points.append(mapped_points)
face_inds.append(facet[0])
jacobian = basix.cell.facet_jacobians(basix_ele)[f]
if dim == 2:
size_jacobian = onp.linalg.norm(jacobian)
else:
size_jacobian = onp.linalg.norm(onp.cross(jacobian[:, 0], jacobian[:, 1]))
face_weights.append(weights*size_jacobian)
face_quad_points = onp.stack(face_quad_points)
face_weights = onp.stack(face_weights)
face_normals = basix.cell.facet_outward_normals(basix_ele)
face_inds = onp.array(face_inds)
face_inds = reorder_inds(face_inds, re_order)
num_faces, num_face_quads, dim = face_quad_points.shape
element = basix.create_element(element_family, basix_ele, degree)
vals_and_grads = element.tabulate(1, face_quad_points.reshape(-1, dim))[:, :, re_order, :]
face_shape_vals = vals_and_grads[0, :, :, 0].reshape(num_faces, num_face_quads, -1)
face_shape_grads_ref = vals_and_grads[1:, :, :, 0].reshape(dim, num_faces, num_face_quads, -1)
face_shape_grads_ref = onp.transpose(face_shape_grads_ref, axes=(1, 2, 3, 0))
logger.debug(f"face_quad_points.shape = (num_faces, num_face_quads, dim) = {face_quad_points.shape}")
return face_shape_vals, face_shape_grads_ref, face_weights, face_normals, face_inds