generate multiple plots by querying mongodb using multiprocessing - python-3.x

I would like to speed up a plotting function that looks up data from mongodb atlas. I used examples from online, however I'm not sure if it is the correct implementation. Using multiprocessing.Pool() seems slower than doing it without the package. What am I doing wrong? Thanks.
from pymongo import MongoClient
from matplotlib.backends.backend_svg import FigureCanvasSVG
from matplotlib.figure import Figure
import io
import multiprocessing
import time
lstOfwavelengths = list(range(220,810,10))
def build_graph_mongo_multiproc(pltcodeWithSuffix,wellID):
client = MongoClient()
db = client.databasename
img = io.BytesIO()
fig = Figure(figsize=(0.6,0.6))
axis = fig.add_subplot(1,1,1)
absvals = db[pltcodeWithSuffix].find({"Wavelength":wavelength})
absvals = {k:v for k,v in absvals[0].items() if k}
axis.plot(lstOfwavelengths,absvals)
axis.set_title(f'{pltcodeWithSuffix}:{wellID}',fontsize=9)
axis.title.set_position([.5, .6])
axis.tick_params(
which='both',
bottom=False,
left=False,
labelbottom=False,
labelleft=False)
FigureCanvasSVG(fig).print_svg(img)
lstOfPlts.append(img.getvalue() )
The only difference from the single and multiproc function is that the MongoClient is called once, outside the function.

I found this great article: The efficient way of using multiprocessing with pymongo
Using the article as a template, I was able to reduce the computation time to ~7.5 seconds instead of 21 seconds. I'm sure someone more experienced can shave off more time, but I think that is suffice for my level.
manager = multiprocessing.Manager()
lstOfPlots = manager.list()
def chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]
def getAllWellVals(db,pltcodeWithSuffix,wellID):
lstOfVals = []
for i in db[pltcodeWithSuffix].find({}, {wellID:1,'_id':0}):
lstOfVals.append(i[wellID])
return lstOfVals
def build_graph_mongo_multiproc(chunk,pltcodeWithSuffix):
global lstOfPlots
client=MongoClient(connect_string,maxPoolSize=10000)
db = client[dbname]
#loop over the id's in the chunk and do the plotting with each
for wid in chunk:
#do the plotting with document collection.find_one(id)
img = io.BytesIO()
fig = Figure(figsize=(0.6,0.6))
axis = fig.add_subplot(1,1,1)
absVals = getAllWellVals(db,pltcodeWithSuffix,wid)
axis.plot(lstOfwavelengths,absVals)
axis.set_title(f'{wid}',fontsize=9)
axis.title.set_position([.5, .6])
axis.tick_params(
which='both',
bottom=False,
left=False,
labelbottom=False,
labelleft=False)
FigureCanvasSVG(fig).print_svg(img)
result = img.getvalue()
lstOfPlots.append(result)

Related

Pytorch Dataset for video

Hi I made a video frames loader Dataset to be fed into a pytorch model. I want to sample frames from a video, but the frames should be uniformly sampled from each video. This is the class I came up with. I was wondering if there was any better method to speed up the sampling process.
Do you have any suggestion especially in the read_video method part??
Thanks
import torch
import torchvision as tv
import cv2
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from pathlib import Path
class VideoLoader(torch.utils.data.Dataset):
def __init__(self, data_path, classes, transforms=None, max_frames=None, frames_ratio=None):
super(VideoLoader, self).__init__()
self.data_path = data_path
self.classes = classes
self.frames_ratio = frames_ratio
self.transforms = transforms
self.max_frames = max_frames
def read_video(self, path):
frames = []
vc = cv2.VideoCapture(path)
total_frames = int(vc.get(cv2.CAP_PROP_FRAME_COUNT))
if self.frames_ratio:
if type(self.frames_ratio) is float:
frames_to_pick = int(total_frames * self.frames_ratio)
else:
frames_to_pick = self.frames_ratio
else:
frames_to_pick = total_frames
idxs = np.linspace(0, total_frames, frames_to_pick, endpoint=False)
for i in idxs:
ok, f = vc.read()
if ok:
f = tv.transforms.ToTensor()(f)
f = self.transforms(f) if self.transforms else f
frames.append(f)
vc.set(cv2.CAP_PROP_POS_FRAMES, i)
if self.max_frames and len(frames) == self.max_frames: break
else: break
vc.release()
return torch.stack(frames)
def __getitem__(self, index):
v_path, label = self.data_path[index]
return self.read_video(v_path), self.classes[label]
def __len__(self): return len(self.data_path)
Because you can't really seek through a video in parallel, there's not really any faster sampling process you can run locally. I personally had trouble with this problem which is why I started building a simple API for this called Sieve. You can literally upload data directly to Sieve (either from a cloud bucket or from local storage) and it'll quickly cut up all the frames for you and even mark them with things like motion, people, objects, and more. It parallelizes using serverless functions in the cloud which makes it really fast, even for hours or days of footage.
You can then quickly export from Sieve using the dashboard which gives you a quick curl command you can run to download the exact samples you want.
Here's a helpful repo: https://github.com/Sieve-Data/automatic-video-processing
If you are happy with extracting the frames of each video to disk beforehand, this library is exactly what you're looking for:
Video-Dataset-Loading-PyTorch on Github
https://github.com/RaivoKoot/Video-Dataset-Loading-Pytorch

