Skip to content

Commit

Permalink
Compatibility fixers feb24 (#160)
Browse files Browse the repository at this point in the history
Trying to get the Petsc compilation issues out of the way
  • Loading branch information
lmoresi authored Feb 17, 2024
2 parents 1853bd8 + 5c2b57b commit a287bf9
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 18 deletions.
15 changes: 8 additions & 7 deletions src/underworld3/cython/petsc_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,12 @@ PetscErrorCode UW_PetscDSViewBdWF(PetscDS ds, PetscInt bd)
return 1;
}

// PetscErrorCode UW_PetscVecConcatenate(PetscInt nx, Vec inputVecs[], Vec *outputVec)
// {
// IS *x_is;

// PetscErrorCode VecConcatenate(nx, inputVecs, outputVec, &x_is);
PetscErrorCode UW_DMPlexSetSNESLocalFEM(DM dm, PetscBool flag, void *ctx)
{

// return 1;
// }
#if PETSC_VERSION_LE(3, 20, 4)
return DMPlexSetSNESLocalFEM(dm, NULL, NULL, NULL);
#else
return DMPlexSetSNESLocalFEM(dm, flag, NULL);
#endif
}
6 changes: 2 additions & 4 deletions src/underworld3/cython/petsc_extras.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ cdef extern from "petsc_compat.h":
PetscErrorCode UW_PetscDSSetBdTerms (PetscDS, PetscDMLabel, PetscInt, PetscInt, PetscInt, PetscInt, PetscInt, void*, void*, void*, void*, void*, void* )
PetscErrorCode UW_PetscDSViewWF(PetscDS)
PetscErrorCode UW_PetscDSViewBdWF(PetscDS, PetscInt)

# PetscErrorCode UW_PetscVecConcatenate(PetscInt, PetscVec[], PetscVec *)
# PetscErrorCode DMAddBoundary(PetscDM, DMBoundaryConditionType, const char name[], DMLabel label, PetscInt Nv, PetscInt values[], PetscInt field, PetscInt Nc, PetscInt comps[], void (*)(), void (*)(), void *ctx, PetscInt *bd)
PetscErrorCode UW_DMPlexSetSNESLocalFEM( PetscDM, PetscBool, void *)

cdef extern from "petsc.h" nogil:
PetscErrorCode DMPlexSNESComputeBoundaryFEM( PetscDM, void *, void *)
# PetscErrorCode DMPlexSetSNESLocalFEM( PetscDM, void *, void *, void *)
PetscErrorCode DMPlexSetSNESLocalFEM( PetscDM, PetscBool, void *)
# PetscErrorCode DMPlexSetSNESLocalFEM( PetscDM, PetscBool, void *)
PetscErrorCode DMPlexComputeGeometryFVM( PetscDM dm, PetscVec *cellgeom, PetscVec *facegeom)
PetscErrorCode MatInterpolate(PetscMat A, PetscVec x, PetscVec y)
PetscErrorCode DMSetLocalSection(PetscDM, PetscSection)
Expand Down
7 changes: 4 additions & 3 deletions src/underworld3/cython/petsc_generic_snes_solvers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ class SNES_Scalar(Solver):


cdef DM dm = self.dm
DMPlexSetSNESLocalFEM(dm.dm, PETSC_FALSE, NULL)
UW_DMPlexSetSNESLocalFEM(dm.dm, PETSC_FALSE, NULL)

self.is_setup = True
self.constitutive_model._solver_is_setup = True
Expand Down Expand Up @@ -1333,7 +1333,8 @@ class SNES_Vector(Solver):
self.snes.setFromOptions()

cdef DM dm = self.dm
DMPlexSetSNESLocalFEM(dm.dm, PETSC_FALSE, NULL)
UW_DMPlexSetSNESLocalFEM(dm.dm, PETSC_FALSE, NULL)


self.is_setup = True
self.constitutive_model._solver_is_setup = True
Expand Down Expand Up @@ -2298,7 +2299,7 @@ class SNES_Stokes_SaddlePt(Solver):
self.snes.setFromOptions()

cdef DM c_dm = self.dm
DMPlexSetSNESLocalFEM(c_dm.dm, PETSC_FALSE, NULL)
UW_DMPlexSetSNESLocalFEM(c_dm.dm, PETSC_FALSE, NULL)

# Setup subdms here too.
# These will be used to copy back/forth SNES solutions
Expand Down
4 changes: 3 additions & 1 deletion src/underworld3/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _from_gmsh(
# we do this by saving the mesh as h5 which is more flexible to re-use later

if uw.mpi.rank == 0:
plex_0 = PETSc.DMPlex().createFromFile(filename, comm=PETSc.COMM_SELF)
plex_0 = PETSc.DMPlex().createFromFile(
filename, interpolate=True, comm=PETSc.COMM_SELF
)

plex_0.setName("uw_mesh")
plex_0.markBoundaryFaces("All_Boundaries", 1001)
Expand Down
9 changes: 6 additions & 3 deletions src/underworld3/utilities/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def mem_footprint():
return python_process.memory_info().rss // 1000000


def gather_data(val, bcast=False):
def gather_data(val, bcast=False, dtype="float64"):

"""
gather values on root (bcast=False) or all (bcast = True) processors
Parameters:
Expand All @@ -110,15 +111,17 @@ def gather_data(val, bcast=False):
if len(val > 0):
val_local = np.ascontiguousarray(val.copy())
else:
val_local = np.array([np.nan], dtype="float64")
val_local = np.array([np.nan], dtype=dtype)


comm.barrier()

### Collect local array sizes using the high-level mpi4py gather
sendcounts = np.array(comm.gather(len(val_local), root=0))

if rank == 0:
val_global = np.zeros((sum(sendcounts)), dtype="float64")
val_global = np.zeros((sum(sendcounts)), dtype=dtype)

else:
val_global = None

Expand Down

0 comments on commit a287bf9

Please sign in to comment.