import jax
import jax.numpy as np
import jax.flatten_util
import numpy as onp
from jax.experimental.sparse import BCOO
import scipy
import time
from petsc4py import PETSc
from jax_fem import logger
from jax import config
config.update("jax_enable_x64", True)
try:
import pyamgx
pyamgx.initialize()
PYAMGX_AVAILABLE = True
except ImportError:
PYAMGX_AVAILABLE = False
def _timing_record(timing, name, dt):
timing[name] += dt
def _log_newton_iter_start(iter_num):
print()
logger.info(" iter %d", iter_num)
def _log_newton_iter_summary(iter_num, local_s, global_s, res_val, rel_res_val, linear_s=None):
logger.info(" nonlinear residual: L2 norm = %.3g (relative to initial = %.3g)",
res_val, rel_res_val)
if linear_s is None:
logger.info(" timing: local assembly %6.3f s, global matrix %6.3f s",
local_s, global_s)
else:
logger.info(" timing: linear solve %6.3f s, local assembly %6.3f s, global matrix %6.3f s",
linear_s, local_s, global_s)
def _log_timing_table(n_iters, parts, wall_s):
rows = (
('local_assembly', 'local'),
('global_matrix', 'global'),
('linear', 'linear'),
)
print()
logger.info("Timing summary — %d Newton iter, %.3f s wall", n_iters, wall_s)
for key, label in rows:
dt = parts[key]
pct = 100. * dt / wall_s if wall_s > 0 else 0.
logger.info(" %-8s %7.3f s %5.1f%%", label, dt, pct)
other = wall_s - sum(parts.values())
if other >= 0.01:
pct = 100. * other / wall_s if wall_s > 0 else 0.
logger.info(" %-8s %7.3f s %5.1f%%", "other", other, pct)
################################################################################
# Linear solvers (JAX / SciPy / PETSc / AMGX)
def jax_solve(A, b, x0, precond):
logger.debug(f"JAX Solver - Solving linear system")
indptr, indices, data = A.getValuesCSR()
A_sp_scipy = scipy.sparse.csr_array((data, indices, indptr), shape=A.getSize())
A = BCOO.from_scipy_sparse(A_sp_scipy).sort_indices()
jacobi = np.array(A_sp_scipy.diagonal())
pc = lambda x: x * (1. / jacobi) if precond else None
if issubclass(PETSc.ScalarType, np.complexfloating):
logger.debug("JAX Solver - Using PETSc with complex number support")
A = A.astype(complex)
b = b.astype(complex)
if x0 is not None:
x0 = x0.astype(complex)
x, info = jax.scipy.sparse.linalg.bicgstab(A,
b,
x0=x0,
M=pc,
tol=1e-10,
atol=1e-10,
maxiter=10000)
# Verify convergence
err = np.linalg.norm(A @ x - b)
logger.debug("JAX Solver - Finished solving, linear solve res = %.3g", err)
assert err < 0.1, f"JAX linear solver failed to converge with err = {err}"
x = np.where(err < 0.1, x, np.nan) # For assert purpose, somehow this also affects bicgstab.
return x
def scipy_spsolve(A, b):
logger.debug("Scipy Solver - Solving linear system with scipy.sparse.linalg.spsolve")
indptr, indices, data = A.getValuesCSR()
Asp = scipy.sparse.csr_matrix((data, indices, indptr))
# SciPy's spsolve uses UMFPACK only when scikits.umfpack is installed and
# applicable; otherwise it falls back to SuperLU.
x = scipy.sparse.linalg.spsolve(Asp, onp.array(b))
# TODO: try https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.linalg.spsolve.html
# x = jax.experimental.sparse.linalg.spsolve(av, aj, ai, b)
logger.debug("Scipy Solver - Finished solving, linear solve res = %.3g",
np.linalg.norm(Asp @ x - b))
return x
def petsc_solve(A, b, ksp_type, pc_type):
rhs = PETSc.Vec().createSeq(len(b))
rhs.setValues(range(len(b)), onp.array(b))
ksp = PETSc.KSP().create()
ksp.setOperators(A)
ksp.setFromOptions()
ksp.setType(ksp_type)
ksp.pc.setType(pc_type)
# TODO: This works better. Do we need to generalize the code a little bit?
if ksp_type == 'tfqmr':
ksp.pc.setFactorSolverType('mumps')
logger.debug(f'PETSc Solver - Solving linear system with ksp_type = {ksp.getType()}, pc = {ksp.pc.getType()}')
x = PETSc.Vec().createSeq(len(b))
ksp.solve(rhs, x)
# Verify convergence
y = PETSc.Vec().createSeq(len(b))
A.mult(x, y)
err = np.linalg.norm(y.getArray() - rhs.getArray())
logger.debug("PETSc Solver - Finished solving, linear solve res = %.3g", err)
assert err < 0.1, f"PETSc linear solver failed to converge, err = {err}"
return x.getArray()
def AMGX_solve_host(indptr, indices, data, shape_arr, x, b, cfg_path):
dtype, shape_b = b.dtype, b.shape
n_rows = int(shape_arr[0])
n_cols = int(shape_arr[1])
A_csr = scipy.sparse.csr_matrix(
(data, indices, indptr),
shape=(n_rows, n_cols)
)
b_host = onp.asarray(b)
x_guess = onp.zeros_like(b_host) if x is None else onp.asarray(x)
cfg = None
resources = None
solver = None
A_amg = None
b_amg = None
x_amg = None
try:
## See: https://github.com/NVIDIA/AMGX/tree/main/src/configs
if cfg_path is not None:
cfg = pyamgx.Config().create_from_file(cfg_path)
else:
cfg = pyamgx.Config().create_from_dict({
"config_version": 2,
"determinism_flag": 1,
"exception_handling": 1,
"solver": {
"solver": "BICGSTAB", # "CG", BICGSTAB
#change to PBICGSTAB to use preconditioners
"use_scalar_norm": 1,
"norm": "L2",
"tolerance": 1e-10,
"monitor_residual": 1,
"max_iters": 10000,
"convergence": "ABSOLUTE", # RELATIVE_INI_CORE
"monitor_residual": 1,
# "print_solve_stats": 1,
"preconditioner": {
"scope": "amg",
"solver": "AMG",
"algorithm": "CLASSICAL",
"smoother": "JACOBI",
"cycle": "V",
"max_levels": 10,
"max_iters": 2
}
}
})
# Create resources
resources = pyamgx.Resources().create_simple(cfg)
solver = pyamgx.Solver().create(resources, cfg)
A_amg = pyamgx.Matrix().create(resources)
b_amg = pyamgx.Vector().create(resources)
x_amg = pyamgx.Vector().create(resources)
A_amg.upload_CSR(A_csr)
b_amg.upload(b_host)
x_amg.upload(x_guess)
solver.setup(A_amg)
solver.solve(b_amg, x_amg)
result = x_amg.download()
result = onp.asarray(result)
res = onp.linalg.norm(A_csr @ result - b_host)
logger.info("AMGX Solver - Finished solving, linear solve res = %.3g", res)
return result.astype(dtype).reshape(shape_b)
finally:
if x_amg is not None:
x_amg.destroy()
if b_amg is not None:
b_amg.destroy()
if A_amg is not None:
A_amg.destroy()
if solver is not None:
solver.destroy()
if resources is not None:
resources.destroy()
if cfg is not None:
cfg.destroy()
# pyamgx.finalize()
def AMGX_solve(A, b, x0, cfg_path):
if not PYAMGX_AVAILABLE:
raise RuntimeError("pyamgx not installed. AMGX solver disabled.")
# A is PETSc.Mat here.
indptr, indices, data = A.getValuesCSR()
n_rows, n_cols = A.getSize()
# Convert to numpy arrays directly to avoid JAX device memory copy overhead
indptr = onp.asarray(indptr, dtype=onp.int32)
indices = onp.asarray(indices, dtype=onp.int32)
# Keep matrix data dtype consistent with b.
data = onp.asarray(data, dtype=onp.asarray(b).dtype)
shape_arr = onp.array([n_rows, n_cols], dtype=onp.int64)
if x0 is None:
x0 = np.zeros_like(b)
result_shape = jax.ShapeDtypeStruct(b.shape, b.dtype)
def amgx_solve_callback(x, b_in):
return AMGX_solve_host(indptr, indices, data, shape_arr, x, b_in, cfg_path)
return jax.pure_callback(
amgx_solve_callback,
result_shape,
x0,
b
)
def linear_solver(A, b, x0, linear_options):
# If user does not specify any solver, set jax_solver as the default one.
if len(linear_options.keys() & {'jax_solver', 'amgx_solver', 'spsolve_solver', 'petsc_solver', 'custom_solver'}) == 0:
linear_options['jax_solver'] = {}
if 'jax_solver' in linear_options:
precond = linear_options['jax_solver']['precond'] if 'precond' in linear_options['jax_solver'] else True
x = jax_solve(A, b, x0, precond)
elif 'amgx_solver' in linear_options:
cfg_path = linear_options['amgx_solver']['cfg_path'] if 'cfg_path' in linear_options['amgx_solver'] else None
x = AMGX_solve(A, b, x0, cfg_path)
elif 'spsolve_solver' in linear_options:
x = scipy_spsolve(A, b)
elif 'petsc_solver' in linear_options:
ksp_type = linear_options['petsc_solver']['ksp_type'] if 'ksp_type' in linear_options['petsc_solver'] else 'bcgsl'
pc_type = linear_options['petsc_solver']['pc_type'] if 'pc_type' in linear_options['petsc_solver'] else 'ilu'
x = petsc_solve(A, b, ksp_type, pc_type)
elif 'custom_solver' in linear_options:
custom_solver = linear_options['custom_solver']
x = custom_solver(A, b, x0, linear_options)
else:
raise NotImplementedError(f"Unknown linear solver.")
return x
################################################################################
# Dirichlet boundary conditions ("row elimination")
def apply_bc_vec(res_vec, dofs, problem, scale=1.):
res_list = problem.unflatten_fn_sol_list(res_vec)
sol_list = problem.unflatten_fn_sol_list(dofs)
for ind, fe in enumerate(problem.fes):
res = res_list[ind]
sol = sol_list[ind]
for i in range(len(fe.node_inds_list)):
res = (res.at[fe.node_inds_list[i], fe.vec_inds_list[i]].set(
sol[fe.node_inds_list[i], fe.vec_inds_list[i]], unique_indices=True))
res = res.at[fe.node_inds_list[i], fe.vec_inds_list[i]].add(-fe.vals_list[i]*scale)
res_list[ind] = res
return jax.flatten_util.ravel_pytree(res_list)[0]
def apply_bc(res_fn, problem, scale=1.):
def res_fn_bc(dofs):
"""Apply Dirichlet boundary conditions
"""
res_vec = res_fn(dofs)
return apply_bc_vec(res_vec, dofs, problem, scale)
return res_fn_bc
def assign_bc(dofs, problem):
sol_list = problem.unflatten_fn_sol_list(dofs)
for ind, fe in enumerate(problem.fes):
sol = sol_list[ind]
for i in range(len(fe.node_inds_list)):
sol = sol.at[fe.node_inds_list[i],
fe.vec_inds_list[i]].set(fe.vals_list[i])
sol_list[ind] = sol
return jax.flatten_util.ravel_pytree(sol_list)[0]
def assign_ones_bc(dofs, problem):
sol_list = problem.unflatten_fn_sol_list(dofs)
for ind, fe in enumerate(problem.fes):
sol = sol_list[ind]
for i in range(len(fe.node_inds_list)):
sol = sol.at[fe.node_inds_list[i],
fe.vec_inds_list[i]].set(1.)
sol_list[ind] = sol
return jax.flatten_util.ravel_pytree(sol_list)[0]
def assign_zeros_bc(dofs, problem):
sol_list = problem.unflatten_fn_sol_list(dofs)
for ind, fe in enumerate(problem.fes):
sol = sol_list[ind]
for i in range(len(fe.node_inds_list)):
sol = sol.at[fe.node_inds_list[i],
fe.vec_inds_list[i]].set(0.)
sol_list[ind] = sol
return jax.flatten_util.ravel_pytree(sol_list)[0]
def copy_bc(dofs, problem):
new_dofs = np.zeros_like(dofs)
sol_list = problem.unflatten_fn_sol_list(dofs)
new_sol_list = problem.unflatten_fn_sol_list(new_dofs)
for ind, fe in enumerate(problem.fes):
sol = sol_list[ind]
new_sol = new_sol_list[ind]
for i in range(len(fe.node_inds_list)):
new_sol = (new_sol.at[fe.node_inds_list[i],
fe.vec_inds_list[i]].set(sol[fe.node_inds_list[i],
fe.vec_inds_list[i]]))
new_sol_list[ind] = new_sol
return jax.flatten_util.ravel_pytree(new_sol_list)[0]
################################################################################
# Newton helpers: flattening and tangent probe
def get_flatten_fn(fn_sol_list, problem):
def fn_dofs(dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
val_list = fn_sol_list(sol_list)
return jax.flatten_util.ravel_pytree(val_list)[0]
return fn_dofs
def operator_to_matrix(operator_fn, problem):
"""Only used for when debugging.
Can be used to print the matrix, check the conditional number, etc.
"""
J = jax.jacfwd(operator_fn)(np.zeros(problem.num_total_dofs_all_vars))
return J
################################################################################
# Newton step (linear increment + optional line search)
def newton_step(problem, res_vec, A, dofs, newton_cfg, timing):
"""One Newton correction: solve :math:`A\\,\\Delta u = -R`, then update ``dofs``.
Returns
-------
dofs : ndarray
linear_s : float
Linear solve wall time (also accumulated in ``timing``).
"""
logger.debug(f"Solving linear system...")
b = -res_vec
# x0 will always be correct at boundary locations
x0_1 = assign_bc(np.zeros(problem.num_total_dofs_all_vars), problem)
if hasattr(problem, 'P_mat'):
x0_2 = copy_bc(problem.P_mat @ dofs, problem)
x0 = problem.P_mat.T @ (x0_1 - x0_2)
else:
x0_2 = copy_bc(dofs, problem)
x0 = x0_1 - x0_2
t0 = time.perf_counter()
inc = linear_solver(A, b, x0, newton_cfg.get('linear', {}))
linear_s = time.perf_counter() - t0
_timing_record(timing, 'linear', linear_s)
if newton_cfg.get('line_search_flag', False):
dofs = line_search(problem, dofs, inc)
else:
dofs = dofs + inc
return dofs, linear_s
def line_search(problem, dofs, inc):
"""
TODO: This is useful for finite deformation plasticity.
"""
res_fn = problem.compute_residual
res_fn = get_flatten_fn(res_fn, problem)
res_fn = apply_bc(res_fn, problem)
def res_norm_fn(alpha):
res_vec = res_fn(dofs + alpha*inc)
return np.linalg.norm(res_vec)
# grad_res_norm_fn = jax.grad(res_norm_fn)
# hess_res_norm_fn = jax.hessian(res_norm_fn)
# tol = 1e-3
# alpha = 1.
# lr = 1.
# grad_alpha = 1.
# while np.abs(grad_alpha) > tol:
# grad_alpha = grad_res_norm_fn(alpha)
# hess_alpha = hess_res_norm_fn(alpha)
# alpha = alpha - 1./hess_alpha*grad_alpha
# print(f"alpha = {alpha}, grad_alpha = {grad_alpha}, hess_alpha = {hess_alpha}")
alpha = 1.
res_norm = res_norm_fn(alpha)
for i in range(3):
alpha *= 0.5
res_norm_half = res_norm_fn(alpha)
logger.debug(f"i = {i}, res_norm = {res_norm}, res_norm_half = {res_norm_half}")
if res_norm_half > res_norm:
alpha *= 2.
break
res_norm = res_norm_half
return dofs + alpha*inc
################################################################################
# Tangent stiffness matrix (PETSc cache)
class _PetscTangentCache:
"""Reusable full-space PETSc tangent built from fixed ``problem.I/J`` COO pattern."""
def __init__(self, problem):
n = problem.num_total_dofs_all_vars
coo_i = onp.asarray(problem.I, dtype=PETSc.IntType)
coo_j = onp.asarray(problem.J, dtype=PETSc.IntType)
self.mat = PETSc.Mat().createAIJ(size=(n, n))
self.mat.setOption(PETSc.Mat.Option.KEEP_NONZERO_PATTERN, True)
self.mat.setPreallocationCOO(coo_i, coo_j)
self.bc_row_inds_list = []
for ind, fe in enumerate(problem.fes):
for i in range(len(fe.node_inds_list)):
row_inds = onp.array(
fe.node_inds_list[i] * fe.vec + fe.vec_inds_list[i] + problem.offset[ind],
dtype=onp.int32,
)
self.bc_row_inds_list.append(row_inds)
def update(self, problem):
values = onp.asarray(problem.V, dtype=onp.float64)
self.mat.setValuesCOO(values)
self.mat.assemble()
for row_inds in self.bc_row_inds_list:
self.mat.zeroRows(row_inds)
return self.mat
def _get_petsc_tangent_cache(problem):
cache = getattr(problem, '_petsc_tangent_cache', None)
if cache is None:
cache = _PetscTangentCache(problem)
problem._petsc_tangent_cache = cache
return cache
def get_A(problem):
logger.debug("Updating cached PETSc tangent from COO values...")
A = _get_petsc_tangent_cache(problem).update(problem)
# Linear multipoint constraints
if hasattr(problem, 'P_mat'):
P = PETSc.Mat().createAIJ(size=problem.P_mat.shape, csr=(problem.P_mat.indptr.astype(PETSc.IntType, copy=False),
problem.P_mat.indices.astype(PETSc.IntType, copy=False), problem.P_mat.data))
tmp = A.matMult(P)
P_T = P.transpose()
A = P_T.matMult(tmp)
return A
################################################################################
# Arc-length: Crisfeld formulation (displacement / force control)
def arc_length_solver_disp_driven(problem, prev_u_vec, prev_lamda, prev_Delta_u_vec,
prev_Delta_lamda, Delta_l=0.1, psi=1.):
def newton_update_helper(dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
res_list = problem.newton_update(sol_list)
res_vec = jax.flatten_util.ravel_pytree(res_list)[0]
res_vec = apply_bc_vec(res_vec, dofs, problem, lamda)
A = get_A(problem)
return res_vec, A
def u_lamda_dot_product(Delta_u_vec1, Delta_lamda1, Delta_u_vec2, Delta_lamda2):
return (np.sum(Delta_u_vec1 * Delta_u_vec2)
+ psi**2. * Delta_lamda1 * Delta_lamda2 * np.sum(u_b**2.))
u_vec = prev_u_vec
lamda = prev_lamda
u_b = assign_bc(np.zeros_like(prev_u_vec), problem)
Delta_u_vec_dir = prev_Delta_u_vec
Delta_lamda_dir = prev_Delta_lamda
tol = 1e-6
res_val = 1.
while res_val > tol:
res_vec, A = newton_update_helper(u_vec)
res_val = np.linalg.norm(res_vec)
logger.debug(f"Arc length solver: res_val = {res_val}")
delta_u_bar = scipy_spsolve(A, -res_vec)
delta_u_t = scipy_spsolve(A, u_b)
Delta_u_vec = u_vec - prev_u_vec
Delta_lamda = lamda - prev_lamda
a1 = np.sum(delta_u_t**2.) + psi**2. * np.sum(u_b**2.)
a2 = (2. * np.sum((Delta_u_vec + delta_u_bar) * delta_u_t)
+ 2. * psi**2. * Delta_lamda * np.sum(u_b**2.))
a3 = (np.sum((Delta_u_vec + delta_u_bar)**2.)
+ psi**2. * Delta_lamda**2. * np.sum(u_b**2.) - Delta_l**2.)
delta_lamda1 = (-a2 + np.sqrt(a2**2. - 4. * a1 * a3)) / (2. * a1)
delta_lamda2 = (-a2 - np.sqrt(a2**2. - 4. * a1 * a3)) / (2. * a1)
logger.debug(f"Arc length solver: delta_lamda1 = {delta_lamda1}, delta_lamda2 = {delta_lamda2}")
assert np.isfinite(delta_lamda1) and np.isfinite(delta_lamda2), (
f"No valid solutions for delta lambda, a1 = {a1}, a2 = {a2}, a3 = {a3}")
delta_u_vec1 = delta_u_bar + delta_lamda1 * delta_u_t
delta_u_vec2 = delta_u_bar + delta_lamda2 * delta_u_t
Delta_u_vec_dir1 = u_vec + delta_u_vec1 - prev_u_vec
Delta_lamda_dir1 = lamda + delta_lamda1 - prev_lamda
dot_prod1 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir,
Delta_u_vec_dir1, Delta_lamda_dir1)
Delta_u_vec_dir2 = u_vec + delta_u_vec2 - prev_u_vec
Delta_lamda_dir2 = lamda + delta_lamda2 - prev_lamda
dot_prod2 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir,
Delta_u_vec_dir2, Delta_lamda_dir2)
if np.abs(dot_prod1) < 1e-10 and np.abs(dot_prod2) < 1e-10:
delta_lamda = np.maximum(delta_lamda1, delta_lamda2)
elif dot_prod1 > dot_prod2:
delta_lamda = delta_lamda1
else:
delta_lamda = delta_lamda2
lamda = lamda + delta_lamda
delta_u = delta_u_bar + delta_lamda * delta_u_t
u_vec = u_vec + delta_u
Delta_u_vec_dir = u_vec - prev_u_vec
Delta_lamda_dir = lamda - prev_lamda
logger.debug(f"Arc length solver: finished for one step, with Delta lambda = {lamda - prev_lamda}")
return u_vec, lamda, Delta_u_vec_dir, Delta_lamda_dir
def get_q_vec(problem):
"""Load vector at ``u=0`` for force-controlled arc-length (``arc_length`` cfg ``q_vec_aux``)."""
dofs = np.zeros(problem.num_total_dofs_all_vars)
sol_list = problem.unflatten_fn_sol_list(dofs)
res_list = problem.newton_update(sol_list)
return jax.flatten_util.ravel_pytree(res_list)[0]
def arc_length_solver_force_driven(problem, prev_u_vec, prev_lamda, prev_Delta_u_vec,
prev_Delta_lamda, q_aux, Delta_l=0.1, psi=1.):
def newton_update_helper(dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
res_list = problem.newton_update(sol_list)
res_vec = jax.flatten_util.ravel_pytree(res_list)[0]
res_vec = apply_bc_vec(res_vec, dofs, problem)
A = get_A(problem)
return res_vec, A
def u_lamda_dot_product(Delta_u_vec1, Delta_lamda1, Delta_u_vec2, Delta_lamda2):
return (np.sum(Delta_u_vec1 * Delta_u_vec2)
+ psi**2. * Delta_lamda1 * Delta_lamda2 * np.sum(q_aux_mapped**2.))
u_vec = prev_u_vec
lamda = prev_lamda
q_aux_mapped = assign_zeros_bc(q_aux, problem)
Delta_u_vec_dir = prev_Delta_u_vec
Delta_lamda_dir = prev_Delta_lamda
tol = 1e-6
res_val = 1.
while res_val > tol:
res_vec, A = newton_update_helper(u_vec)
load_term = (1. - lamda) * q_aux_mapped
res_val = np.linalg.norm(res_vec + load_term)
logger.debug(f"Arc length solver: res_val = {res_val}")
delta_u_bar = scipy_spsolve(A, -(res_vec + load_term))
delta_u_t = scipy_spsolve(A, q_aux_mapped)
Delta_u_vec = u_vec - prev_u_vec
Delta_lamda = lamda - prev_lamda
a1 = np.sum(delta_u_t**2.) + psi**2. * np.sum(q_aux_mapped**2.)
a2 = (2. * np.sum((Delta_u_vec + delta_u_bar) * delta_u_t)
+ 2. * psi**2. * Delta_lamda * np.sum(q_aux_mapped**2.))
a3 = (np.sum((Delta_u_vec + delta_u_bar)**2.)
+ psi**2. * Delta_lamda**2. * np.sum(q_aux_mapped**2.) - Delta_l**2.)
delta_lamda1 = (-a2 + np.sqrt(a2**2. - 4. * a1 * a3)) / (2. * a1)
delta_lamda2 = (-a2 - np.sqrt(a2**2. - 4. * a1 * a3)) / (2. * a1)
logger.debug(f"Arc length solver: delta_lamda1 = {delta_lamda1}, delta_lamda2 = {delta_lamda2}")
assert np.isfinite(delta_lamda1) and np.isfinite(delta_lamda2), (
f"No valid solutions for delta lambda, a1 = {a1}, a2 = {a2}, a3 = {a3}")
delta_u_vec1 = delta_u_bar + delta_lamda1 * delta_u_t
delta_u_vec2 = delta_u_bar + delta_lamda2 * delta_u_t
Delta_u_vec_dir1 = u_vec + delta_u_vec1 - prev_u_vec
Delta_lamda_dir1 = lamda + delta_lamda1 - prev_lamda
dot_prod1 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir,
Delta_u_vec_dir1, Delta_lamda_dir1)
Delta_u_vec_dir2 = u_vec + delta_u_vec2 - prev_u_vec
Delta_lamda_dir2 = lamda + delta_lamda2 - prev_lamda
dot_prod2 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir,
Delta_u_vec_dir2, Delta_lamda_dir2)
if np.abs(dot_prod1) < 1e-10 and np.abs(dot_prod2) < 1e-10:
delta_lamda = np.maximum(delta_lamda1, delta_lamda2)
elif dot_prod1 > dot_prod2:
delta_lamda = delta_lamda1
else:
delta_lamda = delta_lamda2
lamda = lamda + delta_lamda
delta_u = delta_u_bar + delta_lamda * delta_u_t
u_vec = u_vec + delta_u
Delta_u_vec_dir = u_vec - prev_u_vec
Delta_lamda_dir = lamda - prev_lamda
logger.debug(f"Arc length solver: finished for one step, with Delta lambda = {lamda - prev_lamda}")
return u_vec, lamda, Delta_u_vec_dir, Delta_lamda_dir
def _arc_length_newton_polish(problem, sol_list, cfg, lam_continuation):
logger.info(
"Arc-length continuation ended at lambda=%.6f (target=%.6f); "
"standard Newton polish at full load",
lam_continuation, _LAMBDA_TARGET)
polish = dict(cfg.get('newton', {}))
polish['initial_guess'] = sol_list
if cfg.get('linear'):
polish['linear'] = cfg['linear']
return solver(problem, {'newton': polish})
def _finish_arc_length(problem, u_vec, lam, cfg, max_steps, history, control):
sol_list = problem.unflatten_fn_sol_list(onp.asarray(u_vec))
lam_continuation = float(lam)
reached_target = lam_continuation >= _LAMBDA_TARGET
if not reached_target:
logger.warning(
"Arc-length stopped at lambda=%.6f after %d continuation steps "
"(max_continuation_steps=%d); lambda=1 was not reached — "
"the intended forward problem was not solved. "
"Increase max_continuation_steps or adjust arc-length settings.",
lam_continuation, len(history), max_steps)
if reached_target:
sol_list = _arc_length_newton_polish(
problem, sol_list, cfg, lam_continuation)
return sol_list, {
'lam': lam_continuation,
'lambda_target': _LAMBDA_TARGET,
'polished': reached_target,
'history': history,
'control': control,
}
def _solve_arc_length_disp(problem, cfg):
"""Displacement-controlled arc-length (Crisfeld outer loop)."""
psi = cfg.get('psi', 1.)
delta_l = cfg.get('Delta_l', 0.1)
max_steps = cfg.get('max_continuation_steps', 600)
step_callback = cfg.get('step_callback')
u_vec = onp.zeros(problem.num_total_dofs_all_vars)
lam = 0.
delta_u_dir = onp.zeros_like(u_vec)
delta_lam_dir = 0.
history = []
logger.info("Arc-length solve started (displacement control, Crisfeld).")
start = time.time()
for step in range(max_steps):
u_vec, lam, delta_u_dir, delta_lam_dir = arc_length_solver_disp_driven(
problem, u_vec, lam, delta_u_dir, delta_lam_dir, Delta_l=delta_l, psi=psi)
record = {
'step': step,
'lam': float(lam),
'u': onp.asarray(u_vec, dtype=onp.float64),
}
history.append(record)
if step_callback is not None:
step_callback(step, record['u'], lam)
if lam >= _LAMBDA_TARGET:
break
elapsed = time.time() - start
logger.info("Arc-length finished in %.3f s, %d continuation steps, final lambda=%.6f",
elapsed, len(history), lam)
return _finish_arc_length(
problem, u_vec, lam, cfg, max_steps, history, 'displacement')
def _solve_arc_length_force(problem, cfg):
"""Force-controlled arc-length (Crisfeld outer loop)."""
q_aux = cfg.get('q_vec_aux')
if q_aux is None:
raise ValueError("arc_length force control requires cfg['q_vec_aux'].")
psi = cfg.get('psi', 0.5)
delta_l = cfg.get('Delta_l', 0.1)
delta_l_late = cfg.get('Delta_l_late', 1.0)
switch_step = cfg.get('Delta_l_switch_step', 200)
max_steps = cfg.get('max_continuation_steps', 500)
step_callback = cfg.get('step_callback')
u_vec = onp.zeros(problem.num_total_dofs_all_vars)
lam = 0.
delta_u_dir = onp.zeros_like(u_vec)
delta_lam_dir = 0.
history = []
logger.info("Arc-length solve started (force control, Crisfeld).")
start = time.time()
for step in range(max_steps):
dl = delta_l if step < switch_step else delta_l_late
u_vec, lam, delta_u_dir, delta_lam_dir = arc_length_solver_force_driven(
problem, u_vec, lam, delta_u_dir, delta_lam_dir, q_aux, Delta_l=dl, psi=psi)
record = {
'step': step,
'lam': float(lam),
'u': onp.asarray(u_vec, dtype=onp.float64),
}
history.append(record)
if step_callback is not None:
step_callback(step, record['u'], lam)
if lam >= _LAMBDA_TARGET:
break
elapsed = time.time() - start
logger.info("Arc-length finished in %.3f s, %d continuation steps, final lambda=%.6f",
elapsed, len(history), lam)
return _finish_arc_length(
problem, u_vec, lam, cfg, max_steps, history, 'force')
def _solve_arc_length(problem, cfg):
"""
Reference: Vasios, Nikolaos. "Nonlinear analysis of structures." The Arc-Length method.
"""
if 'control' not in cfg:
raise ValueError(
"arc_length requires cfg['control']; use 'displacement' or 'force'.")
control = cfg['control']
if control == 'displacement':
return _solve_arc_length_disp(problem, cfg)
if control == 'force':
return _solve_arc_length_force(problem, cfg)
raise ValueError(f"Unknown arc_length control={control!r}; use 'displacement' or 'force'.")
################################################################################
# Dynamic relaxation
def _solve_dynamic_relax(problem, cfg):
flat_guess = None
if 'initial_guess' in cfg:
initial_guess = jax.lax.stop_gradient(cfg['initial_guess'])
flat_guess = jax.flatten_util.ravel_pytree(initial_guess)[0]
return dynamic_relax_solve(
problem,
tol=cfg.get('tol', 1e-6),
nKMat=cfg.get('nKMat', 50),
nPrint=cfg.get('nPrint', 500),
info=cfg.get('info', True),
info_force=cfg.get('info_force', True),
initial_guess=flat_guess,
linear_options=cfg.get('linear', {}),
)
def assembleCSR(problem, dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
problem.newton_update(sol_list)
A_sp_scipy = scipy.sparse.csr_array((problem.V, (problem.I, problem.J)),
shape=(problem.fes[0].num_total_dofs, problem.fes[0].num_total_dofs))
for ind, fe in enumerate(problem.fes):
for i in range(len(fe.node_inds_list)):
row_inds = onp.array(fe.node_inds_list[i] * fe.vec + fe.vec_inds_list[i] + problem.offset[ind], dtype=onp.int32)
for row_ind in row_inds:
A_sp_scipy.data[A_sp_scipy.indptr[row_ind]: A_sp_scipy.indptr[row_ind + 1]] = 0.
A_sp_scipy[row_ind, row_ind] = 1.
return A_sp_scipy
def calC(t, cmin, cmax):
if t < 0.: t = 0.
c = 2. * onp.sqrt(t)
if (c < cmin): c = cmin
if (c > cmax): c = cmax
return c
def printInfo(error, t, c, tol, eps, qdot, qdotdot, nIters, nPrint, info, info_force):
## printing control
if nIters % nPrint == 1:
#logger.info('\t------------------------------------')
if info_force == True:
print(('\nDR Iteration %d: Max force (residual error) = %g (tol = %g)' +
'Max velocity = %g') % (nIters, error, tol,
np.max(np.absolute(qdot))))
if info == True:
print('\nDamping t: ',t, );
print('Damping coefficient: ', c)
print('Max epsilon: ',np.max(eps))
print('Max acceleration: ',np.max(np.absolute(qdotdot)))
def dynamic_relax_solve(problem, tol=1e-6, nKMat=50, nPrint=500, info=True, info_force=True,
initial_guess=None, linear_options=None):
"""
Implementation of
Luet, David Joseph. Bounding volume hierarchy and non-uniform rational B-splines for contact enforcement
in large deformation finite element analysis of sheet metal forming. Diss. Princeton University, 2016.
Chapter 4.3 Nonlinear System Solution
Particularly good for handling buckling behavior.
There is a FEniCS version of this dynamic relaxation algorithm.
The code below is a direct translation from the FEniCS version.
TODO: Does not support periodic B.C., need some work here.
"""
linear_options = linear_options or {'spsolve_solver': {}}
# TODO: consider these in initial guess
def newton_update_helper(dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
res_list = problem.newton_update(sol_list)
res_vec = jax.flatten_util.ravel_pytree(res_list)[0]
res_vec = apply_bc_vec(res_vec, dofs, problem)
A = get_A(problem)
return res_vec, A
dofs = np.zeros(problem.num_total_dofs_all_vars)
res_vec, A = newton_update_helper(dofs)
dofs, _ = newton_step(problem, res_vec, A, dofs, {'linear': linear_options},
{'local_assembly': 0., 'global_matrix': 0., 'linear': 0.})
if initial_guess is not None:
dofs = initial_guess
dofs = assign_bc(dofs, problem)
# parameters not to change
cmin = 1e-3
cmax = 3.9
h_tilde = 1.1
h = 1.
# initialize all arrays
N = len(dofs) #print("--------num of DOF's: %d-----------" % N)
#initialize displacements, velocities and accelerations
q, qdot, qdotdot = onp.zeros(N), onp.zeros(N), onp.zeros(N)
#initialize displacements, velocities and accelerations from a previous time step
q_old, qdot_old, qdotdot_old = onp.zeros(N), onp.zeros(N), onp.zeros(N)
#initialize the M, eps, R_old arrays
eps, M, R, R_old = onp.zeros(N), onp.zeros(N), onp.zeros(N), onp.zeros(N)
@jax.jit
def assembleVec(dofs):
res_fn = get_flatten_fn(problem.compute_residual, problem)
res_vec = res_fn(dofs)
res_vec = assign_zeros_bc(res_vec, problem)
return res_vec
R = onp.array(assembleVec(dofs))
KCSR = assembleCSR(problem, dofs)
M[:] = h_tilde * h_tilde / 4. * onp.array(
onp.absolute(KCSR).sum(axis=1)).squeeze()
q[:] = dofs
qdot[:] = -h / 2. * R / M
# set the counters for iterations and
nIters, iKMat = 0, 0
error = 1.0
timeZ = time.time() #Measurement of loop time.
assert onp.all(onp.isfinite(M)), f"M not finite"
assert onp.all(onp.isfinite(q)), f"q not finite"
assert onp.all(onp.isfinite(qdot)), f"qdot not finite"
error = onp.max(onp.absolute(R))
while error > tol:
print(f"error = {error}")
# marching forward
q_old[:] = q[:]; R_old[:] = R[:]
q[:] += h*qdot; dofs = np.array(q)
R = onp.array(assembleVec(dofs))
nIters += 1
iKMat += 1
error = onp.max(onp.absolute(R))
# damping calculation
S0 = onp.dot((R - R_old) / h, qdot)
t = S0 / onp.einsum('i,i,i', qdot, M, qdot)
c = calC(t, cmin, cmax)
# determine whether to recal KMat
eps = h_tilde * h_tilde / 4. * onp.absolute(
onp.divide((qdotdot - qdotdot_old), (q - q_old),
out=onp.zeros_like((qdotdot - qdotdot_old)),
where=(q - q_old) != 0))
# calculating the jacobian matrix
if ((onp.max(eps) > 1) and (iKMat > nKMat)): #SPR JAN max --> min
if info == True:
print('\nRecalculating the tangent matrix: ', nIters)
iKMat = 0
KCSR = assembleCSR(problem, dofs)
M[:] = h_tilde * h_tilde / 4. * onp.array(
onp.absolute(KCSR).sum(axis=1)).squeeze()
# compute new velocities and accelerations
qdot_old[:] = qdot[:]; qdotdot_old[:] = qdotdot[:];
qdot = (2.- c*h)/(2 + c*h) * qdot_old - 2.*h/(2.+c*h)* R / M
qdot_old[:] = qdot[:]
qdotdot = qdot - qdot_old
# output on screen
printInfo(error, t, c, tol, eps, qdot, qdotdot, nIters, nPrint, info, info_force)
# check if converged
convergence = True
if onp.isnan(onp.max(onp.absolute(R))):
convergence = False
# print final info
if convergence:
print("DRSolve finished in %d iterations and %fs" % \
(nIters, time.time() - timeZ))
else:
print("FAILED to converged")
sol_list = problem.unflatten_fn_sol_list(dofs)
return sol_list
################################################################################
# solver_options registry and dispatch
#
# Layout:
#
# Top level: at most ONE method key. Omit for Newton; legacy flat dicts
# (petsc_solver, tol, initial_guess, ...) are auto-wrapped as newton.
#
# {'newton': {
# 'tol': 1e-6, 'rel_tol': 1e-8, 'line_search_flag': False,
# 'initial_guess': sol_list,
# 'linear': {'petsc_solver': {}},
# }}
#
# {'arc_length': {
# 'control': 'displacement' | 'force',
# 'return_info': True,
# 'q_vec_aux': ..., 'Delta_l': 0.1, 'step_callback': fn, ...
# 'linear': {'petsc_solver': {}}, # polish + inner solves
# 'newton': {'tol': 1e-6}, # polish only
# }}
#
# {'dynamic_relax': {
# 'tol': 1e-8, 'nKMat': 50, 'initial_guess': sol_list, ...
# 'linear': {'spsolve_solver': {}},
# }}
_METHOD_KEYS = frozenset({'newton', 'arc_length', 'dynamic_relax'})
_LINEAR_OPTION_KEYS = frozenset({
'jax_solver', 'amgx_solver', 'spsolve_solver', 'petsc_solver', 'custom_solver',
})
_NEWTON_OPTION_KEYS = frozenset({'tol', 'rel_tol', 'line_search_flag', 'initial_guess'})
_LAMBDA_TARGET = 1.
def _resolve_solver_options(solver_options):
"""Return (nonlinear_method, method_cfg). Legacy flat dicts become Newton."""
opts = solver_options or {}
methods = [m for m in _METHOD_KEYS if m in opts]
if not methods:
linear = {k: opts[k] for k in _LINEAR_OPTION_KEYS if k in opts}
cfg = {k: opts[k] for k in _NEWTON_OPTION_KEYS if k in opts}
if linear:
cfg['linear'] = linear
return 'newton', cfg
if len(methods) > 1:
raise ValueError(f"Pick one nonlinear method, got {methods}.")
method = methods[0]
if not isinstance(opts[method], dict):
raise ValueError(f"solver_options['{method}'] must be a dict.")
return method, opts[method]
################################################################################
# Nonlinear solver entry point (``solver()``)
[docs]
def solver(problem, solver_options={}):
r"""Solve a nonlinear problem (Newton by default, or arc-length / dynamic relaxation).
The solver imposes Dirichlet B.C. with "row elimination" method. Conceptually,
.. math::
r(u) = D \, r_{\text{unc}}(u) + (I - D)u - u_b \\
A = \frac{\text{d}r}{\text{d}u} = D \frac{\text{d}r}{\text{d}u} + (I - D)
where:
- :math:`r_{\text{unc}}: \mathbb{R}^N\rightarrow\mathbb{R}^N` is the residual function without considering Dirichlet boundary conditions.
- :math:`u\in\mathbb{R}^N` is the FE solution vector.
- :math:`u_b\in\mathbb{R}^N` is the vector for Dirichlet boundary conditions, e.g.,
.. math::
u_b = \begin{bmatrix}
0 \\
0 \\
2 \\
3
\end{bmatrix}
- :math:`D\in\mathbb{R}^{N\times N}` is the auxiliary matrix for masking, e.g.,
.. math::
D = \begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
0 & 0 & 0 & 0 \\
0 & 0 & 0 & 0
\end{bmatrix}
- :math:`I\in\mathbb{R}^{N\times N}` is the ientity matrix, e.g.,
.. math::
I = \begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
0 & 0 & 1 & 0 \\
0 & 0 & 0 & 1
\end{bmatrix}
- :math:`A\in\mathbb{R}^{N\times N}` is the tangent stiffness matrix (the global Jacobian matrix).
Notes
-----
- TODO: Show some comments for linear multipoint constraint handling.
Parameters
----------
problem : Problem
The nonlinear problem to solve
solver_options : dict
Configuration for the nonlinear solve. Use exactly one top-level method key—
``newton``, ``arc_length``, or ``dynamic_relax``. Nest the linear
solver and method-specific options inside that block.
**Newton** (default nonlinear backend)::
solver_options = {
'newton': {
'tol': 1e-5,
'rel_tol': 1e-8,
'line_search_flag': False,
'initial_guess': initial_guess,
'linear': {'petsc_solver': {}},
},
}
**Linear solvers** (keys under ``linear`` in any method block).
Four backends are currently available:
- `JAX solver <https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.sparse.linalg.bicgstab.html>`_
- `SciPy solver <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.spsolve.html>`_
- `PETSc solver <https://www.mcs.anl.gov/petsc/petsc4py-current/docs/apiref/index.html>`_
- `AMGX solver <https://github.com/NVIDIA/AMGX>`_ (requires ``pyamgx``)
Examples nested under ``newton``::
solver_options = {'newton': {'linear': {'jax_solver': {}}}}
solver_options = {'newton': {'linear': {'spsolve_solver': {}}}}
solver_options = {
'newton': {
'linear': {
'petsc_solver': {
'ksp_type': 'bcgsl', # e.g. 'minres', 'gmres', 'tfqmr'
'pc_type': 'ilu', # e.g. 'jacobi'
},
},
},
}
solver_options = {'newton': {'linear': {'amgx_solver': {'cfg_path': 'path/to/amgx.json'}}}}
**Defaults.** Omitted keys are filled in as follows.
Newton (inside a ``newton`` block, or implied when no method key is
given):
- ``tol`` → ``1e-6`` (absolute residual :math:`\ell_2` norm)
- ``rel_tol`` → ``1e-8`` (relative to the initial residual)
- ``line_search_flag`` → ``False``
- ``initial_guess`` → zero displacement vector
- ``linear``: The following are all equivalent for the linear solve::
solver_options = {}
solver_options = {'newton': {}}
solver_options = {'newton': {'linear': {}}}
solver_options = {'newton': {'linear': {'jax_solver': {}}}}
solver_options = {'newton': {'linear': {'jax_solver': {'precond': True}}}}
- ``{'jax_solver': {}}`` → ``precond`` → ``True``
- ``{'petsc_solver': {}}`` → ``ksp_type`` → ``'bcgsl'``; ``pc_type`` → ``'ilu'``
- ``{'amgx_solver': {}}`` → ``cfg_path`` → ``None`` (built-in BICGSTAB + AMG)
**Arc-length** (Crisfeld; ``control`` is required; set
``return_info`` to obtain continuation metadata)::
solver_options = {
'arc_length': {
'control': 'displacement', # or 'force' (needs q_vec_aux)
'return_info': True,
'Delta_l': 0.1,
'linear': {'petsc_solver': {}},
'newton': {'tol': 1e-6}, # optional polish at lambda=1
},
}
**Dynamic relaxation** (useful for buckling paths)::
solver_options = {
'dynamic_relax': {
'tol': 1e-8,
'linear': {'spsolve_solver': {}},
},
}
**Legacy flat dict.** For backward compatibility, a dict with *no*
method key is still accepted and interpreted as Newton. Linear and
Newton keys may appear at the top level, e.g.::
solver_options = {'petsc_solver': {}, 'tol': 1e-5}
is equivalent to specifying::
solver_options = {'newton': {'linear': {'petsc_solver': {}}, 'tol': 1e-5}}
Returns
-------
sol_list : list
"""
method, cfg = _resolve_solver_options(solver_options)
if method == 'arc_length':
sol_list, arc_info = _solve_arc_length(problem, cfg)
if cfg.get('return_info', False):
return sol_list, arc_info
return sol_list
if method == 'dynamic_relax':
return _solve_dynamic_relax(problem, cfg)
print()
logger.info("Solving the nonlinear problem...")
timing = {'local_assembly': 0., 'global_matrix': 0., 'linear': 0.}
wall_start = time.perf_counter()
if 'initial_guess' in cfg:
# We don't want inititual guess to play a role in the differentiation chain.
initial_guess = jax.lax.stop_gradient(cfg['initial_guess'])
dofs = jax.flatten_util.ravel_pytree(initial_guess)[0]
else:
if hasattr(problem, 'P_mat'):
dofs = np.zeros(problem.P_mat.shape[1]) # reduced dofs
else:
dofs = np.zeros(problem.num_total_dofs_all_vars)
rel_tol = cfg.get('rel_tol', 1e-8)
tol = cfg.get('tol', 1e-6)
def newton_update_helper(dofs):
if hasattr(problem, 'P_mat'):
dofs = problem.P_mat @ dofs
sol_list = problem.unflatten_fn_sol_list(dofs)
t0 = time.perf_counter()
res_list = problem.newton_update(sol_list)
local_s = time.perf_counter() - t0
_timing_record(timing, 'local_assembly', local_s)
res_vec = jax.flatten_util.ravel_pytree(res_list)[0]
res_vec = apply_bc_vec(res_vec, dofs, problem)
if hasattr(problem, 'P_mat'):
res_vec = problem.P_mat.T @ res_vec
t0 = time.perf_counter()
A = get_A(problem)
global_s = time.perf_counter() - t0
_timing_record(timing, 'global_matrix', global_s)
return res_vec, A, local_s, global_s
_log_newton_iter_start(0)
res_vec, A, local_s, global_s = newton_update_helper(dofs)
res_val = np.linalg.norm(res_vec)
res_val_initial = res_val
rel_res_val = res_val/res_val_initial
_log_newton_iter_summary(0, local_s, global_s, res_val, rel_res_val)
n_iters = 0
while (rel_res_val > rel_tol) and (res_val > tol):
n_iters += 1
_log_newton_iter_start(n_iters)
dofs, linear_s = newton_step(problem, res_vec, A, dofs, cfg, timing)
res_vec, A, local_s, global_s = newton_update_helper(dofs)
res_val = np.linalg.norm(res_vec)
rel_res_val = res_val/res_val_initial
_log_newton_iter_summary(n_iters, local_s, global_s, res_val, rel_res_val, linear_s)
assert np.all(np.isfinite(res_val)), f"res_val contains NaN, stop the program!"
assert np.all(np.isfinite(dofs)), f"dofs contains NaN, stop the program!"
if hasattr(problem, 'P_mat'):
dofs = problem.P_mat @ dofs
# If sol_list = [[[u1x, u1y],
# [u2x, u2y],
# [u3x, u3y],
# [u4x, u4y]],
# [[p1],
# [p2]]],
# the flattend DOF vector will be [u1x, u1y, u2x, u2y, u3x, u3y, u4x, u4y, p1, p2]
sol_list = problem.unflatten_fn_sol_list(dofs)
_log_timing_table(n_iters, timing, time.perf_counter() - wall_start)
print()
logger.info(f"max of dofs = {np.max(dofs)}")
logger.info(f"min of dofs = {np.min(dofs)}")
return sol_list
################################################################################
# Implicit differentiation (adjoint method)
def implicit_vjp(problem, sol_list, params, v_list, adjoint_solver_options):
def constraint_fn(dofs, params):
"""c(u, p)
"""
problem.set_params(params)
res_fn = problem.compute_residual
res_fn = get_flatten_fn(res_fn, problem)
res_fn = apply_bc(res_fn, problem)
return res_fn(dofs)
def constraint_fn_sol_to_sol(sol_list, params):
dofs = jax.flatten_util.ravel_pytree(sol_list)[0]
con_vec = constraint_fn(dofs, params)
return problem.unflatten_fn_sol_list(con_vec)
def get_partial_params_c_fn(sol_list):
"""c(u=u, p)
"""
def partial_params_c_fn(params):
return constraint_fn_sol_to_sol(sol_list, params)
return partial_params_c_fn
def get_vjp_contraint_fn_params(params, sol_list):
"""v*(partial dc/dp)
"""
partial_c_fn = get_partial_params_c_fn(sol_list)
def vjp_linear_fn(v_list):
primals_output, f_vjp = jax.vjp(partial_c_fn, params)
val, = f_vjp(v_list)
return val
return vjp_linear_fn
problem.set_params(params)
problem.newton_update(sol_list)
A = get_A(problem)
v_vec = jax.flatten_util.ravel_pytree(v_list)[0]
if hasattr(problem, 'P_mat'):
v_vec = problem.P_mat.T @ v_vec
# Be careful that A.transpose() does in-place change to A
# However, A.transpose(A_T) does not do in-place change to A
A_T = PETSc.Mat()
A.transpose(A_T)
adjoint_vec = linear_solver(A_T, v_vec, None, adjoint_solver_options)
if hasattr(problem, 'P_mat'):
adjoint_vec = problem.P_mat @ adjoint_vec
vjp_linear_fn = get_vjp_contraint_fn_params(params, sol_list)
vjp_result = vjp_linear_fn(problem.unflatten_fn_sol_list(adjoint_vec))
vjp_result = jax.tree_util.tree_map(lambda x: -x, vjp_result)
return vjp_result
[docs]
def ad_wrapper(problem, solver_options={}, adjoint_solver_options={}):
"""Automatic differentiation wrapper for the forward problem.
Parameters
----------
problem : Problem
solver_options : dict
Same layout as :func:`solver` (nonlinear method + nested ``linear``).
adjoint_solver_options : dict
Linear solver options for the adjoint solve only (flat dict, e.g.
``{'petsc_solver': {}}``).
Returns
-------
fwd_pred : callable
"""
@jax.custom_vjp
def fwd_pred(params):
problem.set_params(params)
sol_list = solver(problem, solver_options)
return sol_list
def f_fwd(params):
sol_list = fwd_pred(params)
return sol_list, (params, sol_list)
def f_bwd(res, v):
print()
logger.info("Running backward and solving the adjoint problem...")
params, sol_list = res
vjp_result = implicit_vjp(problem, sol_list, params, v, adjoint_solver_options)
return (vjp_result, )
fwd_pred.defvjp(f_fwd, f_bwd)
return fwd_pred