-
Notifications
You must be signed in to change notification settings - Fork 36
Use NoCache
to improve set_to_zero!!
performance with Mooncake
#975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Benchmark Report for Commit de57eddComputer Information
Benchmark Results
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #975 +/- ##
==========================================
+ Coverage 82.97% 83.10% +0.12%
==========================================
Files 36 37 +1
Lines 3965 4025 +60
==========================================
+ Hits 3290 3345 +55
- Misses 675 680 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Test Coverage Report for Build 16414511158Details
💛 - Coveralls |
DynamicPPL.jl documentation for PR #975 is available at: |
@@ -1,9 +1,184 @@ | |||
module DynamicPPLMooncakeExt | |||
|
|||
__precompile__(false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to do this because I am overloading Mooncake.set_to_zero!!
at the bottom of this file.
Alternatively, I can define set_to_zero!!
only on tangent types, but it might be trivial as these can be deeply recursive functions. So careful implementation might need to define this function for many types.
Any better ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the problem in defining set_to_zero!!
for all possible types that DPPL might bring up, but the pretty heavy-handed type piracy of defining Mooncake.set_to_zero!!(x)
does trouble me. For instance, might that cause a lot of invalidations, and thus a lot more Mooncake recompilation?
Also, the fact that the code relies on checking field names means that if someone just happens to define a type with the same field names, the behaviour of Mooncake on those types would depend on whether DynamicPPL is loaded in the same environment. It feels unlikely to happen, but it could lead to some truly horrendous bugs to track if it did, and also just feels like we are messing with other people's code in an inconsiderate manner.
I don't really understand the context here, but it look like this is dealing with some Mooncake issues related to circular references. Any chance that some of the machinery for dealing with that (like declaring certain types as safe/unsafe) could be implemented in Mooncake itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with and share the concerns.
The issue, from my POV, is that, Mooncake need to handle potential cases of circular references, by paying little price in a slightly conservative manner. It turns out, for simple DynamicPPL models, the little price of initializing and looking up in IdDict matter and make benchmarks look bad.
Ideally, there would be systematic changes in Mooncake so that one can tell Mooncake "I promise there will not be circular ref, so no need for IdDict".
As for the invalidations, I have to admit that I don't know how bad it would be. Given that for Mooncake, Julia would precompile set_to_zero_internal!!
(https://github.com/chalk-lab/Mooncake.jl/blob/a26b5c35c55d1e98b9e8c6bfafbbe3dc55784140/src/tangents.jl#L728-L733). Given that set_to_zero!!
at the moment is pretty much just an alias, I don't expect the cost is too grand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that for Mooncake, Julia would precompile set_to_zero_internal!! (https://github.com/chalk-lab/Mooncake.jl/blob/a26b5c35c55d1e98b9e8c6bfafbbe3dc55784140/src/tangents.jl#L728-L733).
I wonder if inlining might mess with this, if the call to set_to_zero_internal!!
is inlined into set_to_zero!!
.
How bad would it be if you used dispatch to make set_to_zero
overloaded only for DPPL-owned types, and thus avoided the type piracy? I see the two ways of doing this:
- Replace the checks of
hasfield
with type bounds. I don't know if Mooncake's several different tangent types make this really hard. Thehasfield
checks aren't any more generic thanisa
checks, though, because the field names effectively determine the type, right? - Have a really high-level check, where you define
Mooncake.set_to_zero!!(x::AllDPPLTypes)
andAllDPPTypes = Union{AbstractVarInfo,Context,AbstractAccumulator, etc.}
, i.e. catch all the highest-level types in our type hierarchy.
Is either of those workable?
Note another reason to avoid hasfield
checks: They create extra maintenance burden, because they would fail in a silent way that only shows as reduced performance every time we change the name of a field of any of these types.
Gonna ping @willtebbutt for his thoughts on how set_to_zero!!
should be overloaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> struct A
a
b
end
julia> struct B
a
b
end
julia> Mooncake.tangent_type(A)
Mooncake.Tangent{@NamedTuple{a, b}}
julia> Mooncake.tangent_type(B)
Mooncake.Tangent{@NamedTuple{a, b}}
It could very well be that I haven't internalize Mooncake so missed a clear solution, but it seems that Mooncake.set_to_zero!!(x::AllDPPLTypes)
via multiple dispatch is not simple to implement at the moment.
Note another reason to avoid hasfield checks: They create extra maintenance burden, because they would fail in a silent way that only shows as reduced performance every time we change the name of a field of any of these types.
agree with this, I added
DynamicPPL.jl/test/ext/DynamicPPLMooncakeExt.jl
Lines 168 to 248 in de57edd
@testset "Struct field assumptions" begin | |
# Test that our assumptions about DynamicPPL struct fields are correct | |
# These tests will fail if DynamicPPL changes its internal structure | |
@testset "LogDensityFunction tangent structure" begin | |
model = test_model1([1.0, 2.0, 3.0]) | |
vi = VarInfo(Random.default_rng(), model) | |
ldf = LogDensityFunction(model, vi, DefaultContext()) | |
tangent = zero_tangent(ldf) | |
# Test expected fields exist | |
@test hasfield(typeof(tangent), :fields) | |
@test hasfield(typeof(tangent.fields), :model) | |
@test hasfield(typeof(tangent.fields), :varinfo) | |
@test hasfield(typeof(tangent.fields), :context) | |
@test hasfield(typeof(tangent.fields), :adtype) | |
@test hasfield(typeof(tangent.fields), :prep) | |
# Test exact field names match | |
@test propertynames(tangent.fields) == | |
(:model, :varinfo, :context, :adtype, :prep) | |
end | |
@testset "VarInfo tangent structure" begin | |
model = test_model1([1.0, 2.0, 3.0]) | |
vi = VarInfo(Random.default_rng(), model) | |
tangent_vi = zero_tangent(vi) | |
# Test expected fields exist | |
@test hasfield(typeof(tangent_vi), :fields) | |
@test hasfield(typeof(tangent_vi.fields), :metadata) | |
@test hasfield(typeof(tangent_vi.fields), :logp) | |
@test hasfield(typeof(tangent_vi.fields), :num_produce) | |
# Test exact field names match | |
@test propertynames(tangent_vi.fields) == (:metadata, :logp, :num_produce) | |
end | |
@testset "Model tangent structure" begin | |
model = test_model1([1.0, 2.0, 3.0]) | |
tangent_model = zero_tangent(model) | |
# Test expected fields exist | |
@test hasfield(typeof(tangent_model), :fields) | |
@test hasfield(typeof(tangent_model.fields), :f) | |
@test hasfield(typeof(tangent_model.fields), :args) | |
@test hasfield(typeof(tangent_model.fields), :defaults) | |
@test hasfield(typeof(tangent_model.fields), :context) | |
# Test exact field names match | |
@test propertynames(tangent_model.fields) == (:f, :args, :defaults, :context) | |
end | |
@testset "Metadata tangent structure" begin | |
model = test_model1([1.0, 2.0, 3.0]) | |
vi = VarInfo(Random.default_rng(), model) | |
tangent_vi = zero_tangent(vi) | |
metadata = tangent_vi.fields.metadata | |
# Metadata is a NamedTuple with variable names as keys | |
@test metadata isa NamedTuple | |
# Each variable's metadata should be a Tangent with the expected fields | |
for (varname, var_metadata) in pairs(metadata) | |
@test var_metadata isa Mooncake.Tangent | |
@test hasfield(typeof(var_metadata), :fields) | |
# Test expected fields exist | |
@test hasfield(typeof(var_metadata.fields), :idcs) | |
@test hasfield(typeof(var_metadata.fields), :vns) | |
@test hasfield(typeof(var_metadata.fields), :ranges) | |
@test hasfield(typeof(var_metadata.fields), :vals) | |
@test hasfield(typeof(var_metadata.fields), :dists) | |
@test hasfield(typeof(var_metadata.fields), :orders) | |
@test hasfield(typeof(var_metadata.fields), :flags) | |
# Test exact field names match | |
@test propertynames(var_metadata.fields) == | |
(:idcs, :vns, :ranges, :vals, :dists, :orders, :flags) | |
end | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Damn, I hadn't realised structs just turn into NamedTuples
. I guess we could say that in some sense we morally almost "own" types like NamedTuple{idcs,vns,ranges,vals,dists,orders,flags}
. It doesn't seem like a great solution either, but sounds like you may well have exhausted the avenues for great solutions.
Sorry to be such a pain in the arse about this and I appreciate you've thought about this a lot, but I remain quite uncomfortable with overloading Base.set_to_zero!!(x)
for all x
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries and thanks for the comments and discussions.
I don't want to shoehorn this. I think it's worth it to have some Mooncake side support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really think that we need to avoid this type piracy -- as @mhauru points out, this has the potential to change the behaviour of user code in really weird ways.
I think that we probably need to modify the way set_to_zero!!
is defined, in order to accomodate individual type / object owners asserting that their thing definitely doesn't
- contain circular references, and
- doesn't contain any aliasing.
I believe that this is a consistent requirement across all of the internal functions which require some kind of cache to keep track of the two things above (set_to_zero!!
, _dot
, increment!!
, etc).
So I would propose the following:
- add a function to Mooncake itself called
requires_cache
, which can be applied to a primal, and returns aBool
. - to
set_to_zero!!
, we add akwarg
, also calledrequires_cache
, which defaults totrue
, and can be used inside that function to determine whether to construct a cache, or to produce aNoCache
(or whatever the correct type is).
We can put a very simple implementation of requires_cache
in Mooncake itself (something as simple as returning true
if it's anything other than a bits type would do), and then overload this for specific package-owned types in extensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me, except I didn't understand the need for the kwarg
. Is it because requires_cache
takes the primal rather than the tangent, and thus needs to be applied at the call-site and the result passed to set_to_zero!!
?
From chalk-lab/Mooncake.jl#644.