PRML復々習レーン #6 @AJITO
11/3(祝) 10:00〜 @AJITO(〜PANGEA) (VOYAGE GROUP)
http://atnd.org/events/32596
§4.4〜§5.3
AJITOでPRML読むのって本レーンぶりかな。
午後の部はPANGEAで1人1テーブル。
ニューラルネットワーク、というか図5.3の再現コードでも書こうかと独りHackathonしていた。
https://github.com/naoyat/PRMLrevenge/tree/master/chap5




隠れユニットの出力が図5.3の破線みたいにならないのだけれど何故だろう・・・
myplot.m
function myplot(X, T, Y, msg) hold on clc; ymin = min(Y) - 0.1; ymax = max(Y) + 0.1; axis([-1, 1, ymin, ymax]); scatter(X, T, 'b'); plot(X, Y, 'r'); text(-0.95, ymax-0.095, msg); hold off end
nnlearn.m
function [W1, W2, iter] = nnlearn(J, T, Eta, Epsilon)
M = numel(J);
W1 = rand(3, M+1)/10;
W2 = rand(1, 3+1)/10;
for i = 1:2000
DTotal = 0;
for j = 1:M
X = zeros(M, 1);
X(j) = 1;
A = W1 * [1;X];
Z = tanh(A);
Y = W2 * [1;Z];
D2 = Y - T(j);
DTotal += D2*D2;
W2 -= Eta * (D2 * [1;Z]');
R = 1 - (Z' * Z);
D1 = R * (W2' * D2); % 4 1
W1 -= Eta * (D1(2:4) * [1;X]');
end
if DTotal < Epsilon
break;
end
end
iter = i;
end
nn.m
function [Y] = nn(X, W1, W2) A = W1 * [1; X]; Z = tanh(A); Y = W2 * [1; Z]; end
fig53.m
function fig53(type)
X = linspace(-1,1,50)';
% 学習率
eta = 0.05;
% 学習打ち切り閾値
epsilon = 1e-4;
exprs = ['y = x^2'
'y = sin(2x)'
'y = |x|'
'y = Heaviside(x)'
'y = sin(3x) + cos(7x)'
'y = 4x^3 - x'
];
type = menu('type', exprs(1,:), exprs(2,:), exprs(3,:), exprs(4,:), exprs(5,:), exprs(6,:));
switch type
case 1
T = X .* X;
case 2
T = sin(X*2);
case 3
T = abs(X);
case 4
T = (X > 0)*1; % Heaviside(X);
case 5
T = (sin(X*3) + cos(X*7)) / 2;
case 6
T = 4 * X .* X .* X - X;
endswitch
expr = exprs(type,:);
[W1, W2, iter] = nnlearn(1:50, T, eta, epsilon);
Y = zeros(50,1);
for j = 1:50
IN = zeros(50,1);
IN(j) = 1;
Y(j) = nn(IN, W1, W2);
endfor
close all ; clc
msg = sprintf("%s; iteration# until E(x) < %g: %d", expr, epsilon, iter);
myplot(X, T, Y, msg, iter);
end
次回予告
次回#7は 12/15 @PANGEAで
http://atnd.org/events/33833
その後にPRMLカラオケクラスタひみつ集会Vol.1
http://atnd.org/events/33843