Skip to content

Latest commit

 

History

History
29 lines (22 loc) · 1.01 KB

README.md

File metadata and controls

29 lines (22 loc) · 1.01 KB

This is the optimized version of focal loss in MXNet, which is modified from unsky/focal-loss and speed up 30% than Original implement during the training.

from focal_loss_OptimizedVersion import *
label = mx.sym.Variable('focalloss_label')
net = mx.symbol.Custom(data=net, op_type='FocalLoss', labels = label, name='focalloss', alpha=0.25, gamma=2)
  • Apart from focal_loss_OptimizedVersion.py, I alse provide metric.py for presenting focal loss value by taking image classification as example:
from metric import *
eval_metric = mx.metric.CompositeEvalMetric()
eval_metric.add(FocalLoss())

model = mx.mod.Module(
        context=mx.gpu(0),
        symbol=symbol,
        label_names=('focalloss_label',)
    )

model.fit(...,
	  eval_metric=eval_metric,
	  ...)

Attention: The value of alpha and gamma in metric.py should be equal to mx.symbol.Custom(...,alpha, gamma)