@@ -114,16 +114,13 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
114
114
from_vector<int64_t >(cols), from_vector<int64_t >(edges));
115
115
}
116
116
117
- bool satisfy_time_constraint (
118
- const c10::Dict<node_t , torch::Tensor> &node_time_dict,
119
- const node_t &src_node_type, const int64_t &dst_time,
120
- const int64_t &src_node) {
121
- // whether src -> dst obeys the time constraint
122
- try {
117
+ inline bool satisfy_time (const c10::Dict<node_t , torch::Tensor> &node_time_dict,
118
+ const node_t &src_node_type, const int64_t &dst_time,
119
+ const int64_t &src_node) {
120
+ try { // Check whether src -> dst obeys the time constraint:
123
121
auto src_time = node_time_dict.at (src_node_type).data_ptr <int64_t >();
124
122
return dst_time < src_time[src_node];
125
- } catch (int err) {
126
- // if the node type does not have timestamp, fall back to normal sampling
123
+ } catch (int err) { // If no time is given, fall back to normal sampling:
127
124
return true ;
128
125
}
129
126
}
@@ -137,8 +134,9 @@ hetero_sample(const vector<node_t> &node_types,
137
134
const c10::Dict<rel_t , torch::Tensor> &row_dict,
138
135
const c10::Dict<node_t , torch::Tensor> &input_node_dict,
139
136
const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
140
- const int64_t num_hops,
141
- const c10::Dict<node_t , torch::Tensor> &node_time_dict) {
137
+ const c10::Dict<node_t , torch::Tensor> &node_time_dict,
138
+ const int64_t num_hops) {
139
+
142
140
// Create a mapping to convert single string relations to edge type triplets:
143
141
unordered_map<rel_t , edge_t > to_edge_type;
144
142
for (const auto &k : edge_types)
@@ -155,8 +153,6 @@ hetero_sample(const vector<node_t> &node_types,
155
153
156
154
unordered_map<node_t , vector<int64_t >> samples_dict;
157
155
unordered_map<node_t , unordered_map<int64_t , int64_t >> to_local_node_dict;
158
- // The timestamp of the center node whose neighborhood that the sampled node
159
- // belongs to. It maps node_type to empty vector in non-temporal sampling.
160
156
unordered_map<node_t , vector<int64_t >> root_time_dict;
161
157
for (const auto &node_type : node_types) {
162
158
samples_dict[node_type];
@@ -169,10 +165,7 @@ hetero_sample(const vector<node_t> &node_types,
169
165
const auto &node_type = kv.key ();
170
166
const torch::Tensor &input_node = kv.value ();
171
167
const auto *input_node_data = input_node.data_ptr <int64_t >();
172
- // dummy value. will be reset to root time if is_temporal==true
173
168
int64_t *node_time_data;
174
- // root_time[i] stores the timestamp of the computation tree root
175
- // of the node samples[i]
176
169
if (temporal) {
177
170
torch::Tensor node_time = node_time_dict.at (node_type);
178
171
node_time_data = node_time.data_ptr <int64_t >();
@@ -185,9 +178,8 @@ hetero_sample(const vector<node_t> &node_types,
185
178
const auto &v = input_node_data[i];
186
179
samples.push_back (v);
187
180
to_local_node.insert ({v, i});
188
- if (temporal) {
181
+ if (temporal)
189
182
root_time.push_back (node_time_data[v]);
190
- }
191
183
}
192
184
}
193
185
@@ -217,11 +209,12 @@ hetero_sample(const vector<node_t> &node_types,
217
209
218
210
const auto &begin = slice_dict.at (dst_node_type).first ;
219
211
const auto &end = slice_dict.at (dst_node_type).second ;
220
- if (begin == end) {
212
+
213
+ if (begin == end)
221
214
continue ;
222
- }
223
- // for temporal sampling, sampled src node cannot have timestamp greater
224
- // than its corresponding dst_root_time
215
+
216
+ // For temporal sampling, sampled nodes cannot have a timestamp greater
217
+ // than the timestamp of the root nodes.
225
218
const auto &dst_root_time = root_time_dict.at (dst_node_type);
226
219
auto &src_root_time = root_time_dict.at (src_node_type);
227
220
@@ -236,16 +229,13 @@ hetero_sample(const vector<node_t> &node_types,
236
229
continue ;
237
230
238
231
if ((num_samples < 0 ) || (!replace && (num_samples >= col_count))) {
239
- // select all neighbors
232
+ // Select all neighbors:
240
233
for (int64_t offset = col_start; offset < col_end; offset++) {
241
234
const int64_t &v = row_data[offset];
242
- bool time_constraint = true ;
243
235
if (temporal) {
244
- time_constraint = satisfy_time_constraint (
245
- node_time_dict, src_node_type, dst_time, v) ;
236
+ if (! satisfy_time (node_time_dict, src_node_type, dst_time, v))
237
+ continue ;
246
238
}
247
- if (!time_constraint)
248
- continue ;
249
239
const auto res = to_local_src_node.insert ({v, src_samples.size ()});
250
240
if (res.second ) {
251
241
src_samples.push_back (v);
@@ -259,18 +249,16 @@ hetero_sample(const vector<node_t> &node_types,
259
249
}
260
250
}
261
251
} else if (replace) {
262
- // sample with replacement
252
+ // Sample with replacement:
263
253
int64_t num_neighbors = 0 ;
264
254
while (num_neighbors < num_samples) {
265
255
const int64_t offset = col_start + uniform_randint (col_count);
266
256
const int64_t &v = row_data[offset];
267
- bool time_constraint = true ;
268
257
if (temporal) {
269
- time_constraint = satisfy_time_constraint (
270
- node_time_dict, src_node_type, dst_time, v);
258
+ // TODO Infinity loop if no neighbor satisfies time constraint:
259
+ if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
260
+ continue ;
271
261
}
272
- if (!time_constraint)
273
- continue ;
274
262
const auto res = to_local_src_node.insert ({v, src_samples.size ()});
275
263
if (res.second ) {
276
264
src_samples.push_back (v);
@@ -285,7 +273,7 @@ hetero_sample(const vector<node_t> &node_types,
285
273
num_neighbors += 1 ;
286
274
}
287
275
} else {
288
- // sample without replacement
276
+ // Sample without replacement:
289
277
unordered_set<int64_t > rnd_indices;
290
278
for (int64_t j = col_count - num_samples; j < col_count; j++) {
291
279
int64_t rnd = uniform_randint (j);
@@ -295,13 +283,10 @@ hetero_sample(const vector<node_t> &node_types,
295
283
}
296
284
const int64_t offset = col_start + rnd;
297
285
const int64_t &v = row_data[offset];
298
- bool time_constraint = true ;
299
286
if (temporal) {
300
- time_constraint = satisfy_time_constraint (
301
- node_time_dict, src_node_type, dst_time, v) ;
287
+ if (! satisfy_time (node_time_dict, src_node_type, dst_time, v))
288
+ continue ;
302
289
}
303
- if (!time_constraint)
304
- continue ;
305
290
const auto res = to_local_src_node.insert ({v, src_samples.size ()});
306
291
if (res.second ) {
307
292
src_samples.push_back (v);
@@ -364,22 +349,6 @@ hetero_sample(const vector<node_t> &node_types,
364
349
from_vector<rel_t , int64_t >(edges_dict));
365
350
}
366
351
367
- template <bool replace, bool directed>
368
- tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
369
- c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
370
- hetero_sample_random (
371
- const vector<node_t > &node_types, const vector<edge_t > &edge_types,
372
- const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
373
- const c10::Dict<rel_t , torch::Tensor> &row_dict,
374
- const c10::Dict<node_t , torch::Tensor> &input_node_dict,
375
- const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
376
- const int64_t num_hops) {
377
- c10::Dict<node_t , torch::Tensor> empty_dict;
378
- return hetero_sample<replace, directed, false >(
379
- node_types, edge_types, colptr_dict, row_dict, input_node_dict,
380
- num_neighbors_dict, num_hops, empty_dict);
381
- }
382
-
383
352
} // namespace
384
353
385
354
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -409,28 +378,30 @@ hetero_neighbor_sample_cpu(
409
378
const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
410
379
const int64_t num_hops, const bool replace, const bool directed) {
411
380
381
+ c10::Dict<node_t , torch::Tensor> node_time_dict; // Empty dictionary.
382
+
412
383
if (replace && directed) {
413
- return hetero_sample_random <true , true >(node_types, edge_types, colptr_dict,
414
- row_dict, input_node_dict,
415
- num_neighbors_dict, num_hops);
384
+ return hetero_sample <true , true , false >(
385
+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
386
+ num_neighbors_dict, node_time_dict , num_hops);
416
387
} else if (replace && !directed) {
417
- return hetero_sample_random <true , false >(
388
+ return hetero_sample <true , false , false >(
418
389
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
419
- num_neighbors_dict, num_hops);
390
+ num_neighbors_dict, node_time_dict, num_hops);
420
391
} else if (!replace && directed) {
421
- return hetero_sample_random <false , true >(
392
+ return hetero_sample <false , true , false >(
422
393
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
423
- num_neighbors_dict, num_hops);
394
+ num_neighbors_dict, node_time_dict, num_hops);
424
395
} else {
425
- return hetero_sample_random< false , false >(
396
+ return hetero_sample< false , false , false >(
426
397
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
427
- num_neighbors_dict, num_hops);
398
+ num_neighbors_dict, node_time_dict, num_hops);
428
399
}
429
400
}
430
401
431
402
tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
432
403
c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
433
- hetero_neighbor_temporal_sample_cpu (
404
+ hetero_temporal_neighbor_sample_cpu (
434
405
const vector<node_t > &node_types, const vector<edge_t > &edge_types,
435
406
const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
436
407
const c10::Dict<rel_t , torch::Tensor> &row_dict,
@@ -442,18 +413,18 @@ hetero_neighbor_temporal_sample_cpu(
442
413
if (replace && directed) {
443
414
return hetero_sample<true , true , true >(
444
415
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
445
- num_neighbors_dict, num_hops, node_time_dict );
416
+ num_neighbors_dict, node_time_dict, num_hops );
446
417
} else if (replace && !directed) {
447
418
return hetero_sample<true , false , true >(
448
419
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
449
- num_neighbors_dict, num_hops, node_time_dict );
420
+ num_neighbors_dict, node_time_dict, num_hops );
450
421
} else if (!replace && directed) {
451
422
return hetero_sample<false , true , true >(
452
423
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
453
- num_neighbors_dict, num_hops, node_time_dict );
424
+ num_neighbors_dict, node_time_dict, num_hops );
454
425
} else {
455
426
return hetero_sample<false , false , true >(
456
427
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
457
- num_neighbors_dict, num_hops, node_time_dict );
428
+ num_neighbors_dict, node_time_dict, num_hops );
458
429
}
459
430
}
0 commit comments