Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensemble #19

Open
wants to merge 72 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
26891d5
failing with map_fn
davek44 Jun 30, 2018
41444af
cond refuses transform_fn
davek44 Jun 30, 2018
0bb926d
another failed attempt
davek44 Jul 1, 2018
32a529c
moving cond downstream but map_fn still fails
davek44 Jul 1, 2018
bf42d4e
separate train and eval loss
davek44 Jul 3, 2018
76f3340
needs preds_length
davek44 Jul 4, 2018
2e57fd6
augmentation methods
davek44 Jul 7, 2018
2797ccb
h5 in-graph augmentation
davek44 Jul 7, 2018
b66bc9c
tuning
davek44 Jul 7, 2018
257544b
tran epoch bug
davek44 Jul 7, 2018
a013f9e
h5 ensembling in-graph
davek44 Jul 7, 2018
18386bc
test feed dict
davek44 Jul 8, 2018
9686338
update placeholder
davek44 Jul 11, 2018
cd9a875
0 shift defaults
davek44 Jul 11, 2018
558f53f
testing
davek44 Jul 11, 2018
03fad02
target labels
davek44 Jul 11, 2018
bcef7ab
float64 loss mean
davek44 Jul 11, 2018
19e110b
default shift 0
davek44 Jul 11, 2018
23d2f70
no data open
davek44 Jul 11, 2018
b4e5df5
create new data_ops dict
davek44 Jul 12, 2018
aaa46fa
average predictions, not representations
davek44 Jul 12, 2018
cd77912
rename build
davek44 Jul 20, 2018
df29615
predict in-graph ensembling
davek44 Jul 21, 2018
269a2ed
penultimate draft
davek44 Jul 21, 2018
a16a31b
missing comma
davek44 Jul 22, 2018
2e5af1c
penultimate loss fix
davek44 Jul 22, 2018
f376d01
hidden and map
davek44 Jul 23, 2018
29dd294
align predict_h5 with predict_h5_manual
davek44 Jul 26, 2018
0f0f34a
sqrt soft clip
davek44 Aug 10, 2018
672b434
debugging TFR
davek44 Aug 11, 2018
e0a0408
nan baseline
davek44 Aug 12, 2018
777b0e5
seqs_per_tfr
davek44 Aug 12, 2018
ab0ed1d
shuffle bug
davek44 Aug 13, 2018
c86098f
failing with map_fn
davek44 Jun 30, 2018
391abd4
cond refuses transform_fn
davek44 Jun 30, 2018
77bda7f
another failed attempt
davek44 Jul 1, 2018
57ee2c3
moving cond downstream but map_fn still fails
davek44 Jul 1, 2018
d917be8
separate train and eval loss
davek44 Jul 3, 2018
17fd1ce
needs preds_length
davek44 Jul 4, 2018
96c01d6
augmentation methods
davek44 Jul 7, 2018
1770abd
h5 in-graph augmentation
davek44 Jul 7, 2018
65c7bb5
tuning
davek44 Jul 7, 2018
68f5f65
tran epoch bug
davek44 Jul 7, 2018
702ec1e
h5 ensembling in-graph
davek44 Jul 7, 2018
40a960e
test feed dict
davek44 Jul 8, 2018
e80c4fb
update placeholder
davek44 Jul 11, 2018
dfa8224
0 shift defaults
davek44 Jul 11, 2018
c253cb9
testing
davek44 Jul 11, 2018
f68c6ac
target labels
davek44 Jul 11, 2018
b344161
float64 loss mean
davek44 Jul 11, 2018
74b12c3
default shift 0
davek44 Jul 11, 2018
99fdf38
no data open
davek44 Jul 11, 2018
53b9a53
create new data_ops dict
davek44 Jul 12, 2018
6799d25
average predictions, not representations
davek44 Jul 12, 2018
2769f82
rename build
davek44 Jul 20, 2018
44b099d
predict in-graph ensembling
davek44 Jul 21, 2018
f6c4379
penultimate draft
davek44 Jul 21, 2018
ad01d86
missing comma
davek44 Jul 22, 2018
0ec3447
penultimate loss fix
davek44 Jul 22, 2018
3f35246
hidden and map
davek44 Jul 23, 2018
2f84d67
align predict_h5 with predict_h5_manual
davek44 Jul 26, 2018
7ff7ed2
sad tf.data
davek44 Aug 6, 2018
6b5e502
reverse_complement bug fix
davek44 Aug 6, 2018
51a3671
optimizing
davek44 Aug 6, 2018
5e19e7e
sadq multi
davek44 Aug 7, 2018
1d6bb21
dynamic batch size
davek44 Aug 10, 2018
74dec8b
full batches unnecessary
davek44 Aug 10, 2018
8264f08
weighted average losses
davek44 Aug 10, 2018
eb5c224
sad optimizations
davek44 Aug 15, 2018
33fa21e
increase shuffle buffer
davek44 Aug 15, 2018
a4c05a6
typo and descriptions
davek44 Aug 18, 2018
415a8ca
conflicts
davek44 Aug 20, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions basenji/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

