Matplotlib animation.FuncAnimation: Custom frame generator only yields once - python-3.x

I'm encountering a strange problem with the matplotlib animation. I'm trying to create a animated bar plot using the following code:
import os, time
from PIL import Image, ImageSequence
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.path as path
import matplotlib.animation as animation
import blackxample
FILE_PREFIX = "cell-isotohyper"
FILE_SUFFIX = ".tif"
FILE_PATH = "./example-video"
XCUT = (91, 91+266)
YCUT = (646, 646+252)
LIMIT = 100
OFFSET = 0
Y_SCALE = 3000
NUM_OF_BINS = 37
BAR_WIDTH = 1.0
BAR_COLOR = 'b'
RANGE = range(0, NUM_OF_BINS//2+1)
def animate(i, fig, ax, bars):
# = np.random.randn(1000)
print(len(i))
for a in RANGE:
bars[a].set_height(i[a])
return (fig, ax, bars)
def main():
fig, ax = plt.subplots()
ax.set_ylim(0, Y_SCALE)
ax.set_xlim(0, NUM_OF_BINS//2+1)
bars = ax.bar(np.arange(NUM_OF_BINS), [i for i in range(NUM_OF_BINS)], BAR_WIDTH, color=BAR_COLOR)
ani = animation.FuncAnimation(fig, animate, xframes, fargs = (fig, ax, bars), interval=500)
plt.show()
This code snippet works completely fine if I'm using randomly generated data or constant via:
def xframes():
i = 0
while i < 100:
yield [2312.7094266223335, 27.238786592368257, 75.252063484372513, 13.678304922077643, 11.879804374653929, 21.900570139020687, 2.930771773796323, 11.945594479736741, 10.88517941461987, 4.4176609254771506, 4.1075871395528338, 1.248363771876285, 1.4798157379442216, 3.5285036346353564, 3.2583080973651732, 3.4640042567344267, 3.130503535456981, 0.67334205875304676, 0.71393606581800562]
#yield np.histogram(np.random.randn(1000), NUM_OF_BINS//2 + 1)[0]
i+=1
Using the function, aframes, instead, does only yield the first item if it is used together animation.FuncAnimation(). If aframe is iterated manually, however, the generator works completely fine.
def aframes():
list_of_files = []
for dirname, dirnames, filenames in os.walk(FILE_PATH):
for filename in filenames:
if filename.startswith(FILE_PREFIX) and filename.endswith(FILE_SUFFIX):
list_of_files.append(os.path.join(FILE_PATH, filename))
# Open every picture - in every file
count = 0
imagecount = 0
framecount = 0
skipped = 0
for file in list_of_files:
framecount = 0
a = Image.open(file)
for frame in ImageSequence.Iterator(a):
if count > OFFSET and count <= OFFSET+LIMIT:
# Cut image beforehand - probably faster
frame = frame.crop((XCUT[0], YCUT[0], XCUT[1], YCUT[1]))
# Load image intro Matrix
imageMatrix = blackxample.Matrix.fromPillow(frame)
try:
imageMatrix.findContour()
imageMatrix.calculateCentroid()
imageMatrix.transform(NUM_OF_BINS)
#yield imageMatrix.getTransform()
yield [2312.7094266223335, 27.238786592368257, 75.252063484372513, 13.678304922077643, 11.879804374653929, 21.900570139020687, 2.930771773796323, 11.945594479736741, 10.88517941461987, 4.4176609254771506, 4.1075871395528338, 1.248363771876285, 1.4798157379442216, 3.5285036346353564, 3.2583080973651732, 3.4640042567344267, 3.130503535456981, 0.67334205875304676, 0.71393606581800562]
except blackxample.NoConvergenceError:
skipped+=1
print("[", count ,"] done")
framecount+=1
count+=1
imagecount+=1
# Test for frame iterator - works fine
#for i in _frames():
# print(i)
Does someone has a clue what and why is happening? How can I fix it?
The generator also runs as expected if the three imageMatrix-lines inside the try-block are commented out which suggests that there is an error inside imageMatrix.findContour(). But what am I looking for? findContour doesn't do anything weird

Since I have not found any solutions regarding this problem, I've decided to save the result of aframes() in a file, then reading and animating it seperatly which works flawlessly without adjusting the animation code.

Related

Static Colormap in Scatter Plot during Addition of Datapoints

I would like to add the datapoints to a locked portion of the screen one after the other, save all plots as a *.png and make a gif. Everything is working so far, but I want to lock the colorbar, so that it does not change its range during the addition of the points. I have no idea on how to do tackle this problem...
The input data is (the t_string will be modified to make it work):
x = [2.803480599999999, 5.5502475000000056, 6.984381300000002, 4.115224099999998, 5.746583699999995, 8.971469500000019, 12.028179500000032, 13.451193300000014, 12.457393999999972, 12.027555199999998, 16.077930800000015, 5.021229700000006, 11.206380399999999, 7.903262600000004, 11.98195070000001, 12.21701, 10.35045, 10.231890000000002]
y = [11.961321698938578, 5.218986480632915, 5.211628408660906, 4.847852635777481, 4.936266162218553, 5.233256380128127, 5.441388698929861, 5.461721129728066, 5.722170570613203, 5.2698434785261545, 5.645419662253215, 4.617062894639794, 4.973357261130752, 5.906843248930297, 5.256517482861392, 5.537361432952908, 5.339542403148332, 5.376979880224148]
t_string = ['2019-10-7', '2019-10-13', '2019-11-10', '2019-11-16', '2019-11-17', '2019-11-23', '2019-11-24', '2019-11-27', '2019-12-1', '2019-12-4', '2019-12-8', '2019-12-21', '2019-12-23', '2019-12-25', '2019-12-27', '2020-1-2', '2020-1-5', '2020-1-9']
Below you find the whole code I used. It will create a new directory in your working directory and write all files in there, you will also find the .gif in there. It might be necessary to install some packages. Many thanks in ad
# Import some stuff
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime
import pathlib
from pathlib import Path
import moviepy.editor as mpy
# Define Input
x = [2.803480599999999, 5.5502475000000056, 6.984381300000002, 4.115224099999998, 5.746583699999995, 8.971469500000019,
12.028179500000032, 13.451193300000014, 12.457393999999972, 12.027555199999998, 16.077930800000015,
5.021229700000006, 11.206380399999999, 7.903262600000004, 11.98195070000001, 12.21701, 10.35045,
10.231890000000002]
y = [11.961321698938578, 5.218986480632915, 5.211628408660906, 4.847852635777481, 4.936266162218553, 5.233256380128127,
5.441388698929861, 5.461721129728066, 5.722170570613203, 5.2698434785261545, 5.645419662253215, 4.617062894639794,
4.973357261130752, 5.906843248930297, 5.256517482861392, 5.537361432952908, 5.339542403148332, 5.376979880224148]
t_string = ['2019-10-7', '2019-10-13', '2019-11-10', '2019-11-16', '2019-11-17', '2019-11-23', '2019-11-24',
'2019-11-27', '2019-12-1', '2019-12-4', '2019-12-8', '2019-12-21', '2019-12-23', '2019-12-25', '2019-12-27',
'2020-1-2', '2020-1-5', '2020-1-9']
# Define a function to get the datafiles with a certain suffix in a path
def getfile_UI(file_directory, file_suffix):
from glob import glob
path_to_search = file_directory / file_suffix
filenames = glob(str(path_to_search))
return filenames
# Start of script/calculations
t = [mdates.date2num(datetime.strptime(i, "%Y-%m-%d")) for i in t_string]
workingdirectory_projectfolder = Path.cwd().parent
my_dpi = 75
for index, entry in enumerate(t_string):
fig = plt.figure(figsize=(480 / my_dpi, 480 / my_dpi), dpi=my_dpi)
sc = plt.scatter(x[0:index + 1],
y[0:index + 1],
c=t[0:index + 1])
plt.xlim(0, 20)
plt.ylim(4, 7)
plt.title(entry)
plt.xlabel("Distace [km]")
plt.ylabel("Pace [min/km]")
loc = mdates.AutoDateLocator()
fig.colorbar(sc, ticks=loc,
format=mdates.AutoDateFormatter(loc))
filename = 'png_' + str(index) + '.png'
new_dir = workingdirectory_projectfolder / "type 3 gif"
pathlib.Path(new_dir).mkdir(exist_ok=True)
plt.savefig(workingdirectory_projectfolder / "type 3 gif" / filename, dpi=96)
plt.gca()
# Make a GIF from all png files
# http://superfluoussextant.com/making-gifs-with-python.html
fps = 10
gif_name = str(fps) + " fps_""type3_"
workingdirectory_projectfolder = Path.cwd().parent
gif_path = workingdirectory_projectfolder / "type 3 gif" / gif_name
filenamelist_path = workingdirectory_projectfolder / "type 3 gif"
filenamelist_png = getfile_UI(filenamelist_path, "*.png")
list.sort(filenamelist_png, key=lambda x: int(
x.split('_')[1].split('.png')[0])) # Sort the images by #, this may need to be tweaked for your use case
clip = mpy.ImageSequenceClip(filenamelist_png, fps=fps)
clip.write_gif('{}.gif'.format(gif_path), fps=fps)
You could call plt.scatter with vmin=min(t), vmax=max(t). This fixes the limits used for coloring.
Something else you could add in your animation is to only show the tick dates up to the current:
loc = mdates.AutoDateLocator()
cbar = fig.colorbar(sc, ticks=loc, format=mdates.AutoDateFormatter(loc))
ticks = [ti for ti in cbar.get_ticks() if ti <= t[index]]
cbar.set_ticks(ticks)

Python 3 Multiprocessing and openCV problem with dictionary sharing between processor

I would like to use multiprocessing to compute the SIFT extraction and SIFT matching for object detection.
For now, I have a problem with the return value of the function that does not insert data in the dictionary.
I'm using Manager class and image that are open inside the function. But does not work.
Finally, my idea is:
Computer the keypoint for every reference image, use this keypoint as a parameter of a second function that compares and match with the keypoint and descriptors of the test image.
My code is:
# %% Import Section
import cv2
import numpy as np
from matplotlib import pyplot as plt
import os
from datetime import datetime
from multiprocessing import Process, cpu_count, Manager, Lock
import argparse
# %% path section
tests_path = 'TestImages/'
references_path = 'ReferenceImages2/'
result_path = 'ResultParametrizer/'
#%% Number of processor
cpus = cpu_count()
# %% parameter section
eps = 1e-7
useTwo = False # using the m and n keypoint better with False
# good point parameters
distanca_coefficient = 0.75
# gms parameter
gms_thresholdFactor = 3
gms_withRotation = True
gms_withScale = True
# flann parameter
flann_trees = 5
flann_checks = 50
#%% Locker
lock = Lock()
# %% function definition
def keypointToDictionaries(keypoint):
x, y = keypoint.pt
pt = float(x), float(y)
angle = float(keypoint.angle) if keypoint.angle is not None else None
size = float(keypoint.size) if keypoint.size is not None else None
response = float(keypoint.response) if keypoint.response is not None else None
class_id = int(keypoint.class_id) if keypoint.class_id is not None else None
octave = int(keypoint.octave) if keypoint.octave is not None else None
return {
'point': pt,
'angle': angle,
'size': size,
'response': response,
'class_id': class_id,
'octave': octave
}
def dictionariesToKeypoint(dictionary):
kp = cv2.KeyPoint()
kp.pt = dictionary['pt']
kp.angle = dictionary['angle']
kp.size = dictionary['size']
kp.response = dictionary['response']
kp.octave = dictionary['octave']
kp.class_id = dictionary['class_id']
return kp
def rootSIFT(dictionary, image_name, image_path,eps=eps):
# SIFT init
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
sift = cv2.xfeatures2d.SIFT_create()
keypoints, descriptors = sift.detectAndCompute(image, None)
descriptors /= (descriptors.sum(axis=1, keepdims=True) + eps)
descriptors = np.sqrt(descriptors)
print('Finito di calcolare, PID: ', os.getpid())
lock.acquire()
dictionary[image_name]['keypoints'] = keypoints
dictionary[image_name]['descriptors'] = descriptors
lock.release()
def featureMatching(reference_image, reference_descriptors, reference_keypoints, test_image, test_descriptors,
test_keypoints, flann_trees=flann_trees, flann_checks=flann_checks):
# FLANN parameter
FLANN_INDEX_KDTREE = 1
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=flann_trees)
search_params = dict(checks=flann_checks) # or pass empty dictionary
flann = cv2.FlannBasedMatcher(index_params, search_params)
flann_matches = flann.knnMatch(reference_descriptors, test_descriptors, k=2)
matches_copy = []
for i, (m, n) in enumerate(flann_matches):
if m.distance < distanca_coefficient * n.distance:
matches_copy.append(m)
gsm_matches = cv2.xfeatures2d.matchGMS(reference_image.shape, test_image.shape, keypoints1=reference_keypoints,
keypoints2=test_keypoints, matches1to2=matches_copy,
withRotation=gms_withRotation, withScale=gms_withScale,
thresholdFactor=gms_thresholdFactor)
#%% Starting reference list file creation
reference_init = datetime.now()
print('Start reference file list creation')
reference_image_process_list = []
manager = Manager()
reference_image_dictionary = manager.dict()
reference_image_list = manager.list()
for root, directories, files in os.walk(references_path):
for file in files:
if file.endswith('.DS_Store'):
continue
reference_image_path = os.path.join(root, file)
reference_name = file.split('.')[0]
image = cv2.imread(reference_image_path, cv2.IMREAD_GRAYSCALE)
reference_image_dictionary[reference_name] = {
'image': image,
'keypoints': None,
'descriptors': None
}
proc = Process(target=rootSIFT, args=(reference_image_list, reference_name, reference_image_path))
reference_image_process_list.append(proc)
proc.start()
for proc in reference_image_process_list:
proc.join()
reference_end = datetime.now()
reference_time = reference_end - reference_init
print('End reference file list creation, time required: ', reference_time)
I faced pretty much the same error. It seems that the code hangs at detectAndCompute in my case, not when creating the dictionary. For some reason, sift feature extraction is not multi-processing safe (to my understanding, it is the case in Macs but I am not totally sure.)
I found this in a github thread. Many people say it works but I couldn't get it worked. (Edit: I tried this later which works fine)
Instead I used multithreading which is pretty much the same code and works perfectly. Of course you need to take multithreading vs multiprocessing into account

How to generate heat map on the Whole Slide Images (.svs format) using some probability values?

I am trying to generate heat map, or probability map, for Whole Slide Images (WSIs) using probability values. I have coordinate points (which determine areas on the WSIs) and corresponding probability values.
Basic Introduction on WSI: WSIs are large is size (almost 100000 x 100000 pixels). Hence, can't open these images using normal image viewer. The WSIs are processed using OpenSlide software.
I have seen previous posts in Stack Overflow on related to heat map, but as WSIs are processed in a different way, I am unable to figure out how to apply these solutions. Some examples that I followed: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, etc.
To generate heat map on WSIs, follow below instructions:
First of all Extract image patches and save the coordinates. Use below code for patch extraction. The code require some changes as per the requirements. The code has been copied from: patch extraction code link
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import logging
try:
import Image
except:
from PIL import Image
import math
import numpy as np
import openslide
import os
from time import strftime,gmtime
parser = argparse.ArgumentParser(description='Extract a series of patches from a whole slide image')
parser.add_argument("-i", "--image", dest='wsi', nargs='+', required=True, help="path to a whole slide image")
parser.add_argument("-p", "--patch_size", dest='patch_size', default=299, type=int, help="pixel width and height for patches")
parser.add_argument("-b", "--grey_limit", dest='grey_limit', default=0.8, type=float, help="greyscale value to determine if there is sufficient tissue present [default: `0.8`]")
parser.add_argument("-o", "--output", dest='output_name', default="output", help="Name of the output file directory [default: `output/`]")
parser.add_argument("-v", "--verbose",
dest="logLevel",
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default="INFO",
help="Set the logging level")
args = parser.parse_args()
if args.logLevel:
logging.basicConfig(level=getattr(logging, args.logLevel))
wsi=' '.join(args.wsi)
""" Set global variables """
mean_grey_values = args.grey_limit * 255
number_of_useful_regions = 0
wsi=os.path.abspath(wsi)
outname=os.path.abspath(args.output_name)
basename = os.path.basename(wsi)
level = 0
def main():
img,num_x_patches,num_y_patches = open_slide()
logging.debug('img: {}, num_x_patches = {}, num_y_patches: {}'.format(img,num_x_patches,num_y_patches))
for x in range(num_x_patches):
for y in range(num_y_patches):
img_data = img.read_region((x*args.patch_size,y*args.patch_size),level, (args.patch_size, args.patch_size))
print_pics(x*args.patch_size,y*args.patch_size,img_data,img)
pc_uninformative = number_of_useful_regions/(num_x_patches*num_y_patches)*100
pc_uninformative = round(pc_uninformative,2)
logging.info('Completed patch extraction of {} images.'.format(number_of_useful_regions))
logging.info('{}% of the image is uninformative\n'.format(pc_uninformative))
def print_pics(x_top_left,y_top_left,img_data,img):
if x_top_left % 100 == 0 and y_top_left % 100 == 0 and x_top_left != 0:
pc_complete = round(x_top_left /img.level_dimensions[0][0],2) * 100
logging.info('{:.2f}% Complete at {}'.format(pc_complete,strftime("%a, %d %b %Y %H:%M:%S +0000", gmtime())))
exit()
img_data_np = np.array(img_data)
""" Convert to grayscale"""
grey_img = rgb2gray(img_data_np)
if np.mean(grey_img) < mean_grey_values:
logging.debug('Image grayscale = {} compared to threshold {}'.format(np.mean(grey_img),mean_grey_values))
global number_of_useful_regions
number_of_useful_regions += 1
wsi_base = os.path.basename(wsi)
wsi_base = wsi_base.split('.')[0]
img_name = wsi_base + "_" + str(x_top_left) + "_" + str(y_top_left) + "_" + str(args.patch_size)
#write_img_rotations(img_data_np,img_name)
logging.debug('Saving {} {} {}'.format(x_top_left,y_top_left,np.mean(grey_img)))
save_image(img_data_np,1,img_name)
def gen_x_and_y(xlist,ylist,img):
for x in xlist:
for y in ylist:
img_data = img.read_region((x*args.patch_size,y*args.patch_size),level, (args.patch_size, args.patch_size))
yield (x, y,img_data)
def open_slide():
"""
The first level is always the main image
Get width and height tuple for the first level
"""
logging.debug('img: {}'.format(wsi))
img = openslide.OpenSlide(wsi)
img_dim = img.level_dimensions[0]
"""
Determine what the patch size should be, and how many iterations it will take to get through the WSI
"""
num_x_patches = int(math.floor(img_dim[0] / args.patch_size))
num_y_patches = int(math.floor(img_dim[1] / args.patch_size))
remainder_x = img_dim[0] % num_x_patches
remainder_y = img_dim[1] % num_y_patches
logging.debug('The WSI shape is {}'.format(img_dim))
logging.debug('There are {} x-patches and {} y-patches to iterate through'.format(num_x_patches,num_y_patches))
return img,num_x_patches,num_y_patches
def validate_dir_exists():
if os.path.isdir(outname) == False:
os.mkdir(outname)
logging.debug('Validated {} directory exists'.format(outname))
if os.path.exists(wsi):
logging.debug('Found the file {}'.format(wsi))
else:
logging.debug('Could not find the file {}'.format(wsi))
exit()
def rgb2gray(rgb):
"""Converts an RGB image into grayscale """
r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
def save_image(img,j,img_name):
tmp = os.path.join(outname,img_name+"_"+str(j)+".png")
try:
im = Image.fromarray(img)
im.save(tmp)
except:
print('Could not print {}'.format(tmp))
exit()
if __name__ == '__main__':
validate_dir_exists()
main()
Secondly, generate the probability values of each patches.
Finally, replace all the pixel values within a coordinates with the corresponding probability values and display the results using color maps.
This is the basic idea of generating heat map on WSIs. You can modify the code and concept to get a heat map as per your wish.
We have developed a python package for processing whole-slide-images:
https://github.com/amirakbarnejad/PyDmed
Here is a tutorial for getting heatmaps for whole-slide-images:
https://amirakbarnejad.github.io/Tutorial/tutorial_section5.html.
Also here is a sample notebook that gets heatmaps for WSIs using PyDmed:
Link to the sample notebook.
The benefit of PyDmed is that it is multi-processed. The dataloader sends a stream of patches to GPU(s), and the StreamWriter writes to disk in a separate process. Therefore, it is highly efficient. The running time of course depends on the machine, the size of WSIs, etc. On a good machine with a good GPU, PyDmed can generate heatmaps for ~120 WSIs in one day.

Ploting building surfaces from CityGML data

I've been struggling with 3D ploting some coordinates since a long ago and now I'm really frustrated, so your help will be really appreciated.
I'd like to plot the facade of a building from a CityGML file (which is originally simply an XML file). I have no problem with parsing the CityGML file using XML.etree and extracting the coordinates. But after extracting the coordinates, I cann't find a way to 3D plot them.
from xml.etree import ElementTree as ET
tree = ET.parse('3860_5819__.gml')
root = tree.getroot()
namespaces = {
'ns0': "http://www.opengis.net/citygml/1.0",
'ns1': "http://www.opengis.net/gml",
'ns2': "http://www.opengis.net/citygml/building/1.0"
}
c = 0
wallString = []
for wallSurface in root.findall('.//ns2:WallSurface', namespaces):
for posList in wallSurface.findall('.//ns1:posList', namespaces):
c += 1
wallCoordinates = posList.text
wallCoordinates = wallCoordinates.split()
wallString.append(wallCoordinates)
verts = []
for string in wallString:
X, Y, Z = [], [], []
c = 0
for value in string:
value = float(value)
if c % 3 == 0:
X.append(value)
elif c % 3 == 1:
Y.append(value)
else:
Z.append(value)
c += 1
if c > len(string) - 3:
break
vert = [list(zip(X, Y, Z))]
verts.append(vert)
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import matplotlib.pyplot as plt
fig = plt.figure()
ax = Axes3D(fig)
for vert in verts:
ax.add_collection3d(Poly3DCollection(vert))
ax.autoscale_view(tight=True, scalex=True, scaley=True, scalez=True)
plt.show()
plt.close()
Could the problem be that I can't make my plot "tight"? And if not, is there something I'm doing fundamentally wrong?
If relevant, the CityGML file in this case is related to TU Berlin center of entrepreneurship which can be taken from here.
Just realized that nothing was wrong with the main part of the code. The only issue was that the axis were not set. I change the plot part like this:
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d as mpl3
fig = plt.figure()
ax = mpl3.Axes3D(fig)
for vert in verts:
poly = mpl3.art3d.Poly3DCollection(vert)
ax.add_collection3d(poly)
ax.set_xlim3d(left=386284-50,right=386284+50)
ax.set_ylim3d(bottom=5819224-50, top=5819224+50)
ax.set_zlim3d(bottom=32-10,top=32+20)
plt.show()
plt.close()
Now it works perfectly fine.

how to update a matplotlib heatmap plot without creating a new window

I have matrix class that inherits from list. This class can display itself as a matplotlib heatmap representation of the matrix.
I'm trying to have the class written such that when I change values in the matrix, I can call the matrix's method plot() and it'll update the plot to reflect the matrix changes in the heatmap.
However, every time I run the method plot(), it creates a new heatmap in a new window instead of updating the existing plot. How could I get it simply to update the existing plot?
In the code below, there are three main parts: the main function shows how an instance of the matrix class is created, plotted and updated; the matrix class is basically a list object, with some minor functionality (including plotting) bolted on; the function plotList() is the function the matrix class calls in order to generate the plot object initially.
import time
import random
import matplotlib.pyplot as plt
plt.ion()
import numpy as np
def main():
print("plot 2 x 2 matrix and display it changing in a loop")
matrix = Matrix(
numberOfColumns = 2,
numberOfRows = 2,
randomise = True
)
# Plot the matrix.
matrix.plot()
# Change the matrix, redrawing it after each change.
for row in range(len(matrix)):
for column in range(len(matrix[row])):
input("Press Enter to continue.")
matrix[row][column] = 10
matrix.plot()
input("Press Enter to terminate.")
matrix.closePlot()
class Matrix(list):
def __init__(
self,
*args,
numberOfColumns = 3,
numberOfRows = 3,
element = 0.0,
randomise = False,
randomiseLimitLower = -0.2,
randomiseLimitUpper = 0.2
):
# list initialisation
super().__init__(self, *args)
self.numberOfColumns = numberOfColumns
self.numberOfRows = numberOfRows
self.element = element
self.randomise = randomise
self.randomiseLimitLower = randomiseLimitLower
self.randomiseLimitUpper = randomiseLimitUpper
# fill with default element
for column in range(self.numberOfColumns):
self.append([element] * self.numberOfRows)
# fill with pseudorandom elements
if self.randomise:
random.seed()
for row in range(self.numberOfRows):
for column in range(self.numberOfColumns):
self[row][column] = random.uniform(
self.randomiseLimitUpper,
self.randomiseLimitLower
)
# plot
self._plot = plotList(
list = self,
mode = "return"
)
# for display or redraw plot behaviour
self._plotShown = False
def plot(self):
# display or redraw plot
self._plot.draw()
if self._plotShown:
#self._plot = plotList(
# list = self,
# mode = "return"
# )
array = np.array(self)
fig, ax = plt.subplots()
heatmap = ax.pcolor(array, cmap = plt.cm.Blues)
self._plot.draw()
else:
self._plot.show()
self._plotShown = True
def closePlot(self):
self._plot.close()
def plotList(
list = list,
mode = "plot" # plot/return
):
# convert list to NumPy array
array = np.array(list)
# create axis labels
labelsColumn = []
labelsRow = []
for rowNumber in range(0, len(list)):
labelsRow.append(rowNumber + 1)
for columnNumber in range(0, len(list[rowNumber])):
labelsColumn.append(columnNumber)
fig, ax = plt.subplots()
heatmap = ax.pcolor(array, cmap = plt.cm.Blues)
# display plot or return plot object
if mode == "plot":
plt.show()
elif mode == "return":
return(plt)
else:
Exception
if __name__ == '__main__':
main()
I'm using Python 3 in Ubuntu.
The method plot(self) creates a new figure in the line fig, ax = plt.subplots(). To use an existing figure you can give your figure a number or name when it's first created in plotList():
fig = plt.figure('matrix figure')
ax = fig.add_subplot(111)
then use
plt.figure('matrix figure')
ax = gca() # gets current axes
to make that the active figure and axes. Alternately, you might want to the figure and axis created in plotList and pass them to plot.

Resources