Skip to content

Commit

Permalink
Added functorch to functional_autograd_benchmark
Browse files Browse the repository at this point in the history
Description:

- Following pytorch/functorch#497 adding an option to run benchmarks with functorch and compare to original functional autograd results.

Running the benchmark we get below table:

<details>
<summary>
Table
</summary>

```
| model | task | mean | var |
| -- | -- | -- | -- |
| resnet18 | vjp | 0.03826599195599556 | 4.3332115637895186e-06 |
| resnet18 | functorch vjp | 0.037201929837465286 | 6.139693198292662e-09 |
| resnet18 | vhp | 0.2202976644039154 | 2.8687209052691287e-08 |
| resnet18 | functorch vhp | 0.22117868065834045 | 4.108771278765744e-08 |
| resnet18 | jvp | 0.18679651618003845 | 1.832455254202614e-08 |
| resnet18 | functorch jvp | 0.05305683612823486 | 1.6690266946284282e-08 |
| fcn_resnet | vjp | 0.6071907877922058 | 7.436695454998699e-07 |
| fcn_resnet | functorch vjp | 0.6115708947181702 | 1.121692207561864e-06 |
| fcn_resnet | vhp | 3.419469118118286 | 0.020633839070796967 |
| fcn_resnet | jvp | 2.5421929359436035 | 3.1765587209520163e-06 |
| fcn_resnet | functorch jvp | 0.7628333568572998 | 1.4555752159139956e-07 |
| detr | vjp | 0.19494840502738953 | 1.9122715457342565e-05 |
| detr | vhp | 1.1664292812347412 | 0.000948643428273499 |
| detr | jvp | 0.9990308880805969 | 1.0214127541985363e-05 |
| ppl_simple_reg | vjp | 0.0007535457843914628 | 6.024204690646684e-09 |
| ppl_simple_reg | functorch vjp | 0.0016954183811321855 | 1.160151974488599e-08 |
| ppl_simple_reg | vhp | 0.0011888503795489669 | 5.93119386937957e-10 |
| ppl_simple_reg | functorch vhp | 0.0026826143730431795 | 1.6787025103326414e-08 |
| ppl_simple_reg | jvp | 0.001067900680936873 | 7.409912128331086e-10 |
| ppl_simple_reg | functorch jvp | 0.002065300941467285 | 9.710328185974504e-08 |
| ppl_simple_reg | hvp | 0.001212477684020996 | 1.974137298077494e-09 |
| ppl_simple_reg | functorch hvp | 0.00482442369684577 | 2.327668653379078e-07 |
| ppl_simple_reg | jacobian | 0.0009108781814575195 | 3.489469158068914e-09 |
| ppl_simple_reg | functorch jacobian | 0.0019866942893713713 | 1.938326299466553e-08 |
| ppl_simple_reg | hessian | 0.005053090862929821 | 3.370298600202659e-07 |
| ppl_simple_reg | functorch hessian | 0.006374978926032782 | 7.556796077778927e-08 |
| ppl_simple_reg | hessian_fwdrev | 0.0036706924438476562 | 1.996075527088692e-09 |
| ppl_simple_reg | functorch hessian_fwdrev | 0.0058908225037157536 | 7.548283775804521e-08 |
| ppl_simple_reg | hessian_revrev | 0.0015769004821777344 | 1.5754418214442012e-08 |
| ppl_simple_reg | functorch hessian_revrev | 0.0041002752259373665 | 6.713568723171193e-08 |
| ppl_simple_reg | jacfwd | 0.0018048763740807772 | 2.7375660849315864e-08 |
| ppl_simple_reg | functorch jacfwd | 0.002047991845756769 | 2.432247070416338e-09 |
| ppl_simple_reg | jacrev | 0.0009733677143231034 | 1.0078769818733235e-08 |
| ppl_simple_reg | functorch jacrev | 0.0021971464157104492 | 1.2729884701911942e-08 |
| ppl_robust_reg | vjp | 0.005820560269057751 | 8.582588151284654e-08 |
| ppl_robust_reg | functorch vjp | 0.00796132069081068 | 9.663100541956737e-09 |
| ppl_robust_reg | vhp | 0.009825301356613636 | 2.0081762386325863e-07 |
| ppl_robust_reg | functorch vhp | 0.014890861697494984 | 4.558066279969353e-07 |
| ppl_robust_reg | jvp | 0.008297419175505638 | 2.9454400873873965e-07 |
| ppl_robust_reg | functorch jvp | 0.008052706718444824 | 7.120377176761394e-08 |
| ppl_robust_reg | hvp | 0.015414690598845482 | 7.42123745567369e-07 |
| ppl_robust_reg | functorch hvp | 0.02699306048452854 | 1.4650488537881756e-06 |
| ppl_robust_reg | jacobian | 0.006207776255905628 | 1.7068457225377642e-07 |
| ppl_robust_reg | functorch jacobian | 0.009173822589218616 | 1.2214455580306094e-07 |
| ppl_robust_reg | hessian | 0.04670915752649307 | 1.4299343092716299e-05 |
| ppl_robust_reg | functorch hessian | 0.02337808534502983 | 3.0397418413485866e-06 |
| ppl_robust_reg | hessian_fwdrev | 0.024229884147644043 | 2.0425247839739313e-06 |
| ppl_robust_reg | functorch hessian_fwdrev | 0.022021746262907982 | 3.512146236062108e-07 |
| ppl_robust_reg | hessian_revrev | 0.012355780228972435 | 7.090877147675201e-07 |
| ppl_robust_reg | functorch hessian_revrev | 0.013960313983261585 | 6.326549737423193e-07 |
| ppl_robust_reg | jacfwd | 0.008112502284348011 | 2.88503088086145e-08 |
| ppl_robust_reg | functorch jacfwd | 0.008947920985519886 | 4.2070990247111695e-08 |
| ppl_robust_reg | jacrev | 0.00635871896520257 | 1.3403841592207755e-07 |
| ppl_robust_reg | functorch jacrev | 0.009123563766479492 | 2.677554675756255e-07 |
| wav2letter | vjp | 0.02078995667397976 | 2.1110793113621185e-06 |
| wav2letter | functorch vjp | 0.019202351570129395 | 9.210506135559626e-09 |
| wav2letter | vhp | 0.05997290462255478 | 8.558587616391833e-09 |
| wav2letter | functorch vhp | 0.06035261228680611 | 1.6448565842708263e-09 |
| wav2letter | jvp | 0.04507789760828018 | 1.5771547401399744e-09 |
| wav2letter | functorch jvp | 0.013057494536042213 | 3.804750292601966e-09 |
| deepspeech | vjp | 0.3648746609687805 | 1.5359055396402255e-05 |
| transformer | vjp | 0.05496881157159805 | 1.242562319703211e-08 |
| transformer | functorch vjp | 0.057835936546325684 | 2.6113376350167528e-08 |
| transformer | vhp | 0.18313491344451904 | 7.226336151688884e-08 |
| transformer | jvp | 0.13924935460090637 | 1.6989159234981344e-07 |
| multiheadattn | vjp | 0.0014708995586261153 | 3.710916729460223e-08 |
| multiheadattn | functorch vjp | 0.002404856728389859 | 2.1910574687922235e-08 |
| multiheadattn | vhp | 0.003382015274837613 | 5.3098595742540056e-08 |
| multiheadattn | functorch vhp | 0.005340623669326305 | 5.897558708056749e-08 |
| multiheadattn | jvp | 0.0027526854537427425 | 3.508620949332908e-08 |
| multiheadattn | functorch jvp | 0.0022981404326856136 | 1.327894807445773e-07 |

```

