-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_lfd_dataset.m
76 lines (62 loc) · 2.14 KB
/
demo_lfd_dataset.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
% Script for a robot following the demonstration trajectories
%
% Author
% Sipu Ruan, 2022
close all; clear; clc;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Tunable parameters
% ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
% Dataset and demo set info
dataset_name = 'panda_arm';
% dataset_name = 'lasa_handwriting/pose_data';
% Whether to robot or not
is_load_robot = true;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
demo_type = load_dataset_param(dataset_name);
data_folder = strcat("../data/", dataset_name, "/", demo_type, "/");
if is_load_robot
robot = loadrobot('frankaEmikaPanda');
ik = inverseKinematics('RigidBodyTree', robot);
weights = ones(1,6);
initialguess = robot.homeConfiguration;
end
% Demo trajectories
for i = 1:length(demo_type)
argin.n_step = 50;
argin.group_name = "PCG";
argin.data_folder = data_folder(i);
% Load demos
filenames = dir(strcat(argin.data_folder, "*.json"));
g_demo = parse_demo_trajectory(filenames, argin);
% Compute trajectory distribution from demonstrations
[g_mean, cov_t] = get_pdf_from_demo(g_demo, argin.group_name);
n_demo = length(g_demo);
%% PLOT: Demonstrations
figure;
% (Optional) Robot following mean trajectory
if is_load_robot
for j = 1:20:argin.n_step
config = ik('panda_link8', g_mean.matrix(:,:,j), weights,...
initialguess);
show(robot, config);
hold on; axis off;
end
end
hold on; axis equal; axis off;
% Demos
for j = 1:n_demo
plot3(g_demo{j}.pose(1,:), g_demo{j}.pose(2,:),...
g_demo{j}.pose(3,:))
end
% Prior mean poses
pose_mean = g_mean.pose;
plot3(pose_mean(1,:), pose_mean(2,:), pose_mean(3,:),...
'b-', 'LineWidth', 3)
plot3(pose_mean(1,1), pose_mean(2,1), pose_mean(3,1), 'go',...
'LineWidth', 1.5)
plot3(pose_mean(1,end), pose_mean(2,end), pose_mean(3,end), 'r*',...
'LineWidth', 1.5)
% Save figure
fig_name = strrep(demo_type{i}, '/', '_');
saveas(gcf, [fig_name, '.png'])
end