From a8ecbbf36805a3ad002bdca9826d0e2fde477879 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 5 May 2023 14:56:41 +0200 Subject: [PATCH] Allow postponing dataset integrity checks to training time --- returnn/datasets/hdf.py | 97 +++++++++++++++++++++++++++++++++++------ 1 file changed, 84 insertions(+), 13 deletions(-) diff --git a/returnn/datasets/hdf.py b/returnn/datasets/hdf.py index 95d83b7fd..af1cfad89 100644 --- a/returnn/datasets/hdf.py +++ b/returnn/datasets/hdf.py @@ -471,9 +471,10 @@ class StreamParser(object): Stream parser. """ - def __init__(self, seq_names, stream): + def __init__(self, seq_names, stream, use_lazy_data_integrity_checks=False): self.seq_names = seq_names self.stream = stream + self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks self.num_features = None self.feature_type = None # 1 for sparse, 2 for dense @@ -518,8 +519,10 @@ def __init__(self, *args, **kwargs): if self.dtype is None: self.dtype = str(seq_data.dtype) - assert seq_data.shape[1] == self.num_features - assert str(seq_data.dtype) == self.dtype + if self.use_lazy_data_integrity_checks: + break + + self.check_data_integrity(seq_data, s) self.feature_type = 2 @@ -528,7 +531,12 @@ def get_data(self, seq_name): :param str seq_name: :rtype: numpy.ndarray """ - return self.stream["data"][seq_name][...] + data = self.stream["data"][seq_name][...] + + if self.use_lazy_data_integrity_checks: + self.check_data_integrity(data, seq_name) + + return data def get_seq_length(self, seq_name): """ @@ -537,6 +545,18 @@ def get_seq_length(self, seq_name): """ return self.stream["data"][seq_name].shape[0] + def check_data_integrity(self, data, seq_name): + """ + :param numpy.ndarray data: + :param str seq_name: + """ + + assert len(data.shape) == 2, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)" + assert ( + self.num_features == data.shape[1] + ), f"feature dim mismatch in {seq_name}: {data.shape[1]} (should be {self.num_features})" + assert self.dtype == str(data.dtype), f"dtype mismatch {seq_name}: {str(data.dtype)} (should be {self.dtype})" + class SparseStreamParser(StreamParser): """ @@ -552,7 +572,11 @@ def __init__(self, *args, **kwargs): if self.dtype is None: self.dtype = str(seq_data.dtype) - assert str(seq_data.dtype) == self.dtype + + if self.use_lazy_data_integrity_checks: + break + + self.check_data_integrity(seq_data, s) self.num_features = self.stream["feature_names"].shape[0] self.feature_type = 1 @@ -562,7 +586,12 @@ def get_data(self, seq_name): :param str seq_name: :rtype: numpy.ndarray """ - return self.stream["data"][seq_name][:] + data = self.stream["data"][seq_name][:] + + if self.use_lazy_data_integrity_checks: + self.check_data_integrity(data, seq_name) + + return data def get_seq_length(self, seq_name): """ @@ -571,6 +600,17 @@ def get_seq_length(self, seq_name): """ return self.stream["data"][seq_name].shape[0] + def check_data_integrity(self, data, seq_name): + """ + :param numpy.ndarray data: + :param str seq_name: + """ + + assert len(data.shape) == 1, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)" + assert self.dtype == str( + data.dtype + ), f"dtype mismatch in {seq_name}: {str(data.dtype)} (should be {self.dtype})" + class SegmentAlignmentStreamParser(StreamParser): """ @@ -585,10 +625,11 @@ def __init__(self, *args, **kwargs): if self.dtype is None: self.dtype = str(seq_data.dtype) - assert str(seq_data.dtype) == self.dtype - assert len(seq_data.shape) == 2 - assert seq_data.shape[1] == 2 + if self.use_lazy_data_integrity_checks: + break + + self.check_data_integrity(seq_data, s) self.num_features = self.stream["feature_names"].shape[0] self.feature_type = 1 @@ -602,6 +643,9 @@ def get_data(self, seq_name): length = self.get_seq_length(seq_name) // 2 segments = self.stream["data"][seq_name][:] + if self.use_lazy_data_integrity_checks: + self.check_data_integrity(segments, seq_name) + alignment = numpy.zeros((length, 2), dtype=self.dtype) num_segments = segments.shape[0] seg_end = 0 @@ -621,6 +665,18 @@ def get_seq_length(self, seq_name): """ return 2 * sum(self.stream["data"][seq_name][:, 1]) + def check_data_integrity(self, data, seq_name): + """ + :param numpy.ndarray data: + :param str seq_name: + """ + + assert len(data.shape) == 2, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)" + assert data.shape[1] == 2, f"feature dim mismatch in {seq_name}: {data.shape[1]} (should be 2-dimensional)" + assert self.dtype == str( + data.dtype + ), f"dtype mismatch in {seq_name}: {str(data.dtype)} (should be {self.dtype})" + class NextGenHDFDataset(CachedDataset2): """ @@ -633,7 +689,7 @@ class NextGenHDFDataset(CachedDataset2): "segment_alignment": SegmentAlignmentStreamParser, } - def __init__(self, input_stream_name, files=None, **kwargs): + def __init__(self, input_stream_name, files=None, use_lazy_data_integrity_checks=False, **kwargs): """ :param str input_stream_name: :param None|list[str] files: @@ -649,6 +705,7 @@ def __init__(self, input_stream_name, files=None, **kwargs): self.file_indices = [] self.seq_order = [] self.all_parsers = collections.defaultdict(list) + self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks if files: for fn in files: @@ -684,7 +741,9 @@ def add_file(self, path): ) parsers = { - name: NextGenHDFDataset.parsers[stream.attrs["parser"]](norm_seqs, stream) + name: NextGenHDFDataset.parsers[stream.attrs["parser"]]( + norm_seqs, stream, use_lazy_data_integrity_checks=self.use_lazy_data_integrity_checks + ) for name, stream in cur_file["streams"].items() } for k, v in parsers.items(): @@ -807,7 +866,15 @@ class SiameseHDFDataset(CachedDataset2): "segment_alignment": SegmentAlignmentStreamParser, } - def __init__(self, input_stream_name, seq_label_stream="words", class_distribution=None, files=None, **kwargs): + def __init__( + self, + input_stream_name, + seq_label_stream="words", + class_distribution=None, + files=None, + use_lazy_data_integrity_checks=False, + **kwargs, + ): """ :param str input_stream_name: name of a feature stream :param str seq_label_stream: name of a stream with labels @@ -833,6 +900,8 @@ def __init__(self, input_stream_name, seq_label_stream="words", class_distributi self.target_to_seqs = {} # (int) class_index -> (string) sequence_names self.curr_epoch_triplets = [] self.targets_stream = seq_label_stream + self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks + if files: for fn in files: self.add_file(fn) @@ -872,7 +941,9 @@ def add_file(self, path): ) parsers = { - name: SiameseHDFDataset.parsers[stream.attrs["parser"]](norm_seqs, stream) + name: SiameseHDFDataset.parsers[stream.attrs["parser"]]( + norm_seqs, stream, use_lazy_data_integrity_checks=self.use_lazy_data_integrity_checks + ) for name, stream in cur_file["streams"].items() } # name - stream name (words, features, orth_features) for k, v in parsers.items():