지난 글
갈아먹는 Object Detection [1] R-CNN
갈아먹는 Object Detection [2] Spatial Pyramid Pooling Network
들어가며
지난 시간 SPPNet에 이어서 오늘은 Fast R-CNN[1]을 리뷰해보도록 하겠습니다. 저 역시 그랬고, 많은 분들이 R-CNN 다음으로 Fast R-CNN 논문을 보시는데요, 해당 논문을 보다 보면 SPPNet에서 많은 부분들을 참고한 것을 확인할 수 있습니다. 특히나 핵심인 Spatial Pyramid Pooling은 중요한 개념이므로 리뷰하고 넘어가도록 하겠습니다.
영향력: R-CNN 저자인 Ross가 1 저자로 인용 횟수만 8000회에 달합니다.
주요 기여: CNN fine tuning, boundnig box regression, classification을 모두 하나의 네트워크에서 학습시키는 end-to-end 기법을 제시하였습니다. 추후 이어지는 Faster R-CNN에 많은 영향을 주었습니다.
결과: SPPNet보다 3배 더 빠른 학습 속도, 10배 더 빠른 속도를 보이며 Pascal VOC 2007 데이터 셋을 대상으로 mAP 66%를 기록합니다.
핵심 아이디어
Fast R-CNN은 이전 SPP Net이 가지는 한계점들을 극복하고자 하는 시도에서 출발합니다. SPPNet은 기존 RCNN이 selective search로 찾아낸 모든 RoI에 대해서 CNN inference를 하는 문제를 CNN inference를 전체 이미지에 대하여 1회만 수행하고, 이 피쳐맵을 공유하는 방식으로 해결하였습니다. 그러나 여전히 모델을 학습 시키기 위해선 1) 여러 단계를 거쳐야 했고, 2) Fully Connected Layer 밖에 학습 시키지 못하는 한계점이 있었습니다. 이에 저자는 다음과 같은 주장을 펼칩니다.
"CNN 특징 추출부터 classification, bounding box regression 까지
모두 하나의 모델에서 학습시키자!"
전체 알고리즘은 다음과 같습니다.
1. 먼저 전체 이미지를 미리 학습된 CNN을 통과시켜 피쳐맵을 추출합니다.
2. Selective Search를 통해서 찾은 각각의 RoI에 대하여 RoI Pooling을 진행합니다. 그 결과로 고정된 크기의 feature vector를 얻습니다.
3. feature vector는 fully connected layer들을 통과한 뒤, 두 개의 브랜치로 나뉘게 됩니다.
4-1. 하나의 브랜치는 softmax를 통과하여 해당 RoI가 어떤 물체인지 클래시피케이션 합니다. 더 이상 SVM은 사용되지 않습니다.
4-2. bouding box regression을 통해서 selective search로 찾은 박스의 위치를 조정합니다.
CNN을 한번만 통과시킨 뒤, 그 피쳐맵을 공유하는 것은 이미 SPP Net에서 제안된 방법입니다. 그 이후의 스텝들은 SPPNet이나 R-CNN과 그다지 다르지 않습니다. 본 논문의 가장 큰 특징은 이들을 스텝별로 쪼개어 학습을 진행하지 않고, end-to-end로 엮었다는데 있습니다. 그리고 그 결과로 학습 속도, 인퍼런스 속도, 정확도 모두를 향상시켰다는데 의의가 있습니다.
그러면 본격적으로 해당 논문에서 제시되는 RoI Pooling, Multi Task Loss, finetuning for detection 등의 개념을 살펴보도록 하겠습니다.
RoI Pooling
Fast R-CNN에서 먼저 입력 이미지는 CNN을 통과하여 피쳐맵을 추출합니다. 추출된 피쳐맵을 미리 정해놓은 H x W 크기에 맞게끔 그리드를 설정합니다. 그리고 각각의 칸 별로 가장 큰 값을 추출하는 max pooling을 실시하면 결과값은 항상 H x W 크기의 피쳐 맵이 되고, 이를 쫙 펼쳐서 feature vector를 추출하게 됩니다. 이러한 RoI Pooling은 앞서 살펴보았던 Spatial Pyramid Pooling에서 피라미드 레벨이 1인 경우와 동일합니다.
이 때, 제가 들었던 의문점은 인풋 이미지와 피쳐맵의 크기가 다를 경우 어떻게 RoI의 위치를 피쳐맵에서 찾을 수 있을까? 였습니다. 이에 대해서 한 Stack Overflow 답변[2]을 첨부합니다. 여기서는 인풋 이미지의 크기와 피쳐맵의 크기가 다를 경우, 그 비율을 구해서 RoI를 조절한 다음, RoI Pooling을 진행한다고 합니다. (이 또한 개인의 의견을 뿐이니, 더 정확한 내용이 있다면 댓글로 알려주세요 ㅎㅎ)
Multi Task Loss
이제 우리는 이미지로부터 피쳐맵을 추출했고, 해당 피쳐맵에서 RoI들을 찾아서 RoI Pooling을 적용하여 feature vector를 구했습니다. 이제 이 벡터로 classification과 bounding box regression을 적용하여 각각의 loss를 얻어내고, 이를 back propagation하여 전체 모델을 학습시키면 됩니다. 이 때, classificaiton loss와 bounding box regression을 적절하게 엮어주는 것이 필요하며, 이를 multi task loss라고 합니다. 수식은 아래와 같습니다.
먼저 입력으로 p는 softmax를 통해서 얻어낸 K+1 (K개의 object + 1개의 배경, 아무 물체도 아님을 나타내는 클래스)개의 확률 값입니다. 그다음 u는 해당 RoI의 ground truth 라벨 값입니다.
그 다음으로 bounding box regression을 적용하면 이는 K + 1개 클래스에 대해서 각각 x, y, w, h 값을 조정하는 tk를 리턴합니다. 즉, 이 RoI가 사람일 경우 박스를 이렇게 조절해라, 고양이일 경우 이렇게 조절해라는 값을 리턴합니다. 로스 펑션에서는 이 값들 가운데 ground truth 라벨에 해당하는 값만 가져오며, 이는 tu에 해당합니다. v는 ground truth bounding box 조절 값에 해당합니다.
다시 전체 로스로 돌아가보면 앞부분은 p와 u를 가지고 classification loss를 구합니다. 여기서는 다음과 같이 log loss를 사용합니다.
전체 로스의 뒷 부분은 Bounding Box Regression을 통해서 얻는 로스로 수식은 아래와 같습니다.
입력으로는 정답 라벨에 해당하는 BBR 예측 값과 ground truth 조절 값을 받습니다. 그리고 x, y, w, h 각각에 대해서 예측 값과 라벨 값의 차이를 계산한 다음, smoothL1이라는 함수를 통과시킨 합을 계산합니다. smoothL1은 아래와 같습니다.
예측 값과 라벨 값의 차가 1보다 작으면 0.5x^2로 L2 distance를 계산해줍니다. 반면에 1보다 클 경우 L1 distance를 계산해주는 것을 볼 수 있습니다. 이는 Object Detection 테스크에 맞추어 로스 펑션을 커스텀 하는 것으로 볼 수 있습니다. 저자들은 실험 과정에서 라벨 값과 지나치게 차이가 많이 나는 outlier 예측 값들이 발생했고, 이들을 그대로 L2 distance로 계산하여 적용할 경우 gradient가 explode 해버리는 현상을 관찰했다고 합니다. 이를 방지하기 위해서 다음과 같은 함수를 추가한 것입니다.
Backpropagation through RoI Pooling Layer
(이 부분부터는 이론적인 내용이 많이 포함되어 있으니, 건너 뛰어도 무방합니다.)
이제 로스 펑션까지 구했으니 네트워크를 학습시키는 일만 남았습니다. 그런데 그 전에 짚고 넘어가야할 문제가 있습니다. 바로 네트워크를 어디까지 학습시킬 것인가? 입니다. 이전 SPP Net에서는 피쳐 맵을 뽑는 CNN 부분은 그대로 놔두고, SPP 이후의 FC들만 fine-tune 하였습니다. 그러나 이 논문에서는 이럴 경우 이미지로부터 특징을 뽑는 가장 중요한 역할을 하는 CNN이 학습될 수 없기 때문에 성능 향상에 제약이 있다고 주장합니다. 그리고 과연 RoI Pooling 레이어 이전까지 back propagation을 전달할 수 있는지를 이론적으로 검증합니다.
먼저 다음 수식을 살펴보겠습니다.
xi 라고 하는 것은 CNN을 통해 추출된 피쳐 맵에서 하나의 피쳐 값을 의미하며, 이는 실수입니다. 전체 Loss에 대해서 이 피쳐 값의 편미분 값을 구하면 그 값이 곧 xi에 대한 loss 값이 되며 역전파 알고리즘을 수행할 수 있습니다. 자, 이제 피쳐 맵에서 RoI를 찾고 RoI Pooling을 적용하기 위해서 H x W 크기의 grid로 나눕니다. 이 그리드들을 sub-window라고 부르며, 위 수식에서 j란 몇 번째 sub-window 인지를 나타내는 인덱스입니다. yrj란 이 RoI pooling을 통과하여 최종적으로 얻어진 output의 값이며 이 역시 하나의 실수입니다. 이를 그림으로 나타내면 아래와 같습니다.
xi가 최종 prediction 값에 영향을 주려면 xi가 속하는 모든 RoI의 sub-window에서 해당 xi가 최대 값이 되면 됩니다. i*(r, j)란 RoI와 sub window index j가 주어졌을 때 최대 피쳐 값의 인덱스를 말하며, 이는 곧 RoI Pooling을 통과하는 인덱스 값을 말합니다. 이 RoI Pooling을 통과한 이후 값에 대한 Loss는 이미 전체 Loss에 대한 yrj의 편미분 값으로 이미 계산이 되어 있습니다. 그러므로 이를 중첩시키기만 하면 xi에 대한 loss를 구할 수 있는 것입니다.
그러므로 우리는 앞서 구한 multitask loss를 RoI Pooling layer를 통과하여 CNN 단까지 fine-tuning 할 수 있는 것입니다. 저자들은 실험을 통해서 실제로 CNN까지 fine tuning 하는 것이 성능 향상에 도움이 되었다는 실험 결과를 보여줍니다.
위 실험 결과는 fine-tuning 하는 깊이를 조절해가며 성능 변화를 실험한 것입니다. CNN의 단을 깊이 학습시킬 수록 성능이 향상되었으며, 이 때 테스트에 소요되는 시간 변화는 거의 없는 것을 확인할 수 있습니다. 즉, CNN 단을 Object Detection에 맞게끔 fine-tuning 하는 것이 성능 향상의 키 포인트 였습니다.
마치며
해당 논문은 object detection 테스크를 푸는 end-to-end 모델을 제시하면서 학습 단계를 간소화시키고 정확도와 성능 모두를 향상시켰다는 의의가 있습니다. 그러나 여전히 region proposal을 selective search로 수행하고, 이는 CPU 연산으로만 수행 가능하다는 한계점이 있습니다. 이 부분을 제외하면 inference에 소요되는 시간이 0.3초 정도로 짧습니다.
이 다음 이어지는 Faster R-CNN 모델은 Fast R-CNN의 구조를 그대로 계승하면서 Region Proposal 역시 전체 네트워크의 일부로 끌어옵니다. 다음 리뷰에서 이를 어떻게 구현하는 지를 살펴보겠습니다.
이 외에도 해당 논문에서는 SVD (Singular Vector Decomposition, 특이값 분해)를 통해서 Fully Connected Layer 들의 파라미터를 줄이는 방법 등이 소개되었지만 이후의 연구들에서는 사용되어 지지않고, 지나치게 어렵기 때문에 쿨하게 넘기겠습니다. 머신러닝에서 사용되는 다양한 수학들에 대한 포스팅들도 작성할 예정이니 많은 기대 부탁드립니다!
감사합니다.
Reference
[1] Ross, Fast R-CNN, 2015
[2] stackoverflow, How to project a bounding box on feature map?, https://datascience.stackexchange.com/questions/45051/how-to-project-a-bounding-box-on-feature-map
[3] towardsdatascience, Fast R-CNN for object detection, https://towardsdatascience.com/fast-r-cnn-for-object-detection-a-technical-summary-a0ff94faa022
'갈아먹는 머신러닝 시리즈 > 컴퓨터 비전' 카테고리의 다른 글
갈아먹는 Semantic Segmentation [1] Fully Convolutional Network (0) | 2020.01.09 |
---|---|
갈아먹는 Object Detection [4] Faster R-CNN (8) | 2020.01.07 |
갈아먹는 Face Detection [1] MTCNN (3) | 2020.01.06 |
갈아먹는 Object Detection [2] Spatial Pyramid Pooling Network (7) | 2020.01.04 |
갈아먹는 Object Detection [1] R-CNN (18) | 2020.01.01 |