| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #pragma once |
| | #include "base.h" |
| | #include <cudaTypedefs.h> |
| |
|
| | namespace marlin_24 { |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | #if defined CUDA_VERSION && CUDA_VERSION >= 12050 |
| | #define MMA_SP_INST \ |
| | "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " |
| | #else |
| | #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " |
| | #endif |
| |
|
| | |
| | |
| | __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, |
| | const FragA& frag_b, FragC& frag_c, FragM& frag_m, |
| | const int psel) { |
| | const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0); |
| | const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1); |
| | const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); |
| | const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m); |
| |
|
| | float* c = reinterpret_cast<float*>(&frag_c); |
| | if (psel == 0) { |
| | asm volatile(MMA_SP_INST |
| | "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " |
| | "{%12,%13,%14,%15}, %16, 0x0;\n" |
| | : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) |
| | : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), |
| | "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), |
| | "f"(c[2]), "f"(c[3]), "r"(e[0])); |
| | asm volatile(MMA_SP_INST |
| | "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " |
| | "{%12,%13,%14,%15}, %16, 0x0;\n" |
| | : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) |
| | : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), |
| | "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), |
| | "f"(c[6]), "f"(c[7]), "r"(e[0])); |
| | } else { |
| | asm volatile(MMA_SP_INST |
| | "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " |
| | "{%12,%13,%14,%15}, %16, 0x1;\n" |
| | : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) |
| | : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), |
| | "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), |
| | "f"(c[2]), "f"(c[3]), "r"(e[0])); |
| | asm volatile(MMA_SP_INST |
| | "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " |
| | "{%12,%13,%14,%15}, %16, 0x1;\n" |
| | : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) |
| | : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), |
| | "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), |
| | "f"(c[6]), "f"(c[7]), "r"(e[0])); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | template <int lut> |
| | __device__ inline int lop3(int a, int b, int c) { |
| | int res; |
| | asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" |
| | : "=r"(res) |
| | : "r"(a), "r"(b), "r"(c), "n"(lut)); |
| | return res; |
| | } |
| |
|
| | __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, |
| | float c3) { |
| | uint2 r; |
| | asm("{\n\t" |
| | ".reg .f16 a, b, c, d; \n\t" |
| | "cvt.rn.f16.f32 a, %2; \n\t" |
| | "cvt.rn.f16.f32 b, %3; \n\t" |
| | "cvt.rn.f16.f32 c, %4; \n\t" |
| | "cvt.rn.f16.f32 d, %5; \n\t" |
| | "mov.b32 %0, {a, b}; \n\t" |
| | "mov.b32 %1, {c, d}; \n\t" |
| | "}" |
| | : "=r"(r.x), "=r"(r.y) |
| | : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); |
| | return r; |
| | } |
| |
|
| | |
| | |
| | template <int start_byte, int mask> |
| | __device__ inline uint32_t prmt(uint32_t a) { |
| | uint32_t res; |
| | asm volatile("prmt.b32 %0, %1, %2, %3;\n" |
| | : "=r"(res) |
| | : "r"(a), "n"(start_byte), "n"(mask)); |
| | return res; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | __device__ inline FragB dequant_4bit(int q) { |
| | const int LO = 0x000f000f; |
| | const int HI = 0x00f000f0; |
| | const int EX = 0x64006400; |
| | |
| | int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); |
| | int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); |
| | |
| | |
| | const int SUB = 0x64086408; |
| | const int MUL = 0x2c002c00; |
| | const int ADD = 0xd480d480; |
| |
|
| | FragB frag_b; |
| | frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), |
| | *reinterpret_cast<const half2*>(&SUB)); |
| | frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi), |
| | *reinterpret_cast<const half2*>(&MUL), |
| | *reinterpret_cast<const half2*>(&ADD)); |
| | return frag_b; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | __device__ inline FragB dequant_8bit(int q) { |
| | static constexpr uint32_t mask_for_elt_01 = 0x5250; |
| | static constexpr uint32_t mask_for_elt_23 = 0x5351; |
| | static constexpr uint32_t start_byte_for_fp16 = 0x64646464; |
| |
|
| | uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q); |
| | uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q); |
| |
|
| | static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; |
| |
|
| | FragB frag_b; |
| | frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), |
| | *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); |
| | frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi), |
| | *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); |
| | return frag_b; |
| | } |
| |
|
| | |
| | |
| | __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { |
| | half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); |
| | frag_b[0] = __hmul2(frag_b[0], s); |
| | frag_b[1] = __hmul2(frag_b[1], s); |
| | } |
| |
|
| | __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, |
| | FragS& s0, float* c4, float* c5, float* c6, |
| | float* c7, FragS& s1) { |
| | *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); |
| | *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); |
| | *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); |
| | *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); |
| |
|
| | *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); |
| | *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); |
| | *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); |
| | *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); |
| | } |
| |
|
| | } |
| |
|