2014년 11월 6일 목요일

Java로 구현한 Neural Network BP(Backpropagation)

기존에 작업한 NNBP 클래스를 좀 더 다듬어 보도록 하겠습니다.

learn 함수를 호출하는 wile 문이 너무 길어져서 지저분 해보여서, pattern을 ArrayList로 미리 넣어두고 학습을 하도록 변경하였습니다.
그리고 double보다 속도를 높이기 위해서 전체적으로 float로 변경하였습니다.

기존 코드
while( true ) {
   loop_count++;
   E_sum = 0;

   //Step 3 : For each training pattern pair
   //do Step 4-10 until k = p

   // k=1
   x[0] = 0.0f;
   x[1] = 0.0f;
   y[0] = 0.0f;
   E_sum += nnbp.learn(x,y);

   // k=2
   x[0] = 0.0f;
   x[1] = 1.0f;
   y[0] = 1.0f;
   E_sum += nnbp.learn(x,y);

   // k=3
   x[0] = 1.0f;
   x[1] = 0.0f;
   y[0] = 1.0f;
   E_sum += nnbp.learn(x,y);

   // k=4
   x[0] = 1.0f;
   x[1] = 1.0f;
   y[0] = 0.0f;
   E_sum += nnbp.learn(x,y);

   // Step 11 :   Test stop condition
   if( E_sum < E_MAX ) break;
  }

개선 코드

wile 문 내부가 간단해 졌습니다.
  while( true ) {
   loop_count++;
   E_sum = nnbpv2.epoch(xPatterns,yPatterns);
   if( E_sum < E_MAX ) break;
  }

epoch 함수 내에서 learn 함수를 호출 하도록 변경하였습니다.
 public float epoch(ArrayList<float[]> xPatterns, ArrayList<float[]> yPatterns) {
  int patternCount = xPatterns.size();
  float err = 0f;
  for( int i = 0 ; i < patternCount ; i++ ){
   float x[] = xPatterns.get(i);
   float y[] = yPatterns.get(i);
   err += learn(x,y);
  }
  return err;
 }

소스


댓글 없음:

댓글 쓰기