Source code for jax_fem.solver

import jax
import jax.numpy as np
import jax.flatten_util
import numpy as onp
from jax.experimental.sparse import BCOO
import scipy
import time
from petsc4py import PETSc

from jax_fem import logger

from jax import config
config.update("jax_enable_x64", True)


################################################################################
# JAX solver or scipy solver or PETSc solver

[docs] def jax_solve(A, b, x0, precond): """Solves the equilibrium equation using a JAX solver. Parameters ---------- precond Whether to calculate the preconditioner or not """ logger.debug(f"JAX Solver - Solving linear system") indptr, indices, data = A.getValuesCSR() A_sp_scipy = scipy.sparse.csr_array((data, indices, indptr), shape=A.getSize()) A = BCOO.from_scipy_sparse(A_sp_scipy).sort_indices() jacobi = np.array(A_sp_scipy.diagonal()) pc = lambda x: x * (1. / jacobi) if precond else None x, info = jax.scipy.sparse.linalg.bicgstab(A, b, x0=x0, M=pc, tol=1e-10, atol=1e-10, maxiter=10000) # Verify convergence err = np.linalg.norm(A @ x - b) logger.debug(f"JAX Solver - Finshed solving, res = {err}") assert err < 0.1, f"JAX linear solver failed to converge with err = {err}" x = np.where(err < 0.1, x, np.nan) # For assert purpose, some how this also affects bicgstab. return x
[docs] def umfpack_solve(A, b): logger.debug(f"Scipy Solver - Solving linear system with UMFPACK") indptr, indices, data = A.getValuesCSR() Asp = scipy.sparse.csr_matrix((data, indices, indptr)) x = scipy.sparse.linalg.spsolve(Asp, onp.array(b)) # TODO: try https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.linalg.spsolve.html # x = jax.experimental.sparse.linalg.spsolve(av, aj, ai, b) logger.debug(f'Scipy Solver - Finished solving, linear solve res = {np.linalg.norm(Asp @ x - b)}') return x
[docs] def petsc_solve(A, b, ksp_type, pc_type): rhs = PETSc.Vec().createSeq(len(b)) rhs.setValues(range(len(b)), onp.array(b)) ksp = PETSc.KSP().create() ksp.setOperators(A) ksp.setFromOptions() ksp.setType(ksp_type) ksp.pc.setType(pc_type) # TODO: This works better. Do we need to generalize the code a little bit? if ksp_type == 'tfqmr': ksp.pc.setFactorSolverType('mumps') logger.debug(f'PETSc Solver - Solving linear system with ksp_type = {ksp.getType()}, pc = {ksp.pc.getType()}') x = PETSc.Vec().createSeq(len(b)) ksp.solve(rhs, x) # Verify convergence y = PETSc.Vec().createSeq(len(b)) A.mult(x, y) err = np.linalg.norm(y.getArray() - rhs.getArray()) logger.debug(f"PETSc Solver - Finished solving, linear solve res = {err}") assert err < 0.1, f"PETSc linear solver failed to converge, err = {err}" return x.getArray()
[docs] def linear_solver(A, b, x0, solver_options): # If user does not specify any solver, set jax_solver as the default one. if len(solver_options.keys() & {'jax_solver', 'umfpack_solver', 'petsc_solver', 'custom_solver'}) == 0: solver_options['jax_solver'] = {} if 'jax_solver' in solver_options: precond = solver_options['jax_solver']['precond'] if 'precond' in solver_options['jax_solver'] else True x = jax_solve(A, b, x0, precond) elif 'umfpack_solver' in solver_options: x = umfpack_solve(A, b) elif 'petsc_solver' in solver_options: ksp_type = solver_options['petsc_solver']['ksp_type'] if 'ksp_type' in solver_options['petsc_solver'] else 'bcgsl' pc_type = solver_options['petsc_solver']['pc_type'] if 'pc_type' in solver_options['petsc_solver'] else 'ilu' x = petsc_solve(A, b, ksp_type, pc_type) elif 'custom_solver' in solver_options: # Users can define their own solver custom_solver = solver_options['custom_solver'] x = custom_solver(A, b, x0, solver_options) else: raise NotImplementedError(f"Unknown linear solver.") return x
################################################################################ # "row elimination" solver
[docs] def apply_bc_vec(res_vec, dofs, problem, scale=1.): res_list = problem.unflatten_fn_sol_list(res_vec) sol_list = problem.unflatten_fn_sol_list(dofs) for ind, fe in enumerate(problem.fes): res = res_list[ind] sol = sol_list[ind] for i in range(len(fe.node_inds_list)): res = (res.at[fe.node_inds_list[i], fe.vec_inds_list[i]].set( sol[fe.node_inds_list[i], fe.vec_inds_list[i]], unique_indices=True)) res = res.at[fe.node_inds_list[i], fe.vec_inds_list[i]].add(-fe.vals_list[i]*scale) res_list[ind] = res return jax.flatten_util.ravel_pytree(res_list)[0]
[docs] def apply_bc(res_fn, problem, scale=1.): def res_fn_bc(dofs): """Apply Dirichlet boundary conditions """ res_vec = res_fn(dofs) return apply_bc_vec(res_vec, dofs, problem, scale) return res_fn_bc
[docs] def assign_bc(dofs, problem): sol_list = problem.unflatten_fn_sol_list(dofs) for ind, fe in enumerate(problem.fes): sol = sol_list[ind] for i in range(len(fe.node_inds_list)): sol = sol.at[fe.node_inds_list[i], fe.vec_inds_list[i]].set(fe.vals_list[i]) sol_list[ind] = sol return jax.flatten_util.ravel_pytree(sol_list)[0]
[docs] def assign_ones_bc(dofs, problem): sol_list = problem.unflatten_fn_sol_list(dofs) for ind, fe in enumerate(problem.fes): sol = sol_list[ind] for i in range(len(fe.node_inds_list)): sol = sol.at[fe.node_inds_list[i], fe.vec_inds_list[i]].set(1.) sol_list[ind] = sol return jax.flatten_util.ravel_pytree(sol_list)[0]
[docs] def assign_zeros_bc(dofs, problem): sol_list = problem.unflatten_fn_sol_list(dofs) for ind, fe in enumerate(problem.fes): sol = sol_list[ind] for i in range(len(fe.node_inds_list)): sol = sol.at[fe.node_inds_list[i], fe.vec_inds_list[i]].set(0.) sol_list[ind] = sol return jax.flatten_util.ravel_pytree(sol_list)[0]
[docs] def copy_bc(dofs, problem): new_dofs = np.zeros_like(dofs) sol_list = problem.unflatten_fn_sol_list(dofs) new_sol_list = problem.unflatten_fn_sol_list(new_dofs) for ind, fe in enumerate(problem.fes): sol = sol_list[ind] new_sol = new_sol_list[ind] for i in range(len(fe.node_inds_list)): new_sol = (new_sol.at[fe.node_inds_list[i], fe.vec_inds_list[i]].set(sol[fe.node_inds_list[i], fe.vec_inds_list[i]])) new_sol_list[ind] = new_sol return jax.flatten_util.ravel_pytree(new_sol_list)[0]
[docs] def get_flatten_fn(fn_sol_list, problem): def fn_dofs(dofs): sol_list = problem.unflatten_fn_sol_list(dofs) val_list = fn_sol_list(sol_list) return jax.flatten_util.ravel_pytree(val_list)[0] return fn_dofs
[docs] def operator_to_matrix(operator_fn, problem): """Only used for when debugging. Can be used to print the matrix, check the conditional number, etc. """ J = jax.jacfwd(operator_fn)(np.zeros(problem.num_total_dofs_all_vars)) return J
[docs] def linear_incremental_solver(problem, res_vec, A, dofs, solver_options): """ Linear solver at each Newton's iteration """ logger.debug(f"Solving linear system...") b = -res_vec # x0 will always be correct at boundary locations x0_1 = assign_bc(np.zeros(problem.num_total_dofs_all_vars), problem) if hasattr(problem, 'P_mat'): x0_2 = copy_bc(problem.P_mat @ dofs, problem) x0 = problem.P_mat.T @ (x0_1 - x0_2) else: x0_2 = copy_bc(dofs, problem) x0 = x0_1 - x0_2 inc = linear_solver(A, b, x0, solver_options) line_search_flag = solver_options['line_search_flag'] if 'line_search_flag' in solver_options else False if line_search_flag: dofs = line_search(problem, dofs, inc) else: dofs = dofs + inc return dofs
[docs] def get_A(problem): logger.debug(f"Creating sparse matrix with scipy...") A_sp_scipy = scipy.sparse.csr_array((onp.array(problem.V), (problem.I, problem.J)), shape=(problem.num_total_dofs_all_vars, problem.num_total_dofs_all_vars)) # logger.info(f"Global sparse matrix takes about {A_sp_scipy.data.shape[0]*8*3/2**30} G memory to store.") A = PETSc.Mat().createAIJ(size=A_sp_scipy.shape, csr=(A_sp_scipy.indptr.astype(PETSc.IntType, copy=False), A_sp_scipy.indices.astype(PETSc.IntType, copy=False), A_sp_scipy.data)) for ind, fe in enumerate(problem.fes): for i in range(len(fe.node_inds_list)): row_inds = onp.array(fe.node_inds_list[i] * fe.vec + fe.vec_inds_list[i] + problem.offset[ind], dtype=onp.int32) A.zeroRows(row_inds) # Linear multipoint constraints if hasattr(problem, 'P_mat'): P = PETSc.Mat().createAIJ(size=problem.P_mat.shape, csr=(problem.P_mat.indptr.astype(PETSc.IntType, copy=False), problem.P_mat.indices.astype(PETSc.IntType, copy=False), problem.P_mat.data)) tmp = A.matMult(P) P_T = P.transpose() A = P_T.matMult(tmp) return A
################################################################################ # The "row elimination" solver
[docs] def solver(problem, solver_options={}): """ Specify exactly either 'jax_solver' or 'umfpack_solver' or 'petsc_solver' Examples: (1) solver_options = {'jax_solver': {}} (2) solver_options = {'umfpack_solver': {}} (3) solver_options = {'petsc_solver': {'ksp_type': 'bcgsl', 'pc_type': 'jacobi'}, 'initial_guess': some_guess} Default parameters will be used if no instruction is found: solver_options = { # If multiple solvers are specified or no solver is specified, 'jax_solver' will be used. 'jax_solver': { # The JAX built-in linear solver # Reference: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.sparse.linalg.bicgstab.html 'precond': True, } 'umfpack_solver': { # The scipy solver that calls UMFPACK # Reference: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.spsolve.html } 'petsc_solver': { # PETSc solver # For more ksp_type and pc_type: https://www.mcs.anl.gov/petsc/petsc4py-current/docs/apiref/index.html 'ksp_type': 'bcgsl', # e.g., 'minres', 'gmres', 'tfqmr' 'pc_type': 'ilu', # e.g., 'jacobi' } 'line_search_flag': False, # Line search method 'initial_guess': initial_guess, # Same shape as sol_list 'tol': 1e-5, # Absolute tolerance for residual vector (l2 norm), used in Newton's method 'rel_tol': 1e-8, # Relative tolerance for residual vector (l2 norm), used in Newton's method } The solver imposes Dirichlet B.C. with "row elimination" method. Some memo: res(u) = D*r(u) + (I - D)u - u_b D = [[1 0 0 0] [0 1 0 0] [0 0 0 0] [0 0 0 1]] I = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1] A = d(res)/d(u) = D*dr/du + (I - D) TODO: linear multipoint constraint The function newton_update computes r(u) and dr/du """ logger.debug(f"Calling the row elimination solver for imposing Dirichlet B.C.") logger.debug("Start timing") start = time.time() if 'initial_guess' in solver_options: # We dont't want inititual guess to play a role in the differentiation chain. initial_guess = jax.lax.stop_gradient(solver_options['initial_guess']) dofs = jax.flatten_util.ravel_pytree(initial_guess)[0] else: if hasattr(problem, 'P_mat'): dofs = np.zeros(problem.P_mat.shape[1]) # reduced dofs else: dofs = np.zeros(problem.num_total_dofs_all_vars) rel_tol = solver_options['rel_tol'] if 'rel_tol' in solver_options else 1e-8 tol = solver_options['tol'] if 'tol' in solver_options else 1e-6 def newton_update_helper(dofs): if hasattr(problem, 'P_mat'): dofs = problem.P_mat @ dofs sol_list = problem.unflatten_fn_sol_list(dofs) res_list = problem.newton_update(sol_list) res_vec = jax.flatten_util.ravel_pytree(res_list)[0] res_vec = apply_bc_vec(res_vec, dofs, problem) if hasattr(problem, 'P_mat'): res_vec = problem.P_mat.T @ res_vec A = get_A(problem) return res_vec, A res_vec, A = newton_update_helper(dofs) res_val = np.linalg.norm(res_vec) res_val_initial = res_val rel_res_val = res_val/res_val_initial logger.debug(f"Before, l_2 res = {res_val}, relative l_2 res = {rel_res_val}") while (rel_res_val > rel_tol) and (res_val > tol): dofs = linear_incremental_solver(problem, res_vec, A, dofs, solver_options) res_vec, A = newton_update_helper(dofs) # logger.debug(f"DEBUG: l_2 res = {np.linalg.norm(apply_bc_vec(A @ dofs, dofs, problem))}") res_val = np.linalg.norm(res_vec) rel_res_val = res_val/res_val_initial logger.debug(f"l_2 res = {res_val}, relative l_2 res = {rel_res_val}") assert np.all(np.isfinite(res_val)), f"res_val contains NaN, stop the program!" assert np.all(np.isfinite(dofs)), f"dofs contains NaN, stop the program!" if hasattr(problem, 'P_mat'): dofs = problem.P_mat @ dofs # If sol_list = [[[u1x, u1y], # [u2x, u2y], # [u3x, u3y], # [u4x, u4y]], # [[p1], # [p2]]], # the flattend DOF vector will be [u1x, u1y, u2x, u2y, u3x, u3y, u4x, u4y, p1, p2] sol_list = problem.unflatten_fn_sol_list(dofs) end = time.time() solve_time = end - start logger.info(f"Solve took {solve_time} [s]") logger.info(f"max of dofs = {np.max(dofs)}") logger.info(f"min of dofs = {np.min(dofs)}") return sol_list
################################################################################ # The "arc length" solver # Reference: Vasios, Nikolaos. "Nonlinear analysis of structures." The Arc-Length method. Harvard (2015). # Our implementation follows the Crisfeld's formulation # TODO: Do we want to merge displacement-control and force-control codes?
[docs] def arc_length_solver_disp_driven(problem, prev_u_vec, prev_lamda, prev_Delta_u_vec, prev_Delta_lamda, Delta_l=0.1, psi=1.): """ TODO: Does not support periodic B.C., need some work here. """ def newton_update_helper(dofs): sol_list = problem.unflatten_fn_sol_list(dofs) res_list = problem.newton_update(sol_list) res_vec = jax.flatten_util.ravel_pytree(res_list)[0] res_vec = apply_bc_vec(res_vec, dofs, problem, lamda) A = get_A(problem) return res_vec, A def u_lamda_dot_product(Delta_u_vec1, Delta_lamda1, Delta_u_vec2, Delta_lamda2): return np.sum(Delta_u_vec1*Delta_u_vec2) + psi**2.*Delta_lamda1*Delta_lamda2*np.sum(u_b**2.) u_vec = prev_u_vec lamda = prev_lamda u_b = assign_bc(np.zeros_like(prev_u_vec), problem) Delta_u_vec_dir = prev_Delta_u_vec Delta_lamda_dir = prev_Delta_lamda tol = 1e-6 res_val = 1. while res_val > tol: res_vec, A = newton_update_helper(u_vec) res_val = np.linalg.norm(res_vec) logger.debug(f"Arc length solver: res_val = {res_val}") delta_u_bar = umfpack_solve(A, -res_vec) delta_u_t = umfpack_solve(A, u_b) Delta_u_vec = u_vec - prev_u_vec Delta_lamda = lamda - prev_lamda a1 = np.sum(delta_u_t**2.) + psi**2.*np.sum(u_b**2.) a2 = 2.* np.sum((Delta_u_vec + delta_u_bar)*delta_u_t) + 2.*psi**2.*Delta_lamda*np.sum(u_b**2.) a3 = np.sum((Delta_u_vec + delta_u_bar)**2.) + psi**2.*Delta_lamda**2.*np.sum(u_b**2.) - Delta_l**2. delta_lamda1 = (-a2 + np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1) delta_lamda2 = (-a2 - np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1) logger.debug(f"Arc length solver: delta_lamda1 = {delta_lamda1}, delta_lamda2 = {delta_lamda2}") assert np.isfinite(delta_lamda1) and np.isfinite(delta_lamda2), f"No valid solutions for delta lambda, a1 = {a1}, a2 = {a2}, a3 = {a3}" delta_u_vec1 = delta_u_bar + delta_lamda1 * delta_u_t delta_u_vec2 = delta_u_bar + delta_lamda2 * delta_u_t Delta_u_vec_dir1 = u_vec + delta_u_vec1 - prev_u_vec Delta_lamda_dir1 = lamda + delta_lamda1 - prev_lamda dot_prod1 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir, Delta_u_vec_dir1, Delta_lamda_dir1) Delta_u_vec_dir2 = u_vec + delta_u_vec2 - prev_u_vec Delta_lamda_dir2 = lamda + delta_lamda2 - prev_lamda dot_prod2 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir, Delta_u_vec_dir2, Delta_lamda_dir2) if np.abs(dot_prod1) < 1e-10 and np.abs(dot_prod2) < 1e-10: # At initial step, (Delta_u_vec_dir, Delta_lamda_dir) is zero, so both dot_prod1 and dot_prod2 are zero. # We simply select the larger value for delta_lamda. delta_lamda = np.maximum(delta_lamda1, delta_lamda2) elif dot_prod1 > dot_prod2: delta_lamda = delta_lamda1 else: delta_lamda = delta_lamda2 lamda = lamda + delta_lamda delta_u = delta_u_bar + delta_lamda * delta_u_t u_vec = u_vec + delta_u Delta_u_vec_dir = u_vec - prev_u_vec Delta_lamda_dir = lamda - prev_lamda logger.debug(f"Arc length solver: finished for one step, with Delta lambda = {lamda - prev_lamda}") return u_vec, lamda, Delta_u_vec_dir, Delta_lamda_dir
[docs] def arc_length_solver_force_driven(problem, prev_u_vec, prev_lamda, prev_Delta_u_vec, prev_Delta_lamda, q_vec, Delta_l=0.1, psi=1.): """ TODO: Does not support periodic B.C., need some work here. """ def newton_update_helper(dofs): sol_list = problem.unflatten_fn_sol_list(dofs) res_list = problem.newton_update(sol_list) res_vec = jax.flatten_util.ravel_pytree(res_list)[0] res_vec = apply_bc_vec(res_vec, dofs, problem) A = get_A(problem) return res_vec, A def u_lamda_dot_product(Delta_u_vec1, Delta_lamda1, Delta_u_vec2, Delta_lamda2): return np.sum(Delta_u_vec1*Delta_u_vec2) + psi**2.*Delta_lamda1*Delta_lamda2*np.sum(q_vec_mapped**2.) u_vec = prev_u_vec lamda = prev_lamda q_vec_mapped = assign_zeros_bc(q_vec, problem) Delta_u_vec_dir = prev_Delta_u_vec Delta_lamda_dir = prev_Delta_lamda tol = 1e-6 res_val = 1. while res_val > tol: res_vec, A = newton_update_helper(u_vec) res_val = np.linalg.norm(res_vec + lamda*q_vec_mapped) logger.debug(f"Arc length solver: res_val = {res_val}") # TODO: the scipy umfpack solver seems to be far better than the jax linear solver, so we use umfpack solver here. # x0_1 = assign_bc(np.zeros_like(u_vec), problem) # x0_2 = copy_bc(u_vec, problem) # delta_u_bar = jax_solve(problem, A, -(res_vec + lamda*q_vec_mapped), x0=x0_1 - x0_2, precond=True) # delta_u_t = jax_solve(problem, A, -q_vec_mapped, x0=np.zeros_like(u_vec), precond=True) delta_u_bar = umfpack_solve(A, -(res_vec + lamda*q_vec_mapped)) delta_u_t = umfpack_solve(A, -q_vec_mapped) Delta_u_vec = u_vec - prev_u_vec Delta_lamda = lamda - prev_lamda a1 = np.sum(delta_u_t**2.) + psi**2.*np.sum(q_vec_mapped**2.) a2 = 2.* np.sum((Delta_u_vec + delta_u_bar)*delta_u_t) + 2.*psi**2.*Delta_lamda*np.sum(q_vec_mapped**2.) a3 = np.sum((Delta_u_vec + delta_u_bar)**2.) + psi**2.*Delta_lamda**2.*np.sum(q_vec_mapped**2.) - Delta_l**2. delta_lamda1 = (-a2 + np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1) delta_lamda2 = (-a2 - np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1) logger.debug(f"Arc length solver: delta_lamda1 = {delta_lamda1}, delta_lamda2 = {delta_lamda2}") assert np.isfinite(delta_lamda1) and np.isfinite(delta_lamda2), f"No valid solutions for delta lambda, a1 = {a1}, a2 = {a2}, a3 = {a3}" delta_u_vec1 = delta_u_bar + delta_lamda1 * delta_u_t delta_u_vec2 = delta_u_bar + delta_lamda2 * delta_u_t Delta_u_vec_dir1 = u_vec + delta_u_vec1 - prev_u_vec Delta_lamda_dir1 = lamda + delta_lamda1 - prev_lamda dot_prod1 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir, Delta_u_vec_dir1, Delta_lamda_dir1) Delta_u_vec_dir2 = u_vec + delta_u_vec2 - prev_u_vec Delta_lamda_dir2 = lamda + delta_lamda2 - prev_lamda dot_prod2 = u_lamda_dot_product(Delta_u_vec_dir, Delta_lamda_dir, Delta_u_vec_dir2, Delta_lamda_dir2) if np.abs(dot_prod1) < 1e-10 and np.abs(dot_prod2) < 1e-10: # At initial step, (Delta_u_vec_dir, Delta_lamda_dir) is zero, so both dot_prod1 and dot_prod2 are zero. # We simply select the larger value for delta_lamda. delta_lamda = np.maximum(delta_lamda1, delta_lamda2) elif dot_prod1 > dot_prod2: delta_lamda = delta_lamda1 else: delta_lamda = delta_lamda2 lamda = lamda + delta_lamda delta_u = delta_u_bar + delta_lamda * delta_u_t u_vec = u_vec + delta_u Delta_u_vec_dir = u_vec - prev_u_vec Delta_lamda_dir = lamda - prev_lamda logger.debug(f"Arc length solver: finished for one step, with Delta lambda = {lamda - prev_lamda}") return u_vec, lamda, Delta_u_vec_dir, Delta_lamda_dir
[docs] def get_q_vec(problem): """ Used in the arc length method only, to get the external force vector q_vec """ dofs = np.zeros(problem.num_total_dofs_all_vars) sol_list = problem.unflatten_fn_sol_list(dofs) res_list = problem.newton_update(sol_list) q_vec = jax.flatten_util.ravel_pytree(res_list)[0] return q_vec
################################################################################ # Dynamic relaxation solver
[docs] def assembleCSR(problem, dofs): sol_list = problem.unflatten_fn_sol_list(dofs) problem.newton_update(sol_list) A_sp_scipy = scipy.sparse.csr_array((problem.V, (problem.I, problem.J)), shape=(problem.fes[0].num_total_dofs, problem.fes[0].num_total_dofs)) for ind, fe in enumerate(problem.fes): for i in range(len(fe.node_inds_list)): row_inds = onp.array(fe.node_inds_list[i] * fe.vec + fe.vec_inds_list[i] + problem.offset[ind], dtype=onp.int32) for row_ind in row_inds: A_sp_scipy.data[A_sp_scipy.indptr[row_ind]: A_sp_scipy.indptr[row_ind + 1]] = 0. A_sp_scipy[row_ind, row_ind] = 1. return A_sp_scipy
[docs] def calC(t, cmin, cmax): if t < 0.: t = 0. c = 2. * onp.sqrt(t) if (c < cmin): c = cmin if (c > cmax): c = cmax return c
[docs] def printInfo(error, t, c, tol, eps, qdot, qdotdot, nIters, nPrint, info, info_force): ## printing control if nIters % nPrint == 1: #logger.info('\t------------------------------------') if info_force == True: print(('\nDR Iteration %d: Max force (residual error) = %g (tol = %g)' + 'Max velocity = %g') % (nIters, error, tol, np.max(np.absolute(qdot)))) if info == True: print('\nDamping t: ',t, ); print('Damping coefficient: ', c) print('Max epsilon: ',np.max(eps)) print('Max acceleration: ',np.max(np.absolute(qdotdot)))
[docs] def dynamic_relax_solve(problem, tol=1e-6, nKMat=50, nPrint=500, info=True, info_force=True, initial_guess=None): """ Implementation of Luet, David Joseph. Bounding volume hierarchy and non-uniform rational B-splines for contact enforcement in large deformation finite element analysis of sheet metal forming. Diss. Princeton University, 2016. Chapter 4.3 Nonlinear System Solution Particularly good for handling buckling behavior. There is a FEniCS version of this dynamic relaxation algorithm. The code below is a direct translation from the FEniCS version. TODO: Does not support periodic B.C., need some work here. """ solver_options = {'umfpack_solver': {}} # TODO: consider these in initial guess def newton_update_helper(dofs): sol_list = problem.unflatten_fn_sol_list(dofs) res_list = problem.newton_update(sol_list) res_vec = jax.flatten_util.ravel_pytree(res_list)[0] res_vec = apply_bc_vec(res_vec, dofs, problem) A = get_A(problem) return res_vec, A dofs = np.zeros(problem.num_total_dofs_all_vars) res_vec, A = newton_update_helper(dofs) dofs = linear_incremental_solver(problem, res_vec, A, dofs, solver_options) if initial_guess is not None: dofs = initial_guess dofs = assign_bc(dofs, problem) # parameters not to change cmin = 1e-3 cmax = 3.9 h_tilde = 1.1 h = 1. # initialize all arrays N = len(dofs) #print("--------num of DOF's: %d-----------" % N) #initialize displacements, velocities and accelerations q, qdot, qdotdot = onp.zeros(N), onp.zeros(N), onp.zeros(N) #initialize displacements, velocities and accelerations from a previous time step q_old, qdot_old, qdotdot_old = onp.zeros(N), onp.zeros(N), onp.zeros(N) #initialize the M, eps, R_old arrays eps, M, R, R_old = onp.zeros(N), onp.zeros(N), onp.zeros(N), onp.zeros(N) @jax.jit def assembleVec(dofs): res_fn = get_flatten_fn(problem.compute_residual, problem) res_vec = res_fn(dofs) res_vec = assign_zeros_bc(res_vec, problem) return res_vec R = onp.array(assembleVec(dofs)) KCSR = assembleCSR(problem, dofs) M[:] = h_tilde * h_tilde / 4. * onp.array( onp.absolute(KCSR).sum(axis=1)).squeeze() q[:] = dofs qdot[:] = -h / 2. * R / M # set the counters for iterations and nIters, iKMat = 0, 0 error = 1.0 timeZ = time.time() #Measurement of loop time. assert onp.all(onp.isfinite(M)), f"M not finite" assert onp.all(onp.isfinite(q)), f"q not finite" assert onp.all(onp.isfinite(qdot)), f"qdot not finite" error = onp.max(onp.absolute(R)) while error > tol: print(f"error = {error}") # marching forward q_old[:] = q[:]; R_old[:] = R[:] q[:] += h*qdot; dofs = np.array(q) R = onp.array(assembleVec(dofs)) nIters += 1 iKMat += 1 error = onp.max(onp.absolute(R)) # damping calculation S0 = onp.dot((R - R_old) / h, qdot) t = S0 / onp.einsum('i,i,i', qdot, M, qdot) c = calC(t, cmin, cmax) # determine whether to recal KMat eps = h_tilde * h_tilde / 4. * onp.absolute( onp.divide((qdotdot - qdotdot_old), (q - q_old), out=onp.zeros_like((qdotdot - qdotdot_old)), where=(q - q_old) != 0)) # calculating the jacobian matrix if ((onp.max(eps) > 1) and (iKMat > nKMat)): #SPR JAN max --> min if info == True: print('\nRecalculating the tangent matrix: ', nIters) iKMat = 0 KCSR = assembleCSR(problem, dofs) M[:] = h_tilde * h_tilde / 4. * onp.array( onp.absolute(KCSR).sum(axis=1)).squeeze() # compute new velocities and accelerations qdot_old[:] = qdot[:]; qdotdot_old[:] = qdotdot[:]; qdot = (2.- c*h)/(2 + c*h) * qdot_old - 2.*h/(2.+c*h)* R / M qdot_old[:] = qdot[:] qdotdot = qdot - qdot_old # output on screen printInfo(error, t, c, tol, eps, qdot, qdotdot, nIters, nPrint, info, info_force) # check if converged convergence = True if onp.isnan(onp.max(onp.absolute(R))): convergence = False # print final info if convergence: print("DRSolve finished in %d iterations and %fs" % \ (nIters, time.time() - timeZ)) else: print("FAILED to converged") sol_list = problem.unflatten_fn_sol_list(dofs) return sol_list[0]
################################################################################ # Implicit differentiation with the adjoint method
[docs] def implicit_vjp(problem, sol_list, params, v_list, adjoint_solver_options): def constraint_fn(dofs, params): """c(u, p) """ problem.set_params(params) res_fn = problem.compute_residual res_fn = get_flatten_fn(res_fn, problem) res_fn = apply_bc(res_fn, problem) return res_fn(dofs) def constraint_fn_sol_to_sol(sol_list, params): dofs = jax.flatten_util.ravel_pytree(sol_list)[0] con_vec = constraint_fn(dofs, params) return problem.unflatten_fn_sol_list(con_vec) def get_partial_params_c_fn(sol_list): """c(u=u, p) """ def partial_params_c_fn(params): return constraint_fn_sol_to_sol(sol_list, params) return partial_params_c_fn def get_vjp_contraint_fn_params(params, sol_list): """v*(partial dc/dp) """ partial_c_fn = get_partial_params_c_fn(sol_list) def vjp_linear_fn(v_list): primals_output, f_vjp = jax.vjp(partial_c_fn, params) val, = f_vjp(v_list) return val return vjp_linear_fn problem.set_params(params) problem.newton_update(sol_list) A = get_A(problem) v_vec = jax.flatten_util.ravel_pytree(v_list)[0] if hasattr(problem, 'P_mat'): v_vec = problem.P_mat.T @ v_vec # Be careful that A.transpose() does in-place change to A adjoint_vec = linear_solver(A.transpose(), v_vec, None, adjoint_solver_options) if hasattr(problem, 'P_mat'): adjoint_vec = problem.P_mat @ adjoint_vec vjp_linear_fn = get_vjp_contraint_fn_params(params, sol_list) vjp_result = vjp_linear_fn(problem.unflatten_fn_sol_list(adjoint_vec)) vjp_result = jax.tree_map(lambda x: -x, vjp_result) return vjp_result
[docs] def ad_wrapper(problem, solver_options={}, adjoint_solver_options={}): @jax.custom_vjp def fwd_pred(params): problem.set_params(params) sol_list = solver(problem, solver_options) return sol_list def f_fwd(params): sol_list = fwd_pred(params) return sol_list, (params, sol_list) def f_bwd(res, v): logger.info("Running backward and solving the adjoint problem...") params, sol_list = res vjp_result = implicit_vjp(problem, sol_list, params, v, adjoint_solver_options) return (vjp_result, ) fwd_pred.defvjp(f_fwd, f_bwd) return fwd_pred