Skip to content

Commit

Permalink
feat: add support for sensor data classification
Browse files Browse the repository at this point in the history
  • Loading branch information
LynnL4 committed Mar 6, 2024
1 parent fb1a65c commit 41a94a3
Show file tree
Hide file tree
Showing 14 changed files with 275 additions and 128 deletions.
5 changes: 4 additions & 1 deletion configs/_base_/default_runtime_cls.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# defaults to use registries in mmpretrain
default_scope = 'sscma'

# defaults input type image
input_type = 'image'

# ========================Suggested optional parameters========================
# RUNING
# RUNNING
# Model validation interval in epoch
val_interval = 5
# Model weight saving interval in epochs
Expand Down
6 changes: 5 additions & 1 deletion configs/_base_/default_runtime_det.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
default_scope = 'mmdet'

# defaults input type image
input_type = 'image'

# ========================Suggested optional parameters========================
# RUNING
# RUNNING
# Model validation interval in epoch
val_interval = 5
# Model weight saving interval in epochs
Expand Down
6 changes: 5 additions & 1 deletion configs/_base_/default_runtime_pose.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
default_scope = 'sscma'

# defaults input type image
input_type = 'image'

# ========================Suggested optional parameters========================
# RUNING
# RUNNING
# Model validation interval in epoch
val_interval = 5
# Model weight saving interval in epochs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
_base_ = '../_base_/default_runtime_cls.py'
_base_ = './base.py'

default_scope = 'sscma'
# ========================Suggested optional parameters========================
# MODEL
num_classes = 3
num_axes = 3
window_size = 30
window_size = 62
stride = 20

# DATA
dataset_type = 'sscma.SensorDataset'
data_root = 'datasets/aixs-export'
data_root = 'datasets/sensor-export'
train_ann = 'info.labels'
train_data = 'training'
val_ann = 'info.labels'
Expand Down
30 changes: 30 additions & 0 deletions configs/accelerometer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
_base_ = '../_base_/default_runtime_cls.py'

# defaults input type image
input_type = 'sensor'

# ========================Suggested optional parameters========================
# RUNNING
# Model validation interval in epoch
val_interval = 5
# Model weight saving interval in epochs
save_interval = val_interval

# defaults to use registries in mmpretrain
default_scope = 'sscma'
# ================================END=================================
# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type='IterTimerHook'),
# print log every 100 iterations.
logger=dict(type='TextLoggerHook', interval=100),
# enable the parameter scheduler.
param_scheduler=dict(type='ParamSchedulerHook'),
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', save_best='auto', interval=save_interval),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type='DistSamplerSeedHook'),
# validation results visualization, set True to enable it.
visualization=dict(type='sscma.SensorVisualizationHook', enable=False),
)
21 changes: 13 additions & 8 deletions sscma/datasets/sensordataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(

super().__init__(ann_file=ann_file, metainfo=metainfo, data_root=data_root, data_prefix=data_prefix, **kwargs)

self._metainfo = {'classes': self.get_classes()}

def get_classes(self, classes=None):
if classes is not None:
return classes
Expand Down Expand Up @@ -89,26 +91,28 @@ def load_data_list(self):
data_list = []
for filename, gt_label in samples:
ann_path = os.path.join(self.data_dir, filename)
sensors, data_set = self.read_split_data(ann_path)
data_list.extend(
[{'data': np.asanyarray([data]), 'gt_label': int(gt_label)} for data in self.read_split_data(ann_path)]
[{'data': np.asanyarray([data]), 'gt_label': int(gt_label), 'sensors': sensors} for data in data_set]
)

return data_list

def read_split_data(self, file_path: str) -> List:
if file_path.lower().endswith('.cbor'):
with open(file_path, 'rb') as f:
data = cbor.loads(f.read())
info_lables = cbor.loads(f.read())
elif file_path.lower().endswith('.json'):
with open(file_path, 'r') as f:
data = json.load(f)
info_lables = json.load(f)

values = np.asanyarray(data['payload']['values'])
values = np.asanyarray(info_lables['payload']['values'])
sensors = info_lables['payload']['sensors']

result = []
data_set = []
values_len = len(values)
if values_len <= self.window_size:
result.append(self.pad_data(values, self.window_size).transpose(0, 1).reshape(-1))
data_set.append(self.pad_data(values, self.window_size).transpose(0, 1).reshape(-1))
else:
indexes = range(0, values_len, self.stride)
for i in indexes:
Expand All @@ -128,8 +132,9 @@ def read_split_data(self, file_path: str) -> List:
data = values[i:end]
if self.flatten:
data = data.transpose(0, 1).reshape(-1)
result.append(data)
return result
data_set.append(data)

return sensors, data_set

def pad_data(self, data: np.asanyarray, total_len: int, mode='constant', pad_val=0) -> np.array:
pad_len = total_len - len(data)
Expand Down
9 changes: 7 additions & 2 deletions sscma/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
TextLoggerHook,
WandbLoggerHook,
)
from .visualization_hook import DetFomoVisualizationHook, Posevisualization
from .semihook import SemiHook
from .visualization_hook import (
DetFomoVisualizationHook,
Posevisualization,
SensorVisualizationHook,
)

