Skip to content

Remove ThreadSafeVarInfo #1023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Remove ThreadSafeVarInfo #1023

wants to merge 1 commit into from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Aug 17, 2025

Summary

This PR removes ThreadSafeVarInfo.

In its place, a @pobserve macro is added to enable multithreaded tilde-observe statements, according to the plan outlined in #924 (comment). Broadly speaking, the following

@model function f(x)
    a ~ Normal()
    @pobserve for i in eachindex(x)
        b = my_fancy_calculation(a)
        x[i] ~ Normal(b)
    end
end

is converted into (modulo variable names)

@model function f(x)
    a ~ Normal()
    thread_results = map(eachindex(x)) do i
        Threads.@spawn begin
            loglike = zero(DynamicPPL.getloglikelihood(__varinfo__))
            b = my_fancy_calculation(a)
            loglike += Distributions.logpdf(Normal(b), x[i])
            loglike
        end
    end
    __varinfo__ = DynamicPPL.accloglikelihood!!(__varinfo__, sum(fetch.(thread_results)))
end

No actual varinfo manipulation happens inside the Threads.@spawn: instead, the log-likelihood contributions are calculated in each thread, then summed after the individual threads have finished their tasks. Because of this, there is no need to maintain one log-likelihood accumulator per thread, and consequently no need for ThreadSafeVarInfo.

Closes #429.
Closes #924.
Closes #947.

Why?

Code simplification in DynamicPPL, and reducing the number of AbstractVarInfo subtypes, is obviously a big argument.

But in fact, that's not my main motivation. I'm mostly motivated to do this because TSVI in general is IMO not good code: it works, but in many ways it's a hack.

  1. Any time Julia is launched with more than 1 thread, all models will be executed with TSVI, even if there is no parallelisation within the model itself. This violates a general principle that users should only 'pay for what they need'.
  2. TSVI encourages users to use Threads.@threads for i in x ... end, and then internally we use Threads.threadid() to index into a vector of accumulators. This is now regarded as "incorrect parallel code that contains the possibility of race conditions which can give wrong results". See https://julialang.org/blog/2023/07/PSA-dont-use-threadid/ and https://discourse.julialang.org/t/behavior-of-threads-threads-for-loop/76042.
  3. Furthermore, to determine the requisite length for that vector of accumulators, we currently use Threads.nthreads() * 2 which is a hacky heuristic. The correct solution would be Threads.maxthreadid(), but Mooncake couldn't differentiate through that.
  4. In fact, even the correct solution is not actually correct. Quoting from Julia blog post above, "relying on threadid, nthreads and even maxthreadid [is] perilous. Any code that relies on a specific threadid staying constant, or on a constant number of threads during execution, is bound to be incorrect.".
  5. In general TSVI is difficult to make type-stable. See, e.g., Accumulators, stage 1 #885 (comment).
  6. The choice of whether to use TSVI or not is determined by if Threads.nthreads() > 1, which cannot be determined at compile time. This means that:

Does this actually work?

This PR has no tests yet, but I ran this locally and the log-likelihood gets accumulated correctly:

julia> using DynamicPPL, Distributions

julia> @model function f(x)
           @pobserve for i in eachindex(x)
               println(Threads.threadid(), "->", x[i])
               x[i] ~ Normal()
           end
           return DynamicPPL.getlogp(__varinfo__)
       end
f (generic function with 2 methods)

julia> f([1.0, 2.0])() # note that this was run with 2 threads
2->2.0
1->1.0
(logprior = 0.0, logjac = 0.0, loglikelihood = -4.337877066409345)

julia> logpdf(Normal(), 1.0) + logpdf(Normal(), 2.0) # for comparison
-4.337877066409345

I can also confirm that the parallelisation is correctly occurring with this model:

using DynamicPPL, Distributions
@model function g(x)
    @pobserve for i in eachindex(x)
        # can't use Base.sleep as that doesn't fully block
        Libc.systemsleep(1.0)
        x[i] ~ Normal()
    end
end
println(Threads.nthreads())
@time g([1.0, 2.0])()

If you run this with 1 thread it takes 2 seconds, and if you run it with 2 threads it takes 1 second.

It also works correctly with MCMCThreads() (with some minor adjustments to Turing.jl for compatibility with this branch). NOTE: Sampling with @pobserve is now fully reproducible, whereas Threads.@threads was not reproducible even when seeded.

using Turing, DynamicPPL, Random
@model function h(y)
    x ~ MvNormal(zeros(length(y)), I)
    @pobserve for i in eachindex(y)
        y[i] ~ Normal(x[i])
    end
end
chn = sample(Xoshiro(468), h([1.0, 2.0, 3.0]), NUTS(), MCMCThreads(), 2000, 4; check_model=false)
describe(chn)

What now?

There are a handful of limitations to this PR. These are the ones I can think of right now:

  1. It will crash if the VarInfo used for evaluation does not have a likelihood accumulator.
  2. It only works with likelihood terms. This mimics the pre-0.37 behaviour but in principle, 0.37 does allow users to accumulate prior probabilities (or any accumulator) in a thread-safe manner. Of course, they can't do it with tilde statements; they can only do it by calling something like DynamicPPL.acclogprior!!().
  3. It doesn't work with .~ (or maybe it does, I haven't tested, but my guess is that it will bug out)
  4. It doesn't work with conditioned values.
  5. If x is not a model argument or conditioned upon, this will yield wrong results for the typical x = Vector{Float64}(undef, 2); @pobserve for i in eachindex(x); x[i] ~ dist; end as it will naively accumulate logpdf(dist, x[i]) even though this should be an assumption rather than observation
  6. There is no way to extract other computations from the threads.
  7. Libtask doesn't work with Threads.@spawn, so PG will throw an error with @pobserve.
  8. @pobserve is a bit too unambitious. If one day we make it work with assume, then it will have to be renamed, i.e. a breaking change.

I believe that all of these are either unimportant or can be worked around with some additional macro leg-work:

  1. Not important, nobody is running around evaluating their models with no likelihood accumulator. Not even Turing does this. Also easy enough for us to guard against by wrapping the entire thing in an if/else.
  2. You can still manually calculate log-prior terms in a thread and then do a single acclogprior!! outside the threaded bit.
  3. This is a bit boilerplate-y but otherwise quite straightforward to fix.
  4. This can be fixed by performing the same checks on the tilde lhs that we do in the main model macro.
  5. Same as (4). Note that this is broadly also a problem with non-parallel models (it's just the inverse problem of Derived variables from data on the LHS of tilde  #519) and in general Forcing all LHS variables of tilde to be part of model arguments #965 or similar 'static VarInfo' approaches would fix this.
  6. This can be fixed easily by changing the macro to return a tuple of (retval, loglike) rather than just loglike.
  7. This is actually an improvement over the current behaviour because right now PG silently yields incorrect results with Threads.@threads.
  8. I am more than happy to take ideas for other names.

So for now this should mostly be considered a proof of principle rather than a complete PR.

Finally, note that this PR already removes > 550 lines of code but this is not a full picture of the simplification afforded. For example, I did not remove the split, combine, and convert_eltype methods on accumulators, which I believe can either be removed or simplified once TSVI is removed.

@penelopeysm penelopeysm marked this pull request as draft August 17, 2025 18:55
Copy link

codecov bot commented Aug 17, 2025

Codecov Report

❌ Patch coverage is 28.12500% with 23 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.70%. Comparing base (0cf3440) to head (79bedaf).

Files with missing lines Patch % Lines
src/pobserve_macro.jl 0.00% 23 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1023      +/-   ##
==========================================
- Coverage   82.26%   81.70%   -0.56%     
==========================================
  Files          38       38              
  Lines        3947     3827     -120     
==========================================
- Hits         3247     3127     -120     
  Misses        700      700              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

github-actions bot commented Aug 17, 2025

Benchmark Report for Commit 79bedaf

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  8.5 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                678.8 |                42.2 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                422.2 |                53.4 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1013.1 |                34.4 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               6888.6 |                27.4 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1065.0 |                40.9 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |               1044.7 |                 4.6 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5938.3 |                 4.3 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                993.0 |                 9.1 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              66437.8 |                 3.9 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8291.6 |                10.3 |
|               Dynamic |        10 |    mooncake |             typed |   true |                141.9 |                12.2 |
|              Submodel |         1 |    mooncake |             typed |   true |                 12.7 |                 5.5 |
|                   LDA |        12 | reversediff |             typed |   true |               1047.7 |                 2.5 |

@coveralls
Copy link

coveralls commented Aug 17, 2025

Pull Request Test Coverage Report for Build 17024844925

Details

  • 9 of 32 (28.13%) changed or added relevant lines in 5 files are covered.
  • 6 unchanged lines in 3 files lost coverage.
  • Overall coverage decreased (-0.6%) to 81.966%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/pobserve_macro.jl 0 23 0.0%
Files with Coverage Reduction New Missed Lines %
src/abstract_varinfo.jl 1 75.42%
src/compiler.jl 1 86.98%
src/varinfo.jl 4 86.21%
Totals Coverage Status
Change from base Build 16942570157: -0.6%
Covered Lines: 3127
Relevant Lines: 3815

💛 - Coveralls

Copy link
Contributor

DynamicPPL.jl documentation for PR #1023 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1023/

@penelopeysm
Copy link
Member Author

As a bonus, this PR completely fixes all Enzyme issues arising from DPPL 0.37. #947

@penelopeysm penelopeysm changed the title Attempt to remove TSVI Remove ThreadSafeVarInfo Aug 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enzyme doesn't like accumulators ThreadSafeVarInfo and threadid Remove use of threadid
2 participants