jax_fem.problem module#

class jax_fem.problem.Problem(mesh, vec, dim, ele_type='HEX8', gauss_order=None, dirichlet_bc_info=None, location_fns=None, additional_info=())[source]#

Problem class to handle one FE variable or multiple coupled FE variables.

Parameters:
  • mesh (Mesh)

  • vec (int)

  • dim (int)

  • ele_type (str)

  • gauss_order (int)

  • dirichlet_bc_info (list)

  • location_fns (list)

  • additional_info (tuple)

mesh#

mesh

Type:

Mesh

vec#

vec

Type:

int

dim#

dim

Type:

int

ele_type#

ele_type

Type:

str

gauss_order#

gauss_order

Type:

int

dirichlet_bc_info#

dirichlet_bc_info

Type:

list

location_fns#

A list of location functions useful for surface integrals in the weak form. Such surface integral can be related to Neumann boundary condition, or an integral contributing to the stiffness matrix. Each callable takes a point (NumpyArray) and returns a boolean indicating if the point satisfies the location condition. For example,

[lambda point: np.isclose(point[0], 0., atol=1e-5)]
Type:

list

additional_info#

Any other information that might be useful can be stored here. This is problem dependent.

Type:

tuple

custom_init()[source]#

Child class should override if more things need to be done in initialization

compute_residual(sol_list)[source]#

Given FE solution list, compute the residual list.

Parameters:

sol_list (list) – A list of JaxArray with the shape being (num_total_nodes, vec).

Returns:

res_list – Same shape as sol_list.

Return type:

list

newton_update(sol_list)[source]#

Given FE solution list, compute the tangent stiffness matrix, as well as the residual list.

Parameters:

sol_list (list) – A list of JaxArray with the shape being (num_total_nodes, vec).

Returns:

res_list – Same shape as sol_list. The tangent stiffness matrix is stored internally as instance variables.

Return type:

list

set_params(params)[source]#

This is the key method for solving differentiable inverse problems. We MUST define (override) this method so that params become differentiable. No need to define this method if only forward problem is solved.

For parameters defined on the element quadrature points, we may define

def set_params(self, params):
    # Generally, [params1, params2, ...]
    self.internal_vars = [params]

For parameters defined on the element surface quadrature points, we may define

def set_params(self, params):
    surface_params = params
    # Generally, [[surface1_params1, surface1_params2, ...], [surface2_params1, surface2_params2, ...], ...]
    self.internal_vars_surfaces = [[surface_params]]

Note that params itself can be flexible, but self.internal_vars must accept a precribed input shape. The following example inputs theta as params, and convert it into a well defined self.internal_vars.

def set_params(self, theta):
    # theta is a scalar (float)
    thetas = theta * np.ones((self.fes[0].num_cells, self.fes[0].num_quads))
    # thetas must have the precribed shape to be (num_cells, num_quads, ...)
    self.internal_vars = [thetas]

Then, the following automatic differentiable wrapper must be applied to make fwd_pred a differentiable function.

fwd_pred = ad_wrapper(problem)
sol_list = fwd_pred(params)

A similar argument holds for self.internal_vars_surfaces and get_surface_maps.

Notes

The definition of self.internal_vars is bonded to get_tensor_map. If

def set_params(self, params):
    # params has shape (num_cells, num_quads, shape1, shape2)
    self.internal_vars = [params]

Then, we must have

class Elasticity(Problem):
    def get_tensor_map(self):
        def stress_fn(u_grad, param):
            # param MUST have shape (shape1, shape2)
            # ...
            return stress
        return stress_fn

    def get_mass_map(self):
        def mass_fn(u, x, param):
            # param MUST have shape (shape1, shape2)
            # ...
            return stress
        return mass_fn
params: JaxPytree

The parameters to be differentiated.