__all__ = [
'TextLoggerHook',
Expand All @@ -15,5 +19,6 @@
'ClearMLLoggerHook',
'Posevisualization',
'DetFomoVisualizationHook',
"SemiHook",
'SensorVisualizationHook',
'SemiHook',
]
106 changes: 102 additions & 4 deletions sscma/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os
import os.path as osp
import warnings
Expand All @@ -6,10 +7,11 @@
import mmcv
import mmengine
import mmengine.fileio as fileio
from mmcls.structures import ClsDataSample
from mmdet.engine.hooks import DetVisualizationHook
from mmengine.hooks import Hook
from mmengine.hooks.hook import DATA_BATCH
from mmengine.runner import Runner
from mmengine.runner import EpochBasedTrainLoop, Runner
from mmengine.visualization import Visualizer
from mmpose.structures import PoseDataSample, merge_data_samples

Expand Down Expand Up @@ -66,7 +68,7 @@ def after_test_iter(
self.out_dir = os.path.join(runner.work_dir, runner.timestamp, self.out_dir)
mmengine.mkdir_or_exist(self.out_dir)

self._visualizer.set_dataset_meta(runner.test_evaluator.dataset_meta)
self._visualizer.set_data_setet_meta(runner.test_evaluator.data_setet_meta)

for data_sample in outputs:
self._test_index += 1
Expand All @@ -83,7 +85,7 @@ def after_test_iter(
index = len([fname for fname in os.listdir(self.out_dir) if fname.startswith(out_file_name)])
out_file = f'{out_file_name}_{index}.{postfix}'
out_file = os.path.join(self.out_dir, out_file)
self._visualizer.add_datasample(
self._visualizer.add_data_setample(
os.path.basename(img_path) if self.show else 'test_img',
img,
data_sample=data_sample,
Expand Down Expand Up @@ -121,7 +123,7 @@ def after_val_iter(
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')

if total_curr_iter % self.interval == 0:
self._visualizer.add_datasample(
self._visualizer.add_data_setample(
osp.basename(img_path) if self.show else 'val_img',
img,
data_sample=outputs[0],
Expand All @@ -140,3 +142,99 @@ def after_test_iter(
pass
else:
return super().after_test_iter(runner, batch_idx, data_batch, outputs)


@HOOKS.register_module()
class SensorVisualizationHook(Hook):
"""Sensor Classification Visualization Hook. Used to visualize validation
and testing prediction results.
- If ``out_dir`` is specified, all storage backends are ignored
and save the image to the ``out_dir``.
- If ``show`` is True, plot the result image in a window, please
confirm you are able to access the graphical interface.
Args:
enable (bool): Whether to enable this hook. Defaults to False.
interval (int): The interval of samples to visualize. Defaults to 5000.
show (bool): Whether to display the drawn image. Defaults to False.
out_dir (str, optional): directory where painted images will be saved
in the testing process. If None, handle with the backends of the
visualizer. Defaults to None.
**kwargs: other keyword arguments of
:meth:`mmcls.visualization.ClsVisualizer.add_data_setample`.
"""

def __init__(self, enable=False, interval: int = 5000, show: bool = False, out_dir: Optional[str] = None, **kwargs):
self._visualizer: Visualizer = Visualizer.get_current_instance()

self.enable = enable
self.interval = interval
self.show = show
self.out_dir = out_dir

self.draw_args = {**kwargs, 'show': show}

def _draw_samples(
self, batch_idx: int, data_batch: dict, data_samples: Sequence[ClsDataSample], step: int = 0
) -> None:
"""Visualize every ``self.interval`` samples from a data batch.
Args:
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
step (int): Global step value to record. Defaults to 0.
"""
if self.enable is False:
return

batch_size = len(data_samples)
data_set = data_batch['inputs']
start_idx = batch_size * batch_idx
end_idx = start_idx + batch_size

# The first index divisible by the interval, after the start index
first_sample_id = math.ceil(start_idx / self.interval) * self.interval

for sample_id in range(first_sample_id, end_idx, self.interval):
data_sample = data_samples[sample_id - start_idx]
sample_name = str(sample_id)
self._visualizer.add_data_setample(
sample_name,
data=data_set[sample_id - first_sample_id],
data_sample=data_sample,
step=step,
**self.draw_args,
)

def after_val_iter(
self, runner: Runner, batch_idx: int, data_batch: dict, outputs: Sequence[ClsDataSample]
) -> None:
"""Visualize every ``self.interval`` samples during validation.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
"""
if isinstance(runner.train_loop, EpochBasedTrainLoop):
step = runner.epoch
else:
step = runner.iter

self._draw_samples(batch_idx, data_batch, outputs, step=step)

def after_test_iter(
self, runner: Runner, batch_idx: int, data_batch: dict, outputs: Sequence[ClsDataSample]
) -> None:
"""Visualize every ``self.interval`` samples during test.
Args:
runner (:obj:`Runner`): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`Detdata_setample`]): Outputs from model.
"""
self._draw_samples(batch_idx, data_batch, outputs, step=0)
4 changes: 2 additions & 2 deletions sscma/models/backbones/AxesNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class AxesNet(nn.Module):
def __init__(self, num_axes=3, window_size=80, num_classes=-1): # axes number # sample frequency # window size
super().__init__()
self.num_classes = num_classes
self.intput_feature = num_axes * window_size
self.intput_feature = int(num_axes * window_size)
liner_feature = self.liner_feature_fit()
self.fc1 = nn.Linear(in_features=self.intput_feature, out_features=liner_feature, bias=True)
self.fc2 = nn.Linear(in_features=liner_feature, out_features=liner_feature, bias=True)
Expand All @@ -23,7 +23,7 @@ def liner_feature_fit(self):
def forward(self, x):
x = x[0] if isinstance(x, list) else x
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
# x = F.relu(self.fc2(x))

if self.num_classes > 0:
x = self.classifier(x)
Expand Down
2 changes: 1 addition & 1 deletion sscma/models/classifiers/accelerometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward(self, inputs: torch.Tensor, data_samples: Optional[List[ClsDataSampl
head_out = F.softmax(head_out, dim=-1)
else:
head_out = torch.sigmoid(head_out)
return
return head_out
elif mode == 'loss':
return self.loss(inputs, data_samples)
elif mode == 'predict':
Expand Down
Loading

0 comments on commit 41a94a3

Please sign in to comment.