Skip to content

Commit 7b2c99a

Browse files
author
ncassereau
committed
Tests
1 parent 55191bc commit 7b2c99a

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

test/test_da.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,14 @@ def test_mapping_transport_class():
565565
otda.fit(Xs=Xs, Xt=Xt)
566566
assert len(otda.log_.keys()) != 0
567567

568+
# check that it does not crash when derphi is very close to 0
569+
np.random.seed(39)
570+
Xs, ys = make_data_classif('3gauss', ns)
571+
Xt, yt = make_data_classif('3gauss2', nt)
572+
otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
573+
otda.fit(Xs=Xs, Xt=Xt)
574+
np.random.seed(None)
575+
568576

569577
def test_linear_mapping():
570578
ns = 150

test/test_optim.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,28 @@ def test_line_search_armijo():
114114
# Should not throw an exception and return None for alpha
115115
alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
116116
assert alpha is None
117+
118+
# check line search armijo
119+
def f(x):
120+
return np.sum((x - 5.0) ** 2)
121+
122+
def grad(x):
123+
return 2 * (x - 5.0)
124+
125+
xk = np.array([[[-5.0, -5.0]]])
126+
pk = np.array([[[100.0, 100.0]]])
127+
gfk = grad(xk)
128+
old_fval = f(xk)
129+
130+
# chech the case where the optimum is on the direction
131+
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
132+
np.testing.assert_allclose(alpha, 0.1)
133+
134+
# check the case where the direction is not far enough
135+
pk = np.array([[[3.0, 3.0]]])
136+
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
137+
np.testing.assert_allclose(alpha, 1.0)
138+
139+
# check the case where the checking the wrong direction
140+
alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
141+
assert alpha <= 0

0 commit comments

Comments
 (0)