</details>

<details>
<summary>
Stdout
</summary>

```
Found functorch: 0.2.0a0+386a541
Results for model resnet18 on task vjp: 0.03826599195599556s (var: 4.3332115637895186e-06)
Results for model resnet18 on task vjp using Functorch: 0.037201929837465286s (var: 6.139693198292662e-09)
Results for model resnet18 on task vhp: 0.2202976644039154s (var: 2.8687209052691287e-08)
Results for model resnet18 on task vhp using Functorch: 0.22117868065834045s (var: 4.108771278765744e-08)
Results for model resnet18 on task jvp: 0.18679651618003845s (var: 1.832455254202614e-08)
Results for model resnet18 on task jvp using Functorch: 0.05305683612823486s (var: 1.6690266946284282e-08)
Results for model fcn_resnet on task vjp: 0.6071907877922058s (var: 7.436695454998699e-07)
Results for model fcn_resnet on task vjp using Functorch: 0.6115708947181702s (var: 1.121692207561864e-06)
Results for model fcn_resnet on task vhp: 3.419469118118286s (var: 0.020633839070796967)
Failed model using Functorch: fcn_resnet, task: vhp, Error message:
	 CUDA out of memory. Tried to allocate 114.00 MiB (GPU 0; 47.46 GiB total capacity; 45.62 GiB already allocated; 5.31 MiB free; 46.02 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Results for model fcn_resnet on task jvp: 2.5421929359436035s (var: 3.1765587209520163e-06)
Results for model fcn_resnet on task jvp using Functorch: 0.7628333568572998s (var: 1.4555752159139956e-07)
Results for model detr on task vjp: 0.19494840502738953s (var: 1.9122715457342565e-05)
Failed model using Functorch: detr, task: vjp, Error message:
	 Cannot access data pointer of Tensor that doesn't have storage
Results for model detr on task vhp: 1.1664292812347412s (var: 0.000948643428273499)
Failed model using Functorch: detr, task: vhp, Error message:
	 Cannot access data pointer of Tensor that doesn't have storage
Results for model detr on task jvp: 0.9990308880805969s (var: 1.0214127541985363e-05)
Failed model using Functorch: detr, task: jvp, Error message:
	 Trying to use forward AD with _cdist_forward that does not support it because it has not been implemented yet.
Please file an issue to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml so that we can prioritize its implementation.
Results for model ppl_simple_reg on task vjp: 0.0007535457843914628s (var: 6.024204690646684e-09)
Results for model ppl_simple_reg on task vjp using Functorch: 0.0016954183811321855s (var: 1.160151974488599e-08)
Results for model ppl_simple_reg on task vhp: 0.0011888503795489669s (var: 5.93119386937957e-10)
Results for model ppl_simple_reg on task vhp using Functorch: 0.0026826143730431795s (var: 1.6787025103326414e-08)
Results for model ppl_simple_reg on task jvp: 0.001067900680936873s (var: 7.409912128331086e-10)
Results for model ppl_simple_reg on task jvp using Functorch: 0.002065300941467285s (var: 9.710328185974504e-08)
Results for model ppl_simple_reg on task hvp: 0.001212477684020996s (var: 1.974137298077494e-09)
Results for model ppl_simple_reg on task hvp using Functorch: 0.00482442369684577s (var: 2.327668653379078e-07)
Results for model ppl_simple_reg on task jacobian: 0.0009108781814575195s (var: 3.489469158068914e-09)
Results for model ppl_simple_reg on task jacobian using Functorch: 0.0019866942893713713s (var: 1.938326299466553e-08)
Results for model ppl_simple_reg on task hessian: 0.005053090862929821s (var: 3.370298600202659e-07)
Results for model ppl_simple_reg on task hessian using Functorch: 0.006374978926032782s (var: 7.556796077778927e-08)
Results for model ppl_simple_reg on task hessian_fwdrev: 0.0036706924438476562s (var: 1.996075527088692e-09)
Results for model ppl_simple_reg on task hessian_fwdrev using Functorch: 0.0058908225037157536s (var: 7.548283775804521e-08)
Results for model ppl_simple_reg on task hessian_revrev: 0.0015769004821777344s (var: 1.5754418214442012e-08)
Results for model ppl_simple_reg on task hessian_revrev using Functorch: 0.0041002752259373665s (var: 6.713568723171193e-08)
Results for model ppl_simple_reg on task jacfwd: 0.0018048763740807772s (var: 2.7375660849315864e-08)
Results for model ppl_simple_reg on task jacfwd using Functorch: 0.002047991845756769s (var: 2.432247070416338e-09)
Results for model ppl_simple_reg on task jacrev: 0.0009733677143231034s (var: 1.0078769818733235e-08)
Results for model ppl_simple_reg on task jacrev using Functorch: 0.0021971464157104492s (var: 1.2729884701911942e-08)
Results for model ppl_robust_reg on task vjp: 0.005820560269057751s (var: 8.582588151284654e-08)
Results for model ppl_robust_reg on task vjp using Functorch: 0.00796132069081068s (var: 9.663100541956737e-09)
Results for model ppl_robust_reg on task vhp: 0.009825301356613636s (var: 2.0081762386325863e-07)
Results for model ppl_robust_reg on task vhp using Functorch: 0.014890861697494984s (var: 4.558066279969353e-07)
Results for model ppl_robust_reg on task jvp: 0.008297419175505638s (var: 2.9454400873873965e-07)
Results for model ppl_robust_reg on task jvp using Functorch: 0.008052706718444824s (var: 7.120377176761394e-08)
Results for model ppl_robust_reg on task hvp: 0.015414690598845482s (var: 7.42123745567369e-07)
Results for model ppl_robust_reg on task hvp using Functorch: 0.02699306048452854s (var: 1.4650488537881756e-06)
Results for model ppl_robust_reg on task jacobian: 0.006207776255905628s (var: 1.7068457225377642e-07)
Results for model ppl_robust_reg on task jacobian using Functorch: 0.009173822589218616s (var: 1.2214455580306094e-07)
Results for model ppl_robust_reg on task hessian: 0.04670915752649307s (var: 1.4299343092716299e-05)
Results for model ppl_robust_reg on task hessian using Functorch: 0.02337808534502983s (var: 3.0397418413485866e-06)
Results for model ppl_robust_reg on task hessian_fwdrev: 0.024229884147644043s (var: 2.0425247839739313e-06)
Results for model ppl_robust_reg on task hessian_fwdrev using Functorch: 0.022021746262907982s (var: 3.512146236062108e-07)
Results for model ppl_robust_reg on task hessian_revrev: 0.012355780228972435s (var: 7.090877147675201e-07)
Results for model ppl_robust_reg on task hessian_revrev using Functorch: 0.013960313983261585s (var: 6.326549737423193e-07)
Results for model ppl_robust_reg on task jacfwd: 0.008112502284348011s (var: 2.88503088086145e-08)
Results for model ppl_robust_reg on task jacfwd using Functorch: 0.008947920985519886s (var: 4.2070990247111695e-08)
Results for model ppl_robust_reg on task jacrev: 0.00635871896520257s (var: 1.3403841592207755e-07)
Results for model ppl_robust_reg on task jacrev using Functorch: 0.009123563766479492s (var: 2.677554675756255e-07)
Results for model wav2letter on task vjp: 0.02078995667397976s (var: 2.1110793113621185e-06)
Results for model wav2letter on task vjp using Functorch: 0.019202351570129395s (var: 9.210506135559626e-09)
Results for model wav2letter on task vhp: 0.05997290462255478s (var: 8.558587616391833e-09)
Results for model wav2letter on task vhp using Functorch: 0.06035261228680611s (var: 1.6448565842708263e-09)
Results for model wav2letter on task jvp: 0.04507789760828018s (var: 1.5771547401399744e-09)
Results for model wav2letter on task jvp using Functorch: 0.013057494536042213s (var: 3.804750292601966e-09)
Results for model deepspeech on task vjp: 0.3648746609687805s (var: 1.5359055396402255e-05)
Failed model using Functorch: deepspeech, task: vjp, Error message:
	 Cannot access storage of TensorWrapper
Results for model transformer on task vjp: 0.05496881157159805s (var: 1.242562319703211e-08)
Results for model transformer on task vjp using Functorch: 0.057835936546325684s (var: 2.6113376350167528e-08)
Results for model transformer on task vhp: 0.18313491344451904s (var: 7.226336151688884e-08)
Failed model using Functorch: transformer, task: vhp, Error message:
	 bad optional access
Results for model transformer on task jvp: 0.13924935460090637s (var: 1.6989159234981344e-07)
Failed model using Functorch: transformer, task: jvp, Error message:
	 Trying to use forward AD with embedding that does not support it because it has not been implemented yet.
Please file an issue to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml so that we can prioritize its implementation.
Results for model multiheadattn on task vjp: 0.0014708995586261153s (var: 3.710916729460223e-08)
Results for model multiheadattn on task vjp using Functorch: 0.002404856728389859s (var: 2.1910574687922235e-08)
Results for model multiheadattn on task vhp: 0.003382015274837613s (var: 5.3098595742540056e-08)
Results for model multiheadattn on task vhp using Functorch: 0.005340623669326305s (var: 5.897558708056749e-08)
Results for model multiheadattn on task jvp: 0.0027526854537427425s (var: 3.508620949332908e-08)
Results for model multiheadattn on task jvp using Functorch: 0.0022981404326856136s (var: 1.327894807445773e-07)
```

