Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to recent version of tensorflow and keras 2.13 #582

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions segmentation_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import functools
from tensorflow import keras
from .__version__ import __version__
from . import base

_KERAS_FRAMEWORK_NAME = 'keras'
_KERAS_FRAMEWORK_NAME = 'tf.keras'
_TF_KERAS_FRAMEWORK_NAME = 'tf.keras'

_DEFAULT_KERAS_FRAMEWORK = _KERAS_FRAMEWORK_NAME
Expand Down Expand Up @@ -64,14 +65,12 @@ def set_framework(name):
name = name.lower()

if name == _KERAS_FRAMEWORK_NAME:
import keras
import efficientnet.keras # init custom objects
elif name == _TF_KERAS_FRAMEWORK_NAME:
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers, models, utils, losses
import efficientnet.tfkeras # init custom objects
else:
raise ValueError('Not correct module name `{}`, use `{}` or `{}`'.format(
name, _KERAS_FRAMEWORK_NAME, _TF_KERAS_FRAMEWORK_NAME))
raise ValueError('Not correct module name `{}`, use `{}`'.format(
name, _TF_KERAS_FRAMEWORK_NAME))

global _KERAS_BACKEND, _KERAS_LAYERS, _KERAS_MODELS
global _KERAS_UTILS, _KERAS_LOSSES, _KERAS_FRAMEWORK
Expand Down
1 change: 0 additions & 1 deletion segmentation_models/backbones/backbones_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ class BackbonesFactory(ModelsFactory):
'block3a_expand_activation', 'block2a_expand_activation'),

}

_models_update = {
'inceptionresnetv2': [irv2.InceptionResNetV2, irv2.preprocess_input],
'inceptionv3': [iv3.InceptionV3, iv3.preprocess_input],
Expand Down
31 changes: 31 additions & 0 deletions segmentation_models/base/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,37 @@ def iou_score(gt, pr, class_weights=1., class_indexes=None, smooth=SMOOTH, per_i

return score

def dice_score(gt, pr, class_weights=1., class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs):
"""
Calculate the Dice coefficient, a measure of set similarity.

Args:
gt: Ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W).
pr: Prediction 4D keras tensor (B, H, W, C) or (B, C, H, W).
class_weights: 1. or list of class weights, len(weights) = C.
class_indexes: Optional integer or list of integers, classes to consider, if `None` all classes are used.
smooth: Value to avoid division by zero.
per_image: If `True`, metric is calculated as mean over images in batch (B), else over whole batch.
threshold: Value to round predictions (use `>` comparison), if `None` prediction will not be rounded.

Returns:
Dice score in range [0, 1].
"""

backend = kwargs['backend']

gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
pr = round_if_needed(pr, threshold, **kwargs)
axes = get_reduce_axes(per_image, **kwargs)

# Adjusted score calculation for Dice coefficient
intersection = backend.sum(gt * pr, axis=axes)
sum_gt_pr = backend.sum(gt, axis=axes) + backend.sum(pr, axis=axes)

score = (2 * intersection + smooth) / (sum_gt_pr + smooth)
score = average(score, per_image, class_weights, **kwargs)

return score

def f_score(gt, pr, beta=1, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None,
**kwargs):
Expand Down
69 changes: 64 additions & 5 deletions segmentation_models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,64 @@ def __call__(self, gt, pr):
**self.submodules
)

class DICEScore(Metric):
r""" The `Dice coefficient`_, also known as the Sørensen–Dice coefficient or Dice similarity,
is a statistic used for gauging the similarity of two samples. It's often used in the context of
binary and multiclass segmentation problems. The Dice coefficient is defined as twice the size
of the intersection divided by the sum of the sizes of the two sample sets:

.. math:: DSC(A, B) = \frac{2 |A \cap B|}{|A| + |B|}

Args:
class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``).
class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
smooth: Value to avoid division by zero.
per_image: If ``True``, metric is calculated as mean over images in batch (B),
else over whole batch.
threshold: Value to round predictions (use ``>`` comparison), if ``None`` prediction will not be rounded.

Returns:
A callable ``dice_score`` instance. Can be used in the ``model.compile(...)`` function.

.. _`Dice coefficient`: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

Example:

.. code:: python

metric = DICEScore()
model.compile('SGD', loss=loss, metrics=[metric])
"""

def __init__(
self,
class_weights=None,
class_indexes=None,
threshold=None,
per_image=False,
smooth=SMOOTH, # SMOOTH should be defined elsewhere in your code
name=None,
):
name = name or 'dice_score'
super().__init__(name=name)
self.class_weights = class_weights if class_weights is not None else 1
self.class_indexes = class_indexes
self.threshold = threshold
self.per_image = per_image
self.smooth = smooth

def __call__(self, gt, pr):
return F.dice_score( # Assuming F.dice_score is your implementation of Dice score
gt,
pr,
class_weights=self.class_weights,
class_indexes=self.class_indexes,
smooth=self.smooth,
per_image=self.per_image,
threshold=self.threshold,
**self.submodules
)


class FScore(Metric):
r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall,
Expand Down Expand Up @@ -256,8 +314,9 @@ def __call__(self, gt, pr):


# aliases
iou_score = IOUScore()
f1_score = FScore(beta=1)
f2_score = FScore(beta=2)
precision = Precision()
recall = Recall()
iou_score = IOUScore()
dice_score = DICEScore()
f1_score = FScore(beta=1)
f2_score = FScore(beta=2)
precision = Precision()
recall = Recall()