Skip to content

miraclewkf/FocalLoss-MXNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 

Repository files navigation

This is the optimized version of focal loss in MXNet, which is modified from unsky/focal-loss and speed up 30% than Original implement during the training.

from focal_loss_OptimizedVersion import *
label = mx.sym.Variable('focalloss_label')
net = mx.symbol.Custom(data=net, op_type='FocalLoss', labels = label, name='focalloss', alpha=0.25, gamma=2)
  • Apart from focal_loss_OptimizedVersion.py, I alse provide metric.py for presenting focal loss value by taking image classification as example:
from metric import *
eval_metric = mx.metric.CompositeEvalMetric()
eval_metric.add(FocalLoss())

model = mx.mod.Module(
        context=mx.gpu(0),
        symbol=symbol,
        label_names=('focalloss_label',)
    )

model.fit(...,
	  eval_metric=eval_metric,
	  ...)

Attention: The value of alpha and gamma in metric.py should be equal to mx.symbol.Custom(...,alpha, gamma)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages