Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of ASFormer (not fully completed) #2642

Open
wants to merge 4 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions configs/_base_/models/asformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# model settings
model = dict(
type='ASFormer',
num_layers=10,
num_f_maps=64,
input_dim=2048,
num_decoders=3,
num_classes=11,
channel_masking_rate=0.5,
sample_rate=1,
r1=2,
r2=2)
106 changes: 106 additions & 0 deletions configs/segmentation/asformer/asformer_gtea.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
_base_ = ['../../_base_/models/asformer.py', '../../_base_/default_runtime.py']
dataset_type = 'ActionSegmentDataset'
data_root = 'data/gtea/csv_mean_100/'
data_root_val = 'data/action_seg/gtea/'
ann_file_train = 'data/action_seg/gtea/splits/train.split1.bundle'
ann_file_val = 'data/action_seg/gtea/splits/test.split1.bundle'

ann_file_test = 'data/action_seg/gtea/splits/test.split1.bundle'

train_pipeline = [
dict(type='LoadSegmentationFeature'),
dict(type='GenerateSegmentationLabels'),
dict(
type='PackLocalizationInputs',
keys=('gt_bbox', ),
meta_keys=('video_name', ))
]

val_pipeline = [
dict(type='LoadSegmentationFeature'),
dict(type='GenerateSegmentationLabels'),
dict(
type='PackLocalizationInputs',
keys=('gt_bbox', ),
meta_keys=('video_name', 'duration_second', 'duration_frame',
'annotations', 'feature_frame'))
]

test_pipeline = [
dict(type='LoadSegmentationFeature'),
dict(
type='PackSegmentationInputs',
keys=('classes', ),
meta_keys=('num_classes', 'actions_dict', 'index2label',
'ground_truth', 'classes'))
]

train_dataloader = dict(
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
drop_last=True,
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=dict(video=data_root),
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=1,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=dict(video=data_root_val),
pipeline=val_pipeline,
test_mode=True))

test_dataloader = dict(
batch_size=1,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=dict(video=data_root_val),
pipeline=test_pipeline,
test_mode=True))

max_epochs = 9
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_begin=1,
val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

optim_wrapper = dict(
optimizer=dict(type='Adam', lr=0.001, weight_decay=0.0001),
clip_grad=dict(max_norm=40, norm_type=2))

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[
7,
],
gamma=0.1)
]

work_dir = './work_dirs/bmn_400x100_2x8_9e_activitynet_feature/'
load_from = './work_dirs/bmn_400x100_2x8_9e_activitynet_feature/epoch-120.pth'
test_evaluator = dict(
type='SegmentMetric',
metric_type='ALL',
dump_config=dict(out=f'{work_dir}/results.json', output_format='json'))
val_evaluator = test_evaluator
4 changes: 3 additions & 1 deletion mmaction/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .action_segment_dataset import ActionSegmentDataset
from .activitynet_dataset import ActivityNetDataset
from .audio_dataset import AudioDataset
from .ava_dataset import AVADataset, AVAKineticsDataset
Expand All @@ -13,5 +14,6 @@
__all__ = [
'AVADataset', 'AVAKineticsDataset', 'ActivityNetDataset', 'AudioDataset',
'BaseActionDataset', 'PoseDataset', 'RawframeDataset', 'RepeatAugDataset',
'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset'
'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset',
'ActionSegmentDataset'
]
63 changes: 63 additions & 0 deletions mmaction/datasets/action_segment_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional, Union

from mmengine.fileio import exists

from mmaction.registry import DATASETS
from mmaction.utils import ConfigType
from .base import BaseActionDataset


@DATASETS.register_module()
class ActionSegmentDataset(BaseActionDataset):

def __init__(self,
ann_file: str,
pipeline: List[Union[dict, Callable]],
data_prefix: Optional[ConfigType] = dict(video=''),
test_mode: bool = False,
**kwargs):

super().__init__(
ann_file,
pipeline=pipeline,
data_prefix=data_prefix,
test_mode=test_mode,
**kwargs)

def load_data_list(self) -> List[dict]:
"""Load annotation file to get video information."""
exists(self.ann_file)
file_ptr = open(self.ann_file, 'r')
list_of_examples = file_ptr.read().split('\n')[:-1]
file_ptr.close()
gts = [
self.data_prefix['video'] + 'groundTruth/' + vid
for vid in list_of_examples
]
features_npy = [
self.data_prefix['video'] + 'features/' + vid.split('.')[0] +
'.npy' for vid in list_of_examples
]
data_list = []

