Notebooks

Training data for segmentation

iPython Project Created 4 days ago Free
Script augments input projects to prepare training dataset for segmentation (or instance segmentation) models
Free Signup

Trainign data for segmentation

Script augments input projects to prepare training dataset for segmentation (or instance segmentation) models.

This Notebook is a replacement for DTL "Create train set -> Segmentation"

Input:

  • Existing annotated Project

Output:

  • New augumented Project, ready for model training

Configuration

Edit the following settings for your own case

In [1]:
%matplotlib inline
import supervisely_lib as sly
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import math
In [2]:
team_name = "jupyter_tutorials"
workspace_name = "cookbook"
project_name = "tutorial_project"

dst_project_name = "tutorial_project_segm"

validation_portion = 0.1
image_multiplier = 5

class_bg = sly.ObjClass('bg', sly.Rectangle, color=[0, 0, 64])
tag_meta_train = sly.TagMeta('train', sly.TagValueType.NONE)
tag_meta_val = sly.TagMeta('val', sly.TagValueType.NONE)

# Obtain server address and your api_token from environment variables
# Edit those values if you run this notebook on your own PC
address = os.environ['SERVER_ADDRESS']
token = os.environ['API_TOKEN']
In [3]:
# Initialize API object
api = sly.Api(address, token)

Verify input values

Test that context (team / workspace / project) exists

In [4]:
# get IDs of team, workspace and project by names

team = api.team.get_info_by_name(team_name)
if team is None:
    raise RuntimeError("Team {!r} not found".format(team_name))

workspace = api.workspace.get_info_by_name(team.id, workspace_name)
if workspace is None:
    raise RuntimeError("Workspace {!r} not found".format(workspace_name))
    
project = api.project.get_info_by_name(workspace.id, project_name)
if project is None:
    raise RuntimeError("Project {!r} not found".format(project_name))
    
print("Team: id={}, name={}".format(team.id, team.name))
print("Workspace: id={}, name={}".format(workspace.id, workspace.name))
print("Project: id={}, name={}".format(project.id, project.name))
Out [4]:
Team: id=30, name=jupyter_tutorials
Workspace: id=76, name=cookbook
Project: id=898, name=tutorial_project

Get Source Project Meta

In [5]:
project = api.project.get_info_by_name(workspace.id, project_name)
meta_json = api.project.get_meta(project.id)
meta = sly.ProjectMeta.from_json(meta_json)
print("Source ProjectMeta: \n", meta)
Out [5]:
Source ProjectMeta: 
 ProjectMeta:
Object Classes
+--------+-----------+----------------+
|  Name  |   Shape   |     Color      |
+--------+-----------+----------------+
|  bike  | Rectangle | [246, 255, 0]  |
|  car   |  Polygon  | [190, 85, 206] |
|  dog   |  Polygon  |  [253, 0, 0]   |
| person |   Bitmap  |  [0, 255, 18]  |
+--------+-----------+----------------+
Image Tags
+-------------+--------------+-----------------------+
|     Name    |  Value type  |    Possible values    |
+-------------+--------------+-----------------------+
|   situated  | oneof_string | ['inside', 'outside'] |
|     like    |     none     |          None         |
| cars_number |  any_number  |          None         |
+-------------+--------------+-----------------------+
Object Tags
+---------------+--------------+-----------------------+
|      Name     |  Value type  |    Possible values    |
+---------------+--------------+-----------------------+
| person_gender | oneof_string |   ['male', 'female']  |
|  vehicle_age  | oneof_string | ['modern', 'vintage'] |
|   car_color   |  any_string  |          None         |
+---------------+--------------+-----------------------+

Construct Destination ProjectMeta

In [6]:
def process_meta(input_meta):
    output_meta = input_meta.clone(obj_classes=sly.ObjClassCollection())
    
    classes_mapping = {}
    for obj_class in input_meta.obj_classes:
        classes_mapping[obj_class.name] = '{}_bitmap'.format(obj_class.name)
        new_obj_class = sly.ObjClass(classes_mapping[obj_class.name], sly.Bitmap, color=obj_class.color)
        output_meta = output_meta.add_obj_class(new_obj_class)
        
    output_meta = output_meta.add_obj_class(class_bg)
    output_meta = output_meta.add_img_tag_meta(tag_meta_train)
    output_meta = output_meta.add_img_tag_meta(tag_meta_val)
    return output_meta, classes_mapping
In [7]:
dst_meta, classes_mapping = process_meta(meta)
print("Destination ProjectMeta:\n", dst_meta)
Out [7]:
Destination ProjectMeta:
 ProjectMeta:
