diff --git a/cyvcf2/cyvcf2.pyx b/cyvcf2/cyvcf2.pyx index a729997..6229d87 100644 --- a/cyvcf2/cyvcf2.pyx +++ b/cyvcf2/cyvcf2.pyx @@ -293,6 +293,12 @@ cdef class VCF(HTSFile): if threads is not None: self.set_threads(threads) + def __enter__(self): + return self + + def __exit__(self, type, value, tb): + self.close() + def set_threads(self, int n): v = hts_set_threads(self.hts, n) if v < 0: @@ -2414,6 +2420,12 @@ cdef class Writer(VCF): bcf_hdr_sync(self.hdr) self.header_written = False + def __enter__(self): + return self + + def __exit__(self, type, value, tb): + self.close() + @staticmethod def _infer_file_mode(fname, mode=None): if mode is not None: diff --git a/cyvcf2/tests/test_reader.py b/cyvcf2/tests/test_reader.py index 84d2249..43d954d 100644 --- a/cyvcf2/tests/test_reader.py +++ b/cyvcf2/tests/test_reader.py @@ -1410,3 +1410,9 @@ def test_num_records_no_index(path): vcf = VCF(os.path.join(HERE, path)) with pytest.raises(ValueError, match="must be indexed"): vcf.num_records + +def test_reader_context_manager(): + with VCF(VCF_PATH) as vcf: + pass + with pytest.raises(Exception, match="attempt to iterate over closed"): + next(vcf)