You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

6.4 MiB

In [61]:
import os
import cv2
import numpy
import pickle
import skimage
from scipy.misc import imread
from xmljson import badgerfish as bf
from xml.etree.ElementTree import fromstring
from matplotlib import pyplot, rcParams
%matplotlib inline
In [41]:
def negative_regions(positive):
    """Estimate some negative rectangular regions, where positive is a bitmap marking diatomenes.
    First extract 200 to 800 (randomly) consequtive rows from positive.  From that extract all column
    sequences without positive marks. Yield those as a list of 2 slices."""
    for _ in range(3):
        height = numpy.random.randint(200,800)
        ymin = numpy.random.randint(0, positive.shape[0]-width)
        ymax = ymin + height
        has_diatomeen = positive[ymin:ymax].sum(0) == 0

        for cont in cv2.findContours(
            has_diatomeen[:,None].astype(numpy.uint8),
            cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE
        )[1]:
            if len(cont) != 2:
                continue
            xmin, xmax = cont[0][0][1], cont[1][0][1]
            if xmax - xmin > 200 and xmax - xmin < 800:
                region = [slice(ymin, ymax), slice(xmin, xmax)]
                assert positive[region].sum() == 0, "Due to a bug, some positive regions were selected as negative."
                yield region
                break
In [42]:
negatives = []
positives = []

for annotation in os.listdir('annotations/'):
    with open('annotations/' + annotation) as f:
        data = bf.data(fromstring(f.read()))['annotation']
        filename = data['filename']['$']
        try:
            image = imread('images/' + filename)
        except FileNotFoundError:
            continue
        positive = numpy.zeros(image.shape[:2]).astype(numpy.bool)
        for o in data['object'] if type(data['object'])==list else [data['object']]:
            region = [
                slice(o['bndbox']['ymin']['$'], o['bndbox']['ymax']['$']),
                slice(o['bndbox']['xmin']['$'], o['bndbox']['xmax']['$']),
            ]
            positive[region] = 1
            positives.append(image[region])
        for region in negative_regions(positive):
            negatives.append(image[region])
In [45]:
X = negatives + positives
y = [False] * len(negatives) + [True] * len(positives)
In [53]:
rcParams['figure.figsize'] = (20,5)
axes = (axis for _ in range(10000) for axis in pyplot.subplots(1,4)[1])

for image, class_ in zip(X, y):
    if numpy.random.rand() < 0.05:
        axis = next(axes)
        axis.imshow(image)
        axis.set_xticks([])
        axis.set_yticks([])
        h, w, _ = image.shape
        axis.plot([w, 0, 0, w, w], [0, 0, h, h, 0], 'g' if class_ else 'r', linewidth=8)
        axis.set_xlim(0, w)
        axis.set_ylim(0, h)
#         axis.set_ylabel('negative' if not class_ else 'positive')
/home/herbert/.virtualenvs/medical/lib/python3.5/site-packages/matplotlib/pyplot.py:524: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  max_open_warning, RuntimeWarning)
In [62]:
with open('true_false_data.p3', 'wb') as file:
    pickle.dump((X,y), file)