Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repeated matrix calculation in predict() inflates runtime #236

Open
Max-Bladen opened this issue Sep 7, 2022 · 1 comment
Open

Repeated matrix calculation in predict() inflates runtime #236

Max-Bladen opened this issue Sep 7, 2022 · 1 comment
Assignees
Labels
bug Something isn't working wip work-in-progress

Comments

@Max-Bladen
Copy link
Collaborator

🐞 Describe the bug:

Referring to the following formula used to make predictions via a (s)PLS model:

$Y = XW(P'W)C + Y_r$

In the academic literature, the $W(P'W)C$ matrix is referred to as $W^* $. Looking at the following lines (593-603 from predict()), we can see the calculation of $W^*$ occurring three times, within a loop. This matrix is going to be constant as these are outputs of the model used to make predictions. Hence, this repeated calculation is extremely inefficient and bloats runtime significantly.

Ypred = lapply(1 : ncomp[i], function(x){concat.newdata[[i]] %*% Wmat[, 1:x] %*% solve(t(Pmat[, 1:x]) %*% Wmat[, 1:x]) %*% t(Cmat)[1:x, ]})
Ypred = sapply(Ypred, function(x){x*sigma.Y + means.Y}, simplify = "array")

Y.hat[[i]] = array(Ypred, c(nrow(newdata[[i]]), ncol(Y), ncomp[i])) # in case one observation and only one Y, we need array() to keep it an array with a third dimension being ncomp

t.pred[[i]] = concat.newdata[[i]] %*% Wmat %*% solve(t(Pmat) %*% Wmat)
t.pred[[i]] = matrix(data = sapply(1:ncol(t.pred[[i]]),
                                   function(x) {t.pred[[i]][, x] * apply(variatesX[[i]], 2,
                                                                         function(y){(norm(y, type = "2"))^2})[x]}), nrow = nrow(concat.newdata[[i]]), ncol = ncol(t.pred[[i]]))

B.hat[[i]] = sapply(1 : ncomp[i], function(x){Wmat[, 1:x] %*% solve(t(Pmat[, 1:x]) %*% Wmat[, 1:x]) %*% t(Cmat)[1:x, ]}, simplify = "array")

🤔 Expected behavior:
To not have the same calculation performed three times within a loop.


💡 Possible solution:
A simple, initial fix will be implemented first. By having the calculation occur prior to this loop and referring to a defined object should hopefully reduce run time.

A more complicated solution would be to adjust the output of our various functions (eg. spls(), splsda() etc) to return these crucial data frames. This allows users to understand the model, use it for their own purposes as well as decrease runtime by reducing the quantity of matrix calculations.


Credit:
This issue was report by @psalguerog. I greatly appreciate you bringing this to my attention.

@Max-Bladen Max-Bladen added the bug Something isn't working label Sep 7, 2022
@Max-Bladen Max-Bladen self-assigned this Sep 7, 2022
@Max-Bladen Max-Bladen added the wip work-in-progress label Sep 7, 2022
@Max-Bladen
Copy link
Collaborator Author

Max-Bladen commented Sep 7, 2022

Here's a brief summary of the work I've done so far. Regarding specifically the inflated runtime.

The old set of code can be seen above. This was suspected to be inefficient due to the fact that $W* $ was seemingly calculated three times.

  • This is found in the first Ypred line and the B.hat line
    • as Wmat[, 1:x] %*% solve(t(Pmat[, 1:x]) %*% Wmat[, 1:x])
  • It is also found in the first t.pred line
    • as Wmat %*% solve(t(Pmat) %*% Wmat)

I adjusted the code as can be seen below:

W.star <- Wmat %*% solve(t(Pmat) %*% Wmat)
      
B.hat[[i]] = sapply(1 : ncomp[i], function(x){matrix(W.star[, 1:x], ncol=x) %*% t(Cmat)[1:x, ]}, simplify = "array")


# Prediction Y.hat, B.hat and t.pred
Ypred = lapply(1 : ncomp[i], function(x){concat.newdata[[i]] %*% B.hat[[i]][,,x]})
Ypred = sapply(Ypred, function(x){x*sigma.Y + means.Y}, simplify = "array")

Y.hat[[i]] = array(Ypred, c(nrow(newdata[[i]]), ncol(Y), ncomp[i])) # in case one observation and only one Y, we need array() to keep it an array with a third dimension being ncomp

t.pred[[i]] = concat.newdata[[i]] %*% W.star
t.pred[[i]] = matrix(data = sapply(1:ncol(t.pred[[i]]),
                                   function(x) {t.pred[[i]][, x] * apply(variatesX[[i]], 2,
                                                                         function(y){(norm(y, type = "2"))^2})[x]}), nrow = nrow(concat.newdata[[i]]), ncol = ncol(t.pred[[i]]))

The main differences are:

  • Calculating $W^*$ at the start as this is used multiple times
  • Calculating B.hat next as this is used to calculate Ypred. Before it was essentially calculated twice, now only once
  • Replaced the "long form" of $W^*$ with the pre-defined form in any calculations that require it

Now seeing as this reduces the number of required matrix multiplcations, I assumed this would reduce run time. Using:

X <- liver.toxicity$gene
Y <- liver.toxicity$clinic

and randomly generating 100 samples with the same number of columns as X (each using a unique normal distribution) to use as testing data. I ran the default predict() function and the adjusted predict() function 5000 times and evaluated their runtimes. Also, for peace of mind, every iteration, the predictions by the two forms of predict() were validated to be equal (to 10 significant figures). Histograms of the runtimes can be seen below:

image

image

So you can see that the runtime was not improved at all, if anything made worse. This was the same using subsets of the liver.toxicity$gene data for training and testing. This result seems counterintuitive

Max-Bladen added a commit that referenced this issue Sep 8, 2022
perf: added `custom.predict()` function to experiment in the improvement of the runtime of `predict()`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working wip work-in-progress
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

1 participant