-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensemble #19
Open
davek44
wants to merge
72
commits into
master
Choose a base branch
from
ensemble
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Ensemble #19
Changes from all commits
Commits
Show all changes
72 commits
Select commit
Hold shift + click to select a range
26891d5
failing with map_fn
davek44 41444af
cond refuses transform_fn
davek44 0bb926d
another failed attempt
davek44 32a529c
moving cond downstream but map_fn still fails
davek44 bf42d4e
separate train and eval loss
davek44 76f3340
needs preds_length
davek44 2e57fd6
augmentation methods
davek44 2797ccb
h5 in-graph augmentation
davek44 b66bc9c
tuning
davek44 257544b
tran epoch bug
davek44 a013f9e
h5 ensembling in-graph
davek44 18386bc
test feed dict
davek44 9686338
update placeholder
davek44 cd9a875
0 shift defaults
davek44 558f53f
testing
davek44 03fad02
target labels
davek44 bcef7ab
float64 loss mean
davek44 19e110b
default shift 0
davek44 23d2f70
no data open
davek44 b4e5df5
create new data_ops dict
davek44 aaa46fa
average predictions, not representations
davek44 cd77912
rename build
davek44 df29615
predict in-graph ensembling
davek44 269a2ed
penultimate draft
davek44 a16a31b
missing comma
davek44 2e5af1c
penultimate loss fix
davek44 f376d01
hidden and map
davek44 29dd294
align predict_h5 with predict_h5_manual
davek44 0f0f34a
sqrt soft clip
davek44 672b434
debugging TFR
davek44 e0a0408
nan baseline
davek44 777b0e5
seqs_per_tfr
davek44 ab0ed1d
shuffle bug
davek44 c86098f
failing with map_fn
davek44 391abd4
cond refuses transform_fn
davek44 77bda7f
another failed attempt
davek44 57ee2c3
moving cond downstream but map_fn still fails
davek44 d917be8
separate train and eval loss
davek44 17fd1ce
needs preds_length
davek44 96c01d6
augmentation methods
davek44 1770abd
h5 in-graph augmentation
davek44 65c7bb5
tuning
davek44 68f5f65
tran epoch bug
davek44 702ec1e
h5 ensembling in-graph
davek44 40a960e
test feed dict
davek44 e80c4fb
update placeholder
davek44 dfa8224
0 shift defaults
davek44 c253cb9
testing
davek44 f68c6ac
target labels
davek44 b344161
float64 loss mean
davek44 74b12c3
default shift 0
davek44 99fdf38
no data open
davek44 53b9a53
create new data_ops dict
davek44 6799d25
average predictions, not representations
davek44 2769f82
rename build
davek44 44b099d
predict in-graph ensembling
davek44 f6c4379
penultimate draft
davek44 ad01d86
missing comma
davek44 0ec3447
penultimate loss fix
davek44 3f35246
hidden and map
davek44 2f84d67
align predict_h5 with predict_h5_manual
davek44 7ff7ed2
sad tf.data
davek44 6b5e502
reverse_complement bug fix
davek44 51a3671
optimizing
davek44 5e19e7e
sadq multi
davek44 1d6bb21
dynamic batch size
davek44 74dec8b
full batches unnecessary
davek44 8264f08
weighted average losses
davek44 eb5c224
sad optimizations
davek44 33fa21e
increase shuffle buffer
davek44 a4c05a6
typo and descriptions
davek44 415a8ca
conflicts
davek44 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ========================================================================= | ||
|
||
import pdb | ||
import tensorflow as tf | ||
|
||
from basenji import ops | ||
|
||
def shift_sequence(seq, shift_amount, pad_value=0.25): | ||
"""Shift a sequence left or right by shift_amount. | ||
|
||
Args: | ||
seq: a [batch_size, sequence_length, sequence_depth] sequence to shift | ||
shift_amount: the signed amount to shift (tf.int32 or int) | ||
pad_value: value to fill the padding (primitive or scalar tf.Tensor) | ||
""" | ||
if seq.shape.ndims != 3: | ||
raise ValueError('input sequence should be rank 3') | ||
input_shape = seq.shape | ||
|
||
pad = pad_value * tf.ones_like(seq[:, 0:tf.abs(shift_amount), :]) | ||
|
||
def _shift_right(_seq): | ||
sliced_seq = _seq[:, :-shift_amount:, :] | ||
return tf.concat([pad, sliced_seq], axis=1) | ||
|
||
def _shift_left(_seq): | ||
sliced_seq = _seq[:, -shift_amount:, :] | ||
return tf.concat([sliced_seq, pad], axis=1) | ||
|
||
output = tf.cond( | ||
tf.greater(shift_amount, 0), lambda: _shift_right(seq), | ||
lambda: _shift_left(seq)) | ||
|
||
output.set_shape(input_shape) | ||
return output | ||
|
||
def augment_deterministic_set(data_ops, augment_rc=False, augment_shifts=[0]): | ||
""" | ||
|
||
Args: | ||
data_ops: dict with keys 'sequence,' 'label,' and 'na.' | ||
augment_rc: Boolean | ||
augment_shifts: List of ints. | ||
Returns | ||
data_ops_list: | ||
""" | ||
augment_pairs = [] | ||
for ashift in augment_shifts: | ||
augment_pairs.append((False, ashift)) | ||
if augment_rc: | ||
augment_pairs.append((True, ashift)) | ||
|
||
data_ops_list = [] | ||
for arc, ashift in augment_pairs: | ||
data_ops_aug = augment_deterministic(data_ops, arc, ashift) | ||
data_ops_list.append(data_ops_aug) | ||
|
||
return data_ops_list | ||
|
||
|
||
def augment_deterministic(data_ops, augment_rc=False, augment_shift=0): | ||
"""Apply a deterministic augmentation, specified by the parameters. | ||
|
||
Args: | ||
data_ops: dict with keys 'sequence,' 'label,' and 'na.' | ||
augment_rc: Boolean | ||
<<<<<<< HEAD | ||
augment_shift: Int | ||
Returns | ||
data_ops: augmented data, with all existing keys transformed | ||
and 'reverse_preds' bool added. | ||
""" | ||
|
||
data_ops_aug = {} | ||
if 'label' in data_ops: | ||
data_ops_aug['label'] = data_ops['label'] | ||
if 'na' in data_ops: | ||
data_ops_aug['na'] = data_ops['na'] | ||
======= | ||
augment_shifts: Int | ||
Returns | ||
data_ops: augmented data | ||
""" | ||
|
||
data_ops_aug = {'label': data_ops['label'], 'na': data_ops['na']} | ||
>>>>>>> 29dd294bf104eb6f38559a6665fc2ff7d233afc9 | ||
|
||
if augment_shift == 0: | ||
data_ops_aug['sequence'] = data_ops['sequence'] | ||
else: | ||
shift_amount = tf.constant(augment_shift, shape=(), dtype=tf.int64) | ||
data_ops_aug['sequence'] = shift_sequence(data_ops['sequence'], shift_amount) | ||
|
||
if augment_rc: | ||
data_ops_aug = augment_deterministic_rc(data_ops_aug) | ||
else: | ||
data_ops_aug['reverse_preds'] = tf.zeros((), dtype=tf.bool) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are the semantics of these as targets? |
||
|
||
return data_ops_aug | ||
|
||
|
||
def augment_deterministic_rc(data_ops): | ||
"""Apply a deterministic reverse complement augmentation. | ||
|
||
Args: | ||
data_ops: dict with keys 'sequence,' 'label,' and 'na.' | ||
Returns | ||
data_ops_aug: augmented data ops | ||
""" | ||
<<<<<<< HEAD | ||
data_ops_aug = ops.reverse_complement_transform(data_ops) | ||
data_ops_aug['reverse_preds'] = tf.ones((), dtype=tf.bool) | ||
======= | ||
seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']] | ||
seq, label, na = ops.reverse_complement_transform(seq, label, na) | ||
reverse_preds = tf.ones((), dtype=tf.bool) | ||
data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds} | ||
>>>>>>> 29dd294bf104eb6f38559a6665fc2ff7d233afc9 | ||
return data_ops_aug | ||
|
||
|
||
def augment_stochastic_rc(data_ops): | ||
"""Apply a stochastic reverse complement augmentation. | ||
|
||
Args: | ||
data_ops: dict with keys 'sequence,' 'label,' and 'na.' | ||
Returns | ||
data_ops_aug: augmented data | ||
""" | ||
<<<<<<< HEAD | ||
reverse_preds = tf.random_uniform(shape=[]) > 0.5 | ||
data_ops_aug = tf.cond(reverse_preds, lambda: ops.reverse_complement_transform(data_ops), | ||
lambda: data_ops.copy()) | ||
data_ops_aug['reverse_preds'] = reverse_preds | ||
======= | ||
seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']] | ||
reverse_preds = tf.random_uniform(shape=[]) > 0.5 | ||
seq, label, na = tf.cond(reverse_preds, lambda: ops.reverse_complement_transform(seq, label, na), | ||
lambda: (seq, label, na)) | ||
data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds} | ||
>>>>>>> 29dd294bf104eb6f38559a6665fc2ff7d233afc9 | ||
return data_ops_aug | ||
|
||
|
||
def augment_stochastic_shifts(seq, augment_shifts): | ||
"""Apply a stochastic shift augmentation. | ||
|
||
Args: | ||
seq: input sequence of size [batch_size, length, depth] | ||
augment_shifts: list of int offsets to sample from | ||
Returns: | ||
shifted and padded sequence of size [batch_size, length, depth] | ||
""" | ||
shift_index = tf.random_uniform(shape=[], minval=0, | ||
maxval=len(augment_shifts), dtype=tf.int64) | ||
shift_value = tf.gather(tf.constant(augment_shifts), shift_index) | ||
|
||
seq = tf.cond(tf.not_equal(shift_value, 0), | ||
lambda: shift_sequence(seq, shift_value), | ||
lambda: seq) | ||
|
||
return seq | ||
|
||
|
||
def augment_stochastic(data_ops, augment_rc=False, augment_shifts=[]): | ||
"""Apply stochastic augmentations, | ||
|
||
Args: | ||
data_ops: dict with keys 'sequence,' 'label,' and 'na.' | ||
augment_rc: Boolean for whether to apply reverse complement augmentation. | ||
augment_shifts: list of int offsets to sample shift augmentations. | ||
Returns: | ||
data_ops_aug: augmented data | ||
""" | ||
if augment_shifts: | ||
data_ops['sequence'] = augment_stochastic_shifts(data_ops['sequence'], | ||
augment_shifts) | ||
|
||
if augment_rc: | ||
data_ops = augment_stochastic_rc(data_ops) | ||
else: | ||
data_ops['reverse_preds'] = tf.zeros((), dtype=tf.bool) | ||
|
||
return data_ops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when would you use this function? It seems odd to call with augment_shift != 0, but only a single value.