-
Notifications
You must be signed in to change notification settings - Fork 645
Lazy tensor downloading in burn-remote #3276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3276 +/- ##
==========================================
+ Coverage 82.51% 82.66% +0.14%
==========================================
Files 990 995 +5
Lines 127088 127626 +538
==========================================
+ Hits 104865 105500 +635
+ Misses 22223 22126 -97 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
_shape: Vec<usize>, | ||
_dtype: burn_tensor::DType, | ||
) -> RouterTensor<Self::Client> { | ||
let router_tensor = client.register_tensor_data(handle); | ||
unimplemented!(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why unimplemented?
} | ||
|
||
impl RemoteTensorHandle { | ||
pub(crate) fn change_backend(mut self, target_device: &WsDevice) -> Self { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add comments on the global strategy with the why!
pub struct TensorUploadState { | ||
pub data: TensorData, | ||
pub total_upload_count: u32, | ||
pub cur_upload_count: u32, | ||
} | ||
|
||
pub struct WsServerState { | ||
pub current_uploads: Mutex<HashMap<TensorId, TensorUploadState>>, | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would put that in tensor_data_service.rs
}; | ||
|
||
// build our application with some routes | ||
let app = Router::new() | ||
.route("/response", any(Self::handler_response)) | ||
.route("/request", any(Self::handler_request)) | ||
.with_state(state); | ||
.route("/upload", any(Self::handler_upload)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/data
let mut upload_state; | ||
{ | ||
let mut current_uploads = self.state.current_uploads.lock().unwrap(); | ||
if current_uploads.contains_key(&id) { | ||
// take the upload out of the hashmap while we download | ||
upload_state = current_uploads.remove(&id).unwrap(); | ||
log::info!("Tensor found (id: {id:?})"); | ||
} else { | ||
panic!("A tensor was requested (id: {id:?}) that isn't being served"); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let upload_state = {
};
if upload_state.total_upload_count != upload_state.cur_upload_count { | ||
let mut current_uploads = self.state.current_uploads.lock().unwrap(); | ||
current_uploads.insert(id, upload_state); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would do that during the first lock.
|
||
match msg { | ||
Message::Binary(bytes) => { | ||
let data: TensorData = rmp_serde::from_slice(&bytes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Optimize that at some point.
pub fn register_remote_tensor(&self, tensor: TensorNetwork, new_id: TensorId) { | ||
self.compute_sender | ||
.send(ProcessorTask::RegisterRemoteTensor(tensor, new_id)) | ||
.unwrap() | ||
} | ||
|
||
pub fn upload_tensor(&self, tensor: TensorIr, count: u32) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change naming
@@ -61,7 +83,7 @@ impl<B: BackendIr> Stream<B> { | |||
|
|||
self.compute_sender | |||
.send(ProcessorTask::Sync(id, callback_sender)) | |||
.unwrap(); | |||
.unwrap_or_else(|x| println!("{x:?}")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unwrap
crates/burn-router/tmp.rs
Outdated
@@ -0,0 +1,2451 @@ | |||
mod types { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
Co-authored-by: Jonathan Richard <jwric@users.noreply.github.com>
No description provided.