We list some common troubles faced by many users and their corresponding solutions here. Feel free to enrich the list if you find any frequent issues and have ways to help others to solve them. If the contents here do not cover your issue, please create an issue using the provided templates and make sure you fill in all required information in the template.
The vast majority of algorithms in MMDetection now support PyTorch 2.0 and its torch.compile
function. Users only need to install MMDetection 3.0.0rc7 or later versions to enjoy this feature. If any unsupported algorithms are found during use, please feel free to give us feedback. We also welcome contributions from the community to benchmark the speed improvement brought by using the torch.compile
function.
To enable the torch.compile
function, simply add --cfg-options compile=True
after train.py
or test.py
. For example, to enable torch.compile
for RTMDet, you can use the following command:
# Single GPU
python tools/train.py configs/rtmdet/rtmdet_s_8xb32-300e_coco.py --cfg-options compile=True
# Single node multiple GPUs
./tools/dist_train.sh configs/rtmdet/rtmdet_s_8xb32-300e_coco.py 8 --cfg-options compile=True
# Single node multiple GPUs + AMP
./tools/dist_train.sh configs/rtmdet/rtmdet_s_8xb32-300e_coco.py 8 --cfg-options compile=True --amp
It is important to note that PyTorch 2.0's support for dynamic shapes is not yet fully developed. In most object detection algorithms, not only are the input shapes dynamic, but the loss calculation and post-processing parts are also dynamic. This can lead to slower training speeds when using the torch.compile
function. Therefore, if you wish to enable the torch.compile
function, you should follow these principles:
- Input images to the network are fixed shape, not multi-scale
- set
torch._dynamo.config.cache_size_limit
parameter. TorchDynamo will convert and cache the Python bytecode, and the compiled functions will be stored in the cache. When the next check finds that the function needs to be recompiled, the function will be recompiled and cached. However, if the number of recompilations exceeds the maximum value set (64), the function will no longer be cached or recompiled. As mentioned above, the loss calculation and post-processing parts of the object detection algorithm are also dynamically calculated, and these functions need to be recompiled every time. Therefore, setting thetorch._dynamo.config.cache_size_limit
parameter to a smaller value can effectively reduce the compilation time
In MMDetection, you can set the torch._dynamo.config.cache_size_limit
parameter through the environment variable DYNAMO_CACHE_SIZE_LIMIT
. For example, the command is as follows:
# Single GPU
export DYNAMO_CACHE_SIZE_LIMIT = 4
python tools/train.py configs/rtmdet/rtmdet_s_8xb32-300e_coco.py --cfg-options compile=True
# Single node multiple GPUs
export DYNAMO_CACHE_SIZE_LIMIT = 4
./tools/dist_train.sh configs/rtmdet/rtmdet_s_8xb32-300e_coco.py 8 --cfg-options compile=True
About the common questions about PyTorch 2.0's dynamo, you can refer to here
Compatibility issue between MMCV and MMDetection; "ConvWS is already registered in conv layer"; "AssertionError: MMCV==xxx is used but incompatible. Please install mmcv>=xxx, <=xxx."
Compatible MMDetection, MMEngine, and MMCV versions are shown as below. Please choose the correct version of MMCV to avoid installation issues.
MMDetection version | MMCV version | MMEngine version |
---|---|---|
main | mmcv>=2.0.0, <2.2.0 | mmengine>=0.7.1, <1.0.0 |
3.3.0 | mmcv>=2.0.0, <2.2.0 | mmengine>=0.7.1, <1.0.0 |
3.2.0 | mmcv>=2.0.0, <2.2.0 | mmengine>=0.7.1, <1.0.0 |
3.1.0 | mmcv>=2.0.0, <2.1.0 | mmengine>=0.7.1, <1.0.0 |
3.0.0 | mmcv>=2.0.0, <2.1.0 | mmengine>=0.7.1, <1.0.0 |
3.0.0rc6 | mmcv>=2.0.0rc4, <2.1.0 | mmengine>=0.6.0, <1.0.0 |
3.0.0rc5 | mmcv>=2.0.0rc1, <2.1.0 | mmengine>=0.3.0, <1.0.0 |
3.0.0rc4 | mmcv>=2.0.0rc1, <2.1.0 | mmengine>=0.3.0, <1.0.0 |
3.0.0rc3 | mmcv>=2.0.0rc1, <2.1.0 | mmengine>=0.3.0, <1.0.0 |
3.0.0rc2 | mmcv>=2.0.0rc1, <2.1.0 | mmengine>=0.1.0, <1.0.0 |
3.0.0rc1 | mmcv>=2.0.0rc1, <2.1.0 | mmengine>=0.1.0, <1.0.0 |
3.0.0rc0 | mmcv>=2.0.0rc1, <2.1.0 | mmengine>=0.1.0, <1.0.0 |
Note:
- If you want to install mmdet-v2.x, the compatible MMDetection and MMCV versions table can be found at here. Please choose the correct version of MMCV to avoid installation issues.
- In MMCV-v2.x,
mmcv-full
is rename tommcv
, if you want to installmmcv
without CUDA ops, you can installmmcv-lite
.
-
"No module named 'mmcv.ops'"; "No module named 'mmcv._ext'".
- Uninstall existing
mmcv-lite
in the environment usingpip uninstall mmcv-lite
. - Install
mmcv
following the installation instruction.
- Uninstall existing
-
"Microsoft Visual C++ 14.0 or graeter is required" during installation on Windows.
This error happens when building the 'pycocotools._mask' extension of pycocotools and the environment lacks corresponding C++ compilation dependencies. You need to download it at Microsoft officials visual-cpp-build-tools, select the "Use C ++ Desktop Development" option to install the minimum dependencies, and then reinstall pycocotools.
-
Using Albumentations
If you would like to use
albumentations
, we suggest usingpip install -r requirements/albu.txt
orpip install -U albumentations --no-binary qudida,albumentations
. If you simply usepip install albumentations>=0.3.2
, it will installopencv-python-headless
simultaneously (even though you have already installedopencv-python
). Please refer to the official documentation for details. -
ModuleNotFoundError is raised when using some algorithms
Some extra dependencies are required for Instaboost, Panoptic Segmentation, LVIS dataset, etc. Please note the error message and install corresponding packages, e.g.,
# for instaboost pip install instaboostfast # for panoptic segmentation pip install git+https://github.com/cocodataset/panopticapi.git # for LVIS dataset pip install git+https://github.com/lvis-dataset/lvis-api.git
-
Do I need to reinstall mmdet after some code modifications
If you follow the best practice and install mmdet with
pip install -e .
, any local modifications made to the code will take effect without reinstallation. -
How to develop with multiple MMDetection versions
You can have multiple folders like mmdet-3.0, mmdet-3.1. When you run the train or test script, it will adopt the mmdet package in the current folder.
To use the default MMDetection installed in the environment rather than the one you are working with, you can remove the following line in those scripts:
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH
-
"RTX 30 series card fails when building MMCV or MMDet"
- Temporary work-around: do
MMCV_WITH_OPS=1 MMCV_CUDA_ARGS='-gencode=arch=compute_80,code=sm_80' pip install -e .
. The common issue isnvcc fatal : Unsupported gpu architecture 'compute_86'
. This means that the compiler should optimize for sm_86, i.e., nvidia 30 series card, but such optimizations have not been supported by CUDA toolkit 11.0. This work-around modifies the compile flag by addingMMCV_CUDA_ARGS='-gencode=arch=compute_80,code=sm_80'
, which tellsnvcc
to optimize for sm_80, i.e., Nvidia A100. Although A100 is different from the 30 series card, they use similar ampere architecture. This may hurt the performance but it works. - PyTorch developers have updated that the default compiler flags should be fixed by pytorch/pytorch#47585. So using PyTorch-nightly may also be able to solve the problem, though we have not tested it yet.
- Temporary work-around: do
-
"invalid device function" or "no kernel image is available for execution".
- Check if your cuda runtime version (under
/usr/local/
),nvcc --version
andconda list cudatoolkit
version match. - Run
python mmdet/utils/collect_env.py
to check whether PyTorch, torchvision, and MMCV are built for the correct GPU architecture. You may need to setTORCH_CUDA_ARCH_LIST
to reinstall MMCV. The GPU arch table could be found here, i.e. runTORCH_CUDA_ARCH_LIST=7.0 pip install mmcv
to build MMCV for Volta GPUs. The compatibility issue could happen when using old GPUS, e.g., Tesla K80 (3.7) on colab. - Check whether the running environment is the same as that when mmcv/mmdet has compiled. For example, you may compile mmcv using CUDA 10.0 but run it on CUDA 9.0 environments.
- Check if your cuda runtime version (under
-
"undefined symbol" or "cannot open xxx.so".
- If those symbols are CUDA/C++ symbols (e.g., libcudart.so or GLIBCXX), check whether the CUDA/GCC runtimes are the same as those used for compiling mmcv,
i.e. run
python mmdet/utils/collect_env.py
to see if"MMCV Compiler"
/"MMCV CUDA Compiler"
is the same as"GCC"
/"CUDA_HOME"
. - If those symbols are PyTorch symbols (e.g., symbols containing caffe, aten, and TH), check whether the PyTorch version is the same as that used for compiling mmcv.
- Run
python mmdet/utils/collect_env.py
to check whether PyTorch, torchvision, and MMCV are built by and running on the same environment.
- If those symbols are CUDA/C++ symbols (e.g., libcudart.so or GLIBCXX), check whether the CUDA/GCC runtimes are the same as those used for compiling mmcv,
i.e. run
-
setuptools.sandbox.UnpickleableException: DistutilsSetupError("each element of 'ext_modules' option must be an Extension instance or 2-tuple")
- If you are using miniconda rather than anaconda, check whether Cython is installed as indicated in #3379.
You need to manually install Cython first and then run command
pip install -r requirements.txt
. - You may also need to check the compatibility between the
setuptools
,Cython
, andPyTorch
in your environment.
- If you are using miniconda rather than anaconda, check whether Cython is installed as indicated in #3379.
You need to manually install Cython first and then run command
-
"Segmentation fault".
-
Check you GCC version and use GCC 5.4. This usually caused by the incompatibility between PyTorch and the environment (e.g., GCC < 4.9 for PyTorch). We also recommend the users to avoid using GCC 5.5 because many feedbacks report that GCC 5.5 will cause "segmentation fault" and simply changing it to GCC 5.4 could solve the problem.
-
Check whether PyTorch is correctly installed and could use CUDA op, e.g. type the following command in your terminal.
python -c 'import torch; print(torch.cuda.is_available())'
And see whether they could correctly output results.
-
If Pytorch is correctly installed, check whether MMCV is correctly installed.
python -c 'import mmcv; import mmcv.ops'
If MMCV is correctly installed, then there will be no issue of the above two commands.
-
If MMCV and Pytorch is correctly installed, you man use
ipdb
,pdb
to set breakpoints or directly add 'print' in mmdetection code and see which part leads the segmentation fault.
-
-
"Loss goes Nan"
- Check if the dataset annotations are valid: zero-size bounding boxes will cause the regression loss to be Nan due to the commonly used transformation for box regression. Some small size (width or height are smaller than 1) boxes will also cause this problem after data augmentation (e.g., instaboost). So check the data and try to filter out those zero-size boxes and skip some risky augmentations on the small-size boxes when you face the problem.
- Reduce the learning rate: the learning rate might be too large due to some reasons, e.g., change of batch size. You can rescale them to the value that could stably train the model.
- Extend the warmup iterations: some models are sensitive to the learning rate at the start of the training. You can extend the warmup iterations, e.g., change the
warmup_iters
from 500 to 1000 or 2000. - Add gradient clipping: some models requires gradient clipping to stabilize the training process. The default of
grad_clip
isNone
, you can add gradient clippint to avoid gradients that are too large, i.e., setoptim_wrapper=dict(clip_grad=dict(max_norm=35, norm_type=2))
in your config file.
-
"GPU out of memory"
-
There are some scenarios when there are large amount of ground truth boxes, which may cause OOM during target assignment. You can set
gpu_assign_thr=N
in the config of assigner thus the assigner will calculate box overlaps through CPU when there are more than N GT boxes. -
Set
with_cp=True
in the backbone. This uses the sublinear strategy in PyTorch to reduce GPU memory cost in the backbone. -
Try mixed precision training using following the examples in
config/fp16
. Theloss_scale
might need further tuning for different models. -
Try to use
AvoidCUDAOOM
to avoid GPU out of memory. It will first retry after callingtorch.cuda.empty_cache()
. If it still fails, it will then retry by converting the type of inputs to FP16 format. If it still fails, it will try to copy inputs from GPUs to CPUs to continue computing. Try AvoidOOM in you code to make the code continue to run when GPU memory runs out:from mmdet.utils import AvoidCUDAOOM output = AvoidCUDAOOM.retry_if_cuda_oom(some_function)(input1, input2)
You can also try
AvoidCUDAOOM
as a decorator to make the code continue to run when GPU memory runs out:from mmdet.utils import AvoidCUDAOOM @AvoidCUDAOOM.retry_if_cuda_oom def function(*args, **kwargs): ... return xxx
-
-
"RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one"
- This error indicates that your module has parameters that were not used in producing loss. This phenomenon may be caused by running different branches in your code in DDP mode.
- You can set
find_unused_parameters = True
in the config to solve the above problems, but this will slow down the training speed. - You can set
detect_anomalous_params = True
in the config ormodel_wrapper_cfg = dict(type='MMDistributedDataParallel', detect_anomalous_params=True)
(More details please refer to MMEngine) to get the name of those unused parameters. Notedetect_anomalous_params = True
will slow down the training speed, so it is recommended for debugging only.
-
Save the best model
It can be turned on by configuring
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='auto'),
. In the case of theauto
parameter, the first key in the returned evaluation result will be used as the basis for selecting the best model. You can also directly set the key in the evaluation result to manually set it, for example,save_best='coco/bbox_mAP'
.
- COCO Dataset, AP or AR = -1
- According to the definition of COCO dataset, the small and medium areas in an image are less than 1024 (32*32), 9216 (96*96), respectively.
- If the corresponding area has no object, the result of AP and AR will set to -1.
-
style
in ResNetThe
style
parameter in ResNet allows eitherpytorch
orcaffe
style. It indicates the difference in the Bottleneck module. Bottleneck is a stacking structure of1x1-3x3-1x1
convolutional layers. In the case ofcaffe
mode, the convolution layer withstride=2
is the first1x1
convolution, while inpyorch
mode, it is the second3x3
convolution hasstride=2
. A sample code is as below:if self.style == 'pytorch': self.conv1_stride = 1 self.conv2_stride = stride else: self.conv1_stride = stride self.conv2_stride = 1
-
ResNeXt parameter description
ResNeXt comes from the paper
Aggregated Residual Transformations for Deep Neural Networks
. It introduces group and uses “cardinality” to control the number of groups to achieve a balance between accuracy and complexity. It controls the basic width and grouping parameters of the internal Bottleneck module through two hyperparametersbaseWidth
andcardinality
. An example configuration name in MMDetection ismask_rcnn_x101_64x4d_fpn_mstrain-poly_3x_coco.py
, wheremask_rcnn
represents the algorithm using Mask R-CNN,x101
represents the backbone network using ResNeXt-101, and64x4d
represents that the bottleneck block has 64 group and each group has basic width of 4. -
norm_eval
in backboneSince the detection model is usually large and the input image resolution is high, this will result in a small batch of the detection model, which will make the variance of the statistics calculated by BatchNorm during the training process very large and not as stable as the statistics obtained during the pre-training of the backbone network . Therefore, the
norm_eval=True
mode is generally used in training, and the BatchNorm statistics in the pre-trained backbone network are directly used. The few algorithms that use large batches are thenorm_eval=False
mode, such as NASFPN. For the backbone network without ImageNet pre-training and the batch is relatively small, you can consider usingSyncBN
.