@@ -140,6 +140,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
140
140
# or for original Sinkhorn paper formulation [2]
141
141
res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
142
142
143
+ # Use implicit differentiation for memory saving
144
+ res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors
145
+ res.value.backward() # only the value is differentiable
146
+
147
+ Note that by default the Sinkhorn solver uses automatic differentiation to
148
+ compute the gradients of the values and plan. This can be changed with the
149
+ `grad` parameter. The `implicit` mode computes the implicit gradients only
150
+ for the value and the other outputs are detached. This is useful for
151
+ memory saving when only the gradient of value is needed.
152
+
143
153
- **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``):
144
154
145
155
.. math::
@@ -1024,6 +1034,16 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
1024
1034
# lazy OT plan
1025
1035
lazy_plan = res.lazy_plan
1026
1036
1037
+ # Use implicit differentiation for memory saving
1038
+ res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit')
1039
+ res.value.backward() # only the value is differentiable
1040
+
1041
+ Note that by default the Sinkhorn solver uses automatic differentiation to
1042
+ compute the gradients of the values and plan. This can be changed with the
1043
+ `grad` parameter. The `implicit` mode computes the implicit gradients only
1044
+ for the value and the other outputs are detached. This is useful for
1045
+ memory saving when only the gradient of value is needed.
1046
+
1027
1047
We also have a very efficient solver with compiled CPU/CUDA code using
1028
1048
geomloss/PyKeOps that can be used with the following code:
1029
1049
0 commit comments