Skip to content

Commit 8e86b7c

Browse files
committed
ci: execute wgpu tests on self-hosted runner
1 parent 0e1c451 commit 8e86b7c

File tree

2 files changed

+93
-81
lines changed

2 files changed

+93
-81
lines changed

.github/workflows/test.yml

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,35 @@ jobs:
199199
- name: Tests (burn-vulkan)
200200
run: cargo xtask test --ci gcp-vulkan-runner
201201

202+
linux-std-wgpu-tests:
203+
needs: [prepare-checks, code-quality]
204+
# '@id:' label must be unique within this worklow
205+
runs-on: [
206+
'@id:wgpu-job-${{github.run_id}}-${{github.run_attempt}}',
207+
'@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}',
208+
'@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}',
209+
'@zone:${{ needs.prepare-checks.outputs.gcp_runners_zone }}',
210+
'gpu' ]
211+
env:
212+
# disable incremental compilation (reduces artifact size)
213+
CARGO_PROFILE_TEST_INCREMENTAL: 'false'
214+
# Keep the stragegy to be able to easily add new rust versions if required
215+
strategy:
216+
matrix:
217+
rust: [stable]
218+
include:
219+
- rust: stable
220+
toolchain: stable
221+
steps:
222+
- name: Setup Rust
223+
uses: tracel-ai/github-actions/setup-rust@v3
224+
with:
225+
rust-toolchain: ${{ matrix.toolchain }}
226+
enable-cache: false
227+
# --------------------------------------------------------------------------------
228+
- name: Tests (burn-wgpu)
229+
run: cargo xtask test --ci gcp-wgpu-runner
230+
202231
linux-std-tests:
203232
runs-on: ubuntu-22.04
204233
needs: [prepare-checks, code-quality]
@@ -223,15 +252,7 @@ jobs:
223252
cache-key: ${{ matrix.rust }}-linux
224253
# Disable cache on linux-std (stable) runner which currently always runs out of disk space with tests + coverage
225254
enable-cache: ${{ matrix.rust != 'stable' }}
226-
# --------------------------------------------------------------------------------
227-
- name: Setup Linux runner
228-
uses: tracel-ai/github-actions/setup-linux@v3
229-
with:
230-
vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }}
231-
mesa-version: ${{ env.MESA_VERSION }}
232-
mesa-ci-build-version: ${{ env.MESA_CI_BINARY_BUILD }}
233-
cargo-package-to-clean: burn-tch
234-
# --------------------------------------------------------------------------------
255+
# # --------------------------------------------------------------------------------
235256
- name: Install grcov
236257
if: matrix.rust == 'stable'
237258
shell: bash
@@ -335,4 +356,4 @@ jobs:
335356
cache-key: ${{ matrix.rust }}-macos
336357
# --------------------------------------------------------------------------------
337358
- name: Tests
338-
run: CUBECL_DEBUG_OPTION=profile cargo xtask test --release --ci github-runner
359+
run: CUBECL_DEBUG_OPTION=profile cargo xtask test --release --ci github-mac-runner

xtask/src/commands/test.rs

