-
Notifications
You must be signed in to change notification settings - Fork 106
Description
Model initialization step within colab using tpu and default configuration exits with error.
Errors are nested through jax and hypernerf, but it appear that the root is
hypernerf/hypernerf/model_utils.py
Line 119 in d433ebe
accum_prod = jnp.concatenate([ |
within the
volumetric_rendering
function, jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
.
The relevant error is
/usr/local/lib/python3.7/dist-packages/hypernerf/model_utils.py in volumetric_rendering(rgb, sigma, z_vals, dirs, use_white_background, sample_at_infinity, eps)
113 z_vals[..., 1:] - z_vals[..., :-1],
--> 114 jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
115 ], -1)/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/util.py in _broadcast_to(arr, shape)
341 return arr.broadcast_to(shape)
--> 342 _check_arraylike("broadcast_to", arr)
343 arr = arr if isinstance(arr, ndarray) else _asarray(arr)/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/util.py in _check_arraylike(fun_name, *args)
294 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 295 raise TypeError(msg.format(fun_name, type(arg), pos))
296UnfilteredStackTrace: TypeError: broadcast_to requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
A quick search brought up things like https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax which suggested all elements should be converted to the jnp arrays. Haven't gotten it working yet, though.