From 46bb19e51c2b394a2b20ef4040d63a63ef98889e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@inrae.fr> Date: Tue, 22 Aug 2023 20:06:08 +0200 Subject: [PATCH 1/2] ENH: option to expand dims in otbtf.Argmax --- otbtf/layers.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/otbtf/layers.py b/otbtf/layers.py index a3680421..8fc76deb 100644 --- a/otbtf/layers.py +++ b/otbtf/layers.py @@ -136,13 +136,15 @@ class Argmax(tf.keras.layers.Layer): Useful to transform a softmax into a "categorical" map for instance. """ - def __init__(self, name: str = None): + def __init__(self, name: str = None, expand_last_dim: bool = True): """ Params: name: layer name + expand_last_dim: expand the last dimension when True """ super().__init__(name=name) + self.expand_last_dim = expand_last_dim def call(self, inputs): """ @@ -157,7 +159,10 @@ class Argmax(tf.keras.layers.Layer): (nb_classes - 1). """ - return tf.expand_dims(tf.math.argmax(inputs, axis=-1), axis=-1) + argmax = tf.math.argmax(inputs, axis=-1) + if self.expand_last_dim: + return tf.expand_dims(argmax, axis=-1) + return argmax class Max(tf.keras.layers.Layer): -- GitLab From 887fb17332dd485edb528c4a53d2d030e3ca7b31 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@inrae.fr> Date: Tue, 22 Aug 2023 20:08:01 +0200 Subject: [PATCH 2/2] STY: linting --- otbtf/layers.py | 2 +- otbtf/ops.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/otbtf/layers.py b/otbtf/layers.py index 8fc76deb..ef65ec1c 100644 --- a/otbtf/layers.py +++ b/otbtf/layers.py @@ -55,7 +55,7 @@ class DilatedMask(tf.keras.layers.Layer): nodata_mask = tf.cast(tf.math.equal(inp, self.nodata_value), tf.uint8) se_size = 1 + 2 * self.radius - # Create a morphological kernel suitable for binary dilatation, see + # Create a morphological kernel suitable for binary dilatation, see # https://stackoverflow.com/q/54686895/13711499 kernel = tf.zeros((se_size, se_size, 1), dtype=tf.uint8) conv2d_out = tf.nn.dilation2d( diff --git a/otbtf/ops.py b/otbtf/ops.py index 4a8d0b96..ef5c52b9 100644 --- a/otbtf/ops.py +++ b/otbtf/ops.py @@ -30,6 +30,8 @@ import tensorflow as tf Tensor = Any Scalars = List[float] | Tuple[float] + + def one_hot(labels: Tensor, nb_classes: int): """ Converts labels values into one-hot vector. @@ -43,4 +45,4 @@ def one_hot(labels: Tensor, nb_classes: int): """ labels_xy = tf.squeeze(tf.cast(labels, tf.int32), axis=-1) # shape [x, y] - return tf.one_hot(labels_xy, depth=nb_classes) # shape [x, y, nb_classes] \ No newline at end of file + return tf.one_hot(labels_xy, depth=nb_classes) # shape [x, y, nb_classes] -- GitLab