import jax
import jax.numpy as np
import jax.flatten_util
import numpy as onp
from jax.experimental.sparse import BCOO
import scipy
import time
from petsc4py import PETSc
from jax_fem import logger
from jax import config
config.update("jax_enable_x64", True)
################################################################################
# JAX solver or scipy solver or PETSc solver
[docs]
def jax_solve(A, b, x0, precond):
"""Solves the equilibrium equation using a JAX solver.
Parameters
----------
precond
Whether to calculate the preconditioner or not
"""
logger.debug(f"JAX Solver - Solving linear system")
indptr, indices, data = A.getValuesCSR()
A_sp_scipy = scipy.sparse.csr_array((data, indices, indptr), shape=A.getSize())
A = BCOO.from_scipy_sparse(A_sp_scipy).sort_indices()
jacobi = np.array(A_sp_scipy.diagonal())
pc = lambda x: x * (1. / jacobi) if precond else None
x, info = jax.scipy.sparse.linalg.bicgstab(A,
b,
x0=x0,
M=pc,
tol=1e-10,
atol=1e-10,
maxiter=10000)
# Verify convergence
err = np.linalg.norm(A @ x - b)
logger.debug(f"JAX Solver - Finshed solving, res = {err}")
assert err < 0.1, f"JAX linear solver failed to converge with err = {err}"
x = np.where(err < 0.1, x, np.nan) # For assert purpose, some how this also affects bicgstab.
return x
[docs]
def umfpack_solve(A, b):
logger.debug(f"Scipy Solver - Solving linear system with UMFPACK")
indptr, indices, data = A.getValuesCSR()
Asp = scipy.sparse.csr_matrix((data, indices, indptr))
x = scipy.sparse.linalg.spsolve(Asp, onp.array(b))
# TODO: try https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.linalg.spsolve.html
# x = jax.experimental.sparse.linalg.spsolve(av, aj, ai, b)
logger.debug(f'Scipy Solver - Finished solving, linear solve res = {np.linalg.norm(Asp @ x - b)}')
return x
[docs]
def petsc_solve(A, b, ksp_type, pc_type):
rhs = PETSc.Vec().createSeq(len(b))
rhs.setValues(range(len(b)), onp.array(b))
ksp = PETSc.KSP().create()
ksp.setOperators(A)
ksp.setFromOptions()
ksp.setType(ksp_type)
ksp.pc.setType(pc_type)
# TODO: This works better. Do we need to generalize the code a little bit?
if ksp_type == 'tfqmr':
ksp.pc.setFactorSolverType('mumps')
logger.debug(f'PETSc Solver - Solving linear system with ksp_type = {ksp.getType()}, pc = {ksp.pc.getType()}')
x = PETSc.Vec().createSeq(len(b))
ksp.solve(rhs, x)
# Verify convergence
y = PETSc.Vec().createSeq(len(b))
A.mult(x, y)
err = np.linalg.norm(y.getArray() - rhs.getArray())
logger.debug(f"PETSc Solver - Finished solving, linear solve res = {err}")
assert err < 0.1, f"PETSc linear solver failed to converge, err = {err}"
return x.getArray()
[docs]
def linear_solver(A, b, x0, solver_options):
# If user does not specify any solver, set jax_solver as the default one.
if len(solver_options.keys() & {'jax_solver', 'umfpack_solver', 'petsc_solver', 'custom_solver'}) == 0:
solver_options['jax_solver'] = {}
if 'jax_solver' in solver_options:
precond = solver_options['jax_solver']['precond'] if 'precond' in solver_options['jax_solver'] else True
x = jax_solve(A, b, x0, precond)
elif 'umfpack_solver' in solver_options:
x = umfpack_solve(A, b)
elif 'petsc_solver' in solver_options:
ksp_type = solver_options['petsc_solver']['ksp_type'] if 'ksp_type' in solver_options['petsc_solver'] else 'bcgsl'
pc_type = solver_options['petsc_solver']['pc_type'] if 'pc_type' in solver_options['petsc_solver'] else 'ilu'
x = petsc_solve(A, b, ksp_type, pc_type)
elif 'custom_solver' in solver_options:
# Users can define their own solver
custom_solver = solver_options['custom_solver']
x = custom_solver(A, b, x0, solver_options)
else:
raise NotImplementedError(f"Unknown linear solver.")
return x
################################################################################
# "row elimination" solver
[docs]
def apply_bc_vec(res_vec, dofs, problem, scale=1.):
res_list = problem.unflatten_fn_sol_list(res_vec)
sol_list = problem.unflatten_fn_sol_list(dofs)
for ind, fe in enumerate(problem.fes):
res = res_list[ind]
sol = sol_list[ind]
for i in range(len(fe.node_inds_list)):
res = (res.at[fe.node_inds_list[i], fe.vec_inds_list[i]].set(
sol[fe.node_inds_list[i], fe.vec_inds_list[i]], unique_indices=True))
res = res.at[fe.node_inds_list[i], fe.vec_inds_list[i]].add(-fe.vals_list[i]*scale)
res_list[ind] = res
return jax.flatten_util.ravel_pytree(res_list)[0]
[docs]
def apply_bc(res_fn, problem, scale=1.):
def res_fn_bc(dofs):
"""Apply Dirichlet boundary conditions
"""
res_vec = res_fn(dofs)
return apply_bc_vec(res_vec, dofs, problem, scale)
return res_fn_bc
[docs]
def assign_bc(dofs, problem):
sol_list = problem.unflatten_fn_sol_list(dofs)
for ind, fe in enumerate(problem.fes):
sol = sol_list[ind]
for i in range(len(fe.node_inds_list)):
sol = sol.at[fe.node_inds_list[i],
fe.vec_inds_list[i]].set(fe.vals_list[i])
sol_list[ind] = sol
return jax.flatten_util.ravel_pytree(sol_list)[0]
[docs]
def assign_ones_bc(dofs, problem):
sol_list = problem.unflatten_fn_sol_list(dofs)
for ind, fe in enumerate(problem.fes):
sol = sol_list[ind]
for i in range(len(fe.node_inds_list)):
sol = sol.at[fe.node_inds_list[i],
fe.vec_inds_list[i]].set(1.)
sol_list[ind] = sol
return jax.flatten_util.ravel_pytree(sol_list)[0]
[docs]
def assign_zeros_bc(dofs, problem):
sol_list = problem.unflatten_fn_sol_list(dofs)
for ind, fe in enumerate(problem.fes):
sol = sol_list[ind]
for i in range(len(fe.node_inds_list)):
sol = sol.at[fe.node_inds_list[i],
fe.vec_inds_list[i]].set(0.)
sol_list[ind] = sol
return jax.flatten_util.ravel_pytree(sol_list)[0]
[docs]
def copy_bc(dofs, problem):
new_dofs = np.zeros_like(dofs)
sol_list = problem.unflatten_fn_sol_list(dofs)
new_sol_list = problem.unflatten_fn_sol_list(new_dofs)
for ind, fe in enumerate(problem.fes):
sol = sol_list[ind]
new_sol = new_sol_list[ind]
for i in range(len(fe.node_inds_list)):
new_sol = (new_sol.at[fe.node_inds_list[i],
fe.vec_inds_list[i]].set(sol[fe.node_inds_list[i],
fe.vec_inds_list[i]]))
new_sol_list[ind] = new_sol
return jax.flatten_util.ravel_pytree(new_sol_list)[0]
[docs]
def get_flatten_fn(fn_sol_list, problem):
def fn_dofs(dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
val_list = fn_sol_list(sol_list)
return jax.flatten_util.ravel_pytree(val_list)[0]
return fn_dofs
[docs]
def operator_to_matrix(operator_fn, problem):
"""Only used for when debugging.
Can be used to print the matrix, check the conditional number, etc.
"""
J = jax.jacfwd(operator_fn)(np.zeros(problem.num_total_dofs_all_vars))
return J
[docs]
def linear_incremental_solver(problem, res_vec, A, dofs, solver_options):
"""
Linear solver at each Newton's iteration
"""
logger.debug(f"Solving linear system...")
b = -res_vec
# x0 will always be correct at boundary locations
x0_1 = assign_bc(np.zeros(problem.num_total_dofs_all_vars), problem)
if hasattr(problem, 'P_mat'):
x0_2 = copy_bc(problem.P_mat @ dofs, problem)
x0 = problem.P_mat.T @ (x0_1 - x0_2)
else:
x0_2 = copy_bc(dofs, problem)
x0 = x0_1 - x0_2
inc = linear_solver(A, b, x0, solver_options)
line_search_flag = solver_options['line_search_flag'] if 'line_search_flag' in solver_options else False
if line_search_flag:
dofs = line_search(problem, dofs, inc)
else:
dofs = dofs + inc
return dofs
[docs]
def line_search(problem, dofs, inc):
"""
TODO: This is useful for finite deformation plasticity.
"""
res_fn = problem.compute_residual
res_fn = get_flatten_fn(res_fn, problem)
res_fn = apply_bc(res_fn, problem)
def res_norm_fn(alpha):
res_vec = res_fn(dofs + alpha*inc)
return np.linalg.norm(res_vec)
# grad_res_norm_fn = jax.grad(res_norm_fn)
# hess_res_norm_fn = jax.hessian(res_norm_fn)
# tol = 1e-3
# alpha = 1.
# lr = 1.
# grad_alpha = 1.
# while np.abs(grad_alpha) > tol:
# grad_alpha = grad_res_norm_fn(alpha)
# hess_alpha = hess_res_norm_fn(alpha)
# alpha = alpha - 1./hess_alpha*grad_alpha
# print(f"alpha = {alpha}, grad_alpha = {grad_alpha}, hess_alpha = {hess_alpha}")
alpha = 1.
res_norm = res_norm_fn(alpha)
for i in range(3):
alpha *= 0.5
res_norm_half = res_norm_fn(alpha)
print(f"i = {i}, res_norm = {res_norm}, res_norm_half = {res_norm_half}")
if res_norm_half > res_norm:
alpha *= 2.
break
res_norm = res_norm_half
return dofs + alpha*inc
[docs]
def get_A(problem):
logger.debug(f"Creating sparse matrix with scipy...")
A_sp_scipy = scipy.sparse.csr_array((onp.array(problem.V), (problem.I, problem.J)),
shape=(problem.num_total_dofs_all_vars, problem.num_total_dofs_all_vars))
# logger.info(f"Global sparse matrix takes about {A_sp_scipy.data.shape[0]*8*3/2**30} G memory to store.")
A = PETSc.Mat().createAIJ(size=A_sp_scipy.shape,
csr=(A_sp_scipy.indptr.astype(PETSc.IntType, copy=False),
A_sp_scipy.indices.astype(PETSc.IntType, copy=False),
A_sp_scipy.data))
for ind, fe in enumerate(problem.fes):
for i in range(len(fe.node_inds_list)):
row_inds = onp.array(fe.node_inds_list[i] * fe.vec + fe.vec_inds_list[i] + problem.offset[ind], dtype=onp.int32)
A.zeroRows(row_inds)
# Linear multipoint constraints
if hasattr(problem, 'P_mat'):
P = PETSc.Mat().createAIJ(size=problem.P_mat.shape, csr=(problem.P_mat.indptr.astype(PETSc.IntType, copy=False),
problem.P_mat.indices.astype(PETSc.IntType, copy=False), problem.P_mat.data))
tmp = A.matMult(P)
P_T = P.transpose()
A = P_T.matMult(tmp)
return A
################################################################################
# The "row elimination" solver
[docs]
def solver(problem, solver_options={}):
"""
Specify exactly either 'jax_solver' or 'umfpack_solver' or 'petsc_solver'
Examples:
(1) solver_options = {'jax_solver': {}}
(2) solver_options = {'umfpack_solver': {}}
(3) solver_options = {'petsc_solver': {'ksp_type': 'bcgsl', 'pc_type': 'jacobi'}, 'initial_guess': some_guess}
Default parameters will be used if no instruction is found:
solver_options =
{
# If multiple solvers are specified or no solver is specified, 'jax_solver' will be used.
'jax_solver':
{
# The JAX built-in linear solver
# Reference: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.sparse.linalg.bicgstab.html
'precond': True,
}
'umfpack_solver':
{
# The scipy solver that calls UMFPACK
# Reference: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.spsolve.html
}
'petsc_solver':
{
# PETSc solver
# For more ksp_type and pc_type: https://www.mcs.anl.gov/petsc/petsc4py-current/docs/apiref/index.html
'ksp_type': 'bcgsl', # e.g., 'minres', 'gmres', 'tfqmr'
'pc_type': 'ilu', # e.g., 'jacobi'
}
'line_search_flag': False, # Line search method
'initial_guess': initial_guess, # Same shape as sol_list
'tol': 1e-5, # Absolute tolerance for residual vector (l2 norm), used in Newton's method
'rel_tol': 1e-8, # Relative tolerance for residual vector (l2 norm), used in Newton's method
}
The solver imposes Dirichlet B.C. with "row elimination" method.
Some memo:
res(u) = D*r(u) + (I - D)u - u_b
D = [[1 0 0 0]
[0 1 0 0]
[0 0 0 0]
[0 0 0 1]]
I = [[1 0 0 0]
[0 1 0 0]
[0 0 1 0]
[0 0 0 1]
A = d(res)/d(u) = D*dr/du + (I - D)
TODO: linear multipoint constraint
The function newton_update computes r(u) and dr/du
"""
logger.debug(f"Calling the row elimination solver for imposing Dirichlet B.C.")
logger.debug("Start timing")
start = time.time()
if 'initial_guess' in solver_options:
# We dont't want inititual guess to play a role in the differentiation chain.
initial_guess = jax.lax.stop_gradient(solver_options['initial_guess'])
dofs = jax.flatten_util.ravel_pytree(initial_guess)[0]
else:
if hasattr(problem, 'P_mat'):
dofs = np.zeros(problem.P_mat.shape[1]) # reduced dofs
else:
dofs = np.zeros(problem.num_total_dofs_all_vars)
rel_tol = solver_options['rel_tol'] if 'rel_tol' in solver_options else 1e-8
tol = solver_options['tol'] if 'tol' in solver_options else 1e-6
def newton_update_helper(dofs):
if hasattr(problem, 'P_mat'):
dofs = problem.P_mat @ dofs
sol_list = problem.unflatten_fn_sol_list(dofs)
res_list = problem.newton_update(sol_list)
res_vec = jax.flatten_util.ravel_pytree(res_list)[0]
res_vec = apply_bc_vec(res_vec, dofs, problem)
if hasattr(problem, 'P_mat'):
res_vec = problem.P_mat.T @ res_vec
A = get_A(problem)
return res_vec, A
res_vec, A = newton_update_helper(dofs)
res_val = np.linalg.norm(res_vec)
res_val_initial = res_val
rel_res_val = res_val/res_val_initial
logger.debug(f"Before, l_2 res = {res_val}, relative l_2 res = {rel_res_val}")
while (rel_res_val > rel_tol) and (res_val > tol):
dofs = linear_incremental_solver(problem, res_vec, A, dofs, solver_options)
res_vec, A = newton_update_helper(dofs)
# logger.debug(f"DEBUG: l_2 res = {np.linalg.norm(apply_bc_vec(A @ dofs, dofs, problem))}")
res_val = np.linalg.norm(res_vec)
rel_res_val = res_val/res_val_initial
logger.debug(f"l_2 res = {res_val}, relative l_2 res = {rel_res_val}")
assert np.all(np.isfinite(res_val)), f"res_val contains NaN, stop the program!"
assert np.all(np.isfinite(dofs)), f"dofs contains NaN, stop the program!"
if hasattr(problem, 'P_mat'):
dofs = problem.P_mat @ dofs
# If sol_list = [[[u1x, u1y],
# [u2x, u2y],
# [u3x, u3y],
# [u4x, u4y]],
# [[p1],
# [p2]]],
# the flattend DOF vector will be [u1x, u1y, u2x, u2y, u3x, u3y, u4x, u4y, p1, p2]
sol_list = problem.unflatten_fn_sol_list(dofs)
end = time.time()
solve_time = end - start
logger.info(f"Solve took {solve_time} [s]")
logger.info(f"max of dofs = {np.max(dofs)}")
logger.info(f"min of dofs = {np.min(dofs)}")
return sol_list
################################################################################
# The "arc length" solver
# Reference: Vasios, Nikolaos. "Nonlinear analysis of structures." The Arc-Length method. Harvard (2015).
# Our implementation follows the Crisfeld's formulation
# TODO: Do we want to merge displacement-control and force-control codes?
[docs]
def arc_length_solver_disp_driven(problem, prev_u_vec, prev_lamda, prev_Delta_u_vec, prev_Delta_lamda, Delta_l=0.1, psi=1.):
"""
TODO: Does not support periodic B.C., need some work here.
"""
def newton_update_helper(dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
res_list = problem.newton_update(sol_list)
res_vec = jax.flatten_util.ravel_pytree(res_list)[0]
res_vec = apply_bc_vec(res_vec, dofs, problem, lamda)
A = get_A(problem)
return res_vec, A
def u_lamda_dot_product(Delta_u_vec1, Delta_lamda1, Delta_u_vec2, Delta_lamda2):
return np.sum(Delta_u_vec1*Delta_u_vec2) + psi**2.*Delta_lamda1*Delta_lamda2*np.sum(u_b**2.)
u_vec = prev_u_vec
lamda = prev_lamda
u_b = assign_bc(np.zeros_like(prev_u_vec), problem)
Delta_u_vec_dir = prev_Delta_u_vec
Delta_lamda_dir = prev_Delta_lamda
tol = 1e-6
res_val = 1.
while res_val > tol:
res_vec, A = newton_update_helper(u_vec)
res_val = np.linalg.norm(res_vec)
logger.debug(f"Arc length solver: res_val = {res_val}")
delta_u_bar = umfpack_solve(A, -res_vec)
delta_u_t = umfpack_solve(A, u_b)
Delta_u_vec = u_vec - prev_u_vec
Delta_lamda = lamda - prev_lamda
a1 = np.sum(delta_u_t**2.) + psi**2.*np.sum(u_b**2.)
a2 = 2.* np.sum((Delta_u_vec + delta_u_bar)*delta_u_t) + 2.*psi**2.*Delta_lamda*np.sum(u_b**2.)
a3 = np.sum((Delta_u_vec + delta_u_bar)**2.) + psi**2.*Delta_lamda**2.*np.sum(u_b**2.) - Delta_l**2.
delta_lamda1 = (-a2 + np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1)
delta_lamda2 = (-a2 - np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1)
logger.debug(f"Arc length solver: delta_lamda1 = {delta_lamda1}, delta_lamda2 = {delta_lamda2}")
assert np.isfinite(delta_lamda1) and np.isfinite(delta_lamda2), f"No valid solutions for delta lambda, a1 = {a1}, a2 = {a2}, a3 = {a3}"
delta_u_vec1 = delta_u_bar + delta_lamda1 * delta_u_t
delta_u_vec2 = delta_u_bar + delta_lamda2 * delta_u_t
Delta_u_vec_dir1 = u_vec + delta_u_vec1 - prev_u_vec
Delta_lamda_dir1 = lamda + delta_lamda1 - prev_lamda
dot_prod1 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir, Delta_u_vec_dir1, Delta_lamda_dir1)
Delta_u_vec_dir2 = u_vec + delta_u_vec2 - prev_u_vec
Delta_lamda_dir2 = lamda + delta_lamda2 - prev_lamda
dot_prod2 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir, Delta_u_vec_dir2, Delta_lamda_dir2)
if np.abs(dot_prod1) < 1e-10 and np.abs(dot_prod2) < 1e-10:
# At initial step, (Delta_u_vec_dir, Delta_lamda_dir) is zero, so both dot_prod1 and dot_prod2 are zero.
# We simply select the larger value for delta_lamda.
delta_lamda = np.maximum(delta_lamda1, delta_lamda2)
elif dot_prod1 > dot_prod2:
delta_lamda = delta_lamda1
else:
delta_lamda = delta_lamda2
lamda = lamda + delta_lamda
delta_u = delta_u_bar + delta_lamda * delta_u_t
u_vec = u_vec + delta_u
Delta_u_vec_dir = u_vec - prev_u_vec
Delta_lamda_dir = lamda - prev_lamda
logger.debug(f"Arc length solver: finished for one step, with Delta lambda = {lamda - prev_lamda}")
return u_vec, lamda, Delta_u_vec_dir, Delta_lamda_dir
[docs]
def arc_length_solver_force_driven(problem, prev_u_vec, prev_lamda, prev_Delta_u_vec, prev_Delta_lamda, q_vec, Delta_l=0.1, psi=1.):
"""
TODO: Does not support periodic B.C., need some work here.
"""
def newton_update_helper(dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
res_list = problem.newton_update(sol_list)
res_vec = jax.flatten_util.ravel_pytree(res_list)[0]
res_vec = apply_bc_vec(res_vec, dofs, problem)
A = get_A(problem)
return res_vec, A
def u_lamda_dot_product(Delta_u_vec1, Delta_lamda1, Delta_u_vec2, Delta_lamda2):
return np.sum(Delta_u_vec1*Delta_u_vec2) + psi**2.*Delta_lamda1*Delta_lamda2*np.sum(q_vec_mapped**2.)
u_vec = prev_u_vec
lamda = prev_lamda
q_vec_mapped = assign_zeros_bc(q_vec, problem)
Delta_u_vec_dir = prev_Delta_u_vec
Delta_lamda_dir = prev_Delta_lamda
tol = 1e-6
res_val = 1.
while res_val > tol:
res_vec, A = newton_update_helper(u_vec)
res_val = np.linalg.norm(res_vec + lamda*q_vec_mapped)
logger.debug(f"Arc length solver: res_val = {res_val}")
# TODO: the scipy umfpack solver seems to be far better than the jax linear solver, so we use umfpack solver here.
# x0_1 = assign_bc(np.zeros_like(u_vec), problem)
# x0_2 = copy_bc(u_vec, problem)
# delta_u_bar = jax_solve(problem, A, -(res_vec + lamda*q_vec_mapped), x0=x0_1 - x0_2, precond=True)
# delta_u_t = jax_solve(problem, A, -q_vec_mapped, x0=np.zeros_like(u_vec), precond=True)
delta_u_bar = umfpack_solve(A, -(res_vec + lamda*q_vec_mapped))
delta_u_t = umfpack_solve(A, -q_vec_mapped)
Delta_u_vec = u_vec - prev_u_vec
Delta_lamda = lamda - prev_lamda
a1 = np.sum(delta_u_t**2.) + psi**2.*np.sum(q_vec_mapped**2.)
a2 = 2.* np.sum((Delta_u_vec + delta_u_bar)*delta_u_t) + 2.*psi**2.*Delta_lamda*np.sum(q_vec_mapped**2.)
a3 = np.sum((Delta_u_vec + delta_u_bar)**2.) + psi**2.*Delta_lamda**2.*np.sum(q_vec_mapped**2.) - Delta_l**2.
delta_lamda1 = (-a2 + np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1)
delta_lamda2 = (-a2 - np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1)
logger.debug(f"Arc length solver: delta_lamda1 = {delta_lamda1}, delta_lamda2 = {delta_lamda2}")
assert np.isfinite(delta_lamda1) and np.isfinite(delta_lamda2), f"No valid solutions for delta lambda, a1 = {a1}, a2 = {a2}, a3 = {a3}"
delta_u_vec1 = delta_u_bar + delta_lamda1 * delta_u_t
delta_u_vec2 = delta_u_bar + delta_lamda2 * delta_u_t
Delta_u_vec_dir1 = u_vec + delta_u_vec1 - prev_u_vec
Delta_lamda_dir1 = lamda + delta_lamda1 - prev_lamda
dot_prod1 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir, Delta_u_vec_dir1, Delta_lamda_dir1)
Delta_u_vec_dir2 = u_vec + delta_u_vec2 - prev_u_vec
Delta_lamda_dir2 = lamda + delta_lamda2 - prev_lamda
dot_prod2 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir, Delta_u_vec_dir2, Delta_lamda_dir2)
if np.abs(dot_prod1) < 1e-10 and np.abs(dot_prod2) < 1e-10:
# At initial step, (Delta_u_vec_dir, Delta_lamda_dir) is zero, so both dot_prod1 and dot_prod2 are zero.
# We simply select the larger value for delta_lamda.
delta_lamda = np.maximum(delta_lamda1, delta_lamda2)
elif dot_prod1 > dot_prod2:
delta_lamda = delta_lamda1
else:
delta_lamda = delta_lamda2
lamda = lamda + delta_lamda
delta_u = delta_u_bar + delta_lamda * delta_u_t
u_vec = u_vec + delta_u
Delta_u_vec_dir = u_vec - prev_u_vec
Delta_lamda_dir = lamda - prev_lamda
logger.debug(f"Arc length solver: finished for one step, with Delta lambda = {lamda - prev_lamda}")
return u_vec, lamda, Delta_u_vec_dir, Delta_lamda_dir
[docs]
def get_q_vec(problem):
"""
Used in the arc length method only, to get the external force vector q_vec
"""
dofs = np.zeros(problem.num_total_dofs_all_vars)
sol_list = problem.unflatten_fn_sol_list(dofs)
res_list = problem.newton_update(sol_list)
q_vec = jax.flatten_util.ravel_pytree(res_list)[0]
return q_vec
################################################################################
# Dynamic relaxation solver
[docs]
def assembleCSR(problem, dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
problem.newton_update(sol_list)
A_sp_scipy = scipy.sparse.csr_array((problem.V, (problem.I, problem.J)),
shape=(problem.fes[0].num_total_dofs, problem.fes[0].num_total_dofs))
for ind, fe in enumerate(problem.fes):
for i in range(len(fe.node_inds_list)):
row_inds = onp.array(fe.node_inds_list[i] * fe.vec + fe.vec_inds_list[i] + problem.offset[ind], dtype=onp.int32)
for row_ind in row_inds:
A_sp_scipy.data[A_sp_scipy.indptr[row_ind]: A_sp_scipy.indptr[row_ind + 1]] = 0.
A_sp_scipy[row_ind, row_ind] = 1.
return A_sp_scipy
[docs]
def calC(t, cmin, cmax):
if t < 0.: t = 0.
c = 2. * onp.sqrt(t)
if (c < cmin): c = cmin
if (c > cmax): c = cmax
return c
[docs]
def printInfo(error, t, c, tol, eps, qdot, qdotdot, nIters, nPrint, info, info_force):
## printing control
if nIters % nPrint == 1:
#logger.info('\t------------------------------------')
if info_force == True:
print(('\nDR Iteration %d: Max force (residual error) = %g (tol = %g)' +
'Max velocity = %g') % (nIters, error, tol,
np.max(np.absolute(qdot))))
if info == True:
print('\nDamping t: ',t, );
print('Damping coefficient: ', c)
print('Max epsilon: ',np.max(eps))
print('Max acceleration: ',np.max(np.absolute(qdotdot)))
[docs]
def dynamic_relax_solve(problem, tol=1e-6, nKMat=50, nPrint=500, info=True, info_force=True, initial_guess=None):
"""
Implementation of
Luet, David Joseph. Bounding volume hierarchy and non-uniform rational B-splines for contact enforcement
in large deformation finite element analysis of sheet metal forming. Diss. Princeton University, 2016.
Chapter 4.3 Nonlinear System Solution
Particularly good for handling buckling behavior.
There is a FEniCS version of this dynamic relaxation algorithm.
The code below is a direct translation from the FEniCS version.
TODO: Does not support periodic B.C., need some work here.
"""
solver_options = {'umfpack_solver': {}}
# TODO: consider these in initial guess
def newton_update_helper(dofs):
sol_list = problem.unflatten_fn_sol_list(dofs)
res_list = problem.newton_update(sol_list)
res_vec = jax.flatten_util.ravel_pytree(res_list)[0]
res_vec = apply_bc_vec(res_vec, dofs, problem)
A = get_A(problem)
return res_vec, A
dofs = np.zeros(problem.num_total_dofs_all_vars)
res_vec, A = newton_update_helper(dofs)
dofs = linear_incremental_solver(problem, res_vec, A, dofs, solver_options)
if initial_guess is not None:
dofs = initial_guess
dofs = assign_bc(dofs, problem)
# parameters not to change
cmin = 1e-3
cmax = 3.9
h_tilde = 1.1
h = 1.
# initialize all arrays
N = len(dofs) #print("--------num of DOF's: %d-----------" % N)
#initialize displacements, velocities and accelerations
q, qdot, qdotdot = onp.zeros(N), onp.zeros(N), onp.zeros(N)
#initialize displacements, velocities and accelerations from a previous time step
q_old, qdot_old, qdotdot_old = onp.zeros(N), onp.zeros(N), onp.zeros(N)
#initialize the M, eps, R_old arrays
eps, M, R, R_old = onp.zeros(N), onp.zeros(N), onp.zeros(N), onp.zeros(N)
@jax.jit
def assembleVec(dofs):
res_fn = get_flatten_fn(problem.compute_residual, problem)
res_vec = res_fn(dofs)
res_vec = assign_zeros_bc(res_vec, problem)
return res_vec
R = onp.array(assembleVec(dofs))
KCSR = assembleCSR(problem, dofs)
M[:] = h_tilde * h_tilde / 4. * onp.array(
onp.absolute(KCSR).sum(axis=1)).squeeze()
q[:] = dofs
qdot[:] = -h / 2. * R / M
# set the counters for iterations and
nIters, iKMat = 0, 0
error = 1.0
timeZ = time.time() #Measurement of loop time.
assert onp.all(onp.isfinite(M)), f"M not finite"
assert onp.all(onp.isfinite(q)), f"q not finite"
assert onp.all(onp.isfinite(qdot)), f"qdot not finite"
error = onp.max(onp.absolute(R))
while error > tol:
print(f"error = {error}")
# marching forward
q_old[:] = q[:]; R_old[:] = R[:]
q[:] += h*qdot; dofs = np.array(q)
R = onp.array(assembleVec(dofs))
nIters += 1
iKMat += 1
error = onp.max(onp.absolute(R))
# damping calculation
S0 = onp.dot((R - R_old) / h, qdot)
t = S0 / onp.einsum('i,i,i', qdot, M, qdot)
c = calC(t, cmin, cmax)
# determine whether to recal KMat
eps = h_tilde * h_tilde / 4. * onp.absolute(
onp.divide((qdotdot - qdotdot_old), (q - q_old),
out=onp.zeros_like((qdotdot - qdotdot_old)),
where=(q - q_old) != 0))
# calculating the jacobian matrix
if ((onp.max(eps) > 1) and (iKMat > nKMat)): #SPR JAN max --> min
if info == True:
print('\nRecalculating the tangent matrix: ', nIters)
iKMat = 0
KCSR = assembleCSR(problem, dofs)
M[:] = h_tilde * h_tilde / 4. * onp.array(
onp.absolute(KCSR).sum(axis=1)).squeeze()
# compute new velocities and accelerations
qdot_old[:] = qdot[:]; qdotdot_old[:] = qdotdot[:];
qdot = (2.- c*h)/(2 + c*h) * qdot_old - 2.*h/(2.+c*h)* R / M
qdot_old[:] = qdot[:]
qdotdot = qdot - qdot_old
# output on screen
printInfo(error, t, c, tol, eps, qdot, qdotdot, nIters, nPrint, info, info_force)
# check if converged
convergence = True
if onp.isnan(onp.max(onp.absolute(R))):
convergence = False
# print final info
if convergence:
print("DRSolve finished in %d iterations and %fs" % \
(nIters, time.time() - timeZ))
else:
print("FAILED to converged")
sol_list = problem.unflatten_fn_sol_list(dofs)
return sol_list[0]
################################################################################
# Implicit differentiation with the adjoint method
[docs]
def implicit_vjp(problem, sol_list, params, v_list, adjoint_solver_options):
def constraint_fn(dofs, params):
"""c(u, p)
"""
problem.set_params(params)
res_fn = problem.compute_residual
res_fn = get_flatten_fn(res_fn, problem)
res_fn = apply_bc(res_fn, problem)
return res_fn(dofs)
def constraint_fn_sol_to_sol(sol_list, params):
dofs = jax.flatten_util.ravel_pytree(sol_list)[0]
con_vec = constraint_fn(dofs, params)
return problem.unflatten_fn_sol_list(con_vec)
def get_partial_params_c_fn(sol_list):
"""c(u=u, p)
"""
def partial_params_c_fn(params):
return constraint_fn_sol_to_sol(sol_list, params)
return partial_params_c_fn
def get_vjp_contraint_fn_params(params, sol_list):
"""v*(partial dc/dp)
"""
partial_c_fn = get_partial_params_c_fn(sol_list)
def vjp_linear_fn(v_list):
primals_output, f_vjp = jax.vjp(partial_c_fn, params)
val, = f_vjp(v_list)
return val
return vjp_linear_fn
problem.set_params(params)
problem.newton_update(sol_list)
A = get_A(problem)
v_vec = jax.flatten_util.ravel_pytree(v_list)[0]
if hasattr(problem, 'P_mat'):
v_vec = problem.P_mat.T @ v_vec
# Be careful that A.transpose() does in-place change to A
adjoint_vec = linear_solver(A.transpose(), v_vec, None, adjoint_solver_options)
if hasattr(problem, 'P_mat'):
adjoint_vec = problem.P_mat @ adjoint_vec
vjp_linear_fn = get_vjp_contraint_fn_params(params, sol_list)
vjp_result = vjp_linear_fn(problem.unflatten_fn_sol_list(adjoint_vec))
vjp_result = jax.tree_map(lambda x: -x, vjp_result)
return vjp_result
[docs]
def ad_wrapper(problem, solver_options={}, adjoint_solver_options={}):
@jax.custom_vjp
def fwd_pred(params):
problem.set_params(params)
sol_list = solver(problem, solver_options)
return sol_list
def f_fwd(params):
sol_list = fwd_pred(params)
return sol_list, (params, sol_list)
def f_bwd(res, v):
logger.info("Running backward and solving the adjoint problem...")
params, sol_list = res
vjp_result = implicit_vjp(problem, sol_list, params, v, adjoint_solver_options)
return (vjp_result, )
fwd_pred.defvjp(f_fwd, f_bwd)
return fwd_pred