Different sessions show different streaming data with a single bokeh server, how to solve it?

I'm working on a simulated osilloscope where the server PC collects data and ultimately will publish the streaming plot online. Below is a working script that can do the job. However, when I open multiple browsers, the streaming plots exhibit different data. (Although they are using the same data source). The example 'ohlc' seems to have the same problem. So, what is the right way to do this? I'm considering to write data to a file, but that will bring some issues like file i/o delay and disk storage limitation etc. Thank you for any help.
from bokeh.server.server import Server
from bokeh.models import ColumnDataSource, Label
from bokeh.plotting import figure
from bokeh.layouts import column
import numpy as np
import datetime as dt
from functools import partial
import time
# this will be replaced with the real data collector in the end
def f_emitter(p=0.1):
v = np.random.rand()
return (dt.datetime.now(), 0. if v>p else v)
def make_document(doc, functions, labels):
def update():
for index, func in enumerate(functions):
data = func()
sources[index].stream(new_data=dict(time=[data[0]], data=[data[1]]), rollover=1000)
annotations[index].text = f'{data[1]: .3f}'
sources = [ColumnDataSource(dict(time=[], data=[])) for _ in range(len(functions))]
figs = []
annotations = []
for i in range(len(functions)):
figs.append(figure(x_axis_type='datetime',
y_axis_label=labels[i], toolbar_location=None,
active_drag=None, active_scroll=None))
figs[i].line(x='time', y='data', source=sources[i])
annotations.append(Label(x=10, y=10, text='', text_font_size='40px', text_color='black',
x_units='screen', y_units='screen', background_fill_color='white'))
figs[i].add_layout(annotations[i])
# print(figs[i].plot_height)
doc.add_root(column([fig for fig in figs], sizing_mode='stretch_both'))
doc.add_periodic_callback(callback=update, period_milliseconds=100)
if __name__ == '__main__':
# list of functions and labels to feed into the scope
functions = [f_emitter]
labels = ['emitter']
server = Server({'/': partial(make_document, functions=functions, labels=labels)})
server.start()
server.io_loop.add_callback(server.show, "/")
try:
server.io_loop.start()
except KeyboardInterrupt:
print('keyboard interruption')
When you connect with a new client, by default Bokeh creates a new session. Each session has its own document, so the data source end up not being the same.

How can one parallelize geopandas "to_file" function

