Skip to content

Commit

Permalink
SCAN THE CODE: refactor classifier using base class main function
Browse files Browse the repository at this point in the history
  • Loading branch information
kev-the-dev committed Apr 28, 2018
1 parent e03f75a commit 1b2accd
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,16 @@ class ScanTheCodeClassifier(GaussianColorClassifier):
def __init__(self):
rospack = RosPack()
path = rospack.get_path('navigator_vision')
self.features_file = os.path.join(path, 'config/stc_colors.csv')
super(ScanTheCodeClassifier, self).__init__(ScanTheCodeClassifier.CLASSES)

def train_from_csv(self):
return super(ScanTheCodeClassifier, self).train_from_csv(self.features_file)

def save_csv(self, features, classes):
return super(ScanTheCodeClassifier, self).save_csv(features, classes, filename=self.features_file)
training_file = os.path.join(path, 'config/stc/training.csv')
labelfile = os.path.join(path, 'config/stc/labels.json')
super(ScanTheCodeClassifier, self).__init__(ScanTheCodeClassifier.CLASSES,
training_file=training_file, labelfile=labelfile)


if __name__ == '__main__':
'''
When run as an executable, saves the training features to a csv file
2 arguemnts: labelbox.io labelfile, and image directory
Can be run as executable to extract features or check accuracy score
'''
import sys
labelfile = sys.argv[1]
image_dir = sys.argv[2]
s = ScanTheCodeClassifier()
features, classes = s.extract_labels(labelfile, image_dir)
s.save_csv(features, classes)
s.main(sys.argv[1:])
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python
from mil_vision_tools import GaussianColorClassifier
from rospkg import RosPack
import os


class TotemsColorClassifier(GaussianColorClassifier):
CLASSES = ['{}_totem'.format(color) for color in ['white', 'red', 'green', 'blue', 'yellow']]

def __init__(self):
rospack = RosPack()
path = rospack.get_path('navigator_vision')
training_file = os.path.join(path, 'config/totems_color/training.csv')
labelfile = os.path.join(path, 'config/totems_color/labels.json')
super(TotemsColorClassifier, self).__init__(self.CLASSES,
training_file=training_file, labelfile=labelfile)


if __name__ == '__main__':
'''
Can be run as executable to extract features or check accuracy score
'''
import sys
c = TotemsColorClassifier()
c.main(sys.argv[1:])

0 comments on commit 1b2accd

Please sign in to comment.