본문 바로가기

컴퓨터 비전/실습

mean shift 직접 구현하기(python)

Mean Shift

이번 포스팅에서는 mean shift 알고리즘을 직접 구현해보도록 하겠습니다.

산 정상을 가장 빨리 도착하는 방법은 등고선 지도에서 선이 조밀한 곳을 직선거리로 가는 방법일 것입니다.

이처럼 어떤 값의 분포중심을 찾는 방법이 mean shift (평균 이동) 알고리즘 입니다.

구현방법은 데이터 중심으로부터 반경 R안에 있는 전체 데이터의 무게중심을 구하고 중심을 무게중심으로 이동시켜 다시 그 과정을 수렴할때까지 반복하면 됩니다. 따라서 k means 알고리즘과 달리 사람이 직접 k개로 데이터를 나누지 않아도 적절히 세그멘테이션이 가능합니다. 하지만 모든 데이터에 대해 수렴점을 찾기때문에 시간이 오래 걸리고 탐색 윈도우 반경을 얼마로 설정해야할지는 정해야 합니다. 너무 작으면 local_minimum에 빠져서 일반적인 중심을 찾지 못하고 너무 크면 의미있는 세그멘테이션이 안될수도 있습니다.

밀도 함수와 관련된 식에 대해선 위키백과를 참고해 주세요.

 

https://en.wikipedia.org/wiki/Mean_shift

 

Mean shift - Wikipedia

From Wikipedia, the free encyclopedia Mathematical technique Mean shift is a non-parametric feature-space mathematical analysis technique for locating the maxima of a density function, a so-called mode-seeking algorithm.[1] Application domains include clus

en.wikipedia.org

 

Code(좌표 + rgb)

import cv2
import time
import numpy as np
import math
from matplotlib import pyplot as plt
from copy import deepcopy
import random

def g_k(x, y, h):
    return np.exp(-1 * (np.sum((x-y)**2 , axis=1)/h**2) )

def mean_shift2(data, hs, hr):
    w, h, c = data.shape
    data_copy = np.copy(data)
    data_copy = data_copy.reshape(w*h,c)
    
    data_final = np.copy(data)
    data_final = data_final.reshape(w*h,c) # finally to return the result
    
    a = np.zeros(h)
    b = np.arange(0,h)
    d = np.vstack((a,b))
    d = d.T

    for i in range(1, w):
        a = i * np.ones(h)
        f = np.vstack((a,b))
        f = f.T
        d = np.vstack((d,f)) # coordinate (w*h, 2)
    
    data_copy = np.hstack((d,data_copy)) # to concatenate the coordinate and make 5d data 
        
             

    v = [] # v is conversence point list
    t = 0

    for i in range(w*h):
        

            t += 1
            print( (t / (w*h))  * 100) # time(percent)

            y_t = data_copy[i]
            m_s = np.zeros((5))

            while True:
                g = g_k(data_copy[:,:2], y_t[:2], hs)  * g_k(data_copy[:,2:], y_t[2:], hr) # g is weights
                g = np.vstack((g,g,g,g,g)) # to multiply the same weight
                g = g.T # to make (w*h,5) shape
                deno = np.sum(g) # only sum of weights
                m_s =  np.sum((data_copy - y_t) * g) # m_s is direction vector
                
                m_s = m_s / deno # normalization
                temp = np.copy(y_t)
                y_t = y_t + m_s
                
                if dist(y_t, temp) < 1:

                    v.append(y_t) # append conversence points
                    # print(y_t) 
                    break
    
    label = 0 # num of group
    index_list = [] # to save the index of data
 
    for i in range(w*h):
        
        if len(v[i]) == 6: # if len is 6, it is already grouped (length condition).
            continue
        
        sum = np.copy(v[i]) # v[i] is standard 
        count = 1 # to count that how many data belong to that label
        index_list.append(i) # starting from standard

        for j in range(i+1, w*h): # before i, it already had done
            
            if len(v[j]) == 6: # if len is 6, it is already grouped.
                continue

            if dist(v[i][:2], v[j][:2]) < hs and dist(v[i][2:], v[j][2:]) < hr: # if conversence points' distance is less than threshold, suppose that conversence points are same 
                    count += 1
                    sum += v[j]
                    v[j] = np.append(v[j],[label]) # to make length condition
                    index_list.append(j) # append that index
        
        v[i] = np.append(v[i],[label])
        new_value = sum[2:] / count # to renew the value of the same conversence point

        for index in index_list:
            data_final[index] = new_value # data of the same conversence points have the same value (clustering)
       
        index_list = [] # to empty the index list for the next step
        label += 1 # to increase the num of group for the next step



    print(label)
    return data_final.reshape(w,h,c)
    
