Skip to content

Commit

Permalink
Added support for thread exit callbacks (#503)
Browse files Browse the repository at this point in the history
* Added support for thread exit callbacks

* Added user data support to aws_thread_call_once
  • Loading branch information
Justin Boswell authored Sep 11, 2019
1 parent 043b979 commit 94351a9
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 22 deletions.
13 changes: 12 additions & 1 deletion include/aws/common/thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ AWS_EXTERN_C_BEGIN
AWS_COMMON_API
const struct aws_thread_options *aws_default_thread_options(void);

AWS_COMMON_API void aws_thread_call_once(aws_thread_once *flag, void (*call_once)(void));
AWS_COMMON_API void aws_thread_call_once(aws_thread_once *flag, void (*call_once)(void *), void *user_data);

/**
* Initializes a new platform specific thread object struct (not the os-level
Expand Down Expand Up @@ -121,6 +121,17 @@ uint64_t aws_thread_current_thread_id(void);
AWS_COMMON_API
void aws_thread_current_sleep(uint64_t nanos);

typedef void(aws_thread_atexit_fn)(void *user_data);

/**
* Adds a callback to the chain to be called when the current thread joins.
* Callbacks are called from the current thread, in the reverse order they
* were added, after the thread function returns.
* If not called from within an aws_thread, has no effect.
*/
AWS_COMMON_API
int aws_thread_current_at_exit(aws_thread_atexit_fn *callback, void *user_data);

AWS_EXTERN_C_END

#endif /* AWS_COMMON_THREAD_H */
7 changes: 4 additions & 3 deletions source/posix/clock.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ static int s_legacy_get_time(uint64_t *timestamp) {
static aws_thread_once s_thread_once_flag = AWS_THREAD_ONCE_STATIC_INIT;
static int (*s_gettime_fn)(clockid_t __clock_id, struct timespec *__tp) = NULL;

static void s_do_osx_loads(void) {
static void s_do_osx_loads(void *user_data) {
(void)user_data;
s_gettime_fn = (int (*)(clockid_t __clock_id, struct timespec * __tp)) dlsym(RTLD_DEFAULT, "clock_gettime");
}

int aws_high_res_clock_get_ticks(uint64_t *timestamp) {
aws_thread_call_once(&s_thread_once_flag, s_do_osx_loads);
aws_thread_call_once(&s_thread_once_flag, s_do_osx_loads, NULL);
int ret_val = 0;

if (s_gettime_fn) {
Expand All @@ -81,7 +82,7 @@ int aws_high_res_clock_get_ticks(uint64_t *timestamp) {
}

int aws_sys_clock_get_ticks(uint64_t *timestamp) {
aws_thread_call_once(&s_thread_once_flag, s_do_osx_loads);
aws_thread_call_once(&s_thread_once_flag, s_do_osx_loads, NULL);
int ret_val = 0;

if (s_gettime_fn) {
Expand Down
5 changes: 3 additions & 2 deletions source/posix/device_random.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ static aws_thread_once s_rand_init = AWS_THREAD_ONCE_STATIC_INIT;
#else
# define OPEN_FLAGS (O_RDONLY)
#endif
static void s_init_rand(void) {
static void s_init_rand(void *user_data) {
(void)user_data;
s_rand_fd = open("/dev/urandom", OPEN_FLAGS);

if (s_rand_fd == -1) {
Expand All @@ -46,7 +47,7 @@ static void s_init_rand(void) {

static int s_fallback_device_random_buffer(struct aws_byte_buf *output) {

aws_thread_call_once(&s_rand_init, s_init_rand);
aws_thread_call_once(&s_rand_init, s_init_rand, NULL);

size_t diff = output->capacity - output->len;

Expand Down
64 changes: 57 additions & 7 deletions source/posix/thread.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,35 @@ static struct aws_thread_options s_default_options = {
/* this will make sure platform default stack size is used. */
.stack_size = 0};

struct thread_atexit_callback {
aws_thread_atexit_fn *callback;
void *user_data;
struct thread_atexit_callback *next;
};

struct thread_wrapper {
struct aws_allocator *allocator;
void (*func)(void *arg);
void *arg;
struct thread_atexit_callback *atexit;
void (*call_once)(void *);
void *once_arg;
};

static void *thread_fn(void *arg) {
struct thread_wrapper wrapper = *(struct thread_wrapper *)arg;
aws_mem_release(wrapper.allocator, arg);
static AWS_THREAD_LOCAL struct thread_wrapper *tl_wrapper = NULL;

wrapper.func(wrapper.arg);
static void *thread_fn(void *arg) {
struct thread_wrapper *wrapper = arg;
tl_wrapper = wrapper;
wrapper->func(wrapper->arg);
while (wrapper->atexit) {
struct thread_atexit_callback *cb = wrapper->atexit;
cb->callback(cb->user_data);
wrapper->atexit = wrapper->atexit->next;
aws_mem_release(wrapper->allocator, cb);
}
tl_wrapper = NULL;
aws_mem_release(wrapper->allocator, wrapper);
return NULL;
}

Expand All @@ -49,8 +67,24 @@ void aws_thread_clean_up(struct aws_thread *thread) {
}
}

void aws_thread_call_once(aws_thread_once *flag, void (*call_once)(void)) {
pthread_once(flag, call_once);
static void s_call_once(void) {
tl_wrapper->call_once(tl_wrapper->once_arg);
}

void aws_thread_call_once(aws_thread_once *flag, void (*call_once)(void *), void *user_data) {
// If this is a non-aws_thread, then gin up a temp thread wrapper
struct thread_wrapper temp_wrapper;
if (!tl_wrapper) {
tl_wrapper = &temp_wrapper;
}

tl_wrapper->call_once = call_once;
tl_wrapper->once_arg = user_data;
pthread_once(flag, s_call_once);

if (tl_wrapper == &temp_wrapper) {
tl_wrapper = NULL;
}
}

int aws_thread_init(struct aws_thread *thread, struct aws_allocator *allocator) {
Expand Down Expand Up @@ -91,7 +125,7 @@ int aws_thread_launch(
}

struct thread_wrapper *wrapper =
(struct thread_wrapper *)aws_mem_acquire(thread->allocator, sizeof(struct thread_wrapper));
(struct thread_wrapper *)aws_mem_calloc(thread->allocator, 1, sizeof(struct thread_wrapper));

if (!wrapper) {
allocation_failed = 1;
Expand Down Expand Up @@ -179,3 +213,19 @@ void aws_thread_current_sleep(uint64_t nanos) {

nanosleep(&tm, &output);
}

int aws_thread_current_at_exit(aws_thread_atexit_fn *callback, void *user_data) {
if (!tl_wrapper) {
return aws_raise_error(AWS_ERROR_THREAD_NOT_JOINABLE);
}

struct thread_atexit_callback *cb = aws_mem_calloc(tl_wrapper->allocator, 1, sizeof(struct thread_atexit_callback));
if (!cb) {
return AWS_OP_ERR;
}
cb->callback = callback;
cb->user_data = user_data;
cb->next = tl_wrapper->atexit;
tl_wrapper->atexit = cb;
return AWS_OP_SUCCESS;
}
5 changes: 3 additions & 2 deletions source/windows/device_random.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
static BCRYPT_ALG_HANDLE s_alg_handle = NULL;
static aws_thread_once s_rand_init = AWS_THREAD_ONCE_STATIC_INIT;

static void s_init_rand(void) {
static void s_init_rand(void *user_data) {
(void)user_data;
NTSTATUS status = 0;

status = BCryptOpenAlgorithmProvider(&s_alg_handle, BCRYPT_RNG_ALGORITHM, NULL, 0);
Expand All @@ -34,7 +35,7 @@ static void s_init_rand(void) {
}

int aws_device_random_buffer(struct aws_byte_buf *output) {
aws_thread_call_once(&s_rand_init, s_init_rand);
aws_thread_call_once(&s_rand_init, s_init_rand, NULL);

size_t offset = output->capacity - output->len;
NTSTATUS status = BCryptGenRandom(s_alg_handle, output->buffer + output->len, (ULONG)offset, 0);
Expand Down
50 changes: 43 additions & 7 deletions source/windows/thread.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,33 @@ static struct aws_thread_options s_default_options = {
.stack_size = 0,
};

struct thread_atexit_callback {
aws_thread_atexit_fn *callback;
void *user_data;
struct thread_atexit_callback *next;
};

struct thread_wrapper {
struct aws_allocator *allocator;
void (*func)(void *arg);
void *arg;
struct thread_atexit_callback *atexit;
};

static AWS_THREAD_LOCAL struct thread_wrapper *tl_wrapper = NULL;

static DWORD WINAPI thread_wrapper_fn(LPVOID arg) {
struct thread_wrapper thread_wrapper = *(struct thread_wrapper *)arg;
aws_mem_release(thread_wrapper.allocator, (void *)arg);
thread_wrapper.func(thread_wrapper.arg);
struct thread_wrapper *thread_wrapper = arg;
tl_wrapper = thread_wrapper;
thread_wrapper->func(thread_wrapper->arg);
while (thread_wrapper->atexit) {
struct thread_atexit_callback *cb = thread_wrapper->atexit;
cb->callback(cb->user_data);
thread_wrapper->atexit = thread_wrapper->atexit->next;
aws_mem_release(thread_wrapper->allocator, cb);
}
tl_wrapper = NULL;
aws_mem_release(thread_wrapper->allocator, thread_wrapper);
return 0;
}

Expand All @@ -41,21 +58,23 @@ const struct aws_thread_options *aws_default_thread_options(void) {
}

struct callback_fn_wrapper {
void (*call_once)(void);
void (*call_once)(void *);
void *user_data;
};

BOOL WINAPI s_init_once_wrapper(PINIT_ONCE init_once, void *param, void **context) {
(void)context;
(void)init_once;

struct callback_fn_wrapper *callback_fn_wrapper = param;
callback_fn_wrapper->call_once();
callback_fn_wrapper->call_once(callback_fn_wrapper->user_data);
return TRUE;
}

void aws_thread_call_once(aws_thread_once *flag, void (*call_once)(void)) {
void aws_thread_call_once(aws_thread_once *flag, void (*call_once)(void *), void *user_data) {
struct callback_fn_wrapper wrapper;
wrapper.call_once = call_once;
wrapper.user_data = user_data;
InitOnceExecuteOnce((PINIT_ONCE)flag, s_init_once_wrapper, &wrapper, NULL);
}

Expand All @@ -81,10 +100,11 @@ int aws_thread_launch(
}

struct thread_wrapper *thread_wrapper =
(struct thread_wrapper *)aws_mem_acquire(thread->allocator, sizeof(struct thread_wrapper));
(struct thread_wrapper *)aws_mem_calloc(thread->allocator, 1, sizeof(struct thread_wrapper));
thread_wrapper->allocator = thread->allocator;
thread_wrapper->arg = arg;
thread_wrapper->func = func;

thread->thread_handle =
CreateThread(0, stack_size, thread_wrapper_fn, (LPVOID)thread_wrapper, 0, &thread->thread_id);

Expand Down Expand Up @@ -129,3 +149,19 @@ void aws_thread_current_sleep(uint64_t nanos) {
* arises put the effort in here. */
Sleep((DWORD)aws_timestamp_convert(nanos, AWS_TIMESTAMP_NANOS, AWS_TIMESTAMP_MILLIS, NULL));
}

int aws_thread_current_at_exit(aws_thread_atexit_fn *callback, void *user_data) {
if (!tl_wrapper) {
return aws_raise_error(AWS_ERROR_THREAD_NOT_JOINABLE);
}

struct thread_atexit_callback *cb = aws_mem_calloc(tl_wrapper->allocator, 1, sizeof(struct thread_atexit_callback));
if (!cb) {
return AWS_OP_ERR;
}
cb->callback = callback;
cb->user_data = user_data;
cb->next = tl_wrapper->atexit;
tl_wrapper->atexit = cb;
return AWS_OP_SUCCESS;
}
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ add_test_case(aws_load_error_strings_test)
add_test_case(aws_assume_compiles_test)

add_test_case(thread_creation_join_test)
add_test_case(thread_atexit_test)

add_test_case(mutex_aquire_release_test)
add_test_case(mutex_is_actually_mutex_test)
Expand Down
36 changes: 36 additions & 0 deletions tests/thread_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,39 @@ static int s_test_thread_creation_join_fn(struct aws_allocator *allocator, void
}

AWS_TEST_CASE(thread_creation_join_test, s_test_thread_creation_join_fn)

static uint32_t s_atexit_call_count = 0;
static void s_thread_atexit_fn(void *user_data) {
(void)user_data;
AWS_FATAL_ASSERT(s_atexit_call_count == 0);
s_atexit_call_count = 1;
}

static void s_thread_atexit_fn2(void *user_data) {
(void)user_data;
AWS_FATAL_ASSERT(s_atexit_call_count == 1);
s_atexit_call_count = 2;
}

static void s_thread_worker_with_atexit(void *arg) {
(void)arg;
AWS_FATAL_ASSERT(AWS_OP_SUCCESS == aws_thread_current_at_exit(s_thread_atexit_fn2, NULL));
AWS_FATAL_ASSERT(AWS_OP_SUCCESS == aws_thread_current_at_exit(s_thread_atexit_fn, NULL));
}

static int s_test_thread_atexit(struct aws_allocator *allocator, void *ctx) {
(void)ctx;
struct aws_thread thread;
ASSERT_SUCCESS(aws_thread_init(&thread, allocator));

ASSERT_SUCCESS(aws_thread_launch(&thread, s_thread_worker_with_atexit, NULL, 0), "thread creation failed");
ASSERT_SUCCESS(aws_thread_join(&thread), "thread join failed");

ASSERT_INT_EQUALS(2, s_atexit_call_count);

aws_thread_clean_up(&thread);

return 0;
}

AWS_TEST_CASE(thread_atexit_test, s_test_thread_atexit)

0 comments on commit 94351a9

Please sign in to comment.