Skip to content

Commit 15c5635

Browse files
authored
Version up (#224)
* version up * formatting * fix * reset * revert
1 parent 79535f3 commit 15c5635

File tree

6 files changed

+47
-58
lines changed

6 files changed

+47
-58
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cmake_minimum_required(VERSION 3.0)
22
project(torchsparse)
33
set(CMAKE_CXX_STANDARD 14)
4-
set(TORCHSPARSE_VERSION 0.6.13)
4+
set(TORCHSPARSE_VERSION 0.7.0)
55

66
option(WITH_CUDA "Enable CUDA support" OFF)
77
option(WITH_PYTHON "Link to Python when building" ON)

conda/pytorch-sparse/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package:
22
name: pytorch-sparse
3-
version: 0.6.13
3+
version: 0.7.0
44

55
source:
66
path: ../..

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 38 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -114,35 +114,31 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
114114
from_vector<int64_t>(cols), from_vector<int64_t>(edges));
115115
}
116116

117-
bool satisfy_time_constraint(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
118-
const std::string &src_node_type,
119-
const int64_t &dst_time,
120-
const int64_t &sampled_node) {
117+
bool satisfy_time_constraint(
118+
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
119+
const node_t &src_node_type, const int64_t &dst_time,
120+
const int64_t &src_node) {
121121
// whether src -> dst obeys the time constraint
122122
try {
123-
const auto *src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
124-
return dst_time < src_time[sampled_node];
125-
}
126-
catch (int err) {
123+
auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
124+
return dst_time < src_time[src_node];
125+
} catch (int err) {
127126
// if the node type does not have timestamp, fall back to normal sampling
128127
return true;
129128
}
130129
}
131130

132-
133131
template <bool replace, bool directed, bool temporal>
134132
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
135133
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
136134
hetero_sample(const vector<node_t> &node_types,
137-
const vector<edge_t> &edge_types,
138-
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
139-
const c10::Dict<rel_t, torch::Tensor> &row_dict,
140-
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
141-
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
142-
const int64_t num_hops,
143-
const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
144-
//bool temporal = (!node_time_dict.empty());
145-
135+
const vector<edge_t> &edge_types,
136+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
137+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
138+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
139+
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
140+
const int64_t num_hops,
141+
const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
146142
// Create a mapping to convert single string relations to edge type triplets:
147143
unordered_map<rel_t, edge_t> to_edge_type;
148144
for (const auto &k : edge_types)
@@ -174,11 +170,12 @@ hetero_sample(const vector<node_t> &node_types,
174170
const torch::Tensor &input_node = kv.value();
175171
const auto *input_node_data = input_node.data_ptr<int64_t>();
176172
// dummy value. will be reset to root time if is_temporal==true
177-
auto *node_time_data = input_node.data_ptr<int64_t>();
173+
int64_t *node_time_data;
178174
// root_time[i] stores the timestamp of the computation tree root
179175
// of the node samples[i]
180176
if (temporal) {
181-
node_time_data = node_time_dict.at(node_type).data_ptr<int64_t>();
177+
torch::Tensor node_time = node_time_dict.at(node_type);
178+
node_time_data = node_time.data_ptr<int64_t>();
182179
}
183180

184181
auto &samples = samples_dict.at(node_type);
@@ -220,7 +217,7 @@ hetero_sample(const vector<node_t> &node_types,
220217

221218
const auto &begin = slice_dict.at(dst_node_type).first;
222219
const auto &end = slice_dict.at(dst_node_type).second;
223-
if (begin == end){
220+
if (begin == end) {
224221
continue;
225222
}
226223
// for temporal sampling, sampled src node cannot have timestamp greater
@@ -370,22 +367,17 @@ hetero_sample(const vector<node_t> &node_types,
370367
template <bool replace, bool directed>
371368
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
372369
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
373-
hetero_sample_random(const vector<node_t> &node_types,
374-
const vector<edge_t> &edge_types,
375-
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
376-
const c10::Dict<rel_t, torch::Tensor> &row_dict,
377-
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
378-
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
379-
const int64_t num_hops) {
370+
hetero_sample_random(
371+
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
372+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
373+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
374+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
375+
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
376+
const int64_t num_hops) {
380377
c10::Dict<node_t, torch::Tensor> empty_dict;
381-
return hetero_sample<replace, directed, false>(node_types,
382-
edge_types,
383-
colptr_dict,
384-
row_dict,
385-
input_node_dict,
386-
num_neighbors_dict,
387-
num_hops,
388-
empty_dict);
378+
return hetero_sample<replace, directed, false>(
379+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
380+
num_neighbors_dict, num_hops, empty_dict);
389381
}
390382

391383
} // namespace
@@ -418,24 +410,20 @@ hetero_neighbor_sample_cpu(
418410
const int64_t num_hops, const bool replace, const bool directed) {
419411

420412
if (replace && directed) {
421-
return hetero_sample_random<true, true>(
422-
node_types, edge_types, colptr_dict,
423-
row_dict, input_node_dict,
424-
num_neighbors_dict, num_hops);
413+
return hetero_sample_random<true, true>(node_types, edge_types, colptr_dict,
414+
row_dict, input_node_dict,
415+
num_neighbors_dict, num_hops);
425416
} else if (replace && !directed) {
426417
return hetero_sample_random<true, false>(
427-
node_types, edge_types, colptr_dict,
428-
row_dict, input_node_dict,
418+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
429419
num_neighbors_dict, num_hops);
430420
} else if (!replace && directed) {
431421
return hetero_sample_random<false, true>(
432-
node_types, edge_types, colptr_dict,
433-
row_dict, input_node_dict,
422+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
434423
num_neighbors_dict, num_hops);
435424
} else {
436425
return hetero_sample_random<false, false>(
437-
node_types, edge_types, colptr_dict,
438-
row_dict, input_node_dict,
426+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
439427
num_neighbors_dict, num_hops);
440428
}
441429
}
@@ -453,23 +441,19 @@ hetero_neighbor_temporal_sample_cpu(
453441

454442
if (replace && directed) {
455443
return hetero_sample<true, true, true>(
456-
node_types, edge_types, colptr_dict,
457-
row_dict, input_node_dict,
444+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
458445
num_neighbors_dict, num_hops, node_time_dict);
459446
} else if (replace && !directed) {
460447
return hetero_sample<true, false, true>(
461-
node_types, edge_types, colptr_dict,
462-
row_dict, input_node_dict,
448+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
463449
num_neighbors_dict, num_hops, node_time_dict);
464450
} else if (!replace && directed) {
465451
return hetero_sample<false, true, true>(
466-
node_types, edge_types, colptr_dict,
467-
row_dict, input_node_dict,
452+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
468453
num_neighbors_dict, num_hops, node_time_dict);
469454
} else {
470455
return hetero_sample<false, false, true>(
471-
node_types, edge_types, colptr_dict,
472-
row_dict, input_node_dict,
456+
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
473457
num_neighbors_dict, num_hops, node_time_dict);
474458
}
475459
}

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,8 @@ test = pytest
1717

1818
[tool:pytest]
1919
addopts = --capture=no
20+
21+
[isort]
22+
multi_line_output=3
23+
include_trailing_comma = True
24+
skip=.gitignore,__init__.py

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
1212
CUDAExtension)
1313

14-
__version__ = '0.6.13'
14+
__version__ = '0.7.0'
1515
URL = 'https://github.com/rusty1s/pytorch_sparse'
1616

1717
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None

torch_sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
__version__ = '0.6.13'
6+
__version__ = '0.7.0'
77

88
for library in [
99
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',

0 commit comments

Comments
 (0)