Skip to content

Commit 726e84e

Browse files
authored
[MRG] Torch random generator not working for Cuda tensor (#373)
* Solve bug * Update release file
1 parent ccc076e commit 726e84e

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

RELEASES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
- Added Generalized Wasserstein Barycenter solver + example (PR #372)
88

9+
#### Closed issues
10+
11+
- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
12+
(Issue #371, PR #373)
13+
914

1015
## 0.8.2
1116

ot/backend.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,15 +1507,19 @@ class TorchBackend(Backend):
15071507

15081508
def __init__(self):
15091509

1510-
self.rng_ = torch.Generator()
1510+
self.rng_ = torch.Generator("cpu")
15111511
self.rng_.seed()
15121512

15131513
self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
15141514
torch.tensor(1, dtype=torch.float64)]
15151515

15161516
if torch.cuda.is_available():
1517+
self.rng_cuda_ = torch.Generator("cuda")
1518+
self.rng_cuda_.seed()
15171519
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
15181520
self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
1521+
else:
1522+
self.rng_cuda_ = torch.Generator("cpu")
15191523

15201524
from torch.autograd import Function
15211525

@@ -1761,20 +1765,26 @@ def reshape(self, a, shape):
17611765
def seed(self, seed=None):
17621766
if isinstance(seed, int):
17631767
self.rng_.manual_seed(seed)
1768+
self.rng_cuda_.manual_seed(seed)
17641769
elif isinstance(seed, torch.Generator):
1765-
self.rng_ = seed
1770+
if self.device_type(seed) == "GPU":
1771+
self.rng_cuda_ = seed
1772+
else:
1773+
self.rng_ = seed
17661774
else:
17671775
raise ValueError("Non compatible seed : {}".format(seed))
17681776

17691777
def rand(self, *size, type_as=None):
17701778
if type_as is not None:
1771-
return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
1779+
generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
1780+
return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device)
17721781
else:
17731782
return torch.rand(size=size, generator=self.rng_)
17741783

17751784
def randn(self, *size, type_as=None):
17761785
if type_as is not None:
1777-
return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
1786+
generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
1787+
return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device)
17781788
else:
17791789
return torch.randn(size=size, generator=self.rng_)
17801790

0 commit comments

Comments
 (0)