@@ -114,35 +114,31 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
114
114
from_vector<int64_t >(cols), from_vector<int64_t >(edges));
115
115
}
116
116
117
- bool satisfy_time_constraint (const c10::Dict< node_t , torch::Tensor> &node_time_dict,
118
- const std::string &src_node_type ,
119
- const int64_t &dst_time,
120
- const int64_t &sampled_node ) {
117
+ bool satisfy_time_constraint (
118
+ const c10::Dict< node_t , torch::Tensor> &node_time_dict ,
119
+ const node_t &src_node_type, const int64_t &dst_time,
120
+ const int64_t &src_node ) {
121
121
// whether src -> dst obeys the time constraint
122
122
try {
123
- const auto *src_time = node_time_dict.at (src_node_type).data_ptr <int64_t >();
124
- return dst_time < src_time[sampled_node];
125
- }
126
- catch (int err) {
123
+ auto src_time = node_time_dict.at (src_node_type).data_ptr <int64_t >();
124
+ return dst_time < src_time[src_node];
125
+ } catch (int err) {
127
126
// if the node type does not have timestamp, fall back to normal sampling
128
127
return true ;
129
128
}
130
129
}
131
130
132
-
133
131
template <bool replace, bool directed, bool temporal>
134
132
tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
135
133
c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
136
134
hetero_sample (const vector<node_t > &node_types,
137
- const vector<edge_t > &edge_types,
138
- const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
139
- const c10::Dict<rel_t , torch::Tensor> &row_dict,
140
- const c10::Dict<node_t , torch::Tensor> &input_node_dict,
141
- const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
142
- const int64_t num_hops,
143
- const c10::Dict<node_t , torch::Tensor> &node_time_dict) {
144
- // bool temporal = (!node_time_dict.empty());
145
-
135
+ const vector<edge_t > &edge_types,
136
+ const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
137
+ const c10::Dict<rel_t , torch::Tensor> &row_dict,
138
+ const c10::Dict<node_t , torch::Tensor> &input_node_dict,
139
+ const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
140
+ const int64_t num_hops,
141
+ const c10::Dict<node_t , torch::Tensor> &node_time_dict) {
146
142
// Create a mapping to convert single string relations to edge type triplets:
147
143
unordered_map<rel_t , edge_t > to_edge_type;
148
144
for (const auto &k : edge_types)
@@ -174,11 +170,12 @@ hetero_sample(const vector<node_t> &node_types,
174
170
const torch::Tensor &input_node = kv.value ();
175
171
const auto *input_node_data = input_node.data_ptr <int64_t >();
176
172
// dummy value. will be reset to root time if is_temporal==true
177
- auto *node_time_data = input_node. data_ptr < int64_t >() ;
173
+ int64_t *node_time_data;
178
174
// root_time[i] stores the timestamp of the computation tree root
179
175
// of the node samples[i]
180
176
if (temporal) {
181
- node_time_data = node_time_dict.at (node_type).data_ptr <int64_t >();
177
+ torch::Tensor node_time = node_time_dict.at (node_type);
178
+ node_time_data = node_time.data_ptr <int64_t >();
182
179
}
183
180
184
181
auto &samples = samples_dict.at (node_type);
@@ -220,7 +217,7 @@ hetero_sample(const vector<node_t> &node_types,
220
217
221
218
const auto &begin = slice_dict.at (dst_node_type).first ;
222
219
const auto &end = slice_dict.at (dst_node_type).second ;
223
- if (begin == end){
220
+ if (begin == end) {
224
221
continue ;
225
222
}
226
223
// for temporal sampling, sampled src node cannot have timestamp greater
@@ -370,22 +367,17 @@ hetero_sample(const vector<node_t> &node_types,
370
367
template <bool replace, bool directed>
371
368
tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
372
369
c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
373
- hetero_sample_random (const vector< node_t > &node_types,
374
- const vector<edge_t > &edge_types,
375
- const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
376
- const c10::Dict<rel_t , torch::Tensor> &row_dict,
377
- const c10::Dict<node_t , torch::Tensor> &input_node_dict,
378
- const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
379
- const int64_t num_hops) {
370
+ hetero_sample_random (
371
+ const vector< node_t > &node_types, const vector<edge_t > &edge_types,
372
+ const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
373
+ const c10::Dict<rel_t , torch::Tensor> &row_dict,
374
+ const c10::Dict<node_t , torch::Tensor> &input_node_dict,
375
+ const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
376
+ const int64_t num_hops) {
380
377
c10::Dict<node_t , torch::Tensor> empty_dict;
381
- return hetero_sample<replace, directed, false >(node_types,
382
- edge_types,
383
- colptr_dict,
384
- row_dict,
385
- input_node_dict,
386
- num_neighbors_dict,
387
- num_hops,
388
- empty_dict);
378
+ return hetero_sample<replace, directed, false >(
379
+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
380
+ num_neighbors_dict, num_hops, empty_dict);
389
381
}
390
382
391
383
} // namespace
@@ -418,24 +410,20 @@ hetero_neighbor_sample_cpu(
418
410
const int64_t num_hops, const bool replace, const bool directed) {
419
411
420
412
if (replace && directed) {
421
- return hetero_sample_random<true , true >(
422
- node_types, edge_types, colptr_dict,
423
- row_dict, input_node_dict,
424
- num_neighbors_dict, num_hops);
413
+ return hetero_sample_random<true , true >(node_types, edge_types, colptr_dict,
414
+ row_dict, input_node_dict,
415
+ num_neighbors_dict, num_hops);
425
416
} else if (replace && !directed) {
426
417
return hetero_sample_random<true , false >(
427
- node_types, edge_types, colptr_dict,
428
- row_dict, input_node_dict,
418
+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
429
419
num_neighbors_dict, num_hops);
430
420
} else if (!replace && directed) {
431
421
return hetero_sample_random<false , true >(
432
- node_types, edge_types, colptr_dict,
433
- row_dict, input_node_dict,
422
+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
434
423
num_neighbors_dict, num_hops);
435
424
} else {
436
425
return hetero_sample_random<false , false >(
437
- node_types, edge_types, colptr_dict,
438
- row_dict, input_node_dict,
426
+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
439
427
num_neighbors_dict, num_hops);
440
428
}
441
429
}
@@ -453,23 +441,19 @@ hetero_neighbor_temporal_sample_cpu(
453
441
454
442
if (replace && directed) {
455
443
return hetero_sample<true , true , true >(
456
- node_types, edge_types, colptr_dict,
457
- row_dict, input_node_dict,
444
+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
458
445
num_neighbors_dict, num_hops, node_time_dict);
459
446
} else if (replace && !directed) {
460
447
return hetero_sample<true , false , true >(
461
- node_types, edge_types, colptr_dict,
462
- row_dict, input_node_dict,
448
+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
463
449
num_neighbors_dict, num_hops, node_time_dict);
464
450
} else if (!replace && directed) {
465
451
return hetero_sample<false , true , true >(
466
- node_types, edge_types, colptr_dict,
467
- row_dict, input_node_dict,
452
+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
468
453
num_neighbors_dict, num_hops, node_time_dict);
469
454
} else {
470
455
return hetero_sample<false , false , true >(
471
- node_types, edge_types, colptr_dict,
472
- row_dict, input_node_dict,
456
+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
473
457
num_neighbors_dict, num_hops, node_time_dict);
474
458
}
475
459
}
0 commit comments