Thais Lima de Sousa - Supervised Graduate Project

Graph-based Image Segmentation

In [1]:
import numpy as np
import networkx as nx
from matplotlib import pyplot as plt
from skimage import io, graph, color, segmentation
from skimage.future import graph
%matplotlib inline

Auxiliar functions

In [2]:
def show_img(img):
    H = 5.0
    W = img.shape[0]*H/img.shape[1]
    f = plt.figure(figsize=(H, W))
    plt.axis('off')
    plt.imshow(img)
In [3]:
# disjoint-set forest with union rank and path compression

class UF:
    def __init__(self, n):
        self.num = n
        self.parent = list(range(n))
        self.rank = [0 for i in range(n)]
        self.size = [1 for i in range(n)]
    
    def find(self, v):
        w = v
        while(w != self.parent[w]):
            w = self.parent[w]
        self.parent[v] = w
        return w
    
    def union(self, x, y):
        xR = self.find(x)
        yR = self.find(y)
        if xR == yR:
            return
        
        if self.rank[xR] > self.rank[yR]:
            self.parent[yR] = xR
            self.size[xR] += self.size[yR]
        
        else:
            self.parent[xR] = yR
            self.size[yR] += self.size[xR]
            if self.rank[xR] == self.rank[yR]:
                self.rank[yR] += 1
        self.num -= 1
        
    def sizec(self, x):
        x = self.find(x)
        return self.size[x]
    
    def num_sets(self):
        return self.num
    
    # for debugging
    
    def print_parent(self, x):
        return self.parent[x]
    
    def print_rank(self, x):
        return self.rank[x]
In [4]:
# disjoint-set forest with union rank and path compression

class UF:
    def __init__(self, n):
        self.num = n
        self.parent = list(range(n))
        self.rank = [0 for i in range(n)]
        self.size = [1 for i in range(n)]
    
    def find(self, v):
        w = v
        while(w != self.parent[w]):
            w = self.parent[w]
        self.parent[v] = w
        return w
    
    def union(self, x, y):
        xR = self.find(x)
        yR = self.find(y)
        if xR == yR:
            return
        
        if self.rank[xR] > self.rank[yR]:
            self.parent[yR] = xR
            self.size[xR] += self.size[yR]
        
        else:
            self.parent[xR] = yR
            self.size[yR] += self.size[xR]
            if self.rank[xR] == self.rank[yR]:
                self.rank[yR] += 1
        self.num -= 1
        
    def sizec(self, x):
        x = self.find(x)
        return self.size[x]
    
    def num_sets(self):
        return self.num

Falzenszwalb and Huttenlocher's algorithm

In [5]:
def falzenszwalb(img, k, min_size, labels=None):
    H, W = img.shape[:2]
    
    # segment at pixel level
    if labels is None:
        labels = np.arange(H*W).reshape((H, W))
        
    g = graph.rag_mean_color(img, labels)
    V = len(g)
    E = g.size()
    c = 1.0*k
    disj_set = UF(V)
    edges = list(g.edges_iter(data='weight'))
    edges.sort(key=lambda w:w[2])
    
    # initialize thresholdings
    L = [None]*V
    for i in range(V):
        L[i] = c/1.0
    for i in range(E):
        edge = edges[i]
        v1 = disj_set.find(edge[0])
        v2 = disj_set.find(edge[1])
        w = edge[2]
        if not v1 == v2:
            if w <= L[v1] and w <= L[v2]:
                disj_set.union(v1, v2)
                v1 = disj_set.find(v1)
                L[v1] = w + c/disj_set.sizec(v1) # add threshold    

    # post process small components
    for e in edges:
        v1 = disj_set.find(e[0])
        v2 = disj_set.find(e[1])
        
        if (v1 != v2 and ((disj_set.sizec(v1) < min_size) or (disj_set.sizec(v2) < min_size))):
            disj_set.union(v1, v2)

    new_labels = np.zeros((H, W))
    for i in range(H):
        for j in range(W):
            new_labels[i, j] = disj_set.find(labels[i, j])  
    return new_labels

Guimaraẽs' et al. algorithm

In [6]:
def build_hierarchies(G):
    V = len(G)
    mst = nx.minimum_spanning_edges(G)
    mst = list(mst)
    Pw = UF(V)
    Gh = nx.Graph()
    Int = [0.0]*V
    
    for (x, y, w) in mst:
        a = Pw.find(x)
        b = Pw.find(y)
        w = w['weight']
        Sa = (w - Int[a])*Pw.sizec(a)
        Sb = (w - Int[b])*Pw.sizec(b)
        if Sa >= Sb: Gh.add_edge(x, y, weight=Sa)
        else: Gh.add_edge(x, y, weight=Sb)
        Pw.union(a, b)
        Int[Pw.find(a)] = w
        
    return Gh

def guimaraes(img, k, min_size, labels = None):
    H, W = img.shape[:2]
    
    # segment at pixel level
    if labels is None:
        labels = np.arange(H*W).reshape((H, W))

    g = graph.rag_mean_color(img, labels)    
    # segment graph
    V = len(g)
    disj_set = UF(V)
    h_map = build_hierarchies(g)
    edge_list = h_map.edges()
    for (i, j) in edge_list:
        L = h_map.get_edge_data(i, j)['weight']
        if L < k: disj_set.union(i, j)    
    # post process small components
    for (v1, v2) in edge_list:
        v1 = disj_set.find(v1)
        v2 = disj_set.find(v2)
        if (v1 != v2 and ((disj_set.sizec(v1) < min_size) or (disj_set.sizec(v2) < min_size))):
            disj_set.union(v1, v2)

    new_labels = np.zeros((H, W))
    for i in range(H):
        for j in range(W):
            new_labels[i, j] = disj_set.find(labels[i, j])  
    return new_labels

Example

In [7]:
filename = 'g3_01'
img = io.imread('data/' + filename + '.jpg')
slk = 500; slm = 30 
k = 200; ms = 20

print('SLIC: %d superpixels with compactness %d' % (slk, slm))
slic_labels = segmentation.slic(img, n_segments=slk, compactness=slm)
result_slic = color.label2rgb(slic_labels, img, kind='avg')
result_slic = segmentation.mark_boundaries(result_slic, slic_labels, (0, 0, 0))
show_img(result_slic)
SLIC: 500 superpixels with compactness 30
In [8]:
print('Felzenszwalb, k = %d' % k)
slic_fh = falzenszwalb(img, k, ms, slic_labels)
result_fh = color.label2rgb(slic_fh, img, kind='avg')
show_img(result_fh)
Felzenszwalb, k = 200
In [9]:
print('GuimarĂ£es, k = %d' % k)
slic_gui = guimaraes(img, k, ms, slic_labels)
result_gui = color.label2rgb(slic_gui, img, kind='avg')
show_img(result_gui)
GuimarĂ£es, k = 200