PRML復々習レーン #6 @AJITO
11/3(祝) 10:00〜 @AJITO(〜PANGEA) (VOYAGE GROUP)
http://atnd.org/events/32596
§4.4〜§5.3
AJITOでPRML読むのって本レーンぶりかな。
午後の部はPANGEAで1人1テーブル。
ニューラルネットワーク、というか図5.3の再現コードでも書こうかと独りHackathonしていた。
https://github.com/naoyat/PRMLrevenge/tree/master/chap5
隠れユニットの出力が図5.3の破線みたいにならないのだけれど何故だろう・・・
myplot.m
function myplot(X, T, Y, msg) hold on clc; ymin = min(Y) - 0.1; ymax = max(Y) + 0.1; axis([-1, 1, ymin, ymax]); scatter(X, T, 'b'); plot(X, Y, 'r'); text(-0.95, ymax-0.095, msg); hold off end
nnlearn.m
function [W1, W2, iter] = nnlearn(J, T, Eta, Epsilon) M = numel(J); W1 = rand(3, M+1)/10; W2 = rand(1, 3+1)/10; for i = 1:2000 DTotal = 0; for j = 1:M X = zeros(M, 1); X(j) = 1; A = W1 * [1;X]; Z = tanh(A); Y = W2 * [1;Z]; D2 = Y - T(j); DTotal += D2*D2; W2 -= Eta * (D2 * [1;Z]'); R = 1 - (Z' * Z); D1 = R * (W2' * D2); % 4 1 W1 -= Eta * (D1(2:4) * [1;X]'); end if DTotal < Epsilon break; end end iter = i; end
nn.m
function [Y] = nn(X, W1, W2) A = W1 * [1; X]; Z = tanh(A); Y = W2 * [1; Z]; end
fig53.m
function fig53(type) X = linspace(-1,1,50)'; % 学習率 eta = 0.05; % 学習打ち切り閾値 epsilon = 1e-4; exprs = ['y = x^2' 'y = sin(2x)' 'y = |x|' 'y = Heaviside(x)' 'y = sin(3x) + cos(7x)' 'y = 4x^3 - x' ]; type = menu('type', exprs(1,:), exprs(2,:), exprs(3,:), exprs(4,:), exprs(5,:), exprs(6,:)); switch type case 1 T = X .* X; case 2 T = sin(X*2); case 3 T = abs(X); case 4 T = (X > 0)*1; % Heaviside(X); case 5 T = (sin(X*3) + cos(X*7)) / 2; case 6 T = 4 * X .* X .* X - X; endswitch expr = exprs(type,:); [W1, W2, iter] = nnlearn(1:50, T, eta, epsilon); Y = zeros(50,1); for j = 1:50 IN = zeros(50,1); IN(j) = 1; Y(j) = nn(IN, W1, W2); endfor close all ; clc msg = sprintf("%s; iteration# until E(x) < %g: %d", expr, epsilon, iter); myplot(X, T, Y, msg, iter); end
次回予告
次回#7は 12/15 @PANGEAで
http://atnd.org/events/33833
その後にPRMLカラオケクラスタひみつ集会Vol.1
http://atnd.org/events/33843