@@ -13,8 +13,10 @@ pub struct BurnTestCmdArgs {
13
13
#[ derive( Debug , Clone , ValueEnum , PartialEq ) ]
14
14
pub enum CiTestType {
15
15
GithubRunner ,
16
+ GithubMacRunner ,
16
17
GcpCudaRunner ,
17
18
GcpVulkanRunner ,
19
+ GcpWgpuRunner ,
18
20
}
19
21
20
22
pub ( crate ) fn handle_command (
@@ -40,123 +42,115 @@ pub(crate) fn handle_command(
40
42
Ok ( ( ) )
41
43
}
42
44
Context :: Std => {
43
- let disable_wgpu = std:: env:: var ( "DISABLE_WGPU" )
44
- . map ( |val| val == "1" || val == "true" )
45
- . unwrap_or ( false ) ;
46
-
45
+ // 1) Tests with default features
46
+ // ------------------------------
47
47
match args. ci {
48
- CiTestType :: GithubRunner => {
48
+ CiTestType :: GithubRunner | CiTestType :: GithubMacRunner => {
49
49
// Exclude crates that are not supported on CI
50
50
args. exclude . extend ( vec ! [
51
51
"burn-cuda" . to_string( ) ,
52
52
"burn-rocm" . to_string( ) ,
53
+ // "burn-router" uses "burn-wgpu" for the tests.
54
+ "burn-router" . to_string( ) ,
53
55
"burn-tch" . to_string( ) ,
56
+ "burn-wgpu" . to_string( ) ,
54
57
] ) ;
55
58
56
59
// Burn remote tests don't work on windows for now
57
60
#[ cfg( target_os = "windows" ) ]
58
61
{
59
62
args. exclude . extend ( vec ! [ "burn-remote" . to_string( ) ] ) ;
60
63
} ;
61
-
62
- if disable_wgpu {
63
- args. exclude . extend ( vec ! [
64
- "burn-wgpu" . to_string( ) ,
65
- // "burn-router" uses "burn-wgpu" for the tests.
66
- "burn-router" . to_string( ) ,
67
- ] ) ;
68
- } ;
69
- }
64
+ } ,
70
65
CiTestType :: GcpCudaRunner => {
71
66
args. target = Target :: AllPackages ;
72
67
args. only . push ( "burn-cuda" . to_string ( ) ) ;
73
- }
68
+ } ,
74
69
CiTestType :: GcpVulkanRunner => {
75
70
args. target = Target :: AllPackages ;
76
71
args. only . push ( "burn-wgpu" . to_string ( ) ) ;
77
72
args. features
78
73
. get_or_insert_with ( Vec :: new)
79
74
. push ( "vulkan" . to_string ( ) ) ;
75
+ } ,
76
+ CiTestType :: GcpWgpuRunner => {
77
+ args. target = Target :: AllPackages ;
78
+ // "burn-router" uses "burn-wgpu" for the tests.
79
+ args. only . extend ( vec ! [
80
+ "burn-wgpu" . to_string( ) ,
81
+ "burn-router" . to_string( )
82
+ ] ) ;
80
83
}
81
84
}
82
85
83
86
// test workspace
84
87
base_commands:: test:: handle_command ( args. clone ( ) . try_into ( ) . unwrap ( ) , env, context) ?;
85
88
86
- // Specific additional commands to test specific features
87
- if args. ci == CiTestType :: GithubRunner {
88
- // burn-dataset
89
- helpers:: custom_crates_tests (
90
- vec ! [ "burn-dataset" ] ,
91
- vec ! [ "--all-features" ] ,
92
- None ,
93
- None ,
94
- "std all features" ,
95
- ) ?;
96
-
97
- // burn-core
98
- helpers:: custom_crates_tests (
99
- vec ! [ "burn-core" ] ,
100
- vec ! [ "--features" , "test-tch,record-item-custom-serde" ] ,
101
- None ,
102
- None ,
103
- "std with features: test-tch,record-item-custom-serde" ,
104
- ) ?;
105
-
106
- // burn-vision
107
- helpers:: custom_crates_tests (
108
- vec ! [ "burn-vision" ] ,
109
- vec ! [ "--features" , "test-cpu" ] ,
110
- None ,
111
- None ,
112
- "std cpu" ,
113
- ) ?;
114
-
115
- if !disable_wgpu {
116
- helpers:: custom_crates_tests (
117
- vec ! [ "burn-core" ] ,
118
- vec ! [ "--features" , "test-wgpu" ] ,
119
- None ,
120
- None ,
121
- "std wgpu" ,
122
- ) ?;
123
- helpers:: custom_crates_tests (
124
- vec ! [ "burn-vision" ] ,
125
- vec ! [ "--features" , "test-wgpu" ] ,
126
- None ,
127
- None ,
128
- "std wgpu" ,
129
- ) ?;
130
-
131
- // Vulkan isn't available on MacOS
132
- #[ cfg( not( target_os = "macos" ) ) ]
133
- {
134
- let disable_wgpu_spirv = std:: env:: var ( "DISABLE_WGPU_SPIRV" )
135
- . map ( |val| val == "1" || val == "true" )
136
- . unwrap_or ( false ) ;
137
-
138
- if !disable_wgpu_spirv {
139
- helpers:: custom_crates_tests (
140
- vec ! [ "burn-core" ] ,
141
- vec ! [ "--features" , "test-wgpu-spirv" ] ,
142
- None ,
143
- None ,
144
- "std vulkan" ,
145
- ) ?;
146
- helpers:: custom_crates_tests (
147
- vec ! [ "burn-vision" ] ,
148
- vec ! [ "--features" , "test-vulkan" ] ,
149
- None ,
150
- None ,
151
- "std vulkan" ,
152
- ) ?;
153
- }
154
- }
155
- }
156
-
157
- // MacOS specific tests
158
- #[ cfg( target_os = "macos" ) ]
159
- {
89
+ // 2) Specific additional commands to test specific features
90
+ // ---------------------------------------------------------
91
+ match args. ci {
92
+ CiTestType :: GithubRunner => {
93
+ // burn-dataset
94
+ helpers:: custom_crates_tests (
95
+ vec ! [ "burn-dataset" ] ,
96
+ vec ! [ "--all-features" ] ,
97
+ None ,
98
+ None ,
99
+ "std all features" ,
100
+ ) ?;
101
+
102
+ // burn-core
103
+ helpers:: custom_crates_tests (
104
+ vec ! [ "burn-core" ] ,
105
+ vec ! [ "--features" , "test-tch,record-item-custom-serde" ] ,
106
+ None ,
107
+ None ,
108
+ "std with features: test-tch,record-item-custom-serde" ,
109
+ ) ?;
110
+
111
+ // burn-vision
112
+ helpers:: custom_crates_tests (
113
+ vec ! [ "burn-vision" ] ,
114
+ vec ! [ "--features" , "test-cpu" ] ,
115
+ None ,
116
+ None ,
117
+ "std cpu" ,
118
+ ) ?;
119
+ } ,
120
+ CiTestType :: GcpCudaRunner => ( ) ,
121
+ CiTestType :: GcpVulkanRunner => {
122
+ helpers:: custom_crates_tests (
123
+ vec ! [ "burn-core" ] ,
124
+ vec ! [ "--features" , "test-wgpu-spirv" ] ,
125
+ None ,
126
+ None ,
127
+ "std vulkan" ,
128
+ ) ?;
129
+ helpers:: custom_crates_tests (
130
+ vec ! [ "burn-vision" ] ,
131
+ vec ! [ "--features" , "test-vulkan" ] ,
132
+ None ,
133
+ None ,
134
+ "std vulkan" ,
135
+ ) ?;
136
+ } ,
137
+ CiTestType :: GcpWgpuRunner => {
138
+ helpers:: custom_crates_tests (
139
+ vec ! [ "burn-core" ] ,
140
+ vec ! [ "--features" , "test-wgpu" ] ,
141
+ None ,
142
+ None ,
143
+ "std wgpu" ,
144
+ ) ?;
145
+ helpers:: custom_crates_tests (
146
+ vec ! [ "burn-vision" ] ,
147
+ vec ! [ "--features" , "test-wgpu" ] ,
148
+ None ,
149
+ None ,
150
+ "std wgpu" ,
151
+ ) ?;
152
+ } ,
153
+ CiTestType :: GithubMacRunner => {
160
154
// burn-candle
161
155
helpers:: custom_crates_tests (
162
156
vec ! [ "burn-candle" ] ,
@@ -173,9 +167,8 @@ pub(crate) fn handle_command(
173
167
None ,
174
168
"std blas-accelerate" ,
175
169
) ?;
176
- }
170
+ } ,
177
171
}
178
-
179
172
Ok ( ( ) )
180
173
}
181
174
Context :: All => Context :: value_variants ( )
0 commit comments