Compute gradients#
Problem definition#
In this example, we focus on inverse problems that aim to infer unknown model parameters or loading conditions from observed deformation data, often requiring efficient and accurate gradient computation under large strains. We demostrate the process to compute the derivative by automatic differentiation and validate the results by the finite difference method. The same hyperelastic body as in the hyperelasticity example is considered, i.e., a unit cube with a neo-Hookean solid model. In addition, we have the following definitions:
\(\Omega = (0,1) \times (0,1) \times (0,1)\) (a unit cube)
\(\mathbf{b} = [0, 0, 0]\)
\(\Gamma_D = (0,1) \times (0,1) \times \{0\}\)
\(\mathbf{u}_D = [0, 0, \beta]\)
\(\Gamma_{N_1} = (0,1) \times (0,1) \times \{1\}\)
\(\mathbf{t}_{N_1} = [0, 0, -1000]\)
\(\Gamma_{N_2} = \partial \Omega \setminus (\Gamma_D \cup \Gamma_{N_1})\)
\(\mathbf{t}_{N_2} = [0, 0, 0]\)
To solve the inverse problem, we formulate an objective function that measures the discrepancy between the computed displacement and the target one, and compute its derivatives with respect to model parameters. The objective function is defined as:
where \(N_d\) is the total number of degrees of freedom. \(u[i]\) is the \(i\)-th component of the displacement vector \(\boldsymbol{u}\), which is obtained by solving the following discretized governing PDE:
where \(\boldsymbol{\alpha}\) is the parameter vector. Here, we set up three parameters, \(\boldsymbol{\alpha}_1 = E\) the elasticity modulus, \(\boldsymbol{\alpha}_2 = \rho\) the material density, and \(\boldsymbol{\alpha}_3 = \beta\) the scale factor of the Dirichlet boundary conditions. We can see that \(\boldsymbol{u}(\boldsymbol{\alpha})\) is the implicit function of the parameter vector \(\boldsymbol{\alpha}\).
Implementation#
First, we need to import some useful modules and JAX-FEM specific modules:
[ ]:
# Import some useful modules.
import numpy as onp
import jax
import jax.numpy as np
import os
import glob
import matplotlib.pyplot as plt
# Import JAX-FEM specific modules.
from jax_fem.problem import Problem
from jax_fem.solver import solver, ad_wrapper
from jax_fem.utils import save_sol
from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, box_mesh_gmsh
Weak form#
Similarly, we use the Laplace Kernel to implement hyperelastic constitutive relations by overriding the get_tensor_map
method, and utilize Surface Kernel for boundary condition implementation. The get_surface_maps
method defines boundary loading through the surface mapping function surface_map
, where the returned traction vector is integrated over boundary faces via the Surface Kernel.
For inverse problems, the set_params(params)
method provides the interface for dynamic model parameter updates, decomposing parameters into material properties (\(E\), \(\rho\)) and boundary condition scales (\(\beta\)). This method assigns material density as internal variable self.internal_vars = [rho]
and modifies Dirichlet boundary conditions through self.fe.dirichlet_bc_info
updates.
[ ]:
class HyperElasticity(Problem):
def custom_init(self):
self.fe = self.fes[0]
def get_tensor_map(self):
def psi(F, rho):
E = self.E * rho
nu = 0.3
mu = E/(2.*(1. + nu))
kappa = E/(3.*(1. - 2.*nu))
J = np.linalg.det(F)
Jinv = J**(-2./3.)
I1 = np.trace(F.T @ F)
energy = (mu/2.)*(Jinv*I1 - 3.) + (kappa/2.) * (J - 1.)**2.
return energy
P_fn = jax.grad(psi)
def first_PK_stress(u_grad, rho):
I = np.eye(self.dim)
F = u_grad + I
P = P_fn(F, rho)
return P
return first_PK_stress
def get_surface_maps(self):
def surface_map(u, x):
return np.array([0., 0., 1e3])
return [surface_map]
def set_params(self, params):
E, rho, scale_d = params
self.E = E
self.internal_vars = [rho]
self.fe.dirichlet_bc_info[-1][-1] = get_dirichlet_bottom(scale_d)
self.fe.update_Dirichlet_boundary_conditions(self.fe.dirichlet_bc_info)
Mesh#
[ ]:
# Specify mesh-related information (first-order hexahedron element).
ele_type = 'HEX8'
cell_type = get_meshio_cell_type(ele_type)
data_dir = os.path.join(os.path.dirname(__file__), 'data')
Lx, Ly, Lz = 1., 1., 1.
meshio_mesh = box_mesh_gmsh(Nx=5, Ny=5, Nz=5, Lx=Lx, Ly=Ly, Lz=Lz, data_dir=data_dir, ele_type=ele_type)
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])
Boundary conditions#
Dirichlet boundary condition is applied on the bottom surface (\(z = 0\)):
Displacement in the \(x\) and \(y\) directions is fixed to zero.
Displacement in the \(z\) direction is set to \(\beta L_z\), where \(\beta\) is a scaling parameter.
Neumann boundary condition (surface traction) is applied on the top surface (\(z = L_z\)).
[ ]:
# Define Dirichlet boundary values.
def get_dirichlet_bottom(scale):
def dirichlet_bottom(point):
z_disp = scale*Lz
return z_disp
return dirichlet_bottom
def zero_dirichlet_val(point):
return 0.
# Define boundary locations.
def bottom(point):
return np.isclose(point[2], 0., atol=1e-5)
def top(point):
return np.isclose(point[2], Lz, atol=1e-5)
dirichlet_bc_info = [[bottom]*3, [0, 1, 2], [zero_dirichlet_val]*2 + [get_dirichlet_bottom(1.)]]
location_fns = [top]
Problem#
we can proceed to define the problem in JAX-FEM
[ ]:
# Create an instance of the problem.
problem = HyperElasticity(mesh, vec=3, dim=3, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns)
and define the parameter values:
[ ]:
rho = 0.5*np.ones((problem.fe.num_cells, problem.fe.num_quads))
E = 1.e6
scale_d = 1.
params = [E, rho, scale_d]
Solver#
In jax_fem
, users can easily compute the derivative of the objective function with respect to these parameters through automatic differentiation. We first wrap the forward problem with the function jax_fem.solver.ad_wrapper
, which defines the implicit differentiation through @jax.custom_vjp
. The wrapper defines custom forward and backward passes: the forward pass f_fwd
calls the nonlinear solver to obtain the displacement field, while the backward pass f_bwd
implements the
adjoint method through implicit_vjp
, computing parameter sensitivities via the adjoint method. This enables efficient gradient computation for inverse problems using standard jax
operations like jax.grad
on the composed objective function \(J\).
[ ]:
# Implicit differentiation wrapper
fwd_pred = ad_wrapper(problem)
sol_list = fwd_pred(params)
def test_fn(sol_list):
return np.sum(sol_list[0]**2)
def composed_fn(params):
return test_fn(fwd_pred(params))
val = test_fn(sol_list)
We also use the finte difference method to validate the results.
[ ]:
h = 1e-3 # small perturbation
# Forward difference
E_plus = (1 + h)*E
params_E = [E_plus, rho, scale_d]
dE_fd = (composed_fn(params_E) - val)/(h*E)
rho_plus = rho.at[0, 0].set((1 + h)*rho[0, 0])
params_rho = [E, rho_plus, scale_d]
drho_fd_00 = (composed_fn(params_rho) - val)/(h*rho[0, 0])
scale_d_plus = (1 + h)*scale_d
params_scale_d = [E, rho, scale_d_plus]
dscale_d_fd = (composed_fn(params_scale_d) - val)/(h*scale_d)
# Derivative obtained by automatic differentiation
dE, drho, dscale_d = jax.grad(composed_fn)(params)
Postprocessing#
We then compare the computation results:
[ ]:
# Comparison
print(f"\nDerivative comparison between automatic differentiation (AD) and finite difference (FD)")
print(f"\ndrho[0, 0] = {drho[0, 0]}, drho_fd_00 = {drho_fd_00}")
print(f"\ndscale_d = {dscale_d}, dscale_d_fd = {dscale_d_fd}")
print(f"\ndE = {dE}, dE_fd = {dE_fd}, WRONG results! Please avoid gradients w.r.t self.E")
print(f"This is due to the use of global variable self.E, inside a jax jitted function.")
vtk_path = os.path.join(data_dir, f'vtk/u.vtu')
save_sol(problem.fe, sol_list[0], vtk_path)
which are shown as follows:
Derivative comparison between automatic differentiation (AD) and finite difference (FD)
dE = 4.0641751938577116e-07, dE_fd = 0.0, WRONG results! Please avoid gradients w.r.t self.E
drho[0, 0] = 0.002266954599447443, drho_fd_00 = 0.0022666187078357325
dscale_d = 431.59223609853564, dscale_d_fd = 431.80823609844765
Please refer to this link to download the source file.