Source code for jax_fem.basis

import basix
import numpy as onp

from jax_fem import logger


# def get_full_integration_poly_degree(ele_type, lag_order, dim):
#     """Only works for weak forms of (grad_u, grad_v).
#     TODO: Is this correct?
#     Reference:
#     https://zhuanlan.zhihu.com/p/521630645
#     """
#     if ele_type == 'hexahedron' or ele_type == 'quadrilateral':
#         return 2 * (dim*lag_order - 1)

#     if ele_type == 'tetrahedron' or ele_type == 'triangle':
#         return 2 * (dim*(lag_order - 1) - 1)

[docs] def get_elements(ele_type): """Mesh node ordering is important. If the input mesh file is Gmsh .msh or Abaqus .inp, meshio would convert it to its own ordering. My experience shows that meshio ordering is the same as Abaqus. For example, for a 10-node tetrahedron element, the ordering of meshio is the following https://web.mit.edu/calculix_v2.7/CalculiX/ccx_2.7/doc/ccx/node33.html The troublesome thing is that basix has a different ordering. As shown below https://defelement.com/elements/lagrange.html The consequence is that we need to define this "re_order" variable to make sure the ordering is correct. """ element_family = basix.ElementFamily.P if ele_type == 'HEX8': re_order = [0, 1, 3, 2, 4, 5, 7, 6] basix_ele = basix.CellType.hexahedron basix_face_ele = basix.CellType.quadrilateral gauss_order = 2 # 2x2x2, TODO: is this full integration? degree = 1 elif ele_type == 'HEX27': print(f"Warning: 27-node hexahedron is rarely used in practice and not recommended.") re_order = [0, 1, 3, 2, 4, 5, 7, 6, 8, 11, 13, 9, 16, 18, 19, 17, 10, 12, 15, 14, 22, 23, 21, 24, 20, 25, 26] basix_ele = basix.CellType.hexahedron basix_face_ele = basix.CellType.quadrilateral gauss_order = 10 # 6x6x6, full integration degree = 2 elif ele_type == 'HEX20': re_order = [0, 1, 3, 2, 4, 5, 7, 6, 8, 11, 13, 9, 16, 18, 19, 17, 10, 12, 15, 14] element_family = basix.ElementFamily.serendipity basix_ele = basix.CellType.hexahedron basix_face_ele = basix.CellType.quadrilateral gauss_order = 2 # 6x6x6, full integration degree = 2 elif ele_type == 'TET4': re_order = [0, 1, 2, 3] basix_ele = basix.CellType.tetrahedron basix_face_ele = basix.CellType.triangle gauss_order = 0 # 1, full integration degree = 1 elif ele_type == 'TET10': re_order = [0, 1, 2, 3, 9, 6, 8, 7, 5, 4] basix_ele = basix.CellType.tetrahedron basix_face_ele = basix.CellType.triangle gauss_order = 2 # 4, full integration degree = 2 # TODO: Check if this is correct. elif ele_type == 'QUAD4': re_order = [0, 1, 3, 2] basix_ele = basix.CellType.quadrilateral basix_face_ele = basix.CellType.interval gauss_order = 2 degree = 1 elif ele_type == 'QUAD8': re_order = [0, 1, 3, 2, 4, 6, 7, 5] element_family = basix.ElementFamily.serendipity basix_ele = basix.CellType.quadrilateral basix_face_ele = basix.CellType.interval gauss_order = 2 degree = 2 elif ele_type == 'TRI3': re_order = [0, 1, 2] basix_ele = basix.CellType.triangle basix_face_ele = basix.CellType.interval gauss_order = 0 # 1, full integration degree = 1 elif ele_type == 'TRI6': re_order = [0, 1, 2, 5, 3, 4] basix_ele = basix.CellType.triangle basix_face_ele = basix.CellType.interval gauss_order = 2 # 3, full integration degree = 2 else: raise NotImplementedError return element_family, basix_ele, basix_face_ele, gauss_order, degree, re_order
[docs] def reorder_inds(inds, re_order): new_inds = [] for ind in inds.reshape(-1): new_inds.append(onp.argwhere(re_order == ind)) new_inds = onp.array(new_inds).reshape(inds.shape) return new_inds
[docs] def get_shape_vals_and_grads(ele_type, gauss_order=None): """TODO: Add comments Returns ------- shape_values: ndarray (8, 8) = (num_quads, num_nodes) shape_grads_ref: ndarray (8, 8, 3) = (num_quads, num_nodes, dim) weights: ndarray (8,) = (num_quads,) """ element_family, basix_ele, basix_face_ele, gauss_order_default, degree, re_order = get_elements(ele_type) if gauss_order is None: gauss_order = gauss_order_default quad_points, weights = basix.make_quadrature(basix_ele, gauss_order) element = basix.create_element(element_family, basix_ele, degree) vals_and_grads = element.tabulate(1, quad_points)[:, :, re_order, :] shape_values = vals_and_grads[0, :, :, 0] shape_grads_ref = onp.transpose(vals_and_grads[1:, :, :, 0], axes=(1, 2, 0)) logger.debug(f"ele_type = {ele_type}, quad_points.shape = (num_quads, dim) = {quad_points.shape}") return shape_values, shape_grads_ref, weights
[docs] def get_face_shape_vals_and_grads(ele_type, gauss_order=None): """TODO: Add comments Returns ------- face_shape_vals: ndarray (6, 4, 8) = (num_faces, num_face_quads, num_nodes) face_shape_grads_ref: ndarray (6, 4, 3) = (num_faces, num_face_quads, num_nodes, dim) face_weights: ndarray (6, 4) = (num_faces, num_face_quads) face_normals:ndarray (6, 3) = (num_faces, dim) face_inds: ndarray (6, 4) = (num_faces, num_face_vertices) """ element_family, basix_ele, basix_face_ele, gauss_order_default, degree, re_order = get_elements(ele_type) if gauss_order is None: gauss_order = gauss_order_default # TODO: Check if this is correct. # We should provide freedom for seperate gauss_order for volume integral and surface integral # Currently, they're using the same gauss_order! points, weights = basix.make_quadrature(basix_face_ele, gauss_order) map_degree = 1 lagrange_map = basix.create_element(basix.ElementFamily.P, basix_face_ele, map_degree) values = lagrange_map.tabulate(0, points)[0, :, :, 0] vertices = basix.geometry(basix_ele) dim = len(vertices[0]) facets = basix.cell.sub_entity_connectivity(basix_ele)[dim - 1] # Map face points # Reference: https://docs.fenicsproject.org/basix/main/python/demo/demo_facet_integral.py.html face_quad_points = [] face_inds = [] face_weights = [] for f, facet in enumerate(facets): mapped_points = [] for i in range(len(points)): vals = values[i] mapped_point = onp.sum(vertices[facet[0]] * vals[:, None], axis=0) mapped_points.append(mapped_point) face_quad_points.append(mapped_points) face_inds.append(facet[0]) jacobian = basix.cell.facet_jacobians(basix_ele)[f] if dim == 2: size_jacobian = onp.linalg.norm(jacobian) else: size_jacobian = onp.linalg.norm(onp.cross(jacobian[:, 0], jacobian[:, 1])) face_weights.append(weights*size_jacobian) face_quad_points = onp.stack(face_quad_points) face_weights = onp.stack(face_weights) face_normals = basix.cell.facet_outward_normals(basix_ele) face_inds = onp.array(face_inds) face_inds = reorder_inds(face_inds, re_order) num_faces, num_face_quads, dim = face_quad_points.shape element = basix.create_element(element_family, basix_ele, degree) vals_and_grads = element.tabulate(1, face_quad_points.reshape(-1, dim))[:, :, re_order, :] face_shape_vals = vals_and_grads[0, :, :, 0].reshape(num_faces, num_face_quads, -1) face_shape_grads_ref = vals_and_grads[1:, :, :, 0].reshape(dim, num_faces, num_face_quads, -1) face_shape_grads_ref = onp.transpose(face_shape_grads_ref, axes=(1, 2, 3, 0)) logger.debug(f"face_quad_points.shape = (num_faces, num_face_quads, dim) = {face_quad_points.shape}") return face_shape_vals, face_shape_grads_ref, face_weights, face_normals, face_inds