</details>

All functorch errors are reported in its repository.

cc @zou3519
Pull Request resolved: #75689
Approved by: https://github.com/zou3519
  • Loading branch information
vfdev-5 authored and pytorchmergebot committed Apr 22, 2022
1 parent b3aa2de commit 6593d29
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 8 deletions.
19 changes: 19 additions & 0 deletions benchmarks/functional_autograd_benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ export OMP_NUM_THREADS=10
git checkout master
python setup.py develop

# Install dependencies:
# Scipy is required by detr
pip install scipy

# Run the benchmark for the base
# This will use the GPU if available.
pushd benchmarks/functional_autograd_benchmark
Expand All @@ -46,3 +50,18 @@ popd
- `compare.py` is the entry point to run the comparison script that generates a markdown table.
- `torchaudio_models.py` and `torchvision_models.py` contains code extracted from torchaudio and torchvision to be able to run the models without having a specific version of these libraries installed.
- `ppl_models.py`, `vision_models.py` and `audio_text_models.py` contain all the getter functions used for the benchmark.


### Benchmarking against `functorch`

```bash
# Install stable functorch:
pip install functorch
# or install from source:
pip install git+https://github.com/pytorch/functorch

# Run the benchmark for the base
# This will use the GPU if available.
pushd benchmarks/functional_autograd_benchmark
python functional_autograd_benchmark.py --output bench-with-functorch.txt
```
17 changes: 16 additions & 1 deletion benchmarks/functional_autograd_benchmark/audio_text_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

