@@ -20,9 +20,8 @@ project_tangent(M::Manifold, x) = project_tangent!(M, similar(x), x)
2020retract (M:: Manifold , x) = retract! (M, copy (x))
2121
2222# Fake objective function implementing a retraction
23- mutable struct ManifoldObjective{T<: NLSolversBase.AbstractObjective } < :
24- NLSolversBase. AbstractObjective
25- manifold:: Manifold
23+ struct ManifoldObjective{M<: Manifold ,T<: AbstractObjective } <: AbstractObjective
24+ manifold:: M
2625 inner_obj:: T
2726end
2827# TODO : is it safe here to call retract! and change x?
@@ -43,6 +42,20 @@ function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
4342 return f_x, g_x
4443end
4544
45+ # In general, we have to compute the gradient/Jacobian separately as it has to be projected
46+ function NLSolversBase. jvp! (obj:: ManifoldObjective , x, v)
47+ xin = retract (obj. manifold, x)
48+ g_x = gradient! (obj. inner_obj, xin)
49+ project_tangent! (obj. manifold, g_x, xin)
50+ return dot (g_x, v)
51+ end
52+ function NLSolversBase. value_jvp! (obj:: ManifoldObjective , x, v)
53+ xin = retract (obj. manifold, x)
54+ f_x, g_x = value_gradient! (obj. inner_obj, xin)
55+ project_tangent! (obj. manifold, g_x, xin)
56+ return f_x, dot (g_x, v)
57+ end
58+
4659""" Flat Euclidean space {R,C}^N, with projections equal to the identity."""
4760struct Flat <: Manifold end
4861# all the functions below are no-ops, and therefore the generated code
@@ -53,6 +66,10 @@ retract!(M::Flat, x) = x
5366project_tangent (M:: Flat , g, x) = g
5467project_tangent! (M:: Flat , g, x) = g
5568
69+ # Optimizations for `Flat` manifold
70+ NLSolversBase. jvp! (obj:: ManifoldObjective{Flat} , x, v) = NLSolversBase. jvp! (obj. inner_obj, x, v)
71+ NLSolversBase. value_jvp! (obj:: ManifoldObjective{Flat} , x, v) = NLSolversBase. value_jvp! (obj. inner_obj, x, v)
72+
5673""" Spherical manifold {|x| = 1}."""
5774struct Sphere <: Manifold end
5875retract! (S:: Sphere , x) = (x ./= norm (x))
0 commit comments