file_ptr_d = open(self.data_prefix['video'] + '/mapping.txt', 'r')
actions = file_ptr_d.read().split('\n')[:-1]
file_ptr.close()
actions_dict = dict()
for a in actions:
actions_dict[a.split()[1]] = int(a.split()[0])
index2label = dict()
for k, v in actions_dict.items():
index2label[v] = k
num_classes = len(actions_dict)
for idx, feature in enumerate(features_npy):
video_info = dict()
feature_path = features_npy[idx]
video_info['feature_path'] = feature_path
video_info['actions_dict'] = actions_dict
video_info['index2label'] = index2label
video_info['ground_truth_path'] = gts[idx]
video_info['num_classes'] = num_classes
data_list.append(video_info)
return data_list
13 changes: 8 additions & 5 deletions mmaction/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .formatting import (FormatAudioShape, FormatGCNInput, FormatShape,
PackActionInputs, PackLocalizationInputs, Transpose)
PackActionInputs, PackLocalizationInputs,
PackSegmentationInputs, Transpose)
from .loading import (ArrayDecode, AudioFeatureSelector, BuildPseudoClip,
DecordDecode, DecordInit, DenseSampleFrames,
GenerateLocalizationLabels, ImageDecode,
LoadAudioFeature, LoadHVULabel, LoadLocalizationFeature,
LoadProposals, LoadRGBFromFile, OpenCVDecode, OpenCVInit,
GenerateLocalizationLabels, GenerateSegmentationLabels,
ImageDecode, LoadAudioFeature, LoadHVULabel,
LoadLocalizationFeature, LoadProposals, LoadRGBFromFile,
LoadSegmentationFeature, OpenCVDecode, OpenCVInit,
PIMSDecode, PIMSInit, PyAVDecode, PyAVDecodeMotionVector,
PyAVInit, RawFrameDecode, SampleAVAFrames, SampleFrames,
UniformSample, UntrimmedSampleFrames)
Expand Down Expand Up @@ -37,5 +39,6 @@
'SampleAVAFrames', 'SampleFrames', 'TenCrop', 'ThreeCrop', 'ToMotion',
'TorchVisionWrapper', 'Transpose', 'UniformSample', 'UniformSampleFrames',
'UntrimmedSampleFrames', 'MMUniformSampleFrames', 'MMDecode', 'MMCompact',
'CLIPTokenize'
'CLIPTokenize', 'LoadSegmentationFeature', 'GenerateSegmentationLabels',
'PackSegmentationInputs'
]
59 changes: 59 additions & 0 deletions mmaction/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,65 @@ def __repr__(self) -> str:
return repr_str


@TRANSFORMS.register_module()
class PackSegmentationInputs(BaseTransform):

def __init__(self, keys=(), meta_keys=('video_name', )):
self.keys = keys
self.meta_keys = meta_keys

def transform(self, results):
"""Method to pack the input data.

Args:
results (dict): Result dict from the data pipeline.

Returns:
dict:

- 'inputs' (obj:`torch.Tensor`): The forward data of models.
- 'data_samples' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
if 'raw_feature' in results:
raw_feature = results['raw_feature']
packed_results['inputs'] = to_tensor(raw_feature)
elif 'bsp_feature' in results:
packed_results['inputs'] = torch.tensor(0.)
else:
raise ValueError(
'Cannot get "raw_feature" or "bsp_feature" in the input '
'dict of `PackActionInputs`.')

data_sample = ActionDataSample()
for key in self.keys:
if key not in results:
continue
if key == 'classes':
instance_data = InstanceData()
instance_data[key] = to_tensor(results[key])
data_sample.gt_instances = instance_data
elif key == 'proposals':
instance_data = InstanceData()
instance_data[key] = to_tensor(results[key])
data_sample.proposals = instance_data
else:
raise NotImplementedError(
f"Key '{key}' is not supported in `PackLocalizationInputs`"
)

img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results

def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str


@TRANSFORMS.register_module()
class Transpose(BaseTransform):
"""Transpose image channels to a given order.
Expand Down
71 changes: 71 additions & 0 deletions mmaction/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,6 +1854,77 @@ def transform(self, results):
return results


@TRANSFORMS.register_module()
class LoadSegmentationFeature(BaseTransform):
"""Load Video features for Segmentation with given video_name list.

The required key is "feature_path", added or modified keys
are "raw_feature".

Args:
raw_feature_ext (str): Raw feature file extension. Default: '.csv'.
"""

def transform(self, results):
"""Perform the LoadSegmentationFeature loading.

Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
raw_feature = np.load(results['feature_path'])
file_ptr = open(results['ground_truth_path'], 'r')
content = file_ptr.read().split('\n')[:-1]
classes = np.zeros(min(np.shape(raw_feature)[1], len(content)))
for i in range(len(classes)):
classes[i] = results['actions_dict'][content[i]]

results['raw_feature'] = raw_feature
results['ground_truth'] = content
results['classes'] = classes

return results

def __repr__(self):
repr_str = f'{self.__class__.__name__}'
return repr_str


@TRANSFORMS.register_module()
class GenerateSegmentationLabels(BaseTransform):
"""Load video label for localizer with given video_name list.

Required keys are "duration_frame", "duration_second", "feature_frame",
"annotations", added or modified keys are "gt_bbox".
"""

def transform(self, results):
"""Perform the GenerateLocalizationLabels loading.

Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
video_frame = results['duration_frame']
video_second = results['duration_second']
feature_frame = results['feature_frame']
corrected_second = float(feature_frame) / video_frame * video_second
annotations = results['annotations']

gt_bbox = []

for annotation in annotations:
current_start = max(
min(1, annotation['segment'][0] / corrected_second), 0)
current_end = max(
min(1, annotation['segment'][1] / corrected_second), 0)
gt_bbox.append([current_start, current_end])

gt_bbox = np.array(gt_bbox)
results['gt_bbox'] = gt_bbox
return results


@TRANSFORMS.register_module()
class LoadProposals(BaseTransform):
"""Loading proposals with given proposal results.
Expand Down
3 changes: 2 additions & 1 deletion mmaction/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from .ava_metric import AVAMetric
from .multisports_metric import MultiSportsMetric
from .retrieval_metric import RetrievalMetric
from .segment_metric import SegmentMetric

__all__ = [
'AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix',
'MultiSportsMetric', 'RetrievalMetric'
'MultiSportsMetric', 'RetrievalMetric', 'SegmentMetric'
]
Loading