Skip to content

Commit

Permalink
Moved DeviceType and DeviceInfo to platform rather than scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
inakleinbottle committed Aug 18, 2023
1 parent 1523781 commit 3e42f3e
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 97 deletions.
33 changes: 1 addition & 32 deletions device/include/roughpy/device/core.h
Original file line number Diff line number Diff line change
@@ -1,40 +1,9 @@
#ifndef ROUGHPY_DEVICE_CORE_H_
#define ROUGHPY_DEVICE_CORE_H_

#ifdef __NVCC__
# include <cuda.h>

# define RPY_DEVICE __device__
# define RPY_HOST __host__
# define RPY_DEVICE_HOST __device__ __host__
# define RPY_KERNEL __global__
# define RPY_DEVICE_SHARED __shared__
# define RPY_STRONG_INLINE __inline__

#elif defined(__HIPCC__)

# define RPY_DEVICE __device__
# define RPY_HOST __host__
# define RPY_DEVICE_HOST __device__ __host__
# define RPY_KERNEL __global__
# define RPY_DEVICE_SHARED __shared__
# define RPY_STRONG_INLINE

#else
# define RPY_DEVICE
# define RPY_HOST
# define RPY_DEVICE_HOST
# define RPY_KERNEL
# define RPY_DEVICE_SHARED
# define RPY_STRONG_INLINE

#endif

namespace rpy {
namespace device {

using dindex_t = int;
using dsize_t = unsigned int;


}// namespace device
}// namespace rpy
Expand Down
13 changes: 8 additions & 5 deletions platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@ cmake_minimum_required(VERSION 3.21)
project(Roughpy_Platform VERSION 0.0.1)

add_roughpy_component(Platform
SOURCES
SOURCES
src/configuration.cpp
src/threading/openmp_threading.cpp
PUBLIC_HEADERS
src/device.cpp
PUBLIC_HEADERS
include/roughpy/platform.h
include/roughpy/platform/filesystem.h
include/roughpy/platform/configuration.h
include/roughpy/platform/serialization.h
include/roughpy/platform/threads.h
PUBLIC_DEPS
include/roughpy/platform/device.h
PUBLIC_DEPS
Boost::boost
Boost::url
Boost::disable_autolinking
cereal::cereal
OpenMP::OpenMP_CXX
NEEDS
NEEDS
RoughPy::Core
)
)

130 changes: 130 additions & 0 deletions platform/include/roughpy/platform/device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
//
// Created by sam on 17/08/23.
//

#ifndef ROUGHPY_DEVICE_H
#define ROUGHPY_DEVICE_H

#include <roughpy/core/macros.h>
#include <roughpy/core/types.h>

#include "filesystem.h"

#if defined(__NVCC__)
# include <cuda.h>

# define RPY_DEVICE __device__
# define RPY_HOST __host__
# define RPY_DEVICE_HOST __device__ __host__
# define RPY_KERNEL __global__
# define RPY_DEVICE_SHARED __shared__
# define RPY_STRONG_INLINE __inline__

#elif defined(__HIPCC__)

# define RPY_DEVICE __device__
# define RPY_HOST __host__
# define RPY_DEVICE_HOST __device__ __host__
# define RPY_KERNEL __global__
# define RPY_DEVICE_SHARED __shared__
# define RPY_STRONG_INLINE

#else
# define RPY_DEVICE
# define RPY_HOST
# define RPY_DEVICE_HOST
# define RPY_KERNEL
# define RPY_DEVICE_SHARED
# define RPY_STRONG_INLINE

#endif



namespace rpy { namespace platform {

using dindex_t = int;
using dsize_t = unsigned int;


/**
* @brief Code for different device types
*
* These codes are chosen to be compatible with the DLPack
* array interchange protocol. They enumerate the various different
* device types that scalar data may be allocated on. This code goes
* with a 32bit integer device ID, which is implementation specific.
*/
enum DeviceType : int32_t {
CPU = 1,
CUDA = 2,
CUDAHost = 3,
OpenCL = 4,
Vulkan = 7,
Metal = 8,
VPI = 9,
ROCM = 10,
ROCMHost = 11,
ExtDev = 12,
CUDAManaged = 13,
OneAPI = 14,
WebGPU = 15,
Hexagon = 16
};

/**
* @brief Device type/id pair to identify a device
*
*
*/
struct DeviceInfo {
DeviceType device_type;
int32_t device_id;
};


class RPY_EXPORT DeviceHandle {
DeviceInfo m_info;

public:

virtual ~DeviceHandle();

explicit DeviceHandle(DeviceInfo info)
: m_info(info)
{}

explicit DeviceHandle(DeviceType type, int32_t device_id)
: m_info {type, device_id}
{}

RPY_NO_DISCARD
const DeviceInfo& info() const noexcept { return m_info; }

RPY_NO_DISCARD
virtual const fs::path& runtime_library() const noexcept = 0;


virtual void launch_kernel(const void* kernel,
const void* launch_config,
void** args
) = 0;


};






constexpr bool
operator==(const DeviceInfo& lhs, const DeviceInfo& rhs) noexcept
{
return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id;
}

}}


#endif// ROUGHPY_DEVICE_H
13 changes: 13 additions & 0 deletions platform/src/device.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//
// Created by sam on 17/08/23.
//

#include <roughpy/platform/device.h>


using namespace rpy;
using namespace rpy::platform;