imageFile = '../house.bmp'
img1 = cv2.imread(imageFile, cv2.IMREAD_COLOR)
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
img1 = cv2.resize(img1,dsize=(256, 256))

t = mean_shift2(img1, 8, 7)

plt.subplot(1,2,1)
plt.imshow(img1, vmin=0, vmax=255)

plt.subplot(1,2,2)
plt.imshow(t, vmin=0, vmax=255)

plt.show()

 

g_k 함수는 weight들을 구하기위한 가우시안 커널 입니다. 

데이터의 차원은 정밀도를 위해 좌표값을 rgb값에 합쳐서 5차원로 만들었습니다.

mean shift함수의 앞부분은 좌표값을 numpy array에 concatenate하기 위해 (w*h, 2) 꼴을 만드는 과정입니다.
그 후에 data(rgb) 앞쪽에 concatenate 하면 5차원 데이터가 완성됩니다.

 

함수는 크게 무한루프의 break를 기준으로 하여 두 부분으로 나뉩니다.

앞부분은 수렴점을 모으는 단계로 현재 값에 가우시안 커널을 씌우고 확률밀도 함수를 이용하여 윈도우를 군집의 중심으로 이동하게 합니다. 차이가 거의 없을 때까지 반복하며 이 과정을 모든 데이터에 대해 하여 수렴점들을 모으는 과정입니다.

뒷부분은 군집화 단계로 전 단계에서 모은 모든 수렴점에 대해서 일정 거리 안에 있는 점들을 같은 라벨로 분류합니다.
그리고 같은 라벨에 있는 모든 수렴점의 평균으로 그 라벨을 가진 데이터를 평균값으로 바꿔줍니다.
자세한 사항은 주석을 확인해 주세요.

 

결과사진

원본 구현한 mean shift
원본 구현한 mean shift

 

mean shift의 세그멘테이션(segmentation)이 색감별로 잘 적용됨을 확인하였습니다. k means에 비해서 성능이 우수하지만 모든 데이터에 대해서 수렴점을 구하는 과정과 그 수렴점을 군집화하는 과정에서 시간이 많이 소요된다는 단점이 있었습니다.

이제 opencv의 라이브러리를 사용한 결과와 비교해 보겠습니다.

 

Code(라이브러리)

import numpy as np
import cv2 as cv
from sklearn.cluster import MeanShift, estimate_bandwidth
from matplotlib import pyplot as plt

imageFile = '../lena.bmp'
img1 = cv2.imread(imageFile, cv2.IMREAD_COLOR)
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
img1 = cv2.resize(img1,dsize=(256, 256))

# flatten the image
flat_image = img1.reshape((-1,3))
flat_image = np.float32(flat_image)

# meanshift
bandwidth = estimate_bandwidth(flat_image, quantile=.06, n_samples=3000)
ms = MeanShift(bandwidth, max_iter=800, bin_seeding=True)
ms.fit(flat_image)
labeled=ms.labels_


# get number of segments
segments = np.unique(labeled)
print('Number of segments: ', segments.shape[0])