I am trying to implement a parallelized function for Geopandas that takes a single vector data (i.e.: a Shapefile containing a Multipolygon data type), and converts it to a standard celular grid with cell x and y sizes defined by the user.
As this function may result in serious Memory issues (i.e.: caused by too high spatial resolution), I was wondering whether it would be possible to save the data iteratively in the given destinated file. That way, as each parallel process runs the "GRID" function, the same process can save the data iteratively in appended mode. That way, I believe that one wouldn't have Memory issues.
Here is my "SHP_to_GRID_Function". Note that the code below still requires that the whole data generated by the multiprocessing be handled by memory directly (so overflow is more than certain for large datasets).
import pandas as pd
import numpy as np
import geopandas as gpd
from shapely.geometry import Polygon
from multiprocessing import Pool
import os
from functools import partial
def info(title):
print(title)
print('module name:', __name__)
print('parent process:', os.getppid())
print('process id:', os.getpid())
def parallelize_df(gdf, func, n_cores, dx=100, dy=100, verbose=False):
Geometries= gdf.loc[:, 'geometry'].values
pool = Pool(processes=n_cores)
func_partial=partial(func, dx, dy, verbose) # prod_x has only one argument x (y is fixed to 10)
results = pool.map(func_partial, Geometries)
pool.close()
pool.join()
print(np.shape(results))
GRID = gpd.GeoSeries(np.array(results).ravel())
print("GRID well created")
return GRID
def generate_grid_from_Poligon(dx=100, dy=100, verbose=False, polygon=None):
if verbose == True:
info('function parallelize_df')
else:
None
xmin,ymin,xmax,ymax = polygon.bounds
lenght = dx
wide = dy
cols = list(np.arange(int(np.floor(xmin)), int(np.ceil(xmax)), wide))
rows = list(np.arange(int(np.floor(ymin)), int(np.ceil(ymax)), lenght))
rows.reverse()
subpolygons = []
for x in cols:
for y in rows:
subpolygons.append( Polygon([(x,y), (x+wide, y), (x+wide, y-lenght), (x, y-lenght)]) )
return subpolygons
def main(GDF, n_cores='standard', dx=100, dy=100, verbose= False):
"""
GDF: geodataframe
n_cores: use standard or a positive numerical (int) value. It will set the number of cores to use in the multiprocessing
args: (dx: dimension in the x coordinate to make the grid
dy: dimenion in the y coordinate to make the grid)
"""
if isinstance(n_cores, str):
import multiprocessing
N_cores = multiprocessing.cpu_count() -1
elif isinstance(n_cores, int):
N_cores =n_cores
GRID_GDF = parallelize_df(GDF, generate_grid_from_Poligon, n_cores=N_cores, dx=dx, dy=dy, verbose=verbose)
return GRID_GDF
I thank you for you time,
Sincerely yours,
Philipe Leal
I finally have come across a solution for my question. It is not perfect, since it requires several writing processes and one final concatenation process over all temporary files created during the run.
Feel free to suggest alternatives.
Here is the solution I found.
import numpy as np
import geopandas as gpd
import pandas as pd
from shapely.geometry import Polygon
from multiprocessing import Pool, Lock, freeze_support
import os
from functools import partial
import time
def info(time_value):
print('module name:', __name__)
print('parent process:', os.getppid())
print('process id:', os.getpid())
print("Time spent: ", time.time() - time_value)
def init(l):
global lock
lock=l
def Data_Arranger(to_filename):
"""This function concatenates and deletes temporary files. It is an arranger
of the multicessing data results"
"""
Base = os.path.join(os.path.dirname(to_filename), 'temp')
Strings = [file for file in os.listdir(Base)]
Strings = [os.path.join(Base, S) for S in Strings]
if not os.path.exists(os.path.dirname(to_filename)):
os.mkdir(os.path.dirname(to_filename))
Sq = [S for S in Strings if S.endswith('.shp')]
gpd.GeoDataFrame(pd.concat([gpd.read_file(sq1) for sq1 in Sq]), crs=GDF.crs).to_file(to_filename)
for sq1 in Sq:
os.remove(sq1)
import shutil
shutil.rmtree(Base, ignore_errors=True)
def parallelize_df(gdf, func, n_cores, dx=100, dy=100, verbose=False, to_filename=None):
Geometries= gdf.loc[:, 'geometry'].values
crs = gdf.crs
pool = Pool(processes=n_cores, initializer=init, initargs=(Lock(), ) )
func_partial=partial(func, dx, dy, verbose, to_filename, crs) # prod_x has only one argument x (y is fixed to 10)
pool.map(func_partial, Geometries)
pool.close()
pool.join()
def generate_grid_from_gdf(dx=100, dy=100, verbose=False, to_filename=None, crs=None, polygon=None):
if verbose == True:
info(time.time())
else:
None
xmin,ymin,xmax,ymax = polygon.bounds
lenght = dx
wide = dy
cols = list(np.arange(int(np.floor(xmin)), int(np.ceil(xmax)), wide))
rows = list(np.arange(int(np.floor(ymin)), int(np.ceil(ymax)), lenght))
rows.reverse()
subpolygons = []
for x in cols:
for y in rows:
subpolygons.append( Polygon([(x,y), (x+wide, y), (x+wide, y-lenght), (x, y-lenght)]) )
lock.acquire()
print('parent process: ', os.getppid(), ' has activated the Lock')
GDF = gpd.GeoDataFrame(geometry=subpolygons, crs=crs)
to_filename = os.path.join(os.path.dirname(to_filename), 'temp', str(os.getpid()) + '_' + str(time.time()) + '.' + os.path.basename(to_filename).split('.')[-1])
if not os.path.exists(os.path.dirname(to_filename)):
os.mkdir(os.path.dirname(to_filename))
try:
print("to_filename: ", to_filename)
GDF.to_file(to_filename)
except:
print("error in the file saving")
lock.release()
print('parent process: ', os.getppid(), ' has unlocked')
def main(GDF, n_cores='standard', dx=100, dy=100, verbose= False, to_filename=None):
"""
GDF: geodataframe
n_cores: use standard or a positive numerical (int) value. It will set the number of cores to use in the multiprocessing
dx: dimension in the x coordinate to make the grid
dy: dimenion in the y coordinate to make the grid)
verbose: whether or not to show info from the processing. Appliable only if applying the function not
in Windows (LINUX, UBUNTU, etc.), or when running in separte console in Windows.
to_filename: the path which will be used to save the resultant file.
"""
if isinstance(n_cores, str):
import multiprocessing
N_cores = multiprocessing.cpu_count() -1
elif isinstance(n_cores, int):
N_cores =n_cores
parallelize_df(GDF, generate_grid_from_gdf, n_cores=N_cores, dx=dx, dy=dy, verbose=verbose, to_filename=to_filename)
Data_Arranger(to_filename)
####################################################################################
if "__main__" == __name__:
freeze_support()
GDF = gpd.read_file("Someone's_file.shp")
to_filename = "To_file_directory/To_file_name.shp"
dx = 500 # resampling to 500 units. Ex: assuming the coordinate reference system is in meters, this function will return polygons of the given geometries in 500m for the longitudinal dimension.
dy = 500 # same here. Assuming CRS is in meters units, the resultant file will be have polygons of 500m in latitudinal dimension
main(GDF, dx=dx, dy=dy, verbose=True, to_filename=to_filename)
I thank you for your time.
Philipe Leal

