Source code for jax_fem.problem

import numpy as onp
import jax
import jax.numpy as np
import jax.flatten_util
from dataclasses import dataclass
from typing import Any, Callable, Optional, List, Union
import functools

from jax_fem.utils import timeit 
from jax_fem.generate_mesh import Mesh
from jax_fem.fe import FiniteElement
from jax_fem import logger


[docs] @dataclass class Problem: mesh: Mesh vec: int dim: int ele_type: str = 'HEX8' gauss_order: int = None dirichlet_bc_info: Optional[List[Union[List[Callable], List[int], List[Callable]]]] = None location_fns: Optional[List[Callable]] = None additional_info: Any = () def __post_init__(self): if type(self.mesh) != type([]): self.mesh = [self.mesh] self.vec = [self.vec] self.ele_type = [self.ele_type] self.gauss_order = [self.gauss_order] self.dirichlet_bc_info = [self.dirichlet_bc_info] self.num_vars = len(self.mesh) self.fes = [FiniteElement(mesh=self.mesh[i], vec=self.vec[i], dim=self.dim, ele_type=self.ele_type[i], gauss_order=self.gauss_order[i] if type(self.gauss_order) == type([]) else self.gauss_order, dirichlet_bc_info=self.dirichlet_bc_info[i] if type(self.dirichlet_bc_info) == type([]) else self.dirichlet_bc_info) \ for i in range(self.num_vars)] self.cells_list = [fe.cells for fe in self.fes] # Assume all fes have the same number of cells, same dimension self.num_cells = self.fes[0].num_cells self.boundary_inds_list = self.fes[0].get_boundary_conditions_inds(self.location_fns) self.offset = [0] for i in range(len(self.fes) - 1): self.offset.append(self.offset[i] + self.fes[i].num_total_dofs) def find_ind(*x): inds = [] for i in range(len(x)): x[i].reshape(-1) crt_ind = self.fes[i].vec * x[i][:, None] + np.arange(self.fes[i].vec)[None, :] + self.offset[i] inds.append(crt_ind.reshape(-1)) return np.hstack(inds) # (num_cells, num_nodes*vec + ...) inds = onp.array(jax.vmap(find_ind)(*self.cells_list)) self.I = onp.repeat(inds[:, :, None], inds.shape[1], axis=2).reshape(-1) self.J = onp.repeat(inds[:, None, :], inds.shape[1], axis=1).reshape(-1) self.cells_list_face_list = [] for i, boundary_inds in enumerate(self.boundary_inds_list): cells_list_face = [cells[boundary_inds[:, 0]] for cells in self.cells_list] # [(num_selected_faces, num_nodes), ...] inds_face = onp.array(jax.vmap(find_ind)(*cells_list_face)) # (num_selected_faces, num_nodes*vec + ...) I_face = onp.repeat(inds_face[:, :, None], inds_face.shape[1], axis=2).reshape(-1) J_face = onp.repeat(inds_face[:, None, :], inds_face.shape[1], axis=1).reshape(-1) self.I = onp.hstack((self.I, I_face)) self.J = onp.hstack((self.J, J_face)) self.cells_list_face_list.append(cells_list_face) self.cells_flat = jax.vmap(lambda *x: jax.flatten_util.ravel_pytree(x)[0])(*self.cells_list) # (num_cells, num_nodes + ...) dumb_array_dof = [np.zeros((fe.num_nodes, fe.vec)) for fe in self.fes] # TODO: dumb_array_dof is useless? dumb_array_node = [np.zeros(fe.num_nodes) for fe in self.fes] # _, unflatten_fn_node = jax.flatten_util.ravel_pytree(dumb_array_node) _, self.unflatten_fn_dof = jax.flatten_util.ravel_pytree(dumb_array_dof) dumb_sol_list = [np.zeros((fe.num_total_nodes, fe.vec)) for fe in self.fes] dumb_dofs, self.unflatten_fn_sol_list = jax.flatten_util.ravel_pytree(dumb_sol_list) self.num_total_dofs_all_vars = len(dumb_dofs) self.num_nodes_cumsum = onp.cumsum([0] + [fe.num_nodes for fe in self.fes]) # (num_cells, num_vars, num_quads) self.JxW = onp.transpose(onp.stack([fe.JxW for fe in self.fes]), axes=(1, 0, 2)) # (num_cells, num_quads, num_nodes +..., dim) self.shape_grads = onp.concatenate([fe.shape_grads for fe in self.fes], axis=2) # (num_cells, num_quads, num_nodes + ..., 1, dim) self.v_grads_JxW = onp.concatenate([fe.v_grads_JxW for fe in self.fes], axis=2) # TODO: assert all vars quad points be the same # (num_cells, num_quads, dim) self.physical_quad_points = self.fes[0].get_physical_quad_points() self.selected_face_shape_grads = [] self.nanson_scale = [] self.selected_face_shape_vals = [] self.physical_surface_quad_points = [] for boundary_inds in self.boundary_inds_list: s_shape_grads = [] n_scale = [] s_shape_vals = [] for fe in self.fes: # (num_selected_faces, num_face_quads, num_nodes, dim), (num_selected_faces, num_face_quads) face_shape_grads_physical, nanson_scale = fe.get_face_shape_grads(boundary_inds) selected_face_shape_vals = fe.face_shape_vals[boundary_inds[:, 1]] # (num_selected_faces, num_face_quads, num_nodes) s_shape_grads.append(face_shape_grads_physical) n_scale.append(nanson_scale) s_shape_vals.append(selected_face_shape_vals) # (num_selected_faces, num_face_quads, num_nodes + ..., dim) s_shape_grads = onp.concatenate(s_shape_grads, axis=2) # (num_selected_faces, num_vars, num_face_quads) n_scale = onp.transpose(onp.stack(n_scale), axes=(1, 0, 2)) # (num_selected_faces, num_face_quads, num_nodes + ...) s_shape_vals = onp.concatenate(s_shape_vals, axis=2) # (num_selected_faces, num_face_quads, dim) physical_surface_quad_points = self.fes[0].get_physical_surface_quad_points(boundary_inds) self.selected_face_shape_grads.append(s_shape_grads) self.nanson_scale.append(n_scale) self.selected_face_shape_vals.append(s_shape_vals) # TODO: assert all vars face quad points be the same self.physical_surface_quad_points.append(physical_surface_quad_points) self.internal_vars = () self.internal_vars_surfaces = [() for _ in range(len(self.boundary_inds_list))] self.custom_init(*self.additional_info) self.pre_jit_fns()
[docs] def custom_init(self): """Child class should override if more things need to be done in initialization """ pass
[docs] def get_laplace_kernel(self, tensor_map): def laplace_kernel(cell_sol_flat, cell_shape_grads, cell_v_grads_JxW, *cell_internal_vars): # cell_sol_flat: (num_nodes*vec + ...,) # cell_sol_list: [(num_nodes, vec), ...] # cell_shape_grads: (num_quads, num_nodes + ..., dim) # cell_v_grads_JxW: (num_quads, num_nodes + ..., 1, dim) cell_sol_list = self.unflatten_fn_dof(cell_sol_flat) cell_shape_grads = cell_shape_grads[:, :self.fes[0].num_nodes, :] cell_sol = cell_sol_list[0] cell_v_grads_JxW = cell_v_grads_JxW[:, :self.fes[0].num_nodes, :, :] vec = self.fes[0].vec # (1, num_nodes, vec, 1) * (num_quads, num_nodes, 1, dim) -> (num_quads, num_nodes, vec, dim) u_grads = cell_sol[None, :, :, None] * cell_shape_grads[:, :, None, :] u_grads = np.sum(u_grads, axis=1) # (num_quads, vec, dim) u_grads_reshape = u_grads.reshape(-1, vec, self.dim) # (num_quads, vec, dim) # (num_quads, vec, dim) u_physics = jax.vmap(tensor_map)(u_grads_reshape, *cell_internal_vars).reshape(u_grads.shape) # (num_quads, num_nodes, vec, dim) -> (num_nodes, vec) val = np.sum(u_physics[:, None, :, :] * cell_v_grads_JxW, axis=(0, -1)) val = jax.flatten_util.ravel_pytree(val)[0] # (num_nodes*vec + ...,) return val return laplace_kernel
[docs] def get_mass_kernel(self, mass_map): def mass_kernel(cell_sol_flat, x, cell_JxW, *cell_internal_vars): # cell_sol_flat: (num_nodes*vec + ...,) # cell_sol_list: [(num_nodes, vec), ...] # x: (num_quads, dim) # cell_JxW: (num_vars, num_quads) cell_sol_list = self.unflatten_fn_dof(cell_sol_flat) cell_sol = cell_sol_list[0] cell_JxW = cell_JxW[0] vec = self.fes[0].vec # (1, num_nodes, vec) * (num_quads, num_nodes, 1) -> (num_quads, num_nodes, vec) -> (num_quads, vec) u = np.sum(cell_sol[None, :, :] * self.fes[0].shape_vals[:, :, None], axis=1) u_physics = jax.vmap(mass_map)(u, x, *cell_internal_vars) # (num_quads, vec) # (num_quads, 1, vec) * (num_quads, num_nodes, 1) * (num_quads, 1, 1) -> (num_nodes, vec) val = np.sum(u_physics[:, None, :] * self.fes[0].shape_vals[:, :, None] * cell_JxW[:, None, None], axis=0) val = jax.flatten_util.ravel_pytree(val)[0] # (num_nodes*vec + ...,) return val return mass_kernel
[docs] def get_surface_kernel(self, surface_map): def surface_kernel(cell_sol_flat, x, face_shape_vals, face_shape_grads, face_nanson_scale, *cell_internal_vars_surface): # face_shape_vals: (num_face_quads, num_nodes + ...) # face_shape_grads: (num_face_quads, num_nodes + ..., dim) # x: (num_face_quads, dim) # face_nanson_scale: (num_vars, num_face_quads) cell_sol_list = self.unflatten_fn_dof(cell_sol_flat) cell_sol = cell_sol_list[0] face_shape_vals = face_shape_vals[:, :self.fes[0].num_nodes] face_nanson_scale = face_nanson_scale[0] # (1, num_nodes, vec) * (num_face_quads, num_nodes, 1) -> (num_face_quads, vec) u = np.sum(cell_sol[None, :, :] * face_shape_vals[:, :, None], axis=1) u_physics = jax.vmap(surface_map)(u, x, *cell_internal_vars_surface) # (num_face_quads, vec) # (num_face_quads, 1, vec) * (num_face_quads, num_nodes, 1) * (num_face_quads, 1, 1) -> (num_nodes, vec) val = np.sum(u_physics[:, None, :] * face_shape_vals[:, :, None] * face_nanson_scale[:, None, None], axis=0) return jax.flatten_util.ravel_pytree(val)[0] return surface_kernel
[docs] def pre_jit_fns(self): def value_and_jacfwd(f, x): pushfwd = functools.partial(jax.jvp, f, (x, )) basis = np.eye(len(x.reshape(-1)), dtype=x.dtype).reshape(-1, *x.shape) y, jac = jax.vmap(pushfwd, out_axes=(None, -1))((basis, )) return y, jac def get_kernel_fn_cell(): def kernel(cell_sol_flat, physical_quad_points, cell_shape_grads, cell_JxW, cell_v_grads_JxW, *cell_internal_vars): """ universal_kernel should be able to cover all situations (including mass_kernel and laplace_kernel). mass_kernel and laplace_kernel are from legacy JAX-FEM. They can still be used, but not mandatory. """ # TODO: If there is no kernel map, returning 0. is not a good choice. # Return a zero array with proper shape will be better. if hasattr(self, 'get_mass_map'): mass_kernel = self.get_mass_kernel(self.get_mass_map()) mass_val = mass_kernel(cell_sol_flat, physical_quad_points, cell_JxW, *cell_internal_vars) else: mass_val = 0. if hasattr(self, 'get_tensor_map'): laplace_kernel = self.get_laplace_kernel(self.get_tensor_map()) laplace_val = laplace_kernel(cell_sol_flat, cell_shape_grads, cell_v_grads_JxW, *cell_internal_vars) else: laplace_val = 0. if hasattr(self, 'get_universal_kernel'): universal_kernel = self.get_universal_kernel() universal_val = universal_kernel(cell_sol_flat, physical_quad_points, cell_shape_grads, cell_JxW, cell_v_grads_JxW, *cell_internal_vars) else: universal_val = 0. return laplace_val + mass_val + universal_val def kernel_jac(cell_sol_flat, *args): kernel_partial = lambda cell_sol_flat: kernel(cell_sol_flat, *args) return value_and_jacfwd(kernel_partial, cell_sol_flat) # kernel(cell_sol_flat, *args), jax.jacfwd(kernel)(cell_sol_flat, *args) return kernel, kernel_jac def get_kernel_fn_face(ind): def kernel(cell_sol_flat, physical_surface_quad_points, face_shape_vals, face_shape_grads, face_nanson_scale, *cell_internal_vars_surface): """ universal_kernel should be able to cover all situations (including surface_kernel). surface_kernel is from legacy JAX-FEM. It can still be used, but not mandatory. """ if hasattr(self, 'get_surface_maps'): surface_kernel = self.get_surface_kernel(self.get_surface_maps()[ind]) surface_val = surface_kernel(cell_sol_flat, physical_surface_quad_points, face_shape_vals, face_shape_grads, face_nanson_scale, *cell_internal_vars_surface) else: surface_val = 0. if hasattr(self, 'get_universal_kernels_surface'): universal_kernel = self.get_universal_kernels_surface()[ind] universal_val = universal_kernel(cell_sol_flat, physical_surface_quad_points, face_shape_vals, face_shape_grads, face_nanson_scale, *cell_internal_vars_surface) else: universal_val = 0. return surface_val + universal_val def kernel_jac(cell_sol_flat, *args): # return jax.jacfwd(kernel)(cell_sol_flat, *args) kernel_partial = lambda cell_sol_flat: kernel(cell_sol_flat, *args) return value_and_jacfwd(kernel_partial, cell_sol_flat) # kernel(cell_sol_flat, *args), jax.jacfwd(kernel)(cell_sol_flat, *args) return kernel, kernel_jac kernel, kernel_jac = get_kernel_fn_cell() kernel = jax.jit(jax.vmap(kernel)) kernel_jac = jax.jit(jax.vmap(kernel_jac)) self.kernel = kernel self.kernel_jac = kernel_jac num_surfaces = len(self.boundary_inds_list) if hasattr(self, 'get_surface_maps'): assert num_surfaces == len(self.get_surface_maps()) elif hasattr(self, 'get_universal_kernels_surface'): assert num_surfaces == len(self.get_universal_kernels_surface()) else: assert num_surfaces == 0, "Missing definitions for surface integral" self.kernel_face = [] self.kernel_jac_face = [] for i in range(len(self.boundary_inds_list)): kernel_face, kernel_jac_face = get_kernel_fn_face(i) kernel_face = jax.jit(jax.vmap(kernel_face)) kernel_jac_face = jax.jit(jax.vmap(kernel_jac_face)) self.kernel_face.append(kernel_face) self.kernel_jac_face.append(kernel_jac_face)
[docs] @timeit def split_and_compute_cell(self, cells_sol_flat, np_version, jac_flag, internal_vars): """Volume integral in weak form """ vmap_fn = self.kernel_jac if jac_flag else self.kernel num_cuts = 20 if num_cuts > self.num_cells: num_cuts = self.num_cells batch_size = self.num_cells // num_cuts input_collection = [cells_sol_flat, self.physical_quad_points, self.shape_grads, self.JxW, self.v_grads_JxW, *internal_vars] if jac_flag: values = [] jacs = [] for i in range(num_cuts): if i < num_cuts - 1: input_col = jax.tree_map(lambda x: x[i * batch_size:(i + 1) * batch_size], input_collection) else: input_col = jax.tree_map(lambda x: x[i * batch_size:], input_collection) val, jac = vmap_fn(*input_col) values.append(val) jacs.append(jac) values = np_version.vstack(values) jacs = np_version.vstack(jacs) return values, jacs else: values = [] for i in range(num_cuts): if i < num_cuts - 1: input_col = jax.tree_map(lambda x: x[i * batch_size:(i + 1) * batch_size], input_collection) else: input_col = jax.tree_map(lambda x: x[i * batch_size:], input_collection) val = vmap_fn(*input_col) values.append(val) values = np_version.vstack(values) return values
[docs] def compute_face(self, cells_sol_flat, np_version, jac_flag, internal_vars_surfaces): """Surface integral in weak form """ if jac_flag: values = [] jacs = [] for i, boundary_inds in enumerate(self.boundary_inds_list): vmap_fn = self.kernel_jac_face[i] selected_cell_sols_flat = cells_sol_flat[boundary_inds[:, 0]] # (num_selected_faces, num_nodes*vec + ...)) input_collection = [selected_cell_sols_flat, self.physical_surface_quad_points[i], self.selected_face_shape_vals[i], self.selected_face_shape_grads[i], self.nanson_scale[i], *internal_vars_surfaces[i]] val, jac = vmap_fn(*input_collection) values.append(val) jacs.append(jac) return values, jacs else: values = [] for i, boundary_inds in enumerate(self.boundary_inds_list): vmap_fn = self.kernel_face[i] selected_cell_sols_flat = cells_sol_flat[boundary_inds[:, 0]] # (num_selected_faces, num_nodes*vec + ...)) # TODO: duplicated code input_collection = [selected_cell_sols_flat, self.physical_surface_quad_points[i], self.selected_face_shape_vals[i], self.selected_face_shape_grads[i], self.nanson_scale[i], *internal_vars_surfaces[i]] val = vmap_fn(*input_collection) values.append(val) return values
[docs] def compute_residual_vars_helper(self, weak_form_flat, weak_form_face_flat): res_list = [np.zeros((fe.num_total_nodes, fe.vec)) for fe in self.fes] weak_form_list = jax.vmap(lambda x: self.unflatten_fn_dof(x))(weak_form_flat) # [(num_cells, num_nodes, vec), ...] res_list = [res_list[i].at[self.cells_list[i].reshape(-1)].add(weak_form_list[i].reshape(-1, self.fes[i].vec)) for i in range(self.num_vars)] for ind, cells_list_face in enumerate(self.cells_list_face_list): weak_form_face_list = jax.vmap(lambda x: self.unflatten_fn_dof(x))(weak_form_face_flat[ind]) # [(num_selected_faces, num_nodes, vec), ...] res_list = [res_list[i].at[cells_list_face[i].reshape(-1)].add(weak_form_face_list[i].reshape(-1, self.fes[i].vec)) for i in range(self.num_vars)] return res_list
[docs] def compute_residual_vars(self, sol_list, internal_vars, internal_vars_surfaces): logger.debug(f"Computing cell residual...") cells_sol_list = [sol[cells] for cells, sol in zip(self.cells_list, sol_list)] # [(num_cells, num_nodes, vec), ...] cells_sol_flat = jax.vmap(lambda *x: jax.flatten_util.ravel_pytree(x)[0])(*cells_sol_list) # (num_cells, num_nodes*vec + ...) weak_form_flat = self.split_and_compute_cell(cells_sol_flat, np, False, internal_vars) # (num_cells, num_nodes*vec + ...) weak_form_face_flat = self.compute_face(cells_sol_flat, np, False, internal_vars_surfaces) # [(num_selected_faces, num_nodes*vec + ...), ...] return self.compute_residual_vars_helper(weak_form_flat, weak_form_face_flat)
[docs] def compute_newton_vars(self, sol_list, internal_vars, internal_vars_surfaces): logger.debug(f"Computing cell Jacobian and cell residual...") cells_sol_list = [sol[cells] for cells, sol in zip(self.cells_list, sol_list)] # [(num_cells, num_nodes, vec), ...] cells_sol_flat = jax.vmap(lambda *x: jax.flatten_util.ravel_pytree(x)[0])(*cells_sol_list) # (num_cells, num_nodes*vec + ...) # (num_cells, num_nodes*vec + ...), (num_cells, num_nodes*vec + ..., num_nodes*vec + ...) weak_form_flat, cells_jac_flat = self.split_and_compute_cell(cells_sol_flat, onp, True, internal_vars) self.V = onp.array(cells_jac_flat.reshape(-1)) # [(num_selected_faces, num_nodes*vec + ...,), ...], [(num_selected_faces, num_nodes*vec + ..., num_nodes*vec + ...,), ...] weak_form_face_flat, cells_jac_face_flat = self.compute_face(cells_sol_flat, onp, True, internal_vars_surfaces) for cells_jac_f_flat in cells_jac_face_flat: self.V = onp.hstack((self.V, onp.array(cells_jac_f_flat.reshape(-1)))) return self.compute_residual_vars_helper(weak_form_flat, weak_form_face_flat)
[docs] def compute_residual(self, sol_list): return self.compute_residual_vars(sol_list, self.internal_vars, self.internal_vars_surfaces)
[docs] def newton_update(self, sol_list): return self.compute_newton_vars(sol_list, self.internal_vars, self.internal_vars_surfaces)
[docs] def set_params(self, params): """Used for solving inverse problems. """ raise NotImplementedError("Child class must implement this function!")