Skip to content

Commit 81985bd

Browse files
authored
Lazy tensor downloading in burn-remote (#3276)
1 parent f8273f0 commit 81985bd

File tree

14 files changed

+428
-98
lines changed

14 files changed

+428
-98
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/burn-remote/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ version.workspace = true
1515
workspace = true
1616

1717
[features]
18-
default = []
18+
default = ["client", "server"]
1919
doc = []
2020
client = ["tokio-tungstenite", "async-channel", "tokio/sync"]
21-
server = ["axum", "tracing-core", "tracing-subscriber"]
21+
server = ["tokio-tungstenite", "async-channel", "tokio/sync", "axum", "tracing-core", "tracing-subscriber"]
2222

2323

2424
[dependencies]
@@ -27,6 +27,8 @@ burn-tensor = { path = "../burn-tensor", version = "0.18.0", default-features =
2727
burn-common = { path = "../burn-common", version = "0.18.0", default-features = true }
2828
burn-router = { path = "../burn-router", version = "0.18.0", default-features = true }
2929

30+
bytes = { version = "1.0" }
31+
3032
# Basic dependencies
3133
derive-new = { workspace = true }
3234
log = { workspace = true }

crates/burn-remote/src/client/channel.rs

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use burn_ir::TensorIr;
2-
use burn_router::{RouterTensor, RunnerChannel, RunnerClient, TensorHandle};
2+
use burn_router::{RouterTensor, RunnerChannel, get_client};
3+
4+
use crate::shared::{ComputeTask, TensorRemote};
35

46
use super::{
57
WsClient,
6-
runner::{WsBridge, WsDevice},
8+
runner::{RemoteTensorHandle, WsBridge, WsDevice},
79
};
810

911
/// A local channel with direct connection to the backend runner clients.
@@ -29,17 +31,47 @@ impl RunnerChannel for WsChannel {
2931
WsClient::init(device.clone())
3032
}
3133

32-
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> TensorHandle<Self::Bridge> {
33-
client.runtime.block_on(client.read_tensor(tensor.clone()))
34+
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> RemoteTensorHandle {
35+
RemoteTensorHandle {
36+
client: client.clone(),
37+
tensor: tensor.clone(),
38+
}
3439
}
3540

3641
fn register_tensor(
3742
client: &Self::Client,
38-
handle: TensorHandle<Self::Bridge>,
39-
_shape: Vec<usize>,
40-
_dtype: burn_tensor::DType,
43+
handle: RemoteTensorHandle,
44+
shape: Vec<usize>,
45+
dtype: burn_tensor::DType,
46+
) -> RouterTensor<Self::Client> {
47+
let remote_tensor = TensorRemote {
48+
id: handle.tensor.id,
49+
address: client.device.address.to_string(),
50+
};
51+
let new_id = client.sender.new_tensor_id();
52+
client
53+
.sender
54+
.send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
55+
56+
RouterTensor::new(handle.tensor.id, shape, dtype, client.clone())
57+
}
58+
59+
fn change_client_backend(
60+
tensor: RouterTensor<Self::Client>,
61+
target_device: &Self::Device, // target device
4162
) -> RouterTensor<Self::Client> {
42-
let router_tensor = client.register_tensor_data(handle);
63+
// Get tensor handle from current client
64+
let original_client = tensor.client.clone();
65+
let desc = tensor.into_ir();
66+
let handle = Self::get_tensor_handle(&desc, &original_client);
67+
68+
let handle = handle.change_backend(target_device);
69+
70+
let id = handle.tensor.id;
71+
72+
let target_client = get_client::<Self>(target_device);
73+
let router_tensor: RouterTensor<WsClient> =
74+
RouterTensor::new(id, handle.tensor.shape, handle.tensor.dtype, target_client);
4375

4476
router_tensor
4577
}

crates/burn-remote/src/client/runner.rs

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use burn_common::future::DynFut;
2-
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient};
2+
use burn_ir::TensorIr;
3+
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
34
use burn_tensor::{
45
DType, TensorData,
56
backend::{DeviceId, DeviceOps},
@@ -9,9 +10,9 @@ use std::{
910
sync::Arc,
1011
};
1112

12-
use crate::shared::{ComputeTask, TaskResponseContent};
13+
use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
1314

14-
use super::WsClient;
15+
use super::{WsChannel, WsClient};
1516

1617
// It is very important to block on any request made with the sender, since ordering is crucial
1718
// when registering operation or creating tensors.
@@ -45,7 +46,7 @@ impl RunnerClient for WsClient {
4546

4647
self.sender.send(ComputeTask::RegisterTensor(id, data));
4748

48-
RouterTensor::new(Arc::new(id), shape, dtype, self.clone())
49+
RouterTensor::new(id, shape, dtype, self.clone())
4950
}
5051

5152
fn register_empty_tensor(
@@ -55,7 +56,7 @@ impl RunnerClient for WsClient {
5556
) -> RouterTensor<Self> {
5657
let id = self.sender.new_tensor_id();
5758

58-
RouterTensor::new(Arc::new(id), shape, dtype, self.clone())
59+
RouterTensor::new(id, shape, dtype, self.clone())
5960
}
6061

6162
fn register_float_tensor(
@@ -140,31 +141,67 @@ impl DeviceOps for WsDevice {
140141

141142
pub struct WsBridge;
142143

144+
pub struct RemoteTensorHandle {
145+
pub(crate) client: WsClient,
146+
pub(crate) tensor: TensorIr,
147+
}
148+
149+
impl RemoteTensorHandle {
150+
/// Changes the backend of the tensor via a WebSocket.
151+
/// We ask the original server to expose the tensor, then ask the target server to fetch
152+
/// the tensor. The target server will open a new websocket connection to the original server
153+
/// to download the data.
154+
/// This way the client never sees the tensor's data, and we avoid a bottleneck.
155+
pub(crate) fn change_backend(mut self, target_device: &WsDevice) -> Self {
156+
self.client.sender.send(ComputeTask::ExposeTensorRemote {
157+
tensor: self.tensor.clone(),
158+
count: 1,
159+
});
160+
161+
let target_client: WsClient = get_client::<WsChannel>(target_device);
162+
163+
let new_id = target_client.sender.new_tensor_id();
164+
165+
let remote_tensor = TensorRemote {
166+
id: self.tensor.id,
167+
address: self.client.device.address.to_string(),
168+
};
169+
target_client
170+
.sender
171+
.send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
172+
173+
self.tensor.id = new_id;
174+
self.client = target_client;
175+
176+
self
177+
}
178+
}
179+
143180
impl MultiBackendBridge for WsBridge {
144-
type TensorHandle = TensorData;
181+
type TensorHandle = RemoteTensorHandle;
145182
type Device = WsDevice;
146183

147184
fn change_backend_float(
148185
tensor: Self::TensorHandle,
149186
_shape: burn_tensor::Shape,
150-
_target_device: &Self::Device,
187+
target_device: &Self::Device,
151188
) -> Self::TensorHandle {
152-
tensor
189+
tensor.change_backend(target_device)
153190
}
154191

155192
fn change_backend_int(
156193
tensor: Self::TensorHandle,
157194
_shape: burn_tensor::Shape,
158-
_target_device: &Self::Device,
195+
target_device: &Self::Device,
159196
) -> Self::TensorHandle {
160-
tensor
197+
tensor.change_backend(target_device)
161198
}
162199

163200
fn change_backend_bool(
164201
tensor: Self::TensorHandle,
165202
_shape: burn_tensor::Shape,
166-
_target_device: &Self::Device,
203+
target_device: &Self::Device,
167204
) -> Self::TensorHandle {
168-
tensor
205+
tensor.change_backend(target_device)
169206
}
170207
}

0 commit comments

Comments
 (0)