Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include a factor of x8 #78

Open
EmanuelCastanho opened this issue Jun 21, 2024 · 1 comment
Open

Include a factor of x8 #78

EmanuelCastanho opened this issue Jun 21, 2024 · 1 comment

Comments

@EmanuelCastanho
Copy link

Hi,

I am trying to include on SR4RS a factor of 8. For this, I changed the following:

Inside constants.py:
factors = [1, 2, 4, 8]

Inside network.py, on the discriminator function:

def discriminator(hr_images, scope, dim):
    """
    Discriminator
    """
    conv_lrelu = partial(conv, activation_fn=lrelu)

    def _combine(x, newdim, name, z=None):
        x = conv_lrelu(x, newdim, 1, 1, name)
        y = x if z is None else tf.concat([x, z], axis=-1)
        return minibatch_stddev_layer(y)

    def _conv_downsample(x, dim, ksize, name):
        y = conv2d_downscale2d(x, dim, ksize, name=name)
        y = lrelu(y)
        return y

    with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
       with tf.compat.v1.variable_scope("res_8x"):
            net = _combine(hr_images[1], newdim=dim, name="from_input")
            net = conv_lrelu(net, dim, 3, 1, "conv1")
            net = conv_lrelu(net, dim, 3, 1, "conv2")
            net = conv_lrelu(net, dim, 3, 1, "conv3")
            net = _conv_downsample(net, dim, 3, "conv_down")
        
        with tf.compat.v1.variable_scope("res_4x"):
            net = _combine(hr_images[2], newdim=dim, name="from_input", z=net)
            dim *= 2
            net = conv_lrelu(net, dim, 3, 1, "conv1")
            net = conv_lrelu(net, dim, 3, 1, "conv2")
            net = conv_lrelu(net, dim, 3, 1, "conv3")
            net = _conv_downsample(net, dim, 3, "conv_down")

        with tf.compat.v1.variable_scope("res_2x"):
            net = _combine(hr_images[4], newdim=dim, name="from_input", z=net)
            dim *= 2
            net = conv_lrelu(net, dim, 3, 1, "conv1")
            net = conv_lrelu(net, dim, 3, 1, "conv2")
            net = conv_lrelu(net, dim, 3, 1, "conv3")
            net = _conv_downsample(net, dim, 3, "conv_down")

        with tf.compat.v1.variable_scope("res_1x"):
            net = _combine(hr_images[8], newdim=dim, name="from_input", z=net)
            dim *= 2
            net = conv_lrelu(net, dim, 3, 1, "conv")
            net = _conv_downsample(net, dim, 3, "conv_down")

        with tf.compat.v1.variable_scope("bn"):
            dim *= 2
            net = conv_lrelu(net, dim, 3, 1, "conv1")
            net = _conv_downsample(net, dim, 3, "conv_down1")
            net = minibatch_stddev_layer(net)

            # dense
            dim *= 2
            net = conv_lrelu(net, dim, 1, 1, "dense1")
            net = conv(net, 1, 1, 1, "dense2")
            net = tf.reduce_mean(net, axis=[1, 2])

            return net

Inside network.py, on the generator function:

def generator(lr_image, scope, nchannels, nresblocks, dim):
    """
    Generator
    """
    hr_images = dict()

    def conv_upsample(x, dim, ksize, name):
        y = upscale2d_conv2d(x, dim, ksize, name)
        y = blur2d(y)
        y = lrelu(y)
        y = pixel_norm(y)
        return y

    def _residule_block(x, dim, name):
        with tf.compat.v1.variable_scope(name):
            y = conv(x, dim, 3, 1, "conv1")
            y = lrelu(y)
            y = pixel_norm(y)
            y = conv(y, dim, 3, 1, "conv2")
            y = pixel_norm(y)
            return y + x

    def conv_bn(x, dim, ksize, name):
        y = conv(x, dim, ksize, 1, name)
        y = lrelu(y)
        y = pixel_norm(y)
        return y

    def _make_output(net, factor):
        hr_images[factor] = conv(net, nchannels, 1, 1, "output")

    with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
        with tf.compat.v1.variable_scope("encoder"):
            net = lrelu(conv(lr_image, dim, 9, 1, "conv1_9x9"))
            conv1 = net
            for i in range(nresblocks):
                net = _residule_block(net, dim=dim, name="ResBlock{}".format(i))

        with tf.compat.v1.variable_scope("res_1x"):
            net = conv(net, dim, 3, 1, "conv1")
            net = pixel_norm(net)
            net += conv1
            _make_output(net, factor=8)

        with tf.compat.v1.variable_scope("res_2x"):
            net = conv_upsample(net, 4 * dim, 3, "conv_upsample")
            net = conv_bn(net, 4 * dim, 3, "conv1")
            net = conv_bn(net, 4 * dim, 3, "conv2")
            net = conv_bn(net, 4 * dim, 5, "conv3")
            _make_output(net, factor=4)

        with tf.compat.v1.variable_scope("res_4x"):
            net = conv_upsample(net, 4 * dim, 3, "conv_upsample")
            net = conv_bn(net, 4 * dim, 3, "conv1")
            net = conv_bn(net, 4 * dim, 3, "conv2")
            net = conv_bn(net, 4 * dim, 9, "conv3")
            _make_output(net, factor=2)

        with tf.compat.v1.variable_scope("res_8x"):
            net = conv_upsample(net, 4 * dim, 3, "conv_upsample")
            net = conv_bn(net, 4 * dim, 3, "conv1")
            net = conv_bn(net, 4 * dim, 3, "conv2")
            net = conv_bn(net, 4 * dim, 9, "conv3")
            _make_output(net, factor=1)

        return hr_images

Any ideas or suggestions?

@remicres
Copy link
Owner

remicres commented Sep 2, 2024

Hi @EmanuelCastanho, I guess you're on the right track. You also have to modify the main program to feed the various resampled inputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants