Published on

Class Classification Accuracy 값 구하기

Authors
  • avatar
    Name
    Dongju Lee
    Twitter

YOLOv8 Segment 클래스 분류 정확도 구하기

YOLOv8의 Segemnt Task를 수행하면서 예측한 Mask에 대해서 영역뿐만 아니라 class 또한 잘 예측했는지 확인하기 위해서 함수를 작성했습니다.

예측 class 정보는 아래의 코드로 확인할 수 있습니다.

from ultralytics import YOLO

model = YOLO('학습한 모델 경로')
results = model('예측할 이미지 경로')

results
# 출력 결과
[ultralytics.engine.results.Results object with attributes:

 boxes: ultralytics.engine.results.Boxes object
 keypoints: None
 masks: ultralytics.engine.results.Masks object
 names: {0: 'no_battery_cover', 1: 'velcro'}
 obb: None
 orig_img: array([[[ 7,  2,  1],
         [ 7,  2,  1],
         [ 7,  2,  1],
         ...,
         [10,  8,  8],
         [ 9,  7,  7],
         [ 9,  7,  7]],

        [[ 7,  2,  1],
         [ 7,  2,  1],
         [ 7,  2,  1],
         ...,
         [10,  8,  8],
         [ 9,  7,  7],
         [ 9,  7,  7]],

        [[ 7,  2,  1],
         [ 7,  2,  1],
         [ 7,  2,  1],
         ...,
         [10,  8,  8],
         [ 9,  7,  7],
         [ 9,  7,  7]],

        ...,

        [[12, 10, 10],
         [12, 10, 10],
         [12, 10, 10],
         ...,
         [ 5,  0,  1],
         [ 5,  0,  1],
         [ 5,  0,  1]],

        [[13, 11, 11],
         [13, 11, 11],
         [13, 11, 11],
         ...,
         [ 5,  0,  1],
         [ 5,  0,  1],
         [ 5,  0,  1]],

        [[14, 12, 12],
         [14, 12, 12],
         [14, 12, 12],
         ...,
         [ 5,  0,  1],
         [ 5,  0,  1],
         [ 5,  0,  1]]], dtype=uint8)
 orig_shape: (640, 640)
 path: '/yolov8/datasets/test/images/Screenshot_2024_02_20_16_32_08_jpg.rf.bff8b03c49f87bc3829aae633488f262.jpg'
 probs: None
 save_dir: None
 speed: {'preprocess': 6.598472595214844, 'inference': 7.602930068969727, 'postprocess': 1.863718032836914}]

결과 값인 'results'의 boxes에서 cls 속성을 확인하면 됩니다.

results[0].boxes.cls

# 출력 결과
tensor([1.], device='cuda:0')

확인하게 되면 tensor의 형태로 결과를 return합니다.

숫자의 의미는 label.txt의 첫 번째 값입니다.

해당 값은 labeling할 때 지정한 class name의 index 번호입니다.

model로 예측한 cls 값과 해당 이미지의 label.txt 파일을 비교해서 분류를 잘 했는지 확인할 수 있습니다.

label.txt의 예시는 다음과 같습니다.

# label.txt 형식
'1 0.50390625 0.5052083328125 0.494140625 0.4739583328125 0.4755859375 0.44444444375 0.4541015625 0.4270833328125 0.4384765625 0.4236111109375 0.4365234375 0.4270833328125 0.4228515625 0.4270833328125 0.4169921875 0.43055555625 0.41015625 0.4427083328125 0.404296875 0.4704861109375 0.40625 0.4739583328125 0.40625 0.484375 0.412109375 0.5017361109375 0.4345703125 0.5416666671875 0.4462890625 0.5520833328125 0.4658203125 0.559027778125 0.4677734375 0.5625 0.48046875 0.5607638890625 0.4931640625 0.5520833328125 0.501953125 0.5364583328125 0.50390625 0.5260416671875 0.50390625 0.5052083328125'

label.txt를 리스트 형태로 변환하고 첫 번째 인덱스를 추출해서 비교하는 함수를 작성했습니다.

input 값은 학습한 모델로 이미지를 예측한 값 predict와 해당 이미지의 label.txt인 label입니다.

# predict = model('이미지경로')
# label = 'label.txt경로'

def judge_cls(predict, label):
    # cls 맞춘 개수 초기화
    correct = 0

    # 변수 선언
    label_list = label.split('\n')
    true_mask_cnt = len(label_list)
    predict_mask_cnt = predict[0].boxes.cls.size()[0]

    # 예측 mask와 정답 mask 개수가 같은 경우
    if true_mask_cnt == predict_mask_cnt:
        for cls_index in range(true_mask_cnt):
            if float(label_list[cls_index][0])==predict[0].boxes.cls[cls_index].item():
                correct += 1

        return correct / true_mask_cnt

    # 예측 mask가 정답 mask 보다 많은 경우
    if true_mask_cnt < predict_mask_cnt:
        for cls_index in range(true_mask_cnt):
            if float(label_list[cls_index][0])==predict[0].boxes.cls[cls_index].item():
                correct += 1

        return correct / predict_mask_cnt


    # 정답 mask가 예측 mask 보다 많은 경우
    if true_mask_cnt > predict_mask_cnt:
        for cls_index in range(predict_mask_cnt):
            if float(label_list[cls_index][0])==predict[0].boxes.cls[cls_index].item():
                correct += 1

        return correct / true_mask_cnt