Fastest Way to Update and Plot Lists? [duplicate]

For years, I've been struggling to get efficient live plotting in matplotlib, and to this day I remain unsatisfied.
I want a redraw_figure function that updates the figure "live" (as the code runs), and will display the latest plots if I stop at a breakpoint.
Here is some demo code:
import time
from matplotlib import pyplot as plt
import numpy as np
def live_update_demo():
plt.subplot(2, 1, 1)
h1 = plt.imshow(np.random.randn(30, 30))
redraw_figure()
plt.subplot(2, 1, 2)
h2, = plt.plot(np.random.randn(50))
redraw_figure()
t_start = time.time()
for i in xrange(1000):
h1.set_data(np.random.randn(30, 30))
redraw_figure()
h2.set_ydata(np.random.randn(50))
redraw_figure()
print 'Mean Frame Rate: %.3gFPS' % ((i+1) / (time.time() - t_start))
def redraw_figure():
plt.draw()
plt.pause(0.00001)
live_update_demo()
Plots should update live when the code is run, and we should see the latest data when stopping at any breakpoint after redraw_figure(). The question is how to best implement redraw_figure()
In the implementation above (plt.draw(); plt.pause(0.00001)), it works, but is very slow (~3.7FPS)
I can implement it as:
def redraw_figure():
plt.gcf().canvas.flush_events()
plt.show(block=False)
And it runs faster (~11FPS), but plots are not up-to date when you stop at breakpoints (eg if I put a breakpoint on the t_start = ... line, the second plot does not appear).
Strangely enough, what does actually work is calling the show twice:
def redraw_figure():
plt.gcf().canvas.flush_events()
plt.show(block=False)
plt.show(block=False)
Which gives ~11FPS and does keep plots up-to-data if your break on any line.
Now I've heard it said that the "block" keyword is deprecated. And calling the same function twice seems like a weird, probably-non-portable hack anyway.
So what can I put in this function that will plot at a reasonable frame rate, isn't a giant kludge, and preferably will work across backends and systems?
Some notes:
I'm on OSX, and using TkAgg backend, but solutions on any backend/system are welcome
Interactive mode "On" will not work, because it does not update live. It just updates when in the Python console when the interpreter waits for user input.
A blog suggested the implementation:
def redraw_figure():
fig = plt.gcf()
fig.canvas.draw()
fig.canvas.flush_events()
But at least on my system, that does not redraw the plots at all.
So, if anybody has an answer, you would directly make me and thousands of others very happy. Their happiness would probably trickle through to their friends and relatives, and their friends and relatives, and so on, so that you could potentially improve the lives of billions.
Conclusions
ImportanceOfBeingErnest shows how you can use blit for faster plotting, but it's not as simple as putting something different in the redraw_figure function (you need to keep track of what things to redraw).
First of all, the code that is posted in the question runs with 7 fps on my machine, with QT4Agg as backend.
Now, as has been suggested in many posts, like here or here, using blit might be an option. Although this article mentions that blit causes strong memory leakage, I could not observe that.
I have modified your code a bit and compared the frame rate with and without the use of blit. The code below gives
28 fps when run without blit
175 fps with blit
Code:
import time
from matplotlib import pyplot as plt
import numpy as np
def live_update_demo(blit = False):
x = np.linspace(0,50., num=100)
X,Y = np.meshgrid(x,x)
fig = plt.figure()
ax1 = fig.add_subplot(2, 1, 1)
ax2 = fig.add_subplot(2, 1, 2)
img = ax1.imshow(X, vmin=-1, vmax=1, interpolation="None", cmap="RdBu")
line, = ax2.plot([], lw=3)
text = ax2.text(0.8,0.5, "")
ax2.set_xlim(x.min(), x.max())
ax2.set_ylim([-1.1, 1.1])
fig.canvas.draw() # note that the first draw comes before setting data
if blit:
# cache the background
axbackground = fig.canvas.copy_from_bbox(ax1.bbox)
ax2background = fig.canvas.copy_from_bbox(ax2.bbox)
plt.show(block=False)
t_start = time.time()
k=0.
for i in np.arange(1000):
img.set_data(np.sin(X/3.+k)*np.cos(Y/3.+k))
line.set_data(x, np.sin(x/3.+k))
tx = 'Mean Frame Rate:\n {fps:.3f}FPS'.format(fps= ((i+1) / (time.time() - t_start)) )
text.set_text(tx)
#print tx
k+=0.11
if blit:
# restore background
fig.canvas.restore_region(axbackground)
fig.canvas.restore_region(ax2background)
# redraw just the points
ax1.draw_artist(img)
ax2.draw_artist(line)
ax2.draw_artist(text)
# fill in the axes rectangle
fig.canvas.blit(ax1.bbox)
fig.canvas.blit(ax2.bbox)
# in this post http://bastibe.de/2013-05-30-speeding-up-matplotlib.html
# it is mentionned that blit causes strong memory leakage.
# however, I did not observe that.
else:
# redraw everything
fig.canvas.draw()
fig.canvas.flush_events()
#alternatively you could use
#plt.pause(0.000000000001)
# however plt.pause calls canvas.draw(), as can be read here:
#http://bastibe.de/2013-05-30-speeding-up-matplotlib.html
live_update_demo(True) # 175 fps
#live_update_demo(False) # 28 fps
Update:
For faster plotting, one may consider using pyqtgraph.
As the pyqtgraph documentation puts it: "For plotting, pyqtgraph is not nearly as complete/mature as matplotlib, but runs much faster."
I ported the above example to pyqtgraph. And although it looks kind of ugly, it runs with 250 fps on my machine.
Summing that up,
matplotlib (without blitting): 28 fps
matplotlib (with blitting): 175 fps
pyqtgraph : 250 fps
pyqtgraph code:
import sys
import time
from pyqtgraph.Qt import QtCore, QtGui
import numpy as np
import pyqtgraph as pg
class App(QtGui.QMainWindow):
def __init__(self, parent=None):
super(App, self).__init__(parent)
#### Create Gui Elements ###########
self.mainbox = QtGui.QWidget()
self.setCentralWidget(self.mainbox)
self.mainbox.setLayout(QtGui.QVBoxLayout())
self.canvas = pg.GraphicsLayoutWidget()
self.mainbox.layout().addWidget(self.canvas)
self.label = QtGui.QLabel()
self.mainbox.layout().addWidget(self.label)
self.view = self.canvas.addViewBox()
self.view.setAspectLocked(True)
self.view.setRange(QtCore.QRectF(0,0, 100, 100))
# image plot
self.img = pg.ImageItem(border='w')
self.view.addItem(self.img)
self.canvas.nextRow()
# line plot
self.otherplot = self.canvas.addPlot()
self.h2 = self.otherplot.plot(pen='y')
#### Set Data #####################
self.x = np.linspace(0,50., num=100)
self.X,self.Y = np.meshgrid(self.x,self.x)
self.counter = 0
self.fps = 0.
self.lastupdate = time.time()
#### Start #####################
self._update()
def _update(self):
self.data = np.sin(self.X/3.+self.counter/9.)*np.cos(self.Y/3.+self.counter/9.)
self.ydata = np.sin(self.x/3.+ self.counter/9.)
self.img.setImage(self.data)
self.h2.setData(self.ydata)
now = time.time()
dt = (now-self.lastupdate)
if dt <= 0:
dt = 0.000000000001
fps2 = 1.0 / dt
self.lastupdate = now
self.fps = self.fps * 0.9 + fps2 * 0.1
tx = 'Mean Frame Rate: {fps:.3f} FPS'.format(fps=self.fps )
self.label.setText(tx)
QtCore.QTimer.singleShot(1, self._update)
self.counter += 1
if __name__ == '__main__':
app = QtGui.QApplication(sys.argv)
thisapp = App()
thisapp.show()
sys.exit(app.exec_())
Here's one way to do live plotting: get the plot as an image array then draw the image to a multithreaded screen.
Example using a pyformulas screen (~30 FPS):
import pyformulas as pf
import matplotlib.pyplot as plt
import numpy as np
import time
fig = plt.figure()
screen = pf.screen(title='Plot')
start = time.time()
for i in range(10000):
t = time.time() - start
x = np.linspace(t-3, t, 100)
y = np.sin(2*np.pi*x) + np.sin(3*np.pi*x)
plt.xlim(t-3,t)
plt.ylim(-3,3)
plt.plot(x, y, c='black')
# If we haven't already shown or saved the plot, then we need to draw the figure first...
fig.canvas.draw()
image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
screen.update(image)
#screen.close()
Disclaimer: I'm the maintainer of pyformulas

