Skip to content

MRG: Forgotten weights arg in barycenter funcs #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,11 +1037,13 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
"""

if method.lower() == 'sinkhorn':
return barycenter_sinkhorn(A, M, reg, numItermax=numItermax,
return barycenter_sinkhorn(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return barycenter_stabilized(A, M, reg, numItermax=numItermax,
return barycenter_stabilized(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
else:
Expand Down
105 changes: 53 additions & 52 deletions ot/unbalanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,

fi = reg_m / (reg_m + reg)

cpt = 0
err = 1.

while (err > stopThr and cpt < numItermax):
for i in range(numItermax):
uprev = u
vprev = v

Expand All @@ -401,28 +400,27 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
warnings.warn('Numerical errors at iteration %s' % i)
u = uprev
v = vprev
break
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)

err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
if cpt % 200 == 0:
if i % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
cpt += 1
print('{:5d}|{:8e}|'.format(i, err))
if err < stopThr:
break

if log:
log['logu'] = np.log(u + 1e-16)
log['logv'] = np.log(v + 1e-16)
log['logu'] = np.log(u + 1e-300)
log['logv'] = np.log(v + 1e-300)

if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
Expand Down Expand Up @@ -747,8 +745,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
alpha = np.zeros(dim)
beta = np.zeros(dim)
q = np.ones(dim) / dim
while (err > stopThr and cpt < numItermax):
qprev = q
for i in range(numItermax):
qprev = q.copy()
Kv = K.dot(v)
f_alpha = np.exp(- alpha / (reg + reg_m))
f_beta = np.exp(- beta / (reg + reg_m))
Expand Down Expand Up @@ -777,28 +775,29 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
warnings.warn('Numerical errors at iteration %s' % cpt)
q = qprev
break
if (cpt % 10 == 0 and not absorbing) or cpt == 0:
if (i % 10 == 0 and not absorbing) or i == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err = abs(q - qprev).max() / max(abs(q).max(),
abs(qprev).max(), 1.)
if log:
log['err'].append(err)
if verbose:
if cpt % 50 == 0:
if i % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
print('{:5d}|{:8e}|'.format(i, err))
if err < stopThr:
break

cpt += 1
if err > stopThr:
warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
"Try a larger entropy `reg` or a lower mass `reg_m`." +
"Or a larger absorption threshold `tau`.")
if log:
log['niter'] = cpt
log['logu'] = np.log(u + 1e-16)
log['logv'] = np.log(v + 1e-16)
log['niter'] = i
log['logu'] = np.log(u + 1e-300)
log['logv'] = np.log(v + 1e-300)
return q, log
else:
return q
Expand Down Expand Up @@ -882,15 +881,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,

fi = reg_m / (reg_m + reg)

v = np.ones((dim, n_hists)) / dim
u = np.ones((dim, 1)) / dim

cpt = 0
v = np.ones((dim, n_hists))
u = np.ones((dim, 1))
q = np.ones(dim)
err = 1.

while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
for i in range(numItermax):
uprev = u.copy()
vprev = v.copy()
qprev = q.copy()

Kv = K.dot(v)
u = (A / Kv) ** fi
Expand All @@ -905,31 +904,30 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
warnings.warn('Numerical errors at iteration %s' % i)
u = uprev
v = vprev
q = qprev
break
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err_u = abs(u - uprev).max()
err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
err_v = abs(v - vprev).max()
err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
if cpt % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
# compute change in barycenter
err = abs(q - qprev).max()
err /= max(abs(q).max(), abs(qprev).max(), 1.)
if log:
log['err'].append(err)
# if barycenter did not change + at least 10 iterations - stop
if err < stopThr and i > 10:
break

if verbose:
if i % 10 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(i, err))

cpt += 1
if log:
log['niter'] = cpt
log['logu'] = np.log(u + 1e-16)
log['logv'] = np.log(v + 1e-16)
log['niter'] = i
log['logu'] = np.log(u + 1e-300)
log['logv'] = np.log(v + 1e-300)
return q, log
else:
return q
Expand Down Expand Up @@ -1002,19 +1000,22 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,

if method.lower() == 'sinkhorn':
return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)

elif method.lower() == 'sinkhorn_stabilized':
return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
weights=weights,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
return barycenter_unbalanced(A, M, reg, reg_m,
weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
Expand Down