-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsaxpy_acc_mpi.f90
128 lines (114 loc) · 2.88 KB
/
saxpy_acc_mpi.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
program saxpy_acc_mpi
use openacc
use mpi
implicit none
integer :: len_global, len_local, i
integer :: irank, nranks, igpu, ngpus, ierr, istat(MPI_STATUS_SIZE)
real(4), allocatable, dimension(:) :: X_local, X_global, Y_global, Y_local, Y_ref
real(4) :: a
character(len=128) :: argv
a = 2.0
len_global = 1024
! Initialize MPI
call MPI_Init(ierr)
call MPI_Comm_rank(MPI_COMM_WORLD, irank, ierr)
call MPI_Comm_size(MPI_COMM_WORLD, nranks, ierr)
! Check to see that the global array length is evenly divisible by the number of MPI ranks
if (mod(len_global, nranks) .ne. 0) then
if (irank .eq. 0) then
write(*,'(a,5i,a,5i)'), 'The global array length, ', len_global, &
', is not divisible by the number of ranks, ', nranks
call MPI_Abort(MPI_COMM_WORLD, 1, ierr)
endif
else
len_local = len_global / nranks
endif
! Find GPU devices and set the device number
ngpus = acc_get_num_devices(acc_device_nvidia)
if (ngpus .le. 0) then
if (irank .eq. 0) then
write(*,'(a)'), 'No NVIDIA GPUs available'
call MPI_Abort(MPI_COMM_WORLD, 1, ierr)
endif
else
igpu = mod(irank, ngpus)
call acc_set_device_num(igpu, acc_device_nvidia)
endif
! Allocate local and global arrays
allocate(X_local(len_local))
allocate(Y_local(len_local))
if (irank .eq. 0) then
allocate(X_global(len_global))
allocate(Y_global(len_global))
allocate(Y_ref(len_global))
endif
! If root, set global and reference arrays
if (irank .eq. 0) then
do i = 1, len_global
X_global(i) = i
Y_global(i) = i + len_global
Y_ref(i) = a * i + (i + len_global)
enddo
endif
! Scatter operands
call MPI_Scatter( &
X_global, &
len_local, &
MPI_REAL4, &
X_local, &
len_local, &
MPI_REAL4, &
0, &
MPI_COMM_WORLD, &
ierr &
)
call MPI_Scatter( &
Y_global, &
len_local, &
MPI_REAL4, &
Y_local, &
len_local, &
MPI_REAL4, &
0, &
MPI_COMM_WORLD, &
ierr &
)
! Do local calculation
!$ACC KERNELS
do i = 1, len_local
Y_local(i) = a*X_local(i) + Y_local(i)
enddo
!$ACC END KERNELS
! Gather result
call MPI_Gather( &
Y_local, &
len_local, &
MPI_REAL4, &
Y_global, &
len_local, &
MPI_REAL4, &
0, &
MPI_COMM_WORLD, &
ierr &
)
! Root checks result
if (irank .eq. 0) then
print *, 'Ran SAXPY for n = ', len_global
if (all(Y_ref .eq. Y_global)) then
print *, 'SUCCESS: Y_global matches Y_ref'
else
print *, 'FAILURE: Y_global does not match Y_ref'
print *, 'Y_global = ', Y_global
print *, 'Y_ref = ', Y_ref
endif
endif
! Cleanup
deallocate(X_local)
deallocate(Y_local)
if (irank .eq. 0) then
deallocate(X_global)
deallocate(Y_global)
deallocate(Y_ref)
endif
call MPI_Finalize(ierr)
end program