Skip to content

Use NoCache to improve set_to_zero!! performance with Mooncake #975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

sunxd3
Copy link
Member

@sunxd3 sunxd3 commented Jul 8, 2025

Copy link
Contributor

github-actions bot commented Jul 8, 2025

Benchmark Report for Commit de57edd

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  9.2 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                713.0 |                39.1 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                304.1 |                72.0 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1195.9 |                28.4 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               3170.8 |                24.5 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1411.5 |                29.4 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                912.0 |                 5.1 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5569.9 |                 3.9 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                978.5 |                 8.9 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              61906.4 |                 3.6 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8969.8 |                 9.4 |
|               Dynamic |        10 |    mooncake |             typed |   true |                125.0 |                14.8 |
|              Submodel |         1 |    mooncake |             typed |   true |                 12.8 |                 6.8 |
|                   LDA |        12 | reversediff |             typed |   true |                487.3 |                 4.6 |

Copy link

codecov bot commented Jul 8, 2025

Codecov Report

Attention: Patch coverage is 84.61538% with 10 lines in your changes missing coverage. Please review.

Project coverage is 83.10%. Comparing base (ce7c8b1) to head (de57edd).

Files with missing lines Patch % Lines
ext/DynamicPPLMooncakeExt.jl 84.61% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #975      +/-   ##
==========================================
+ Coverage   82.97%   83.10%   +0.12%     
==========================================
  Files          36       37       +1     
  Lines        3965     4025      +60     
==========================================
+ Hits         3290     3345      +55     
- Misses        675      680       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coveralls
Copy link

coveralls commented Jul 8, 2025

Pull Request Test Coverage Report for Build 16414511158

Details

  • 55 of 65 (84.62%) changed or added relevant lines in 1 file are covered.
  • 1928 unchanged lines in 28 files lost coverage.
  • Overall coverage increased (+0.02%) to 83.168%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLMooncakeExt.jl 55 65 84.62%
Files with Coverage Reduction New Missed Lines %
ext/DynamicPPLEnzymeCoreExt.jl 1 0.0%
ext/DynamicPPLForwardDiffExt.jl 1 63.64%
src/extract_priors.jl 5 53.57%
src/test_utils/model_interface.jl 5 22.22%
src/test_utils/varinfo.jl 5 76.19%
src/transforming.jl 5 72.22%
src/model_utils.jl 11 0.0%
src/logdensityfunction.jl 12 59.57%
src/distribution_wrappers.jl 16 0.0%
src/submodel_macro.jl 26 0.0%
Totals Coverage Status
Change from base Build 16174747412: 0.02%
Covered Lines: 3345
Relevant Lines: 4022

💛 - Coveralls

Copy link
Contributor

github-actions bot commented Jul 9, 2025

DynamicPPL.jl documentation for PR #975 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR975/

@sunxd3 sunxd3 requested a review from penelopeysm July 21, 2025 10:53
@@ -1,9 +1,184 @@
module DynamicPPLMooncakeExt

__precompile__(false)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to do this because I am overloading Mooncake.set_to_zero!! at the bottom of this file.

Alternatively, I can define set_to_zero!! only on tangent types, but it might be trivial as these can be deeply recursive functions. So careful implementation might need to define this function for many types.

Any better ideas?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the problem in defining set_to_zero!! for all possible types that DPPL might bring up, but the pretty heavy-handed type piracy of defining Mooncake.set_to_zero!!(x) does trouble me. For instance, might that cause a lot of invalidations, and thus a lot more Mooncake recompilation?

Also, the fact that the code relies on checking field names means that if someone just happens to define a type with the same field names, the behaviour of Mooncake on those types would depend on whether DynamicPPL is loaded in the same environment. It feels unlikely to happen, but it could lead to some truly horrendous bugs to track if it did, and also just feels like we are messing with other people's code in an inconsiderate manner.

I don't really understand the context here, but it look like this is dealing with some Mooncake issues related to circular references. Any chance that some of the machinery for dealing with that (like declaring certain types as safe/unsafe) could be implemented in Mooncake itself?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with and share the concerns.