DeviceHandle::~DeviceHandle() = default;
4 changes: 2 additions & 2 deletions roughpy/src/scalars/scalars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void python::init_scalars(pybind11::module_& m)
static const scalars::ScalarType*
dlpack_dtype_to_scalar_type(DLDataType dtype, DLDevice device)
{
using scalars::ScalarDeviceType;
using platform::DeviceType;

scalars::ScalarTypeCode type;
switch (dtype.code) {
Expand All @@ -154,7 +154,7 @@ dlpack_dtype_to_scalar_type(DLDataType dtype, DLDevice device)

return scalars::ScalarType::from_type_details(
{type, dtype.bits, dtype.lanes},
{static_cast<ScalarDeviceType>(device.device_type),
{static_cast<DeviceType>(device.device_type),
device.device_id}
);
}
Expand Down
6 changes: 3 additions & 3 deletions scalars/include/roughpy/scalars/scalar_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class RPY_EXPORT ScalarType

template <typename T>
RPY_NO_DISCARD inline static const ScalarType*
of(const ScalarDeviceInfo& device);
of(const platform::DeviceInfo& device);

/*
* ScalarTypes objects should be unique for each configuration,
Expand Down Expand Up @@ -106,7 +106,7 @@ class RPY_EXPORT ScalarType
* @return const pointer to appropriate scalar type
*/
RPY_NO_DISCARD static const ScalarType* from_type_details(
const BasicScalarInfo& details, const ScalarDeviceInfo& device
const BasicScalarInfo& details, const platform::DeviceInfo& device
);

/**
Expand Down Expand Up @@ -509,7 +509,7 @@ inline const ScalarType* ScalarType::of()
}

template <typename T>
const ScalarType* ScalarType::of(const ScalarDeviceInfo& device)
const ScalarType* ScalarType::of(const platform::DeviceInfo& device)
{
return get_type(type_id_of<T>(), device);
}
Expand Down
49 changes: 5 additions & 44 deletions scalars/include/roughpy/scalars/scalars_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <roughpy/core/macros.h>
#include <roughpy/core/traits.h>
#include <roughpy/core/types.h>
#include <roughpy/platform/device.h>

#include <complex>

Expand Down Expand Up @@ -92,32 +93,6 @@ struct signed_size_type_marker {
struct unsigned_size_type_marker {
};

/**
* @brief Code for different device types
*
* These codes are chosen to be compatible with the DLPack
* array interchange protocol. They enumerate the various different
* device types that scalar data may be allocated on. This code goes
* with a 32bit integer device ID, which is implementation specific.
*/
enum class ScalarDeviceType : int32_t
{
CPU = 1,
CUDA = 2,
CUDAHost = 3,
OpenCL = 4,
Vulkan = 7,
Metal = 8,
VPI = 9,
ROCM = 10,
ROCMHost = 11,
ExtDev = 12,
CUDAManaged = 13,
OneAPI = 14,
WebGPU = 15,
Hexagon = 16
};

/**
* @brief Type codes for different scalar types.
*
Expand All @@ -138,15 +113,6 @@ enum class ScalarTypeCode : uint8_t
Bool = 6U
};

/**
* @brief Device type/id pair to identify a device
*
*
*/
struct ScalarDeviceInfo {
ScalarDeviceType device_type;
std::int32_t device_id;
};

/**
* @brief Basic information for identifying the type, size, and
Expand All @@ -168,10 +134,10 @@ struct BasicScalarInfo {
struct ScalarTypeInfo {
string name;
string id;
std::size_t n_bytes;
std::size_t alignment;
size_t n_bytes;
size_t alignment;
BasicScalarInfo basic_info;
ScalarDeviceInfo device;
platform::DeviceInfo device;
};

// Forward declarations
Expand Down Expand Up @@ -202,11 +168,6 @@ inline remove_cv_ref_t<T> scalar_cast(const Scalar& arg);
using conversion_function
= std::function<void(ScalarPointer, ScalarPointer, dimn_t)>;

constexpr bool
operator==(const ScalarDeviceInfo& lhs, const ScalarDeviceInfo& rhs) noexcept
{
return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id;
}

constexpr bool
operator==(const BasicScalarInfo& lhs, const BasicScalarInfo& rhs) noexcept
Expand All @@ -233,7 +194,7 @@ RPY_EXPORT
const ScalarType* get_type(const string& id);

RPY_EXPORT
const ScalarType* get_type(const string& id, const ScalarDeviceInfo& device);
const ScalarType* get_type(const string& id, const platform::DeviceInfo& device);

/**
* @brief Get a list of all registered ScalarTypes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

#include "b_float_16_type.h"

#include <roughpy/platform/device.h>

#include <string>
#include <utility>

Expand All @@ -42,5 +44,5 @@ BFloat16Type::BFloat16Type()
string("BFloat16"), string("bf16"), sizeof(bfloat16),
alignof(bfloat16),
{ScalarTypeCode::BFloat, sizeof(bfloat16) * CHAR_BIT, 1U},
{ScalarDeviceType::CPU, 0})
{platform::DeviceType::CPU, 0})
{}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ RationalType::RationalType()
ScalarTypeCode::OpaqueHandle,
0, 0,
},
{ ScalarDeviceType::CPU, 0 }
{ DeviceType::CPU, 0 }
})
{}
ScalarPointer RationalType::allocate(std::size_t count) const
Expand Down
Loading

0 comments on commit 3e42f3e

Please sign in to comment.