import random
import cv2
__all__ = ['to_tuple', 'BasicTransform', 'DualTransform', 'ImageOnlyTransform', 'NoOp']
def to_tuple(param, low=None):
if isinstance(param, (list, tuple)):
return tuple(param)
elif param is not None:
if low is None:
return -param, param
return (low, param) if low < param else (param, low)
else:
return param
class BasicTransform(object):
def __init__(self, always_apply=False, p=0.5):
self.p = p
self.always_apply = always_apply
self._additional_targets = {}
def __call__(self, **kwargs):
if (random.random() < self.p) or self.always_apply:
params = self.get_params()
params = self.update_params(params, **kwargs)
if self.targets_as_params:
targets_as_params = {k: kwargs[k] for k in self.targets_as_params}
params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params)
params.update(params_dependent_on_targets)
res = {}
for key, arg in kwargs.items():
if arg is not None:
target_function = self._get_target_function(key)
target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
res[key] = target_function(arg, **dict(params, **target_dependencies))
else:
res[key] = None
return res
return kwargs
def _get_target_function(self, key):
transform_key = key
if key in self._additional_targets:
transform_key = self._additional_targets.get(key, None)
target_function = self.targets.get(transform_key, lambda x, **p: x)
return target_function
def apply(self, img, **params):
raise NotImplementedError
def get_params(self):
return {}
@property
def targets(self):
# you must specify targets in subclass
# for example: ('image', 'mask')
# ('image', 'boxes')
raise NotImplementedError
def update_params(self, params, **kwargs):
if hasattr(self, 'interpolation'):
params['interpolation'] = self.interpolation
params.update({'cols': kwargs['image'].shape[1], 'rows': kwargs['image'].shape[0]})
return params
@property
def target_dependence(self):
return {}
def add_targets(self, additional_targets):
"""Add targets to transform them the same way as one of existing targets
ex: {'target_image': 'image'}
ex: {'obj1_mask': 'mask', 'obj2_mask': 'mask'}
by the way you must have at least one object with key 'image'
Args:
additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'}
"""
self._additional_targets = additional_targets
@property
def targets_as_params(self):
return []
def get_params_dependent_on_targets(self, params):
raise NotImplementedError
[docs]class NoOp(DualTransform):
"""Does nothing"""
def apply_to_bbox(self, bbox, **params):
return bbox
def apply(self, img, **params):
return img
def apply_to_mask(self, img, **params):
return img