Skip to content

Commit

Permalink
Merge pull request #1465 from pengmiaoying/develop
Browse files Browse the repository at this point in the history
update
  • Loading branch information
weihuayi authored Jan 20, 2025
2 parents b7543bd + 1b63fa2 commit 389f7a3
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 43 deletions.
33 changes: 26 additions & 7 deletions app/tssim/sph/test/tgv_test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
@ref
'''
from fealpy.backend import backend_manager as bm
from fealpy.mesh.node_mesh import NodeMesh
from fealpy.cfd.sph.particle_solver_new import SPHSolver
from fealpy.mesh.node_mesh import NodeMesh, KDTree
from fealpy.cfd.sph.particle_solver_new import SPHSolver, Space, VmapBackend
from fealpy.cfd.sph.particle_kernel_function import QuinticKernel
from jax_md import space #?
import time

bm.set_backend('numpy')
bm.set_backend('pytorch')

EPS = bm.finfo(float).eps
dx = 0.02
dy = 0.02
dx = 0.25
dy = 0.25
h = dx
Vmax = 1.0 #预期最大速度
c0 =10 * Vmax #声速
Expand All @@ -33,8 +33,27 @@

mesh = NodeMesh.from_tgv_domain(box_size, dx)
solver = SPHSolver(mesh)
space = Space()
kernel = QuinticKernel(h=h, dim=2)
displacement, shift = space.periodic(side=box_size)
displacement, shift = space.periodic(box_size, True)
kdtree = KDTree(mesh.nodedata["position"],box_size)
vmap_backend = VmapBackend()


#start = time.time()
for i in range(1):
print(i)
mesh.nodedata['mv'] += 1.0*dt*mesh.nodedata["dmvdt"]
mesh.nodedata['tv'] = mesh.nodedata['mv']
mesh.nodedata["position"] = shift(mesh.nodedata["position"], 1.0 * dt * mesh.nodedata["tv"])

r = mesh.nodedata["position"]
i_s, j_s = kdtree.range_query(mesh.nodedata["position"], 3*h, include_self=True)
r_i_s, r_j_s = r[i_s], r[j_s]
dr_i_j = vmap_backend.apply(displacement, r_i_s, r_j_s)
#print(dr_i_j)



#end = time.time()
#print(end-start)
106 changes: 97 additions & 9 deletions fealpy/cfd/sph/particle_solver_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,37 @@
from fealpy.backend import backend_manager as bm
from fealpy.backend import TensorLike

import numpy as np
import jax
import jax.numpy as jnp
import torch

# Types
Box = TensorLike
f32 = bm.float32

class SPHSolver:
def __init__(self, mesh):
self.mesh = mesh
class Space:
def raw_transform(self, box:Box, R:TensorLike):
if box.ndim == 0 or box.size == 1:

return R * box
elif box.ndim == 1:
indices = self._get_free_indices(R.ndim - 1) + "i"

return bm.einsum(f"i,{indices}->{indices}", box, R)
elif box.ndim == 2:
free_indices = self._get_free_indices(R.ndim - 1)
left_indices = free_indices + "j"
right_indices = free_indices + "i"

return bm.einsum(f"ij,{left_indices}->{right_indices}", box, R)
raise ValueError(
("Box must be either: a scalar, a vector, or a matrix. " f"Found {box}.")
)

def _get_free_indices(self, n: int):

return "".join([chr(ord("a") + i) for i in range(n)])

def pairwise_displacement(self, Ra: TensorLike, Rb: TensorLike):
if len(Ra.shape) != 1:
Expand All @@ -34,12 +57,77 @@ def pairwise_displacement(self, Ra: TensorLike, Rb: TensorLike):
return Ra - Rb

def periodic_displacement(self, side: Box, dR: TensorLike):
_dR = ((dR + side * f32(0.5)) % side) - f32(0.5) * side

return _dR

def periodic_shift(self, side: Box, R: TensorLike, dR: TensorLike):

return (R + dR) % side

def periodic(self, side: Box, wrapped: bool = True):
def displacement_fn( Ra: TensorLike, Rb: TensorLike, perturbation = None, **unused_kwargs):
if "box" in unused_kwargs:
raise UnexpectedBoxException(
(
"`space.periodic` does not accept a box "
"argument. Perhaps you meant to use "
"`space.periodic_general`?"
)
)
dR = self.periodic_displacement(side, self.pairwise_displacement(Ra, Rb))
if perturbation is not None:
dR = self.raw_transform(perturbation, dR)

return dR
if wrapped:
def shift_fn(R: TensorLike, dR: TensorLike, **unused_kwargs):
if "box" in unused_kwargs:
raise UnexpectedBoxException(
(
"`space.periodic` does not accept a box "
"argument. Perhaps you meant to use "
"`space.periodic_general`?"
)
)

return bm.mod(dR + side * f32(0.5), side) - f32(0.5) * side
return self.periodic_shift(side, R, dR)
else:
def shift_fn(R: TensorLike, dR: TensorLike, **unused_kwargs):
if "box" in unused_kwargs:
raise UnexpectedBoxException(
(
"`space.periodic` does not accept a box "
"argument. Perhaps you meant to use "
"`space.periodic_general`?"
)
)
return R + dR

return displacement_fn, shift_fn

'''
def periodic(side: Box):
def displacement_fn
pass
'''
class VmapBackend:
def __init__(self):
current_backend = bm.backend_name

# 根据当前后端设置 vmap 函数
if current_backend == 'jax':
# 使用 JAX 的 vmap
self.vmap_func = jax.vmap
elif current_backend == 'numpy':
# 使用 NumPy 的 vectorize
self.vmap_func = np.vectorize
elif current_backend == 'pytorch':
# 使用 PyTorch 的 vmap
self.vmap_func = torch.vmap
else:
raise ValueError(f"Unsupported backend: {current_backend}")

def apply(self, func, *args, **kwargs):
# 返回已经适配的 vmap 函数
return self.vmap_func(func)(*args, **kwargs)

class SPHSolver:
def __init__(self, mesh):
self.mesh = mesh

Loading

0 comments on commit 389f7a3

Please sign in to comment.