Skip to content

Numerical issue with entropy regularization in generic_conditional_gradient #502

@kachayev

Description

@kachayev

Describe the bug

ot.optim.generic_conditional_gradient uses the following cost function for line search (when reg2 is provided):

def cost(G):
    return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G))

When matrix G contains negative values, nx.log(G) returns nan with a RuntimeWarning.

Code sample

>>> from ot.da import SinkhornL1l2Transport
>>> from ot.datasets import make_data_classif
>>> ns, nt = 50, 50
>>> Xs, ys = make_data_classif('3gauss', ns)
>>> Xt, yt = make_data_classif('3gauss2', nt)
>>> otda = SinkhornL1l2Transport()
>>> otda.fit(Xs=Xs, ys=ys, Xt=Xt)
ot/backend.py:1082: RuntimeWarning: invalid value encountered in log
  return np.log(a)
<ot.da.SinkhornL1l2Transport object at 0x10327cb50>

This example is used for test_sinkhorn_l1l2_transport_class in test/test_da.py.

Expected behavior

The question now is negative values are expected to happen while performing line search? If yes, should we evaluate the cost to -inf in this case? Or is it okay for the cost to be nan (if this is the case, we should short-circuit the call to avoid warnings).

Environment (please complete the following information):

macOS-13.4-arm64-arm-64bit
Python 3.8.16 (default, Mar  1 2023, 21:18:45)
[Clang 14.0.6 ]
NumPy 1.24.3
SciPy 1.10.1
POT 0.9.1

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions