Skip to content

Commit

Permalink
Fix plot_tensor in python3 environment
Browse files Browse the repository at this point in the history
  • Loading branch information
chaklim committed Nov 27, 2019
1 parent b0fa6ef commit e7b73ee
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions clair/plot_tensor.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import matplotlib.pyplot as plt
import sys
import os
import argparse
import math
import numpy as np
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from argparse import ArgumentParser

from clair.utils import setup_environment

def Prepare(args):
setup_environment()


def PlotTensor(ofn, XArray):
def plot_tensor(ofn, XArray):
plot = plt.figure(figsize=(15, 8))

plt.subplot(4, 1, 1)
Expand Down Expand Up @@ -44,7 +39,7 @@ def PlotTensor(ofn, XArray):
plt.close(plot)


def CreatePNGs(args):
def create_png(args):
f = open(args.array_fn, 'r')
array = f.read()
f.close()
Expand All @@ -63,7 +58,11 @@ def CreatePNGs(args):
# for i in range(len(splitted_array)):
# splitted_array[i] = int(splitted_array[i])

XArray = np.array(splitted_array).reshape((-1, 33, 8, 4))
XArray = np.array(splitted_array, dtype=np.float32).reshape((-1, 33, 8, 4))
XArray[0, :, :, 1] -= XArray[0, :, :, 0]
XArray[0, :, :, 2] -= XArray[0, :, :, 0]
XArray[0, :, :, 3] -= XArray[0, :, :, 0]

_YArray = np.zeros((1, 16))
varName = args.name
print("Plotting %s..." % (varName), file=sys.stderr)
Expand All @@ -73,11 +72,11 @@ def CreatePNGs(args):
os.makedirs(varName)

# Plot tensors
PlotTensor(varName+"/tensor.png", XArray)
plot_tensor(varName+"/tensor.png", XArray)


def ParseArgs():
parser = argparse.ArgumentParser(
parser = ArgumentParser(
description="Visualize tensors and hidden layers in PNG")

parser.add_argument('--array_fn', type=str, default="vartensors",
Expand All @@ -97,8 +96,8 @@ def ParseArgs():

def main():
args = ParseArgs()
Prepare(args)
CreatePNGs(args)
setup_environment()
create_png(args)


if __name__ == "__main__":
Expand Down

0 comments on commit e7b73ee

Please sign in to comment.