Skip to content

Commit a408d71

Browse files
committed
collective config refactoring to add broadcast and reduce
1 parent 18ba297 commit a408d71

File tree

9 files changed

+446
-232
lines changed

9 files changed

+446
-232
lines changed

crates/burn-collective/multinode-tests/src/bin/node.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ fn launch_threads<B: Backend, const D: usize>(
7474

7575
// Put all the parameters in the config
7676
let config = CollectiveConfig::default()
77-
.with_all_reduce_kind(test_input.all_reduce_kind)
77+
.with_all_reduce_kind(test_input.all_reduce_op)
7878
.with_num_devices(test_input.device_count)
7979
.with_device_id(id.into())
8080
.with_node_id(test_input.node_id)

crates/burn-collective/multinode-tests/src/bin/test_launcher.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use burn::{
1313
prelude::Backend,
1414
tensor::{Shape, Tensor, TensorData},
1515
};
16-
use burn_collective::{AllReduceStrategy, ReduceKind};
16+
use burn_collective::{AllReduceStrategy, ReduceOperation};
1717
use burn_collective_multinode_tests::shared::NodeTestData;
1818
use burn_common::rand::{SeedableRng, StdRng};
1919

@@ -39,7 +39,7 @@ async fn main() {
3939
let nodes = launch_nodes(
4040
topology,
4141
tensor_shape,
42-
ReduceKind::Mean,
42+
ReduceOperation::Mean,
4343
AllReduceStrategy::Ring,
4444
AllReduceStrategy::Tree(2),
4545
);
@@ -131,13 +131,13 @@ fn launch_orchestrator(test_files_dir: &str) -> Child {
131131
fn launch_nodes(
132132
topology: Vec<u32>,
133133
tensor_shape: Shape,
134-
reduce_kind: ReduceKind,
134+
reduce_op: ReduceOperation,
135135
global_strategy: AllReduceStrategy,
136136
local_strategy: AllReduceStrategy,
137137
) -> Vec<(String, Child)> {
138138
let total_device_count = topology.iter().sum();
139139
let (inputs, expected) =
140-
generate_random_input(tensor_shape, reduce_kind, total_device_count, 42);
140+
generate_random_input(tensor_shape, reduce_op, total_device_count, 42);
141141

142142
// URL for the global orchestrator on port 3000
143143
let global_url = "ws://localhost:3000";
@@ -156,7 +156,7 @@ fn launch_nodes(
156156
expected.clone(),
157157
node_count,
158158
global_address.clone(),
159-
reduce_kind,
159+
reduce_op,
160160
global_strategy,
161161
local_strategy,
162162
);
@@ -190,7 +190,7 @@ fn write_node_input(
190190
expected: TensorData,
191191
node_count: u32,
192192
global_address: Address,
193-
reduce_kind: ReduceKind,
193+
reduce_op: ReduceOperation,
194194
global_strategy: AllReduceStrategy,
195195
local_strategy: AllReduceStrategy,
196196
) -> String {
@@ -213,7 +213,7 @@ fn write_node_input(
213213
global_address,
214214
node_address,
215215
data_service_port,
216-
all_reduce_kind: reduce_kind,
216+
all_reduce_op: reduce_op,
217217
global_strategy,
218218
local_strategy,
219219
inputs,
@@ -231,7 +231,7 @@ fn write_node_input(
231231
/// Generates random input tensors and expected output based on the provided shape and reduce kind.
232232
fn generate_random_input(
233233
shape: Shape,
234-
reduce_kind: ReduceKind,
234+
reduce_kind: ReduceOperation,
235235
input_count: u32,
236236
seed: u64,
237237
) -> (Vec<TensorData>, TensorData) {
@@ -256,7 +256,7 @@ fn generate_random_input(
256256
expected_tensor = expected_tensor.add(input_tensor);
257257
}
258258

259-
if reduce_kind == ReduceKind::Mean {
259+
if reduce_kind == ReduceOperation::Mean {
260260
expected_tensor = expected_tensor.div_scalar(input_count);
261261
}
262262

crates/burn-collective/multinode-tests/src/shared.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use burn::tensor::TensorData;
2-
use burn_collective::{AllReduceStrategy, NodeId, ReduceKind};
2+
use burn_collective::{AllReduceStrategy, NodeId, ReduceOperation};
33
use burn_communication::Address;
44
use serde::{Deserialize, Serialize};
55

@@ -17,8 +17,8 @@ pub struct NodeTestData {
1717
pub node_address: Address,
1818
/// Node's data service port, for initializing the p2p tensor data service
1919
pub data_service_port: u16,
20-
/// What kind of aggregation
21-
pub all_reduce_kind: ReduceKind,
20+
/// What kind of all-reduce
21+
pub all_reduce_op: ReduceOperation,
2222
/// Node's data service port, for initializing the p2p tensor data service
2323
pub global_strategy: AllReduceStrategy,
2424
/// What kind of aggregation

crates/burn-collective/src/api.rs

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use burn_tensor::{Tensor, backend::Backend};
1+
use burn_tensor::{backend::Backend, Tensor};
22

33
use crate::{
4-
CollectiveConfig, global::shared::GlobalCollectiveError, local_server::get_collective_client,
4+
global::shared::GlobalCollectiveError, local_server::get_collective_client, CollectiveConfig, DeviceId, ReduceOperation
55
};
66

77
/// Errors from collective operations
@@ -33,32 +33,94 @@ pub enum CollectiveError {
3333
/// Registers a device. `num_devices` must be the same for every register,
3434
/// and `device_id` must be unique.
3535
///
36+
/// * `id` - The peer id of the caller
37+
///
3638
/// With auto-diff backends, make sure to use the inner backend.
37-
pub fn register<B: Backend>(config: &CollectiveConfig) -> Result<(), CollectiveError> {
39+
pub fn register<B: Backend>(
40+
id: DeviceId,
41+
config: CollectiveConfig,
42+
) -> Result<(), CollectiveError> {
3843
let mut client = get_collective_client::<B>();
39-
client.register(config)
44+
client.register(id, config)
4045
}
4146

4247
/// Calls for an all-reduce operation with the given parameters, and returns the result.
4348
/// The `params` must be the same as the parameters passed by the other nodes.
49+
///
50+
/// * `id` - The peer id of the caller
51+
/// * `tensor` - The input tensor to reduce with the peers' tensors
52+
/// * `config` - Config of the collective operation, must be coherent with the other calls
4453
pub fn all_reduce<B: Backend, const D: usize>(
54+
id: DeviceId,
4555
tensor: Tensor<B, D>,
46-
config: &CollectiveConfig,
56+
op: ReduceOperation,
4757
) -> Result<Tensor<B, D>, CollectiveError> {
4858
let client = get_collective_client::<B>();
4959
let device = tensor.device();
5060
let tensor = tensor.into_primitive().tensor();
51-
let primitive = client.all_reduce(tensor, config)?;
61+
let primitive = client.all_reduce(id, tensor, op)?;
5262
let tensor =
5363
Tensor::from_primitive(burn_tensor::TensorPrimitive::Float(primitive)).to_device(&device);
5464

5565
Ok(tensor)
5666
}
5767

68+
/// Broadcasts, or recives a broadcasted tensor.
69+
///
70+
/// * `id` - The peer id of the caller
71+
/// * `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive
72+
/// the broadcasted tensor.
73+
/// * `root` - The peer that will broadcast the tensor.
74+
/// * `config` - Config of the collective operation, must be coherent with the other calls
75+
///
76+
/// Returns the broadcasted tensor.
77+
pub fn broadcast<B: Backend, const D: usize>(
78+
id: DeviceId,
79+
tensor: Option<Tensor<B, D>>,
80+
_device: B::Device, // TODO `register` should return a client, and collective ops should be done on the client.
81+
root: DeviceId,
82+
) -> Result<Tensor<B, D>, CollectiveError> {
83+
let client = get_collective_client::<B>();
84+
let tensor = tensor.map(|tensor| {
85+
tensor.device();
86+
tensor.into_primitive().tensor()
87+
});
88+
let primitive = client.broadcast(id, tensor, root)?;
89+
let tensor =
90+
Tensor::from_primitive(burn_tensor::TensorPrimitive::Float(primitive));
91+
92+
Ok(tensor)
93+
}
94+
95+
/// Reduces a tensor onto one device.
96+
///
97+
/// * `id` - The peer id of the caller
98+
/// * `tensor` - The tensor to send as input
99+
/// * `root` - The ID of the peer that will receive the result.
100+
/// * `config` - Config of the collective operation, must be coherent with the other calls
101+
///
102+
/// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor.
103+
pub fn reduce<B: Backend, const D: usize>(
104+
id: DeviceId,
105+
tensor: Tensor<B, D>,
106+
op: ReduceOperation,
107+
root: DeviceId,
108+
) -> Result<Option<Tensor<B, D>>, CollectiveError> {
109+
let client = get_collective_client::<B>();
110+
let device = tensor.device();
111+
let tensor = tensor.into_primitive().tensor();
112+
let primitive = client.reduce(id, tensor, op, root)?;
113+
let tensor = primitive.map(|primitive| {
114+
Tensor::from_primitive(burn_tensor::TensorPrimitive::Float(primitive)).to_device(&device)
115+
});
116+
117+
Ok(tensor)
118+
}
119+
58120
/// Closes the collective session, unregistering the device
59-
pub fn finish_collective<B: Backend>(config: &CollectiveConfig) -> Result<(), CollectiveError> {
121+
pub fn finish_collective<B: Backend>(id: DeviceId) -> Result<(), CollectiveError> {
60122
let client = get_collective_client::<B>();
61-
client.finish(config.device_id)
123+
client.finish(id)
62124
}
63125

64126
/// Resets the local collective server. All registered callers and ongoing operations are forgotten

crates/burn-collective/src/client.rs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
use std::sync::mpsc::SyncSender;
2+
3+
use burn_tensor::backend::Backend;
4+
5+
use crate::{
6+
CollectiveConfig, CollectiveError, DeviceId, ReduceOperation,
7+
local_server::{
8+
AllReduceResult, BroadcastResult, FinishResult, Message, ReduceResult, RegisterResult,
9+
},
10+
};
11+
12+
#[derive(Clone)]
13+
pub(crate) struct LocalCollectiveClient<B: Backend> {
14+
pub channel: SyncSender<Message<B>>,
15+
}
16+
17+
impl<B: Backend> LocalCollectiveClient<B> {
18+
pub(crate) fn reset(&self) {
19+
self.channel.send(Message::Reset).unwrap();
20+
}
21+
22+
pub(crate) fn register(&mut self, id: DeviceId, config: CollectiveConfig) -> RegisterResult {
23+
if config.is_valid() {
24+
return Err(CollectiveError::InvalidConfig);
25+
}
26+
27+
let (callback, rec) = std::sync::mpsc::sync_channel::<RegisterResult>(1);
28+
29+
self.channel
30+
.send(Message::Register {
31+
device_id: id,
32+
config,
33+
callback,
34+
})
35+
.unwrap();
36+
37+
rec.recv()
38+
.unwrap_or(Err(CollectiveError::LocalServerMissing))
39+
}
40+
41+
pub(crate) fn all_reduce(
42+
&self,
43+
id: DeviceId,
44+
tensor: B::FloatTensorPrimitive,
45+
op: ReduceOperation,
46+
) -> AllReduceResult<B::FloatTensorPrimitive> {
47+
let (callback, rec) =
48+
std::sync::mpsc::sync_channel::<AllReduceResult<B::FloatTensorPrimitive>>(1);
49+
let msg = Message::AllReduce {
50+
device_id: id,
51+
tensor,
52+
op,
53+
callback,
54+
};
55+
56+
self.channel.send(msg).unwrap();
57+
58+
// returns a tensor primitive that may or may not be on the correct device,
59+
// depending on the strategy used.
60+
rec.recv()
61+
.unwrap_or(Err(CollectiveError::LocalServerMissing))
62+
}
63+
64+
pub(crate) fn reduce(
65+
&self,
66+
id: DeviceId,
67+
tensor: B::FloatTensorPrimitive,
68+
op: ReduceOperation,
69+
root: DeviceId,
70+
) -> ReduceResult<B::FloatTensorPrimitive> {
71+
let (callback, rec) =
72+
std::sync::mpsc::sync_channel::<ReduceResult<B::FloatTensorPrimitive>>(1);
73+
let msg = Message::Reduce {
74+
device_id: id,
75+
tensor,
76+
op,
77+
root,
78+
callback,
79+
};
80+
81+
self.channel.send(msg).unwrap();
82+
83+
// returns a tensor or none depending on if this device is the root
84+
rec.recv()
85+
.unwrap_or(Err(CollectiveError::LocalServerMissing))
86+
}
87+
88+
pub(crate) fn broadcast(
89+
&self,
90+
id: DeviceId,
91+
tensor: Option<B::FloatTensorPrimitive>,
92+
root: DeviceId,
93+
) -> BroadcastResult<B::FloatTensorPrimitive> {
94+
let (callback, rec) =
95+
std::sync::mpsc::sync_channel::<BroadcastResult<B::FloatTensorPrimitive>>(1);
96+
let msg = Message::Broadcast {
97+
device_id: id,
98+
tensor,
99+
root,
100+
callback,
101+
};
102+
103+
self.channel.send(msg).unwrap();
104+
105+
// returns a tensor or none depending on if this device is the root
106+
rec.recv()
107+
.unwrap_or(Err(CollectiveError::LocalServerMissing))
108+
}
109+
110+
pub(crate) fn finish(&self, id: DeviceId) -> FinishResult {
111+
let (callback, rec) = std::sync::mpsc::sync_channel::<FinishResult>(1);
112+
self.channel.send(Message::Finish { id, callback }).unwrap();
113+
114+
rec.recv()
115+
.unwrap_or(Err(CollectiveError::LocalServerMissing))
116+
}
117+
}

0 commit comments

Comments
 (0)