2014년 11월 3일 월요일

Neural Network BP(Backpropagation)를 이용한 XOR 학습 in Java


소스는 http://martinblog.tistory.com/888 여기 cpp 소스를 참고하였습니다. 하지만 일부 구현이 이상하다고 판단되는 부분은 다른 ppt를 참고하였습니다.
http://ai-times.tistory.com/272 여기 1211978906_09. BP알고리즘.ppt 첨부 문서를 참고하였습니다만, 여기에서도 일부 잘못된 내용들이 있었습니다.
p.14 Step 8 부분인데 빨간색으로 표시된 부분이 잘못표현 되어있었습니다.

Class이름을 NNBP라고 이름 지었습니다.
NNBP는 생성자를 호출할때 입출력,히든 노드의 개수와 학습률, 출력 최대 오차를 정의 합니다.
그리고 학습하는 함수는 learn이란 이름으로 하고 인자로는 입력 값과 출력을 넘겨 주면 오차를 넘겨줍니다. 이것을 가지고 전반적인 top down식 설계를 하여, xor을 학습 시키는 예제를 만들도록 하겠습니다.

  final int INPUT_NODE_COUNT = 2;
  final int HIDDEN_NODE_COUNT = 4;
  final int OUTPUT_NODE_COUNT = 1;

  //Step 2 : Set learning rate and Emax
  final double ALPHA = 1.0d;//학습률
  final double E_MAX = 0.01d;//최대 출력 오차
  double E_sum = 1d;

  NNBP nnbp= new NNBP(
    INPUT_NODE_COUNT,
    HIDDEN_NODE_COUNT,
    OUTPUT_NODE_COUNT,
    ALPHA,
    E_MAX
    );

  double x[] = new double [INPUT_NODE_COUNT];
  double y[] = new double [OUTPUT_NODE_COUNT];
  int loop_count = 0;

  while( true ) {
   loop_count++;
   E_sum = 0;

   //Step 3 : For each training pattern pair
   //do Step 4-10 until k = p

   // k=1
   x[0] = 0.0d;
   x[1] = 0.0d;
   y[0] = 0.0d;
   E_sum += nnbp.learn(x,y);

   // k=2
   x[0] = 0.0d;
   x[1] = 1.0d;
   y[0] = 1.0d;
   E_sum += nnbp.learn(x,y);

   // k=3
   x[0] = 1.0d;
   x[1] = 0.0d;
   y[0] = 1.0d;
   E_sum += nnbp.learn(x,y);

   // k=4
   x[0] = 1.0d;
   x[1] = 1.0d;
   y[0] = 0.0d;
   E_sum += nnbp.learn(x,y);

   // Step 11 :   Test stop condition
   if( E_sum < E_MAX ) break;
  }

  System.out.printf("loop_count:%d, error:%f\n",loop_count,E_sum);


이제 학습이 끝나면 테스트가 필요하므로 테스트시 필요한 소스를 제작해 보았습니다.

  // test
  x[0] = 0.0d;
  x[1] = 0.0d;
  y = nnbp.doTest(x);
  System.out.printf("test : %f %f %f \n",x[0],x[1],y[0]);
  x[0] = 1.0d;
  x[1] = 0.0d;
  y = nnbp.doTest(x);
  System.out.printf("test : %f %f %f \n",x[0],x[1],y[0]);
  x[0] = 0.0d;
  x[1] = 1.0d;
  y = nnbp.doTest(x);
  System.out.printf("test : %f %f %f \n",x[0],x[1],y[0]);
  x[0] = 1.0d;
  x[1] = 1.0d;
  y = nnbp.doTest(x);
  System.out.printf("test : %f %f %f \n",x[0],x[1],y[0]);

클래스 생성자에서는 변수 초기화를 하는데 BP에서는 가중치값이 -0.5~+0.5가 되도록 설정합니다. (random함수 이용)

 public NNBP(int input, int hidden, int output, double aLPHA, double e_MAX) {
  inputNodeCount = input;
  hiddenNodeCount = hidden;
  outputNodeCount = output;
  alpha = aLPHA;
  e_max = e_MAX;
  v = new double [input][hidden];
  z = new double [hidden];
  w = new double [hidden][output];
  y = new double [output];
  
  for (int i = 0; i < inputNodeCount; i++) {
   for (int j = 0; j < hiddenNodeCount; j++) {
    v[i][j] = (double)Math.random()-0.5d;
   }
  }
  for (int i = 0; i < hiddenNodeCount; i++) {
   for (int j = 0; j < outputNodeCount; j++) {
    w[i][j] = (double)Math.random()-0.5d;
   }
  }
 }

