Skip to content

Commit

Permalink
Fix kwargs to fmm/fmmv
Browse files Browse the repository at this point in the history
  • Loading branch information
gmeanti committed Sep 21, 2023
1 parent eb51da4 commit 95c7005
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion falkon/kernels/distance_kernel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple, Type, Union
from typing import Dict, Optional, Type, Union

import numpy as np
import torch
Expand Down
14 changes: 7 additions & 7 deletions falkon/mmv_ops/fmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def sparse_mm_run_thread(

for j in range(0, M, m):
lenj = min(m, M - j)
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2}
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2.items()}

c_m2 = m2.narrow_rows(j, lenj).to(dtype=comp_dt)
# On CUDA the second argument to apply (a Sparse*Sparse multiplication) must be
Expand All @@ -258,7 +258,7 @@ def sparse_mm_run_thread(

for i in range(0, N, n):
leni = min(n, N - i)
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1.items()}

c_m1 = m1.narrow_rows(i, leni).to(dtype=comp_dt)
if dev.type == "cuda":
Expand Down Expand Up @@ -336,7 +336,7 @@ def mm_run_thread(

for i in range(0, N, n):
leni = min(n, N - i)
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1.items()}

if has_gpu_bufs:
c_dev_m1 = copy(
Expand All @@ -350,7 +350,7 @@ def mm_run_thread(

for j in range(0, M, m):
lenj = min(m, M - j)
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2}
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2.items()}

if has_gpu_bufs:
c_dev_m2 = copy(
Expand Down Expand Up @@ -415,12 +415,12 @@ def mm_diff_run_thread(

for i in range(0, N, n):
leni = min(n, N - i)
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1.items()}

c_dev_m1 = m1[i : i + leni, :].to(device=dev, dtype=comp_dt, non_blocking=True, copy=False)
for j in range(0, M, m):
lenj = min(m, M - j)
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2}
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2.items()}

c_dev_m2 = m2[j : j + lenj, :].to(device=dev, dtype=comp_dt, non_blocking=True, copy=False)
c_dev_out = kernel.compute_diff(
Expand Down Expand Up @@ -493,7 +493,7 @@ def run_cpu_gpu(
X1_block = X1.narrow(0, block_sizes[i], bwidth)
c_kwargs_m1 = {}
if kwargs_m1 is not None:
c_kwargs_m1 = {k: v[block_sizes[i] : block_sizes[i] + bwidth] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[block_sizes[i] : block_sizes[i] + bwidth] for k, v in kwargs_m1.items()}
args.append(
(
ArgsFmm(
Expand Down
20 changes: 10 additions & 10 deletions falkon/mmv_ops/fmmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def sparse_mmv_run_thread(
s1, s2 = _init_two_streams(stack, dev, tid) # enters stream 1
for i in range(0, N, blk_n):
leni = min(blk_n, N - i)
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1.items()}

c_m1 = m1.narrow_rows(i, leni)
if incore: # Note that CUDA-incore is not allowed to happen (so this is CPU->CPU)
Expand All @@ -271,7 +271,7 @@ def sparse_mmv_run_thread(

for j in range(0, M, blk_m):
lenj = min(blk_m, M - j)
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2}
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2.items()}

c_m2 = m2.narrow_rows(j, lenj)
if incore: # CPU -> CPU
Expand Down Expand Up @@ -357,7 +357,7 @@ def mmv_run_thread(
s1, s2 = _init_two_streams(stack, dev, tid)
for i in range(0, N, blk_n):
leni = min(blk_n, N - i)
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1.items()}
if m1_ic:
c_dev_m1 = m1[i : i + leni, :]
else:
Expand All @@ -370,7 +370,7 @@ def mmv_run_thread(

for j in range(0, M, blk_m):
lenj = min(blk_m, M - j)
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2}
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2.items()}
if m2_ic:
c_dev_m2 = m2[j : j + lenj, :]
else:
Expand Down Expand Up @@ -433,14 +433,14 @@ def mmv_diff_run_thread(
s1, s2 = _init_two_streams(stack, dev, tid)
for i in range(0, N, blk_n):
leni = min(blk_n, N - i)
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1.items()}
c_dev_m1 = m1[i : i + leni, :].to(dev, non_blocking=True, copy=False)
c_dev_m1_g = None if grads[0] is None else grads[0][i : i + leni, :].to(dev, non_blocking=True, copy=False)
c_dev_out = out[i : i + leni, :].to(dev, non_blocking=True, copy=False)

for j in range(0, M, blk_m):
lenj = min(blk_m, M - j)
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2}
c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2.items()}
c_dev_m2 = m2[j : j + lenj, :].to(dev, non_blocking=True, copy=False)
c_dev_m2_g = (
None if grads[1] is None else grads[1][j : j + lenj, :].to(dev, non_blocking=True, copy=False)
Expand Down Expand Up @@ -655,7 +655,7 @@ def sparse_dmmv_run_thread(

for i in range(0, N, blk_n):
leni = min(blk_n, N - i)
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1.items()}

c_m1 = m1.narrow_rows(i, leni)
if incore: # Note that CUDA-incore is not allowed to happen (so this is CPU->CPU)
Expand Down Expand Up @@ -739,7 +739,7 @@ def dmmv_run_thread(
copy(v, dev_v, non_blocking=True)
for i in range(0, N, blk_n):
leni = min(blk_n, N - i)
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1.items()}
if m1_ic:
c_dev_m1 = m1[i : i + leni, :]
else:
Expand Down Expand Up @@ -820,7 +820,7 @@ def run_cpu_gpu(
X1_block = X1.narrow(0, block_sizes[i], bwidth)
c_kwargs_m1 = {}
if kwargs_m1 is not None:
c_kwargs_m1 = {k: v[block_sizes[i] : block_sizes[i] + bwidth] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[block_sizes[i] : block_sizes[i] + bwidth] for k, v in kwargs_m1.items()}
args.append(
(
ArgsFmmv(
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def fdmmv(
X1_block = X1.narrow(0, block_sizes[i], bwidth)
c_kwargs_m1 = {}
if kwargs_m1 is not None:
c_kwargs_m1 = {k: v[block_sizes[i] : block_sizes[i] + bwidth] for k, v in kwargs_m1}
c_kwargs_m1 = {k: v[block_sizes[i] : block_sizes[i] + bwidth] for k, v in kwargs_m1.items()}
args.append(
(
ArgsFmmv(
Expand Down

0 comments on commit 95c7005

Please sign in to comment.