Skip to content

Commit

Permalink
SCAN THE CODE: refactor classifier using base class main function
Browse files Browse the repository at this point in the history
  • Loading branch information
kev-the-dev committed Apr 28, 2018
1 parent e03f75a commit a806d0a
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,16 @@ class ScanTheCodeClassifier(GaussianColorClassifier):
def __init__(self):
rospack = RosPack()
path = rospack.get_path('navigator_vision')
self.features_file = os.path.join(path, 'config/stc_colors.csv')
super(ScanTheCodeClassifier, self).__init__(ScanTheCodeClassifier.CLASSES)

def train_from_csv(self):
return super(ScanTheCodeClassifier, self).train_from_csv(self.features_file)

def save_csv(self, features, classes):
return super(ScanTheCodeClassifier, self).save_csv(features, classes, filename=self.features_file)
training_file = os.path.join(path, 'config/stc/training.csv')
labelfile = os.path.join(path, 'config/stc/labels.json')
super(ScanTheCodeClassifier, self).__init__(ScanTheCodeClassifier.CLASSES,
training_file=training_file, labelfile=labelfile)


if __name__ == '__main__':
'''
When run as an executable, saves the training features to a csv file
2 arguemnts: labelbox.io labelfile, and image directory
Can be run as executable to extract features or check accuracy score
'''
import sys
labelfile = sys.argv[1]
image_dir = sys.argv[2]
s = ScanTheCodeClassifier()
features, classes = s.extract_labels(labelfile, image_dir)
s.save_csv(features, classes)
s.main(sys.argv[1:])

0 comments on commit a806d0a

Please sign in to comment.