-
Notifications
You must be signed in to change notification settings - Fork 528
Description
Describe the bug
I was using the ot.da.SinkhornTransport
class when I got the following warning:
UserWarning: Warning: numerical errors at iteration 0
.
Thinking about why I got this error, I remembered that cost matrices should be normalised, so I changed the norm
parameter to "max"
to normalise it to the maximum value. However, this gave me an error.
The reason was that the utils.cost_normalization
function is implemented only using numpy, when I'm currently using pytorch tensors. Internally, I got the backend of the C
array that uses this function, and changed the line C /= float(np.max(C))
to C /= float(nx.max(C))
, and then the program worked fine.
I mention this mainly because it may be a bug that happens to someone else in another context.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Window
- Python version: 3.10
- How was POT installed (source,
pip
,conda
): pip - Build command you used (if compiling from source):
- Only for GPU related bugs:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Windows-10-10.0.19045-SP0
Python 3.10.4 (tags/v3.10.4:9d38120, Mar 23 2022, 23:13:41) [MSC v.1929 64 bit (AMD64)]
NumPy 1.24.1
SciPy 1.9.3
POT 0.8.2