Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

Add vanilla and optimized cpu device to detectron #596

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions detectron/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,7 @@
'TRAIN.DROPOUT',
'USE_GPU_NMS',
'TEST.NUM_TEST_IMAGES',
'--device_id'
}
)

Expand Down
3 changes: 2 additions & 1 deletion detectron/core/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def result_getter():
dataset_name,
proposal_file,
output_dir,
multi_gpu=multi_gpu_testing
multi_gpu=multi_gpu_testing,
gpu_id=gpu_id
)
all_results.update(results)

Expand Down
25 changes: 23 additions & 2 deletions detectron/roi_data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,12 @@ def enqueue_blobs(self, gpu_id, blob_names, blobs):
assert len(blob_names) == len(blobs)
t = time.time()
dev = c2_utils.CudaDevice(gpu_id)
queue_name = 'gpu_{}/{}'.format(gpu_id, self._blobs_queue_name)
blob_names = ['gpu_{}/{}'.format(gpu_id, b) for b in blob_names]
if gpu_id < 0:
queue_name = self._blobs_queue_name
blob_names = blob_names
else:
queue_name = 'gpu_{}/{}'.format(gpu_id, self._blobs_queue_name)
blob_names = ['gpu_{}/{}'.format(gpu_id, b) for b in blob_names]
for (blob_name, blob) in zip(blob_names, blobs):
workspace.FeedBlob(blob_name, blob, device_option=dev)
logger.debug(
Expand Down Expand Up @@ -258,6 +262,14 @@ def create_blobs_queues(self):
capacity=self._blobs_queue_capacity
)
)
if self._num_gpus == 0:
workspace.RunOperatorOnce(
core.CreateOperator(
'CreateBlobsQueue', [], [self._blobs_queue_name],
num_blobs=len(self.get_output_names()),
capacity=self._blobs_queue_capacity
)
)
return self.create_enqueue_blobs()

def close_blobs_queues(self):
Expand All @@ -269,6 +281,12 @@ def close_blobs_queues(self):
'CloseBlobsQueue', [self._blobs_queue_name], []
)
)
if self._num_gpus == 0:
workspace.RunOperatorOnce(
core.CreateOperator(
'CloseBlobsQueue', [self._blobs_queue_name], []
)
)

def create_enqueue_blobs(self):
blob_names = self.get_output_names()
Expand All @@ -279,6 +297,9 @@ def create_enqueue_blobs(self):
with c2_utils.NamedCudaScope(gpu_id):
for blob in enqueue_blob_names:
workspace.CreateBlob(core.ScopedName(blob))
if self._num_gpus == 0:
for blob in enqueue_blob_names:
workspace.CreateBlob(core.ScopedName(blob))
return enqueue_blob_names

def register_sigint_handler(self):
Expand Down
55 changes: 45 additions & 10 deletions detectron/utils/c2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

import detectron.utils.env as envu

DEVICE_ID_CPU = -1
DEVICE_ID_IDEEP = -2

def import_contrib_ops():
"""Import contrib ops needed by Detectron."""
Expand Down Expand Up @@ -104,37 +106,70 @@ def UnscopeName(possibly_scoped_name):
def NamedCudaScope(gpu_id):
"""Creates a GPU name scope and CUDA device scope. This function is provided
to reduce `with ...` nesting levels."""
with GpuNameScope(gpu_id):
with CudaScope(gpu_id):
if gpu_id == DEVICE_ID_CPU:
with CpuScope():
yield
elif gpu_id == DEVICE_ID_IDEEP:
with IdeepScope():
yield
else:
with GpuNameScope(gpu_id):
with CudaScope(gpu_id):
yield


@contextlib.contextmanager
def GpuNameScope(gpu_id):
"""Create a name scope for GPU device `gpu_id`."""
with core.NameScope('gpu_{:d}'.format(gpu_id)):
if gpu_id < 0:
yield
else:
"""Create a name scope for GPU device `gpu_id`."""
with core.NameScope('gpu_{:d}'.format(gpu_id)):
yield


