diff --git a/Lecture_TUKR/tanaka/data_scratch_tanaka.py b/Lecture_TUKR/tanaka/data_scratch_tanaka.py new file mode 100644 index 00000000..98db63ca --- /dev/null +++ b/Lecture_TUKR/tanaka/data_scratch_tanaka.py @@ -0,0 +1,38 @@ +import numpy as np +import matplotlib.pyplot as plt + +def load_kura_tsom(xsamples, ysamples, missing_rate=None,retz=False): + z1 = np.linspace(-1,1,xsamples) + z2 = np.linspace(-1,1,ysamples) + + z1_repeated, z2_repeated = np.meshgrid(z1,z2) + x1 = z1_repeated + x2 = z2_repeated + x3 = (x1**2-x2**2) + #ノイズを加えたい時はここをいじる,locがガウス分布の平均、scaleが分散,size何個ノイズを作るか + #このノイズを加えることによって三次元空間のデータ点は上下に動く + + x = np.concatenate((x1[:, :, np.newaxis], x2[:, :, np.newaxis], x3[:, :, np.newaxis]), axis=2) + truez = np.concatenate((z1_repeated[:, :, np.newaxis], z2_repeated[:, :, np.newaxis]), axis=2) + print(x.shape) + + if missing_rate == 0 or missing_rate == None: + if retz: + return x, truez + else: + return x + +if __name__ == '__main__': + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D + + xsamples = 10 + ysamples = 15 + + x, truez = load_kura_tsom(xsamples,ysamples,retz=True) + + fig = plt.figure(figsize=[5, 5]) + ax_x = fig.add_subplot(projection='3d') + ax_x.scatter(x[:, :, 0].flatten(), x[:, :, 1].flatten(), x[:, :, 2].flatten(), c=x[:, :, 0].flatten()) + ax_x.set_title('Generated three-dimensional data') + plt.show() diff --git a/Lecture_TUKR/tanaka/tukr.py b/Lecture_TUKR/tanaka/tukr.py new file mode 100644 index 00000000..5cfca776 --- /dev/null +++ b/Lecture_TUKR/tanaka/tukr.py @@ -0,0 +1,140 @@ +import numpy as np +import jax,jaxlib +import jax.numpy as jnp +import tensorflow as tf +from tqdm import tqdm #プログレスバーを表示させてくれる + + +class TUKR: + def __init__(self, X, latent_dim1, latent_dim2, sigma, prior='random', Uinit=None, Vinit=None): + #--------初期値を設定する.--------- + self.X = X + #ここから下は書き換えてね + self.nb_samples1,self.nb_samples2, self.ob_dim = self.X.shape + self.sigma = sigma + self.latent_dim1 = latent_dim1 + self.latent_dim2 = latent_dim2 + self.alpha = alpha + self.norm = norm + + if Uinit is None: + if prior == 'random': #一様事前分布のとき + self.U = np.random.normal(0, 0.1 * self.sigma, size=(self.nb_samples1, self.latent_dim1)) + #(平均,標準偏差,配列のサイズ) + # else: #ガウス事前分布のとき + # U = + else: #Zの初期値が与えられた時 + self.U = Uinit + + self.history = {} + + if Vinit is None: + if prior == 'random': # 一様事前分布のとき + self.V = np.random.normal(0, 0.1 * self.sigma, size=(self.nb_samples2, self.latent_dim2)) + # (平均,標準偏差,配列のサイズ) + # else: #ガウス事前分布のとき + # V = + else: # Zの初期値が与えられた時 + self.V = Vinit + + self.history = {} + + def f(self, U, V): #写像の計算 + DistU = jnp.sum((U[:, None, :] - U[None, :, :])**2, axis=2) + DistV = jnp.sum((V[:, None, :] - V[None, :, :]) ** 2, axis=2) + HU = jnp.exp((-1 * DistU) / (2 * (self.sigma) **2)) + HV = jnp.exp((-1 * DistV) / (2 * (self.sigma) ** 2)) + GU = jnp.sum(HU, axis=1)[:, None] + GV = jnp.sum(HV, axis=1)[:, None] + RU = HU / GU + RV = HV / GV + f1 = RU @ self.X + f2 = RV @ self.X + return (f1,f2) + + #def E(self,Z,X,alpha=1,norm=2): + def E(self,Z,X,alpha,norm):#目的関数の計算 + Y = self.f(Z,Z) + + e = jnp.sum((X - Y) ** 2) + r = alpha*jnp.sum(Z**norm) + e = e/self.nb_samples + r = r/self.nb_samples + return e + r + + def fit(self, nb_epoch: int, eta: float): + # 学習過程記録用 + self.history['z'] = np.zeros((nb_epoch, self.nb_samples, self.latent_dim)) + self.history['f'] = np.zeros((nb_epoch, self.nb_samples, self.ob_dim)) + self.history['error'] = np.zeros(nb_epoch) + + for epoch in tqdm(range(nb_epoch)): + + dEdx = jax.grad(self.E,argnums=0)(self.Z,self.X,self.alpha,self.norm) + self.Z = self.Z - eta * dEdx + + # Zの更新 + + + + # 学習過程記録用 + self.history['z'][epoch] = self.Z + self.history['f'][epoch] = self.f(self.Z,self.Z) + self.history['error'][epoch] = self.E(self.Z,self.X,self.alpha,self.norm) + + #--------------以下描画用(上の部分が実装できたら実装してね)--------------------- + def calc_approximate_f(self, resolution): #fのメッシュ描画用,resolution:一辺の代表点の数 + nb_epoch = self.history['z'].shape[0] + self.history['y'] = np.zeros((nb_epoch, resolution ** self.latent_dim, self.ob_dim)) + for epoch in tqdm(range(nb_epoch)): + + y = self.f(self.create_zeta(self.history['z'][epoch],resolution),self.Z) + self.history['y'][epoch] = y + + return self.history['y'] + def create_zeta(self, Z, resolution): #fのメッシュの描画用に潜在空間に代表点zetaを作る. + a = np.linspace(np.min(Z), np.max(Z), resolution) + b = np.linspace(np.min(Z), np.max(Z), resolution) + A,B = np.meshgrid(a,b) + aa = A.reshape(-1) + bb = B.reshape(-1) + zeta = np.concatenate([aa[:,None],bb[:,None]],axis=1) + + return zeta + + +if __name__ == '__main__': + from Lecture_TUKR.tanaka.data_scratch_tanaka import load_kura_tsom + # from Lecture_TUKR.tanaka.data_scratch_tanaka import create_rasen + # from Lecture_TUKR.tanaka.data_scratch_tanaka import create_2d_sin_curve + from visualizer import visualize_history + + #各種パラメータ変えて遊んでみてね. + epoch = 200 #学習回数 + sigma = 0.2 #カーネルの幅 + eta = 1 #学習率 + latent_dim = 2 #潜在空間の次元 + alpha = 0.1 + norm = 2 + seed = 4 + resolution = 100 + np.random.seed(seed) + + + + #入力データ(詳しくはdata.pyを除いてみると良い) + nb_samples = 100 #データ数 + X = load_kura_tsom(nb_samples1,nb_samples2) #鞍型データ ob_dim=3, 真のL=2 + # X = create_rasen(nb_samples) #らせん型データ ob_dim=3, 真のL=1 + # X = create_2d_sin_curve(nb_samples) #sin型データ ob_dim=2, 真のL=1 + + tukr = TUKR(X, latent_dim, sigma, prior='random') + tukr.fit(epoch, eta) + # visualize_history(X, tukr.history['f'], tukr.history['z'], tukr.history['error'], save_gif=False, filename="mp4") + + #----------描画部分が実装されたらコメントアウト外す---------- + tukr.calc_approximate_f(resolution) + visualize_history(X, tukr.history['y'], tukr.history['z'], tukr.history['error'], save_gif=False, filename="tmp") + + + diff --git a/Lecture_TUKR/tanaka/visualizer.py b/Lecture_TUKR/tanaka/visualizer.py new file mode 100644 index 00000000..f888ef86 --- /dev/null +++ b/Lecture_TUKR/tanaka/visualizer.py @@ -0,0 +1,87 @@ +import numpy as np +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation + +STEP = 150 + + +def visualize_history(X, Y_history, Z_history, error_history, save_gif=False, filename="tmp"): + input_dim, latent_dim = X.shape[1], Z_history[0].shape[1] + input_projection_type = '3d' if input_dim > 2 else 'rectilinear' + + fig = plt.figure(figsize=(10, 8)) + gs = fig.add_gridspec(3, 2) + input_ax = fig.add_subplot(gs[0:2, 0], projection=input_projection_type) + latent_ax = fig.add_subplot(gs[0:2, 1], aspect='equal') + error_ax = fig.add_subplot(gs[2, :]) + num_epoch = len(Y_history) + + if input_dim == 3 and latent_dim == 2: + resolution = int(np.sqrt(Y_history.shape[1])) + if Y_history.shape[1] == resolution ** 2: + Y_history = np.array(Y_history).reshape((num_epoch, resolution, resolution, input_dim)) + + observable_drawer = [None, None, draw_observable_2D, + draw_observable_3D][input_dim] + latent_drawer = [None, draw_latent_1D, draw_latent_2D][latent_dim] + + ani = FuncAnimation( + fig, + update_graph, + frames=num_epoch, # // STEP, + repeat=True, + interval=50, + fargs=(observable_drawer, latent_drawer, X, Y_history, Z_history, error_history, fig, + input_ax, latent_ax, error_ax, num_epoch)) + plt.show() + if save_gif: + ani.save(f"{filename}.mp4", writer='ffmpeg') + + +def update_graph(epoch, observable_drawer, latent_drawer, X, Y_history, + Z_history, error_history, fig, input_ax, latent_ax, error_ax, num_epoch): + fig.suptitle(f"epoch: {epoch}") + input_ax.cla() + # input_ax.view_init(azim=(epoch * 400 / num_epoch), elev=30) + latent_ax.cla() + error_ax.cla() + + Y, Z= Y_history[epoch], Z_history[epoch] + colormap = X[:, 0] + + observable_drawer(input_ax, X, Y, colormap) + latent_drawer(latent_ax, Z, colormap) + draw_error(error_ax, error_history, epoch) + + +def draw_observable_3D(ax, X, Y, colormap): + ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=colormap) + # ax.set_zlim(-1, 1) + if len(Y.shape) == 3: + ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + # ax.scatter(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + else: + ax.plot(Y[:, 0], Y[:, 1], Y[:, 2], color='black') +# ax.plot(Y[:, 0], Y[:, 1], Y[:, 2], color='black') +# ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + + +def draw_observable_2D(ax, X, Y, colormap): + ax.scatter(X[:, 0], X[:, 1], c=colormap) + ax.plot(Y[:, 0], Y[:, 1], c='black') + + +def draw_latent_2D(ax, Z, colormap): + ax.set_xlim(-1.1, 1.1) + ax.set_ylim(-1.1, 1.1) + ax.scatter(Z[:, 0], Z[:, 1], c=colormap) + + +def draw_latent_1D(ax, Z, colormap): + ax.scatter(Z, np.zeros(Z.shape), c=colormap) + ax.set_ylim(-1, 1) + +def draw_error(ax, error_history, epoch): + ax.set_title("error_function", fontsize=8) + ax.plot(error_history, label='誤差関数') + ax.scatter(epoch, error_history[epoch], s=55, marker="*") diff --git a/Lecture_UKR/tanaka/__pycache__/data.cpython-39.pyc b/Lecture_UKR/tanaka/__pycache__/data.cpython-39.pyc index 18e5ea79..c858e58d 100644 Binary files a/Lecture_UKR/tanaka/__pycache__/data.cpython-39.pyc and b/Lecture_UKR/tanaka/__pycache__/data.cpython-39.pyc differ diff --git a/Lecture_UKR/tanaka/__pycache__/visualizer.cpython-39.pyc b/Lecture_UKR/tanaka/__pycache__/visualizer.cpython-39.pyc index 1aed8b5a..193a6bea 100644 Binary files a/Lecture_UKR/tanaka/__pycache__/visualizer.cpython-39.pyc and b/Lecture_UKR/tanaka/__pycache__/visualizer.cpython-39.pyc differ diff --git a/Lecture_UKR/tanaka/shippai.gif b/Lecture_UKR/tanaka/shippai.gif deleted file mode 100644 index 7e9477a0..00000000 Binary files a/Lecture_UKR/tanaka/shippai.gif and /dev/null differ diff --git a/Lecture_UKR/tanaka/ukr.py b/Lecture_UKR/tanaka/ukr.py index 402bea3a..1c0e2586 100644 --- a/Lecture_UKR/tanaka/ukr.py +++ b/Lecture_UKR/tanaka/ukr.py @@ -56,7 +56,7 @@ def fit(self, nb_epoch: int, eta: float): for epoch in tqdm(range(nb_epoch)): dEdx = jax.grad(self.E,argnums=0)(self.Z,self.X,self.alpha,self.norm) - self.Z = self.Z - (eta) * dEdx + self.Z = self.Z - eta * dEdx # Zの更新 @@ -81,29 +81,27 @@ def create_zeta(self, Z, resolution): #fのメッシュの描画用に潜在空 a = np.linspace(np.min(Z), np.max(Z), resolution) b = np.linspace(np.min(Z), np.max(Z), resolution) A,B = np.meshgrid(a,b) - # A = np.meshgrid(a) aa = A.reshape(-1) bb = B.reshape(-1) zeta = np.concatenate([aa[:,None],bb[:,None]],axis=1) - #zeta = np.concatenate(aa[:,None],axis=0) return zeta if __name__ == '__main__': from Lecture_UKR.tanaka.data import create_kura - # from Lecture_UKR.tanaka.data import create_rasen - # from Lecture_UKR.tanaka.data import create_2d_sin_curve + from Lecture_UKR.tanaka.data import create_rasen + from Lecture_UKR.tanaka.data import create_2d_sin_curve from visualizer import visualize_history #各種パラメータ変えて遊んでみてね. epoch = 200 #学習回数 - sigma = 0.03 #カーネルの幅 - eta = 0.1 #学習率 + sigma = 0.2 #カーネルの幅 + eta = 1 #学習率 latent_dim = 2 #潜在空間の次元 alpha = 0.1 norm = 2 - seed = 3 + seed = 4 resolution = 100 np.random.seed(seed) @@ -117,11 +115,11 @@ def create_zeta(self, Z, resolution): #fのメッシュの描画用に潜在空 ukr = UKR(X, latent_dim, sigma, prior='random') ukr.fit(epoch, eta) - visualize_history(X, ukr.history['f'], ukr.history['z'], ukr.history['error'], save_gif=False, filename="mp4") + # visualize_history(X, ukr.history['f'], ukr.history['z'], ukr.history['error'], save_gif=False, filename="mp4") #----------描画部分が実装されたらコメントアウト外す---------- - #ukr.calc_approximate_f(resolution) - #visualize_history(X, ukr.history['y'], ukr.history['z'], ukr.history['error'], save_gif=False, filename="tmp") + ukr.calc_approximate_f(resolution) + visualize_history(X, ukr.history['y'], ukr.history['z'], ukr.history['error'], save_gif=False, filename="tmp") diff --git a/Lecture_UKR/tanaka/visualizer.py b/Lecture_UKR/tanaka/visualizer.py index c711506d..68a47732 100644 --- a/Lecture_UKR/tanaka/visualizer.py +++ b/Lecture_UKR/tanaka/visualizer.py @@ -59,11 +59,11 @@ def draw_observable_3D(ax, X, Y, colormap): ax.set_zlim(-1, 1) if len(Y.shape) == 3: ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') - ax.scatter(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') + # ax.scatter(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') else: ax.plot(Y[:, 0], Y[:, 1], Y[:, 2], color='black') # ax.plot(Y[:, 0], Y[:, 1], Y[:, 2], color='black') - ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') +# ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') def draw_observable_2D(ax, X, Y, colormap): @@ -81,7 +81,7 @@ def draw_latent_1D(ax, Z, colormap): ax.set_xlim(np.max(Z), np.min(Z)) ax.set_ylim(np.max(Z), np.min(Z)) ax.scatter(Z, np.zeros(Z.shape), c=colormap) - #ax.set_ylim(-1, 1) + ax.set_ylim(-1, 1) def draw_error(ax, error_history, epoch): ax.set_title("error_function", fontsize=8) diff --git a/Lecture_UKR/tokunaga/seed10.mp4 b/Lecture_UKR/tokunaga/seed10.mp4 deleted file mode 100644 index 8eff02ae..00000000 Binary files a/Lecture_UKR/tokunaga/seed10.mp4 and /dev/null differ