@@ -276,7 +276,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
276
276
- p : distribution in the source space
277
277
- q : distribution in the target space
278
278
- L : loss function to account for the misfit between the similarity matrices
279
- - H : entropy
280
279
281
280
Parameters
282
281
----------
@@ -343,6 +342,83 @@ def df(G):
343
342
return cg (p , q , 0 , 1 , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
344
343
345
344
345
+ def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False , armijo = False , ** kwargs ):
346
+ """
347
+ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
348
+
349
+ The function solves the following optimization problem:
350
+
351
+ .. math::
352
+ GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
353
+
354
+ Where :
355
+ - C1 : Metric cost matrix in the source space
356
+ - C2 : Metric cost matrix in the target space
357
+ - p : distribution in the source space
358
+ - q : distribution in the target space
359
+ - L : loss function to account for the misfit between the similarity matrices
360
+
361
+ Parameters
362
+ ----------
363
+ C1 : ndarray, shape (ns, ns)
364
+ Metric cost matrix in the source space
365
+ C2 : ndarray, shape (nt, nt)
366
+ Metric cost matrix in the target space
367
+ p : ndarray, shape (ns,)
368
+ Distribution in the source space.
369
+ q : ndarray, shape (nt,)
370
+ Distribution in the target space.
371
+ loss_fun : str
372
+ loss function used for the solver either 'square_loss' or 'kl_loss'
373
+ max_iter : int, optional
374
+ Max number of iterations
375
+ tol : float, optional
376
+ Stop threshold on error (>0)
377
+ verbose : bool, optional
378
+ Print information along iterations
379
+ log : bool, optional
380
+ record log if True
381
+ armijo : bool, optional
382
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
383
+ If there is convergence issues use False.
384
+
385
+ Returns
386
+ -------
387
+ gw_dist : float
388
+ Gromov-Wasserstein distance
389
+ log : dict
390
+ convergence information and Coupling marix
391
+
392
+ References
393
+ ----------
394
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
395
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
396
+ International Conference on Machine Learning (ICML). 2016.
397
+
398
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
399
+ metric approach to object matching. Foundations of computational
400
+ mathematics 11.4 (2011): 417-487.
401
+
402
+ """
403
+
404
+ constC , hC1 , hC2 = init_matrix (C1 , C2 , p , q , loss_fun )
405
+
406
+ G0 = p [:, None ] * q [None , :]
407
+
408
+ def f (G ):
409
+ return gwloss (constC , hC1 , hC2 , G )
410
+
411
+ def df (G ):
412
+ return gwggrad (constC , hC1 , hC2 , G )
413
+ res , log_gw = cg (p , q , 0 , 1 , f , df , G0 , log = True , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
414
+ log_gw ['gw_dist' ] = gwloss (constC , hC1 , hC2 , res )
415
+ log_gw ['T' ] = res
416
+ if log :
417
+ return log_gw ['gw_dist' ], log_gw
418
+ else :
419
+ return log_gw ['gw_dist' ]
420
+
421
+
346
422
def fused_gromov_wasserstein (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , log = False , ** kwargs ):
347
423
"""
348
424
Computes the FGW transport between two graphs see [24]
@@ -506,84 +582,6 @@ def df(G):
506
582
return log ['fgw_dist' ]
507
583
508
584
509
- def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False , armijo = False , ** kwargs ):
510
- """
511
- Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
512
-
513
- The function solves the following optimization problem:
514
-
515
- .. math::
516
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
517
-
518
- Where :
519
- - C1 : Metric cost matrix in the source space
520
- - C2 : Metric cost matrix in the target space
521
- - p : distribution in the source space
522
- - q : distribution in the target space
523
- - L : loss function to account for the misfit between the similarity matrices
524
- - H : entropy
525
-
526
- Parameters
527
- ----------
528
- C1 : ndarray, shape (ns, ns)
529
- Metric cost matrix in the source space
530
- C2 : ndarray, shape (nt, nt)
531
- Metric cost matrix in the target space
532
- p : ndarray, shape (ns,)
533
- Distribution in the source space.
534
- q : ndarray, shape (nt,)
535
- Distribution in the target space.
536
- loss_fun : str
537
- loss function used for the solver either 'square_loss' or 'kl_loss'
538
- max_iter : int, optional
539
- Max number of iterations
540
- tol : float, optional
541
- Stop threshold on error (>0)
542
- verbose : bool, optional
543
- Print information along iterations
544
- log : bool, optional
545
- record log if True
546
- armijo : bool, optional
547
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
548
- If there is convergence issues use False.
549
-
550
- Returns
551
- -------
552
- gw_dist : float
553
- Gromov-Wasserstein distance
554
- log : dict
555
- convergence information and Coupling marix
556
-
557
- References
558
- ----------
559
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
560
- "Gromov-Wasserstein averaging of kernel and distance matrices."
561
- International Conference on Machine Learning (ICML). 2016.
562
-
563
- .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
564
- metric approach to object matching. Foundations of computational
565
- mathematics 11.4 (2011): 417-487.
566
-
567
- """
568
-
569
- constC , hC1 , hC2 = init_matrix (C1 , C2 , p , q , loss_fun )
570
-
571
- G0 = p [:, None ] * q [None , :]
572
-
573
- def f (G ):
574
- return gwloss (constC , hC1 , hC2 , G )
575
-
576
- def df (G ):
577
- return gwggrad (constC , hC1 , hC2 , G )
578
- res , log = cg (p , q , 0 , 1 , f , df , G0 , log = True , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
579
- log ['gw_dist' ] = gwloss (constC , hC1 , hC2 , res )
580
- log ['T' ] = res
581
- if log :
582
- return log ['gw_dist' ], log
583
- else :
584
- return log ['gw_dist' ]
585
-
586
-
587
585
def entropic_gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon ,
588
586
max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
589
587
"""
0 commit comments