Skip to content

Commit

Permalink
Fixes #561: Launch config builder validation code sort-out
Browse files Browse the repository at this point in the history
* Dropped the unused static function `validate_all_dimension_compatibility`
* Added that validation when obtaining the composite block and grid dimensions
* Removed the `resolve_dimensions()` method - it's not useful (and not used anywhere else)
  • Loading branch information
eyalroz committed Dec 24, 2023
1 parent e77e78d commit a878633
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions src/cuda/api/launch_config_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ inline dimensions_t div_rounding_up(overall_dimensions_t overall_dims, block_dim

namespace detail_ {

static void validate_all_dimension_compatibility(
static void validate_all_dimensions_compatibility(
grid::block_dimensions_t block,
grid::dimensions_t grid,
grid::overall_dimensions_t overall)
Expand All @@ -59,16 +59,6 @@ static void validate_all_dimension_compatibility(
} // namespace detail_

class launch_config_builder_t {
public:
void resolve_dimensions() {
grid::composite_dimensions_t cd = get_composite_dimensions();
dimensions_.block = cd.block;
dimensions_.grid = cd.grid;
if (not dimensions_.overall) {
dimensions_.overall = cd.grid * cd.block;
}
}

protected:
memory::shared::size_t get_dynamic_shared_memory_size(grid::block_dimensions_t block_dims) const
{
Expand Down Expand Up @@ -128,17 +118,21 @@ class launch_config_builder_t {
return result;
}
#endif
if (dimensions_.block and dimensions_.overall) {
if (dimensions_.block and dimensions_.overall and not dimensions_.grid) {
result.grid = grid::detail_::div_rounding_up(dimensions_.overall.value(), dimensions_.block.value());
result.block = dimensions_.block.value();
return result;
}
if (dimensions_.grid and dimensions_.overall) {
if (dimensions_.grid and dimensions_.overall and not dimensions_.block) {
result.block = grid::detail_::div_rounding_up(dimensions_.overall.value(), dimensions_.grid.value());
result.grid = dimensions_.grid.value();
return result;
}

if (dimensions_.grid and dimensions_.block) {
if (dimensions_.overall and (dimensions_.grid.value() * dimensions_.block.value() != dimensions_.overall.value())) {
throw ::std::invalid_argument("specified block, grid and overall dimensions do not agree");
}
result.block = dimensions_.block.value();
result.grid = dimensions_.grid.value();
return result;
Expand All @@ -149,7 +143,7 @@ class launch_config_builder_t {
"Neither block nor grid dimensions have been specified");
} else if (not dimensions_.block and not dimensions_.overall) {
throw ::std::logic_error(
"Attempt to obtain the composite grid dimensions, while the grid dimensions have only bee specified "
"Attempt to obtain the composite grid dimensions, while the grid dimensions have only been specified "
"in terms of blocks, not threads, with no block dimensions specified");
} else { // it must be the case that (not dimensions_.block and not dimensions_.overall)
throw ::std::logic_error(
Expand Down Expand Up @@ -262,7 +256,7 @@ class launch_config_builder_t {
{
detail_::validate_block_dimensions(block_dims);
if (dimensions_.grid and dimensions_.overall) {
detail_::validate_all_dimension_compatibility(
detail_::validate_all_dimensions_compatibility(
block_dims, dimensions_.grid.value(), dimensions_.overall.value());
}
// TODO: Check divisibility
Expand All @@ -274,7 +268,7 @@ class launch_config_builder_t {
{
detail_::validate_grid_dimensions(grid_dims);
if (dimensions_.block and dimensions_.overall) {
detail_::validate_all_dimension_compatibility(
detail_::validate_all_dimensions_compatibility(
dimensions_.block.value(), grid_dims, dimensions_.overall.value());
}
// TODO: Check divisibility
Expand Down

0 comments on commit a878633

Please sign in to comment.