diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index e8159f46f..bffba50db 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1284,6 +1284,7 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::F32) => "sa_u8_f32", (DType::U8, DType::F16) => "sa_u8_f16", (DType::U8, DType::BF16) => "sa_u8_bf16", + (DType::U32, DType::U32) => "sa_u32_u32", (DType::U32, DType::F32) => "sa_u32_f32", (DType::U32, DType::F16) => "sa_u32_f16", (DType::U32, DType::BF16) => "sa_u32_bf16", diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 2594689cf..7509b6280 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -219,6 +219,7 @@ GATHER_OP(gather_u32_u32, uint, uint) SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) SCATTER_ADD_OP(sa_i64_f32, int64_t, float) +SCATTER_ADD_OP(sa_u32_u32, uint32_t, uint32_t) SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) SCATTER_ADD_OP(sa_i64_f16, int64_t, half)