Thais Lima de Sousa - Supervised Graduate Project

SLIC Superpixels

In [1]:
from skimage import data, segmentation, color, io
from skimage.future import graph
from matplotlib import pyplot as plt
from vpi.io import *
%matplotlib inline

Python's scikitimage results

In [2]:
img1 = data.coffee()

labels1 = segmentation.slic(img1, compactness=20, n_segments=400)
out1 = color.label2rgb(labels1, img1, kind='avg')

labels2 = segmentation.slic(img1, compactness=40, n_segments=400)
out2 = color.label2rgb(labels2, img1, kind='avg')

display_image(out1)
display_image(out2)
In [3]:
img2 = data.astronaut()

labels3 = segmentation.slic(img2, compactness=40, n_segments=400)
out3 = color.label2rgb(labels3, img2, kind='avg')

display_image(img2)
display_image(out3)
In [4]:
FLT_MAX = 1000000

def enforceConnectivity(img, clusters, num_superp):
    # relabel disjoint segments with the labels of the 
    # largest neighboring cluster
    # this function was modified from
    # https://github.com/PSMM/SLIC-Superpixels/blob/master/slic.cpp
    
    H, W = img.shape[:2]
    label = 0
    adjlabel = 0
    limit = W*H//num_superp # mean of pixels per cluster
    dx4 = [-1, 0, 1, 0]
    dy4 = [0, -1, 0, 1]
    new_clusters = (-1)*np.ones((H,W)).astype('int')
    pixels = []
    
    for j in range(W):
        for i in range(H):
            if new_clusters[i, j] == -1:
                pixels = []
                pixels.append((i, j))
                # find an adjacent label
                for dx, dy in zip(dx4, dy4):
                    x = pixels[0][0] + dx
                    y = pixels[0][1] + dy
                    if (x >= 0 and x < H and y >= 0 and y < W and
                            new_clusters[x, y] >=0):
                        adjlabel = new_clusters[x, y]
            count = 1
            c = 0
            while c < count:
                for dx, dy in zip(dx4, dy4):
                    x = pixels[c][0] + dx
                    y = pixels[c][1] + dy

                    if (x >= 0 and x < H and y >= 0 and y < W):
                        if new_clusters[x, y] == -1 and clusters[i, j] == clusters[x, y]:
                            pixels.append((x, y))
                            new_clusters[x, y] = label
                            count += 1
                c += 1
            # use the earlier found adjacent label if a segment size is smaller than a limit
            if (count <= limit >> 2):
                for c in range(count):
                    new_clusters[pixels[c]] = adjlabel
                label -= 1
            label += 1
    
    return new_clusters

def lowestGradPos(img, x, y):
    # find lowest gradient position on a 3x3 neighborhood
    xmin, ymin = x, y
    min_grad = FLT_MAX

    for i in range(x - 1, x + 2):
        for j in range(y - 1, y + 2):
            c1 = img[i+1, j]
            c2 = img[i, j+1]
            c3 = img[i, j]
            grad = np.sqrt((np.sum(c1 - c3))**2) + np.sqrt((np.sum(c2 - c3))**2)
            if grad < min_grad:
                min_grad = np.abs((np.sum(c1 - c3))) + (np.sum(c2 - c3))
                xmin, ymin = i, j
    return xmin, ymin

def slic(img, k, m, nItr=10):
    # img: image to be segmented
    # k: number of desired superpixels
    # m: compactness of superpixels
    #    large: enforce superpixels with more regular and smoother shapes

    H, W = img.shape[:2]
    img = color.rgb2lab(img)
    S = int(np.sqrt(H*W/k))
    l = (-1)*np.ones((H, W)).astype('int') # labels

    # initialize clusters
    centers = []
    for i in range(S, H - S//2, S):
        for j in range(S, W - S//2, S):
            cx, cy = lowestGradPos(img, i, j)
            I = img[cx, cy]
            centers.append([I[0], I[1], I[2], cy, cx, 1])
    k = (len(centers))
    C = np.array(centers) # 0,1,2 -> LAB value, 3,4 -> position, 5: no of pixels

    # generate superpixels
    for i in range(nItr):
        d = FLT_MAX*np.ones((H, W)) # pixel distances from cluster centres
        for j in range(k):
            # get subimage around cluster
            rmin = np.maximum(int(C[j, 4])-S, 0)
            rmax = np.minimum(int(C[j, 4])+S, H)
            cmin = np.maximum(int(C[j, 3])-S, 0)
            cmax = np.minimum(int(C[j, 3])+S, W)

            subimg = img[rmin:rmax, cmin:cmax]
            
            # compute distances D between subimage and cluster centres
            colordiff = subimg - img[int(C[j, 4]), int(C[j, 3])]
            dc = np.sqrt(np.sum(np.square(colordiff), axis=2))
            yy, xx = np.ogrid[rmin:rmax, cmin:cmax]
            ds = np.sqrt((yy - C[j, 4])**2 + (xx - C[j, 3])**2)
            D = np.sqrt(dc**2 + ((ds/S)*m)**2)
            
            # if any pixel distance from the cluster centre is less than its
            # previous value, update its distance and label
            
            subd = d[rmin:rmax, cmin:cmax]
            updateMask = D < subd
            subd[updateMask] = D[updateMask]
            d[rmin:rmax, cmin:cmax] = subd
            l[rmin:rmax, cmin:cmax][updateMask] = j

        # update cluster centres with mean values
        C = np.zeros(C.shape)
        for r in range(H):
            for c in range(W):
                if l[r,c] != -1:
                    tmp = [img[r,c,0], img[r,c,1], img[r,c,2], c, r, 1]
                    C[l[r,c]] = C[l[r,c]] + tmp

        # divide by number of pixels in each superpixel to get mean values
        for kk in range(k):
            C[kk, :5] = np.round(C[kk, :5]/C[kk, 5])

    l = enforceConnectivity(img, l, C.shape[0])
    return l
In [5]:
labels1 = slic(img1, 400, 20)
resp1 = color.label2rgb(labels1, img1, kind='avg')
display_image(resp1)

labels2 = slic(img1, 400, 40)
resp2 = color.label2rgb(labels2, img1, kind='avg')
display_image(resp2)
In [6]:
display_image(img2)
labels3 = slic(img2, 400, 40)
resp3 = color.label2rgb(labels3, img2, kind='avg')
display_image(resp3)

skimage versus my results:

In [11]:
fig, ax = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(16, 10))
ax[0, 0].imshow(out1)
ax[0, 1].imshow(resp1)
ax[1, 0].imshow(out2)
ax[1, 1].imshow(resp2)
for i in range(2):
    for j in range(2):
        ax[i, j].axis('off')
plt.tight_layout()
In [13]:
fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(16, 10))

ax[0].imshow(out3)
ax[1].imshow(resp3)
for a in ax:
        a.axis('off')
plt.tight_layout()
In [ ]: