You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
In its place, a @pobserve macro is added to enable multithreaded tilde-observe statements, according to the plan outlined in #924 (comment). Broadly speaking, the following
@modelfunctionf(x)
a ~Normal()
@pobservefor i ineachindex(x)
b =my_fancy_calculation(a)
x[i] ~Normal(b)
endend
is converted into (modulo variable names)
@modelfunctionf(x)
a ~Normal()
thread_results =map(eachindex(x)) do i
Threads.@spawnbegin
loglike =zero(DynamicPPL.getloglikelihood(__varinfo__))
b =my_fancy_calculation(a)
loglike += Distributions.logpdf(Normal(b), x[i])
loglike
endend
__varinfo__ = DynamicPPL.accloglikelihood!!(__varinfo__, sum(fetch.(thread_results)))
end
No actual varinfo manipulation happens inside the Threads.@spawn: instead, the log-likelihood contributions are calculated in each thread, then summed after the individual threads have finished their tasks. Because of this, there is no need to maintain one log-likelihood accumulator per thread, and consequently no need for ThreadSafeVarInfo.
Code simplification in DynamicPPL, and reducing the number of AbstractVarInfo subtypes, is obviously a big argument.
But in fact, that's not my main motivation. I'm mostly motivated to do this because TSVI in general is IMO not good code: it works, but in many ways it's a hack.
Any time Julia is launched with more than 1 thread, all models will be executed with TSVI, even if there is no parallelisation within the model itself. This violates a general principle that users should only 'pay for what they need'.
Furthermore, to determine the requisite length for that vector of accumulators, we currently use Threads.nthreads() * 2 which is a hacky heuristic. The correct solution would be Threads.maxthreadid(), but Mooncake couldn't differentiate through that.
In fact, even the correct solution is not actually correct. Quoting from Julia blog post above, "relying on threadid, nthreads and even maxthreadid [is] perilous. Any code that relies on a specific threadid staying constant, or on a constant number of threads during execution, is bound to be incorrect.".
The choice of whether to use TSVI or not is determined by if Threads.nthreads() > 1, which cannot be determined at compile time. This means that:
Extra effort is needed to make sure that, not only is TSVI type stable, but both TSVI and non-TSVI branches in evaluate!! must be together type stable.
Even though Mooncake can't differentiate through multithreaded code, we still need to make sure it's able to differentiate through TSVI, otherwise it can't differentiate through evaluate!!. That's just silly IMO.
This PR has no tests yet, but I ran this locally and the log-likelihood gets accumulated correctly:
julia>using DynamicPPL, Distributions
julia>@modelfunctionf(x)
@pobservefor i ineachindex(x)
println(Threads.threadid(), "->", x[i])
x[i] ~Normal()
endreturn DynamicPPL.getlogp(__varinfo__)
end
f (generic function with 2 methods)
julia>f([1.0, 2.0])() # note that this was run with 2 threads2->2.01->1.0
(logprior =0.0, logjac =0.0, loglikelihood =-4.337877066409345)
julia>logpdf(Normal(), 1.0) +logpdf(Normal(), 2.0) # for comparison-4.337877066409345
I can also confirm that the parallelisation is correctly occurring with this model:
using DynamicPPL, Distributions
@modelfunctiong(x)
@pobservefor i ineachindex(x)
# can't use Base.sleep as that doesn't fully block
Libc.systemsleep(1.0)
x[i] ~Normal()
endendprintln(Threads.nthreads())
@timeg([1.0, 2.0])()
If you run this with 1 thread it takes 2 seconds, and if you run it with 2 threads it takes 1 second.
It also works correctly with MCMCThreads() (with some minor adjustments to Turing.jl for compatibility with this branch). NOTE: Sampling with @pobserve is now fully reproducible, whereas Threads.@threads was not reproducible even when seeded.
using Turing, DynamicPPL, Random
@modelfunctionh(y)
x ~MvNormal(zeros(length(y)), I)
@pobservefor i ineachindex(y)
y[i] ~Normal(x[i])
endend
chn =sample(Xoshiro(468), h([1.0, 2.0, 3.0]), NUTS(), MCMCThreads(), 2000, 4; check_model=false)
describe(chn)
What now?
There are a handful of limitations to this PR. These are the ones I can think of right now:
It will crash if the VarInfo used for evaluation does not have a likelihood accumulator.
It only works with likelihood terms. This mimics the pre-0.37 behaviour but in principle, 0.37 does allow users to accumulate prior probabilities (or any accumulator) in a thread-safe manner. Of course, they can't do it with tilde statements; they can only do it by calling something like DynamicPPL.acclogprior!!().
It doesn't work with .~ (or maybe it does, I haven't tested, but my guess is that it will bug out)
It doesn't work with conditioned values.
If x is not a model argument or conditioned upon, this will yield wrong results for the typical x = Vector{Float64}(undef, 2); @pobserve for i in eachindex(x); x[i] ~ dist; end as it will naively accumulate logpdf(dist, x[i]) even though this should be an assumption rather than observation
There is no way to extract other computations from the threads.
Libtask doesn't work with Threads.@spawn, so PG will throw an error with @pobserve.
@pobserve is a bit too unambitious. If one day we make it work with assume, then it will have to be renamed, i.e. a breaking change.
I believe that all of these are either unimportant or can be worked around with some additional macro leg-work:
Not important, nobody is running around evaluating their models with no likelihood accumulator. Not even Turing does this. Also easy enough for us to guard against by wrapping the entire thing in an if/else.
You can still manually calculate log-prior terms in a thread and then do a single acclogprior!! outside the threaded bit.
This is a bit boilerplate-y but otherwise quite straightforward to fix.
This can be fixed by performing the same checks on the tilde lhs that we do in the main model macro.
I am more than happy to take ideas for other names.
So for now this should mostly be considered a proof of principle rather than a complete PR.
Finally, note that this PR already removes > 550 lines of code but this is not a full picture of the simplification afforded. For example, I did not remove the split, combine, and convert_eltype methods on accumulators, which I believe can either be removed or simplified once TSVI is removed.
❌ Patch coverage is 28.12500% with 23 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.70%. Comparing base (0cf3440) to head (79bedaf).
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR removes
ThreadSafeVarInfo
.In its place, a
@pobserve
macro is added to enable multithreaded tilde-observe statements, according to the plan outlined in #924 (comment). Broadly speaking, the followingis converted into (modulo variable names)
No actual varinfo manipulation happens inside the
Threads.@spawn
: instead, the log-likelihood contributions are calculated in each thread, then summed after the individual threads have finished their tasks. Because of this, there is no need to maintain one log-likelihood accumulator per thread, and consequently no need forThreadSafeVarInfo
.Closes #429.
Closes #924.
Closes #947.
Why?
Code simplification in DynamicPPL, and reducing the number of
AbstractVarInfo
subtypes, is obviously a big argument.But in fact, that's not my main motivation. I'm mostly motivated to do this because TSVI in general is IMO not good code: it works, but in many ways it's a hack.
Threads.@threads for i in x ... end
, and then internally we useThreads.threadid()
to index into a vector of accumulators. This is now regarded as "incorrect parallel code that contains the possibility of race conditions which can give wrong results". See https://julialang.org/blog/2023/07/PSA-dont-use-threadid/ and https://discourse.julialang.org/t/behavior-of-threads-threads-for-loop/76042.Threads.nthreads() * 2
which is a hacky heuristic. The correct solution would beThreads.maxthreadid()
, but Mooncake couldn't differentiate through that.threadid
,nthreads
and evenmaxthreadid
[is] perilous. Any code that relies on a specificthreadid
staying constant, or on a constant number of threads during execution, is bound to be incorrect.".if Threads.nthreads() > 1
, which cannot be determined at compile time. This means that:evaluate!!
must be together type stable.evaluate!!
. That's just silly IMO.cacheForReverse
EnzymeAD/Enzyme.jl#2518Does this actually work?
This PR has no tests yet, but I ran this locally and the log-likelihood gets accumulated correctly:
I can also confirm that the parallelisation is correctly occurring with this model:
If you run this with 1 thread it takes 2 seconds, and if you run it with 2 threads it takes 1 second.
It also works correctly with
MCMCThreads()
(with some minor adjustments to Turing.jl for compatibility with this branch). NOTE: Sampling with@pobserve
is now fully reproducible, whereasThreads.@threads
was not reproducible even when seeded.What now?
There are a handful of limitations to this PR. These are the ones I can think of right now:
DynamicPPL.acclogprior!!()
..~
(or maybe it does, I haven't tested, but my guess is that it will bug out)x
is not a model argument or conditioned upon, this will yield wrong results for the typicalx = Vector{Float64}(undef, 2); @pobserve for i in eachindex(x); x[i] ~ dist; end
as it will naively accumulatelogpdf(dist, x[i])
even though this should be an assumption rather than observationThreads.@spawn
, so PG will throw an error with@pobserve
.@pobserve
is a bit too unambitious. If one day we make it work with assume, then it will have to be renamed, i.e. a breaking change.I believe that all of these are either unimportant or can be worked around with some additional macro leg-work:
acclogprior!!
outside the threaded bit.(retval, loglike)
rather than justloglike
.Threads.@threads
.So for now this should mostly be considered a proof of principle rather than a complete PR.
Finally, note that this PR already removes > 550 lines of code but this is not a full picture of the simplification afforded. For example, I did not remove the
split
,combine
, andconvert_eltype
methods on accumulators, which I believe can either be removed or simplified once TSVI is removed.