Skip to content

Commit

Permalink
Merge pull request #582 from eqcorrscan/detection-trim
Browse files Browse the repository at this point in the history
Detection trim
  • Loading branch information
calum-chamberlain authored Jul 26, 2024
2 parents 3d51092 + b27a382 commit d9d826d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 7 deletions.
28 changes: 22 additions & 6 deletions eqcorrscan/core/match_filter/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ def extract_stream(self, stream, length, prepick, all_vert=False,
pick = [p for p in pick
if p.waveform_id.channel_code == channel]
if len(pick) == 0:
Logger.info("No pick for {0}.{1}".format(station, channel))
Logger.info(
"No pick for {0}.{1}".format(station, channel))
continue
elif len(pick) > 1:
Logger.info(
Expand All @@ -406,13 +407,28 @@ def extract_stream(self, stream, length, prepick, all_vert=False,
pick.sort(key=lambda p: p.time)
pick = pick[0]
cut_start = pick.time - prepick
cut_end = cut_start + length
_st = _st.slice(starttime=cut_start, endtime=cut_end).copy()
# Minimum length check
# Find nearest sample to avoid to too-short length - see #573
for tr in _st:
if abs((tr.stats.endtime - tr.stats.starttime) -
sample_offset = (cut_start -
tr.stats.starttime) * tr.stats.sampling_rate
Logger.debug(
f"Sample offset for slice on {tr.id}: {sample_offset}")
sample_offset //= 1
# If the sample offset is not a whole number, always take the
# sample before that requested
_tr_cut_start = tr.stats.starttime + (
sample_offset * tr.stats.delta)
_tr_cut_end = _tr_cut_start + length
Logger.debug(
f"Trimming {tr.id} between {_tr_cut_end} "
f"and {_tr_cut_end}.")
_tr = tr.slice(_tr_cut_start, _tr_cut_end).copy()
Logger.debug(
f"Length: {(_tr.stats.endtime - _tr.stats.starttime)}")
Logger.debug(f"Requested length: {length}")
if abs((_tr.stats.endtime - _tr.stats.starttime) -
length) < tr.stats.delta:
cut_stream += tr
cut_stream += _tr
else:
Logger.info(
"Insufficient data length for {0}".format(tr.id))
Expand Down
5 changes: 4 additions & 1 deletion eqcorrscan/tests/lag_calc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ class SyntheticTests(unittest.TestCase):
def setUpClass(cls):
np.random.seed(999)
print("Setting up class")
np.random.seed(999)
samp_rate = 50
t_length = .75
# Make some synthetic templates
templates, data, seeds = generate_synth_data(
nsta=5, ntemplates=5, nseeds=10, samp_rate=samp_rate,
t_length=t_length, max_amp=10, max_lag=15, phaseout="both",
jitter=0, noise=False, same_phase=True)
print("Made synthetic data")
# Rename channels
channel_mapper = {"SYN_Z": "HHZ", "SYN_H": "HHN"}
for tr in data:
Expand All @@ -44,6 +44,7 @@ def setUpClass(cls):
party = Party()
t = 0
data_start = data[0].stats.starttime
print("Making party")
for template, template_seeds in zip(templates, seeds):
template_name = "template_{0}".format(t)
detections = []
Expand All @@ -68,6 +69,8 @@ def setUpClass(cls):
family = Family(template=_template, detections=detections)
party += family
t += 1
print(f"Made template {template_name}")
print("Made party")
cls.party = party
cls.data = data
cls.t_length = t_length
Expand Down
46 changes: 46 additions & 0 deletions eqcorrscan/tests/matched_filter/match_filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,52 @@ def test_family_catalogs(self):
family.detections.append(additional_detection)
self.assertEqual(family.catalog, get_catalog(family.detections))

def test_detection_extract_stream(self):
# Create simple synthetic stream
traces = []
pick_sids = {p.waveform_id.get_seed_string() for f in self.party
for d in f for p in d.event.picks}
pick_times = sorted([p.time for f in self.party for d in f
for p in d.event.picks])
first_pick, last_pick = pick_times[0], pick_times[-1]
delta = self.party[0].template.st[0].stats.delta
n_samples = int(((last_pick - first_pick) + 120) / delta)
data = np.arange(n_samples, dtype=int)
for sid in pick_sids:
tr = Trace(data=data.copy())
tr.stats.starttime = first_pick - 60
n, s, l, c = sid.split('.')
tr.stats.delta = delta
tr.stats.network = n
tr.stats.station = s
tr.stats.location = l
tr.stats.channel = c
traces.append(tr)
st = Stream(traces=traces)

# Test straightforward extraction
detection = self.party[0][0]
length, pre_pick = 40.0, 1 / delta
for shift in range(6):
pre_pick -= shift * (delta / 6) # Sub-sample shifting
expected_starts = {
p.waveform_id.get_seed_string(): p.time - pre_pick
for p in detection.event.picks}
cut_st = detection.extract_stream(
stream=st, length=length, prepick=pre_pick)
for sid, expected_start in expected_starts.items():
Logger.debug(f"Checking for {sid}")
tr = cut_st.select(id=sid)
# Check that we get a returned trace
self.assertTrue(len(tr), 1)
tr = tr[0]
# Check that start is within one sample of the expected start
self.assertLess(abs(tr.stats.starttime - expected_start),
delta)
# Check that the length is correct
returned_length = tr.stats.endtime - tr.stats.starttime
self.assertEqual(length, returned_length)


class TestTemplateGrouping(unittest.TestCase):
@classmethod
Expand Down

0 comments on commit d9d826d

Please sign in to comment.