Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change #238

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions learned_optimization/eval_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def fn(opt_state, key, data):
key, summary_key = jax.random.split(key)
(next_opt_state, loss,
key), metrics = summary.with_summary_output_reduced(fn)(
opt_state, key, data, summary_sample_rng_key=summary_key)
opt_state, key, data, sample_rng_key=summary_key)
key, key1 = jax.random.split(key)
metrics = summary.aggregate_metric_list([metrics], use_jnp=True, key=key1)
else:
next_opt_state, loss, key = fn(opt_state, key, data)
metrics = {}
Expand Down Expand Up @@ -142,6 +144,7 @@ def single_task_training_curves(
last_eval_batches: int = 20,
eval_task: Optional[tasks_base.Task] = None,
device: Optional[jax.lib.xla_client.Device] = None,
metrics_every: Optional[int] = None,
summary_writer: Optional[summary.SummaryWriterBase] = None,
) -> Mapping[str, jnp.ndarray]:
"""Compute training curves."""
Expand All @@ -160,11 +163,14 @@ def single_task_training_curves(
opt.init, static_argnames=("num_steps",))(
p, model_state=s, num_steps=num_steps)

losses = []
eval_auxs = []
use_data = task.datasets is not None
train_xs = []
eval_xs = []
losses = []
eval_auxs = []
use_data = task.datasets is not None
train_xs = []
eval_xs = []
metrics = []
metrics_xs = []

for i in tqdm.trange(num_steps + 1, position=0):
with profile.Profile("eval"):
m = {}
Expand Down Expand Up @@ -196,16 +202,41 @@ def single_task_training_curves(
batch = jax.device_put(batch, device=device)

with profile.Profile("next_state"):
opt_state, l, key, _ = _next_state(
task, opt, opt_state, batch, key, with_metrics=False)
with_metrics = False if (
metrics_every is None) else i % metrics_every == 0
opt_state, l, key, m = _next_state(
task, opt, opt_state, batch, key, with_metrics=with_metrics)
losses.append(l)
train_xs.append(i)

if summary_writer:
summary_writer.scalar("train/loss", l, step=i)

if metrics_every:
if summary_writer:
for k, v in m.items():
agg, k = k.split("||")
if agg in ["mean", "sample"]:
summary_writer.scalar(k, v, step=i)
elif agg == "tensor":
summary_writer.tensor(k, v, step=i)
else:
logging.warning(f"Not supported aggregation type {agg}." # pylint: disable=logging-fstring-interpolation
f"Dropping data for key {k}.")
metrics.append(m)
metrics_xs.append(i)

ret = {
"train/xs": onp.asarray(train_xs),
"train/loss": onp.asarray(losses),
}

if metrics_every:
stacked_metrics = tree_utils.tree_zip_onp(metrics)
metric_dict = {f"train/metrics/{k}": v for k, v in stacked_metrics.items()}
ret = {**ret, **metric_dict}
ret["train/metrics/xs"] = onp.asarray(metrics_xs)

if eval_batches:
stacked_metrics = tree_utils.tree_zip_onp(eval_auxs)
ret["eval/xs"] = onp.asarray(eval_xs)
Expand Down
28 changes: 27 additions & 1 deletion learned_optimization/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class AggregationType(str, enum.Enum):
sample = "sample" # pylint: disable=invalid-name
collect = "collect" # pylint: disable=invalid-name
none = "none" # pylint: disable=invalid-name
tensor = "tensor" # pylint: disable=invalid-name


def summary(
Expand Down Expand Up @@ -138,8 +139,12 @@ def summary(

oryx_name = aggregation + "||" + name

mode = "append"
if aggregation == AggregationType.tensor:
mode = "strict"

if ORYX_LOGGING:
val = oryx.core.sow(val, tag=_SOW_TAG, name=oryx_name, mode="append")
val = oryx.core.sow(val, tag=_SOW_TAG, name=oryx_name, mode=mode)

return val

Expand Down Expand Up @@ -203,6 +208,9 @@ def aggregate_metric(k: str,
elif agg == AggregationType.collect:
# This might be multi dim if vmap is used, so ravel first.
return xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0)
elif agg == AggregationType.tensor:
assert len(vs) == 1
return vs[0]
elif agg == AggregationType.none:
if len(vs) != 1:
raise ValueError("when using no aggregation one must ensure only scalar "
Expand Down Expand Up @@ -279,6 +287,8 @@ def out_fn(unused_in, *args):
to_sample.append((k, v))
elif agg == AggregationType.collect:
new_metrics[k] = v.ravel()
elif agg == AggregationType.tensor:
new_metrics[k] = v
else:
raise ValueError(f"unsupported aggregation {agg}")

Expand Down Expand Up @@ -387,6 +397,9 @@ def scalar(self, name, value, step):
def histogram(self, name, value, step):
raise NotImplementedError()

def tensor(self, name, value, step):
raise NotImplementedError()

def flush(self):
raise NotImplementedError()

Expand Down Expand Up @@ -423,6 +436,10 @@ def histogram(self, name, value, step):
if self.filter_fn(name):
print(f"{step}] {name}={value}")

def tensor(self, name, value, step):
if self.filter_fn(name):
print(f"{step}] {name}=Tensor: {value.shape}")

def flush(self):
pass

Expand All @@ -439,6 +456,9 @@ def scalar(self, name, value, step):
def flush(self):
_ = [w.flush() for w in self.writers]

def tensor(self, name, value, step):
_ = [w.tensor(name, value, step) for w in self.writers]

def histogram(self, name, value, step):
_ = [w.histogram(name, value, step) for w in self.writers]

Expand Down Expand Up @@ -528,6 +548,12 @@ def text(self, name, textdata, step):

tf.summary.text(name=name, data=tf.constant(textdata), step=step)

def tensor(self, name, tensor, step):
"""Write a tensor summary."""
self._ensure_default()
tf.summary.write(tag=name, tensor=tensor, step=step, name=name)



JaxboardWriter = TensorboardWriter