Skip to content

Commit caf7ddd

Browse files
authored
Temporal neighbor_sample adjustments (#225)
* version up * formatting * fix * reset * revert * temporal neighbor sampling adjustments * typo
1 parent 15c5635 commit caf7ddd

File tree

3 files changed

+51
-77
lines changed

3 files changed

+51
-77
lines changed

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 40 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,13 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
114114
from_vector<int64_t>(cols), from_vector<int64_t>(edges));
115115
}
116116

117-
bool satisfy_time_constraint(
118-
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
119-
const node_t &src_node_type, const int64_t &dst_time,
120-
const int64_t &src_node) {
121-
// whether src -> dst obeys the time constraint
122-
try {
117+
inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
118+
const node_t &src_node_type, const int64_t &dst_time,
119+
const int64_t &src_node) {
120+
try { // Check whether src -> dst obeys the time constraint:
123121
auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
124122
return dst_time < src_time[src_node];
125-
} catch (int err) {
126-
// if the node type does not have timestamp, fall back to normal sampling
123+
} catch (int err) { // If no time is given, fall back to normal sampling:
127124
return true;
128125
}
129126
}
@@ -137,8 +134,9 @@ hetero_sample(const vector<node_t> &node_types,
137134
const c10::Dict<rel_t, torch::Tensor> &row_dict,
138135
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
139136
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
140-
const int64_t num_hops,
141-
const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
137+
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
138+
const int64_t num_hops) {
139+
142140
// Create a mapping to convert single string relations to edge type triplets:
143141
unordered_map<rel_t, edge_t> to_edge_type;
144142
for (const auto &k : edge_types)
@@ -155,8 +153,6 @@ hetero_sample(const vector<node_t> &node_types,
155153

156154
unordered_map<node_t, vector<int64_t>> samples_dict;
157155
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
158-
// The timestamp of the center node whose neighborhood that the sampled node
159-
// belongs to. It maps node_type to empty vector in non-temporal sampling.
160156
unordered_map<node_t, vector<int64_t>> root_time_dict;
161157
for (const auto &node_type : node_types) {
162158
samples_dict[node_type];
@@ -169,10 +165,7 @@ hetero_sample(const vector<node_t> &node_types,
169165
const auto &node_type = kv.key();
170166
const torch::Tensor &input_node = kv.value();
171167
const auto *input_node_data = input_node.data_ptr<int64_t>();
172-
// dummy value. will be reset to root time if is_temporal==true
173168
int64_t *node_time_data;
174-
// root_time[i] stores the timestamp of the computation tree root
175-
// of the node samples[i]
176169
if (temporal) {
177170
torch::Tensor node_time = node_time_dict.at(node_type);
178171
node_time_data = node_time.data_ptr<int64_t>();
@@ -185,9 +178,8 @@ hetero_sample(const vector<node_t> &node_types,
185178
const auto &v = input_node_data[i];
186179
samples.push_back(v);
187180
to_local_node.insert({v, i});
188-
if (temporal) {
181+
if (temporal)
189182
root_time.push_back(node_time_data[v]);
190-
}
191183
}
192184
}
193185

@@ -217,11 +209,12 @@ hetero_sample(const vector<node_t> &node_types,
217209

218210
const auto &begin = slice_dict.at(dst_node_type).first;
219211
const auto &end = slice_dict.at(dst_node_type).second;
220-
if (begin == end) {
212+
213+
if (begin == end)
221214
continue;
222-
}
223-
// for temporal sampling, sampled src node cannot have timestamp greater
224-
// than its corresponding dst_root_time
215+
216+
// For temporal sampling, sampled nodes cannot have a timestamp greater
217+
// than the timestamp of the root nodes.
225218
const auto &dst_root_time = root_time_dict.at(dst_node_type);
226219
auto &src_root_time = root_time_dict.at(src_node_type);
227220

@@ -236,16 +229,13 @@ hetero_sample(const vector<node_t> &node_types,
236229
continue;
237230

238231
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
239-
// select all neighbors
232+
// Select all neighbors:
240233
for (int64_t offset = col_start; offset < col_end; offset++) {
241234
const int64_t &v = row_data[offset];
242-
bool time_constraint = true;
243235
if (temporal) {
244-
time_constraint = satisfy_time_constraint(
245-
node_time_dict, src_node_type, dst_time, v);
236+
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
237+
continue;
246238
}
247-
if (!time_constraint)
248-
continue;
249239
const auto res = to_local_src_node.insert({v, src_samples.size()});
250240
if (res.second) {
251241
src_samples.push_back(v);
@@ -259,18 +249,16 @@ hetero_sample(const vector<node_t> &node_types,
259249
}
260250
}
261251
} else if (replace) {
262-
// sample with replacement
252+
// Sample with replacement:
263253
int64_t num_neighbors = 0;
264254
while (num_neighbors < num_samples) {
265255
const int64_t offset = col_start + uniform_randint(col_count);
266256
const int64_t &v = row_data[offset];
267-
bool time_constraint = true;
268257
if (temporal) {
269-
time_constraint = satisfy_time_constraint(
270-
node_time_dict, src_node_type, dst_time, v);
258+
// TODO Infinity loop if no neighbor satisfies time constraint:
259+
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
260+
continue;
271261
}
272-
if (!time_constraint)
273-
continue;
274262
const auto res = to_local_src_node.insert({v, src_samples.size()});
275263
if (res.second) {
276264
src_samples.push_back(v);
@@ -285,7 +273,7 @@ hetero_sample(const vector<node_t> &node_types,
285273
num_neighbors += 1;
286274
}
287275
} else {
288-
// sample without replacement
276+
// Sample without replacement:
289277
unordered_set<int64_t> rnd_indices;
290278
for (int64_t j = col_count - num_samples; j < col_count; j++) {
291279
int64_t rnd = uniform_randint(j);
@@ -295,13 +283,10 @@ hetero_sample(const vector<node_t> &node_types,
295283
}
296284
const int64_t offset = col_start + rnd;
297285
const int64_t &v = row_data[offset];
298-
bool time_constraint = true;
299286
if (temporal) {
300-
time_constraint = satisfy_time_constraint(
301-
node_time_dict, src_node_type, dst_time, v);
287+
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
288+
continue;
302289
}
303-
if (!time_constraint)
304-
continue;
305290
const auto res = to_local_src_node.insert({v, src_samples.size()});
306291
if (res.second) {
307292
src_samples.push_back(v);
@@ -364,22 +349,6 @@ hetero_sample(const vector<node_t> &node_types,
364349
from_vector<rel_t, int64_t>(edges_dict));
365350
}
366351

367-
template <bool replace, bool directed>
368-
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
369-
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
370-
hetero_sample_random(
371-
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
372-
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
373-
const c10::Dict<rel_t, torch::Tensor> &row_dict,
374-
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
375-
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
376-
const int64_t num_hops) {
377-
c10::Dict<node_t, torch::Tensor> empty_dict;
378-
return hetero_sample<replace, directed, false>(
379-
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
380-
num_neighbors_dict, num_hops, empty_dict);
381-
}
382-
383352
} // namespace
384353

385354
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -409,28 +378,30 @@ hetero_neighbor_sample_cpu(
409378
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
410379
const int64_t num_hops, const bool replace, const bool directed) {
411380

381+
c10::Dict<node_t, torch::Tensor> node_time_dict; // Empty dictionary.
382+
412383
if (replace && directed) {
413-
return hetero_sample_random<true, true>(node_types, edge_types, colptr_dict,
414-
row_dict, input_node_dict,
415-
num_neighbors_dict, num_hops);
384+
return hetero_sample<true, true, false>(
385+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
386+
num_neighbors_dict, node_time_dict, num_hops);
416387
} else if (replace && !directed) {
417-
return hetero_sample_random<true, false>(
388+
return hetero_sample<true, false, false>(
418389
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
419-
num_neighbors_dict, num_hops);
390+
num_neighbors_dict, node_time_dict, num_hops);
420391
} else if (!replace && directed) {
421-
return hetero_sample_random<false, true>(
392+
return hetero_sample<false, true, false>(
422393
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
423-
num_neighbors_dict, num_hops);
394+
num_neighbors_dict, node_time_dict, num_hops);
424395
} else {
425-
return hetero_sample_random<false, false>(
396+
return hetero_sample<false, false, false>(
426397
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
427-
num_neighbors_dict, num_hops);
398+
num_neighbors_dict, node_time_dict, num_hops);
428399
}
429400
}
430401

431402
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
432403
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
433-
hetero_neighbor_temporal_sample_cpu(
404+
hetero_temporal_neighbor_sample_cpu(
434405
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
435406
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
436407
const c10::Dict<rel_t, torch::Tensor> &row_dict,
@@ -442,18 +413,18 @@ hetero_neighbor_temporal_sample_cpu(
442413
if (replace && directed) {
443414
return hetero_sample<true, true, true>(
444415
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
445-
num_neighbors_dict, num_hops, node_time_dict);
416+
num_neighbors_dict, node_time_dict, num_hops);
446417
} else if (replace && !directed) {
447418
return hetero_sample<true, false, true>(
448419
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
449-
num_neighbors_dict, num_hops, node_time_dict);
420+
num_neighbors_dict, node_time_dict, num_hops);
450421
} else if (!replace && directed) {
451422
return hetero_sample<false, true, true>(
452423
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
453-
num_neighbors_dict, num_hops, node_time_dict);
424+
num_neighbors_dict, node_time_dict, num_hops);
454425
} else {
455426
return hetero_sample<false, false, true>(
456427
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
457-
num_neighbors_dict, num_hops, node_time_dict);
428+
num_neighbors_dict, node_time_dict, num_hops);
458429
}
459430
}

csrc/cpu/neighbor_sample_cpu.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ hetero_neighbor_sample_cpu(
2525

2626
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
2727
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
28-
hetero_neighbor_temporal_sample_cpu(
29-
const std::vector<node_t> &node_types,
28+
hetero_temporal_neighbor_sample_cpu(
29+
const std::vector<node_t> &node_types,
3030
const std::vector<edge_t> &edge_types,
3131
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
3232
const c10::Dict<rel_t, torch::Tensor> &row_dict,
3333
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
3434
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
3535
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
36-
const int64_t num_hops, const bool replace, const bool directed);
36+
const int64_t num_hops, const bool replace, const bool directed);

csrc/neighbor_sample.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ PyMODINIT_FUNC PyInit__neighbor_sample_cpu(void) { return NULL; }
1616
#endif
1717

1818
// Returns 'output_node', 'row', 'col', 'output_edge'
19-
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
19+
SPARSE_API
20+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
2021
neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row,
2122
const torch::Tensor &input_node,
2223
const std::vector<int64_t> num_neighbors, const bool replace,
@@ -25,7 +26,8 @@ neighbor_sample(const torch::Tensor &colptr, const torch::Tensor &row,
2526
directed);
2627
}
2728

28-
SPARSE_API std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
29+
SPARSE_API
30+
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
2931
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
3032
hetero_neighbor_sample(
3133
const std::vector<node_t> &node_types,
@@ -42,7 +44,7 @@ hetero_neighbor_sample(
4244

4345
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
4446
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
45-
hetero_neighbor_temporal_sample(
47+
hetero_temporal_neighbor_sample(
4648
const std::vector<node_t> &node_types,
4749
const std::vector<edge_t> &edge_types,
4850
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
@@ -51,7 +53,7 @@ hetero_neighbor_temporal_sample(
5153
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
5254
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
5355
const int64_t num_hops, const bool replace, const bool directed) {
54-
return hetero_neighbor_temporal_sample_cpu(
56+
return hetero_temporal_neighbor_sample_cpu(
5557
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
5658
num_neighbors_dict, node_time_dict, num_hops, replace, directed);
5759
}
@@ -60,4 +62,5 @@ static auto registry =
6062
torch::RegisterOperators()
6163
.op("torch_sparse::neighbor_sample", &neighbor_sample)
6264
.op("torch_sparse::hetero_neighbor_sample", &hetero_neighbor_sample)
63-
.op("torch_sparse::hetero_neighbor_temporal_sample", &hetero_neighbor_temporal_sample);
65+
.op("torch_sparse::hetero_temporal_neighbor_sample",
66+
&hetero_temporal_neighbor_sample);

0 commit comments

Comments
 (0)