Notebooks

Training data for segmentation

iPython Project Created 3 months ago Free
Script augments input projects to prepare training dataset for segmentation (or instance segmentation) models
Free Signup

Training data for segmentation

Script augments input projects to prepare training dataset for segmentation (or instance segmentation) models.

This Notebook is a replacement for DTL "Create train set -> Segmentation"

Input:

  • Existing annotated project

Output:

  • New augumented project, ready for model training

Configuration

Edit the following settings for your own case

In [1]:
%matplotlib inline
import supervisely_lib as sly
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import math
In [2]:
team_name = "jupyter_tutorials"
workspace_name = "cookbook"
project_name = "tutorial_project"

dst_project_name = "tutorial_project_aug_segmentation"

validation_fraction = 0.1
image_multiplier = 5

class_bg = sly.ObjClass('bg', sly.Rectangle, color=[0, 0, 64])
tag_meta_train = sly.TagMeta('train', sly.TagValueType.NONE)
tag_meta_val = sly.TagMeta('val', sly.TagValueType.NONE)

# Obtain server address and your api_token from environment variables
# Edit those values if you run this notebook on your own PC
address = os.environ['SERVER_ADDRESS']
token = os.environ['API_TOKEN']
In [3]:
# Initialize API object
api = sly.Api(address, token)

Verify input values

Test that context (team / workspace / project) exists

In [4]:
# get IDs of team, workspace and project by names

team = api.team.get_info_by_name(team_name)
if team is None:
    raise RuntimeError("Team {!r} not found".format(team_name))

workspace = api.workspace.get_info_by_name(team.id, workspace_name)
if workspace is None:
    raise RuntimeError("Workspace {!r} not found".format(workspace_name))
    
project = api.project.get_info_by_name(workspace.id, project_name)
if project is None:
    raise RuntimeError("Project {!r} not found".format(project_name))
    
print("Team: id={}, name={}".format(team.id, team.name))
print("Workspace: id={}, name={}".format(workspace.id, workspace.name))
print("Project: id={}, name={}".format(project.id, project.name))
Out [4]:
Team: id=30, name=jupyter_tutorials
Workspace: id=76, name=cookbook
Project: id=898, name=tutorial_project

Get Source Project Meta

In [5]:
meta_json = api.project.get_meta(project.id)
meta = sly.ProjectMeta.from_json(meta_json)
print("Source ProjectMeta: \n", meta)
Out [5]:
Source ProjectMeta: 
 ProjectMeta:
Object Classes
+--------+-----------+----------------+
|  Name  |   Shape   |     Color      |
+--------+-----------+----------------+
|  bike  | Rectangle | [246, 255, 0]  |
|  car   |  Polygon  | [190, 85, 206] |
|  dog   |  Polygon  |  [253, 0, 0]   |
| person |   Bitmap  |  [0, 255, 18]  |
+--------+-----------+----------------+
Image Tags
+-------------+--------------+-----------------------+
|     Name    |  Value type  |    Possible values    |
+-------------+--------------+-----------------------+
| cars_number |  any_number  |          None         |
|     like    |     none     |          None         |
|   situated  | oneof_string | ['inside', 'outside'] |
+-------------+--------------+-----------------------+
Object Tags
+---------------+--------------+-----------------------+
|      Name     |  Value type  |    Possible values    |
+---------------+--------------+-----------------------+
|   car_color   |  any_string  |          None         |
| person_gender | oneof_string |   ['male', 'female']  |
|  vehicle_age  | oneof_string | ['modern', 'vintage'] |
+---------------+--------------+-----------------------+

Construct Destination ProjectMeta

In [6]:
class_name_mapping = {}
bitmap_classes = [class_bg]
for obj_class in meta.obj_classes:
    class_name_mapping[obj_class.name] = '{}_bitmap'.format(obj_class.name)
    new_obj_class = sly.ObjClass(class_name_mapping[obj_class.name], sly.Bitmap, color=obj_class.color)
    bitmap_classes.append(new_obj_class)

dst_meta = meta.clone(obj_classes=sly.ObjClassCollection())
dst_meta = dst_meta.add_obj_classes(bitmap_classes)
dst_meta = dst_meta.add_tag_metas([tag_meta_train, tag_meta_val])

print("Destination ProjectMeta:\n", dst_meta)
Out [6]:
Destination ProjectMeta:
 ProjectMeta:
