import os
import gmsh
import numpy as onp
import meshio
from jax_fem.basis import get_elements
from jax_fem.basis import get_face_shape_vals_and_grads
import jax
import jax.numpy as np
[docs]
class Mesh():
"""
A custom mesh manager might be better using a third-party library like
meshio.
"""
def __init__(self, points, cells, ele_type='TET4'):
# TODO (Very important for debugging purpose!): Assert that cells must have correct orders
self.points = points
self.cells = cells
self.ele_type = ele_type
[docs]
def count_selected_faces(self, location_fn):
"""Given location functions, compute the count of faces that satisfy
the location function. Useful for setting up distributed load
conditions.
Parameters
----------
location_fns : List[Callable]
Callable: a function that inputs a point and returns a boolean
value describing whether the boundary condition should be applied.
Returns
-------
face_count : int
"""
_, _, _, _, face_inds = get_face_shape_vals_and_grads(self.ele_type)
cell_points = onp.take(self.points, self.cells, axis=0)
cell_face_points = onp.take(cell_points, face_inds, axis=1)
vmap_location_fn = jax.vmap(location_fn)
def on_boundary(cell_points):
boundary_flag = vmap_location_fn(cell_points)
return onp.all(boundary_flag)
vvmap_on_boundary = jax.vmap(jax.vmap(on_boundary))
boundary_flags = vvmap_on_boundary(cell_face_points)
boundary_inds = onp.argwhere(boundary_flags)
return boundary_inds.shape[0]
[docs]
def check_mesh_TET4(points, cells):
# TODO
def quality(pts):
p1, p2, p3, p4 = pts
v1 = p2 - p1
v2 = p3 - p1
v12 = np.cross(v1, v2)
v3 = p4 - p1
return np.dot(v12, v3)
qlts = jax.vmap(quality)(points[cells])
return qlts
[docs]
def get_meshio_cell_type(ele_type):
"""Reference:
https://github.com/nschloe/meshio/blob/9dc6b0b05c9606cad73ef11b8b7785dd9b9ea325/src/meshio/xdmf/common.py#L36
"""
if ele_type == 'TET4':
cell_type = 'tetra'
elif ele_type == 'TET10':
cell_type = 'tetra10'
elif ele_type == 'HEX8':
cell_type = 'hexahedron'
elif ele_type == 'HEX27':
cell_type = 'hexahedron27'
elif ele_type == 'HEX20':
cell_type = 'hexahedron20'
elif ele_type == 'TRI3':
cell_type = 'triangle'
elif ele_type == 'TRI6':
cell_type = 'triangle6'
elif ele_type == 'QUAD4':
cell_type = 'quad'
elif ele_type == 'QUAD8':
cell_type = 'quad8'
else:
raise NotImplementedError
return cell_type
[docs]
def rectangle_mesh(Nx, Ny, domain_x, domain_y):
"""
QUAD4 mesh, generated by our own code
"""
dim = 2
x = onp.linspace(0, domain_x, Nx + 1)
y = onp.linspace(0, domain_y, Ny + 1)
xv, yv = onp.meshgrid(x, y, indexing='ij')
points_xy = onp.stack((xv, yv), axis=dim)
points = points_xy.reshape(-1, dim)
points_inds = onp.arange(len(points))
points_inds_xy = points_inds.reshape(Nx + 1, Ny + 1)
inds1 = points_inds_xy[:-1, :-1]
inds2 = points_inds_xy[1:, :-1]
inds3 = points_inds_xy[1:, 1:]
inds4 = points_inds_xy[:-1, 1:]
cells = onp.stack((inds1, inds2, inds3, inds4), axis=dim).reshape(-1, 4)
out_mesh = meshio.Mesh(points=points, cells={'quad': cells})
return out_mesh
[docs]
def box_mesh(Nx, Ny, Nz, domain_x, domain_y, domain_z):
"""
HEX8 mesh, generated by our own code
"""
dim = 3
x = onp.linspace(0, domain_x, Nx + 1)
y = onp.linspace(0, domain_y, Ny + 1)
z = onp.linspace(0, domain_z, Nz + 1)
xv, yv, zv = onp.meshgrid(x, y, z, indexing='ij')
points_xyz = onp.stack((xv, yv, zv), axis=dim)
points = points_xyz.reshape(-1, dim)
points_inds = onp.arange(len(points))
points_inds_xyz = points_inds.reshape(Nx + 1, Ny + 1, Nz + 1)
inds1 = points_inds_xyz[:-1, :-1, :-1]
inds2 = points_inds_xyz[1:, :-1, :-1]
inds3 = points_inds_xyz[1:, 1:, :-1]
inds4 = points_inds_xyz[:-1, 1:, :-1]
inds5 = points_inds_xyz[:-1, :-1, 1:]
inds6 = points_inds_xyz[1:, :-1, 1:]
inds7 = points_inds_xyz[1:, 1:, 1:]
inds8 = points_inds_xyz[:-1, 1:, 1:]
cells = onp.stack((inds1, inds2, inds3, inds4, inds5, inds6, inds7, inds8),
axis=dim).reshape(-1, 8)
out_mesh = meshio.Mesh(points=points, cells={'hexahedron': cells})
return out_mesh
[docs]
def box_mesh_gmsh(Nx, Ny, Nz, Lx, Ly, Lz, data_dir, ele_type='HEX8'):
"""References:
https://gitlab.onelab.info/gmsh/gmsh/-/blob/master/examples/api/hex.py
https://gitlab.onelab.info/gmsh/gmsh/-/blob/gmsh_4_7_1/tutorial/python/t1.py
https://gitlab.onelab.info/gmsh/gmsh/-/blob/gmsh_4_7_1/tutorial/python/t3.py
Accepts ele_type = 'HEX8', 'TET4' or 'TET10', mesh will be generated with the help of gmsh
"""
assert ele_type != 'HEX20', f"gmsh cannot produce HEX20 mesh?"
cell_type = get_meshio_cell_type(ele_type)
_, _, _, _, degree, _ = get_elements(ele_type)
msh_dir = os.path.join(data_dir, 'msh')
os.makedirs(msh_dir, exist_ok=True)
msh_file = os.path.join(msh_dir, 'box.msh')
offset_x = 0.
offset_y = 0.
offset_z = 0.
domain_x = Lx
domain_y = Ly
domain_z = Lz
gmsh.initialize()
gmsh.option.setNumber("Mesh.MshFileVersion", 2.2) # save in old MSH format
if cell_type.startswith('tetra'):
Rec2d = False # tris or quads
Rec3d = False # tets, prisms or hexas
else:
Rec2d = True
Rec3d = True
p = gmsh.model.geo.addPoint(offset_x, offset_y, offset_z)
l = gmsh.model.geo.extrude([(0, p)], domain_x, 0, 0, [Nx], [1])
s = gmsh.model.geo.extrude([l[1]], 0, domain_y, 0, [Ny], [1], recombine=Rec2d)
v = gmsh.model.geo.extrude([s[1]], 0, 0, domain_z, [Nz], [1], recombine=Rec3d)
gmsh.model.geo.synchronize()
gmsh.model.mesh.generate(3)
gmsh.model.mesh.setOrder(degree)
gmsh.write(msh_file)
gmsh.finalize()
mesh = meshio.read(msh_file)
points = mesh.points # (num_total_nodes, dim)
cells = mesh.cells_dict[cell_type] # (num_cells, num_nodes)
out_mesh = meshio.Mesh(points=points, cells={cell_type: cells})
return out_mesh
[docs]
def cylinder_mesh_gmsh(data_dir, R=5, H=10, circle_mesh=5, hight_mesh=20, rect_ratio=0.4):
"""By Xinxin Wu at PKU in July, 2022
Reference: https://www.researchgate.net/post/How_can_I_create_a_structured_mesh_using_a_transfinite_volume_in_gmsh
R: radius
H: hight
circle_mesh:num of meshs in circle lines
hight_mesh:num of meshs in hight
rect_ratio: rect length/R
"""
rect_coor = R*rect_ratio
msh_dir = os.path.join(data_dir, 'msh')
os.makedirs(msh_dir, exist_ok=True)
geo_file = os.path.join(msh_dir, 'cylinder.geo')
msh_file = os.path.join(msh_dir, 'cylinder.msh')
string='''
Point(1) = {{0, 0, 0, 1.0}};
Point(2) = {{-{rect_coor}, {rect_coor}, 0, 1.0}};
Point(3) = {{{rect_coor}, {rect_coor}, 0, 1.0}};
Point(4) = {{{rect_coor}, -{rect_coor}, 0, 1.0}};
Point(5) = {{-{rect_coor}, -{rect_coor}, 0, 1.0}};
Point(6) = {{{R}*Cos(3*Pi/4), {R}*Sin(3*Pi/4), 0, 1.0}};
Point(7) = {{{R}*Cos(Pi/4), {R}*Sin(Pi/4), 0, 1.0}};
Point(8) = {{{R}*Cos(-Pi/4), {R}*Sin(-Pi/4), 0, 1.0}};
Point(9) = {{{R}*Cos(-3*Pi/4), {R}*Sin(-3*Pi/4), 0, 1.0}};
Line(1) = {{2, 3}};
Line(2) = {{3, 4}};
Line(3) = {{4, 5}};
Line(4) = {{5, 2}};
Line(5) = {{2, 6}};
Line(6) = {{3, 7}};
Line(7) = {{4, 8}};
Line(8) = {{5, 9}};
Circle(9) = {{6, 1, 7}};
Circle(10) = {{7, 1, 8}};
Circle(11) = {{8, 1, 9}};
Circle(12) = {{9, 1, 6}};
Curve Loop(1) = {{1, 2, 3, 4}};
Plane Surface(1) = {{1}};
Curve Loop(2) = {{1, 6, -9, -5}};
Plane Surface(2) = {{2}};
Curve Loop(3) = {{2, 7, -10, -6}};
Plane Surface(3) = {{3}};
Curve Loop(4) = {{3, 8, -11, -7}};
Plane Surface(4) = {{4}};
Curve Loop(5) = {{4, 5, -12, -8}};
Plane Surface(5) = {{5}};
Transfinite Curve {{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}} = {circle_mesh} Using Progression 1;
Transfinite Surface {{1}};
Transfinite Surface {{2}};
Transfinite Surface {{3}};
Transfinite Surface {{4}};
Transfinite Surface {{5}};
Recombine Surface {{1, 2, 3, 4, 5}};
Extrude {{0, 0, {H}}} {{
Surface{{1:5}}; Layers {{{hight_mesh}}}; Recombine;
}}
Mesh 3;'''.format(R=R, H=H, rect_coor=rect_coor, circle_mesh=circle_mesh, hight_mesh=hight_mesh)
with open(geo_file, "w") as f:
f.write(string)
os.system("gmsh -3 {geo_file} -o {msh_file} -format msh2".format(geo_file=geo_file, msh_file=msh_file))
mesh = meshio.read(msh_file)
points = mesh.points # (num_total_nodes, dim)
cells = mesh.cells_dict['hexahedron'] # (num_cells, num_nodes)
# The mesh somehow has two redundant points...
points = onp.vstack((points[1:14], points[15:]))
cells = onp.where(cells > 14, cells - 2, cells - 1)
out_mesh = meshio.Mesh(points=points, cells={'hexahedron': cells})
return out_mesh