Skip to content

[MTN] more permissive check_backend #494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 45 additions & 28 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,27 @@
str_type_error = "All array should be from the same type/backend. Current types are : {}"


def get_backend_list():
"""Returns the list of available backends"""
lst = [NumpyBackend(), ]
# Mapping between argument types and the existing backend
_BACKENDS = []


if torch:
lst.append(TorchBackend())
def register_backend(backend):
_BACKENDS.append(backend)

if jax:
lst.append(JaxBackend())

if cp: # pragma: no cover
lst.append(CupyBackend())
def get_backend_list():
"""Returns the list of available backends"""
return _BACKENDS


if tf:
lst.append(TensorflowBackend())
def _check_args_backend(backend, args):
is_instance = set(isinstance(a, backend.__type__) for a in args)
# check that all arguments matched or not the type
if len(is_instance) == 1:
return is_instance.pop()

return lst
# Oterwise return an error
raise ValueError(str_type_error.format([type(a) for a in args]))


def get_backend(*args):
Expand All @@ -158,22 +162,12 @@ def get_backend(*args):
# check that some arrays given
if not len(args) > 0:
raise ValueError(" The function takes at least one parameter")
# check all same type
if not len(set(type(a) for a in args)) == 1:
raise ValueError(str_type_error.format([type(a) for a in args]))

if isinstance(args[0], np.ndarray):
return NumpyBackend()
elif isinstance(args[0], torch_type):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
elif isinstance(args[0], cp_type): # pragma: no cover
return CupyBackend()
elif isinstance(args[0], tf_type):
return TensorflowBackend()
else:
raise ValueError("Unknown type of non implemented backend.")

for backend in _BACKENDS:
if _check_args_backend(backend, args):
return backend

raise ValueError("Unknown type of non implemented backend.")


def to_numpy(*args):
Expand Down Expand Up @@ -1318,6 +1312,9 @@ def matmul(self, a, b):
return np.matmul(a, b)


register_backend(NumpyBackend())


class JaxBackend(Backend):
"""
JAX implementation of the backend
Expand Down Expand Up @@ -1676,6 +1673,11 @@ def matmul(self, a, b):
return jnp.matmul(a, b)


if jax:
# Only register jax backend if it is installed
register_backend(JaxBackend())


class TorchBackend(Backend):
"""
PyTorch implementation of the backend
Expand Down Expand Up @@ -2148,6 +2150,11 @@ def matmul(self, a, b):
return torch.matmul(a, b)


if torch:
# Only register torch backend if it is installed
register_backend(TorchBackend())


class CupyBackend(Backend): # pragma: no cover
"""
CuPy implementation of the backend
Expand Down Expand Up @@ -2530,6 +2537,11 @@ def matmul(self, a, b):
return cp.matmul(a, b)


if cp:
# Only register cp backend if it is installed
register_backend(CupyBackend())


class TensorflowBackend(Backend):

__name__ = "tf"
Expand Down Expand Up @@ -2930,3 +2942,8 @@ def detach(self, *args):

def matmul(self, a, b):
return tnp.matmul(a, b)


if tf:
# Only register tensorflow backend if it is installed
register_backend(TensorflowBackend())
2 changes: 1 addition & 1 deletion ot/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,

log_e = {'err': []}

if type(a) == type(b) == type(M) == np.ndarray:
if nx.__name__ == "numpy":
# Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
Expand Down
2 changes: 1 addition & 1 deletion ot/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def c_transform_entropic(b, M, reg, beta):


def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
log=False):
log=False):
r'''
Compute the transportation matrix to solve the regularized discrete measures optimal transport max problem

Expand Down
82 changes: 23 additions & 59 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ot
import ot.backend
from ot.backend import torch, jax, cp, tf
from ot.backend import torch, jax, tf

import pytest

Expand Down Expand Up @@ -37,17 +37,7 @@ def test_to_numpy(nx):
assert isinstance(M2, np.ndarray)


def test_get_backend():

A = np.zeros((3, 2))
B = np.zeros((3, 1))

nx = get_backend(A)
assert nx.__name__ == 'numpy'

nx = get_backend(A, B)
assert nx.__name__ == 'numpy'

def test_get_backend_invalid():
# error if no parameters
with pytest.raises(ValueError):
get_backend()
Expand All @@ -56,64 +46,38 @@ def test_get_backend():
with pytest.raises(ValueError):
get_backend(1, 2.0)

# test torch
if torch:

A2 = torch.from_numpy(A)
B2 = torch.from_numpy(B)
def test_get_backend(nx):

nx = get_backend(A2)
assert nx.__name__ == 'torch'

nx = get_backend(A2, B2)
assert nx.__name__ == 'torch'

# test not unique types in input
with pytest.raises(ValueError):
get_backend(A, B2)

if jax:

A2 = jax.numpy.array(A)
B2 = jax.numpy.array(B)

nx = get_backend(A2)
assert nx.__name__ == 'jax'

nx = get_backend(A2, B2)
assert nx.__name__ == 'jax'
A = np.zeros((3, 2))
B = np.zeros((3, 1))

# test not unique types in input
with pytest.raises(ValueError):
get_backend(A, B2)
nx_np = get_backend(A)
assert nx_np.__name__ == 'numpy'

if cp:
A2 = cp.asarray(A)
B2 = cp.asarray(B)
A2, B2 = nx.from_numpy(A, B)

nx = get_backend(A2)
assert nx.__name__ == 'cupy'
effective_nx = get_backend(A2)
assert effective_nx.__name__ == nx.__name__

nx = get_backend(A2, B2)
assert nx.__name__ == 'cupy'
effective_nx = get_backend(A2, B2)
assert effective_nx.__name__ == nx.__name__

# test not unique types in input
if nx.__name__ != "numpy":
# test that types mathcing different backends in input raise an error
with pytest.raises(ValueError):
get_backend(A, B2)
else:
# Check that subclassing a numpy array does not break get_backend
# note: This is only tested for numpy as this is hard to be consistent
# with other backends
class nx_subclass(nx.__type__):
pass

if tf:
A2 = tf.convert_to_tensor(A)
B2 = tf.convert_to_tensor(B)

nx = get_backend(A2)
assert nx.__name__ == 'tf'
A3 = nx_subclass(0)

nx = get_backend(A2, B2)
assert nx.__name__ == 'tf'

# test not unique types in input
with pytest.raises(ValueError):
get_backend(A, B2)
effective_nx = get_backend(A3, B2)
assert effective_nx.__name__ == nx.__name__


def test_convert_between_backends(nx):
Expand Down