Object Classes
+---------------+-----------+----------------+
|      Name     |   Shape   |     Color      |
+---------------+-----------+----------------+
|  bike_bitmap  |   Bitmap  | [246, 255, 0]  |
|   car_bitmap  |   Bitmap  | [190, 85, 206] |
|   dog_bitmap  |   Bitmap  |  [253, 0, 0]   |
| person_bitmap |   Bitmap  |  [0, 255, 18]  |
|       bg      | Rectangle |   [0, 0, 64]   |
+---------------+-----------+----------------+
Image Tags
+-------------+--------------+-----------------------+
|     Name    |  Value type  |    Possible values    |
+-------------+--------------+-----------------------+
|   situated  | oneof_string | ['inside', 'outside'] |
|     like    |     none     |          None         |
| cars_number |  any_number  |          None         |
|    train    |     none     |          None         |
|     val     |     none     |          None         |
+-------------+--------------+-----------------------+
Object Tags
+---------------+--------------+-----------------------+
|      Name     |  Value type  |    Possible values    |
+---------------+--------------+-----------------------+
| person_gender | oneof_string |   ['male', 'female']  |
|  vehicle_age  | oneof_string | ['modern', 'vintage'] |
|   car_color   |  any_string  |          None         |
+---------------+--------------+-----------------------+

Create Destination project

In [8]:
# check if destination project already exists. If yes - generate new free name
if api.project.exists(workspace.id, dst_project_name):
    dst_project_name = api.project.get_free_name(workspace.id, dst_project_name)
print("Destination project name: ", dst_project_name)
Out [8]:
Destination project name:  tutorial_project_segm_001
In [9]:
dst_project = api.project.create(workspace.id, dst_project_name)
api.project.update_meta(dst_project.id, dst_meta.to_json())
print("Destination project has been created: id={}, name={!r}".format(dst_project.id, dst_project.name))
Out [9]:
Destination project has been created: id=1150, name='tutorial_project_segm_001'

Iterate over all images, augment them and uplod to destination project

In [10]:
def process(img, ann):
    results = []
    
    ann_new = ann.clone(labels=[])
    for label in ann.labels:
        new_class = dst_meta.get_obj_class(classes_mapping[label.obj_class.name])
        [new_geometry] = sly.geometry_to_bitmap(label.geometry)
        new_label = label.clone(obj_class=new_class, geometry=new_geometry)
        ann_new = ann_new.add_label(new_label)

    results.append((img, ann_new))

    img_lr, ann_lr = sly.aug.fliplr(img, ann_new)
    results.append((img_lr, ann_lr))

    crops = []
    for cur_img, cur_ann in results:
        for i in range(image_multiplier):
            res_img, res_ann = sly.aug.random_crop_fraction(cur_img, cur_ann, (0.6, 0.9), (0.6, 0.9))
            crops.append((res_img, res_ann))
    results.extend(crops)

    tagged_results = []
    for cur_img, cur_ann in results:
        bg_label = sly.Label(sly.Rectangle.from_array(cur_img), class_bg)
        # Order matters, please read "tutorial 1" for more info.
        cur_ann = cur_ann.clone(labels = ([bg_label] + cur_ann.labels))
        
        tag = sly.Tag(tag_meta_val) if random.random() <= validation_portion else sly.Tag(tag_meta_train)
        cur_ann = cur_ann.add_tag(tag)
        
        tagged_results.append((cur_img, cur_ann))
    return tagged_results
In [11]:
aug_results_debug = None

for dataset in api.dataset.get_list(project.id):
    print('Dataset: {}'.format(dataset.name))
    dst_dataset = api.dataset.create(dst_project.id, dataset.name)

    for image in tqdm(api.image.get_list(dataset.id)):
        img = api.image.download_np(image.id)
        ann_json = api.annotation.download(image.id).annotation
        ann = sly.Annotation.from_json(ann_json, meta)

        aug_results = process(img, ann)

        if aug_results_debug is None:
            aug_results_debug = aug_results.copy()

        for aug_img, aug_ann in aug_results:
            dst_img_name = api.image.get_free_name(dst_dataset.id, image.name)
            dst_img_hash = api.image.upload_np(aug_img, image.ext)
            dst_image = api.image.add(dst_dataset.id, dst_img_name, dst_img_hash)
            api.annotation.upload(dst_image.id, aug_ann.to_json())
Out [11]:
  0%|          | 0/3 [00:00<?, ?it/s]
Dataset: dataset_01
100%|██████████| 3/3 [00:10<00:00,  3.49s/it]
  0%|          | 0/2 [00:00<?, ?it/s]
Dataset: dataset_02
100%|██████████| 2/2 [00:06<00:00,  3.13s/it]

Visualize first image augmentations

for aug_img, aug_ann in aug_results_debug:
    draw_img = np.copy(aug_img)
    aug_ann.draw(draw_img)
    plt.figure()
    plt.imshow(draw_img)
In [12]:
f = plt.figure(figsize=(20, 20))

ncols = 3
nrows = math.ceil(len(aug_results_debug) / ncols)

index = 1
for aug_img, aug_ann in aug_results_debug:
    draw_img = np.copy(aug_img)
    aug_ann.draw(draw_img, thickness=5)
    
    f.add_subplot(nrows, ncols, index)
    plt.imshow(draw_img)
    
    index += 1

plt.show()
Out [12]:
<Figure size 1440x1440 with 12 Axes>

More Info

ID
30
First released
4 days ago
Last updated
3 hours ago

Owner

s