Source code for jax_fem.mma

"""
Copied and modified from https://github.com/UW-ERSL/AuTO
Under GNU General Public License v3.0

Original copy from https://github.com/arjendeetman/GCMMA-MMA-Python/blob/master/Code/MMA.py

Improvement is made to avoid N^2 memory operation so that the MMA solver is more scalable.
"""
from numpy import diag as diags
from numpy.linalg import solve
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, random, jacfwd, value_and_grad
from functools import partial
import time
import scipy
from jax.experimental.sparse import BCOO

from jax import config
config.update("jax_enable_x64", True)

density_filtering = False
sensitivity_filtering = True

[docs] def compute_filter_kd_tree(fe): """This function is created by Tianju. Not from the original code. We use k-d tree algorithm to compute the filter. """ cell_centroids = np.mean(np.take(fe.points, fe.cells, axis=0), axis=1) flex_num_cells = len(fe.flex_inds) flex_cell_centroids = np.take(cell_centroids, fe.flex_inds, axis=0) V = np.sum(fe.JxW) avg_elem_V = V/fe.num_cells avg_elem_size = avg_elem_V**(1./fe.dim) rmin = 1.5*avg_elem_size kd_tree = scipy.spatial.KDTree(flex_cell_centroids) I = [] J = [] V = [] for i in range(flex_num_cells): num_nbs = 20 dd, ii = kd_tree.query(flex_cell_centroids[i], num_nbs) neighbors = np.take(flex_cell_centroids, ii, axis=0) vals = np.where(rmin - dd > 0., rmin - dd, 0.) I += [i]*num_nbs J += ii.tolist() V += vals.tolist() H_sp = scipy.sparse.csc_array((V, (I, J)), shape=(flex_num_cells, flex_num_cells)) H = BCOO.from_scipy_sparse(H_sp).sort_indices() Hs = H.sum(1).todense() return H, Hs
[docs] def applySensitivityFilter(ft, rho, dJ, dvc): dJ = ft['H'] @ (rho*dJ/np.maximum(1e-3, rho)/ft['Hs'][:, None]) dvc = ft['H'][None, :, :] @ (rho[None, :, :]*dvc/np.maximum(1e-3, rho[None, :, :])/ft['Hs'][None, :, None]) return dJ, dvc
[docs] def applyDensityFilter(ft, rho): return ft['H'] @ rho / ft['Hs'][:, None]
#%% Optimizer
[docs] class MMA: # The code was modified from [MMA Svanberg 1987]. Please cite the paper if # you end up using this code. def __init__(self): self.epoch = 0;
[docs] def resetMMACounter(self): self.epoch = 0;
[docs] def registerMMAIter(self, xval, xold1, xold2): self.epoch += 1; self.xval = xval; self.xold1 = xold1; self.xold2 = xold2;
[docs] def setNumConstraints(self, numConstraints): self.numConstraints = numConstraints;
[docs] def setNumDesignVariables(self, numDesVar): self.numDesignVariables = numDesVar;
[docs] def setMinandMaxBoundsForDesignVariables(self, xmin, xmax): self.xmin = xmin; self.xmax = xmax;
[docs] def setObjectiveWithGradient(self, obj, objGrad): self.objective = obj; self.objectiveGradient = objGrad;
[docs] def setConstraintWithGradient(self, cons, consGrad): self.constraint = cons; self.consGrad = consGrad;
[docs] def setScalingParams(self, zconst, zscale, ylinscale, yquadscale): self.zconst = zconst; self.zscale = zscale; self.ylinscale = ylinscale; self.yquadscale = yquadscale;
[docs] def setMoveLimit(self, movelim): self.moveLimit = movelim;
[docs] def setLowerAndUpperAsymptotes(self, low, upp): self.lowAsymp = low; self.upAsymp = upp;
[docs] def getOptimalValues(self): return self.xmma, self.ymma, self.zmma;
[docs] def getLagrangeMultipliers(self): return self.lam, self.xsi, self.eta, self.mu, self.zet;
[docs] def getSlackValue(self): return self.slack;
[docs] def getAsymptoteValues(self): return self.lowAsymp, self.upAsymp;
# Function for the MMA sub problem
[docs] def mmasub(self, xval): m = self.numConstraints; n = self.numDesignVariables; iter = self.epoch; xmin, xmax = self.xmin, self.xmax; xold1, xold2 = self.xold1, self.xold2; f0val, df0dx = self.objective, self.objectiveGradient; fval, dfdx = self.constraint, self.consGrad; low, upp = self.lowAsymp, self.upAsymp; a0, a, c, d = self.zconst, self.zscale, self.ylinscale, self.yquadscale; move = self.moveLimit; epsimin = 0.0000001 raa0 = 0.00001 albefa = 0.1 asyinit = 0.5 asyincr = 1.2 asydecr = 0.7 eeen = np.ones((n, 1)) eeem = np.ones((m, 1)) zeron = np.zeros((n, 1)) # Calculation of the asymptotes low and upp if iter <= 2: low = xval-asyinit*(xmax-xmin) upp = xval+asyinit*(xmax-xmin) else: zzz = (xval-xold1)*(xold1-xold2) factor = eeen.copy() factor[np.where(zzz>0)] = asyincr factor[np.where(zzz<0)] = asydecr low = xval-factor*(xold1-low) upp = xval+factor*(upp-xold1) lowmin = xval-10*(xmax-xmin) lowmax = xval-0.01*(xmax-xmin) uppmin = xval+0.01*(xmax-xmin) uppmax = xval+10*(xmax-xmin) low = np.maximum(low,lowmin) low = np.minimum(low,lowmax) upp = np.minimum(upp,uppmax) upp = np.maximum(upp,uppmin) # Calculation of the bounds alfa and beta zzz1 = low+albefa*(xval-low) zzz2 = xval-move*(xmax-xmin) zzz = np.maximum(zzz1,zzz2) alfa = np.maximum(zzz,xmin) zzz1 = upp-albefa*(upp-xval) zzz2 = xval+move*(xmax-xmin) zzz = np.minimum(zzz1,zzz2) beta = np.minimum(zzz,xmax) # Calculations of p0, q0, P, Q and b xmami = xmax-xmin xmamieps = 0.00001*eeen xmami = np.maximum(xmami,xmamieps) xmamiinv = eeen/xmami ux1 = upp-xval ux2 = ux1*ux1 xl1 = xval-low xl2 = xl1*xl1 uxinv = eeen/ux1 xlinv = eeen/xl1 p0 = zeron.copy() q0 = zeron.copy() p0 = np.maximum(df0dx,0) q0 = np.maximum(-df0dx,0) pq0 = 0.001*(p0+q0)+raa0*xmamiinv p0 = p0+pq0 q0 = q0+pq0 p0 = p0*ux2 q0 = q0*xl2 P = np.zeros((m,n)) ## @@ make sparse with scipy? Q = np.zeros((m,n)) ## @@ make sparse with scipy? P = np.maximum(dfdx,0) Q = np.maximum(-dfdx,0) PQ = 0.001*(P+Q)+raa0*np.dot(eeem,xmamiinv.T) P = P+PQ Q = Q+PQ # P = (diags(ux2.flatten(),0).dot(P.T)).T # Q = (diags(xl2.flatten(),0).dot(Q.T)).T P = ux2.T*P Q = xl2.T*Q b = (np.dot(P,uxinv)+np.dot(Q,xlinv)-fval) # Solving the subproblem by a primal-dual Newton method xmma,ymma,zmma,lam,xsi,eta,mu,zet,s = subsolv(m,n,epsimin,low,upp,alfa,\ beta,p0,q0,P,Q,a0,a,b,c,d) # Return values self.xmma, self.ymma, self.zmma = xmma, ymma, zmma; self.lam, self.xsi, self.eta, self.mu, self.zet = lam,xsi,eta,mu,zet; self.slack = s; self.lowAsymp, self.upAsymp = low, upp;
[docs] def subsolv(m,n,epsimin,low,upp,alfa,beta,p0,q0,P,Q,a0,a,b,c,d): een = np.ones((n,1)) eem = np.ones((m,1)) epsi = 1 epsvecn = epsi*een epsvecm = epsi*eem x = 0.5*(alfa+beta) y = eem.copy() z = np.array([[1.0]]) lam = eem.copy() xsi = een/(x-alfa) xsi = np.maximum(xsi,een) eta = een/(beta-x) eta = np.maximum(eta,een) mu = np.maximum(eem,0.5*c) zet = np.array([[1.0]]) s = eem.copy() itera = 0 # Start while epsi>epsimin while epsi > epsimin: epsvecn = epsi*een epsvecm = epsi*eem ux1 = upp-x xl1 = x-low ux2 = ux1*ux1 xl2 = xl1*xl1 uxinv1 = een/ux1 xlinv1 = een/xl1 plam = p0+np.dot(P.T,lam) qlam = q0+np.dot(Q.T,lam) gvec = np.dot(P,uxinv1)+np.dot(Q,xlinv1) dpsidx = plam/ux2-qlam/xl2 rex = dpsidx-xsi+eta rey = c+d*y-mu-lam rez = a0-zet-np.dot(a.T,lam) relam = gvec-a*z-y+s-b rexsi = xsi*(x-alfa)-epsvecn reeta = eta*(beta-x)-epsvecn remu = mu*y-epsvecm rezet = zet*z-epsi res = lam*s-epsvecm residu1 = np.concatenate((rex, rey, rez), axis = 0) residu2 = np.concatenate((relam, rexsi, reeta, remu, rezet, res), axis = 0) residu = np.concatenate((residu1, residu2), axis = 0) residunorm = np.sqrt((np.dot(residu.T,residu)).item()) residumax = np.max(np.abs(residu)) ittt = 0 # Start while (residumax>0.9*epsi) and (ittt<200) while (residumax > 0.9*epsi) and (ittt < 200): ittt = ittt+1 itera = itera+1 ux1 = upp-x xl1 = x-low ux2 = ux1*ux1 xl2 = xl1*xl1 ux3 = ux1*ux2 xl3 = xl1*xl2 uxinv1 = een/ux1 xlinv1 = een/xl1 uxinv2 = een/ux2 xlinv2 = een/xl2 plam = p0+np.dot(P.T,lam) qlam = q0+np.dot(Q.T,lam) gvec = np.dot(P,uxinv1)+np.dot(Q,xlinv1) # GG = (diags(uxinv2.flatten(),0).dot(P.T)).T-(diags\ # (xlinv2.flatten(),0).dot(Q.T)).T GG = uxinv2.T*P - xlinv2.T*Q dpsidx = plam/ux2-qlam/xl2 delx = dpsidx-epsvecn/(x-alfa)+epsvecn/(beta-x) dely = c+d*y-lam-epsvecm/y delz = a0-np.dot(a.T,lam)-epsi/z dellam = gvec-a*z-y-b+epsvecm/lam diagx = plam/ux3+qlam/xl3 diagx = 2*diagx+xsi/(x-alfa)+eta/(beta-x) diagxinv = een/diagx diagy = d+mu/y diagyinv = eem/diagy diaglam = s/lam diaglamyi = diaglam+diagyinv # Start if m<n if m < n: blam = dellam+dely/diagy-np.dot(GG,(delx/diagx)) bb = np.concatenate((blam,delz),axis = 0) # Alam = np.asarray(diags(diaglamyi.flatten(),0) \ # +(diags(diagxinv.flatten(),0).dot(GG.T).T).dot(GG.T)) Alam = diags(diaglamyi.flatten(),0) + (diagxinv.T*GG).dot(GG.T) AAr1 = np.concatenate((Alam,a),axis = 1) AAr2 = np.concatenate((a,-zet/z),axis = 0).T AA = np.concatenate((AAr1,AAr2),axis = 0) solut = solve(AA,bb) dlam = solut[0:m] dz = solut[m:m+1] dx = -delx/diagx-np.dot(GG.T,dlam)/diagx else: diaglamyiinv = eem/diaglamyi dellamyi = dellam+dely/diagy Axx = np.asarray(diags(diagx.flatten(),0) \ +(diags(diaglamyiinv.flatten(),0).dot(GG).T).dot(GG)) azz = zet/z+np.dot(a.T,(a/diaglamyi)) axz = np.dot(-GG.T,(a/diaglamyi)) bx = delx+np.dot(GG.T,(dellamyi/diaglamyi)) bz = delz-np.dot(a.T,(dellamyi/diaglamyi)) AAr1 = np.concatenate((Axx,axz),axis = 1) AAr2 = np.concatenate((axz.T,azz),axis = 1) AA = np.concatenate((AAr1,AAr2),axis = 0) bb = np.concatenate((-bx,-bz),axis = 0) solut = solve(AA,bb) dx = solut[0:n] dz = solut[n:n+1] dlam = np.dot(GG,dx)/diaglamyi-dz*(a/diaglamyi)\ +dellamyi/diaglamyi # End if m<n dy = -dely/diagy+dlam/diagy dxsi = -xsi+epsvecn/(x-alfa)-(xsi*dx)/(x-alfa) deta = -eta+epsvecn/(beta-x)+(eta*dx)/(beta-x) dmu = -mu+epsvecm/y-(mu*dy)/y dzet = -zet+epsi/z-zet*dz/z ds = -s+epsvecm/lam-(s*dlam)/lam xx = np.concatenate((y,z,lam,xsi,eta,mu,zet,s),axis = 0) dxx = np.concatenate((dy,dz,dlam,dxsi,deta,dmu,dzet,ds),axis = 0) # stepxx = -1.01*dxx/xx stmxx = np.max(stepxx) stepalfa = -1.01*dx/(x-alfa) stmalfa = np.max(stepalfa) stepbeta = 1.01*dx/(beta-x) stmbeta = np.max(stepbeta) stmalbe = max(stmalfa,stmbeta) stmalbexx = max(stmalbe,stmxx) stminv = max(stmalbexx,1.0) steg = 1.0/stminv # xold = x.copy() yold = y.copy() zold = z.copy() lamold = lam.copy() xsiold = xsi.copy() etaold = eta.copy() muold = mu.copy() zetold = zet.copy() sold = s.copy() # itto = 0 resinew = 2*residunorm # Start: while (resinew>residunorm) and (itto<50) while (resinew > residunorm) and (itto < 50): itto = itto+1 x = xold+steg*dx y = yold+steg*dy z = zold+steg*dz lam = lamold+steg*dlam xsi = xsiold+steg*dxsi eta = etaold+steg*deta mu = muold+steg*dmu zet = zetold+steg*dzet s = sold+steg*ds ux1 = upp-x xl1 = x-low ux2 = ux1*ux1 xl2 = xl1*xl1 uxinv1 = een/ux1 xlinv1 = een/xl1 plam = p0+np.dot(P.T,lam) qlam = q0+np.dot(Q.T,lam) gvec = np.dot(P,uxinv1)+np.dot(Q,xlinv1) dpsidx = plam/ux2-qlam/xl2 rex = dpsidx-xsi+eta rey = c+d*y-mu-lam rez = a0-zet-np.dot(a.T,lam) relam = gvec-np.dot(a,z)-y+s-b rexsi = xsi*(x-alfa)-epsvecn reeta = eta*(beta-x)-epsvecn remu = mu*y-epsvecm rezet = np.dot(zet,z)-epsi res = lam*s-epsvecm residu1 = np.concatenate((rex,rey,rez),axis = 0) residu2 = np.concatenate((relam,rexsi,reeta,remu,rezet,res), \ axis = 0) residu = np.concatenate((residu1,residu2),axis = 0) resinew = np.sqrt(np.dot(residu.T,residu)) steg = steg/2 # End: while (resinew>residunorm) and (itto<50) residunorm = resinew.copy() residumax = max(abs(residu)) steg = 2*steg # End: while (residumax>0.9*epsi) and (ittt<200) epsi = 0.1*epsi # End: while epsi>epsimin xmma = x.copy() ymma = y.copy() zmma = z.copy() lamma = lam xsimma = xsi etamma = eta mumma = mu zetmma = zet smma = s return xmma,ymma,zmma,lamma,xsimma,etamma,mumma,zetmma,smma
[docs] def optimize(fe, rho_ini, optimizationParams, objectiveHandle, consHandle, numConstraints): # TODO: Scale objective function value to be always within 1-100 # See comments in https://doi.org/10.1016/j.compstruc.2018.01.008 H, Hs = compute_filter_kd_tree(fe) ft = {'H':H, 'Hs':Hs} rho = rho_ini loop = 0 m = numConstraints # num constraints n = len(rho.reshape(-1)) # num params mma = MMA() mma.setNumConstraints(numConstraints) mma.setNumDesignVariables(n) mma.setMinandMaxBoundsForDesignVariables\ (np.zeros((n,1)),np.ones((n,1))) xval = rho.reshape(-1)[:, None] xold1, xold2 = xval.copy(), xval.copy() mma.registerMMAIter(xval, xold1, xold2) mma.setLowerAndUpperAsymptotes(np.ones((n,1)), np.ones((n,1))) mma.setScalingParams(1.0, np.zeros((m,1)), \ 10000*np.ones((m,1)), np.zeros((m,1))) # Move limit is an important parameter that affects TO result; default can be 0.2 mma.setMoveLimit(optimizationParams['movelimit']) while loop < optimizationParams['maxIters']: loop = loop + 1 print(f"MMA solver...") if density_filtering: rho_physical = applyDensityFilter(ft, rho) else: rho_physical = rho J, dJ = objectiveHandle(rho_physical) vc, dvc = consHandle(rho_physical, loop) if sensitivity_filtering: dJ, dvc = applySensitivityFilter(ft, rho, dJ, dvc) J, dJ = J, dJ.reshape(-1)[:, None] vc, dvc = vc[:, None], dvc.reshape(dvc.shape[0], -1) print(f"J.shape = {J.shape}") print(f"dJ.shape = {dJ.shape}") print(f"vc.shape = {vc.shape}") print(f"dvc.shape = {dvc.shape}") J, dJ, vc, dvc = np.array(J), np.array(dJ), np.array(vc), np.array(dvc) start = time.time() mma.setObjectiveWithGradient(J, dJ) mma.setConstraintWithGradient(vc, dvc) mma.mmasub(xval) xmma, _, _ = mma.getOptimalValues() xold2 = xold1.copy() xold1 = xval.copy() xval = xmma.copy() mma.registerMMAIter(xval, xold1, xold2) rho = xval.reshape(rho.shape) end = time.time() time_elapsed = end - start print(f"MMA took {time_elapsed} [s]") print(f'Iter {loop:d}; J {J:.5f}; constraint {vc}\n\n\n') return rho