import torchaudio_models as models

from utils import extract_weights, load_weights, GetterReturnType
from utils import check_for_functorch, extract_weights, load_weights, GetterReturnType


has_functorch = check_for_functorch()


def get_wav2letter(device: torch.device) -> GetterReturnType:
N = 10
Expand Down Expand Up @@ -50,6 +54,12 @@ def get_deepspeech(device: torch.device) -> GetterReturnType:

model = models.DeepSpeech(rnn_type=nn.LSTM, labels=labels, rnn_hidden_size=1024, nb_layers=5,
audio_conf=audio_conf, bidirectional=True)

if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_

replace_all_batch_norm_modules_(model)

model = model.to(device)
criterion = nn.CTCLoss()
params, names = extract_weights(model)
Expand All @@ -71,6 +81,11 @@ def get_transformer(device: torch.device) -> GetterReturnType:
ntoken = 50
model = models.TransformerModel(ntoken=ntoken, ninp=720, nhead=12, nhid=2048, nlayers=2)
model.to(device)

if has_functorch:
# disable dropout for consistency checking
model.eval()

criterion = nn.NLLLoss()
params, names = extract_weights(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from collections import defaultdict
from typing import NamedTuple, Callable, List, Any

try:
import functorch as ft
has_functorch = True
print(f"Found functorch: {ft.__version__}")
except ImportError:
has_functorch = False

import ppl_models
import vision_models
import audio_text_models
Expand Down Expand Up @@ -36,6 +43,65 @@ def jacrev(model, inp, strict=None):
else:
return getattr(functional, task)

def get_task_functorch(task: str) -> Callable:

@torch.no_grad()
def vjp(model, inp, v=None, strict=None):
assert v is not None
out, vjpfunc = ft.vjp(model, *inp)
return out, vjpfunc(v)

@torch.no_grad()
def jvp(model, inp, v=None, strict=None):
assert v is not None
return ft.jvp(model, inp, v)

@torch.no_grad()
def vhp(model, inp, v=None, strict=None):
assert v is not None
argnums = tuple(range(len(inp)))
_, vjpfunc, aux = ft.vjp(ft.grad_and_value(model, argnums), *inp, has_aux=True)
return aux, vjpfunc(v)

@torch.no_grad()
def hvp(model, inp, v=None, strict=None):
assert v is not None
argnums = tuple(range(len(inp)))
_, hvp_out, aux = ft.jvp(ft.grad_and_value(model, argnums), inp, v, has_aux=True)
return aux, hvp_out

@torch.no_grad()
def jacfwd(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.jacfwd(model, argnums)(*inp)

@torch.no_grad()
def jacrev(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.jacrev(model, argnums)(*inp)

@torch.no_grad()
def hessian(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.hessian(model, argnums=argnums)(*inp)

@torch.no_grad()
def hessian_fwdrev(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.jacfwd(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp)

@torch.no_grad()
def hessian_revrev(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.jacrev(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp)

if task in locals():
return locals()[task]
elif task == "jacobian":
raise RuntimeError("functorch has no equivalent of autograd.functional.jacobian with vectorize=False yet")
else:
raise RuntimeError(f"Unsupported task: {task}")

# Listing of the different tasks
FAST_TASKS_NO_DOUBLE_BACK = [
"vjp",
Expand Down Expand Up @@ -99,15 +165,32 @@ def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:

return v

def run_once(model: Callable, inp: InputsType, task: str, v: VType) -> None:
def run_once(model: Callable, inp: InputsType, task: str, v: VType, **kwargs) -> None:
func = get_task_func(task)

if v is not None:
res = func(model, inp, v=v, strict=True)
else:
res = func(model, inp, strict=True)

def run_model(model_getter: GetterType, args: Any, task: str) -> List[float]:
def run_once_functorch(model: Callable, inp: InputsType, task: str, v: VType, maybe_check_consistency=False) -> None:
func = get_task_functorch(task)

if v is not None:
res = func(model, inp, v=v, strict=True)
else:
res = func(model, inp, strict=True)

if maybe_check_consistency:
af_func = get_task_func(task)
if v is not None:
expected = af_func(model, inp, v=v, strict=True)
else:
expected = af_func(model, inp, strict=True)
atol = 1e-2 if task == "vhp" else 5e-3
torch.testing.assert_close(res, expected, rtol=1e-5, atol=atol, msg=f"Consistency fail for task '{task}'")

def run_model(model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once) -> List[float]:
if args.gpu == -1:
device = torch.device("cpu")

Expand All @@ -121,14 +204,17 @@ def noop():
model, inp = model_getter(device)

v = get_v_for(model, inp, task)

# Warmup
run_once(model, inp, task, v)
# maybe_check_consistency=True checks for consistency between
# functorch vs autograd.functional and is done in run_once_functorch only
run_once_fn(model, inp, task, v, maybe_check_consistency=True)

elapsed = []
for it in range(args.num_iters):
do_sync()
start = time.time()
run_once(model, inp, task, v)
run_once_fn(model, inp, task, v)
do_sync()
elapsed.append(time.time() - start)

Expand Down Expand Up @@ -173,6 +259,18 @@ def main():
results[name][task] = (mean.item(), var.item())
print("Results for model {} on task {}: {}s (var: {})".format(name, task, mean, var))

if has_functorch:
try:
runtimes = run_model(model_getter, args, task, run_once_fn=run_once_functorch)
except RuntimeError as e:
print(f"Failed model using Functorch: {name}, task: {task}, Error message: \n\t", e)
continue

runtimes = torch.tensor(runtimes)
mean, var = runtimes.mean(), runtimes.var()
results[name][f"functorch {task}"] = (mean.item(), var.item())
print("Results for model {} on task {} using Functorch: {}s (var: {})".format(name, task, mean, var))

if args.output:
with open(args.output, "w") as f:
f.write(to_markdown_table(results))
Expand Down
7 changes: 7 additions & 0 deletions benchmarks/functional_autograd_benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,10 @@ def from_markdown_table(data: str) -> TimingResultType:
res[model][task] = (float(mean), float(var))

return res

def check_for_functorch():
try:
import functorch # noqa: F401
return True
except ImportError:
return False
29 changes: 26 additions & 3 deletions benchmarks/functional_autograd_benchmark/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@
from torch import Tensor
import torchvision_models as models

from utils import extract_weights, load_weights, GetterReturnType
from utils import check_for_functorch, extract_weights, load_weights, GetterReturnType

from typing import cast

has_functorch = check_for_functorch()


def get_resnet18(device: torch.device) -> GetterReturnType:
N = 32
model = models.resnet18(pretrained=False)

if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_

replace_all_batch_norm_modules_(model)

criterion = torch.nn.CrossEntropyLoss()
model.to(device)
params, names = extract_weights(model)
Expand All @@ -29,6 +38,14 @@ def get_fcn_resnet(device: torch.device) -> GetterReturnType:
N = 8
criterion = torch.nn.MSELoss()
model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False)

if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_

replace_all_batch_norm_modules_(model)
# disable dropout for consistency checking
model.eval()

model.to(device)
params, names = extract_weights(model)

Expand Down Expand Up @@ -56,6 +73,12 @@ def get_detr(device: torch.device) -> GetterReturnType:

model = models.DETR(num_classes=num_classes, hidden_dim=hidden_dim, nheads=nheads,
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)

if has_functorch:
from functorch.experimental import replace_all_batch_norm_modules_

replace_all_batch_norm_modules_(model)

losses = ['labels', 'boxes', 'cardinality']
eos_coef = 0.1
bbox_loss_coef = 5
Expand All @@ -74,9 +97,9 @@ def get_detr(device: torch.device) -> GetterReturnType:
for idx in range(N):
targets = {}
n_targets: int = int(torch.randint(5, 10, size=tuple()).item())
label = torch.randint(5, 10, size=(n_targets,))
label = torch.randint(5, 10, size=(n_targets,), device=device)
targets["labels"] = label
boxes = torch.randint(100, 800, size=(n_targets, 4))
boxes = torch.randint(100, 800, size=(n_targets, 4), device=device)
for t in range(n_targets):
if boxes[t, 0] > boxes[t, 2]:
boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0]
Expand Down

0 comments on commit 6593d29

Please sign in to comment.