The issue, from my POV, is that, Mooncake need to handle potential cases of circular references, by paying little price in a slightly conservative manner. It turns out, for simple DynamicPPL models, the little price of initializing and looking up in IdDict matter and make benchmarks look bad.

Ideally, there would be systematic changes in Mooncake so that one can tell Mooncake "I promise there will not be circular ref, so no need for IdDict".

As for the invalidations, I have to admit that I don't know how bad it would be. Given that for Mooncake, Julia would precompile set_to_zero_internal!! (https://github.com/chalk-lab/Mooncake.jl/blob/a26b5c35c55d1e98b9e8c6bfafbbe3dc55784140/src/tangents.jl#L728-L733). Given that set_to_zero!! at the moment is pretty much just an alias, I don't expect the cost is too grand.

ref chalk-lab/Mooncake.jl#552 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that for Mooncake, Julia would precompile set_to_zero_internal!! (https://github.com/chalk-lab/Mooncake.jl/blob/a26b5c35c55d1e98b9e8c6bfafbbe3dc55784140/src/tangents.jl#L728-L733).

I wonder if inlining might mess with this, if the call to set_to_zero_internal!! is inlined into set_to_zero!!.

How bad would it be if you used dispatch to make set_to_zero overloaded only for DPPL-owned types, and thus avoided the type piracy? I see the two ways of doing this:

  1. Replace the checks of hasfield with type bounds. I don't know if Mooncake's several different tangent types make this really hard. The hasfield checks aren't any more generic than isa checks, though, because the field names effectively determine the type, right?
  2. Have a really high-level check, where you define Mooncake.set_to_zero!!(x::AllDPPLTypes) and AllDPPTypes = Union{AbstractVarInfo,Context,AbstractAccumulator, etc.}, i.e. catch all the highest-level types in our type hierarchy.

Is either of those workable?

Note another reason to avoid hasfield checks: They create extra maintenance burden, because they would fail in a silent way that only shows as reduced performance every time we change the name of a field of any of these types.

Gonna ping @willtebbutt for his thoughts on how set_to_zero!! should be overloaded.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> struct A
       a
       b
       end

julia> struct B
       a
       b
       end

julia> Mooncake.tangent_type(A)
Mooncake.Tangent{@NamedTuple{a, b}}

julia> Mooncake.tangent_type(B)
Mooncake.Tangent{@NamedTuple{a, b}}

It could very well be that I haven't internalize Mooncake so missed a clear solution, but it seems that Mooncake.set_to_zero!!(x::AllDPPLTypes) via multiple dispatch is not simple to implement at the moment.

Note another reason to avoid hasfield checks: They create extra maintenance burden, because they would fail in a silent way that only shows as reduced performance every time we change the name of a field of any of these types.

agree with this, I added