# get the average color of each segment
total = np.zeros((segments.shape[0], 3), dtype=float)
count = np.zeros(total.shape, dtype=float)
for i, label in enumerate(labeled):
    total[label] = total[label] + flat_image[i]
    count[label] += 1
avg = total/count
avg = np.uint8(avg)

# cast the labeled image into the corresponding average color
res = avg[labeled]
result = res.reshape((img1.shape))

plt.subplot(1,2,1)
plt.imshow(img1, vmin=0, vmax=255)

plt.subplot(1,2,2)
plt.imshow(result, vmin=0, vmax=255)

plt.show()

 

결과사진

라이브러리 구현한 mean shift
라이브러리 구현한 mean shift

 

직접 구현할때는 좌표정보도 rgb값에 추가하여 좌표와 rgb값에 대한 관련성이 높아졌기 때문에 라이브러리보다 더 선명해 보입니다. 이렇게 mean shift를 직접 구현하는데 성공했습니다.

하지만 라이브러리는 10초안에 동작이 된 반면 직접구현한 코드를 실행할때는 2~3분정도 걸렸습니다. 이 부분은 쓰레드 추가와 군집화 알고리즘 개선으로 더 빠르게 처리할수 있다고 생각합니다.

정확한 비교를 위해 이번에는 라이브러리와 같은 조건으로 좌표값을 포함시키지 않고 구현하여 비교해 보겠습니다.

 

Code(only rgb)

def mean_shift(data, h1):
    w, h, c = data.shape
    data = np.copy(data)
    data = data.reshape(w*h,c)

    v = [] # v is convergence point list
    t = 0

    for i in range(w*h):
        
            t += 1
            print( (t / (w*h))  * 100) # time(percent)

            y_t = data[i]
            m_s = np.zeros((3))
            

            while True:
                g = g_k(data, y_t, h1) # g is weights
                g = np.vstack((g,g,g)) # to multiply the same weight
                g = g.T # to make (w*h,3) shape
                deno = np.sum(g) # only sum of weights
                m_s =  np.sum((data - y_t) * g) # m_s is direction vector
                
                m_s = m_s / deno  # normalization
                temp = np.copy(y_t)
                y_t = y_t + m_s
                
                if dist(y_t, temp) < 1:

                    v.append(y_t) # append convergence points
                    # print(y_t) 
                    break
    
    label = 0 # num of group
    index_list = [] # to save the index of data
 
    for i in range(w*h):
        
        if len(v[i]) == 4: # if len is 4, it is already grouped (length condition).
            continue
        
        sum = np.copy(v[i]) # v[i] is standard 
        count = 1 # to count that how many data belong to that label
        index_list.append(i) # starting from standard

        for j in range(i+1, w*h): # before i, it already had done
            
            if len(v[j]) == 4: # if len is 4, it is already grouped.
                continue

            if dist(v[i], v[j]) < h1: # if convergence points' distance is less than threshold, suppose that convergence points are same 
                    count += 1
                    sum += v[j]
                    v[j] = np.append(v[j],[label]) # to make length condition
                    index_list.append(j) # append that index
        
        v[i] = np.append(v[i],[label])
        new_value = sum / count # to renew the value of the same convergence point

        for index in index_list:
            data[index] = new_value # data of the same convergence points have the same value (clustering)
       
        index_list = [] # to empty the index list for the next step
        label += 1 # to increase the num of group for the next step



    print(label)
    return data.reshape(w,h,c)

 

결과 사진

구현한 mean shift(rgb) 라이브러리
구현한 mean shift(rgb) 라이브러리

 

색감이 과도하게 밝아진 부분을 제외하면 segmentation은 된것으로 보이나 라이브러리가 성능과 소요시간 면에서 더 우수했습니다.

따라서 직접 구현할때는 앞의 것과 같이 좌표정보도 추가하여 색감과 세그멘테이션에서 보다 정밀하게 구현해야 될것 같습니다.

감사합니다.