Skip to content

Commit

Permalink
Backport #7315 and #7380 to release/15.x branch (#7383)
Browse files Browse the repository at this point in the history
* Make Callable::call_argv_fast public (#7315)

* Make Callable::call_argv_fast public

* Add rough specification of the calling convention

* Fix a typo

* Add Callable default ctor + `defined()` method (#7380)

* Add Callable default ctor + `defined()` method

This allows it to behave like

* Add user_assert + test

---------

Co-authored-by: Tom Westerhout <[email protected]>
  • Loading branch information
steven-johnson and twesterhout authored Mar 1, 2023
1 parent 4ce5009 commit d7651f4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
10 changes: 9 additions & 1 deletion src/Callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ void destroy<CallableContents>(const CallableContents *p) {
} // namespace Internal

Callable::Callable()
: contents(new CallableContents) {
: contents(nullptr) {
}

bool Callable::defined() const {
return contents.defined();
}

Callable::Callable(const std::string &name,
Expand Down Expand Up @@ -136,6 +140,8 @@ Callable::FailureFn Callable::check_qcci(size_t argc, const QuickCallCheckInfo *
}

Callable::FailureFn Callable::check_fcci(size_t argc, const FullCallCheckInfo *actual_fcci) const {
user_assert(defined()) << "Cannot call() a default-constructed Callable.";

// Lazily create full_call_check_info upon the first call to make_std_function().
if (contents->full_call_check_info.empty()) {
contents->full_call_check_info.reserve(contents->jit_cache.arguments.size());
Expand Down Expand Up @@ -197,6 +203,8 @@ Callable::FailureFn Callable::check_fcci(size_t argc, const FullCallCheckInfo *a
}

int Callable::call_argv_checked(size_t argc, const void *const *argv, const QuickCallCheckInfo *actual_qcci) const {
user_assert(defined()) << "Cannot call() a default-constructed Callable.";

// It's *essential* we call this for safety.
const auto failure_fn = check_qcci(argc, actual_qcci);
if (failure_fn) {
Expand Down
27 changes: 25 additions & 2 deletions src/Callable.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,13 @@ class Callable {
}
};

Callable();
Callable(const std::string &name,
const JITHandlers &jit_handlers,
const std::map<std::string, JITExtern> &jit_externs,
Internal::JITCache &&jit_cache);

// Note that the first entry in argv must always be a JITUserContext*.
int call_argv_checked(size_t argc, const void *const *argv, const QuickCallCheckInfo *actual_cci) const;
int call_argv_fast(size_t argc, const void *const *argv) const;

using FailureFn = std::function<int(JITUserContext *)>;

Expand All @@ -304,6 +302,13 @@ class Callable {
const std::vector<Argument> &arguments() const;

public:
/** Construct a default Callable. This is not usable (trying to call it will fail).
* The defined() method will return false. */
Callable();

/** Return true if the Callable is well-defined and usable, false if it is a default-constructed empty Callable. */
bool defined() const;

template<typename... Args>
HALIDE_FUNCTION_ATTRS int
operator()(JITUserContext *context, Args &&...args) const {
Expand Down Expand Up @@ -380,6 +385,24 @@ class Callable {
};
}
}

/** Unsafe low-overhead way of invoking the Callable.
*
* This function relies on the same calling convention as the argv-based
* functions generated for ahead-of-time compiled Halide pilelines.
*
* Very rough specifications of the calling convention (but check the source
* code to be sure):
*
* * Arguments are passed in the same order as they appear in the C
* function argument list.
* * The first entry in argv must always be a JITUserContext*. Please,
* note that this means that argv[0] actually contains JITUserContext**.
* * All scalar arguments are passed by pointer, not by value, regardless of size.
* * All buffer arguments (input or output) are passed as halide_buffer_t*.
*
*/
int call_argv_fast(size_t argc, const void *const *argv) const;
};

} // namespace Halide
Expand Down
9 changes: 9 additions & 0 deletions test/correctness/callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ HalideExtern_2(float, my_extern_func, int, float);
int main(int argc, char **argv) {
const Target t = get_jit_target_from_environment();

{
// Check that we can default-construct a Callable.
Callable c;
assert(!c.defined());

// This will assert-fail.
// c(0,1,2);
}

{
Param<int32_t> p_int(42);
Param<float> p_float(1.0f);
Expand Down

0 comments on commit d7651f4

Please sign in to comment.