How to animate multiple figures at the same time

I made an animation for sorting algorithms and it it works great for animating one sorting algorithm, but when I try to animate multiple at the same time both windows come up but none of them are moving. I was wondering how I could go around to fix this.
When I run the code the first figure is stuck on the first frame and the second figure jumps to the last frame
import matplotlib.pyplot as plt
from matplotlib import animation
import random
# my class for getting data from sorting algorithms
from animationSorters import *
def sort_anim(samp_size=100, types=['bubblesort', 'quicksort']):
rndList = random.sample(range(1, samp_size+1), samp_size)
anim = []
for k in range(0, len(types)):
sort_type = types[k]
animation_speed = 1
def barlist(x):
if sort_type == 'bubblesort':
l = bubblesort_swaps(x)#returns bubble sort data
elif sort_type == 'quicksort':
l = quicksort_swaps(x)#returns quick sort data
final = splitSwaps(l, len(x))
return final
fin = barlist(rndList)
fig = plt.figure(k+1)
plt.rcParams['axes.facecolor'] = 'black'
n= len(fin)#Number of frames
x=range(1,len(rndList)+1)
barcollection = plt.bar(x,fin[0], color='w')
anim_title = sort_type.title() + '\nSize: ' + str(samp_size)
plt.title(anim_title)
def animate(i):
y=fin[i]
for i, b in enumerate(barcollection):
b.set_height(y[i])
anim.append(animation.FuncAnimation(fig,animate, repeat=False,
blit=False, frames=n, interval=animation_speed))
plt.show()
sort_anim()
As explained in the documentation for the animation module:
it is critical to keep a reference to the instance object. The
animation is advanced by a timer (typically from the host GUI
framework) which the Animation object holds the only reference to. If
you do not hold a reference to the Animation object, it (and hence the
timers), will be garbage collected which will stop the animation.
Therefore you need to return the references to your animations from your function, otherwise those objects are destroyed when exiting the function.
Consider the following simplification of your code:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
def my_func(nfigs=2):
anims = []
for i in range(nfigs):
fig = plt.figure(num=i)
ax = fig.add_subplot(111)
col = ax.bar(x=range(10), height=np.zeros((10,)))
ax.set_ylim([0, 1])
def animate(k, bars):
new_data = np.random.random(size=(10,))
for j, b in enumerate(bars):
b.set_height(new_data[j])
return bars,
ani = animation.FuncAnimation(fig, animate, fargs=(col, ), frames=100)
anims.append(ani)
return anims
my_anims = my_func(3)
# calling simply my_func() here would not work, you need to keep the returned
# array in memory for the animations to stay alive
plt.show()

Resources