71 lines
3.0 KiB
Python
71 lines
3.0 KiB
Python
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
"""A module for helper tensorflow ops."""
|
||
|
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
|
||
|
image_width):
|
||
|
"""Transforms the box masks back to full image masks.
|
||
|
|
||
|
Embeds masks in bounding boxes of larger masks whose shapes correspond to
|
||
|
image shape.
|
||
|
|
||
|
Args:
|
||
|
box_masks: A tf.float32 tensor of size [num_masks, mask_height, mask_width].
|
||
|
boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
|
||
|
corners. Row i contains [ymin, xmin, ymax, xmax] of the box
|
||
|
corresponding to mask i. Note that the box corners are in
|
||
|
normalized coordinates.
|
||
|
image_height: Image height. The output mask will have the same height as
|
||
|
the image height.
|
||
|
image_width: Image width. The output mask will have the same width as the
|
||
|
image width.
|
||
|
|
||
|
Returns:
|
||
|
A tf.float32 tensor of size [num_masks, image_height, image_width].
|
||
|
"""
|
||
|
|
||
|
# TODO(rathodv): Make this a public function.
|
||
|
def reframe_box_masks_to_image_masks_default():
|
||
|
"""The default function when there are more than 0 box masks."""
|
||
|
|
||
|
def transform_boxes_relative_to_boxes(boxes, reference_boxes):
|
||
|
boxes = tf.reshape(boxes, [-1, 2, 2])
|
||
|
min_corner = tf.expand_dims(reference_boxes[:, 0:2], 1)
|
||
|
max_corner = tf.expand_dims(reference_boxes[:, 2:4], 1)
|
||
|
transformed_boxes = (boxes - min_corner) / (max_corner - min_corner)
|
||
|
return tf.reshape(transformed_boxes, [-1, 4])
|
||
|
|
||
|
box_masks_expanded = tf.expand_dims(box_masks, axis=3)
|
||
|
num_boxes = tf.shape(box_masks_expanded)[0]
|
||
|
unit_boxes = tf.concat(
|
||
|
[tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1)
|
||
|
reverse_boxes = transform_boxes_relative_to_boxes(unit_boxes, boxes)
|
||
|
return tf.image.crop_and_resize(
|
||
|
image=box_masks_expanded,
|
||
|
boxes=reverse_boxes,
|
||
|
box_ind=tf.range(num_boxes),
|
||
|
crop_size=[image_height, image_width],
|
||
|
extrapolation_value=0.0)
|
||
|
|
||
|
image_masks = tf.cond(
|
||
|
tf.shape(box_masks)[0] > 0,
|
||
|
reframe_box_masks_to_image_masks_default,
|
||
|
lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32))
|
||
|
return tf.squeeze(image_masks, axis=3)
|