@@ -1507,15 +1507,19 @@ class TorchBackend(Backend):
1507
1507
1508
1508
def __init__ (self ):
1509
1509
1510
- self .rng_ = torch .Generator ()
1510
+ self .rng_ = torch .Generator ("cpu" )
1511
1511
self .rng_ .seed ()
1512
1512
1513
1513
self .__type_list__ = [torch .tensor (1 , dtype = torch .float32 ),
1514
1514
torch .tensor (1 , dtype = torch .float64 )]
1515
1515
1516
1516
if torch .cuda .is_available ():
1517
+ self .rng_cuda_ = torch .Generator ("cuda" )
1518
+ self .rng_cuda_ .seed ()
1517
1519
self .__type_list__ .append (torch .tensor (1 , dtype = torch .float32 , device = 'cuda' ))
1518
1520
self .__type_list__ .append (torch .tensor (1 , dtype = torch .float64 , device = 'cuda' ))
1521
+ else :
1522
+ self .rng_cuda_ = torch .Generator ("cpu" )
1519
1523
1520
1524
from torch .autograd import Function
1521
1525
@@ -1761,20 +1765,26 @@ def reshape(self, a, shape):
1761
1765
def seed (self , seed = None ):
1762
1766
if isinstance (seed , int ):
1763
1767
self .rng_ .manual_seed (seed )
1768
+ self .rng_cuda_ .manual_seed (seed )
1764
1769
elif isinstance (seed , torch .Generator ):
1765
- self .rng_ = seed
1770
+ if self .device_type (seed ) == "GPU" :
1771
+ self .rng_cuda_ = seed
1772
+ else :
1773
+ self .rng_ = seed
1766
1774
else :
1767
1775
raise ValueError ("Non compatible seed : {}" .format (seed ))
1768
1776
1769
1777
def rand (self , * size , type_as = None ):
1770
1778
if type_as is not None :
1771
- return torch .rand (size = size , generator = self .rng_ , dtype = type_as .dtype , device = type_as .device )
1779
+ generator = self .rng_cuda_ if self .device_type (type_as ) == "GPU" else self .rng_
1780
+ return torch .rand (size = size , generator = generator , dtype = type_as .dtype , device = type_as .device )
1772
1781
else :
1773
1782
return torch .rand (size = size , generator = self .rng_ )
1774
1783
1775
1784
def randn (self , * size , type_as = None ):
1776
1785
if type_as is not None :
1777
- return torch .randn (size = size , dtype = type_as .dtype , generator = self .rng_ , device = type_as .device )
1786
+ generator = self .rng_cuda_ if self .device_type (type_as ) == "GPU" else self .rng_
1787
+ return torch .randn (size = size , dtype = type_as .dtype , generator = generator , device = type_as .device )
1778
1788
else :
1779
1789
return torch .randn (size = size , generator = self .rng_ )
1780
1790
0 commit comments