diff --git a/learned_optimization/summary.py b/learned_optimization/summary.py index f3687f1..5bc5191 100644 --- a/learned_optimization/summary.py +++ b/learned_optimization/summary.py @@ -195,7 +195,7 @@ def aggregate_metric(k: str, if agg == AggregationType.mean: # size is known at compile time. size = onp.sum([onp.prod(v.shape) for v in vs]) - return xnp.sum(xnp.asarray([xnp.sum(v) / size for v in vs])) + return xnp.sum(xnp.asarray([xnp.sum(v) / size for v in vs])) # pytype: disable=bad-return-type # jnp-type elif agg == AggregationType.sample: vs = xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0) if use_jnp: @@ -204,13 +204,13 @@ def aggregate_metric(k: str, else: i = onp.random.randint(0, len(vs), dtype=xnp.int32) - return vs[i] + return vs[i] # pytype: disable=bad-return-type # jnp-type elif agg == AggregationType.collect: # This might be multi dim if vmap is used, so ravel first. - return xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0) + return xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0) # pytype: disable=bad-return-type # jnp-type elif agg == AggregationType.tensor: assert len(vs) == 1 - return vs[0] + return vs[0] # pytype: disable=bad-return-type # jnp-type elif agg == AggregationType.none: if len(vs) != 1: raise ValueError("when using no aggregation one must ensure only scalar " @@ -220,7 +220,7 @@ def aggregate_metric(k: str, if val.size != 1: raise ValueError("Value with none aggregation type was not a scalar?" f" Found {val}") - return xnp.reshape(val, ()) + return xnp.reshape(val, ()) # pytype: disable=bad-return-type # jnp-type else: raise ValueError(f"Unsupported Aggregation type {agg}")