Skip to content

Commit 28fe869

Browse files
committed
update documentation
1 parent 2c27a43 commit 28fe869

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

ot/solvers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
140140
# or for original Sinkhorn paper formulation [2]
141141
res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
142142
143+
# Use implicit differentiation for memory saving
144+
res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors
145+
res.value.backward() # only the value is differentiable
146+
147+
Note that by default the Sinkhorn solver uses automatic differentiation to
148+
compute the gradients of the values and plan. This can be changed with the
149+
`grad` parameter. The `implicit` mode computes the implicit gradients only
150+
for the value and the other outputs are detached. This is useful for
151+
memory saving when only the gradient of value is needed.
152+
143153
- **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``):
144154
145155
.. math::
@@ -1024,6 +1034,16 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
10241034
# lazy OT plan
10251035
lazy_plan = res.lazy_plan
10261036
1037+
# Use implicit differentiation for memory saving
1038+
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit')
1039+
res.value.backward() # only the value is differentiable
1040+
1041+
Note that by default the Sinkhorn solver uses automatic differentiation to
1042+
compute the gradients of the values and plan. This can be changed with the
1043+
`grad` parameter. The `implicit` mode computes the implicit gradients only
1044+
for the value and the other outputs are detached. This is useful for
1045+
memory saving when only the gradient of value is needed.
1046+
10271047
We also have a very efficient solver with compiled CPU/CUDA code using
10281048
geomloss/PyKeOps that can be used with the following code:
10291049

0 commit comments

Comments
 (0)