Skip to content

Commit

Permalink
update ConvKernel isAvailable checking
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 69e316b8ca9f826d8134ebd7a932ccabc328e401
  • Loading branch information
megvii-mge committed Nov 20, 2023
1 parent 252a8d9 commit feee3f2
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 12 deletions.
4 changes: 2 additions & 2 deletions compiler/lib/KernelGen/Arm/Arm64/KernelPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
#include "BatchedMatmul/BatchedMatmul.h"
#include "ConvKernel.h"
#include "Elemwise/Elemwise.h"
#include "Rotate.h"
#include "InternalKernel/InternalKernel.h"
#include "KernelPack.h"
#include "MatMulKernel/MatMul.h"
#include "Rotate.h"

using namespace megcc;
using namespace KernelGen;
Expand Down Expand Up @@ -46,7 +46,7 @@ struct AllA64Kernel {

inner_map[KernelPack::KernType::BatchMatmulKernel] = {
std::make_shared<Arm64::Fp32BatchedMatmul>()};

inner_map[KernelPack::KernType::RotateKernel] = {
std::make_shared<Arm64::RotateKernel>()};
}
Expand Down
6 changes: 3 additions & 3 deletions compiler/lib/KernelGen/Arm/Arm64/Rotate.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#pragma once
#include <string>
#include <sstream>
#include "compiler/KernelGen/KernelGen.h"
#include <string>
#include "Utils/SymbolHelper.h"
#include "Utils/Utils.h"
#include "compiler/KernelGen/KernelGen.h"
namespace megcc {
namespace KernelGen {
namespace Arm64 {

class RotateKernel : public KernelFunc {
class RotateKernel : public KernelFunc {
public:
bool IsAvailable(TContext* context) const override { return false; };
std::string GetKernelBody(TContext* context) const override { return ""; };
Expand Down
3 changes: 2 additions & 1 deletion compiler/lib/KernelGen/Common/ConvKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class ConvImpl : public KernelFunc {
static bool is_channel_broadcast_bias(TContext* ctx) {
if (is_bias(ctx)) {
CCOperand bias = ctx->getAttrOprand("operand:2");
return bias.shape[0] == 1 && bias.shape[2] == 1 && bias.shape[3] == 1;
return (bias.shape[0] == 1 && bias.shape[2] == 1 && bias.shape[3] == 1) ||
bias.shape.size() == 1;
}
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/lib/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_subdirectory(MGB)
add_subdirectory(TinyNN)
add_subdirectory(Hako)
add_subdirectory(onnx)
add_subdirectory(onnx)
2 changes: 1 addition & 1 deletion compiler/lib/Target/onnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ target_include_directories(
${PROTOBUF_INCLUDE_DIR})
# add onnx-imported
target_link_libraries(MLIRONNXImporter PUBLIC $<BUILD_INTERFACE:onnx_imported>)
# target_compile_options(MLIRONNXImporter PUBLIC -fexceptions)
# target_compile_options(MLIRONNXImporter PUBLIC -fexceptions)
2 changes: 1 addition & 1 deletion compiler/test/kernel/opr/arm/cv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,6 @@ TEST(AARCH64, CVrotateFp16) {
checker.exec({{1, 19, 19, 1}, {}});

checker.exec({{1, 19, 19, 3}, {}});
}
}
}
#endif
2 changes: 1 addition & 1 deletion compiler/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ add_subdirectory(dump-kernel)
add_subdirectory(megcc-translate)
add_subdirectory(kernel_exporter)
add_subdirectory(onnx-importer)
add_subdirectory(onnx-to-tinynn)
add_subdirectory(onnx-to-tinynn)
2 changes: 1 addition & 1 deletion compiler/tools/onnx-importer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ add_llvm_executable(onnx-importer onnx-importer.cpp NO_INSTALL_RPATH)
llvm_update_compile_flags(onnx-importer)
target_link_libraries(onnx-importer PRIVATE ${dialect_libs} MLIRONNXImporter Common
${ONNX_LIBS} ${PROTOBUF_LIBS})
mlir_check_all_link_libraries(onnx-importer)
mlir_check_all_link_libraries(onnx-importer)
2 changes: 1 addition & 1 deletion compiler/tools/onnx-to-tinynn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ target_link_libraries(
HakoParse
${ONNX_LIBS}
${PROTOBUF_LIBS})
mlir_check_all_link_libraries(onnx-to-tinynn)
mlir_check_all_link_libraries(onnx-to-tinynn)

0 comments on commit feee3f2

Please sign in to comment.