import pdb
import tensorflow as tf

from basenji import ops

def shift_sequence(seq, shift_amount, pad_value=0.25):
"""Shift a sequence left or right by shift_amount.

Args:
seq: a [batch_size, sequence_length, sequence_depth] sequence to shift
shift_amount: the signed amount to shift (tf.int32 or int)
pad_value: value to fill the padding (primitive or scalar tf.Tensor)
"""
if seq.shape.ndims != 3:
raise ValueError('input sequence should be rank 3')
input_shape = seq.shape

pad = pad_value * tf.ones_like(seq[:, 0:tf.abs(shift_amount), :])

def _shift_right(_seq):
sliced_seq = _seq[:, :-shift_amount:, :]
return tf.concat([pad, sliced_seq], axis=1)

def _shift_left(_seq):
sliced_seq = _seq[:, -shift_amount:, :]
return tf.concat([sliced_seq, pad], axis=1)

output = tf.cond(
tf.greater(shift_amount, 0), lambda: _shift_right(seq),
lambda: _shift_left(seq))

output.set_shape(input_shape)
return output

def augment_deterministic_set(data_ops, augment_rc=False, augment_shifts=[0]):
"""

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
augment_rc: Boolean
augment_shifts: List of ints.
Returns
data_ops_list:
"""
augment_pairs = []
for ashift in augment_shifts:
augment_pairs.append((False, ashift))
if augment_rc:
augment_pairs.append((True, ashift))

data_ops_list = []
for arc, ashift in augment_pairs:
data_ops_aug = augment_deterministic(data_ops, arc, ashift)
data_ops_list.append(data_ops_aug)

return data_ops_list


def augment_deterministic(data_ops, augment_rc=False, augment_shift=0):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would you use this function? It seems odd to call with augment_shift != 0, but only a single value.

"""Apply a deterministic augmentation, specified by the parameters.

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
augment_rc: Boolean
<<<<<<< HEAD
augment_shift: Int
Returns
data_ops: augmented data, with all existing keys transformed
and 'reverse_preds' bool added.
"""

data_ops_aug = {}
if 'label' in data_ops:
data_ops_aug['label'] = data_ops['label']
if 'na' in data_ops:
data_ops_aug['na'] = data_ops['na']
=======
augment_shifts: Int
Returns
data_ops: augmented data
"""

data_ops_aug = {'label': data_ops['label'], 'na': data_ops['na']}
>>>>>>> 29dd294bf104eb6f38559a6665fc2ff7d233afc9

if augment_shift == 0:
data_ops_aug['sequence'] = data_ops['sequence']
else:
shift_amount = tf.constant(augment_shift, shape=(), dtype=tf.int64)
data_ops_aug['sequence'] = shift_sequence(data_ops['sequence'], shift_amount)

if augment_rc:
data_ops_aug = augment_deterministic_rc(data_ops_aug)
else:
data_ops_aug['reverse_preds'] = tf.zeros((), dtype=tf.bool)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the semantics of these as targets?


return data_ops_aug


def augment_deterministic_rc(data_ops):
"""Apply a deterministic reverse complement augmentation.

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
Returns
data_ops_aug: augmented data ops
"""
<<<<<<< HEAD
data_ops_aug = ops.reverse_complement_transform(data_ops)
data_ops_aug['reverse_preds'] = tf.ones((), dtype=tf.bool)
=======
seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']]
seq, label, na = ops.reverse_complement_transform(seq, label, na)
reverse_preds = tf.ones((), dtype=tf.bool)
data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds}
>>>>>>> 29dd294bf104eb6f38559a6665fc2ff7d233afc9
return data_ops_aug


def augment_stochastic_rc(data_ops):
"""Apply a stochastic reverse complement augmentation.

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
Returns
data_ops_aug: augmented data
"""
<<<<<<< HEAD
reverse_preds = tf.random_uniform(shape=[]) > 0.5
data_ops_aug = tf.cond(reverse_preds, lambda: ops.reverse_complement_transform(data_ops),
lambda: data_ops.copy())
data_ops_aug['reverse_preds'] = reverse_preds
=======
seq, label, na = [data_ops[k] for k in ['sequence', 'label', 'na']]
reverse_preds = tf.random_uniform(shape=[]) > 0.5
seq, label, na = tf.cond(reverse_preds, lambda: ops.reverse_complement_transform(seq, label, na),
lambda: (seq, label, na))
data_ops_aug = {'sequence': seq, 'label': label, 'na': na, 'reverse_preds':reverse_preds}
>>>>>>> 29dd294bf104eb6f38559a6665fc2ff7d233afc9
return data_ops_aug


def augment_stochastic_shifts(seq, augment_shifts):
"""Apply a stochastic shift augmentation.

