From 2d916b232ba5c890eb76dfb1d485ea62e7ff78ca Mon Sep 17 00:00:00 2001 From: Skyler <530334303@qq.com> Date: Sat, 18 Jan 2025 17:43:48 +0800 Subject: [PATCH] update --- app/tssim/sph/test/tgv_test_backend.py | 33 ++++-- fealpy/cfd/sph/particle_solver_new.py | 106 +++++++++++++++-- fealpy/mesh/node_mesh.py | 152 ++++++++++++++++++++----- 3 files changed, 248 insertions(+), 43 deletions(-) diff --git a/app/tssim/sph/test/tgv_test_backend.py b/app/tssim/sph/test/tgv_test_backend.py index 46b9dea68..0ca8d11c8 100644 --- a/app/tssim/sph/test/tgv_test_backend.py +++ b/app/tssim/sph/test/tgv_test_backend.py @@ -8,16 +8,16 @@ @ref ''' from fealpy.backend import backend_manager as bm -from fealpy.mesh.node_mesh import NodeMesh -from fealpy.cfd.sph.particle_solver_new import SPHSolver +from fealpy.mesh.node_mesh import NodeMesh, KDTree +from fealpy.cfd.sph.particle_solver_new import SPHSolver, Space, VmapBackend from fealpy.cfd.sph.particle_kernel_function import QuinticKernel -from jax_md import space #? +import time -bm.set_backend('numpy') +bm.set_backend('pytorch') EPS = bm.finfo(float).eps -dx = 0.02 -dy = 0.02 +dx = 0.25 +dy = 0.25 h = dx Vmax = 1.0 #预期最大速度 c0 =10 * Vmax #声速 @@ -33,8 +33,27 @@ mesh = NodeMesh.from_tgv_domain(box_size, dx) solver = SPHSolver(mesh) +space = Space() kernel = QuinticKernel(h=h, dim=2) -displacement, shift = space.periodic(side=box_size) +displacement, shift = space.periodic(box_size, True) +kdtree = KDTree(mesh.nodedata["position"],box_size) +vmap_backend = VmapBackend() + +#start = time.time() for i in range(1): print(i) + mesh.nodedata['mv'] += 1.0*dt*mesh.nodedata["dmvdt"] + mesh.nodedata['tv'] = mesh.nodedata['mv'] + mesh.nodedata["position"] = shift(mesh.nodedata["position"], 1.0 * dt * mesh.nodedata["tv"]) + + r = mesh.nodedata["position"] + i_s, j_s = kdtree.range_query(mesh.nodedata["position"], 3*h, include_self=True) + r_i_s, r_j_s = r[i_s], r[j_s] + dr_i_j = vmap_backend.apply(displacement, r_i_s, r_j_s) + #print(dr_i_j) + + + +#end = time.time() +#print(end-start) \ No newline at end of file diff --git a/fealpy/cfd/sph/particle_solver_new.py b/fealpy/cfd/sph/particle_solver_new.py index 6ddaf42af..376b5e01e 100644 --- a/fealpy/cfd/sph/particle_solver_new.py +++ b/fealpy/cfd/sph/particle_solver_new.py @@ -10,14 +10,37 @@ from fealpy.backend import backend_manager as bm from fealpy.backend import TensorLike +import numpy as np +import jax +import jax.numpy as jnp +import torch + # Types Box = TensorLike f32 = bm.float32 -class SPHSolver: - def __init__(self, mesh): - self.mesh = mesh +class Space: + def raw_transform(self, box:Box, R:TensorLike): + if box.ndim == 0 or box.size == 1: + + return R * box + elif box.ndim == 1: + indices = self._get_free_indices(R.ndim - 1) + "i" + + return bm.einsum(f"i,{indices}->{indices}", box, R) + elif box.ndim == 2: + free_indices = self._get_free_indices(R.ndim - 1) + left_indices = free_indices + "j" + right_indices = free_indices + "i" + + return bm.einsum(f"ij,{left_indices}->{right_indices}", box, R) + raise ValueError( + ("Box must be either: a scalar, a vector, or a matrix. " f"Found {box}.") + ) + def _get_free_indices(self, n: int): + + return "".join([chr(ord("a") + i) for i in range(n)]) def pairwise_displacement(self, Ra: TensorLike, Rb: TensorLike): if len(Ra.shape) != 1: @@ -34,12 +57,77 @@ def pairwise_displacement(self, Ra: TensorLike, Rb: TensorLike): return Ra - Rb def periodic_displacement(self, side: Box, dR: TensorLike): + _dR = ((dR + side * f32(0.5)) % side) - f32(0.5) * side + + return _dR + + def periodic_shift(self, side: Box, R: TensorLike, dR: TensorLike): + + return (R + dR) % side + + def periodic(self, side: Box, wrapped: bool = True): + def displacement_fn( Ra: TensorLike, Rb: TensorLike, perturbation = None, **unused_kwargs): + if "box" in unused_kwargs: + raise UnexpectedBoxException( + ( + "`space.periodic` does not accept a box " + "argument. Perhaps you meant to use " + "`space.periodic_general`?" + ) + ) + dR = self.periodic_displacement(side, self.pairwise_displacement(Ra, Rb)) + if perturbation is not None: + dR = self.raw_transform(perturbation, dR) + + return dR + if wrapped: + def shift_fn(R: TensorLike, dR: TensorLike, **unused_kwargs): + if "box" in unused_kwargs: + raise UnexpectedBoxException( + ( + "`space.periodic` does not accept a box " + "argument. Perhaps you meant to use " + "`space.periodic_general`?" + ) + ) - return bm.mod(dR + side * f32(0.5), side) - f32(0.5) * side + return self.periodic_shift(side, R, dR) + else: + def shift_fn(R: TensorLike, dR: TensorLike, **unused_kwargs): + if "box" in unused_kwargs: + raise UnexpectedBoxException( + ( + "`space.periodic` does not accept a box " + "argument. Perhaps you meant to use " + "`space.periodic_general`?" + ) + ) + return R + dR + return displacement_fn, shift_fn -''' - def periodic(side: Box): - def displacement_fn - pass -''' \ No newline at end of file +class VmapBackend: + def __init__(self): + current_backend = bm.backend_name + + # 根据当前后端设置 vmap 函数 + if current_backend == 'jax': + # 使用 JAX 的 vmap + self.vmap_func = jax.vmap + elif current_backend == 'numpy': + # 使用 NumPy 的 vectorize + self.vmap_func = np.vectorize + elif current_backend == 'pytorch': + # 使用 PyTorch 的 vmap + self.vmap_func = torch.vmap + else: + raise ValueError(f"Unsupported backend: {current_backend}") + + def apply(self, func, *args, **kwargs): + # 返回已经适配的 vmap 函数 + return self.vmap_func(func)(*args, **kwargs) + +class SPHSolver: + def __init__(self, mesh): + self.mesh = mesh + \ No newline at end of file diff --git a/fealpy/mesh/node_mesh.py b/fealpy/mesh/node_mesh.py index 7267665d1..a357e6bd1 100644 --- a/fealpy/mesh/node_mesh.py +++ b/fealpy/mesh/node_mesh.py @@ -3,6 +3,7 @@ from ..backend import backend_manager as bm from ..typing import TensorLike, Index, _S from .. import logger +from fealpy.cfd.sph.particle_solver_new import Space from .mesh_base import MeshDS @@ -360,7 +361,9 @@ def _build_tree(self, points, depth): 'right': right_tree, 'axis': axis } - + + #@staticmethod + #@bm.compile def range_query(self, query_points, radius, include_self=False): """ 查找每个查询点距离为radius以内的所有邻居点,并返回邻居粒子的索引和自身粒子的索引。 @@ -381,32 +384,25 @@ def range_query(self, query_points, radius, include_self=False): return bm.array(neighbors), bm.array(indices) + #@staticmethod + #@bm.compile def _range_query(self, node, query_point, radius, depth, query_idx, include_self): - """ - 递归查找范围查询,返回邻居粒子的索引。 - """ if node is None: return [] axis = node['axis'] neighbors = [] - - # 计算当前点和查询点的距离 point = node['point'] dist = bm.linalg.norm(point - query_point) - # 如果当前点在查询范围内,加入结果 + # 剪枝策略:距离超过半径时,停止搜索 if dist <= radius: - # 如果include_self为True,允许加入查询点本身 - if include_self or not (point==query_point).all(): - # 获取该点的索引 + if include_self or not (point==query_point).all(): neighbor_idx = bm.where(bm.all(self.points == point, axis=1))[0][0] - neighbors.append(neighbor_idx) # 添加邻居粒子的索引 + neighbors.append(neighbor_idx) - # 判断是否需要继续递归查找子树 next_branch = None opposite_branch = None - if query_point[axis] < point[axis]: next_branch = node['left'] opposite_branch = node['right'] @@ -414,26 +410,30 @@ def _range_query(self, node, query_point, radius, depth, query_idx, include_self next_branch = node['right'] opposite_branch = node['left'] - # 递归查询邻居 + # 递归查询 neighbors.extend(self._range_query(next_branch, query_point, radius, depth + 1, query_idx, include_self)) - # 如果对面子树有可能有点在范围内,继续查询 + # 剪枝:如果查询点和当前节点的差异已经大于半径,停止递归查询对面子树 if abs(query_point[axis] - point[axis]) < radius: neighbors.extend(self._range_query(opposite_branch, query_point, radius, depth + 1, query_idx, include_self)) return neighbors + +''' ''' class KDTree: - def __init__(self, points): + def __init__(self, points, box_size): """ - 初始化KDTree,输入为一个N x D的NumPy数组,表示N个D维的点集。 + 初始化KDTree, 输入为一个N x D的NumPy数组,表示N个D维的点集,同时传入box_size用于周期边界处理。 """ self.points = points + self.box_size = box_size # 储存周期边界的盒子尺寸 self.tree = self._build_tree(points, depth=0) + self.space = Space() # 初始化周期边界空间 def _build_tree(self, points, depth): """ - 递归构建KD-Tree。points是当前节点的点集,depth是当前树的深度。 + 递归构建KD-Tree。points是当前节点的点集,depth是当前树的深度。 """ if len(points) == 0: return None @@ -458,12 +458,10 @@ def _build_tree(self, points, depth): 'axis': axis } - @staticmethod - @bm.compile def range_query(self, query_points, radius, include_self=False): """ - 查找每个查询点距离为radius以内的所有邻居点,并返回邻居粒子的索引和自身粒子的索引。 - query_points: (M, D) 形状的 NumPy 数组,表示M个查询点。 + 查找每个查询点距离为radius以内的所有邻居点,并返回邻居粒子的索引和自身粒子的索引。 + query_points: (M, D) 形状的NumPy数组,表示M个查询点。 radius: 查找邻居的半径 include_self: 是否将查询点本身的索引加入到邻居结果中 """ @@ -480,8 +478,6 @@ def range_query(self, query_points, radius, include_self=False): return bm.array(neighbors), bm.array(indices) - @staticmethod - @bm.compile def _range_query(self, node, query_point, radius, depth, query_idx, include_self): if node is None: return [] @@ -489,11 +485,14 @@ def _range_query(self, node, query_point, radius, depth, query_idx, include_self axis = node['axis'] neighbors = [] point = node['point'] - dist = bm.linalg.norm(point - query_point) + + # 计算周期边界下的距离 + dR = self.space.periodic_displacement(self.box_size, query_point - point) + dist = bm.linalg.norm(dR) # 剪枝策略:距离超过半径时,停止搜索 - if dist <= radius: - if include_self or not (point==query_point).all(): + if dist < radius: + if include_self or not (point == query_point).all(): neighbor_idx = bm.where(bm.all(self.points == point, axis=1))[0][0] neighbors.append(neighbor_idx) @@ -513,4 +512,103 @@ def _range_query(self, node, query_point, radius, depth, query_idx, include_self if abs(query_point[axis] - point[axis]) < radius: neighbors.extend(self._range_query(opposite_branch, query_point, radius, depth + 1, query_idx, include_self)) + # 对面子树的遍历应该也使用周期位移处理 + if opposite_branch is not None: + neighbors.extend(self._range_query(opposite_branch, query_point, radius, depth + 1, query_idx, include_self)) + return neighbors +''' +class KDTree: + def __init__(self, points, box_size): + """ + 初始化KDTree, 输入为一个N x D的NumPy数组,表示N个D维的点集,同时传入box_size用于周期边界处理。 + """ + self.points = points + self.box_size = box_size # 储存周期边界的盒子尺寸 + self.tree = self._build_tree(points, depth=0) + self.space = Space() # 初始化周期边界空间 + + def _build_tree(self, points, depth): + """ + 递归构建KD-Tree。points是当前节点的点集,depth是当前树的深度。 + """ + if len(points) == 0: + return None + + # 选择当前维度 + k = points.shape[1] # 点的维度 + axis = depth % k # 循环选择维度 + + # 排序并选择中位点作为根节点 + points = points[points[:, axis].argsort()] + median_idx = len(points) // 2 + median_point = points[median_idx] + + # 递归构建左子树和右子树 + left_tree = self._build_tree(points[:median_idx], depth + 1) + right_tree = self._build_tree(points[median_idx + 1:], depth + 1) + + return { + 'point': median_point, + 'left': left_tree, + 'right': right_tree, + 'axis': axis + } + + def range_query(self, query_points, radius, include_self=False): + neighbors_within_range = [] + self_indices = [] # 存储查询点的索引 + for i, query_point in enumerate(query_points): + neighbors = self._range_query(self.tree, query_point, radius, depth=0, query_idx=i, include_self=include_self) + + # 确保邻居索引去重并转换为列表 + unique_neighbors = sorted(set(map(int, neighbors))) # 使用set去重,再排序,确保neighbors是整数索引 + neighbors_within_range.append(unique_neighbors) + + self_indices.append([i] * len(unique_neighbors)) # 每个查询点的索引重复以匹配邻居数量 + + # 将查询点索引和邻居粒子的索引拆开,按索引顺序整理 + neighbors = [item for sublist in neighbors_within_range for item in sublist] + indices = [item for sublist in self_indices for item in sublist] + + return bm.array(neighbors), bm.array(indices) + + def _range_query(self, node, query_point, radius, depth, query_idx, include_self): + if node is None: + return [] + + axis = node['axis'] + neighbors = [] + point = node['point'] + + # 计算周期边界下的距离 + dR = self.space.periodic_displacement(self.box_size, query_point - point) + dist = bm.linalg.norm(dR) + + # 剪枝策略:距离超过半径时,停止搜索 + if dist < radius: + if include_self or not (point == query_point).all(): + neighbor_idx = bm.where(bm.all(self.points == point, axis=1))[0][0] + neighbors.append(neighbor_idx) + + next_branch = None + opposite_branch = None + if query_point[axis] < point[axis]: + next_branch = node['left'] + opposite_branch = node['right'] + else: + next_branch = node['right'] + opposite_branch = node['left'] + + # 递归查询 + neighbors.extend(self._range_query(next_branch, query_point, radius, depth + 1, query_idx, include_self)) + + # 剪枝:如果查询点和当前节点的差异已经大于半径,停止递归查询对面子树 + if abs(query_point[axis] - point[axis]) < radius: + neighbors.extend(self._range_query(opposite_branch, query_point, radius, depth + 1, query_idx, include_self)) + + # 对面子树的遍历应该也使用周期位移处理 + if opposite_branch is not None: + neighbors.extend(self._range_query(opposite_branch, query_point, radius, depth + 1, query_idx, include_self)) + + return neighbors \ No newline at end of file