Skip to content

The function 'cost_normalization' do not use the backend and only use numpy #465

@framunoz

Description

@framunoz

Describe the bug

I was using the ot.da.SinkhornTransport class when I got the following warning:
UserWarning: Warning: numerical errors at iteration 0.

Thinking about why I got this error, I remembered that cost matrices should be normalised, so I changed the norm parameter to "max" to normalise it to the maximum value. However, this gave me an error.

The reason was that the utils.cost_normalization function is implemented only using numpy, when I'm currently using pytorch tensors. Internally, I got the backend of the C array that uses this function, and changed the line C /= float(np.max(C)) to C /= float(nx.max(C)), and then the program worked fine.

I mention this mainly because it may be a bug that happens to someone else in another context.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Window
  • Python version: 3.10
  • How was POT installed (source, pip, conda): pip
  • Build command you used (if compiling from source):
  • Only for GPU related bugs:
    • CUDA version:
    • GPU models and configuration:
    • Any other relevant information:

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Windows-10-10.0.19045-SP0
Python 3.10.4 (tags/v3.10.4:9d38120, Mar 23 2022, 23:13:41) [MSC v.1929 64 bit (AMD64)]
NumPy 1.24.1
SciPy 1.9.3
POT 0.8.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions