@@ -13,8 +13,10 @@ pub struct BurnTestCmdArgs {
13
13
#[ derive( Debug , Clone , ValueEnum , PartialEq ) ]
14
14
pub enum CiTestType {
15
15
GithubRunner ,
16
+ GithubMacRunner ,
16
17
GcpCudaRunner ,
17
18
GcpVulkanRunner ,
19
+ GcpWgpuRunner ,
18
20
}
19
21
20
22
pub ( crate ) fn handle_command (
@@ -40,32 +42,25 @@ pub(crate) fn handle_command(
40
42
Ok ( ( ) )
41
43
}
42
44
Context :: Std => {
43
- let disable_wgpu = std:: env:: var ( "DISABLE_WGPU" )
44
- . map ( |val| val == "1" || val == "true" )
45
- . unwrap_or ( false ) ;
46
-
45
+ // 1) Tests with default features
46
+ // ------------------------------
47
47
match args. ci {
48
- CiTestType :: GithubRunner => {
48
+ CiTestType :: GithubRunner | CiTestType :: GithubMacRunner => {
49
49
// Exclude crates that are not supported on CI
50
50
args. exclude . extend ( vec ! [
51
51
"burn-cuda" . to_string( ) ,
52
52
"burn-rocm" . to_string( ) ,
53
+ // "burn-router" uses "burn-wgpu" for the tests.
54
+ "burn-router" . to_string( ) ,
53
55
"burn-tch" . to_string( ) ,
56
+ "burn-wgpu" . to_string( ) ,
54
57
] ) ;
55
58
56
59
// Burn remote tests don't work on windows for now
57
60
#[ cfg( target_os = "windows" ) ]
58
61
{
59
62
args. exclude . extend ( vec ! [ "burn-remote" . to_string( ) ] ) ;
60
63
} ;
61
-
62
- if disable_wgpu {
63
- args. exclude . extend ( vec ! [
64
- "burn-wgpu" . to_string( ) ,
65
- // "burn-router" uses "burn-wgpu" for the tests.
66
- "burn-router" . to_string( ) ,
67
- ] ) ;
68
- } ;
69
64
}
70
65
CiTestType :: GcpCudaRunner => {
71
66
args. target = Target :: AllPackages ;
@@ -78,41 +73,66 @@ pub(crate) fn handle_command(
78
73
. get_or_insert_with ( Vec :: new)
79
74
. push ( "vulkan" . to_string ( ) ) ;
80
75
}
76
+ CiTestType :: GcpWgpuRunner => {
77
+ args. target = Target :: AllPackages ;
78
+ // "burn-router" uses "burn-wgpu" for the tests.
79
+ args. only
80
+ . extend ( vec ! [ "burn-wgpu" . to_string( ) , "burn-router" . to_string( ) ] ) ;
81
+ }
81
82
}
82
83
83
84
// test workspace
84
85
base_commands:: test:: handle_command ( args. clone ( ) . try_into ( ) . unwrap ( ) , env, context) ?;
85
86
86
- // Specific additional commands to test specific features
87
- if args. ci == CiTestType :: GithubRunner {
88
- // burn-dataset
89
- helpers:: custom_crates_tests (
90
- vec ! [ "burn-dataset" ] ,
91
- vec ! [ "--all-features" ] ,
92
- None ,
93
- None ,
94
- "std all features" ,
95
- ) ?;
96
-
97
- // burn-core
98
- helpers:: custom_crates_tests (
99
- vec ! [ "burn-core" ] ,
100
- vec ! [ "--features" , "test-tch,record-item-custom-serde" ] ,
101
- None ,
102
- None ,
103
- "std with features: test-tch,record-item-custom-serde" ,
104
- ) ?;
87
+ // 2) Specific additional commands to test specific features
88
+ // ---------------------------------------------------------
89
+ match args. ci {
90
+ CiTestType :: GithubRunner => {
91
+ // burn-dataset
92
+ helpers:: custom_crates_tests (
93
+ vec ! [ "burn-dataset" ] ,
94
+ vec ! [ "--all-features" ] ,
95
+ None ,
96
+ None ,
97
+ "std all features" ,
98
+ ) ?;
105
99
106
- // burn-vision
107
- helpers:: custom_crates_tests (
108
- vec ! [ "burn-vision " ] ,
109
- vec ! [ "--features" , "test-cpu " ] ,
110
- None ,
111
- None ,
112
- "std cpu " ,
113
- ) ?;
100
+ // burn-core
101
+ helpers:: custom_crates_tests (
102
+ vec ! [ "burn-core " ] ,
103
+ vec ! [ "--features" , "test-tch,record-item-custom-serde " ] ,
104
+ None ,
105
+ None ,
106
+ "std with features: test-tch,record-item-custom-serde " ,
107
+ ) ?;
114
108
115
- if !disable_wgpu {
109
+ // burn-vision
110
+ helpers:: custom_crates_tests (
111
+ vec ! [ "burn-vision" ] ,
112
+ vec ! [ "--features" , "test-cpu" ] ,
113
+ None ,
114
+ None ,
115
+ "std cpu" ,
116
+ ) ?;
117
+ }
118
+ CiTestType :: GcpCudaRunner => ( ) ,
119
+ CiTestType :: GcpVulkanRunner => {
120
+ helpers:: custom_crates_tests (
121
+ vec ! [ "burn-core" ] ,
122
+ vec ! [ "--features" , "test-wgpu-spirv" ] ,
123
+ None ,
124
+ None ,
125
+ "std vulkan" ,
126
+ ) ?;
127
+ helpers:: custom_crates_tests (
128
+ vec ! [ "burn-vision" ] ,
129
+ vec ! [ "--features" , "test-vulkan" ] ,
130
+ None ,
131
+ None ,
132
+ "std vulkan" ,
133
+ ) ?;
134
+ }
135
+ CiTestType :: GcpWgpuRunner => {
116
136
helpers:: custom_crates_tests (
117
137
vec ! [ "burn-core" ] ,
118
138
vec ! [ "--features" , "test-wgpu" ] ,
@@ -127,36 +147,8 @@ pub(crate) fn handle_command(
127
147
None ,
128
148
"std wgpu" ,
129
149
) ?;
130
-
131
- // Vulkan isn't available on MacOS
132
- #[ cfg( not( target_os = "macos" ) ) ]
133
- {
134
- let disable_wgpu_spirv = std:: env:: var ( "DISABLE_WGPU_SPIRV" )
135
- . map ( |val| val == "1" || val == "true" )
136
- . unwrap_or ( false ) ;
137
-
138
- if !disable_wgpu_spirv {
139
- helpers:: custom_crates_tests (
140
- vec ! [ "burn-core" ] ,
141
- vec ! [ "--features" , "test-wgpu-spirv" ] ,
142
- None ,
143
- None ,
144
- "std vulkan" ,
145
- ) ?;
146
- helpers:: custom_crates_tests (
147
- vec ! [ "burn-vision" ] ,
148
- vec ! [ "--features" , "test-vulkan" ] ,
149
- None ,
150
- None ,
151
- "std vulkan" ,
152
- ) ?;
153
- }
154
- }
155
150
}
156
-
157
- // MacOS specific tests
158
- #[ cfg( target_os = "macos" ) ]
159
- {
151
+ CiTestType :: GithubMacRunner => {
160
152
// burn-candle
161
153
helpers:: custom_crates_tests (
162
154
vec ! [ "burn-candle" ] ,
@@ -175,7 +167,6 @@ pub(crate) fn handle_command(
175
167
) ?;
176
168
}
177
169
}
178
-
179
170
Ok ( ( ) )
180
171
}
181
172
Context :: All => Context :: value_variants ( )
0 commit comments