naoya_t@hatenablog

いわゆるチラシノウラであります

PRML復々習レーン #6 @AJITO

11/3(祝) 10:00〜 @AJITO(〜PANGEA) (VOYAGE GROUP)
http://atnd.org/events/32596

§4.4〜§5.3

AJITOでPRML読むのって本レーンぶりかな。
午後の部はPANGEAで1人1テーブル。

ニューラルネットワーク、というか図5.3の再現コードでも書こうかと独りHackathonしていた。
https://github.com/naoyat/PRMLrevenge/tree/master/chap5
f:id:n4_t:20121114214242j:plainf:id:n4_t:20121114214253j:plainf:id:n4_t:20121114214255j:plainf:id:n4_t:20121114214302j:plain
隠れユニットの出力が図5.3の破線みたいにならないのだけれど何故だろう・・・

myplot.m

function myplot(X, T, Y, msg)
  hold on
  clc;
  ymin = min(Y) - 0.1;
  ymax = max(Y) + 0.1;
  axis([-1, 1, ymin, ymax]);
  scatter(X, T, 'b');
  plot(X, Y, 'r');
  text(-0.95, ymax-0.095, msg);
  hold off
end

nnlearn.m

function [W1, W2, iter] = nnlearn(J, T, Eta, Epsilon)
  M = numel(J);
  W1 = rand(3, M+1)/10;
  W2 = rand(1, 3+1)/10;

  for i = 1:2000
    DTotal = 0;
    for j = 1:M
      X = zeros(M, 1);
      X(j) = 1;
      A = W1 * [1;X];
      Z = tanh(A);
      Y = W2 * [1;Z];
      D2 = Y - T(j);
      DTotal += D2*D2;

      W2 -= Eta * (D2 * [1;Z]');

      R = 1 - (Z' * Z);
      D1 = R * (W2' * D2); % 4 1
      W1 -= Eta * (D1(2:4) * [1;X]');
    end
    if DTotal < Epsilon
      break;
    end
  end
  iter = i;
end

nn.m

function [Y] = nn(X, W1, W2)
  A = W1 * [1; X];
  Z = tanh(A);

  Y = W2 * [1; Z];
end

fig53.m

function fig53(type)
  X = linspace(-1,1,50)';

  % 学習率
  eta = 0.05;

  % 学習打ち切り閾値
  epsilon = 1e-4;
  
  exprs = ['y = x^2'
           'y = sin(2x)'
           'y = |x|'
           'y = Heaviside(x)'
           'y = sin(3x) + cos(7x)'
           'y = 4x^3 - x'
           ];
     
  type = menu('type', exprs(1,:), exprs(2,:), exprs(3,:), exprs(4,:), exprs(5,:), exprs(6,:));
  switch type
    case 1
      T = X .* X;
    case 2
      T = sin(X*2);
    case 3
      T = abs(X);
    case 4
      T = (X > 0)*1; % Heaviside(X);
    case 5
      T = (sin(X*3) + cos(X*7)) / 2;
    case 6
      T = 4 * X .* X .* X - X;
  endswitch

  expr = exprs(type,:);

  [W1, W2, iter] = nnlearn(1:50, T, eta, epsilon);

  Y = zeros(50,1);
  for j = 1:50
    IN = zeros(50,1);
    IN(j) = 1;
    Y(j) = nn(IN, W1, W2);
  endfor

  close all ; clc
  msg = sprintf("%s; iteration# until E(x) < %g: %d", expr, epsilon, iter);
  myplot(X, T, Y, msg, iter);
end

次回予告

次回#7は 12/15 @PANGEAで
http://atnd.org/events/33833
その後にPRMLカラオケクラスタひみつ集会Vol.1
http://atnd.org/events/33843