diff --git a/examples/pytorch/eeg-gcnn/EEGGraphDataset.py b/examples/pytorch/eeg-gcnn/EEGGraphDataset.py index 8af8f6c59dc7..affda325f878 100644 --- a/examples/pytorch/eeg-gcnn/EEGGraphDataset.py +++ b/examples/pytorch/eeg-gcnn/EEGGraphDataset.py @@ -84,17 +84,22 @@ def get_sensor_distances(self): def get_geodesic_distance( self, montage_sensor1_idx, montage_sensor2_idx, coords_1010 ): + def get_coord(ref_sensor, coord): + return float( + (coords_1010[coords_1010.label == ref_sensor][coord]).iloc[0] + ) + # get the reference sensor in the 10-10 system for the current montage pair in 10-20 system ref_sensor1 = self.ref_names[montage_sensor1_idx] ref_sensor2 = self.ref_names[montage_sensor2_idx] - x1 = float(coords_1010[coords_1010.label == ref_sensor1]["x"]) - y1 = float(coords_1010[coords_1010.label == ref_sensor1]["y"]) - z1 = float(coords_1010[coords_1010.label == ref_sensor1]["z"]) + x1 = get_coord(ref_sensor1, "x") + y1 = get_coord(ref_sensor1, "y") + z1 = get_coord(ref_sensor1, "z") - x2 = float(coords_1010[coords_1010.label == ref_sensor2]["x"]) - y2 = float(coords_1010[coords_1010.label == ref_sensor2]["y"]) - z2 = float(coords_1010[coords_1010.label == ref_sensor2]["z"]) + x2 = get_coord(ref_sensor2, "x") + y2 = get_coord(ref_sensor2, "y") + z2 = get_coord(ref_sensor2, "z") # https://math.stackexchange.com/questions/1304169/distance-between-two-points-on-a-sphere r = 1 # since coords are on unit sphere