-
Notifications
You must be signed in to change notification settings - Fork 528
Closed
Labels
Description
Describe the bug
ot.optim.generic_conditional_gradient
uses the following cost function for line search (when reg2
is provided):
def cost(G):
return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G))
When matrix G
contains negative values, nx.log(G)
returns nan
with a RuntimeWarning
.
Code sample
>>> from ot.da import SinkhornL1l2Transport
>>> from ot.datasets import make_data_classif
>>> ns, nt = 50, 50
>>> Xs, ys = make_data_classif('3gauss', ns)
>>> Xt, yt = make_data_classif('3gauss2', nt)
>>> otda = SinkhornL1l2Transport()
>>> otda.fit(Xs=Xs, ys=ys, Xt=Xt)
ot/backend.py:1082: RuntimeWarning: invalid value encountered in log
return np.log(a)
<ot.da.SinkhornL1l2Transport object at 0x10327cb50>
This example is used for test_sinkhorn_l1l2_transport_class
in test/test_da.py.
Expected behavior
The question now is negative values are expected to happen while performing line search? If yes, should we evaluate the cost to -inf
in this case? Or is it okay for the cost to be nan
(if this is the case, we should short-circuit the call to avoid warnings).
Environment (please complete the following information):
macOS-13.4-arm64-arm-64bit
Python 3.8.16 (default, Mar 1 2023, 21:18:45)
[Clang 14.0.6 ]
NumPy 1.24.3
SciPy 1.10.1
POT 0.9.1