Skip to content

Commit

Permalink
Add flag to dump mlir modules to mxrs (#3182)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Oct 21, 2024
1 parent c38c7ca commit c3a5367
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 11 deletions.
5 changes: 5 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ Enable reduction fusions in MLIR.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable Split-k perf configs when tuning with MLIR.

.. envvar:: MIGRAPHX_MLIR_DUMP_TO_MXR

Set to path where MXRs will be saved.
Dumps MLIRs module to mxr files.

CK vars
-----------

Expand Down
2 changes: 2 additions & 0 deletions src/include/migraphx/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ struct MIGRAPHX_EXPORT program
{
program();

explicit program(module m);

// move constructor
program(program&&) noexcept;

Expand Down
2 changes: 2 additions & 0 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ struct MIGRAPHX_EXPORT shape
const std::size_t& y);
};

static std::string to_sizes_string(const std::vector<shape>& shapes);

static const std::vector<type_t>& types();

static std::string name(type_t t);
Expand Down
16 changes: 5 additions & 11 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ struct program_impl
};

program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
program::program(module m) : impl(std::make_unique<program_impl>())
{
this->create_module("main", std::move(m));
}

program::program(program&&) noexcept = default;
program::~program() noexcept = default;
Expand Down Expand Up @@ -852,17 +856,7 @@ std::string perf_group(instruction_ref ins, bool detailed)
if(detailed)
{
result += "<" + ins->get_shape().type_string();
std::vector<std::string> sizes;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(sizes),
[&](instruction_ref input) {
std::string r = to_string_range(input->get_shape().lens(), "x");
if(not input->get_shape().standard())
r += ":" + to_string_range(input->get_shape().strides(), "x");
return r;
});
result += "(" + join_strings(sizes, ", ") + ")>";
result += "(" + shape::to_sizes_string(to_shapes(ins->inputs())) + ")>";
}
return result;
}
Expand Down
12 changes: 12 additions & 0 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,18 @@ struct shape_impl
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
};

std::string shape::to_sizes_string(const std::vector<shape>& shapes)
{
std::vector<std::string> sizes;
std::transform(shapes.begin(), shapes.end(), std::back_inserter(sizes), [&](const shape& s) {
std::string r = to_string_range(s.lens(), "x");
if(not s.standard())
r += ":" + to_string_range(s.strides(), "x");
return r;
});
return join_strings(sizes, ", ");
}

const std::vector<shape::type_t>& shape::types()
{
static const std::vector<shape::type_t> result = {
Expand Down
4 changes: 4 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/mlir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <string>
#include <vector>
#include <migraphx/value.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/gpu/config.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/instruction_ref.hpp>
Expand Down Expand Up @@ -67,6 +68,9 @@ MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx
const std::vector<shape>& inputs,
bool exhaustive);

MIGRAPHX_GPU_EXPORT void
dump_mlir_to_mxr(module m, const std::vector<instruction_ref>& inputs, const fs::path& location);

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
8 changes: 8 additions & 0 deletions src/targets/gpu/jit/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_DUMP_TO_MXR);

static module create_pointwise_module(module_ref in_mod)
{
module pw_mod;
Expand Down Expand Up @@ -209,8 +211,14 @@ struct mlir_compiler : compiler<mlir_compiler>
const operation&,
bool exhaustive) const
{
static const auto mxr_loc = string_value_of(MIGRAPHX_MLIR_DUMP_TO_MXR{});

auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front();
if(not mxr_loc.empty())
{
dump_mlir_to_mxr(*smod, ins->inputs(), mxr_loc);
}
return get_tuning_config_mlir(ctx, *smod, shapes, exhaustive);
}

Expand Down
43 changes: 43 additions & 0 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
#include <migraphx/env.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
Expand Down Expand Up @@ -1063,6 +1065,23 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
}
}

void replace_params_with_literals(module& m, const std::vector<instruction_ref>& inputs)
{
auto names = m.get_parameter_names();
std::sort(names.begin(), names.end());
for(auto i : range(names.size()))
{
const auto& name = names[i];
const auto& input = inputs[i];
if(input->name() != "@literal")
continue;
auto param = m.get_parameter(name);
auto lit = m.add_literal(input->get_literal());
m.replace_instruction(param, lit);
m.remove_instruction(param);
}
}

std::string dump_mlir(module m, const std::vector<shape>& inputs)
{
const_module_ref mr = &m;
Expand Down Expand Up @@ -1182,6 +1201,30 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
return tc;
}

void dump_mlir_to_mxr(module m,
const std::vector<instruction_ref>& inputs,
const fs::path& location)
{
static std::mutex mutex;
const std::lock_guard<std::mutex> lock(mutex);

adjust_param_shapes(m, to_shapes(inputs));
replace_params_with_literals(m, inputs);
std::vector<instruction_ref> sizes;
for(auto ins : iterator_for(m))
{
if(not contains({"convolution", "dot"}, ins->name()))
continue;
sizes.insert(sizes.end(), ins->inputs().begin(), ins->inputs().end());
}
auto name =
mlir_program::get_symbol_name(m) + "_" + shape::to_sizes_string(to_shapes(sizes)) + ".mxr";
replace_string_inplace(name, ", ", "_");
replace_string_inplace(name, ":", "s");
auto f = location / name;
save(program{std::move(m)}, f.string());
}

#else

template <class T>
Expand Down

0 comments on commit c3a5367

Please sign in to comment.