@@ -114,3 +114,28 @@ def test_line_search_armijo():
114
114
# Should not throw an exception and return None for alpha
115
115
alpha , _ , _ = ot .optim .line_search_armijo (lambda x : 1 , xk , pk , gfk , old_fval )
116
116
assert alpha is None
117
+
118
+ # check line search armijo
119
+ def f (x ):
120
+ return np .sum ((x - 5.0 ) ** 2 )
121
+
122
+ def grad (x ):
123
+ return 2 * (x - 5.0 )
124
+
125
+ xk = np .array ([[[- 5.0 , - 5.0 ]]])
126
+ pk = np .array ([[[100.0 , 100.0 ]]])
127
+ gfk = grad (xk )
128
+ old_fval = f (xk )
129
+
130
+ # chech the case where the optimum is on the direction
131
+ alpha , _ , _ = ot .optim .line_search_armijo (f , xk , pk , gfk , old_fval )
132
+ np .testing .assert_allclose (alpha , 0.1 )
133
+
134
+ # check the case where the direction is not far enough
135
+ pk = np .array ([[[3.0 , 3.0 ]]])
136
+ alpha , _ , _ = ot .optim .line_search_armijo (f , xk , pk , gfk , old_fval , alpha0 = 1.0 )
137
+ np .testing .assert_allclose (alpha , 1.0 )
138
+
139
+ # check the case where the checking the wrong direction
140
+ alpha , _ , _ = ot .optim .line_search_armijo (f , xk , - pk , gfk , old_fval )
141
+ assert alpha <= 0
0 commit comments