@testset "Struct field assumptions" begin
# Test that our assumptions about DynamicPPL struct fields are correct
# These tests will fail if DynamicPPL changes its internal structure
@testset "LogDensityFunction tangent structure" begin
model = test_model1([1.0, 2.0, 3.0])
vi = VarInfo(Random.default_rng(), model)
ldf = LogDensityFunction(model, vi, DefaultContext())
tangent = zero_tangent(ldf)
# Test expected fields exist
@test hasfield(typeof(tangent), :fields)
@test hasfield(typeof(tangent.fields), :model)
@test hasfield(typeof(tangent.fields), :varinfo)
@test hasfield(typeof(tangent.fields), :context)
@test hasfield(typeof(tangent.fields), :adtype)
@test hasfield(typeof(tangent.fields), :prep)
# Test exact field names match
@test propertynames(tangent.fields) ==
(:model, :varinfo, :context, :adtype, :prep)
end
@testset "VarInfo tangent structure" begin
model = test_model1([1.0, 2.0, 3.0])
vi = VarInfo(Random.default_rng(), model)
tangent_vi = zero_tangent(vi)
# Test expected fields exist
@test hasfield(typeof(tangent_vi), :fields)
@test hasfield(typeof(tangent_vi.fields), :metadata)
@test hasfield(typeof(tangent_vi.fields), :logp)
@test hasfield(typeof(tangent_vi.fields), :num_produce)
# Test exact field names match
@test propertynames(tangent_vi.fields) == (:metadata, :logp, :num_produce)
end
@testset "Model tangent structure" begin
model = test_model1([1.0, 2.0, 3.0])
tangent_model = zero_tangent(model)
# Test expected fields exist
@test hasfield(typeof(tangent_model), :fields)
@test hasfield(typeof(tangent_model.fields), :f)
@test hasfield(typeof(tangent_model.fields), :args)
@test hasfield(typeof(tangent_model.fields), :defaults)
@test hasfield(typeof(tangent_model.fields), :context)
# Test exact field names match
@test propertynames(tangent_model.fields) == (:f, :args, :defaults, :context)
end
@testset "Metadata tangent structure" begin
model = test_model1([1.0, 2.0, 3.0])
vi = VarInfo(Random.default_rng(), model)
tangent_vi = zero_tangent(vi)
metadata = tangent_vi.fields.metadata
# Metadata is a NamedTuple with variable names as keys
@test metadata isa NamedTuple
# Each variable's metadata should be a Tangent with the expected fields
for (varname, var_metadata) in pairs(metadata)
@test var_metadata isa Mooncake.Tangent
@test hasfield(typeof(var_metadata), :fields)
# Test expected fields exist
@test hasfield(typeof(var_metadata.fields), :idcs)
@test hasfield(typeof(var_metadata.fields), :vns)
@test hasfield(typeof(var_metadata.fields), :ranges)
@test hasfield(typeof(var_metadata.fields), :vals)
@test hasfield(typeof(var_metadata.fields), :dists)
@test hasfield(typeof(var_metadata.fields), :orders)
@test hasfield(typeof(var_metadata.fields), :flags)
# Test exact field names match
@test propertynames(var_metadata.fields) ==
(:idcs, :vns, :ranges, :vals, :dists, :orders, :flags)
end
end
to at least make it fail louder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn, I hadn't realised structs just turn into NamedTuples. I guess we could say that in some sense we morally almost "own" types like NamedTuple{idcs,vns,ranges,vals,dists,orders,flags}. It doesn't seem like a great solution either, but sounds like you may well have exhausted the avenues for great solutions.

Sorry to be such a pain in the arse about this and I appreciate you've thought about this a lot, but I remain quite uncomfortable with overloading Base.set_to_zero!!(x) for all x.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries and thanks for the comments and discussions.

I don't want to shoehorn this. I think it's worth it to have some Mooncake side support.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really think that we need to avoid this type piracy -- as @mhauru points out, this has the potential to change the behaviour of user code in really weird ways.

I think that we probably need to modify the way set_to_zero!! is defined, in order to accomodate individual type / object owners asserting that their thing definitely doesn't

  1. contain circular references, and
  2. doesn't contain any aliasing.

I believe that this is a consistent requirement across all of the internal functions which require some kind of cache to keep track of the two things above (set_to_zero!!, _dot, increment!!, etc).

So I would propose the following:

  1. add a function to Mooncake itself called requires_cache, which can be applied to a primal, and returns a Bool.
  2. to set_to_zero!!, we add a kwarg, also called requires_cache, which defaults to true, and can be used inside that function to determine whether to construct a cache, or to produce a NoCache (or whatever the correct type is).

We can put a very simple implementation of requires_cache in Mooncake itself (something as simple as returning true if it's anything other than a bits type would do), and then overload this for specific package-owned types in extensions.

@sunxd3 @mhauru do you think this would cut it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me, except I didn't understand the need for the kwarg. Is it because requires_cache takes the primal rather than the tangent, and thus needs to be applied at the call-site and the result passed to set_to_zero!!?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants