Skip to content

Commit

Permalink
remove forceinline. blame @arund42 if this ever does not get inlined …
Browse files Browse the repository at this point in the history
…by the compiler
  • Loading branch information
karpathy committed Apr 29, 2024
1 parent 7634f08 commit af2bc47
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
20 changes: 10 additions & 10 deletions dev/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ void cublasCheck(cublasStatus_t status, const char *file, int line)

template<class ElementType>
struct alignas(16) Packed128 {
__device__ __forceinline__ Packed128() = default;
__device__ __forceinline__ explicit Packed128(int4 bits) {
__device__ Packed128() = default;
__device__ explicit Packed128(int4 bits) {
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&payload, &bits, sizeof(bits));
}

__device__ __forceinline__ ElementType& operator[](int index) {
__device__ ElementType& operator[](int index) {
return payload[index];
}
__device__ __forceinline__ const ElementType& operator[](int index) const {
__device__ const ElementType& operator[](int index) const {
return payload[index];
}
__device__ __forceinline__ float fp32(int index) {
__device__ float fp32(int index) {
return static_cast<float>(payload[index]);
}
__device__ __forceinline__ int4 get_bits() const {
__device__ int4 get_bits() const {
int4 bits;
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&bits, &payload, sizeof(bits));
Expand All @@ -71,25 +71,25 @@ typedef Packed128<float> f128;

// load a Packed128 from an aligned memory address
template<class ElementType>
__device__ __forceinline__ Packed128<ElementType> load128(const ElementType* address) {
__device__ Packed128<ElementType> load128(const ElementType* address) {
return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)};
}

// load a Packed128 from an aligned memory address with streaming cache hint
template<class ElementType>
__device__ __forceinline__ Packed128<ElementType> load128cs(const ElementType* address) {
__device__ Packed128<ElementType> load128cs(const ElementType* address) {
return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))};
}

// store a Packed128 to an aligned memory address
template<class ElementType>
__device__ __forceinline__ void store128(ElementType* target, Packed128<ElementType> value) {
__device__ void store128(ElementType* target, Packed128<ElementType> value) {
*reinterpret_cast<int4*>(target) = value.get_bits();
}

// store a Packed128 to an aligned memory address with streaming cache hint
template<class ElementType>
__device__ __forceinline__ void store128cs(ElementType* target, Packed128<ElementType> value) {
__device__ void store128cs(ElementType* target, Packed128<ElementType> value) {
__stcs(reinterpret_cast<int4*>(target), value.get_bits());
}

Expand Down
20 changes: 10 additions & 10 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -199,22 +199,22 @@ __device__ void atomicAddX(float* addr, float val) {

template<class ElementType>
struct alignas(16) Packed128 {
__device__ __forceinline__ Packed128() = default;
__device__ __forceinline__ explicit Packed128(int4 bits) {
__device__ Packed128() = default;
__device__ explicit Packed128(int4 bits) {
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&payload, &bits, sizeof(bits));
}

__device__ __forceinline__ ElementType& operator[](int index) {
__device__ ElementType& operator[](int index) {
return payload[index];
}
__device__ __forceinline__ const ElementType& operator[](int index) const {
__device__ const ElementType& operator[](int index) const {
return payload[index];
}
__device__ __forceinline__ float fp32(int index) {
__device__ float fp32(int index) {
return static_cast<float>(payload[index]);
}
__device__ __forceinline__ int4 get_bits() const {
__device__ int4 get_bits() const {
int4 bits;
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&bits, &payload, sizeof(bits));
Expand All @@ -230,25 +230,25 @@ typedef Packed128<float> f128;

// load a Packed128 from an aligned memory address
template<class ElementType>
__device__ __forceinline__ Packed128<ElementType> load128(const ElementType* address) {
__device__ Packed128<ElementType> load128(const ElementType* address) {
return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)};
}

// load a Packed128 from an aligned memory address with streaming cache hint
template<class ElementType>
__device__ __forceinline__ Packed128<ElementType> load128cs(const ElementType* address) {
__device__ Packed128<ElementType> load128cs(const ElementType* address) {
return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))};
}

// store a Packed128 to an aligned memory address
template<class ElementType>
__device__ __forceinline__ void store128(ElementType* target, Packed128<ElementType> value) {
__device__ void store128(ElementType* target, Packed128<ElementType> value) {
*reinterpret_cast<int4*>(target) = value.get_bits();
}

// store a Packed128 to an aligned memory address with streaming cache hint
template<class ElementType>
__device__ __forceinline__ void store128cs(ElementType* target, Packed128<ElementType> value) {
__device__ void store128cs(ElementType* target, Packed128<ElementType> value) {
__stcs(reinterpret_cast<int4*>(target), value.get_bits());
}

Expand Down

0 comments on commit af2bc47

Please sign in to comment.