UNetModel(
channels,
skip_connection_names,
model_factory,
depth=None,
use_input_scale=False,
use_output_scale=False,
activation="sigmoid",
enable_skip_connections=True,
allow_preconvolution=False,
additional_conv_layers=None,
downscale_layer_factory=None,
post_layer_factory=None,
**kwargs
)
Bases: ModelFactory
Factory method for creating a UNet model
Creates a Resnet50 UNet variation for pixelwise output.
The output has the same dimension as the input.
Parameters
skip_connection_names: List[str]
The names of the layers to use as skip connections.
depth: Optional[int], default None
Describes the amount of skip_connections used.
If not given it uses the maximal amount w.r.t the given
channels or the the maximal mobilenet depth (count of downsamplings, 5).
channels: Optional[List[int]], default [16, 32, 48, 64, 128]
How many channels after each upsampling should be used.
use_input_scale: bool, default False
If true the input is divided by 255.0
use_output_scale: bool, default False
If true the output is multiplied by 255.0
activation: Optional[str], default sigmoid
Final activation, used for last layer
enable_skip_connections: Optional[bool], default True,
Determines, whether to use skip connections
allow_preconvolution: bool, default False
Uses a convolution to normalize the amount of layers to three.
model_params: Optional[dict], default None
Additional params to init the modelfactory
additional_conv_layers: Optional[List[int]], default None
Additional conv layers to add between the input and model.
downscale_layer_factory: Optional[LayerFactory], default None
Factory to create the downscale layers.
post_layer_factory: Optional[LayerFactory], default None
Factory to create the post layers.
Source code in niceml/dlframeworks/keras/models/unets.py
| def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
channels: List[int],
skip_connection_names: List[str],
model_factory: Callable,
depth: Optional[int] = None,
use_input_scale: bool = False,
use_output_scale: bool = False,
activation: str = "sigmoid",
enable_skip_connections: bool = True,
allow_preconvolution: bool = False,
additional_conv_layers: Optional[List[int]] = None,
downscale_layer_factory: Optional[LayerFactory] = None,
post_layer_factory: Optional[LayerFactory] = None,
**kwargs,
):
"""
Creates a Resnet50 UNet variation for pixelwise output.
The output has the same dimension as the input.
Parameters
----------
skip_connection_names: List[str]
The names of the layers to use as skip connections.
depth: Optional[int], default None
Describes the amount of skip_connections used.
If not given it uses the maximal amount w.r.t the given
channels or the the maximal mobilenet depth (count of downsamplings, 5).
channels: Optional[List[int]], default [16, 32, 48, 64, 128]
How many channels after each upsampling should be used.
use_input_scale: bool, default False
If true the input is divided by 255.0
use_output_scale: bool, default False
If true the output is multiplied by 255.0
activation: Optional[str], default sigmoid
Final activation, used for last layer
enable_skip_connections: Optional[bool], default True,
Determines, whether to use skip connections
allow_preconvolution: bool, default False
Uses a convolution to normalize the amount of layers to three.
model_params: Optional[dict], default None
Additional params to init the modelfactory
additional_conv_layers: Optional[List[int]], default None
Additional conv layers to add between the input and model.
downscale_layer_factory: Optional[LayerFactory], default None
Factory to create the downscale layers.
post_layer_factory: Optional[LayerFactory], default None
Factory to create the post layers.
"""
self.model_factory = model_factory
self.model_params = kwargs
self.channels: List[int] = [64, 128, 256, 512] if channels is None else channels
self.depth: int = len(self.channels) + 1 if depth is None else depth
# adjust channels again
self.channels = self.channels[: self.depth - 1]
self.skip_connection_names = skip_connection_names[: self.depth]
self.activation = activation
self.use_input_scale = use_input_scale
self.use_output_scale = use_output_scale
self.enable_skip_connections = enable_skip_connections
self.allow_preconvolution = allow_preconvolution
self.additional_conv_layers = additional_conv_layers
self.downscale_layer_factory = downscale_layer_factory
self.post_layer_factory = post_layer_factory
|
Functions
create_model
create_model(data_description)
Create a model for the given data description.
Parameters:
-
data_description
(DataDescription
)
–
Data description the model is based on
Returns:
A Unet model object
Source code in niceml/dlframeworks/keras/models/unets.py
| def create_model(self, data_description: DataDescription) -> Any:
"""
Create a model for the given data description.
Args:
data_description: Data description the model is based on
Returns:
A Unet model object
"""
input_dd: InputImageDataDescription = check_instance(
data_description, InputImageDataDescription
)
output_dd: OutputImageDataDescription = check_instance(
data_description, OutputImageDataDescription
)
expected_input_channels = 3
if (
not self.allow_preconvolution
and input_dd.get_input_channel_count() != expected_input_channels
):
raise Exception(
f"Input channels must have the size of {expected_input_channels}!"
f" Instead size == {input_dd.get_input_channel_count()}"
)
input_image_size = input_dd.get_input_image_size()
output_image_size = output_dd.get_output_image_size()
skip_connection_count = len(self.skip_connection_names)
image_size_scale = input_image_size.get_division_factor(output_image_size)
if not math.log(image_size_scale, 2).is_integer():
raise Exception(
f"Image size scale must be a power of 2! Instead {image_size_scale}"
)
input_shape = input_image_size.to_numpy_shape() + (3,)
inputs = layers.Input(shape=input_shape, name="image")
actual_layer = inputs
encoder = self.model_factory(
input_tensor=actual_layer,
weights="imagenet",
include_top=False,
**self.model_params,
)
encoder_output = encoder.get_layer(self.skip_connection_names.pop()).output
actual_layer = encoder_output
actual_image_size = input_image_size / (2 ** (skip_connection_count - 1))
for skip_connection_name in reversed(self.skip_connection_names):
if actual_image_size >= output_image_size:
break
channels = self.channels.pop()
x_skip = encoder.get_layer(skip_connection_name).output
actual_layer = layers.UpSampling2D((2, 2))(actual_layer)
if self.enable_skip_connections:
actual_layer = layers.Concatenate()([actual_layer, x_skip])
actual_layer = layers.Conv2D(channels, (3, 3), padding="same")(actual_layer)
actual_layer = layers.BatchNormalization()(actual_layer)
actual_layer = layers.Activation("relu")(actual_layer)
actual_layer = layers.Conv2D(channels, (3, 3), padding="same")(actual_layer)
actual_layer = layers.BatchNormalization()(actual_layer)
actual_layer = layers.Activation("relu")(actual_layer)
actual_image_size *= 2
while actual_image_size > output_image_size:
if self.downscale_layer_factory is None:
raise Exception(
"Downscale layer factory must be given, if the image size "
"after the skip connections is larger than the output image size!"
)
actual_layer = self.downscale_layer_factory.create_layers(actual_layer)
actual_image_size /= 2
if self.post_layer_factory is not None:
actual_layer = self.post_layer_factory.create_layers(actual_layer)
output_conv_name = "output_conv" if self.use_output_scale else "output"
filters = output_dd.get_output_channel_count()
if output_dd.get_use_void_class():
filters += 1
actual_layer = layers.Conv2D(
name=output_conv_name,
filters=filters,
kernel_size=(1, 1),
activation=self.activation,
padding="same",
)(actual_layer)
if self.use_output_scale:
actual_layer = layers.Lambda(lambda x: x * 255.0, name="output")(
actual_layer
)
model = Model(inputs, actual_layer)
model.summary()
model = add_premodel_layers(
allow_preconvolution=self.allow_preconvolution,
use_input_scale=self.use_input_scale,
data_desc=input_dd,
model=model,
additional_conv_layers=self.additional_conv_layers,
)
return model
|