import numpy as onp
import jax
import jax.numpy as np
import sys
import time
import functools
from dataclasses import dataclass
from typing import Any, Callable, Optional, List, Union
from jax_fem.generate_mesh import Mesh
from jax_fem.basis import get_face_shape_vals_and_grads, get_shape_vals_and_grads
from jax_fem import logger
onp.set_printoptions(threshold=sys.maxsize,
linewidth=1000,
suppress=True,
precision=5)
[docs]
@dataclass
class FiniteElement:
"""
Defines finite element related to one variable (can be vector valued)
Attributes
----------
mesh : Mesh object
The mesh object stores points (coordinates) and cells (connectivity).
vec : int
The number of vector variable components of the solution.
E.g., a 3D displacement field has u_x, u_y and u_z components, so vec=3
dim : int
The dimension of the problem.
ele_type : str
Element type
dirichlet_bc_info : [location_fns, vecs, value_fns]
location_fns : List[Callable]
Callable : a function that inputs a point and returns if the point satisfies the location condition
vecs: List[int]
integer value must be in the range of 0 to vec - 1,
specifying which component of the (vector) variable to apply Dirichlet condition to
value_fns : List[Callable]
Callable : a function that inputs a point and returns the Dirichlet value
periodic_bc_info : [location_fns_A, location_fns_B, mappings, vecs]
location_fns_A : List[Callable]
Callable : location function for boundary A
location_fns_B : List[Callable]
Callable : location function for boundary B
mappings : List[Callable]
Callable: function mapping a point from boundary A to boundary B
vecs: List[int]
which component of the (vector) variable to apply periodic condition to
"""
mesh: Mesh
vec: int
dim: int
ele_type: str
gauss_order: int
dirichlet_bc_info: Optional[List[Union[List[Callable], List[int], List[Callable]]]]
periodic_bc_info: Optional[List[Union[List[Callable], List[Callable], List[Callable], List[int]]]] = None
def __post_init__(self):
self.points = self.mesh.points
self.cells = self.mesh.cells
self.num_cells = len(self.cells)
self.num_total_nodes = len(self.mesh.points)
self.num_total_dofs = self.num_total_nodes * self.vec
start = time.time()
logger.debug(f"Computing shape function values, gradients, etc.")
self.shape_vals, self.shape_grads_ref, self.quad_weights = get_shape_vals_and_grads(self.ele_type, self.gauss_order)
self.face_shape_vals, self.face_shape_grads_ref, self.face_quad_weights, self.face_normals, self.face_inds \
= get_face_shape_vals_and_grads(self.ele_type, self.gauss_order)
self.num_quads = self.shape_vals.shape[0]
self.num_nodes = self.shape_vals.shape[1]
self.num_faces = self.face_shape_vals.shape[0]
self.shape_grads, self.JxW = self.get_shape_grads()
self.node_inds_list, self.vec_inds_list, self.vals_list = self.Dirichlet_boundary_conditions(self.dirichlet_bc_info)
# self.p_node_inds_list_A, self.p_node_inds_list_B, self.p_vec_inds_list = self.periodic_boundary_conditions()
# (num_cells, num_quads, num_nodes, 1, dim)
self.v_grads_JxW = self.shape_grads[:, :, :, None, :] * self.JxW[:, :, None, None, None]
self.num_face_quads = self.face_quad_weights.shape[1]
end = time.time()
compute_time = end - start
logger.debug(f"Done pre-computations, took {compute_time} [s]")
logger.info(f"Solving a problem with {len(self.cells)} cells, {self.num_total_nodes}x{self.vec} = {self.num_total_dofs} dofs.")
logger.info(f"Element type is {self.ele_type}, using {self.num_quads} quad points per element.")
[docs]
def get_shape_grads(self):
"""Compute shape function gradient value
The gradient is w.r.t physical coordinates.
See Hughes, Thomas JR. The finite element method: linear static and dynamic finite element analysis. Courier Corporation, 2012.
Page 147, Eq. (3.9.3)
Returns
-------
shape_grads_physical : onp.ndarray
(num_cells, num_quads, num_nodes, dim)
JxW : onp.ndarray
(num_cells, num_quads)
"""
assert self.shape_grads_ref.shape == (self.num_quads, self.num_nodes, self.dim)
physical_coos = onp.take(self.points, self.cells, axis=0) # (num_cells, num_nodes, dim)
# (num_cells, num_quads, num_nodes, dim, dim) -> (num_cells, num_quads, 1, dim, dim)
jacobian_dx_deta = onp.sum(physical_coos[:, None, :, :, None] *
self.shape_grads_ref[None, :, :, None, :], axis=2, keepdims=True)
jacobian_det = onp.linalg.det(jacobian_dx_deta)[:, :, 0] # (num_cells, num_quads)
jacobian_deta_dx = onp.linalg.inv(jacobian_dx_deta)
# (1, num_quads, num_nodes, 1, dim) @ (num_cells, num_quads, 1, dim, dim)
# (num_cells, num_quads, num_nodes, 1, dim) -> (num_cells, num_quads, num_nodes, dim)
shape_grads_physical = (self.shape_grads_ref[None, :, :, None, :]
[docs]
@ jacobian_deta_dx)[:, :, :, 0, :]
JxW = jacobian_det * self.quad_weights[None, :]
return shape_grads_physical, JxW
def get_face_shape_grads(self, boundary_inds):
"""Face shape function gradients and JxW (for surface integral)
Nanson's formula is used to map physical surface ingetral to reference domain
Reference: https://en.wikiversity.org/wiki/Continuum_mechanics/Volume_change_and_area_change
Parameters
----------
boundary_inds : List[onp.ndarray]
(num_selected_faces, 2)
Returns
-------
face_shape_grads_physical : onp.ndarray
(num_selected_faces, num_face_quads, num_nodes, dim)
nanson_scale : onp.ndarray
(num_selected_faces, num_face_quads)
"""
physical_coos = onp.take(self.points, self.cells, axis=0) # (num_cells, num_nodes, dim)
selected_coos = physical_coos[boundary_inds[:, 0]] # (num_selected_faces, num_nodes, dim)
selected_f_shape_grads_ref = self.face_shape_grads_ref[boundary_inds[:, 1]] # (num_selected_faces, num_face_quads, num_nodes, dim)
selected_f_normals = self.face_normals[boundary_inds[:, 1]] # (num_selected_faces, dim)
# (num_selected_faces, 1, num_nodes, dim, 1) * (num_selected_faces, num_face_quads, num_nodes, 1, dim)
# (num_selected_faces, num_face_quads, num_nodes, dim, dim) -> (num_selected_faces, num_face_quads, dim, dim)
jacobian_dx_deta = onp.sum(selected_coos[:, None, :, :, None] * selected_f_shape_grads_ref[:, :, :, None, :], axis=2)
jacobian_det = onp.linalg.det(jacobian_dx_deta) # (num_selected_faces, num_face_quads)
jacobian_deta_dx = onp.linalg.inv(jacobian_dx_deta) # (num_selected_faces, num_face_quads, dim, dim)
# (1, num_face_quads, num_nodes, 1, dim) @ (num_selected_faces, num_face_quads, 1, dim, dim)
# (num_selected_faces, num_face_quads, num_nodes, 1, dim) -> (num_selected_faces, num_face_quads, num_nodes, dim)
face_shape_grads_physical = (selected_f_shape_grads_ref[:, :, :, None, :] @ jacobian_deta_dx[:, :, None, :, :])[:, :, :, 0, :]
# (num_selected_faces, 1, 1, dim) @ (num_selected_faces, num_face_quads, dim, dim)
# (num_selected_faces, num_face_quads, 1, dim) -> (num_selected_faces, num_face_quads)
nanson_scale = onp.linalg.norm((selected_f_normals[:, None, None, :] @ jacobian_deta_dx)[:, :, 0, :], axis=-1)
selected_weights = self.face_quad_weights[boundary_inds[:, 1]] # (num_selected_faces, num_face_quads)
nanson_scale = nanson_scale * jacobian_det * selected_weights
return face_shape_grads_physical, nanson_scale
[docs]
def get_physical_quad_points(self):
"""Compute physical quadrature points
Returns
-------
physical_quad_points : onp.ndarray
(num_cells, num_quads, dim)
"""
physical_coos = onp.take(self.points, self.cells, axis=0)
# (1, num_quads, num_nodes, 1) * (num_cells, 1, num_nodes, dim) -> (num_cells, num_quads, dim)
physical_quad_points = onp.sum(self.shape_vals[None, :, :, None] * physical_coos[:, None, :, :], axis=2)
return physical_quad_points
[docs]
def get_physical_surface_quad_points(self, boundary_inds):
"""Compute physical quadrature points on the surface
Parameters
----------
boundary_inds : List[onp.ndarray]
ndarray shape: (num_selected_faces, 2)
Returns
-------
physical_surface_quad_points : ndarray
(num_selected_faces, num_face_quads, dim)
"""
physical_coos = onp.take(self.points, self.cells, axis=0)
selected_coos = physical_coos[boundary_inds[:, 0]] # (num_selected_faces, num_nodes, dim)
selected_face_shape_vals = self.face_shape_vals[boundary_inds[:, 1]] # (num_selected_faces, num_face_quads, num_nodes)
# (num_selected_faces, num_face_quads, num_nodes, 1) * (num_selected_faces, 1, num_nodes, dim) -> (num_selected_faces, num_face_quads, dim)
physical_surface_quad_points = onp.sum(selected_face_shape_vals[:, :, :, None] * selected_coos[:, None, :, :], axis=2)
return physical_surface_quad_points
[docs]
def Dirichlet_boundary_conditions(self, dirichlet_bc_info):
"""Indices and values for Dirichlet B.C.
Parameters
----------
dirichlet_bc_info : [location_fns, vecs, value_fns]
Returns
-------
node_inds_List : List[onp.ndarray]
The ndarray ranges from 0 to num_total_nodes - 1
vec_inds_List : List[onp.ndarray]
The ndarray ranges from 0 to to vec - 1
vals_List : List[ndarray]
Dirichlet values to be assigned
"""
node_inds_list = []
vec_inds_list = []
vals_list = []
if dirichlet_bc_info is not None:
location_fns, vecs, value_fns = dirichlet_bc_info
assert len(location_fns) == len(value_fns) and len(value_fns) == len(vecs)
for i in range(len(location_fns)):
num_args = location_fns[i].__code__.co_argcount
if num_args == 1:
location_fn = lambda point, ind: location_fns[i](point)
elif num_args == 2:
location_fn = location_fns[i]
else:
raise ValueError(f"Wrong number of arguments for location_fn: must be 1 or 2, get {num_args}")
node_inds = onp.argwhere(jax.vmap(location_fn)(self.mesh.points, np.arange(self.num_total_nodes))).reshape(-1)
vec_inds = onp.ones_like(node_inds, dtype=onp.int32) * vecs[i]
values = jax.vmap(value_fns[i])(self.mesh.points[node_inds].reshape(-1, self.dim)).reshape(-1)
node_inds_list.append(node_inds)
vec_inds_list.append(vec_inds)
vals_list.append(values)
return node_inds_list, vec_inds_list, vals_list
[docs]
def update_Dirichlet_boundary_conditions(self, dirichlet_bc_info):
"""Reset Dirichlet boundary conditions.
Useful when a time-dependent problem is solved, and at each iteration the boundary condition needs to be updated.
Parameters
----------
dirichlet_bc_info : [location_fns, vecs, value_fns]
"""
self.node_inds_list, self.vec_inds_list, self.vals_list = self.Dirichlet_boundary_conditions(dirichlet_bc_info)
[docs]
def get_boundary_conditions_inds(self, location_fns):
"""Given location functions, compute which faces satisfy the condition.
Parameters
----------
location_fns : List[Callable]
Callable: a location function that inputs a point (ndarray) and returns if the point satisfies the location condition
e.g., lambda x: np.isclose(x[0], 0.)
If this location function takes 2 arguments, then the first is point and the second is index.
e.g., lambda x, ind: np.isclose(x[0], 0.) & np.isin(ind, np.array([1, 3, 10]))
Returns
-------
boundary_inds_list : List[onp.ndarray]
(num_selected_faces, 2)
boundary_inds_list[k][i, 0] returns the global cell index of the ith selected face of boundary subset k
boundary_inds_list[k][i, 1] returns the local face index of the ith selected face of boundary subset k
"""
# TODO: assume this works for all variables, and return the same result
cell_points = onp.take(self.points, self.cells, axis=0) # (num_cells, num_nodes, dim)
cell_face_points = onp.take(cell_points, self.face_inds, axis=1) # (num_cells, num_faces, num_face_vertices, dim)
cell_face_inds = onp.take(self.cells, self.face_inds, axis=1) # (num_cells, num_faces, num_face_vertices)
boundary_inds_list = []
if location_fns is not None:
for i in range(len(location_fns)):
num_args = location_fns[i].__code__.co_argcount
if num_args == 1:
location_fn = lambda point, ind: location_fns[i](point)
elif num_args == 2:
location_fn = location_fns[i]
else:
raise ValueError(f"Wrong number of arguments for location_fn: must be 1 or 2, get {num_args}")
vmap_location_fn = jax.vmap(location_fn)
def on_boundary(cell_points, cell_inds):
boundary_flag = vmap_location_fn(cell_points, cell_inds)
return onp.all(boundary_flag)
vvmap_on_boundary = jax.vmap(jax.vmap(on_boundary))
boundary_flags = vvmap_on_boundary(cell_face_points, cell_face_inds)
boundary_inds = onp.argwhere(boundary_flags) # (num_selected_faces, 2)
boundary_inds_list.append(boundary_inds)
return boundary_inds_list
[docs]
def convert_from_dof_to_quad(self, sol):
"""Obtain quad values from nodal solution
Parameters
----------
sol : np.DeviceArray
(num_total_nodes, vec)
Returns
-------
u : np.DeviceArray
(num_cells, num_quads, vec)
"""
# (num_total_nodes, vec) -> (num_cells, num_nodes, vec)
cells_sol = sol[self.cells]
# (num_cells, 1, num_nodes, vec) * (1, num_quads, num_nodes, 1) -> (num_cells, num_quads, num_nodes, vec) -> (num_cells, num_quads, vec)
u = np.sum(cells_sol[:, None, :, :] * self.shape_vals[None, :, :, None], axis=2)
return u
[docs]
def convert_from_dof_to_face_quad(self, sol, boundary_inds):
"""Obtain surface solution from nodal solution
Parameters
----------
sol : np.DeviceArray
(num_total_nodes, vec)
boundary_inds : int
Returns
-------
u : np.DeviceArray
(num_selected_faces, num_face_quads, vec)
"""
cells_old_sol = sol[self.cells] # (num_cells, num_nodes, vec)
selected_cell_sols = cells_old_sol[boundary_inds[:, 0]] # (num_selected_faces, num_nodes, vec))
selected_face_shape_vals = self.face_shape_vals[boundary_inds[:, 1]] # (num_selected_faces, num_face_quads, num_nodes)
# (num_selected_faces, 1, num_nodes, vec) * (num_selected_faces, num_face_quads, num_nodes, 1)
# -> (num_selected_faces, num_face_quads, vec)
u = np.sum(selected_cell_sols[:, None, :, :] * selected_face_shape_vals[:, :, :, None], axis=2)
return u
[docs]
def sol_to_grad(self, sol):
"""Obtain solution gradient from nodal solution
Parameters
----------
sol : np.DeviceArray
(num_total_nodes, vec)
Returns
-------
u_grads : np.DeviceArray
(num_cells, num_quads, vec, dim)
"""
# (num_cells, 1, num_nodes, vec, 1) * (num_cells, num_quads, num_nodes, 1, dim) -> (num_cells, num_quads, num_nodes, vec, dim)
u_grads = np.take(sol, self.cells, axis=0)[:, None, :, :, None] * self.shape_grads[:, :, :, None, :]
u_grads = np.sum(u_grads, axis=2) # (num_cells, num_quads, vec, dim)
return u_grads
[docs]
def print_BC_info(self):
"""Print boundary condition information for debugging purposes.
TODO: Not working
"""
if hasattr(self, 'neumann_boundary_inds_list'):
print(f"\n\n### Neumann B.C. is specified")
for i in range(len(self.neumann_boundary_inds_list)):
print(f"\nNeumann Boundary part {i + 1} information:")
print(self.neumann_boundary_inds_list[i])
print(
f"Array.shape = (num_selected_faces, 2) = {self.neumann_boundary_inds_list[i].shape}"
)
print(f"Interpretation:")
print(
f" Array[i, 0] returns the global cell index of the ith selected face"
)
print(
f" Array[i, 1] returns the local face index of the ith selected face"
)
else:
print(f"\n\n### No Neumann B.C. found.")
if len(self.node_inds_list) != 0:
print(f"\n\n### Dirichlet B.C. is specified")
for i in range(len(self.node_inds_list)):
print(f"\nDirichlet Boundary part {i + 1} information:")
bc_array = onp.stack([
self.node_inds_list[i], self.vec_inds_list[i],
self.vals_list[i]
]).T
print(bc_array)
print(
f"Array.shape = (num_selected_dofs, 3) = {bc_array.shape}")
print(f"Interpretation:")
print(
f" Array[i, 0] returns the node index of the ith selected dof"
)
print(
f" Array[i, 1] returns the vec index of the ith selected dof"
)
print(
f" Array[i, 2] returns the value assigned to ith selected dof"
)
else:
print(f"\n\n### No Dirichlet B.C. found.")