Skip to content

Commit

Permalink
ggml: add GGML_SET Metal kernel + i32 CPU kernel (ggml/1037)
Browse files Browse the repository at this point in the history
* implemented cpu kernel

* add i32 test cases in test-backend-ops

* typedef `ggml_metal_kargs_set`

* implemented `kernel_set`

* memcpy
  • Loading branch information
PABannier authored and ggerganov committed Dec 5, 2024
1 parent c2082d9 commit a8cbab2
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 1 deletion.
80 changes: 79 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,10 @@ struct ggml_compute_state {

inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }

inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }

inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
Expand Down Expand Up @@ -8248,6 +8251,77 @@ static void ggml_compute_forward_set_f32(
}
}

static void ggml_compute_forward_set_i32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));

// view src0 and dst with these strides and data offset inbytes during set
// nb0 is implicitly element_size because src0 and dst are contiguous
size_t nb1 = ((int32_t *) dst->op_params)[0];
size_t nb2 = ((int32_t *) dst->op_params)[1];
size_t nb3 = ((int32_t *) dst->op_params)[2];
size_t offset = ((int32_t *) dst->op_params)[3];
bool inplace = (bool) ((int32_t *) dst->op_params)[4];

if (!inplace) {
if (params->ith == 0) {
// memcpy needs to be synchronized across threads to avoid race conditions.
// => do it in INIT phase
memcpy(
((char *) dst->data),
((char *) src0->data),
ggml_nbytes(dst));
}
ggml_barrier(params->threadpool);
}

const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(src1);
const int nc = src1->ne[0];

GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)

// src0 and dst as viewed during set
const size_t nb0 = ggml_element_size(src0);

const int im0 = (ne10 == 0 ? 0 : ne10-1);
const int im1 = (ne11 == 0 ? 0 : ne11-1);
const int im2 = (ne12 == 0 ? 0 : ne12-1);
const int im3 = (ne13 == 0 ? 0 : ne13-1);

GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));

GGML_ASSERT(nb10 == sizeof(int32_t));

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int ir = ir0; ir < ir1; ++ir) {
// src0 and dst are viewed with shape of src1 and offset
// => same indices
const int i3 = ir/(ne12*ne11);
const int i2 = (ir - i3*ne12*ne11)/ne11;
const int i1 = (ir - i3*ne12*ne11 - i2*ne11);

ggml_vec_cpy_i32(nc,
(int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
(int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
}
}

static void ggml_compute_forward_set(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
Expand All @@ -8259,6 +8333,10 @@ static void ggml_compute_forward_set(
{
ggml_compute_forward_set_f32(params, dst);
} break;
case GGML_TYPE_I32:
{
ggml_compute_forward_set_i32(params, dst);
} break;
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q4_0:
Expand Down
15 changes: 15 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ typedef struct {
uint64_t nb3;
} ggml_metal_kargs_cpy;

typedef struct {
int64_t ne10;
int64_t ne11;
int64_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
uint64_t offs;
bool inplace;
} ggml_metal_kargs_set;

typedef struct {
int32_t ne00;
int32_t ne01;
Expand Down
76 changes: 76 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_SET_I32,
GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
Expand Down Expand Up @@ -940,6 +942,8 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
Expand Down Expand Up @@ -1159,6 +1163,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
return false;
};
}
case GGML_OP_SET:
{
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_I32:
return true;
default:
return false;
};
}
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{
Expand Down Expand Up @@ -3824,6 +3838,68 @@ static void ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SET:
{
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));

// src0 and dst as viewed during set
const size_t dst_nb0 = ggml_element_size(src0);

const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
const size_t offset = ((int32_t *) dst->op_params)[3];
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];

if (!inplace) {
memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst));
}

const int im0 = (ne10 == 0 ? 0 : ne10-1);
const int im1 = (ne11 == 0 ? 0 : ne11-1);
const int im2 = (ne12 == 0 ? 0 : ne12-1);
const int im3 = (ne13 == 0 ? 0 : ne13-1);

GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst));

id<MTLComputePipelineState> pipeline = nil;

switch (src0t) {
case GGML_TYPE_F32:
GGML_ASSERT(nb10 == sizeof(float));
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
case GGML_TYPE_I32:
GGML_ASSERT(nb10 == sizeof(int32_t));
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
default: GGML_ABORT("fatal error");
}

ggml_metal_kargs_set args = {
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb1 =*/ dst_nb1,
/*.nb2 =*/ dst_nb2,
/*.nb3 =*/ dst_nb3,
/*.offs =*/ offset,
/*.inplace =*/ inplace,
};

const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);

[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];

[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_POOL_2D:
{
GGML_ASSERT(ggml_is_contiguous(src0));
Expand Down
32 changes: 32 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3927,6 +3927,38 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_

#undef FA_TYPES

template<typename T>
kernel void kernel_set(
constant ggml_metal_kargs_set & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i13 = tgpig[2];
const int i12 = tgpig[1];
const int i11 = tgpig[0];

const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;

const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;

device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);

for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
dst_data[i10] = (T) src[0];
}
}

typedef decltype(kernel_set<float>) kernel_set_t;

template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;

template<typename T0, typename T1>
kernel void kernel_cpy(
constant ggml_metal_kargs_cpy & args,
Expand Down
4 changes: 4 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3521,6 +3521,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
}

for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
}

for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (ggml_type type_dst : all_types) {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
Expand Down

0 comments on commit a8cbab2

Please sign in to comment.