diff --git a/RELEASES.md b/RELEASES.md index 56cb6fd4b..294614114 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -23,6 +23,7 @@ - Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602) - Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592) - Fiw linesearch import error on Scipy 1.14 (PR #642, Issue #641) +- Upgrade supported JAX versions from jax<=0.4.24 to jax<=0.4.30 (PR #643) ## 0.9.3 *January 2024* diff --git a/ot/backend.py b/ot/backend.py index 9cc6446bf..534c03293 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -120,6 +120,7 @@ import jax.scipy.special as jspecial from jax.lib import xla_bridge jax_type = jax.numpy.ndarray + jax_new_version = float('.'.join(jax.__version__.split('.')[1:])) > 4.24 except ImportError: jax = False jax_type = float @@ -1439,11 +1440,19 @@ def __init__(self): jax.device_put(jnp.array(1, dtype=jnp.float64), d) ] + self.jax_new_version = jax_new_version + def _to_numpy(self, a): return np.array(a) + def _get_device(self, a): + if self.jax_new_version: + return list(a.devices())[0] + else: + return a.device_buffer.device() + def _change_device(self, a, type_as): - return jax.device_put(a, type_as.device_buffer.device()) + return jax.device_put(a, self._get_device(type_as)) def _from_numpy(self, a, type_as=None): if isinstance(a, float): @@ -1688,7 +1697,10 @@ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) def dtype_device(self, a): - return a.dtype, a.device_buffer.device() + if self.jax_new_version: + return a.dtype, list(a.devices())[0] + else: + return a.dtype, a.device_buffer.device() def assert_same_dtype_device(self, a, b): a_dtype, a_device = self.dtype_device(a) diff --git a/requirements_all.txt b/requirements_all.txt index 66a7c2dfc..a015855f6 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -6,8 +6,8 @@ pymanopt @ git+https://github.com/pymanopt/pymanopt.git@master cvxopt scikit-learn torch -jax<=0.4.24 -jaxlib<=0.4.24 +jax +jaxlib tensorflow pytest torch_geometric