@contextlib.contextmanager
def CudaScope(gpu_id):
"""Create a CUDA device scope for GPU device `gpu_id`."""
gpu_dev = CudaDevice(gpu_id)
with core.DeviceScope(gpu_dev):
yield
if gpu_id == DEVICE_ID_CPU:
with CpuScope():
yield
elif gpu_id == DEVICE_ID_IDEEP:
with IdeepScope():
yield
else:
"""Create a CUDA device scope for GPU device `gpu_id`."""
gpu_dev = CudaDevice(gpu_id)
with core.DeviceScope(gpu_dev):
yield


@contextlib.contextmanager
def CpuScope():
"""Create a CPU device scope."""
cpu_dev = core.DeviceOption(caffe2_pb2.CPU)
cpu_dev = CpuDevice()
with core.DeviceScope(cpu_dev):
yield

def CpuDevice():
return core.DeviceOption(caffe2_pb2.CPU)

@contextlib.contextmanager
def IdeepScope():
ideep_dev = IdeepDevice()
with core.DeviceScope(ideep_dev):
yield

def IdeepDevice():
return core.DeviceOption(caffe2_pb2.IDEEP)

def CudaDevice(gpu_id):
"""Create a Cuda device."""
return core.DeviceOption(caffe2_pb2.CUDA, gpu_id)
if gpu_id == DEVICE_ID_CPU:
return CpuDevice()
elif gpu_id == DEVICE_ID_IDEEP:
return IdeepDevice()
else:
return core.DeviceOption(caffe2_pb2.CUDA, gpu_id)


def gauss_fill(std):
Expand Down
5 changes: 5 additions & 0 deletions detectron/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def get_detectron_ops_lib():
# TODO(ilijar): Switch to using a logger
print('Found Detectron ops lib: {}'.format(ops_path))
break
ops_path = os.path.join(prefix, 'lib/libcaffe2_detectron_ops.so')
if os.path.exists(ops_path):
# TODO(ilijar): Switch to using a logger
print('Found Detectron ops lib: {}'.format(ops_path))
break
assert os.path.exists(ops_path), \
('Detectron ops lib not found; make sure that your Caffe2 '
'version includes Detectron module')
Expand Down
10 changes: 8 additions & 2 deletions tools/infer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def parse_args():
default='pdf',
type=str
)
parser.add_argument(
'--device_id',
dest='device_id',
default=0,
type=int
)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
Expand All @@ -118,7 +124,7 @@ def main(args):
assert not cfg.TEST.PRECOMPUTED_PROPOSALS, \
'Models that require precomputed proposals are not supported'

model = infer_engine.initialize_model_from_cfg(args.weights)
model = infer_engine.initialize_model_from_cfg(args.weights, gpu_id = args.device_id)
dummy_coco_dataset = dummy_datasets.get_coco_dataset()

if os.path.isdir(args.im_or_folder):
Expand All @@ -134,7 +140,7 @@ def main(args):
im = cv2.imread(im_name)
timers = defaultdict(Timer)
t = time.time()
with c2_utils.NamedCudaScope(0):
with c2_utils.NamedCudaScope(args.device_id):
cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
model, im, None, timers=timers
)
Expand Down
7 changes: 7 additions & 0 deletions tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def parse_args():
type=int,
nargs=2
)
parser.add_argument(
'--device_id',
dest='device_id',
default=0,
type=int
)
parser.add_argument(
'opts',
help='See detectron/core/config.py for all options',
Expand Down Expand Up @@ -113,5 +119,6 @@ def parse_args():
cfg.TEST.WEIGHTS,
ind_range=args.range,
multi_gpu_testing=args.multi_gpu_testing,
gpu_id=args.device_id,
check_expected_results=True,
)