Skip to content

Commit b5c9710

Browse files
committed
Compute JVP in line searches
1 parent a466ee4 commit b5c9710

File tree

24 files changed

+167
-133
lines changed

24 files changed

+167
-133
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ jobs:
1111
strategy:
1212
matrix:
1313
version:
14-
- "min"
15-
- "lts"
14+
# - "min"
15+
# - "lts"
1616
- "1"
1717
os:
1818
- ubuntu-latest

.github/workflows/Docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
with:
2020
version: '1'
2121
- name: Install dependencies
22-
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
22+
run: julia --project=docs/ -e 'using Pkg; Pkg.instantiate()'
2323
- name: Build and deploy
2424
env:
2525
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ ExplicitImports = "1.13.2"
3030
FillArrays = "0.6.2, 0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
3131
ForwardDiff = "0.10, 1"
3232
JET = "0.9, 0.10"
33-
LineSearches = "7.5.1"
33+
LineSearches = "8"
3434
LinearAlgebra = "<0.0.1, 1.6"
3535
MathOptInterface = "1.17"
3636
Measurements = "2.14.1"
37-
NLSolversBase = "7.9.0"
37+
NLSolversBase = "8"
3838
NaNMath = "0.3.2, 1"
3939
OptimTestProblems = "2.0.3"
4040
PositiveFactorizations = "0.2.2"
@@ -65,3 +65,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6565

6666
[targets]
6767
test = ["Test", "Aqua", "Distributions", "ExplicitImports", "ForwardDiff", "JET", "MathOptInterface", "Measurements", "OptimTestProblems", "Random", "RecursiveArrayTools", "StableRNGs", "ReverseDiff"]
68+
69+
[sources]
70+
LineSearches = { url = "https://github.com/devmotion/LineSearches.jl.git", rev = "dmw/jvp" }
71+
NLSolversBase = { url = "https://github.com/devmotion/NLSolversBase.jl.git", rev = "dmw/jvp" }

docs/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1212
Documenter = "1"
1313
Literate = "2"
1414

15-
[sources.Optim]
16-
path = ".."
15+
[sources]
16+
Optim = { path = ".." }
17+
NLSolversBase = { url = "https://github.com/devmotion/NLSolversBase.jl.git", rev = "dmw/jvp" }

docs/src/examples/ipnewton_basics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# constraint is unbounded from below or above respectively.
2323

2424
using Optim, NLSolversBase #hide
25+
import ADTypes #hide
2526
import NLSolversBase: clear! #hide
2627

2728
# # Constrained optimization with `IPNewton`

ext/OptimMOIExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module OptimMOIExt
22

33
using Optim
4-
using Optim.LinearAlgebra: rmul!
4+
using Optim.LinearAlgebra: rmul!
55
import MathOptInterface as MOI
66

77
function __init__()

src/Manifolds.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ project_tangent(M::Manifold, x) = project_tangent!(M, similar(x), x)
2020
retract(M::Manifold, x) = retract!(M, copy(x))
2121

2222
# Fake objective function implementing a retraction
23-
mutable struct ManifoldObjective{T<:NLSolversBase.AbstractObjective} <:
24-
NLSolversBase.AbstractObjective
25-
manifold::Manifold
23+
struct ManifoldObjective{M<:Manifold,T<:AbstractObjective} <: AbstractObjective
24+
manifold::M
2625
inner_obj::T
2726
end
2827
# TODO: is it safe here to call retract! and change x?
@@ -43,6 +42,20 @@ function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
4342
return f_x, g_x
4443
end
4544

45+
# In general, we have to compute the gradient/Jacobian separately as it has to be projected
46+
function NLSolversBase.jvp!(obj::ManifoldObjective, x, v)
47+
xin = retract(obj.manifold, x)
48+
g_x = gradient!(obj.inner_obj, xin)
49+
project_tangent!(obj.manifold, g_x, xin)
50+
return dot(g_x, v)
51+
end
52+
function NLSolversBase.value_jvp!(obj::ManifoldObjective, x, v)
53+
xin = retract(obj.manifold, x)
54+
f_x, g_x = value_gradient!(obj.inner_obj, xin)
55+
project_tangent!(obj.manifold, g_x, xin)
56+
return f_x, dot(g_x, v)
57+
end
58+
4659
"""Flat Euclidean space {R,C}^N, with projections equal to the identity."""
4760
struct Flat <: Manifold end
4861
# all the functions below are no-ops, and therefore the generated code
@@ -53,6 +66,10 @@ retract!(M::Flat, x) = x
5366
project_tangent(M::Flat, g, x) = g
5467
project_tangent!(M::Flat, g, x) = g
5568

69+
# Optimizations for `Flat` manifold
70+
NLSolversBase.jvp!(obj::ManifoldObjective{Flat}, x, v) = NLSolversBase.jvp!(obj.inner_obj, x, v)
71+
NLSolversBase.value_jvp!(obj::ManifoldObjective{Flat}, x, v) = NLSolversBase.value_jvp!(obj.inner_obj, x, v)
72+
5673
"""Spherical manifold {|x| = 1}."""
5774
struct Sphere <: Manifold end
5875
retract!(S::Sphere, x) = (x ./= norm(x))

src/Optim.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ documentation online at http://julianlsolvers.github.io/Optim.jl/stable/ .
1616
"""
1717
module Optim
1818

19+
import ADTypes
20+
1921
using PositiveFactorizations: Positive # for globalization strategy in Newton
2022

2123
using LineSearches: LineSearches # for globalization strategy in Quasi-Newton algs
@@ -35,7 +37,6 @@ using NLSolversBase:
3537
NonDifferentiable,
3638
OnceDifferentiable,
3739
TwiceDifferentiable,
38-
TwiceDifferentiableHV,
3940
AbstractConstraints,
4041
ConstraintBounds,
4142
TwiceDifferentiableConstraints,

src/api.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ f_calls(d::AbstractObjective) = NLSolversBase.f_calls(d)
104104

105105
g_calls(r::OptimizationResults) = error("g_calls is not implemented for $(summary(r)).")
106106
g_calls(r::MultivariateOptimizationResults) = r.g_calls
107-
g_calls(d::AbstractObjective) = NLSolversBase.g_calls(d)
107+
g_calls(d::AbstractObjective) = NLSolversBase.g_calls(d) + NLSolversBase.jvp_calls(d)
108108

109109
h_calls(r::OptimizationResults) = error("h_calls is not implemented for $(summary(r)).")
110110
h_calls(r::MultivariateOptimizationResults) = r.h_calls

src/multivariate/optimize/interface.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,6 @@ promote_objtype(
6666
inplace::Bool,
6767
f::InplaceObjective,
6868
) = TwiceDifferentiable(f, x, real(zero(eltype(x))))
69-
promote_objtype(
70-
method::SecondOrderOptimizer,
71-
x,
72-
autodiff::ADTypes.AbstractADType,
73-
inplace::Bool,
74-
f::NLSolversBase.InPlaceObjectiveFGHv,
75-
) = TwiceDifferentiableHV(f, x)
76-
promote_objtype(
77-
method::SecondOrderOptimizer,
78-
x,
79-
autodiff::ADTypes.AbstractADType,
80-
inplace::Bool,
81-
f::NLSolversBase.InPlaceObjectiveFG_Hv,
82-
) = TwiceDifferentiableHV(f, x)
8369
promote_objtype(method::SecondOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, f, g) =
8470
TwiceDifferentiable(
8571
f,

0 commit comments

Comments
 (0)