Skip to content

Commit

Permalink
Additional matrix multiplication examples (#67)
Browse files Browse the repository at this point in the history
Fixes #64.
  • Loading branch information
emiljanogj authored Jul 21, 2021
1 parent e866b76 commit b4e8470
Show file tree
Hide file tree
Showing 5 changed files with 1,418 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright 2021 The ShaderTrap Project Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Matrix multiplication by splitting the matrices into 2x2 submatrices.

GL 4.5

CREATE_BUFFER array1_mat_buf SIZE_BYTES 1024 INIT_VALUES
float 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.

CREATE_BUFFER array1 SIZE_BYTES 1024 INIT_VALUES
float 1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.
1. 3. 5. 7. 9. 11. 13. 15. 17. 19. 21. 23. 25. 27. 29. 31.

CREATE_BUFFER correct_array1_mat_buf SIZE_BYTES 1024 INIT_VALUES
float 1. 1. 3. 3. 5. 5. 7. 7. 9. 9. 11. 11. 13. 13. 15. 15.
17. 17. 19. 19. 21. 21. 23. 23. 25. 25. 27. 27. 29. 29. 31. 31.
1. 1. 3. 3. 5. 5. 7. 7. 9. 9. 11. 11. 13. 13. 15. 15.
17. 17. 19. 19. 21. 21. 23. 23. 25. 25. 27. 27. 29. 29. 31. 31.
1. 1. 3. 3. 5. 5. 7. 7. 9. 9. 11. 11. 13. 13. 15. 15.
17. 17. 19. 19. 21. 21. 23. 23. 25. 25. 27. 27. 29. 29. 31. 31.
1. 1. 3. 3. 5. 5. 7. 7. 9. 9. 11. 11. 13. 13. 15. 15.
17. 17. 19. 19. 21. 21. 23. 23. 25. 25. 27. 27. 29. 29. 31. 31.
1. 1. 3. 3. 5. 5. 7. 7. 9. 9. 11. 11. 13. 13. 15. 15.
17. 17. 19. 19. 21. 21. 23. 23. 25. 25. 27. 27. 29. 29. 31. 31.
1. 1. 3. 3. 5. 5. 7. 7. 9. 9. 11. 11. 13. 13. 15. 15.
17. 17. 19. 19. 21. 21. 23. 23. 25. 25. 27. 27. 29. 29. 31. 31.
1. 1. 3. 3. 5. 5. 7. 7. 9. 9. 11. 11. 13. 13. 15. 15.
17. 17. 19. 19. 21. 21. 23. 23. 25. 25. 27. 27. 29. 29. 31. 31.
1. 1. 3. 3. 5. 5. 7. 7. 9. 9. 11. 11. 13. 13. 15. 15.
17. 17. 19. 19. 21. 21. 23. 23. 25. 25. 27. 27. 29. 29. 31. 31.

CREATE_BUFFER array2_mat_buf SIZE_BYTES 1024 INIT_VALUES
float 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.

CREATE_BUFFER array2 SIZE_BYTES 1024 INIT_VALUES
float 2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.
2. 4. 6. 8. 10. 12. 14. 16. 18. 20. 22. 24. 26. 28. 30. 32.

CREATE_BUFFER correct_array2_mat_buf SIZE_BYTES 1024 INIT_VALUES
float 2. 2. 4. 4. 6. 6. 8. 8. 10. 10. 12. 12. 14. 14. 16. 16.
18. 18. 20. 20. 22. 22. 24. 24. 26. 26. 28. 28. 30. 30. 32. 32.
2. 2. 4. 4. 6. 6. 8. 8. 10. 10. 12. 12. 14. 14. 16. 16.
18. 18. 20. 20. 22. 22. 24. 24. 26. 26. 28. 28. 30. 30. 32. 32.
2. 2. 4. 4. 6. 6. 8. 8. 10. 10. 12. 12. 14. 14. 16. 16.
18. 18. 20. 20. 22. 22. 24. 24. 26. 26. 28. 28. 30. 30. 32. 32.
2. 2. 4. 4. 6. 6. 8. 8. 10. 10. 12. 12. 14. 14. 16. 16.
18. 18. 20. 20. 22. 22. 24. 24. 26. 26. 28. 28. 30. 30. 32. 32.
2. 2. 4. 4. 6. 6. 8. 8. 10. 10. 12. 12. 14. 14. 16. 16.
18. 18. 20. 20. 22. 22. 24. 24. 26. 26. 28. 28. 30. 30. 32. 32.
2. 2. 4. 4. 6. 6. 8. 8. 10. 10. 12. 12. 14. 14. 16. 16.
18. 18. 20. 20. 22. 22. 24. 24. 26. 26. 28. 28. 30. 30. 32. 32.
2. 2. 4. 4. 6. 6. 8. 8. 10. 10. 12. 12. 14. 14. 16. 16.
18. 18. 20. 20. 22. 22. 24. 24. 26. 26. 28. 28. 30. 30. 32. 32.
2. 2. 4. 4. 6. 6. 8. 8. 10. 10. 12. 12. 14. 14. 16. 16.
18. 18. 20. 20. 22. 22. 24. 24. 26. 26. 28. 28. 30. 30. 32. 32.

CREATE_BUFFER blocked_result SIZE_BYTES 1024 INIT_VALUES
float 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.

CREATE_BUFFER correct_blocked_result SIZE_BYTES 1024 INIT_VALUES
float 512. 512. 1024. 1024. 1536. 1536. 2048. 2048. 2560. 2560. 3072. 3072. 3584. 3584. 4096. 4096.
4608. 4608. 5120. 5120. 5632. 5632. 6144. 6144. 6656. 6656. 7168. 7168. 7680. 7680. 8192. 8192.
512. 512. 1024. 1024. 1536. 1536. 2048. 2048. 2560. 2560. 3072. 3072. 3584. 3584. 4096. 4096.
4608. 4608. 5120. 5120. 5632. 5632. 6144. 6144. 6656. 6656. 7168. 7168. 7680. 7680. 8192. 8192.
512. 512. 1024. 1024. 1536. 1536. 2048. 2048. 2560. 2560. 3072. 3072. 3584. 3584. 4096. 4096.
4608. 4608. 5120. 5120. 5632. 5632. 6144. 6144. 6656. 6656. 7168. 7168. 7680. 7680. 8192. 8192.
512. 512. 1024. 1024. 1536. 1536. 2048. 2048. 2560. 2560. 3072. 3072. 3584. 3584. 4096. 4096.
4608. 4608. 5120. 5120. 5632. 5632. 6144. 6144. 6656. 6656. 7168. 7168. 7680. 7680. 8192. 8192.
512. 512. 1024. 1024. 1536. 1536. 2048. 2048. 2560. 2560. 3072. 3072. 3584. 3584. 4096. 4096.
4608. 4608. 5120. 5120. 5632. 5632. 6144. 6144. 6656. 6656. 7168. 7168. 7680. 7680. 8192. 8192.
512. 512. 1024. 1024. 1536. 1536. 2048. 2048. 2560. 2560. 3072. 3072. 3584. 3584. 4096. 4096.
4608. 4608. 5120. 5120. 5632. 5632. 6144. 6144. 6656. 6656. 7168. 7168. 7680. 7680. 8192. 8192.
512. 512. 1024. 1024. 1536. 1536. 2048. 2048. 2560. 2560. 3072. 3072. 3584. 3584. 4096. 4096.
4608. 4608. 5120. 5120. 5632. 5632. 6144. 6144. 6656. 6656. 7168. 7168. 7680. 7680. 8192. 8192.
512. 512. 1024. 1024. 1536. 1536. 2048. 2048. 2560. 2560. 3072. 3072. 3584. 3584. 4096. 4096.
4608. 4608. 5120. 5120. 5632. 5632. 6144. 6144. 6656. 6656. 7168. 7168. 7680. 7680. 8192. 8192.

CREATE_BUFFER result SIZE_BYTES 1024 INIT_VALUES
float 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.

CREATE_BUFFER mat_mul_expected SIZE_BYTES 1024 INIT_VALUES
float 512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.
512. 1024. 1536. 2048. 2560. 3072. 3584. 4096. 4608. 5120. 5632. 6144. 6656. 7168. 7680. 8192.

DECLARE_SHADER array_to_blocked_array KIND COMPUTE
#version 450

#define KERNEL_SIZE 2
#define MAT_SIZE 16

layout(local_size_x=1, local_size_y=1, local_size_z=1) in;

layout(std430, binding = 0) buffer array1 {
float data[MAT_SIZE][MAT_SIZE];
};
layout(std430, binding=1) buffer array1_mat_buf{
mat2x2 array_mat[MAT_SIZE/KERNEL_SIZE][MAT_SIZE/KERNEL_SIZE];
};

void main() {
uint row = uint(gl_WorkGroupID[0]), col = uint(gl_WorkGroupID[1]);
// Converts a MAT_SIZE x MAT_SIZE array to a (MAT_SIZE/KERNEL_SIZE) x (MAT_SIZE/KERNEL_SIZE)
// array of KERNEL_SIZE x KERNEL_SIZE matrices
array_mat[row/KERNEL_SIZE][col/KERNEL_SIZE][col%KERNEL_SIZE][row%KERNEL_SIZE] = data[row][col];
}
END

COMPILE_SHADER array_to_blocked_array_compiled SHADER array_to_blocked_array
CREATE_PROGRAM array_to_blocked_array_prog SHADERS array_to_blocked_array_compiled

BIND_SHADER_STORAGE_BUFFER BUFFER array1 BINDING 0
BIND_SHADER_STORAGE_BUFFER BUFFER array1_mat_buf BINDING 1

RUN_COMPUTE
PROGRAM array_to_blocked_array_prog
NUM_GROUPS 16 16 1

ASSERT_EQUAL BUFFERS array1_mat_buf correct_array1_mat_buf

BIND_SHADER_STORAGE_BUFFER BUFFER array2 BINDING 0
BIND_SHADER_STORAGE_BUFFER BUFFER array2_mat_buf BINDING 1

RUN_COMPUTE
PROGRAM array_to_blocked_array_prog
NUM_GROUPS 16 16 1

ASSERT_EQUAL BUFFERS array2_mat_buf correct_array2_mat_buf

DECLARE_SHADER mat_mul KIND COMPUTE
#version 450

#define KERNEL_SIZE 2
#define MAT_SIZE 16

layout(local_size_x=1, local_size_y=1, local_size_z=1) in;

layout(std430, binding=0) buffer array1_mat_buf{
mat2x2 array1_mat[MAT_SIZE/KERNEL_SIZE][MAT_SIZE/KERNEL_SIZE];
};
layout(std430, binding=1) buffer array2_mat_buf{
mat2x2 array2_mat[MAT_SIZE/KERNEL_SIZE][MAT_SIZE/KERNEL_SIZE];
};
layout(std430, binding=2) buffer blocked_result{
mat2x2 blocked_res[MAT_SIZE/KERNEL_SIZE][MAT_SIZE/KERNEL_SIZE];
};

void main() {
for(int i = 0; i < MAT_SIZE/KERNEL_SIZE; i++){
for(int j = 0; j < MAT_SIZE/KERNEL_SIZE; j++){
blocked_res[i][j] = mat2x2(0);
for(int k = 0; k < MAT_SIZE/KERNEL_SIZE; k++)
blocked_res[i][j] += array1_mat[i][k] * array2_mat[k][j];
}
}
}
END

COMPILE_SHADER mat_mul_compiled SHADER mat_mul
CREATE_PROGRAM mat_mul_prog SHADERS mat_mul_compiled

BIND_SHADER_STORAGE_BUFFER BUFFER array1_mat_buf BINDING 0
BIND_SHADER_STORAGE_BUFFER BUFFER array2_mat_buf BINDING 1
BIND_SHADER_STORAGE_BUFFER BUFFER blocked_result BINDING 2

RUN_COMPUTE
PROGRAM mat_mul_prog
NUM_GROUPS 1 1 1

ASSERT_EQUAL BUFFERS blocked_result correct_blocked_result

DECLARE_SHADER blocked_array_to_array KIND COMPUTE
#version 450

#define KERNEL_SIZE 2
#define MAT_SIZE 16

layout(local_size_x=1, local_size_y=1, local_size_z=1) in;

layout(std430, binding=0) buffer result{
float res[MAT_SIZE][MAT_SIZE];
};
layout(std430, binding=1) buffer blocked_result{
mat2x2 blocked_res[MAT_SIZE/KERNEL_SIZE][MAT_SIZE/KERNEL_SIZE];
};

void main(){
uint row = uint(gl_WorkGroupID[0]), col = uint(gl_WorkGroupID[1]);
// Converts a (MAT_SIZE/KERNEL_SIZE) x (MAT_SIZE/KERNEL_SIZE) array
// of KERNEL_SIZE x KERNEL_SIZE matrices to a MAT_SIZE x MAT_SIZE array
res[row][col] = blocked_res[row/KERNEL_SIZE][col/KERNEL_SIZE][col%KERNEL_SIZE][row%KERNEL_SIZE];
}
END

COMPILE_SHADER blocked_array_to_array_compiled SHADER blocked_array_to_array
CREATE_PROGRAM blocked_array_to_array_prog SHADERS blocked_array_to_array_compiled

BIND_SHADER_STORAGE_BUFFER BUFFER result BINDING 0
BIND_SHADER_STORAGE_BUFFER BUFFER blocked_result BINDING 1

RUN_COMPUTE
PROGRAM blocked_array_to_array_prog
NUM_GROUPS 16 16 1

ASSERT_EQUAL BUFFERS result mat_mul_expected
Loading

0 comments on commit b4e8470

Please sign in to comment.