Skip to content

Commit 1aa0183

Browse files
committed
ci: execute wgpu tests on self-hosted runner
1 parent 244b2ca commit 1aa0183

File tree

2 files changed

+122
-100
lines changed

2 files changed

+122
-100
lines changed

.github/workflows/test.yml

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,35 @@ jobs:
199199
- name: Tests (burn-vulkan)
200200
run: cargo xtask test --ci gcp-vulkan-runner
201201

202+
linux-std-wgpu-tests:
203+
needs: [prepare-checks, code-quality]
204+
# '@id:' label must be unique within this worklow
205+
runs-on: [
206+
'@id:wgpu-job-${{github.run_id}}-${{github.run_attempt}}',
207+
'@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}',
208+
'@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}',
209+
'@zone:${{ needs.prepare-checks.outputs.gcp_runners_zone }}',
210+
'gpu' ]
211+
env:
212+
# disable incremental compilation (reduces artifact size)
213+
CARGO_PROFILE_TEST_INCREMENTAL: 'false'
214+
# Keep the stragegy to be able to easily add new rust versions if required
215+
strategy:
216+
matrix:
217+
rust: [stable]
218+
include:
219+
- rust: stable
220+
toolchain: stable
221+
steps:
222+
- name: Setup Rust
223+
uses: tracel-ai/github-actions/setup-rust@v3
224+
with:
225+
rust-toolchain: ${{ matrix.toolchain }}
226+
enable-cache: false
227+
# --------------------------------------------------------------------------------
228+
- name: Tests (burn-vulkan)
229+
run: cargo xtask test --ci gcp-wgpu-runner
230+
202231
linux-std-tests:
203232
runs-on: ubuntu-22.04
204233
needs: [prepare-checks, code-quality]
@@ -223,14 +252,14 @@ jobs:
223252
cache-key: ${{ matrix.rust }}-linux
224253
# Disable cache on linux-std (stable) runner which currently always runs out of disk space with tests + coverage
225254
enable-cache: ${{ matrix.rust != 'stable' }}
226-
# --------------------------------------------------------------------------------
227-
- name: Setup Linux runner
228-
uses: tracel-ai/github-actions/setup-linux@v3
229-
with:
230-
vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }}
231-
mesa-version: ${{ env.MESA_VERSION }}
232-
mesa-ci-build-version: ${{ env.MESA_CI_BINARY_BUILD }}
233-
cargo-package-to-clean: burn-tch
255+
# # --------------------------------------------------------------------------------
256+
# - name: Setup Linux runner
257+
# uses: tracel-ai/github-actions/setup-linux@v3
258+
# with:
259+
# vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }}
260+
# mesa-version: ${{ env.MESA_VERSION }}
261+
# mesa-ci-build-version: ${{ env.MESA_CI_BINARY_BUILD }}
262+
# cargo-package-to-clean: burn-tch
234263
# --------------------------------------------------------------------------------
235264
- name: Install grcov
236265
if: matrix.rust == 'stable'
@@ -335,4 +364,4 @@ jobs:
335364
cache-key: ${{ matrix.rust }}-macos
336365
# --------------------------------------------------------------------------------
337366
- name: Tests
338-
run: CUBECL_DEBUG_OPTION=profile cargo xtask test --release --ci github-runner
367+
run: CUBECL_DEBUG_OPTION=profile cargo xtask test --release --ci github-mac-runner

xtask/src/commands/test.rs

Lines changed: 84 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ pub struct BurnTestCmdArgs {
1313
#[derive(Debug, Clone, ValueEnum, PartialEq)]
1414
pub enum CiTestType {
1515
GithubRunner,
16+
GithubMacRunner,
1617
GcpCudaRunner,
1718
GcpVulkanRunner,
19+
GcpWgpuRunner,
1820
}
1921

2022
pub(crate) fn handle_command(
@@ -40,123 +42,115 @@ pub(crate) fn handle_command(
4042
Ok(())
4143
}
4244
Context::Std => {
43-
let disable_wgpu = std::env::var("DISABLE_WGPU")
44-
.map(|val| val == "1" || val == "true")
45-
.unwrap_or(false);
46-
45+
// 1) Tests with default features
46+
// ------------------------------
4747
match args.ci {
48-
CiTestType::GithubRunner => {
48+
CiTestType::GithubRunner | CiTestType::GithubMacRunner => {
4949
// Exclude crates that are not supported on CI
5050
args.exclude.extend(vec![
5151
"burn-cuda".to_string(),
5252
"burn-rocm".to_string(),
53+
// "burn-router" uses "burn-wgpu" for the tests.
54+
"burn-router".to_string(),
5355
"burn-tch".to_string(),
56+
"burn-wgpu".to_string(),
5457
]);
5558

5659
// Burn remote tests don't work on windows for now
5760
#[cfg(target_os = "windows")]
5861
{
5962
args.exclude.extend(vec!["burn-remote".to_string()]);
6063
};
61-
62-
if disable_wgpu {
63-
args.exclude.extend(vec![
64-
"burn-wgpu".to_string(),
65-
// "burn-router" uses "burn-wgpu" for the tests.
66-
"burn-router".to_string(),
67-
]);
68-
};
69-
}
64+
},
7065
CiTestType::GcpCudaRunner => {
7166
args.target = Target::AllPackages;
7267
args.only.push("burn-cuda".to_string());
73-
}
68+
},
7469
CiTestType::GcpVulkanRunner => {
7570
args.target = Target::AllPackages;
7671
args.only.push("burn-wgpu".to_string());
7772
args.features
7873
.get_or_insert_with(Vec::new)
7974
.push("vulkan".to_string());
75+
},
76+
CiTestType::GcpWgpuRunner => {
77+
args.target = Target::AllPackages;
78+
// "burn-router" uses "burn-wgpu" for the tests.
79+
args.only.extend(vec![
80+
"burn-wgpu".to_string(),
81+
"burn-router".to_string()
82+
]);
8083
}
8184
}
8285

8386
// test workspace
8487
base_commands::test::handle_command(args.clone().try_into().unwrap(), env, context)?;
8588

86-
// Specific additional commands to test specific features
87-
if args.ci == CiTestType::GithubRunner {
88-
// burn-dataset
89-
helpers::custom_crates_tests(
90-
vec!["burn-dataset"],
91-
vec!["--all-features"],
92-
None,
93-
None,
94-
"std all features",
95-
)?;
96-
97-
// burn-core
98-
helpers::custom_crates_tests(
99-
vec!["burn-core"],
100-
vec!["--features", "test-tch,record-item-custom-serde"],
101-
None,
102-
None,
103-
"std with features: test-tch,record-item-custom-serde",
104-
)?;
105-
106-
// burn-vision
107-
helpers::custom_crates_tests(
108-
vec!["burn-vision"],
109-
vec!["--features", "test-cpu"],
110-
None,
111-
None,
112-
"std cpu",
113-
)?;
114-
115-
if !disable_wgpu {
116-
helpers::custom_crates_tests(
117-
vec!["burn-core"],
118-
vec!["--features", "test-wgpu"],
119-
None,
120-
None,
121-
"std wgpu",
122-
)?;
123-
helpers::custom_crates_tests(
124-
vec!["burn-vision"],
125-
vec!["--features", "test-wgpu"],
126-
None,
127-
None,
128-
"std wgpu",
129-
)?;
130-
131-
// Vulkan isn't available on MacOS
132-
#[cfg(not(target_os = "macos"))]
133-
{
134-
let disable_wgpu_spirv = std::env::var("DISABLE_WGPU_SPIRV")
135-
.map(|val| val == "1" || val == "true")
136-
.unwrap_or(false);
137-
138-
if !disable_wgpu_spirv {
139-
helpers::custom_crates_tests(
140-
vec!["burn-core"],
141-
vec!["--features", "test-wgpu-spirv"],
142-
None,
143-
None,
144-
"std vulkan",
145-
)?;
146-
helpers::custom_crates_tests(
147-
vec!["burn-vision"],
148-
vec!["--features", "test-vulkan"],
149-
None,
150-
None,
151-
"std vulkan",
152-
)?;
153-
}
154-
}
155-
}
156-
157-
// MacOS specific tests
158-
#[cfg(target_os = "macos")]
159-
{
89+
// 2) Specific additional commands to test specific features
90+
// ---------------------------------------------------------
91+
match args.ci {
92+
CiTestType::GithubRunner => {
93+
// burn-dataset
94+
helpers::custom_crates_tests(
95+
vec!["burn-dataset"],
96+
vec!["--all-features"],
97+
None,
98+
None,
99+
"std all features",
100+
)?;
101+
102+
// burn-core
103+
helpers::custom_crates_tests(
104+
vec!["burn-core"],
105+
vec!["--features", "test-tch,record-item-custom-serde"],
106+
None,
107+
None,
108+
"std with features: test-tch,record-item-custom-serde",
109+
)?;
110+
111+
// burn-vision
112+
helpers::custom_crates_tests(
113+
vec!["burn-vision"],
114+
vec!["--features", "test-cpu"],
115+
None,
116+
None,
117+
"std cpu",
118+
)?;
119+
},
120+
CiTestType::GcpCudaRunner => (),
121+
CiTestType::GcpVulkanRunner => {
122+
helpers::custom_crates_tests(
123+
vec!["burn-core"],
124+
vec!["--features", "test-wgpu-spirv"],
125+
None,
126+
None,
127+
"std vulkan",
128+
)?;
129+
helpers::custom_crates_tests(
130+
vec!["burn-vision"],
131+
vec!["--features", "test-vulkan"],
132+
None,
133+
None,
134+
"std vulkan",
135+
)?;
136+
},
137+
CiTestType::GcpWgpuRunner => {
138+
helpers::custom_crates_tests(
139+
vec!["burn-core"],
140+
vec!["--features", "test-wgpu"],
141+
None,
142+
None,
143+
"std wgpu",
144+
)?;
145+
helpers::custom_crates_tests(
146+
vec!["burn-vision"],
147+
vec!["--features", "test-wgpu"],
148+
None,
149+
None,
150+
"std wgpu",
151+
)?;
152+
},
153+
CiTestType::GithubMacRunner => {
160154
// burn-candle
161155
helpers::custom_crates_tests(
162156
vec!["burn-candle"],
@@ -173,9 +167,8 @@ pub(crate) fn handle_command(
173167
None,
174168
"std blas-accelerate",
175169
)?;
176-
}
170+
},
177171
}
178-
179172
Ok(())
180173
}
181174
Context::All => Context::value_variants()

0 commit comments

Comments
 (0)