Lines changed: 62 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ pub struct BurnTestCmdArgs {
1313
#[derive(Debug, Clone, ValueEnum, PartialEq)]
1414
pub enum CiTestType {
1515
GithubRunner,
16+
GithubMacRunner,
1617
GcpCudaRunner,
1718
GcpVulkanRunner,
19+
GcpWgpuRunner,
1820
}
1921

2022
pub(crate) fn handle_command(
@@ -40,32 +42,25 @@ pub(crate) fn handle_command(
4042
Ok(())
4143
}
4244
Context::Std => {
43-
let disable_wgpu = std::env::var("DISABLE_WGPU")
44-
.map(|val| val == "1" || val == "true")
45-
.unwrap_or(false);
46-
45+
// 1) Tests with default features
46+
// ------------------------------
4747
match args.ci {
48-
CiTestType::GithubRunner => {
48+
CiTestType::GithubRunner | CiTestType::GithubMacRunner => {
4949
// Exclude crates that are not supported on CI
5050
args.exclude.extend(vec![
5151
"burn-cuda".to_string(),
5252
"burn-rocm".to_string(),
53+
// "burn-router" uses "burn-wgpu" for the tests.
54+
"burn-router".to_string(),
5355
"burn-tch".to_string(),
56+
"burn-wgpu".to_string(),
5457
]);
5558

5659
// Burn remote tests don't work on windows for now
5760
#[cfg(target_os = "windows")]
5861
{
5962
args.exclude.extend(vec!["burn-remote".to_string()]);
6063
};
61-
62-
if disable_wgpu {
63-
args.exclude.extend(vec![
64-
"burn-wgpu".to_string(),
65-
// "burn-router" uses "burn-wgpu" for the tests.
66-
"burn-router".to_string(),
67-
]);
68-
};
6964
}
7065
CiTestType::GcpCudaRunner => {
7166
args.target = Target::AllPackages;
@@ -78,41 +73,66 @@ pub(crate) fn handle_command(
7873
.get_or_insert_with(Vec::new)
7974
.push("vulkan".to_string());
8075
}
76+
CiTestType::GcpWgpuRunner => {
77+
args.target = Target::AllPackages;
78+
// "burn-router" uses "burn-wgpu" for the tests.
79+
args.only
80+
.extend(vec!["burn-wgpu".to_string(), "burn-router".to_string()]);
81+
}
8182
}
8283

8384
// test workspace
8485
base_commands::test::handle_command(args.clone().try_into().unwrap(), env, context)?;
8586

86-
// Specific additional commands to test specific features
87-
if args.ci == CiTestType::GithubRunner {
88-
// burn-dataset
89-
helpers::custom_crates_tests(
90-
vec!["burn-dataset"],
91-
vec!["--all-features"],
92-
None,
93-
None,
94-
"std all features",
95-
)?;
96-
97-
// burn-core
98-
helpers::custom_crates_tests(
99-
vec!["burn-core"],
100-
vec!["--features", "test-tch,record-item-custom-serde"],
101-
None,
102-
None,
103-
"std with features: test-tch,record-item-custom-serde",
104-
)?;
87+
// 2) Specific additional commands to test specific features
88+
// ---------------------------------------------------------
89+
match args.ci {
90+
CiTestType::GithubRunner => {
91+
// burn-dataset
92+
helpers::custom_crates_tests(
93+
vec!["burn-dataset"],
94+
vec!["--all-features"],
95+
None,
96+
None,
97+
"std all features",
98+
)?;
10599

106-
// burn-vision
107-
helpers::custom_crates_tests(
108-
vec!["burn-vision"],
109-
vec!["--features", "test-cpu"],
110-
None,
111-
None,
112-
"std cpu",
113-
)?;
100+
// burn-core
101+
helpers::custom_crates_tests(
102+
vec!["burn-core"],
103+
vec!["--features", "test-tch,record-item-custom-serde"],
104+
None,
105+
None,
106+
"std with features: test-tch,record-item-custom-serde",
107+
)?;
114108

115-
if !disable_wgpu {
109+
// burn-vision
110+
helpers::custom_crates_tests(
111+
vec!["burn-vision"],
112+
vec!["--features", "test-cpu"],
113+
None,
114+
None,
115+
"std cpu",
116+
)?;
117+
}
118+
CiTestType::GcpCudaRunner => (),
119+
CiTestType::GcpVulkanRunner => {
120+
helpers::custom_crates_tests(
121+
vec!["burn-core"],
122+
vec!["--features", "test-wgpu-spirv"],
123+
None,
124+
None,
125+
"std vulkan",
126+
)?;
127+
helpers::custom_crates_tests(
128+
vec!["burn-vision"],
129+
vec!["--features", "test-vulkan"],
130+
None,
131+
None,
132+
"std vulkan",
133+
)?;
134+
}
135+
CiTestType::GcpWgpuRunner => {
116136
helpers::custom_crates_tests(
117137
vec!["burn-core"],
118138
vec!["--features", "test-wgpu"],
@@ -127,36 +147,8 @@ pub(crate) fn handle_command(
127147
None,
128148
"std wgpu",
129149
)?;
130-
131-
// Vulkan isn't available on MacOS
132-
#[cfg(not(target_os = "macos"))]
133-
{
134-
let disable_wgpu_spirv = std::env::var("DISABLE_WGPU_SPIRV")
135-
.map(|val| val == "1" || val == "true")
136-
.unwrap_or(false);
137-
138-
if !disable_wgpu_spirv {
139-
helpers::custom_crates_tests(
140-
vec!["burn-core"],
141-
vec!["--features", "test-wgpu-spirv"],
142-
None,
143-
None,
144-
"std vulkan",
145-
)?;
146-
helpers::custom_crates_tests(
147-
vec!["burn-vision"],
148-
vec!["--features", "test-vulkan"],
149-
None,
150-
None,
151-
"std vulkan",
152-
)?;
153-
}
154-
}
155150
}
156-
157-
// MacOS specific tests
158-
#[cfg(target_os = "macos")]
159-
{
151+
CiTestType::GithubMacRunner => {
160152
// burn-candle
161153
helpers::custom_crates_tests(
162154
vec!["burn-candle"],
@@ -175,7 +167,6 @@ pub(crate) fn handle_command(
175167
)?;
176168
}
177169
}
178-
179170
Ok(())
180171
}
181172
Context::All => Context::value_variants()

0 commit comments

Comments
 (0)