Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mpelchat04 committed Nov 2, 2021
1 parent 190e2ae commit 74e7216
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions airborne_lidar/airborne_lidar_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,41 +423,42 @@ def main():
if len(dataset_dict[dataset]) == 0:
warnings.warn(f"{base_dir / dataset} is empty")

print(f"Las files per dataset:\n Trn: {len(dataset_dict['trn'])} \n Val: {len(dataset_dict['val'])} \n Tst: {len(dataset_dict['tst'])}")

info_class = class_mode(args['training']['mode'])
if args['test']['test_model'] is None:
print(f"Prepared files per dataset:\n Trn: {len(dataset_dict['trn'])} \n Val: {len(dataset_dict['val'])}")
# Train + Validate model
model_folder = train(args, dataset_dict, info_class)

else:
# Test only
print(f"Prepared files for test")
model_folder = Path(args['test']['test_model'])

# Test model
if args['test']['test']:

# Uses .hdfs files from the test folder to process.
if args['test']['test_tiles'] is None:
print(f" \n Tst: {len(dataset_dict['tst'])}")
for filename in dataset_dict['tst']:
test(args, filename, model_folder, args['global']['rootdir'], info_class, dataset_dict['tst'].index(filename))

# Uses list of .las files from a provided folder.
else:
root_folder = Path(args['test']['test_tiles'])
files = list(root_folder.glob('*.las'))

print(f" \n Tst: {len(dataset_dict['tst'])}")
csv_file = root_folder / Path('Classification_comparison.csv')
iou_csv = CSVWriter(csv_file)
iou_csv.write(('filename', 'overall_iou', 'per_class_iou'))
for filename in files:
xyzni, label, nb_pts, header = read_las_format(filename)
prep_filename = Path(f"{filename.parent / Path(filename.stem)}_prepared.hdfs")
write_features(prep_filename, xyzni=xyzni, labels=label)
print(list(files))
iou = test(args, prep_filename.stem, model_folder, prep_filename.parent, info_class, files.index(filename), header=header)
line = (filename.stem, f"{iou[0]:.3f}")
line += tuple(iou[1])
print(line)
iou_csv.write(line)
iou_csv.close()
if __name__ == '__main__':
Expand Down

0 comments on commit 74e7216

Please sign in to comment.