Args:
seq: input sequence of size [batch_size, length, depth]
augment_shifts: list of int offsets to sample from
Returns:
shifted and padded sequence of size [batch_size, length, depth]
"""
shift_index = tf.random_uniform(shape=[], minval=0,
maxval=len(augment_shifts), dtype=tf.int64)
shift_value = tf.gather(tf.constant(augment_shifts), shift_index)

seq = tf.cond(tf.not_equal(shift_value, 0),
lambda: shift_sequence(seq, shift_value),
lambda: seq)

return seq


def augment_stochastic(data_ops, augment_rc=False, augment_shifts=[]):
"""Apply stochastic augmentations,

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
augment_rc: Boolean for whether to apply reverse complement augmentation.
augment_shifts: list of int offsets to sample shift augmentations.
Returns:
data_ops_aug: augmented data
"""
if augment_shifts:
data_ops['sequence'] = augment_stochastic_shifts(data_ops['sequence'],
augment_shifts)

if augment_rc:
data_ops = augment_stochastic_rc(data_ops)
else:
data_ops['reverse_preds'] = tf.zeros((), dtype=tf.bool)

return data_ops
6 changes: 3 additions & 3 deletions basenji/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,19 @@ def next(self, fwdrc=True, shift=0):

# initialize
Xb = np.zeros(
(self.batch_size, self.seq_len, self.seq_depth), dtype='float32')
(Nb, self.seq_len, self.seq_depth), dtype='float32')
if self.Yf is not None:
if self.Yf.dtype == np.uint8:
ytype = 'int32'
else:
ytype = 'float32'

Yb = np.zeros(
(self.batch_size, self.seq_len // self.pool_width,
(Nb, self.seq_len // self.pool_width,
self.num_targets),
dtype=ytype)
NAb = np.zeros(
(self.batch_size, self.seq_len // self.pool_width), dtype='bool')
(Nb, self.seq_len // self.pool_width), dtype='bool')

# copy data
for i in range(Nb):
Expand Down
27 changes: 23 additions & 4 deletions basenji/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,34 @@ def adjust_max(start, stop, start_value, stop_value, name=None):
else:
return None

def reverse_complement_transform(seq, label, na):
def reverse_complement_transform(data_ops):
"""Reverse complement of batched onehot seq and corresponding label and na."""

# initialize reverse complemented data_ops
data_ops_rc = {}

# extract sequence from dict
seq = data_ops['sequence']

# check rank
rank = seq.shape.ndims
if rank != 3:
raise ValueError("input seq must be rank 3.")

complement = tf.gather(seq, [3, 2, 1, 0], axis=-1)
return (tf.reverse(complement, axis=[1]), tf.reverse(label, axis=[1]),
tf.reverse(na, axis=[1]))
# reverse complement sequence
seq_rc = tf.gather(seq, [3, 2, 1, 0], axis=-1)
seq_rc = tf.reverse(seq_rc, axis=[1])
data_ops_rc['sequence'] = seq_rc

# reverse labels
if 'label' in data_ops:
data_ops_rc['label'] = tf.reverse(data_ops['label'], axis=[1])

# reverse NA
if 'na' in data_ops:
data_ops_rc['na'] = tf.reverse(data_ops['na'], axis=[1])

return data_ops_rc


def reverse_complement(input_seq, lengths=None):
Expand Down
Loading