-
Notifications
You must be signed in to change notification settings - Fork 303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sampling Performance Testing #3584
Merged
rapids-bot
merged 490 commits into
rapidsai:branch-24.02
from
alexbarghi-nv:perf-testing-v2
Jan 12, 2024
Merged
Changes from all commits
Commits
Show all changes
490 commits
Select commit
Hold shift + click to select a range
c09bb25
bug fix
seunghwak 4edb9ae
Merge branch 'branch-23.08' of github.com:rapidsai/cugraph into bug_mfg
seunghwak 57fb8e5
Merge branch 'bug_mfg' of https://github.com/seunghwak/cugraph into p…
alexbarghi-nv 3b95106
add latest updates
alexbarghi-nv 3269a4f
Merge branch 'perf-testing-v2' of https://github.com/alexbarghi-nv/cu…
alexbarghi-nv 3e009cd
bug fix (when edge list is empty)
seunghwak 622a17a
Merge branch 'branch-23.08' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv e4d7796
add latest updates
alexbarghi-nv a226a4e
revert cpp changes
alexbarghi-nv 5d3843f
revert plc changes
alexbarghi-nv 36464a9
revert notebook changes
alexbarghi-nv c5a81c2
Revert logging change
alexbarghi-nv 95a72ab
correction for dataset name
alexbarghi-nv aebe742
fix for empty batch issue
alexbarghi-nv 449984d
do merge
alexbarghi-nv bdaa22f
bring in changes
alexbarghi-nv 223dee3
remove redundant filter function
alexbarghi-nv 0c904ae
construct cugraph graph in CSC format
alexbarghi-nv 399976d
fixes for csc, update tests
alexbarghi-nv 6b1169e
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv 3c9afc9
style fix, add comment explaining function
alexbarghi-nv 88831a8
Merge branch 'cugraph-pyg-loader-improvements' of https://github.com/…
alexbarghi-nv 246ac33
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv 2fe3fe0
improve docstring
alexbarghi-nv f89a3fb
Merge branch 'cugraph-pyg-loader-improvements' of https://github.com/…
alexbarghi-nv 072b1ff
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv 85a9c88
cleanup ahead of conversion to mg
alexbarghi-nv 53b334b
mg work
alexbarghi-nv f0e9f1f
move sampling relatd functions in graph_functions.hpp to sampling_fun…
seunghwak 3b1fd23
draft sampling post processing function APIs
seunghwak 5e99823
mg
alexbarghi-nv 7e4d041
resolve merge conflict
alexbarghi-nv d62f4f0
update to fix hop numbering issue
alexbarghi-nv 67f4d7b
API updates
seunghwak 5a8194e
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv 19f66d0
Persist on host memory
alexbarghi-nv 8f521d2
API updates
seunghwak da3da9b
deprecate the existing renumber_sampeld_edgelist function
seunghwak 0b87ee1
combine renumber & compression/sorting functions
seunghwak 9b5950b
minor documentation updates
seunghwak 5fbb177
mionr documentation updates
seunghwak b9611ab
deprecate the existing sampling output renumber function
seunghwak d1c1440
improvements
alexbarghi-nv 846d3fd
Merge branch 'cugraph-pyg-loader-improvements' of https://github.com/…
alexbarghi-nv e52c614
split homogeneous/heterogeneous for better performance
alexbarghi-nv 2e5479d
Merge branch 'cugraph-pyg-loader-improvements' of https://github.com/…
alexbarghi-nv 6463445
add e2e test, fix a lot of bugs found by test
alexbarghi-nv c291110
style fix
alexbarghi-nv e9d1fcc
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv 8f95c79
Merge branch 'cugraph-pyg-loader-improvements' of https://github.com/…
alexbarghi-nv 29aa194
correct docstrings
alexbarghi-nv 99b6f48
Merge branch 'cugraph-pyg-loader-improvements' of https://github.com/…
alexbarghi-nv ebf0d9c
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv 26d48dd
rename sampling convert function
alexbarghi-nv 0069d9d
Merge branch 'cugraph-pyg-loader-improvements' of https://github.com/…
alexbarghi-nv 34d6bdc
update loader with new name
alexbarghi-nv baa8ea8
add comments to renumbering, clarify deprecation, add warning
alexbarghi-nv c3ee02b
initial implementation of sampling post processing
seunghwak 04c9105
cuda::std::atomic=>cuda::atomic
seunghwak bdc840c
update API documentation
seunghwak 8c304b3
add additional input testing
seunghwak b16a071
replace testing for sampling output post processing
seunghwak 09a38d7
cosmetic updates
seunghwak 82ad8e4
bug fixes
seunghwak e9b39e4
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv d99b512
Merge branch 'fea_mfg' of https://github.com/seunghwak/cugraph into c…
alexbarghi-nv c15d580
the c api
alexbarghi-nv 2ac8b86
work
alexbarghi-nv 9135629
fix compile errors
alexbarghi-nv dfd1cb7
reformat
alexbarghi-nv 6dfd4fe
rename test file from .cu to .cpp
seunghwak f600520
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv 7d5821f
bug fixes
seunghwak 58189ed
add fill wrapper
seunghwak 39db98a
undo adding fill wrapper
seunghwak 98c8e0a
sampling test from .cpp to .cu
seunghwak 687d191
latest perf testing
alexbarghi-nv c151f95
fix a typo
seunghwak fc5a4f0
Merge branch 'branch-23.10' of github.com:rapidsai/cugraph into fea_mfg
seunghwak a7d1804
merge
alexbarghi-nv 3cda233
Merge branch 'branch-23.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv 0a18cde
do merge
alexbarghi-nv 094aaf9
do not return valid nzd vertices if doubly_compress is false
seunghwak cf57a6d
bug fix
seunghwak 2b48b7e
test code
seunghwak 79acc8e
Merge branch 'branch-23.10' of github.com:rapidsai/cugraph into fea_mfg
seunghwak 11009c6
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv 0481bfb
Merge branch 'branch-23.10' into cugraph-sample-convert
alexbarghi-nv 2af9333
Merge branch 'fea_mfg' of https://github.com/seunghwak/cugraph into c…
alexbarghi-nv 23cd2c2
bug fix
seunghwak 6eaf67e
update documentation
seunghwak 4dc0a92
fix c api issues
alexbarghi-nv 2947b33
Merge branch 'branch-23.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv 0a2b2b7
C API fixes, Python/PLC API work
alexbarghi-nv db35940
adjust hop offsets when there is a jump in major vertex IDs between hops
seunghwak b8b72be
add sort only function
seunghwak 38dd11e
Merge branch 'branch-23.10' of github.com:rapidsai/cugraph into fea_mfg
seunghwak 2a799a6
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv c86ceac
various improvements
alexbarghi-nv 37a37bf
Merge branch 'fea_mfg' of https://github.com/seunghwak/cugraph into c…
alexbarghi-nv 002fe93
fix merge conflict
alexbarghi-nv 5051dfc
fix bad merge
alexbarghi-nv 6cdf92b
asdf
alexbarghi-nv 6682cb4
clarifying comments
alexbarghi-nv 0d12a28
t
alexbarghi-nv f5733f2
latest code
alexbarghi-nv 52e2f57
bug fix
seunghwak befeb25
Merge branch 'branch-23.10' of github.com:rapidsai/cugraph into bug_o…
seunghwak 8781612
additional bug fix
seunghwak f92b5f5
add additional checking to detect the previously neglected bugs
seunghwak 2bd93d9
Merge branch 'bug_offsets' of https://github.com/seunghwak/cugraph in…
alexbarghi-nv 3195298
wrap up sg API
alexbarghi-nv 74195cb
test fix, cleanup
alexbarghi-nv 374b103
refactor code into new shared utility
alexbarghi-nv bd625e3
get mg api working
alexbarghi-nv b2a4ed1
add offset mg test
alexbarghi-nv 9fb7438
fix renumber map issue in C++
alexbarghi-nv c770a17
verify new compression formats for sg
alexbarghi-nv b569563
complete csr/csc tests for both sg/mg
alexbarghi-nv ab2a185
get the bulk sampler working again
alexbarghi-nv 89a1b33
remove unwanted file
alexbarghi-nv a9d46ef
fix wrong dataframe issue
alexbarghi-nv 17e9013
update sg bulk sampler tests
alexbarghi-nv c5543b2
fix mg bulk sampler tests
alexbarghi-nv 6581f47
Merge branch 'branch-23.10' into cugraph-pyg-loader-improvements
alexbarghi-nv 16e83bc
write draft of csr bulk sampler
alexbarghi-nv 1e7098d
overhaul the writer methods
alexbarghi-nv ae94c35
remove unused method
alexbarghi-nv 7beba4b
style
alexbarghi-nv 16ed5ef
Merge branch 'branch-23.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv 79e3cef
remove notebook
alexbarghi-nv fd5cceb
add clarifying comment to c++
alexbarghi-nv a47691d
add future warnings
alexbarghi-nv 195d063
cleanup
alexbarghi-nv 0af1750
remove print statements
alexbarghi-nv d65632c
fix c api bug
alexbarghi-nv 247d8d2
revert dataloader change
alexbarghi-nv 72bebc2
fix empty df bug
alexbarghi-nv 4d51751
style
alexbarghi-nv 9dfa3fa
io
alexbarghi-nv 10c8c1f
fix test failures, remove c++ compression enum
alexbarghi-nv 08cf3e1
remove removed api from mg tests
alexbarghi-nv 897e6d6
change to future warning
alexbarghi-nv bb5e621
resolve checking issues
alexbarghi-nv d20e593
Merge branch 'cugraph-pyg-loader-improvements' into cugraph-pyg-mfg
alexbarghi-nv eb3aadc
fix wrong index + off by 1 error, add check in test
alexbarghi-nv a124964
Merge branch 'branch-23.10' into cugraph-sample-convert
alexbarghi-nv 6990c23
add annotations
alexbarghi-nv 920bed7
docstring correction
alexbarghi-nv f8df56f
remove empty batch check
alexbarghi-nv ef2ec5b
fix capi sg test
alexbarghi-nv 8e22ab9
disable broken tests, they are too expensive to fix and redundant
alexbarghi-nv 13bdd43
Merge branch 'cugraph-sample-convert' of https://github.com/alexbargh…
alexbarghi-nv c48a14b
Merge branch 'branch-23.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv cf612c7
update c code
alexbarghi-nv 09a3bd8
Merge branch 'branch-23.10' into cugraph-pyg-mfg
alexbarghi-nv 140b6e4
Merge branch 'branch-23.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv e4544b6
Merge branch 'branch-23.10' into cugraph-sample-convert
alexbarghi-nv 0ee3798
Resolve merge conflict
alexbarghi-nv 6212869
fix bad merge
alexbarghi-nv 0f1a144
initial rewrite
alexbarghi-nv b369e97
fixes, more testing
alexbarghi-nv 13be49c
fix issue with num nodes and edges
alexbarghi-nv 185143c
e2e smoke test
alexbarghi-nv 99efb9c
Merge branch 'branch-23.10' into cugraph-pyg-mfg
alexbarghi-nv bc1f30b
Merge branch 'cugraph-sample-convert' into perf-testing-v2
alexbarghi-nv 9ea6c6b
Merge branch 'branch-23.10' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv a127643
Merge branch 'cugraph-pyg-mfg' of https://github.com/alexbarghi-nv/cu…
alexbarghi-nv 262d1da
fix test column name issues
alexbarghi-nv 7a05c10
Merge branch 'branch-23.10' into cugraph-pyg-mfg
alexbarghi-nv c440f64
resolve merge conflicts
alexbarghi-nv d0d0cb2
copyright
alexbarghi-nv b4e6d06
testing
alexbarghi-nv 20f138c
Merge branch 'perf-testing-v2' of https://github.com/alexbarghi-nv/cu…
alexbarghi-nv 7e770ad
debugging
alexbarghi-nv 4ac962d
perf testing
alexbarghi-nv 55b4e84
regex
alexbarghi-nv 0fd367a
Merge branch 'perf-testing-v2' of https://github.com/alexbarghi-nv/cu…
alexbarghi-nv 894831e
update to latest
alexbarghi-nv 3cad3f2
fixes
alexbarghi-nv 912d6ca
node loader
alexbarghi-nv ea60f94
Merge branch 'branch-23.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv 9972619
finish patch
alexbarghi-nv 1c401d1
merge latest
alexbarghi-nv 02c7210
bulk sampling
alexbarghi-nv b67d5ed
perf testing
alexbarghi-nv da389e0
minor fixes
alexbarghi-nv e29b4e8
get the native workflow working
alexbarghi-nv d358257
wrap up first version of cugraph trainer
alexbarghi-nv e08c46c
remove stats file
alexbarghi-nv a9fc5af
Fixes
alexbarghi-nv 49094db
x
alexbarghi-nv b8e2354
output multiple epochs, train/test/val
alexbarghi-nv 0fd156b
remove unwanted file
alexbarghi-nv 663febe
Merge branch 'perf-testing-v2' of https://github.com/alexbarghi-nv/cu…
alexbarghi-nv 2a3ee5a
revert file
alexbarghi-nv b424e7c
remove unwanted file
alexbarghi-nv b727fcb
remove cmake files
alexbarghi-nv d37f0d7
train/test
alexbarghi-nv d0ca16b
reformat
alexbarghi-nv 06dc14d
add scripts
alexbarghi-nv a5f1b67
Merge branch 'perf-testing-v2' of https://github.com/alexbarghi-nv/cu…
alexbarghi-nv ad83725
reorganize, add scripts
alexbarghi-nv e3d28a6
init
alexbarghi-nv d15a4d4
update
alexbarghi-nv 70a509a
Merge branch 'pyg-nightly-input-nodes-fix' of https://github.com/alex…
alexbarghi-nv ecc2db1
cugraph
alexbarghi-nv 726c81d
loader debug
alexbarghi-nv c095769
fix small bugs in cugraph-pyg
alexbarghi-nv 4be1875
c
alexbarghi-nv 59f030d
fix fanout issues
alexbarghi-nv 4bc7f90
remove experimental warnings
alexbarghi-nv a58d358
remove test files
alexbarghi-nv 318212d
data preprocessing
alexbarghi-nv 68ca511
commit
alexbarghi-nv dbbd791
Merge branch 'dlfw-patch-24.01' of https://github.com/alexbarghi-nv/c…
alexbarghi-nv d47c3ba
comment
alexbarghi-nv 367c79c
fixing issues impacting accuracy
alexbarghi-nv ac1cfbd
add readme
alexbarghi-nv cc2635b
refactor
alexbarghi-nv f1ce3e1
Fix mixed experimental import
alexbarghi-nv e38fe66
update readme
alexbarghi-nv f3f68bd
update readme
alexbarghi-nv d2734c4
fix environment variables
alexbarghi-nv 7222cba
remove unwanted file
alexbarghi-nv c2e8520
minor change to avoid timeout
alexbarghi-nv a4dad32
remove stats file
alexbarghi-nv 2109bfb
Merge branch 'perf-testing-v2' of https://github.com/alexbarghi-nv/cu…
alexbarghi-nv 6358f9b
switch versions of simple distributed graph for 24.02
alexbarghi-nv 3898cb2
remove test python file
alexbarghi-nv 3f266f5
remove mg utils dir
alexbarghi-nv 864e55e
wait for workers
alexbarghi-nv 67d6aa0
reformat
alexbarghi-nv 78fc260
add copyrights
alexbarghi-nv d81a9a8
fix wrong file
alexbarghi-nv 16f225a
remove stats file
alexbarghi-nv 259ec47
Merge branch 'branch-24.02' into perf-testing-v2
alexbarghi-nv 18571fe
fix copyright
alexbarghi-nv 40502de
split off feature transfer time
alexbarghi-nv ea46748
style
alexbarghi-nv 61f30a2
Merge branch 'branch-24.02' into perf-testing-v2
alexbarghi-nv 89ac530
fixes to scripts
alexbarghi-nv 77b0788
compatibility issues
alexbarghi-nv 4e2a706
reset file
alexbarghi-nv 18e43de
c
alexbarghi-nv c4c45db
copyright
alexbarghi-nv 8ea5c92
whitespace
alexbarghi-nv 441810c
set nthreads to 8
alexbarghi-nv c053ed0
Merge branch 'branch-24.02' into perf-testing-v2
alexbarghi-nv 3039843
Merge branch 'branch-24.02' into perf-testing-v2
alexbarghi-nv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
mg_utils/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
251 changes: 251 additions & 0 deletions
251
benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
# Copyright (c) 2023-2024, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
os.environ["RAPIDS_NO_INITIALIZE"] = "1" | ||
os.environ["CUDF_SPILL"] = "1" | ||
os.environ["LIBCUDF_CUFILE_POLICY"] = "KVIKIO" | ||
os.environ["KVIKIO_NTHREADS"] = "8" | ||
|
||
import argparse | ||
import json | ||
import warnings | ||
|
||
import torch | ||
import numpy as np | ||
import pandas | ||
|
||
import torch.distributed as dist | ||
|
||
from datasets import OGBNPapers100MDataset | ||
|
||
from cugraph.testing.mg_utils import enable_spilling | ||
|
||
|
||
def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool = False) -> None: | ||
import cupy | ||
import rmm | ||
from pynvml.smi import nvidia_smi | ||
|
||
smi = nvidia_smi.getInstance() | ||
pool_size = 16e9 # FIXME calculate this | ||
|
||
rmm.reinitialize( | ||
devices=[rank], | ||
pool_allocator=True, | ||
initial_pool_size=pool_size, | ||
) | ||
|
||
if use_rmm_torch_allocator: | ||
warnings.warn( | ||
"Using the rmm pytorch allocator is currently unsupported." | ||
" The default allocator will be used instead." | ||
) | ||
# FIXME somehow get the pytorch allocator to work | ||
# from rmm.allocators.torch import rmm_torch_allocator | ||
# torch.cuda.memory.change_current_allocator(rmm_torch_allocator) | ||
|
||
from rmm.allocators.cupy import rmm_cupy_allocator | ||
|
||
cupy.cuda.set_allocator(rmm_cupy_allocator) | ||
|
||
cupy.cuda.Device(rank).use() | ||
torch.cuda.set_device(rank) | ||
|
||
# Pytorch training worker initialization | ||
torch.distributed.init_process_group(backend="nccl") | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--gpus_per_node", | ||
type=int, | ||
default=8, | ||
help="# GPUs per node", | ||
required=False, | ||
) | ||
|
||
parser.add_argument( | ||
"--num_epochs", | ||
type=int, | ||
default=1, | ||
help="Number of training epochs", | ||
required=False, | ||
) | ||
|
||
parser.add_argument( | ||
"--batch_size", | ||
type=int, | ||
default=512, | ||
help="Batch size", | ||
required=False, | ||
) | ||
|
||
parser.add_argument( | ||
"--fanout", | ||
type=str, | ||
default="10_10_10", | ||
help="Fanout", | ||
required=False, | ||
) | ||
|
||
parser.add_argument( | ||
"--sample_dir", | ||
type=str, | ||
help="Directory with stored bulk samples (required for cuGraph run)", | ||
required=False, | ||
) | ||
|
||
parser.add_argument( | ||
"--output_file", | ||
type=str, | ||
help="File to store results", | ||
required=True, | ||
) | ||
|
||
parser.add_argument( | ||
"--framework", | ||
type=str, | ||
help="The framework to test (PyG, cuGraphPyG)", | ||
required=True, | ||
) | ||
|
||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
default="GraphSAGE", | ||
help="The model to use (currently only GraphSAGE supported)", | ||
required=False, | ||
) | ||
|
||
parser.add_argument( | ||
"--replication_factor", | ||
type=int, | ||
default=1, | ||
help="The replication factor for the dataset", | ||
required=False, | ||
) | ||
|
||
parser.add_argument( | ||
"--dataset_dir", | ||
type=str, | ||
help="The directory where datasets are stored", | ||
required=True, | ||
) | ||
|
||
parser.add_argument( | ||
"--train_split", | ||
type=float, | ||
help="The percentage of the labeled data to use for training. The remainder is used for testing/validation.", | ||
default=0.8, | ||
required=False, | ||
) | ||
|
||
parser.add_argument( | ||
"--val_split", | ||
type=float, | ||
help="The percentage of the testing/validation data to allocate for validation.", | ||
default=0.5, | ||
required=False, | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main(args): | ||
import logging | ||
|
||
logging.basicConfig( | ||
level=logging.INFO, | ||
) | ||
logger = logging.getLogger("bench_cugraph_training") | ||
logger.setLevel(logging.INFO) | ||
|
||
local_rank = int(os.environ["LOCAL_RANK"]) | ||
global_rank = int(os.environ["RANK"]) | ||
|
||
init_pytorch_worker( | ||
local_rank, use_rmm_torch_allocator=(args.framework == "cuGraph") | ||
) | ||
enable_spilling() | ||
print(f"worker initialized") | ||
dist.barrier() | ||
|
||
world_size = int(os.environ["SLURM_JOB_NUM_NODES"]) * args.gpus_per_node | ||
|
||
dataset = OGBNPapers100MDataset( | ||
replication_factor=args.replication_factor, | ||
dataset_dir=args.dataset_dir, | ||
train_split=args.train_split, | ||
val_split=args.val_split, | ||
load_edge_index=(args.framework == "PyG"), | ||
) | ||
|
||
if global_rank == 0: | ||
dataset.download() | ||
dist.barrier() | ||
|
||
fanout = [int(f) for f in args.fanout.split("_")] | ||
|
||
if args.framework == "PyG": | ||
from trainers.pyg import PyGNativeTrainer | ||
|
||
trainer = PyGNativeTrainer( | ||
model=args.model, | ||
dataset=dataset, | ||
device=local_rank, | ||
rank=global_rank, | ||
world_size=world_size, | ||
num_epochs=args.num_epochs, | ||
shuffle=True, | ||
replace=False, | ||
num_neighbors=fanout, | ||
batch_size=args.batch_size, | ||
) | ||
elif args.framework == "cuGraphPyG": | ||
sample_dir = os.path.join( | ||
args.sample_dir, | ||
f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}", | ||
) | ||
from trainers.pyg import PyGCuGraphTrainer | ||
|
||
trainer = PyGCuGraphTrainer( | ||
model=args.model, | ||
dataset=dataset, | ||
sample_dir=sample_dir, | ||
device=local_rank, | ||
rank=global_rank, | ||
world_size=world_size, | ||
num_epochs=args.num_epochs, | ||
shuffle=True, | ||
replace=False, | ||
num_neighbors=fanout, | ||
batch_size=args.batch_size, | ||
) | ||
else: | ||
raise ValueError("unsupported framework") | ||
|
||
logger.info(f"Trainer ready on rank {global_rank}") | ||
stats = trainer.train() | ||
logger.info(stats) | ||
|
||
with open(f"{args.output_file}[{global_rank}]", "w") as f: | ||
json.dump(stats, f) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume we shall include "cuGraphDGL" here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the next PR