This is a PyTorch code for Driver Posture Classification task. We use the AUC Distracted Driver Dataset. The dataset was captured to develop the state-of-the-art in detection of distracted drivers. Here are some samples from the dataset:
The task is to classify an image to one of these pre-defined categories, namely "Drive Safe", "Talk Passenger", "Text Right", "Drink", and etc. We use a pretrained resnet34 model to achieve comparable performance with the orignal paper Real-time Distracted Driver Posture Classification. The classification accuracy is about 95%.
- python 3.5+
- pytorch 0.4
- visdom (optional)
-
Download the dataset and its training and testing splits (train.csv and test.csv). Put them in a directory together.
-
Clone the repository
git clone https://github.com/husencd/DriverPostureClassification.git
cd DriverPostureClassification
-
Download the resnet model pretrained on ImageNet from pytorch official model urls.
cd pretrained_models
sh download.sh
-
Now you can train/fine-tune the model
cd ..
python main.py [--model resnet] [--model_depth 34]
If you want to monitor the training process, use visdom
python -m visdom.server
- Our code is partially based on https://github.com/chenyuntc/pytorch-best-practice.