Skip to content

Commit

Permalink
clean up xla
Browse files Browse the repository at this point in the history
  • Loading branch information
vict0rsch committed Sep 8, 2021
1 parent 3e61ace commit 0659f96
Showing 1 changed file with 1 addition and 38 deletions.
39 changes: 1 addition & 38 deletions apply_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,6 @@ def parse_args():
+ "a directory but n is 10 then only the first 10 images will be loaded"
+ " for processing)",
)
parser.add_argument(
"-x",
"--xla_purge_samples",
type=int,
default=-1,
help="(TPU) XLA compile time induces extra computations."
+ " Use this flag to ignore x samples when computing time averages."
+ " Defaults to -1 (no purge)",
)
parser.add_argument(
"--no_conf",
action="store_true",
Expand Down Expand Up @@ -154,7 +145,6 @@ def parse_args():
import_time = time.time()
import sys
from collections import OrderedDict
from datetime import datetime
from pathlib import Path

import comet_ml # noqa: F401
Expand All @@ -171,15 +161,6 @@ def parse_args():

import_time = time.time() - import_time

XLA = False
try:
import torch_xla.core.xla_model as xm # type: ignore
import torch_xla.debug.metrics as met # type: ignore

XLA = True
except ImportError:
pass


def to_m1_p1(img, i):
if img.min() >= 0 and img.max() <= 1:
Expand Down Expand Up @@ -301,7 +282,6 @@ def write_apply_config(out):
fuse = args.fuse
bin_value = args.flood_mask_binarization
resume_path = args.resume_path
xla_purge_samples = args.xla_purge_samples
n_images = args.n_images
cloudy = not args.no_cloudy
time_inference = not args.no_time
Expand Down Expand Up @@ -375,16 +355,11 @@ def write_apply_config(out):

with Timer(store=stores.get("setup", []), ignore=time_inference):
torch.set_grad_enabled(False)
device = None
if XLA:
device = xm.xla_device() # type: ignore

trainer = Trainer.resume_from_path(
resume_path,
setup=True,
inference=True,
new_exp=None,
device=device,
)
print()
print_num_parameters(trainer, True)
Expand Down Expand Up @@ -456,7 +431,6 @@ def write_apply_config(out):
stores=stores,
bin_value=bin_value,
half=half,
xla=XLA,
cloudy=cloudy,
)

Expand Down Expand Up @@ -539,18 +513,7 @@ def write_apply_config(out):
# ---------------------------
if time_inference:
print("\n• Timings\n")
print_store(stores, purge=xla_purge_samples)

if XLA:
metrics_dir = Path(__file__).parent / "config" / "metrics"
metrics_dir.mkdir(exist_ok=True, parents=True)
now = str(datetime.now()).replace(" ", "_")
with open(
metrics_dir / f"xla_metrics_{now}.txt",
"w",
) as f:
report = met.metrics_report() # type: ignore
print(report, file=f)
print_store(stores)

if not args.no_conf and outdir is not None:
write_apply_config(outdir)

0 comments on commit 0659f96

Please sign in to comment.