Skip to content

Commit 9f51c14

Browse files
example for log treatment in bregman.py
1 parent ea6642c commit 9f51c14

File tree

1 file changed

+73
-49
lines changed

1 file changed

+73
-49
lines changed

ot/bregman.py

Lines changed: 73 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,22 @@
1414
#
1515
# License: MIT License
1616

17-
import numpy as np
17+
import math
1818
import warnings
19-
from .utils import unif, dist
19+
20+
import numpy as np
2021
from scipy.optimize import fmin_l_bfgs_b
22+
from scipy.special import logsumexp
23+
24+
from .utils import unif, dist
25+
26+
27+
def log_matvec(matrix, u, out):
28+
max_matrix = np.max(matrix)
29+
max_u = np.max(u)
30+
np.dot(np.exp(matrix - max_matrix), np.exp(u - max_u), out=out)
31+
np.log(out, out=out)
32+
out += max_matrix + max_u
2133

2234

2335
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -311,61 +323,68 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
311323
ot.optim.cg : General regularized OT
312324
313325
"""
314-
315326
a = np.asarray(a, dtype=np.float64)
316327
b = np.asarray(b, dtype=np.float64)
328+
317329
M = np.asarray(M, dtype=np.float64)
318330

319331
if len(a) == 0:
320-
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
332+
a = np.ones((M.shape[0], 1), dtype=np.float64) / M.shape[0]
321333
if len(b) == 0:
322-
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
323-
324-
# init data
325-
dim_a = len(a)
326-
dim_b = len(b)
334+
b = np.ones((M.shape[1], 1), dtype=np.float64) / M.shape[1]
327335

328336
if len(b.shape) > 1:
329337
n_hists = b.shape[1]
330338
else:
331339
n_hists = 0
332340

341+
if len(a.shape) == 1:
342+
a = a[:, None]
343+
344+
if len(b.shape) == 1:
345+
b = b[:, None]
346+
347+
log_threshold = math.log(stopThr)
348+
is_logweight = kwargs.get('is_logweight', False)
349+
350+
if not is_logweight:
351+
a = np.log(a)
352+
b = np.log(b)
353+
354+
# init data
355+
dim_a = len(a)
356+
dim_b = len(b)
357+
333358
if log:
334359
log = {'err': []}
335360

336361
# we assume that no distances are null except those of the diagonal of
337362
# distances
338363
if n_hists:
339-
u = np.ones((dim_a, n_hists)) / dim_a
340-
v = np.ones((dim_b, n_hists)) / dim_b
364+
u = np.zeros((dim_a, n_hists)) - math.log(dim_a)
365+
v = np.zeros((dim_b, n_hists)) - math.log(dim_b)
341366
else:
342-
u = np.ones(dim_a) / dim_a
343-
v = np.ones(dim_b) / dim_b
344-
345-
# print(reg)
346-
347-
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
348-
K = np.empty(M.shape, dtype=M.dtype)
349-
np.divide(M, -reg, out=K)
350-
np.exp(K, out=K)
367+
u = np.zeros((dim_a, 1)) - math.log(dim_a)
368+
v = np.zeros((dim_b, 1)) - math.log(dim_b)
351369

352-
# print(np.min(K))
353-
tmp2 = np.empty(b.shape, dtype=M.dtype)
370+
log_K = -M / reg
354371

355-
Kp = (1 / a).reshape(-1, 1) * K
372+
log_Kp = -a.reshape(-1, 1) + log_K
373+
log_K_T = log_K.T
356374
cpt = 0
357-
err = 1
358-
while (err > stopThr and cpt < numItermax):
375+
log_err = 0.5 * log_threshold
376+
377+
while log_err > log_threshold and cpt < numItermax:
359378
uprev = u
360379
vprev = v
361380

362-
KtransposeU = np.dot(K.T, u)
363-
v = np.divide(b, KtransposeU)
364-
u = 1. / np.dot(Kp, v)
381+
log_matvec(log_K_T, u, v)
382+
v *= -1
383+
v += b
384+
log_matvec(log_Kp, v, u)
385+
u *= -1
365386

366-
if (np.any(KtransposeU == 0)
367-
or np.any(np.isnan(u)) or np.any(np.isnan(v))
368-
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
387+
if np.any(~np.isfinite(u)) or np.any(~np.isfinite(v)):
369388
# we have reached the machine precision
370389
# come back to previous solution and quit loop
371390
print('Warning: numerical errors at iteration', cpt)
@@ -375,27 +394,32 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
375394
if cpt % 10 == 0:
376395
# we can speed up the process by checking for the error only all
377396
# the 10th iterations
378-
if n_hists:
379-
np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
380-
else:
381-
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
382-
np.einsum('i,ij,j->j', u, K, v, out=tmp2)
383-
err = np.linalg.norm(tmp2 - b) # violation of marginal
397+
temp2 = u + log_K + v.T
398+
temp2 = logsumexp(temp2, axis=0, keepdims=True).T
399+
# noinspection PyTypeChecker
400+
log_err = 0.5 * np.sum(np.exp(2 * temp2) - np.exp(2 * b)) # violation of marginal
401+
# would be more efficient with a check on stability of dual vectors
384402
if log:
385-
log['err'].append(err)
403+
log['err'].append(math.exp(log_err))
386404

387405
if verbose:
388406
if cpt % 200 == 0:
389407
print(
390408
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
391-
print('{:5d}|{:8e}|'.format(cpt, err))
409+
print('{:5d}|{:8e}|'.format(cpt, np.exp(log_err)))
392410
cpt = cpt + 1
393411
if log:
394-
log['u'] = u
395-
log['v'] = v
396-
412+
log['u'] = np.exp(u) if not is_logweight else u
413+
log['v'] = np.exp(v) if not is_logweight else v
414+
415+
gamma = u + log_K + v.T
416+
res = logsumexp(gamma, axis=(0, 1), b=M)
417+
if not is_logweight:
418+
gamma = np.exp(gamma)
419+
res = np.exp(res)
420+
if log:
421+
log['cost'] = res
397422
if n_hists: # return only loss
398-
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
399423
if log:
400424
return res, log
401425
else:
@@ -404,9 +428,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
404428
else: # return OT matrix
405429

406430
if log:
407-
return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log
431+
return gamma.squeeze(), log
408432
else:
409-
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
433+
return gamma.squeeze()
410434

411435

412436
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
@@ -716,7 +740,7 @@ def get_Gamma(alpha, beta, u, v):
716740
if np.abs(u).max() > tau or np.abs(v).max() > tau:
717741
if n_hists:
718742
alpha, beta = alpha + reg * \
719-
np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
743+
np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
720744
else:
721745
alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
722746
if n_hists:
@@ -2182,11 +2206,11 @@ def projection(u, epsilon):
21822206

21832207
# box constraints in L-BFGS-B (see Proposition 1 in [26])
21842208
bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
2185-
ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
2209+
ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
21862210

21872211
bounds_v = [(
2188-
max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
2189-
epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
2212+
max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
2213+
epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
21902214

21912215
# pre-calculated constants for the objective
21922216
vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)

0 commit comments

Comments
 (0)