-
Notifications
You must be signed in to change notification settings - Fork 645
Add Dice-Sorenson Coefficient Metric #3407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Mathijs de Boer <mathijs.de.boer0@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing! That's a great start 🙂
I think the metric strictly assumes one-hot encoded targets, but it is not clearly documented. We should add more info regarding the outputs and targets shapes.
We could also accept int targets for multiclass and convert them to one-hot.
Also, it seems that the background is always included regardless of the config?
Hey there, thanks for the feedback! I've quickly commented on some of your comments where I could form an answer, and will have a closer look at the rest when I have time. |
No rush 🙂 implementation looks good btw, comments were mainly directed at the API/usage and asking questions to better grasp the scope, |
…nd improve input validation Signed-off-by: Mathijs de Boer <mathijs.de.boer0@gmail.com>
I've tackled the points you've raised. It's pretty late here, so there is a reasonable chance that some new mistakes slipped in. All checks pass though! |
Codecov ReportAttention: Patch coverage is
❌ Your project check has failed because the head coverage (63.48%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #3407 +/- ##
==========================================
+ Coverage 63.45% 63.48% +0.02%
==========================================
Files 981 982 +1
Lines 109662 109891 +229
==========================================
+ Hits 69589 69761 +172
- Misses 40073 40130 +57 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Mathijs de Boer <m.deboer-41@umcutrecht.nl>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor comment for the checks, otherwise looks good!
assert!( | ||
outputs.dims() == targets.dims(), | ||
"Outputs and targets must have the same dimensions. Got {:?} and {:?}", | ||
outputs.dims(), | ||
targets.dims() | ||
); | ||
assert!( | ||
outputs.dims().len() == D, | ||
"Outputs must have exactly {} dimensions. Got {:?}", | ||
D, | ||
outputs.dims() | ||
); | ||
assert!( | ||
targets.dims().len() == D, | ||
"Targets must have exactly {} dimensions. Got {:?}", | ||
D, | ||
targets.dims() | ||
); | ||
assert!( | ||
outputs.shape() == targets.shape(), | ||
"Outputs and targets must have the same shape. Got {:?} and {:?}", | ||
outputs.shape(), | ||
targets.shape() | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to check dims().len() == D
, that should always be true for a given tensor.
Also, shape()
and dims()
check are essentially the same thing, so we can keep only one of those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, there's probably some more of my Python main background coming in. I'll simplify that section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These have now been reduced to a single assert
/// Model outputs (predictions), as a tensor. | ||
outputs: Tensor<B, D, Int>, | ||
/// Ground truth targets, as a tensor. | ||
targets: Tensor<B, D, Int>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'll go ahead and make it int-only for now, pending anyone bringing this up in the future
Sure! We can broaden the support in the future.
Signed-off-by: Mathijs de Boer <m.deboer-41@umcutrecht.nl>
Vision Metrics
I saw that the current set metrics did not contain vision metrics. For this pull request, I have currently implemented the Dice-Sorenson Coefficient, and placed it in its own module to differentiate it from the other metrics. I could've added more metrics to this, but I figured I'd open a PR early, to get some feedback from the maintainers regarding organization, codestyle, etc. or if there even is an interest to include these metrics in the core Burn project. (as opposed to a spinoff)
I mainly come from the usual PyTorch scene, and am learning to use Rust in my research. I see Burn as a solid option to move some of our research into a faster language than Python for inference when we're trying to get some of it into clinical practice.
Checklist
cargo run-checks
command has been executed.Related Issues/PRs
None
Changes
Add
DiceMetric
toburn-train
'smetric/vision
module.Testing
Unit testing code has been included to evaluate the following cases:
LLM Notice
I used Copilot in VS Code to set up the scaffolding for this metric, based on the implementations in
acc.rs
andrecall.rs
. All code is double checked by me and altered where I deemed it necessary.