diff --git a/behavior_metrics/brains/brains_handler.py b/behavior_metrics/brains/brains_handler.py index 7977ac2a..fbadd367 100755 --- a/behavior_metrics/brains/brains_handler.py +++ b/behavior_metrics/brains/brains_handler.py @@ -44,7 +44,13 @@ def load_brain(self, path, model=None): module = importlib.import_module(import_name) Brain = getattr(module, 'Brain') if robot_type == 'drone': - self.active_brain = Brain(handler=self, config=self.config) + # self.active_brain = Brain(handler=self, config=self.config) + if model: + self.active_brain = Brain(model=model, handler=self, config=self.config) + elif hasattr(self, 'model'): + self.active_brain = Brain(model=self.model, handler=self, config=self.config) + else: + self.active_brain = Brain(handler=self, config=self.config) else: if model: self.active_brain = Brain(self.sensors, self.actuatrors, model=model, handler=self, config=self.config) diff --git a/behavior_metrics/brains/drone/brain_drone_explicit.py b/behavior_metrics/brains/drone/brain_drone_explicit.py index 9718fe5c..8e665e8b 100644 --- a/behavior_metrics/brains/drone/brain_drone_explicit.py +++ b/behavior_metrics/brains/drone/brain_drone_explicit.py @@ -152,7 +152,7 @@ def execute(self): try: #cv2.imwrite(SAVE_DIR + 'many_curves_data/Images/image{}.png'.format(self.iteration), cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) #print('written many_curves_data/Images/image{}.png'.format(self.iteration)) - image_cropped = image[230:, :, :] + image_cropped = image[120:240,0:320] image_hsv = cv2.cvtColor(image_cropped, cv2.COLOR_BGR2HSV) lower_red = np.array([0,50,50]) upper_red = np.array([180,255,255]) diff --git a/behavior_metrics/brains/drone/brain_drone_torch.py b/behavior_metrics/brains/drone/brain_drone_torch.py new file mode 100644 index 00000000..6af89f82 --- /dev/null +++ b/behavior_metrics/brains/drone/brain_drone_torch.py @@ -0,0 +1,167 @@ +""" + Robot: drone + Framework: torch + Number of networks: 1 + Network type: None + Predicionts: + linear speed(v) + angular speed(w) + z-velocity (vz) + +""" + +import torch +import torchvision +from torchvision import transforms +import numpy as np +import cv2 +import time +import os +from PIL import Image +from brains.drone.torch_utils.deeppilot import DeepPilot +from drone_wrapper import DroneWrapper +from utils.constants import PRETRAINED_MODELS_DIR, ROOT_PATH +from os import path +from collections import deque + +PRETRAINED_MODELS = ROOT_PATH + '/' + PRETRAINED_MODELS_DIR + 'torch_drone_models/' +FLOAT = torch.FloatTensor + +class Brain: + """Specific brain for the f1 robot. See header.""" + + def __init__(self, model=None, handler=None, config=None): + """Constructor of the class. + + Arguments: + sensors {robot.sensors.Sensors} -- Sensors instance of the robot + actuators {robot.actuators.Actuators} -- Actuators instance of the robot + + Keyword Arguments: + handler {brains.brain_handler.Brains} -- Handler of the current brain. Communication with the controller + (default: {None}) + """ + self.drone = DroneWrapper() + self.handler = handler + # self.drone.takeoff() + self.takeoff = False + self.speed_history = deque([], maxlen=100) + self.speedz_history = deque([0]*100, maxlen=100) + self.rot_history = deque([], maxlen=1) + + self.handler = handler + self.cont = 0 + self.iteration = 0 + self.inference_times = [] + self.device = torch.device("cpu") + self.gpu_inferencing = torch.cuda.is_available() + self.first_image = None + self.transformations = transforms.Compose([ + transforms.ToTensor() + ]) + + if config: + if 'ImageCrop' in config.keys(): + self.cropImage = config['ImageCrop'] + else: + self.cropImage = True + + if model: + if not path.exists(PRETRAINED_MODELS + model): + print("File " + model + " cannot be found in " + PRETRAINED_MODELS) + + self.net = DeepPilot((224,224,3), 3).to(self.device) + self.net.load_state_dict(torch.load(PRETRAINED_MODELS + model,map_location=self.device)) + else: + print("Brain not loaded") + + def update_frame(self, frame_id, data): + """Update the information to be shown in one of the GUI's frames. + + Arguments: + frame_id {str} -- Id of the frame that will represent the data + data {*} -- Data to be shown in the frame. Depending on the type of frame (rgbimage, laser, pose3d, etc) + """ + self.handler.update_frame(frame_id, data) + + def getPose3d(self): + return self.drone.get_position() + + def addPadding(self, img): + + target_height = int(224) + target_width = int(target_height * img.shape[1]/img.shape[0]) + img_resized = cv2.resize(img, (target_width, target_height)) + padding_left = int((224 - target_width)/2) + padding_right = 224 - target_width - padding_left + img = cv2.copyMakeBorder(img_resized.copy(),0,0,padding_left,padding_right,cv2.BORDER_CONSTANT,value=[0, 0, 0]) + return img + + + def execute(self): + """Main loop of the brain. This will be called iteratively each TIME_CYCLE (see pilot.py)""" + + self.cont += 1 + + if self.iteration == 0 and not self.takeoff: + self.drone.takeoff() + self.takeoff = True + self.initial_flight_done = False + + img_frontal = self.drone.get_frontal_image() + img_ventral = self.drone.get_ventral_image() + + if self.cont == 1 and img_frontal.shape == (3, 3, 3) or img_ventral.shape == (3, 3, 3): + time.sleep(3) + self.cont = 0 + else: + self.first_image = img_frontal + + self.update_frame('frame_0', img_frontal) + self.update_frame('frame_1', img_ventral) + + image = img_frontal + + try: + if self.cropImage: + image = image[120:240,0:320] + else: + image = self.addPadding(image) + show_image = image + X = image.copy() + if X is not None: + X = cv2.resize(X, (224, 224)) + X = np.transpose(X,(2,0,1)) + X = np.squeeze(X) + X = np.transpose(X, (1,2,0)) + img = Image.fromarray(X) + start_time = time.time() + with torch.no_grad(): + image = self.transformations(img).unsqueeze(0) + image = FLOAT(image).to(self.device) + prediction = self.net(image).numpy() + self.inference_times.append(time.time() - start_time) + + prediction_v = prediction[0][0] + prediction_w = prediction[0][1] + prediction_vz = prediction[0][2] + if prediction_w != '' and prediction_w != '' and prediction_vz != '': + self.speed_history.append(prediction_v) + self.speedz_history.append(prediction_w) + self.rot_history.append(prediction_vz) + + speed_cmd = np.mean(self.speed_history) + speed_z_cmd = np.clip(np.mean(self.speedz_history),-2,2) + rotation_cmd = np.mean(self.rot_history) + + self.drone.set_cmd_vel(speed_cmd, 0, speed_z_cmd, rotation_cmd) + + self.iteration += 1 + + + except Exception as err: + print(err) + + self.update_frame('frame_0', show_image) + + diff --git a/behavior_metrics/brains/drone/brain_drone_torch_pilotnet.py b/behavior_metrics/brains/drone/brain_drone_torch_pilotnet.py new file mode 100644 index 00000000..2077497e --- /dev/null +++ b/behavior_metrics/brains/drone/brain_drone_torch_pilotnet.py @@ -0,0 +1,167 @@ +""" + Robot: drone + Framework: torch + Number of networks: 1 + Network type: None + Predicionts: + linear speed(v) + angular speed(w) + z-velocity (vz) + +""" + +import torch +import torchvision +from torchvision import transforms +import numpy as np +import cv2 +import time +import os +from PIL import Image +from brains.drone.torch_utils.pilotnet import PilotNet +from drone_wrapper import DroneWrapper +from utils.constants import PRETRAINED_MODELS_DIR, ROOT_PATH +from os import path +from collections import deque + +PRETRAINED_MODELS = ROOT_PATH + '/' + PRETRAINED_MODELS_DIR + 'torch_drone_models/' +FLOAT = torch.FloatTensor + +class Brain: + """Specific brain for the f1 robot. See header.""" + + def __init__(self, model=None, handler=None, config=None): + """Constructor of the class. + + Arguments: + sensors {robot.sensors.Sensors} -- Sensors instance of the robot + actuators {robot.actuators.Actuators} -- Actuators instance of the robot + + Keyword Arguments: + handler {brains.brain_handler.Brains} -- Handler of the current brain. Communication with the controller + (default: {None}) + """ + self.drone = DroneWrapper() + self.handler = handler + # self.drone.takeoff() + self.takeoff = False + self.speed_history = deque([], maxlen=100) + self.speedz_history = deque([0]*100, maxlen=100) + self.rot_history = deque([], maxlen=1) + + self.handler = handler + self.cont = 0 + self.iteration = 0 + self.inference_times = [] + self.device = torch.device("cpu") + self.gpu_inferencing = torch.cuda.is_available() + self.first_image = None + self.transformations = transforms.Compose([ + transforms.ToTensor() + ]) + + if config: + if 'ImageCrop' in config.keys(): + self.cropImage = config['ImageCrop'] + else: + self.cropImage = True + + if model: + if not path.exists(PRETRAINED_MODELS + model): + print("File " + model + " cannot be found in " + PRETRAINED_MODELS) + + self.net = PilotNet((200,66,3), 3).to(self.device) + self.net.load_state_dict(torch.load(PRETRAINED_MODELS + model,map_location=self.device)) + else: + print("Brain not loaded") + + def update_frame(self, frame_id, data): + """Update the information to be shown in one of the GUI's frames. + + Arguments: + frame_id {str} -- Id of the frame that will represent the data + data {*} -- Data to be shown in the frame. Depending on the type of frame (rgbimage, laser, pose3d, etc) + """ + self.handler.update_frame(frame_id, data) + + def getPose3d(self): + return self.drone.get_position() + + def addPadding(self, img): + + target_height = int(66) + target_width = int(target_height * img.shape[1]/img.shape[0]) + img_resized = cv2.resize(img, (target_width, target_height)) + padding_left = int((200 - target_width)/2) + padding_right = 200 - target_width - padding_left + img = cv2.copyMakeBorder(img_resized.copy(),0,0,padding_left,padding_right,cv2.BORDER_CONSTANT,value=[0, 0, 0]) + return img + + + def execute(self): + """Main loop of the brain. This will be called iteratively each TIME_CYCLE (see pilot.py)""" + + self.cont += 1 + + if self.iteration == 0 and not self.takeoff: + self.drone.takeoff() + self.takeoff = True + self.initial_flight_done = False + + img_frontal = self.drone.get_frontal_image() + img_ventral = self.drone.get_ventral_image() + + if self.cont == 1 and img_frontal.shape == (3, 3, 3) or img_ventral.shape == (3, 3, 3): + time.sleep(3) + self.cont = 0 + else: + self.first_image = img_frontal + + self.update_frame('frame_0', img_frontal) + self.update_frame('frame_1', img_ventral) + + image = img_frontal + + try: + if self.cropImage: + image = image[120:240,0:320] + else: + image = self.addPadding(image) + show_image = image + X = image.copy() + if X is not None: + X = cv2.resize(X, (200, 66)) + X = np.transpose(X,(2,0,1)) + X = np.squeeze(X) + X = np.transpose(X, (1,2,0)) + img = Image.fromarray(X) + start_time = time.time() + with torch.no_grad(): + image = self.transformations(img).unsqueeze(0) + image = FLOAT(image).to(self.device) + prediction = self.net(image).numpy() + self.inference_times.append(time.time() - start_time) + + prediction_v = prediction[0][0] + prediction_w = prediction[0][1] + prediction_vz = prediction[0][2] + if prediction_w != '' and prediction_w != '' and prediction_vz != '': + self.speed_history.append(prediction_v) + self.speedz_history.append(prediction_w) + self.rot_history.append(prediction_vz) + + speed_cmd = np.mean(self.speed_history) + speed_z_cmd = np.clip(np.mean(self.speedz_history),-2,2) + rotation_cmd = np.mean(self.rot_history) + + self.drone.set_cmd_vel(speed_cmd, 0, speed_z_cmd, rotation_cmd) + + self.iteration += 1 + + + except Exception as err: + print(err) + + self.update_frame('frame_0', show_image) + + diff --git a/behavior_metrics/brains/drone/torch_utils/deeppilot.py b/behavior_metrics/brains/drone/torch_utils/deeppilot.py new file mode 100644 index 00000000..a4e778f9 --- /dev/null +++ b/behavior_metrics/brains/drone/torch_utils/deeppilot.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn + + +class DeepPilot(nn.Module): + def __init__(self, + image_shape, num_labels): + super(DeepPilot, self).__init__() + + self.img_height = image_shape[0] + self.img_width = image_shape[1] + self.num_channels = image_shape[2] + + self.cn_1 = nn.Conv2d(self.num_channels, 64, kernel_size=(7,7), stride=(2,2)) + self.po_1 = nn.MaxPool2d(kernel_size=(3,3),stride=(2,2)) + self.ln_1 = nn.BatchNorm2d(64) + self.re_1 = nn.Conv2d(64, 64, kernel_size=(1,1)) + + self.cn_2 = nn.Conv2d(64, 192, kernel_size=(3,3)) + self.ln_2 = nn.BatchNorm2d(192) + self.po_2 = nn.MaxPool2d(kernel_size=(3,3),stride=(2,2)) + + self.im_1_re_1 = nn.Conv2d(192, 96, kernel_size=(1,1)) + self.im_1_o_1 = nn.Conv2d(96, 128, kernel_size=(3,3),padding=(1,1)) + self.im_1_re_2 = nn.Conv2d(192, 16, kernel_size=(3,3)) + self.im_1_o_2 = nn.Conv2d(16, 32, kernel_size=(5,5),padding=(3,3)) + self.im_1_re_3 = nn.MaxPool2d(kernel_size=(3,3),stride=(1,1)) + self.im_1_o_3 = nn.Conv2d(192, 32, kernel_size=(1,1),padding=(1,1)) + self.im_1_o_0 = nn.Conv2d(192, 64, kernel_size=(1,1)) + + self.im_2_re_1 = nn.Conv2d(128+32+32+64, 128, kernel_size=(1,1)) + self.im_2_o_1 = nn.Conv2d(128, 192, kernel_size=(3,3),padding=(1,1)) + self.im_2_re_2 = nn.Conv2d(128+32+32+64, 32, kernel_size=(1,1)) + self.im_2_o_2 = nn.Conv2d(32, 96, kernel_size=(5,5),padding=(2,2)) + self.im_2_re_3 = nn.MaxPool2d(kernel_size=(3,3),stride=(1,1)) + self.im_2_o_3 = nn.Conv2d(128+32+32+64, 64, kernel_size=(1,1),padding=(1,1)) + self.im_2_o_0 = nn.Conv2d(128+32+32+64, 128, kernel_size=(1,1)) + + self.im_3_in = nn.MaxPool2d(kernel_size=(3,3),stride=(2,2)) + + self.im_3_re_1 = nn.Conv2d(192+96+64+128, 96, kernel_size=(1,1)) + self.im_3_o_1 = nn.Conv2d(96, 208, kernel_size=(3,3),padding=(1,1)) + self.im_3_re_2 = nn.Conv2d(192+96+64+128, 16, kernel_size=(1,1)) + self.im_3_o_2 = nn.Conv2d(16, 48, kernel_size=(5,5),padding=(2,2)) + self.im_3_re_3 = nn.MaxPool2d(kernel_size=(3,3),stride=(1,1)) + self.im_3_o_3 = nn.Conv2d(192+96+64+128, 64, kernel_size=(1,1),padding=(1,1)) + self.im_3_o_0 = nn.Conv2d(192+96+64+128, 192, kernel_size=(1,1)) + + self.last_po_1 = nn.AvgPool2d(kernel_size=(5,5),stride=(3,3)) + self.last_re_1 = nn.Conv2d(208+48+64+192, 128, kernel_size=(1,1)) + + self.fc_r1 = nn.Linear(128 * 3 * 3, 1024) + self.fc_p1 = nn.Linear(128 * 3 * 3, 1024) + self.fc_y1 = nn.Linear(128 * 3 * 3, 1024) + # self.fc_a1 = nn.Linear(128 * 3 * 3, 1024) + + self.fc_r2 = nn.Linear(1024, 1) + self.fc_p2 = nn.Linear(1024, 1) + self.fc_y2 = nn.Linear(1024, 1) + # self.fc_a2 = nn.Linear(1024, 1) + + def forward(self, img): + + inp = self.cn_1(img) + inp = torch.relu(inp) + inp = self.po_1(inp) + inp = torch.relu(inp) + inp = self.ln_1(inp) + inp = torch.relu(inp) + inp = self.re_1(inp) + inp = torch.relu(inp) + inp = self.cn_2(inp) + inp = torch.relu(inp) + inp = self.ln_2(inp) + inp = torch.relu(inp) + inp = self.po_2(inp) + inp = torch.relu(inp) + + icp1_out1 = self.im_1_re_1(inp) + icp1_out1 = torch.relu(icp1_out1) + icp1_out1 = self.im_1_o_1(icp1_out1) + icp1_out1 = torch.relu(icp1_out1) + + icp1_out2 = self.im_1_re_2(inp) + icp1_out2 = torch.relu(icp1_out2) + icp1_out2 = self.im_1_o_2(icp1_out2) + icp1_out2 = torch.relu(icp1_out2) + + icp1_out3 = self.im_1_re_3(inp) + icp1_out3 = self.im_1_o_3(icp1_out3) + icp1_out3 = torch.relu(icp1_out3) + + icp1_out0 = self.im_1_o_0(inp) + + icp1_out = torch.cat((icp1_out0, icp1_out1, icp1_out2, icp1_out3), dim=1) + + icp2_out1 = self.im_2_re_1(icp1_out) + icp2_out1 = torch.relu(icp2_out1) + icp2_out1 = self.im_2_o_1(icp2_out1) + icp2_out1 = torch.relu(icp2_out1) + + icp2_out2 = self.im_2_re_2(icp1_out) + icp2_out2 = torch.relu(icp2_out2) + icp2_out2 = self.im_2_o_2(icp2_out2) + icp2_out2 = torch.relu(icp2_out2) + + icp2_out3 = self.im_2_re_3(icp1_out) + icp2_out3 = self.im_2_o_3(icp2_out3) + icp2_out3 = torch.relu(icp2_out3) + + icp2_out0 = self.im_2_o_0(icp1_out) + + icp2_out = torch.cat((icp2_out0, icp2_out1, icp2_out2, icp2_out3), dim=1) + + icp3_in = self.im_3_in(icp2_out) + + icp3_out1 = self.im_3_re_1(icp3_in) + icp3_out1 = torch.relu(icp3_out1) + icp3_out1 = self.im_3_o_1(icp3_out1) + icp3_out1 = torch.relu(icp3_out1) + + icp3_out2 = self.im_3_re_2(icp3_in) + icp3_out2 = torch.relu(icp3_out2) + icp3_out2 = self.im_3_o_2(icp3_out2) + icp3_out2 = torch.relu(icp3_out2) + + icp3_out3 = self.im_3_re_3(icp3_in) + icp3_out3 = self.im_3_o_3(icp3_out3) + icp3_out3 = torch.relu(icp3_out3) + + icp3_out0 = self.im_3_o_0(icp3_in) + + icp3_out = torch.cat((icp3_out0, icp3_out1, icp3_out2, icp3_out3), dim=1) + + out = self.last_po_1(icp3_out) + out = self.last_re_1(out) + out = torch.relu(out) + + out = out.reshape(out.size(0), -1) + + rout = self.fc_r1(out) + rout = torch.relu(rout) + rout = self.fc_r2(rout) + + pout = self.fc_p1(out) + pout = torch.relu(pout) + pout = self.fc_p2(pout) + + yout = self.fc_y1(out) + yout = torch.relu(yout) + yout = self.fc_y2(yout) + + # aout = self.fc_a1(out) + # aout = torch.relu(aout) + # aout = self.fc_a2(aout) + + # out_final = torch.cat((rout, pout, yout, aout), dim=1) + out_final = torch.cat((rout, pout, yout), dim=1) + + return out_final \ No newline at end of file diff --git a/behavior_metrics/brains/drone/torch_utils/pilotnet.py b/behavior_metrics/brains/drone/torch_utils/pilotnet.py new file mode 100644 index 00000000..b61658ff --- /dev/null +++ b/behavior_metrics/brains/drone/torch_utils/pilotnet.py @@ -0,0 +1,59 @@ + +import torch +import torch.nn as nn + + +class PilotNet(nn.Module): + def __init__(self, + image_shape, + num_labels): + super(PilotNet, self).__init__() + + self.img_height = image_shape[0] + self.img_width = image_shape[1] + self.num_channels = image_shape[2] + + self.output_size = num_labels + + self.ln_1 = nn.BatchNorm2d(self.num_channels, eps=1e-03) + + self.cn_1 = nn.Conv2d(self.num_channels, 24, kernel_size=5, stride=2) + self.cn_2 = nn.Conv2d(24, 36, kernel_size=5, stride=2) + self.cn_3 = nn.Conv2d(36, 48, kernel_size=5, stride=2) + self.cn_4 = nn.Conv2d(48, 64, kernel_size=3, stride=1) + self.cn_5 = nn.Conv2d(64, 64, kernel_size=3, stride=1) + + self.fc_1 = nn.Linear(1 * 18 * 64, 1164) + self.fc_2 = nn.Linear(1164, 100) + self.fc_3 = nn.Linear(100, 50) + self.fc_4 = nn.Linear(50, 10) + self.fc_5 = nn.Linear(10, self.output_size) + + def forward(self, img): + + out = self.ln_1(img) + + out = self.cn_1(out) + out = torch.relu(out) + out = self.cn_2(out) + out = torch.relu(out) + out = self.cn_3(out) + out = torch.relu(out) + out = self.cn_4(out) + out = torch.relu(out) + out = self.cn_5(out) + out = torch.relu(out) + + out = out.reshape(out.size(0), -1) + + out = self.fc_1(out) + out = torch.relu(out) + out = self.fc_2(out) + out = torch.relu(out) + out = self.fc_3(out) + out = torch.relu(out) + out = self.fc_4(out) + out = torch.relu(out) + out = self.fc_5(out) + + return out diff --git a/behavior_metrics/configs/drone-torch.yml b/behavior_metrics/configs/drone-torch.yml new file mode 100644 index 00000000..89b93927 --- /dev/null +++ b/behavior_metrics/configs/drone-torch.yml @@ -0,0 +1,53 @@ +Behaviors: + Robot: + Sensors: + Cameras: + Camera_0: + Name: 'camera_0' + Topic: '/F1ROS/cameraL/image_raw' + Pose3D: + Pose3D_0: + Name: 'pose3d_0' + Topic: '/F1ROS/odom' + Actuators: + Motors: + Motors_0: + Name: 'motors_0' + Topic: '/F1ROS/cmd_vel' + MaxV: 3 + MaxW: 0.3 + BrainPath: 'brains/drone/brain_drone_torch.py' + Parameters: + Model: 'model_deeppilot_torch.ckpt' + ImageCrop: True + Type: 'drone' + Experiment: + Name: "Experiment name" + Description: "Experiment description" + Timeout: 30 + Repetitions: 2 + Simulation: + World: /opt/jderobot/share/jderobot/gazebo/launch/simple_circuit_drone.launch + Dataset: + In: '/tmp/my_bag.bag' + Out: '' + Stats: + Out: './' + PerfectLap: './perfect_bags/lap-simple-circuit.bag' + Layout: + Frame_0: + Name: frame_0 + Geometry: [1, 1, 2, 2] + Data: rgbimage + Frame_1: + Name: frame_1 + Geometry: [0, 1, 1, 1] + Data: rgbimage + Frame_2: + Name: frame_2 + Geometry: [0, 2, 1, 1] + Data: rgbimage + Frame_3: + Name: frame_3 + Geometry: [0, 3, 3, 1] + Data: rgbimage diff --git a/behavior_metrics/models/torch_drone_models/model_deeppilot_torch.ckpt b/behavior_metrics/models/torch_drone_models/model_deeppilot_torch.ckpt new file mode 100644 index 00000000..e656e6eb Binary files /dev/null and b/behavior_metrics/models/torch_drone_models/model_deeppilot_torch.ckpt differ