diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000000..c487f62da7 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,622 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Specify a score threshold to be exceeded before program exits with error. +fail-under=10 + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS,configs + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=print-statement, + parameter-unpacking, + unpacking-in-except, + old-raise-syntax, + backtick, + long-suffix, + old-ne-operator, + old-octal-literal, + import-star-module-level, + non-ascii-bytes-literal, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + apply-builtin, + basestring-builtin, + buffer-builtin, + cmp-builtin, + coerce-builtin, + execfile-builtin, + file-builtin, + long-builtin, + raw_input-builtin, + reduce-builtin, + standarderror-builtin, + unicode-builtin, + xrange-builtin, + coerce-method, + delslice-method, + getslice-method, + setslice-method, + no-absolute-import, + old-division, + dict-iter-method, + dict-view-method, + next-method-called, + metaclass-assignment, + indexing-exception, + raising-string, + reload-builtin, + oct-method, + hex-method, + nonzero-method, + cmp-method, + input-builtin, + round-builtin, + intern-builtin, + unichr-builtin, + map-builtin-not-iterating, + zip-builtin-not-iterating, + range-builtin-not-iterating, + filter-builtin-not-iterating, + using-cmp-argument, + eq-without-hash, + div-method, + idiv-method, + rdiv-method, + exception-message-attribute, + invalid-str-codec, + sys-max-int, + bad-python3-import, + deprecated-string-function, + deprecated-str-translate-call, + deprecated-itertools-function, + deprecated-types-field, + next-method-defined, + dict-items-not-iterating, + dict-keys-not-iterating, + dict-values-not-iterating, + deprecated-operator-function, + deprecated-urllib-function, + xreadlines-attribute, + deprecated-sys-function, + exception-escape, + comprehension-escape, + no-member, + invalid-name, + too-many-branches, + wrong-import-order, + too-many-arguments, + missing-function-docstring, + missing-module-docstring, + too-many-locals, + too-few-public-methods, + abstract-method, + broad-except, + too-many-nested-blocks, + too-many-instance-attributes, + missing-class-docstring, + duplicate-code, + not-callable, + protected-access, + dangerous-default-value, + no-name-in-module, + logging-fstring-interpolation, + super-init-not-called, + redefined-builtin, + attribute-defined-outside-init, + arguments-differ, + cyclic-import, + bad-super-call, + too-many-statements, + line-too-long + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +#notes-rgx= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _, + x, + y, + w, + h, + a, + b + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +#variable-rgx= + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/mmaction/apis/inference.py b/mmaction/apis/inference.py index 89d45ef3f5..46f02736af 100644 --- a/mmaction/apis/inference.py +++ b/mmaction/apis/inference.py @@ -72,7 +72,7 @@ def inference_recognizer(model, video_path, label_path, use_frames=False): if osp.isfile(video_path) and use_frames: raise RuntimeError( f"'{video_path}' is a video file, not a rawframe directory") - elif osp.isdir(video_path) and not use_frames: + if osp.isdir(video_path) and not use_frames: raise RuntimeError( f"'{video_path}' is a rawframe directory, not a video file") diff --git a/mmaction/apis/test.py b/mmaction/apis/test.py index 77c4345041..2541343feb 100644 --- a/mmaction/apis/test.py +++ b/mmaction/apis/test.py @@ -125,21 +125,20 @@ def collect_results_cpu(result_part, size, tmpdir=None): # collect all parts if rank != 0: return None - else: - # load results of all parts from tmp dir - part_list = [] - for i in range(world_size): - part_file = osp.join(tmpdir, f'part_{i}.pkl') - part_list.append(mmcv.load(part_file)) - # sort the results - ordered_results = [] - for res in zip(*part_list): - ordered_results.extend(list(res)) - # the dataloader may pad some samples - ordered_results = ordered_results[:size] - # remove tmp dir - shutil.rmtree(tmpdir) - return ordered_results + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, f'part_{i}.pkl') + part_list.append(mmcv.load(part_file)) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results def collect_results_gpu(result_part, size): @@ -185,3 +184,4 @@ def collect_results_gpu(result_part, size): # the dataloader may pad some samples ordered_results = ordered_results[:size] return ordered_results + return None diff --git a/mmaction/core/evaluation/accuracy.py b/mmaction/core/evaluation/accuracy.py index 83ddf59364..496e39c247 100644 --- a/mmaction/core/evaluation/accuracy.py +++ b/mmaction/core/evaluation/accuracy.py @@ -120,9 +120,8 @@ def mmit_mean_average_precision(scores, labels): np.float: The MMIT style mean average precision. """ results = [] - for i in range(len(scores)): - precision, recall, _ = binary_precision_recall_curve( - scores[i], labels[i]) + for score, label in zip(scores, labels): + precision, recall, _ = binary_precision_recall_curve(score, label) ap = -np.sum(np.diff(recall) * np.array(precision)[:-1]) results.append(ap) return np.mean(results) @@ -144,9 +143,8 @@ def mean_average_precision(scores, labels): scores = np.stack(scores).T labels = np.stack(labels).T - for i in range(len(scores)): - precision, recall, _ = binary_precision_recall_curve( - scores[i], labels[i]) + for score, label in zip(scores, labels): + precision, recall, _ = binary_precision_recall_curve(score, label) ap = -np.sum(np.diff(recall) * np.array(precision)[:-1]) results.append(ap) results = [x for x in results if not np.isnan(x)] @@ -466,7 +464,7 @@ def average_precision_at_temporal_iou(ground_truth, for idx, this_pred in enumerate(prediction): # Check if there is at least one ground truth in the video. - if (this_pred[0] in ground_truth): + if this_pred[0] in ground_truth: this_gt = np.array(ground_truth[this_pred[0]], dtype=float) else: fp[:, idx] = 1 diff --git a/mmaction/core/evaluation/eval_detection.py b/mmaction/core/evaluation/eval_detection.py index 367c5a9ebe..158644157a 100644 --- a/mmaction/core/evaluation/eval_detection.py +++ b/mmaction/core/evaluation/eval_detection.py @@ -50,7 +50,8 @@ def __init__(self, f'Fixed threshold for tiou score: {self.tiou_thresholds}') print_log(log_msg, logger=self.logger) - def _import_ground_truth(self, ground_truth_filename): + @staticmethod + def _import_ground_truth(ground_truth_filename): """Read ground truth file and return the ground truth instances and the activity classes. diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index ababf877dc..40fd5631f2 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -80,7 +80,7 @@ def __init__(self, f'or in {self.less_keys} when rule is None, ' f'but got {key_indicator}') - if not interval > 0: + if interval <= 0: raise ValueError(f'interval must be positive, but got {interval}') if start is not None and start < 0: warnings.warn( @@ -178,8 +178,8 @@ def evaluate(self, runner, results): 'it in config file') return None return eval_res[self.key_indicator] - else: - return None + + return None class DistEpochEvalHook(EpochEvalHook): diff --git a/mmaction/core/runner/omnisource_runner.py b/mmaction/core/runner/omnisource_runner.py index 29b15923db..0209d5d0b1 100644 --- a/mmaction/core/runner/omnisource_runner.py +++ b/mmaction/core/runner/omnisource_runner.py @@ -89,7 +89,7 @@ def train(self, data_loaders, **kwargs): continue for idx, n_times in enumerate(auxiliary_iter_times): - for step in range(n_times): + for _ in range(n_times): data_batch = next(self.aux_iters[idx]) self.call_hook('before_train_iter') self.run_iter( diff --git a/mmaction/datasets/activitynet_dataset.py b/mmaction/datasets/activitynet_dataset.py index fe792773fd..d448cf19f5 100644 --- a/mmaction/datasets/activitynet_dataset.py +++ b/mmaction/datasets/activitynet_dataset.py @@ -114,7 +114,8 @@ def _import_ground_truth(self): ground_truth[video_id] = np.array(this_video_ground_truths) return ground_truth - def proposals2json(self, results, show_progress=False): + @staticmethod + def proposals2json(results, show_progress=False): """Convert all proposals to a final dict(json) format. Args: @@ -141,7 +142,8 @@ def proposals2json(self, results, show_progress=False): prog_bar.update() return result_dict - def _import_proposals(self, results): + @staticmethod + def _import_proposals(results): """Read predictions from results.""" proposals = {} num_proposals = 0 diff --git a/mmaction/datasets/audio_dataset.py b/mmaction/datasets/audio_dataset.py index fc306b0fd7..4443402a71 100644 --- a/mmaction/datasets/audio_dataset.py +++ b/mmaction/datasets/audio_dataset.py @@ -55,7 +55,7 @@ def load_annotations(self): idx += 1 # idx for label[s] label = [int(x) for x in line_split[idx:]] - assert len(label), f'missing label in line: {line}' + assert label, f'missing label in line: {line}' if self.multi_class: assert self.num_classes is not None onehot = torch.zeros(self.num_classes) diff --git a/mmaction/datasets/audio_feature_dataset.py b/mmaction/datasets/audio_feature_dataset.py index 1d0d32be9e..15daa1182c 100644 --- a/mmaction/datasets/audio_feature_dataset.py +++ b/mmaction/datasets/audio_feature_dataset.py @@ -56,7 +56,7 @@ def load_annotations(self): idx += 1 # idx for label[s] label = [int(x) for x in line_split[idx:]] - assert len(label), f'missing label in line: {line}' + assert label, f'missing label in line: {line}' if self.multi_class: assert self.num_classes is not None onehot = torch.zeros(self.num_classes) diff --git a/mmaction/datasets/base.py b/mmaction/datasets/base.py index 99817d2dfb..fbfc7f9bd9 100644 --- a/mmaction/datasets/base.py +++ b/mmaction/datasets/base.py @@ -87,7 +87,6 @@ def __init__(self, @abstractmethod def load_annotations(self): """Load the annotation according to ann_file into video_infos.""" - pass # json annotations already looks like video_infos, so for each dataset, # this func should be the same @@ -224,7 +223,8 @@ def evaluate(self, return eval_results - def dump_results(self, results, out): + @staticmethod + def dump_results(results, out): """Dump data to json/yaml/pickle strings or files.""" return mmcv.dump(results, out) @@ -241,7 +241,7 @@ def prepare_train_frames(self, idx): # prepare tensor in getitem # If HVU, type(results['label']) is dict - if self.multi_class and type(results['label']) is list: + if self.multi_class and isinstance(results['label'], list): onehot = torch.zeros(self.num_classes) onehot[results['label']] = 1. results['label'] = onehot @@ -261,7 +261,7 @@ def prepare_test_frames(self, idx): # prepare tensor in getitem # If HVU, type(results['label']) is dict - if self.multi_class and type(results['label']) is list: + if self.multi_class and isinstance(results['label'], list): onehot = torch.zeros(self.num_classes) onehot[results['label']] = 1. results['label'] = onehot @@ -276,5 +276,5 @@ def __getitem__(self, idx): """Get the sample for either training or testing given index.""" if self.test_mode: return self.prepare_test_frames(idx) - else: - return self.prepare_train_frames(idx) + + return self.prepare_train_frames(idx) diff --git a/mmaction/datasets/hvu_dataset.py b/mmaction/datasets/hvu_dataset.py index c57da859bc..b523748093 100644 --- a/mmaction/datasets/hvu_dataset.py +++ b/mmaction/datasets/hvu_dataset.py @@ -165,7 +165,7 @@ def evaluate(self, gt_labels = [ann['label'] for ann in self.video_infos] eval_results = {} - for i, category in enumerate(self.tag_categories): + for category in self.tag_categories: start_idx = self.category2startidx[category] num = self.category2num[category] diff --git a/mmaction/datasets/pipelines/augmentations.py b/mmaction/datasets/pipelines/augmentations.py index 36586347db..f947e80743 100644 --- a/mmaction/datasets/pipelines/augmentations.py +++ b/mmaction/datasets/pipelines/augmentations.py @@ -252,10 +252,7 @@ def __init__(self, flip_ratio=0.5, direction='horizontal'): self.direction = direction def __call__(self, results): - if np.random.rand() < self.flip_ratio: - flip = True - else: - flip = False + flip = np.random.rand() < self.flip_ratio results['flip'] = flip results['flip_direction'] = self.direction @@ -906,10 +903,7 @@ def __call__(self, results): if modality == 'Flow': assert self.direction == 'horizontal' - if np.random.rand() < self.flip_ratio: - flip = True - else: - flip = False + flip = np.random.rand() < self.flip_ratio results['flip'] = flip results['flip_direction'] = self.direction @@ -993,7 +987,7 @@ def __call__(self, results): results['img_norm_cfg'] = dict( mean=self.mean, std=self.std, to_bgr=self.to_bgr) return results - elif modality == 'Flow': + if modality == 'Flow': num_imgs = len(results['imgs']) assert num_imgs % 2 == 0 assert self.mean.shape[0] == 2 @@ -1019,8 +1013,7 @@ def __call__(self, results): adjust_magnitude=self.adjust_magnitude) results['img_norm_cfg'] = args return results - else: - raise NotImplementedError + raise NotImplementedError def __repr__(self): repr_str = (f'{self.__class__.__name__}(' diff --git a/mmaction/datasets/pipelines/formating.py b/mmaction/datasets/pipelines/formating.py index c7234455c5..fd5502fd06 100644 --- a/mmaction/datasets/pipelines/formating.py +++ b/mmaction/datasets/pipelines/formating.py @@ -16,16 +16,15 @@ def to_tensor(data): """ if isinstance(data, torch.Tensor): return data - elif isinstance(data, np.ndarray): + if isinstance(data, np.ndarray): return torch.from_numpy(data) - elif isinstance(data, Sequence) and not mmcv.is_str(data): + if isinstance(data, Sequence) and not mmcv.is_str(data): return torch.tensor(data) - elif isinstance(data, int): + if isinstance(data, int): return torch.LongTensor([data]) - elif isinstance(data, float): + if isinstance(data, float): return torch.FloatTensor([data]) - else: - raise TypeError(f'type {type(data)} cannot be converted to tensor.') + raise TypeError(f'type {type(data)} cannot be converted to tensor.') @PIPELINES.register_module() diff --git a/mmaction/datasets/pipelines/loading.py b/mmaction/datasets/pipelines/loading.py index 631aec3d24..bde5e94a1f 100644 --- a/mmaction/datasets/pipelines/loading.py +++ b/mmaction/datasets/pipelines/loading.py @@ -24,6 +24,7 @@ class LoadHVULabel: def __init__(self, **kwargs): self.hvu_initialized = False + self.kwargs = kwargs def init_hvu_info(self, categories, category_nums): assert len(categories) == len(category_nums) @@ -502,7 +503,8 @@ def __init__(self, self.mode = mode self.test_interval = test_interval - def _get_train_indices(self, valid_length, num_segments): + @staticmethod + def _get_train_indices(valid_length, num_segments): """Get indices of different stages of proposals in train mode. It will calculate the average interval for each segment, @@ -528,7 +530,8 @@ def _get_train_indices(self, valid_length, num_segments): return offsets - def _get_val_indices(self, valid_length, num_segments): + @staticmethod + def _get_val_indices(valid_length, num_segments): """Get indices of different stages of proposals in validation mode. It will calculate the average interval for each segment. @@ -1230,10 +1233,12 @@ def __init__(self, self.kwargs = kwargs self.file_client = None - def _zero_pad(self, shape): + @staticmethod + def _zero_pad(shape): return np.zeros(shape, dtype=np.float32) - def _random_pad(self, shape): + @staticmethod + def _random_pad(shape): # librosa load raw audio file into a distribution of -1~+1 return np.random.rand(shape).astype(np.float32) * 2 - 1 @@ -1286,10 +1291,12 @@ def __init__(self, pad_method='zero'): raise NotImplementedError self.pad_method = pad_method - def _zero_pad(self, shape): + @staticmethod + def _zero_pad(shape): return np.zeros(shape, dtype=np.float32) - def _random_pad(self, shape): + @staticmethod + def _random_pad(shape): # spectrogram is normalized into a distribution of 0~1 return np.random.rand(shape).astype(np.float32) @@ -1387,7 +1394,7 @@ def __call__(self, results): # the input should be one single image assert len(results['imgs']) == 1 im = results['imgs'][0] - for i in range(1, self.clip_len): + for _ in range(1, self.clip_len): results['imgs'].append(np.copy(im)) results['clip_len'] = self.clip_len results['num_clips'] = 1 diff --git a/mmaction/datasets/rawframe_dataset.py b/mmaction/datasets/rawframe_dataset.py index 8d7567b7cb..6c632851f2 100644 --- a/mmaction/datasets/rawframe_dataset.py +++ b/mmaction/datasets/rawframe_dataset.py @@ -138,7 +138,7 @@ def load_annotations(self): idx += 1 # idx for label[s] label = [int(x) for x in line_split[idx:]] - assert len(label), f'missing label in line: {line}' + assert label, f'missing label in line: {line}' if self.multi_class: assert self.num_classes is not None video_info['label'] = label diff --git a/mmaction/datasets/rawvideo_dataset.py b/mmaction/datasets/rawvideo_dataset.py index d4417067e1..5ba5e612eb 100644 --- a/mmaction/datasets/rawvideo_dataset.py +++ b/mmaction/datasets/rawvideo_dataset.py @@ -114,7 +114,7 @@ def sample_clip(self, results): """Sample a clip from the raw video given the sampling strategy.""" assert self.sampling_strategy in ['positive', 'random'] if self.sampling_strategy == 'positive': - assert len(results['positive_clip_inds']) + assert results['positive_clip_inds'] ind = random.choice(results['positive_clip_inds']) else: ind = random.randint(0, results['num_clips'] - 1) diff --git a/mmaction/datasets/ssn_dataset.py b/mmaction/datasets/ssn_dataset.py index 8df6793136..26cb7de436 100644 --- a/mmaction/datasets/ssn_dataset.py +++ b/mmaction/datasets/ssn_dataset.py @@ -46,13 +46,11 @@ def __init__(self, self.size_reg = None self.regression_targets = [0., 0.] - def compute_regression_targets(self, gt_list, positive_threshold): + def compute_regression_targets(self, gt_list): """Compute regression targets of positive proposals. Args: gt_list (list): The list of groundtruth instances. - positive_threshold (float): Minimum threshold of overlap of - positive/foreground proposals and groundtruths. """ # Find the groundtruth instance with the highest IOU. ious = [ @@ -328,20 +326,12 @@ def load_annotations(self): proposals=proposals)) return video_infos - def results_to_detections(self, - results, - top_k=2000, - softmax_before_filter=True, - cls_top_k=2, - **kwargs): + def results_to_detections(self, results, top_k=2000, **kwargs): """Convert prediction results into detections. Args: results (list): Prediction results. top_k (int): Number of top results. Default: 2000. - softmax_before_filter (bool): Whether to perform softmax operations - before filtering results. Default: True. - cls_top_k (int): Number of top results for each class. Default: 2. Returns: list: Detection results. @@ -360,7 +350,8 @@ def results_to_detections(self, regression_scores = results[idx]['bbox_preds'] if regression_scores is None: regression_scores = np.zeros( - len(relative_proposals), num_classes, 2, dtype=np.float32) + (len(relative_proposals), num_classes, 2), + dtype=np.float32) regression_scores = regression_scores.reshape((-1, num_classes, 2)) if top_k <= 0: @@ -449,7 +440,7 @@ def evaluate(self, if self.use_regression: self.logger.info('Performing location regression') - for class_idx in range(len(detections)): + for class_idx, _ in enumerate(detections): detections[class_idx] = { k: perform_regression(v) for k, v in detections[class_idx].items() @@ -457,7 +448,7 @@ def evaluate(self, self.logger.info('Regression finished') self.logger.info('Performing NMS') - for class_idx in range(len(detections)): + for class_idx, _ in enumerate(detections): detections[class_idx] = { k: temporal_nms(v, self.evaluater.nms) for k, v in detections[class_idx].items() @@ -466,13 +457,13 @@ def evaluate(self, # get gts all_gts = self.get_all_gts() - for class_idx in range(len(detections)): + for class_idx, _ in enumerate(detections): if class_idx not in all_gts: all_gts[class_idx] = dict() # get predictions plain_detections = {} - for class_idx in range(len(detections)): + for class_idx, _ in enumerate(detections): detection_list = [] for video, dets in detections[class_idx].items(): detection_list.extend([[video, class_idx] + x[:3] @@ -534,7 +525,8 @@ def get_all_gts(self): return gts - def get_positives(self, gts, proposals, positive_threshold, with_gt=True): + @staticmethod + def get_positives(gts, proposals, positive_threshold, with_gt=True): """Get positive/foreground proposals. Args: @@ -558,12 +550,12 @@ def get_positives(self, gts, proposals, positive_threshold, with_gt=True): positives.extend(gts) for proposal in positives: - proposal.compute_regression_targets(gts, positive_threshold) + proposal.compute_regression_targets(gts) return positives - def get_negatives(self, - proposals, + @staticmethod + def get_negatives(proposals, incomplete_iou_threshold, background_iou_threshold, background_coverage_threshold=0.01, @@ -648,14 +640,11 @@ def sample_video_proposals(proposal_type, video_id, video_pool, idx = np.random.choice( len(dataset_pool), num_requested_proposals, replace=False) return [(dataset_pool[x], proposal_type) for x in idx] - else: - replicate = len(video_pool) < num_requested_proposals - idx = np.random.choice( - len(video_pool), - num_requested_proposals, - replace=replicate) - return [((video_id, video_pool[x]), proposal_type) - for x in idx] + + replicate = len(video_pool) < num_requested_proposals + idx = np.random.choice( + len(video_pool), num_requested_proposals, replace=replicate) + return [((video_id, video_pool[x]), proposal_type) for x in idx] out_proposals = [] out_proposals.extend( @@ -782,7 +771,7 @@ def prepare_train_frames(self, idx): num_frames = proposal[0][1].num_video_frames (starting_scale_factor, ending_scale_factor, - stage_split) = self._get_stage(proposal[0][1], num_frames) + _) = self._get_stage(proposal[0][1], num_frames) # proposal[1]: Type id of proposal. # Positive/Foreground: 0 diff --git a/mmaction/localization/proposal_utils.py b/mmaction/localization/proposal_utils.py index 4f9f8ef192..0cf0111fcf 100644 --- a/mmaction/localization/proposal_utils.py +++ b/mmaction/localization/proposal_utils.py @@ -72,7 +72,8 @@ def soft_nms(proposals, alpha, low_threshold, high_threshold, top_k): iou_list = temporal_iou(tstart[max_index], tend[max_index], np.array(tstart), np.array(tend)) iou_exp_list = np.exp(-np.square(iou_list) / alpha) - for idx in range(len(tscore)): + + for idx, _ in enumerate(tscore): if idx != max_index: current_iou = iou_list[idx] if current_iou > low_threshold + (high_threshold - diff --git a/mmaction/localization/ssn_utils.py b/mmaction/localization/ssn_utils.py index a4e534c79e..0c1f528d18 100644 --- a/mmaction/localization/ssn_utils.py +++ b/mmaction/localization/ssn_utils.py @@ -159,7 +159,7 @@ def eval_ap(detections, gt_by_cls, iou_range): ap_values = np.zeros((len(detections), len(iou_range))) for iou_idx, min_overlap in enumerate(iou_range): - for class_idx in range(len(detections)): + for class_idx, _ in enumerate(detections): ap = average_precision_at_temporal_iou(gt_by_cls[class_idx], detections[class_idx], [min_overlap]) diff --git a/mmaction/models/backbones/c3d.py b/mmaction/models/backbones/c3d.py index f07385fa3e..847ff576d9 100644 --- a/mmaction/models/backbones/c3d.py +++ b/mmaction/models/backbones/c3d.py @@ -1,7 +1,7 @@ +import torch.nn as nn from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init from mmcv.runner import load_checkpoint from mmcv.utils import _BatchNorm -from torch import nn as nn from ...utils import get_root_logger from ..registry import BACKBONES @@ -137,6 +137,3 @@ def forward(self, x): x = self.relu(self.fc7(x)) return x - - def train(self, mode=True): - super(C3D, self).train(mode) diff --git a/mmaction/models/backbones/resnet.py b/mmaction/models/backbones/resnet.py index d874413ff0..ca0a20578b 100644 --- a/mmaction/models/backbones/resnet.py +++ b/mmaction/models/backbones/resnet.py @@ -1,7 +1,7 @@ +import torch.nn as nn from mmcv.cnn import ConvModule, constant_init, kaiming_init from mmcv.runner import _load_checkpoint, load_checkpoint from mmcv.utils import _BatchNorm -from torch import nn as nn from torch.utils import checkpoint as cp from ...utils import get_root_logger @@ -355,7 +355,7 @@ def __init__(self, self.pretrained = pretrained self.torchvision_pretrain = torchvision_pretrain self.num_stages = num_stages - assert num_stages >= 1 and num_stages <= 4 + assert 1 <= num_stages <= 4 self.out_indices = out_indices assert max(out_indices) < num_stages self.strides = strides @@ -416,7 +416,8 @@ def _make_stem_layer(self): act_cfg=self.act_cfg) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - def _load_conv_params(self, conv, state_dict_tv, module_name_tv, + @staticmethod + def _load_conv_params(conv, state_dict_tv, module_name_tv, loaded_param_names): """Load the conv parameters of resnet from torchvision. @@ -439,8 +440,8 @@ def _load_conv_params(self, conv, state_dict_tv, module_name_tv, conv.bias.data.copy_(state_dict_tv[bias_tv_name]) loaded_param_names.append(bias_tv_name) - def _load_bn_params(self, bn, state_dict_tv, module_name_tv, - loaded_param_names): + @staticmethod + def _load_bn_params(bn, state_dict_tv, module_name_tv, loaded_param_names): """Load the bn parameters of resnet from torchvision. Args: @@ -467,10 +468,7 @@ def _load_bn_params(self, bn, state_dict_tv, module_name_tv, param.data.copy_(param_tv) loaded_param_names.append(param_tv_name) - def _load_torchvision_checkpoint(self, - pretrained, - strict=False, - logger=None): + def _load_torchvision_checkpoint(self, logger=None): """Initiate the parameters from torchvision pretrained checkpoint.""" state_dict_torchvision = _load_checkpoint(self.pretrained) if 'state_dict' in state_dict_torchvision: @@ -511,8 +509,7 @@ def init_weights(self): logger = get_root_logger() if self.torchvision_pretrain: # torchvision's - self._load_torchvision_checkpoint( - self.pretrained, strict=False, logger=logger) + self._load_torchvision_checkpoint(logger) else: # ours load_checkpoint( @@ -546,8 +543,8 @@ def forward(self, x): outs.append(x) if len(outs) == 1: return outs[0] - else: - return tuple(outs) + + return tuple(outs) def _freeze_stages(self): """Prevent all the parameters from being optimized before diff --git a/mmaction/models/backbones/resnet3d.py b/mmaction/models/backbones/resnet3d.py index 29cd93394b..4fd78bfe4a 100644 --- a/mmaction/models/backbones/resnet3d.py +++ b/mmaction/models/backbones/resnet3d.py @@ -417,7 +417,7 @@ def __init__(self, self.in_channels = in_channels self.base_channels = base_channels self.num_stages = num_stages - assert num_stages >= 1 and num_stages <= 4 + assert 1 <= num_stages <= 4 self.out_indices = out_indices assert max(out_indices) < num_stages self.spatial_strides = spatial_strides @@ -481,8 +481,8 @@ def __init__(self, self.feat_dim = self.block.expansion * self.base_channels * 2**( len(self.stage_blocks) - 1) - def make_res_layer(self, - block, + @staticmethod + def make_res_layer(block, inplanes, planes, blocks, @@ -595,7 +595,8 @@ def make_res_layer(self, return nn.Sequential(*layers) - def _inflate_conv_params(self, conv3d, state_dict_2d, module_name_2d, + @staticmethod + def _inflate_conv_params(conv3d, state_dict_2d, module_name_2d, inflated_param_names): """Inflate a conv module from 2d to 3d. @@ -622,7 +623,8 @@ def _inflate_conv_params(self, conv3d, state_dict_2d, module_name_2d, conv3d.bias.data.copy_(state_dict_2d[bias_2d_name]) inflated_param_names.append(bias_2d_name) - def _inflate_bn_params(self, bn3d, state_dict_2d, module_name_2d, + @staticmethod + def _inflate_bn_params(bn3d, state_dict_2d, module_name_2d, inflated_param_names): """Inflate a norm module from 2d to 3d. @@ -802,8 +804,8 @@ def forward(self, x): outs.append(x) if len(outs) == 1: return outs[0] - else: - return tuple(outs) + + return tuple(outs) def train(self, mode=True): """Set the optimization status when training.""" diff --git a/mmaction/models/backbones/resnet3d_csn.py b/mmaction/models/backbones/resnet3d_csn.py index a461739259..97c3e420aa 100644 --- a/mmaction/models/backbones/resnet3d_csn.py +++ b/mmaction/models/backbones/resnet3d_csn.py @@ -48,7 +48,7 @@ def __init__(self, conv2_stride = self.conv2.conv.stride conv2_padding = self.conv2.conv.padding conv2_dilation = self.conv2.conv.dilation - conv2_bias = True if self.conv2.conv.bias else False + conv2_bias = bool(self.conv2.conv.bias) self.conv2 = ConvModule( planes, planes, diff --git a/mmaction/models/backbones/resnet3d_slowfast.py b/mmaction/models/backbones/resnet3d_slowfast.py index 6637ecac7f..c2f0301f55 100644 --- a/mmaction/models/backbones/resnet3d_slowfast.py +++ b/mmaction/models/backbones/resnet3d_slowfast.py @@ -300,7 +300,7 @@ def _freeze_stages(self): for param in m.parameters(): param.requires_grad = False - if (i != len(self.res_layers) and self.lateral): + if i != len(self.res_layers) and self.lateral: # No fusion needed in the final stage lateral_name = self.lateral_connections[i - 1] conv_lateral = getattr(self, lateral_name) @@ -343,8 +343,8 @@ def build_pathway(cfg, *args, **kwargs): pathway_type = cfg_.pop('type') if pathway_type not in pathway_cfg: raise KeyError(f'Unrecognized pathway type {pathway_type}') - else: - pathway_cls = pathway_cfg[pathway_type] + + pathway_cls = pathway_cfg[pathway_type] pathway = pathway_cls(*args, **kwargs, **cfg_) return pathway diff --git a/mmaction/models/backbones/resnet_tsm.py b/mmaction/models/backbones/resnet_tsm.py index 1f47c32cda..2c4f999b5c 100644 --- a/mmaction/models/backbones/resnet_tsm.py +++ b/mmaction/models/backbones/resnet_tsm.py @@ -173,7 +173,7 @@ def make_temporal_shift(self): ] else: num_segment_list = [self.num_segments] * 4 - if not num_segment_list[-1] > 0: + if num_segment_list[-1] <= 0: raise ValueError('num_segment_list[-1] must be positive') if self.shift_place == 'block': diff --git a/mmaction/models/backbones/x3d.py b/mmaction/models/backbones/x3d.py index 32f0f2be24..4d6b85cff3 100644 --- a/mmaction/models/backbones/x3d.py +++ b/mmaction/models/backbones/x3d.py @@ -24,7 +24,8 @@ def __init__(self, channels, reduction): self.bottleneck, channels, kernel_size=1, padding=0) self.sigmoid = nn.Sigmoid() - def _round_width(self, width, multiplier, min_width=8, divisor=8): + @staticmethod + def _round_width(width, multiplier, min_width=8, divisor=8): width *= multiplier min_width = min_width or divisor width_out = max(min_width, @@ -244,7 +245,7 @@ def __init__(self, ] self.num_stages = num_stages - assert num_stages >= 1 and num_stages <= 4 + assert 1 <= num_stages <= 4 self.spatial_strides = spatial_strides assert len(spatial_strides) == num_stages self.frozen_stages = frozen_stages @@ -306,7 +307,8 @@ def __init__(self, act_cfg=self.act_cfg) self.feat_dim = int(self.feat_dim * self.gamma_b) - def _round_width(self, width, multiplier, min_depth=8, divisor=8): + @staticmethod + def _round_width(width, multiplier, min_depth=8, divisor=8): """Round width of filters based on width multiplier.""" if not multiplier: return width @@ -319,9 +321,9 @@ def _round_width(self, width, multiplier, min_depth=8, divisor=8): new_filters += divisor return int(new_filters) - def _round_repeats(self, repeats, multiplier): + @staticmethod + def _round_repeats(repeats, multiplier): """Round number of layers based on depth multiplier.""" - multiplier = multiplier if not multiplier: return repeats return int(math.ceil(multiplier * repeats)) @@ -391,7 +393,7 @@ def make_res_layer(self, if self.se_style == 'all': use_se = [True] * blocks elif self.se_style == 'half': - use_se = [True if i % 2 == 0 else False for i in range(blocks)] + use_se = [i % 2 == 0 for i in range(blocks)] else: raise NotImplementedError @@ -505,7 +507,7 @@ def forward(self, x): """ x = self.conv1_s(x) x = self.conv1_t(x) - for i, layer_name in enumerate(self.res_layers): + for layer_name in self.res_layers: res_layer = getattr(self, layer_name) x = res_layer(x) x = self.conv5(x) diff --git a/mmaction/models/builder.py b/mmaction/models/builder.py index 7fe06c59e7..43a4f32f71 100644 --- a/mmaction/models/builder.py +++ b/mmaction/models/builder.py @@ -23,8 +23,8 @@ def build(cfg, registry, default_args=None): build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg ] return nn.Sequential(*modules) - else: - return build_from_cfg(cfg, registry, default_args) + + return build_from_cfg(cfg, registry, default_args) def build_backbone(cfg): @@ -59,8 +59,10 @@ def build_model(cfg, train_cfg=None, test_cfg=None): obj_type = args.pop('type') if obj_type in LOCALIZERS: return build_localizer(cfg) - elif obj_type in RECOGNIZERS: + if obj_type in RECOGNIZERS: return build_recognizer(cfg, train_cfg, test_cfg) + raise ValueError(f'{obj_type} is not registered in ' + 'LOCALIZERS or RECOGNIZERS') def build_neck(cfg): diff --git a/mmaction/models/heads/base.py b/mmaction/models/heads/base.py index 7165766150..92e9b0427f 100644 --- a/mmaction/models/heads/base.py +++ b/mmaction/models/heads/base.py @@ -60,12 +60,10 @@ def __init__(self, def init_weights(self): """Initiate the parameters either from existing checkpoint or from scratch.""" - pass @abstractmethod def forward(self, x): """Defines the computation performed at every call.""" - pass def loss(self, cls_score, labels, **kwargs): """Calculate the loss given output ``cls_score``, target ``labels``. @@ -96,7 +94,7 @@ def loss(self, cls_score, labels, **kwargs): loss_cls = self.loss_cls(cls_score, labels, **kwargs) # loss_cls may be dictionary or single tensor - if type(loss_cls) is dict: + if isinstance(loss_cls, dict): losses.update(loss_cls) else: losses['loss_cls'] = loss_cls diff --git a/mmaction/models/heads/ssn_head.py b/mmaction/models/heads/ssn_head.py index 3fdd8e1ef0..3399d3dc1a 100644 --- a/mmaction/models/heads/ssn_head.py +++ b/mmaction/models/heads/ssn_head.py @@ -19,10 +19,9 @@ def parse_stage_config(stage_cfg): """ if isinstance(stage_cfg, int): return (stage_cfg, ), stage_cfg - elif isinstance(stage_cfg, tuple): + if isinstance(stage_cfg, tuple): return stage_cfg, sum(stage_cfg) - else: - raise ValueError(f'Incorrect STPP config {stage_cfg}') + raise ValueError(f'Incorrect STPP config {stage_cfg}') class STPPTrain(nn.Module): @@ -50,7 +49,8 @@ def __init__(self, stpp_stage=(1, (1, 2), 1), num_segments_list=(2, 5, 2)): self.num_segments_list = num_segments_list - def _extract_stage_feature(self, stage_feat, stage_parts, num_multipliers, + @staticmethod + def _extract_stage_feature(stage_feat, stage_parts, num_multipliers, scale_factors, num_samples): """Extract stage feature based on structured temporal pyramid pooling. @@ -167,8 +167,9 @@ def __init__(self, self.complete_slice.stop, self.complete_slice.stop + self.reg_score_len * self.num_multipliers) - def _pyramids_pooling(self, out_scores, index, raw_scores, ticks, - scale_factors, score_len, stpp_stage): + @staticmethod + def _pyramids_pooling(out_scores, index, raw_scores, ticks, scale_factors, + score_len, stpp_stage): """Perform pyramids pooling. Args: @@ -401,12 +402,11 @@ def forward(self, x, test_mode=False): else: bbox_preds = None return activity_scores, complete_scores, bbox_preds - else: - x, proposal_tick_list, scale_factor_list = x - test_scores = self.test_fc(x) - (activity_scores, completeness_scores, - bbox_preds) = self.consensus(test_scores, proposal_tick_list, - scale_factor_list) - - return (test_scores, activity_scores, completeness_scores, - bbox_preds) + + x, proposal_tick_list, scale_factor_list = x + test_scores = self.test_fc(x) + (activity_scores, completeness_scores, + bbox_preds) = self.consensus(test_scores, proposal_tick_list, + scale_factor_list) + + return (test_scores, activity_scores, completeness_scores, bbox_preds) diff --git a/mmaction/models/localizers/base.py b/mmaction/models/localizers/base.py index 5cdae15408..abc715593d 100644 --- a/mmaction/models/localizers/base.py +++ b/mmaction/models/localizers/base.py @@ -44,19 +44,17 @@ def extract_feat(self, imgs): @abstractmethod def forward_train(self, imgs, labels): """Defines the computation performed at training.""" - pass @abstractmethod def forward_test(self, imgs): """Defines the computation performed at testing.""" - pass def forward(self, imgs, return_loss=True, **kwargs): """Define the computation performed at every call.""" if return_loss: return self.forward_train(imgs, **kwargs) - else: - return self.forward_test(imgs, **kwargs) + + return self.forward_test(imgs, **kwargs) @staticmethod def _parse_losses(losses): diff --git a/mmaction/models/localizers/bmn.py b/mmaction/models/localizers/bmn.py index c50d78be16..20efbc5b84 100644 --- a/mmaction/models/localizers/bmn.py +++ b/mmaction/models/localizers/bmn.py @@ -350,10 +350,11 @@ def forward(self, label_end = label_end.to(device) return self.forward_train(raw_feature, label_confidence, label_start, label_end) - else: - return self.forward_test(raw_feature, video_meta) - def _get_interp1d_bin_mask(self, seg_tmin, seg_tmax, tscale, num_samples, + return self.forward_test(raw_feature, video_meta) + + @staticmethod + def _get_interp1d_bin_mask(seg_tmin, seg_tmax, tscale, num_samples, num_samples_per_bin): """Generate sample mask for a boundary-matching pair.""" plen = float(seg_tmax - seg_tmin) diff --git a/mmaction/models/localizers/bsn.py b/mmaction/models/localizers/bsn.py index 01db2bba79..23f3e7230a 100644 --- a/mmaction/models/localizers/bsn.py +++ b/mmaction/models/localizers/bsn.py @@ -148,7 +148,8 @@ def forward_test(self, raw_feature, video_meta): video_meta_list = [dict(x) for x in video_meta] video_results = [] - for batch_idx in range(len(batch_action)): + + for batch_idx, _ in enumerate(batch_action): video_name = video_meta_list[batch_idx]['video_name'] video_action = batch_action[batch_idx] video_start = batch_start[batch_idx] @@ -220,8 +221,8 @@ def forward(self, label_end = label_end.to(device) return self.forward_train(raw_feature, label_action, label_start, label_end) - else: - return self.forward_test(raw_feature, video_meta) + + return self.forward_test(raw_feature, video_meta) @LOCALIZERS.register_module() @@ -301,7 +302,7 @@ def _forward(self, x): Returns: torch.Tensor: The output of the module. """ - x = torch.cat([data for data in x]) + x = torch.cat(list(x)) x = F.relu(self.fc1_ratio * self.fc1(x)) x = torch.sigmoid(self.fc2_ratio * self.fc2(x)) return x @@ -309,8 +310,7 @@ def _forward(self, x): def forward_train(self, bsp_feature, reference_temporal_iou): """Define the computation performed at every call when training.""" pem_output = self._forward(bsp_feature) - reference_temporal_iou = torch.cat( - [data for data in reference_temporal_iou]) + reference_temporal_iou = torch.cat(list(reference_temporal_iou)) device = pem_output.device reference_temporal_iou = reference_temporal_iou.to(device) @@ -390,6 +390,6 @@ def forward(self, """Define the computation performed at every call.""" if return_loss: return self.forward_train(bsp_feature, reference_temporal_iou) - else: - return self.forward_test(bsp_feature, tmin, tmax, tmin_score, - tmax_score, video_meta) + + return self.forward_test(bsp_feature, tmin, tmax, tmin_score, + tmax_score, video_meta) diff --git a/mmaction/models/losses/base.py b/mmaction/models/losses/base.py index 2bd0797064..eb1a43f6ee 100644 --- a/mmaction/models/losses/base.py +++ b/mmaction/models/losses/base.py @@ -35,7 +35,7 @@ def forward(self, *args, **kwargs): torch.Tensor: The calculated loss. """ ret = self._forward(*args, **kwargs) - if type(ret) is dict: + if isinstance(ret, dict): for k in ret: if 'loss' in k: ret[k] *= self.loss_weight diff --git a/mmaction/models/losses/bmn_loss.py b/mmaction/models/losses/bmn_loss.py index bed1dc5320..50e49729b0 100644 --- a/mmaction/models/losses/bmn_loss.py +++ b/mmaction/models/losses/bmn_loss.py @@ -22,7 +22,8 @@ class BMNLoss(nn.Module): results of candidate proposals. """ - def tem_loss(self, pred_start, pred_end, gt_start, gt_end): + @staticmethod + def tem_loss(pred_start, pred_end, gt_start, gt_end): """Calculate Temporal Evaluation Module Loss. This function calculate the binary_logistic_regression_loss for start @@ -42,8 +43,8 @@ def tem_loss(self, pred_start, pred_end, gt_start, gt_end): loss = loss_start + loss_end return loss - def pem_reg_loss(self, - pred_score, + @staticmethod + def pem_reg_loss(pred_score, gt_iou_map, mask, high_temporal_iou_threshold=0.7, @@ -91,8 +92,8 @@ def pem_reg_loss(self, return loss - def pem_cls_loss(self, - pred_score, + @staticmethod + def pem_cls_loss(pred_score, gt_iou_map, mask, threshold=0.9, diff --git a/mmaction/models/losses/hvu_loss.py b/mmaction/models/losses/hvu_loss.py index 836ad780c7..9d9b00567d 100644 --- a/mmaction/models/losses/hvu_loss.py +++ b/mmaction/models/losses/hvu_loss.py @@ -46,8 +46,8 @@ def __init__(self, self.category_nums = category_nums self.category_loss_weights = category_loss_weights assert len(self.category_nums) == len(self.category_loss_weights) - for loss_weight in self.category_loss_weights: - assert loss_weight >= 0 + for category_loss_weight in self.category_loss_weights: + assert category_loss_weight >= 0 self.loss_type = loss_type self.with_mask = with_mask self.reduction = reduction @@ -84,11 +84,12 @@ def _forward(self, cls_score, label, mask, category_mask): w_loss_cls = w_loss_cls / torch.sum(mask, dim=1) w_loss_cls = torch.mean(w_loss_cls) return dict(loss_cls=w_loss_cls) - else: - if self.reduction == 'sum': - loss_cls = torch.sum(loss_cls, dim=-1) - return dict(loss_cls=torch.mean(loss_cls)) - elif self.loss_type == 'individual': + + if self.reduction == 'sum': + loss_cls = torch.sum(loss_cls, dim=-1) + return dict(loss_cls=torch.mean(loss_cls)) + + if self.loss_type == 'individual': losses = {} loss_weights = {} for name, num, start_idx in zip(self.categories, @@ -135,3 +136,6 @@ def _forward(self, cls_score, label, mask, category_mask): }) # Note that the loss weights are just for reference. return losses + else: + raise ValueError("loss_type should be 'all' or 'individual', " + f'but got {self.loss_type}') diff --git a/mmaction/models/losses/ssn_loss.py b/mmaction/models/losses/ssn_loss.py index 34a978bbd0..492dc8ddaa 100644 --- a/mmaction/models/losses/ssn_loss.py +++ b/mmaction/models/losses/ssn_loss.py @@ -9,7 +9,8 @@ @LOSSES.register_module() class SSNLoss(nn.Module): - def activity_loss(self, activity_score, labels, activity_indexer): + @staticmethod + def activity_loss(activity_score, labels, activity_indexer): """Activity Loss. It will calculate activity loss given activity_score and label. @@ -26,8 +27,8 @@ def activity_loss(self, activity_score, labels, activity_indexer): gt = labels[activity_indexer] return F.cross_entropy(pred, gt) - def completeness_loss(self, - completeness_score, + @staticmethod + def completeness_loss(completeness_score, labels, completeness_indexer, positive_per_video, @@ -77,7 +78,8 @@ def completeness_loss(self, return ((positive_loss + incomplete_loss) / float(num_positives + num_incompletes)) - def classwise_regression_loss(self, bbox_pred, labels, bbox_targets, + @staticmethod + def classwise_regression_loss(bbox_pred, labels, bbox_targets, regression_indexer): """Classwise Regression Loss. diff --git a/mmaction/models/necks/tpn.py b/mmaction/models/necks/tpn.py index b40bbba5e9..250357c675 100644 --- a/mmaction/models/necks/tpn.py +++ b/mmaction/models/necks/tpn.py @@ -184,7 +184,7 @@ def __init__(self, in_channels, out_channels): def forward(self, x): out = [] - for i, feature in enumerate(x): + for i, _ in enumerate(x): if isinstance(self.spatial_modulation[i], nn.ModuleList): out_ = x[i] for op in self.spatial_modulation[i]: @@ -433,9 +433,7 @@ def forward(self, x, target=None): top_down_outs = self.level_fusion_1(outs) # Build bottom-up flow using downsample operation - if self.flow_type == 'cascade': - outs = outs - else: + if self.flow_type == 'parallel': outs = [out.clone() for out in temporal_modulation_outs] if len(self.downsample_ops) != 0: for i in range(self.num_tpn_stages - 1): diff --git a/mmaction/models/recognizers/audio_recognizer.py b/mmaction/models/recognizers/audio_recognizer.py index 80738cbf0f..a9b431c06a 100644 --- a/mmaction/models/recognizers/audio_recognizer.py +++ b/mmaction/models/recognizers/audio_recognizer.py @@ -12,8 +12,8 @@ def forward(self, audios, label=None, return_loss=True): if label is None: raise ValueError('Label should not be None.') return self.forward_train(audios, label) - else: - return self.forward_test(audios) + + return self.forward_test(audios) def forward_train(self, audios, labels): """Defines the computation performed at every call when training.""" diff --git a/mmaction/models/recognizers/base.py b/mmaction/models/recognizers/base.py index 7ec0c4697f..fbb2d78c1a 100644 --- a/mmaction/models/recognizers/base.py +++ b/mmaction/models/recognizers/base.py @@ -111,19 +111,16 @@ class score. Only called in test mode. @abstractmethod def forward_train(self, imgs, labels, **kwargs): """Defines the computation performed at every call when training.""" - pass @abstractmethod def forward_test(self, imgs): """Defines the computation performed at every call when evaluation and testing.""" - pass @abstractmethod def forward_gradcam(self, imgs): """Defines the computation performed at every all when using gradcam utils.""" - pass @staticmethod def _parse_losses(losses): @@ -166,12 +163,12 @@ def forward(self, imgs, label=None, return_loss=True, **kwargs): if kwargs.get('gradcam', False): del kwargs['gradcam'] return self.forward_gradcam(imgs, **kwargs) - elif return_loss: + if return_loss: if label is None: raise ValueError('Label should not be None.') return self.forward_train(imgs, label, **kwargs) - else: - return self.forward_test(imgs, **kwargs) + + return self.forward_test(imgs, **kwargs) def train_step(self, data_batch, optimizer, **kwargs): """The iteration step during training. diff --git a/mmaction/utils/gradcam_utils.py b/mmaction/utils/gradcam_utils.py index ae09aaff34..d5048427f7 100644 --- a/mmaction/utils/gradcam_utils.py +++ b/mmaction/utils/gradcam_utils.py @@ -3,10 +3,11 @@ class GradCAM: - """GradCAM class helps create visulization results. + """GradCAM class helps create visualization results. - Visulization results are blended by heatmaps and input images. - This class is modified from https://github.com/facebookresearch/SlowFast/blob/master/slowfast/visualization/gradcam_utils.py # noqa + Visualization results are blended by heatmaps and input images. + This class is modified from + https://github.com/facebookresearch/SlowFast/blob/master/slowfast/visualization/gradcam_utils.py # noqa For more information about GradCAM, please visit: https://arxiv.org/pdf/1610.02391.pdf """ @@ -20,7 +21,7 @@ def __init__(self, model, target_layer_name, colormap='viridis'): be used to get gradients and feature maps from for creating localization maps. colormap (Optional[str]): matplotlib colormap used to create - heatmap. For more information, please visit: + heatmap. Default: 'viridis'. For more information, please visit https://matplotlib.org/3.3.0/tutorials/colors/colormaps.html """ from ..models.recognizers import Recognizer2D, Recognizer3D