소스내의 변수들은 아래의 이미지를 참고하여 구현하였습니다. x는 입력, v는 입력층과 은닉층의 연결강도, z는 은닉층, w는 은닉층과 출력층의 연결강도, y는 출력층을 의미합니다. 그리고 ppt문서와는 다르게 상수로 1로 표현되고 있는 bias 부분은 적용하지 않았습니다.(빨간색 X표로 표시해 두었습니다.)

learn함수는 4~9단계로 이루어지며, 4,5단계에서는 출력을 계산합니다.

  //Step 4 : Compute output of hidden layer 
  for (int i = 0; i < hiddenNodeCount; i++)
  {
   double NETz = 0.0d;
   for (int j = 0; j < inputNodeCount; j++) {
    NETz += x[j] * this.v[j][i];
   }
   this.z[i] = sigmoid(NETz);
  }

  //Step 5 : Compute output
  for (int k = 0; k < outputNodeCount; k++)
  {
   double NETy = 0.0d;
   for (int i = 0; i < hiddenNodeCount; i++) {
    NETy += z[i] * this.w[i][k];
   }
   this.y[k] = sigmoid(NETy);
  }

6단계에서는 출력 오차를 계산하고 마지막에 오차를 리턴할때 사용합니다.

  //Step 6 : Compute output error
  double Err = 0.0d;
  for (int k = 0; k < outputNodeCount; k++) {
   Err = (float)(Err + ((t[k] - this.y[k])*(t[k] - this.y[k])) / 2.0D);
  }

7,8,9단에서는 오차 기울기를 구하고 가중치를 변경합니다.

  //Step 7 : Compute error signal of output layer
  // dy = (d-y)*y*(1-y) :: y*(1-y)시그모이드 함수의 미분  * (d-y)출력 오차값  => 오차기울기
  double dy[] = new double [ outputNodeCount ];
  for( int k=0; k < outputNodeCount; k++ ) 
  {
   dy[k] = (t[k] - y[k]) * y[k] * (1 - y[k]);
  }

  //Step 8 : Compute error signal of hidden layer
  // dz = z*(1-z) * sigma[i=1]to[m]( dy * w )
  double dz[] = new double [ hiddenNodeCount ];
  for (int i = 0; i < hiddenNodeCount; i++)
  {
   double Sum = 0.0d;
   for (int k = 0; k < outputNodeCount; k++) {
    Sum += dy[k] * w[i][k];
   }
   dz[i] = (z[i] * (1.0d - z[i]) * Sum);
  }

  //Step 9 : Update weights
  for (int i = 0; i < hiddenNodeCount; i++) {
   for (int j = 0; j < outputNodeCount; j++)
   {
    w[i][j] += (alpha * dy[j] * z[i]);
   }
  }
  for (int i = 0; i < inputNodeCount; i++) {
   for (int j = 0; j < hiddenNodeCount; j++)
   {
    v[i][j] += (alpha * dz[j] * x[i]);
   }
  }

테스트 하는 함수는 4,5단계에 사용하는 z,y의 변수를 class 멤버변수로 사용하지 않고 로컬 변수로 사용하며 최종결과 값인 y값을 리턴해주면 됩니다.
public double[] doTest(double[] testx) {
  double testz[];
  double testy[];
  testz = new double [hiddenNodeCount];
  testy = new double [outputNodeCount];
  // Compute output of hidden layer 
  for (int i = 0; i < hiddenNodeCount; i++)
  {
   double NETz = 0.0d;
   for (int j = 0; j < inputNodeCount; j++) {
    NETz += testx[j] * this.v[j][i];
   }
   testz[i] = sigmoid(NETz);
  }

  // Compute output
  for (int k = 0; k < outputNodeCount; k++)
  {
   double NETy = 0.0d;
   for (int i = 0; i < hiddenNodeCount; i++) {
    NETy += testz[i] * this.w[i][k];
   }
   testy[k] = sigmoid(NETy);
  }
  return testy;
 }
테스트시 아래와 같은 결과가 나옵니다.
loop_count:1827, error:0.009997
test : 0.000000 0.000000 0.077762 
test : 1.000000 0.000000 0.932610 
test : 0.000000 1.000000 0.933328 
test : 1.000000 1.000000 0.069122 
각 출력의 의미하는 바는 다음과 같습니다.
loop_count는 학습한 횟수를 의미합니다. error는 학습이 끝났을때 오차를 의미하고, test 는 각각의 test case에서 결과값이 어떤지를 보여줍니다.
개인적으로 완벽하고 이해하기 쉽게 구현했다고 생각되는 java 구현 소스를 공개합니다.


댓글 없음:

댓글 쓰기