Object Classes
+---------------+-----------+----------------+
|      Name     |   Shape   |     Color      |
+---------------+-----------+----------------+
|       bg      | Rectangle |   [0, 0, 64]   |
|  bike_bitmap  |   Bitmap  | [246, 255, 0]  |
|   car_bitmap  |   Bitmap  | [190, 85, 206] |
|   dog_bitmap  |   Bitmap  |  [253, 0, 0]   |
| person_bitmap |   Bitmap  |  [0, 255, 18]  |
+---------------+-----------+----------------+
Image Tags
+-------------+--------------+-----------------------+
|     Name    |  Value type  |    Possible values    |
+-------------+--------------+-----------------------+
| cars_number |  any_number  |          None         |
|     like    |     none     |          None         |
|   situated  | oneof_string | ['inside', 'outside'] |
|    train    |     none     |          None         |
|     val     |     none     |          None         |
+-------------+--------------+-----------------------+
Object Tags
+---------------+--------------+-----------------------+
|      Name     |  Value type  |    Possible values    |
+---------------+--------------+-----------------------+
|   car_color   |  any_string  |          None         |
| person_gender | oneof_string |   ['male', 'female']  |
|  vehicle_age  | oneof_string | ['modern', 'vintage'] |
+---------------+--------------+-----------------------+

Create Destination project

In [7]:
# check if destination project already exists. If yes - generate new free name
if api.project.exists(workspace.id, dst_project_name):
    dst_project_name = api.project.get_free_name(workspace.id, dst_project_name)
print("Destination project name: ", dst_project_name)
Out [7]:
Destination project name:  tutorial_project_aug_segmentation
In [8]:
dst_project = api.project.create(workspace.id, dst_project_name)
api.project.update_meta(dst_project.id, dst_meta.to_json())
print("Destination project has been created: id={}, name={!r}".format(dst_project.id, dst_project.name))
Out [8]:
Destination project has been created: id=1333, name='tutorial_project_aug_segmentation'

Iterate over all images, augment them and uplod to destination project

In [9]:
def process(img, ann):
    results = []
    
    bitmap_labels = []
    for label in ann.labels:
        new_class = dst_meta.get_obj_class(class_name_mapping[label.obj_class.name])
        [new_geometry] = sly.geometry_to_bitmap(label.geometry)
        new_label = label.clone(obj_class=new_class, geometry=new_geometry)
        bitmap_labels.append(new_label)
    ann_new = ann.clone(labels=bitmap_labels)
    
    results.append((img, ann_new))

    img_lr, ann_lr = sly.aug.fliplr(img, ann_new)
    results.append((img_lr, ann_lr))

    crops = []
    for cur_img, cur_ann in results:
        for _ in range(image_multiplier):
            res_img, res_ann = sly.aug.random_crop_fraction(cur_img, cur_ann, (0.6, 0.9), (0.6, 0.9))
            crops.append((res_img, res_ann))
    results.extend(crops)

    tagged_results = []
    for cur_img, cur_ann in results:
        bg_label = sly.Label(sly.Rectangle.from_array(cur_img), class_bg)
        # Order matters, please read "tutorial 1" for more info.
        cur_ann = cur_ann.clone(labels = ([bg_label] + cur_ann.labels))
        
        tag = sly.Tag(tag_meta_val) if random.random() <= validation_fraction else sly.Tag(tag_meta_train)
        cur_ann = cur_ann.add_tag(tag)
        
        tagged_results.append((cur_img, cur_ann))
        
    res_imgs = [img for img, ann in tagged_results]
    res_anns = [ann for img, ann in tagged_results]
    return res_imgs, res_anns
In [10]:
aug_results_debug = None

for dataset in api.dataset.get_list(project.id):
    print('Dataset: {}'.format(dataset.name), flush=True)
    dst_dataset = api.dataset.create(dst_project.id, dataset.name)

    for image in tqdm(api.image.get_list(dataset.id)):
        img = api.image.download_np(image.id)
        ann_json = api.annotation.download(image.id).annotation
        ann = sly.Annotation.from_json(ann_json, meta)

        aug_imgs, aug_anns = process(img, ann)
        names = sly.generate_names(image.name, len(aug_imgs))
        
        dst_image_infos = api.image.upload_nps(dst_dataset.id, names, aug_imgs)
        dst_image_ids = [img_info.id for img_info in dst_image_infos]
        api.annotation.upload_anns(dst_image_ids, aug_anns)
        
        if aug_results_debug is None:
            aug_results_debug = list(zip(aug_imgs, aug_anns))
Out [10]:
Dataset: dataset_01
100%|██████████| 3/3 [00:02<00:00,  1.36it/s]
Dataset: dataset_02
100%|██████████| 2/2 [00:01<00:00,  1.54it/s]

Visualize first image augmentations

for aug_img, aug_ann in aug_results_debug:
    draw_img = np.copy(aug_img)
    aug_ann.draw(draw_img)
    plt.figure()
    plt.imshow(draw_img)
In [11]:
f = plt.figure(figsize=(20, 20))

ncols = 3
nrows = math.ceil(len(aug_results_debug) / ncols)

for index, (aug_img, aug_ann) in enumerate(aug_results_debug, start=1):
    draw_img = np.copy(aug_img)
    aug_ann.draw(draw_img, thickness=5)
    
    f.add_subplot(nrows, ncols, index)
    plt.imshow(draw_img)

plt.show()
Out [11]:
<Figure size 1440x1440 with 12 Axes>

More Info

ID
30
First released
3 months ago
Last updated
A month ago

Owner

s