Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568283317
  • Loading branch information
Jake VanderPlas authored and learned_optimization authors committed Sep 28, 2023
1 parent fcded5d commit 45eff8c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions learned_optimization/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def aggregate_metric(k: str,
if agg == AggregationType.mean:
# size is known at compile time.
size = onp.sum([onp.prod(v.shape) for v in vs])
return xnp.sum(xnp.asarray([xnp.sum(v) / size for v in vs]))
return xnp.sum(xnp.asarray([xnp.sum(v) / size for v in vs])) # pytype: disable=bad-return-type # jnp-type
elif agg == AggregationType.sample:
vs = xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0)
if use_jnp:
Expand All @@ -204,13 +204,13 @@ def aggregate_metric(k: str,
else:
i = onp.random.randint(0, len(vs), dtype=xnp.int32)

return vs[i]
return vs[i] # pytype: disable=bad-return-type # jnp-type
elif agg == AggregationType.collect:
# This might be multi dim if vmap is used, so ravel first.
return xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0)
return xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0) # pytype: disable=bad-return-type # jnp-type
elif agg == AggregationType.tensor:
assert len(vs) == 1
return vs[0]
return vs[0] # pytype: disable=bad-return-type # jnp-type
elif agg == AggregationType.none:
if len(vs) != 1:
raise ValueError("when using no aggregation one must ensure only scalar "
Expand All @@ -220,7 +220,7 @@ def aggregate_metric(k: str,
if val.size != 1:
raise ValueError("Value with none aggregation type was not a scalar?"
f" Found {val}")
return xnp.reshape(val, ())
return xnp.reshape(val, ()) # pytype: disable=bad-return-type # jnp-type
else:
raise ValueError(f"Unsupported Aggregation type {agg}")

Expand Down

0 comments on commit 45eff8c

Please sign in to comment.