jax_fem.problem
module#
- class jax_fem.problem.Problem(mesh, vec, dim, ele_type='HEX8', gauss_order=None, dirichlet_bc_info=None, location_fns=None, additional_info=())[source]#
Problem class to handle one FE variable or multiple coupled FE variables.
- Parameters:
mesh (Mesh)
vec (int)
dim (int)
ele_type (str)
gauss_order (int)
dirichlet_bc_info (list)
location_fns (list)
additional_info (tuple)
- gauss_order#
-
- Type:
int
- dirichlet_bc_info#
-
- Type:
list
- location_fns#
A list of location functions useful for surface integrals in the weak form. Such surface integral can be related to Neumann boundary condition, or an integral contributing to the stiffness matrix. Each callable takes a point (NumpyArray) and returns a boolean indicating if the point satisfies the location condition. For example,
[lambda point: np.isclose(point[0], 0., atol=1e-5)]
- Type:
list
- additional_info#
Any other information that might be useful can be stored here. This is problem dependent.
- Type:
tuple
- compute_residual(sol_list)[source]#
Given FE solution list, compute the residual list.
- Parameters:
sol_list (list) – A list of JaxArray with the shape being (num_total_nodes, vec).
- Returns:
res_list – Same shape as sol_list.
- Return type:
list
- newton_update(sol_list)[source]#
Given FE solution list, compute the tangent stiffness matrix, as well as the residual list.
- Parameters:
sol_list (list) – A list of JaxArray with the shape being (num_total_nodes, vec).
- Returns:
res_list – Same shape as sol_list. The tangent stiffness matrix is stored internally as instance variables.
- Return type:
list
- set_params(params)[source]#
This is the key method for solving differentiable inverse problems. We MUST define (override) this method so that
params
become differentiable. No need to define this method if only forward problem is solved.For parameters defined on the element quadrature points, we may define
def set_params(self, params): # Generally, [params1, params2, ...] self.internal_vars = [params]
For parameters defined on the element surface quadrature points, we may define
def set_params(self, params): surface_params = params # Generally, [[surface1_params1, surface1_params2, ...], [surface2_params1, surface2_params2, ...], ...] self.internal_vars_surfaces = [[surface_params]]
Note that
params
itself can be flexible, butself.internal_vars
must accept a precribed input shape. The following example inputstheta
asparams
, and convert it into a well definedself.internal_vars
.def set_params(self, theta): # theta is a scalar (float) thetas = theta * np.ones((self.fes[0].num_cells, self.fes[0].num_quads)) # thetas must have the precribed shape to be (num_cells, num_quads, ...) self.internal_vars = [thetas]
Then, the following automatic differentiable wrapper must be applied to make
fwd_pred
a differentiable function.fwd_pred = ad_wrapper(problem) sol_list = fwd_pred(params)
A similar argument holds for
self.internal_vars_surfaces
andget_surface_maps
.Notes
The definition of
self.internal_vars
is bonded toget_tensor_map
. Ifdef set_params(self, params): # params has shape (num_cells, num_quads, shape1, shape2) self.internal_vars = [params]
Then, we must have
class Elasticity(Problem): def get_tensor_map(self): def stress_fn(u_grad, param): # param MUST have shape (shape1, shape2) # ... return stress return stress_fn def get_mass_map(self): def mass_fn(u, x, param): # param MUST have shape (shape1, shape2) # ... return stress return mass_fn
- params: JaxPytree
The parameters to be differentiated.