Using a nested for loop in subplots [duplicate] - python-3.x

I am a little confused about how this code works:
fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()
How does the fig, axes work in this case? What does it do?
Also why wouldn't this work to do the same thing:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

There are several ways to do it. The subplots method creates the figure along with the subplots that are then stored in the ax array. For example:
import matplotlib.pyplot as plt
x = range(10)
y = range(10)
fig, ax = plt.subplots(nrows=2, ncols=2)
for row in ax:
for col in row:
col.plot(x, y)
plt.show()
However, something like this will also work, it's not so "clean" though since you are creating a figure with subplots and then add on top of them:
fig = plt.figure()
plt.subplot(2, 2, 1)
plt.plot(x, y)
plt.subplot(2, 2, 2)
plt.plot(x, y)
plt.subplot(2, 2, 3)
plt.plot(x, y)
plt.subplot(2, 2, 4)
plt.plot(x, y)
plt.show()

import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 2)
ax[0, 0].plot(range(10), 'r') #row=0, col=0
ax[1, 0].plot(range(10), 'b') #row=1, col=0
ax[0, 1].plot(range(10), 'g') #row=0, col=1
ax[1, 1].plot(range(10), 'k') #row=1, col=1
plt.show()

You can also unpack the axes in the subplots call
And set whether you want to share the x and y axes between the subplots
Like this:
import matplotlib.pyplot as plt
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1, ax2, ax3, ax4 = axes.flatten()
ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()

You might be interested in the fact that as of matplotlib version 2.1 the second code from the question works fine as well.
From the change log:
Figure class now has subplots method
The Figure class now has a subplots() method which behaves the same as pyplot.subplots() but on an existing figure.
Example:
import matplotlib.pyplot as plt
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
plt.show()

Read the documentation: matplotlib.pyplot.subplots
pyplot.subplots() returns a tuple fig, ax which is unpacked in two variables using the notation
fig, axes = plt.subplots(nrows=2, ncols=2)
The code:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
does not work because subplots() is a function in pyplot not a member of the object Figure.

Iterating through all subplots sequentially:
fig, axes = plt.subplots(nrows, ncols)
for ax in axes.flatten():
ax.plot(x,y)
Accessing a specific index:
for row in range(nrows):
for col in range(ncols):
axes[row,col].plot(x[row], y[col])

Subplots with pandas
This answer is for subplots with pandas, which uses matplotlib as the default plotting backend.
Here are four options to create subplots starting with a pandas.DataFrame
Implementation 1. and 2. are for the data in a wide format, creating subplots for each column.
Implementation 3. and 4. are for data in a long format, creating subplots for each unique value in a column.
Tested in python 3.8.11, pandas 1.3.2, matplotlib 3.4.3, seaborn 0.11.2
Imports and Data
import seaborn as sns # data only
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# wide dataframe
df = sns.load_dataset('planets').iloc[:, 2:5]
orbital_period mass distance
0 269.300 7.10 77.40
1 874.774 2.21 56.95
2 763.000 2.60 19.84
3 326.030 19.40 110.62
4 516.220 10.50 119.47
# long dataframe
dfm = sns.load_dataset('planets').iloc[:, 2:5].melt()
variable value
0 orbital_period 269.300
1 orbital_period 874.774
2 orbital_period 763.000
3 orbital_period 326.030
4 orbital_period 516.220
1. subplots=True and layout, for each column
Use the parameters subplots=True and layout=(rows, cols) in pandas.DataFrame.plot
This example uses kind='density', but there are different options for kind, and this applies to them all. Without specifying kind, a line plot is the default.
ax is array of AxesSubplot returned by pandas.DataFrame.plot
See How to get a Figure object, if needed.
How to save pandas subplots
axes = df.plot(kind='density', subplots=True, layout=(2, 2), sharex=False, figsize=(10, 6))
# extract the figure object; only used for tight_layout in this example
fig = axes[0][0].get_figure()
# set the individual titles
for ax, title in zip(axes.ravel(), df.columns):
ax.set_title(title)
fig.tight_layout()
plt.show()
2. plt.subplots, for each column
Create an array of Axes with matplotlib.pyplot.subplots and then pass axes[i, j] or axes[n] to the ax parameter.
This option uses pandas.DataFrame.plot, but can use other axes level plot calls as a substitute (e.g. sns.kdeplot, plt.plot, etc.)
It's easiest to collapse the subplot array of Axes into one dimension with .ravel or .flatten. See .ravel vs .flatten.
Any variables applying to each axes, that need to be iterate through, are combined with .zip (e.g. cols, axes, colors, palette, etc.). Each object must be the same length.
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
cols = df.columns # create a list of dataframe columns to use
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for col, color, ax in zip(cols, colors, axes):
df[col].plot(kind='density', ax=ax, color=color, label=col, title=col)
ax.legend()
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
Result for 1. and 2.
3. plt.subplots, for each group in .groupby
This is similar to 2., except it zips color and axes to a .groupby object.
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
dfg = dfm.groupby('variable') # get data for each unique value in the first column
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for (group, data), color, ax in zip(dfg, colors, axes):
data.plot(kind='density', ax=ax, color=color, title=group, legend=False)
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
4. seaborn figure-level plot
Use a seaborn figure-level plot, and use the col or row parameter. seaborn is a high-level API for matplotlib. See seaborn: API reference
p = sns.displot(data=dfm, kind='kde', col='variable', col_wrap=2, x='value', hue='variable',
facet_kws={'sharey': False, 'sharex': False}, height=3.5, aspect=1.75)
sns.move_legend(p, "upper left", bbox_to_anchor=(.55, .45))

Convert the axes array to 1D
Generating subplots with plt.subplots(nrows, ncols), where both nrows and ncols is greater than 1, returns a nested array of <AxesSubplot:> objects.
It’s not necessary to flatten axes in cases where either nrows=1 or ncols=1, because axes will already be 1 dimensional, which is a result of the default parameter squeeze=True
The easiest way to access the objects, is to convert the array to 1 dimension with .ravel(), .flatten(), or .flat.
.ravel vs. .flatten
flatten always returns a copy.
ravel returns a view of the original array whenever possible.
Once the array of axes is converted to 1-d, there are a number of ways to plot.
This answer is relevant to seaborn axes-level plots, which have the ax= parameter (e.g. sns.barplot(…, ax=ax[0]).
seaborn is a high-level API for matplotlib. See Figure-level vs. axes-level functions and seaborn is not plotting within defined subplots
import matplotlib.pyplot as plt
import numpy as np # sample data only
# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]
# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)
# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
[<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
# convert the array to 1 dimension
axes = axes.ravel()
# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
dtype=object)
Iterate through the flattened array
If there are more subplots than data, this will result in IndexError: list index out of range
Try option 3. instead, or select a subset of the axes (e.g. axes[:-2])
for i, ax in enumerate(axes):
ax.plot(x_data[i], y_data[i])
Access each axes by index
axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])
Index the data and axes
for i in range(len(x_data)):
axes[i].plot(x_data[i], y_data[i])
zip the axes and data together and then iterate through the list of tuples.
for ax, x, y in zip(axes, x_data, y_data):
ax.plot(x, y)
Ouput
An option is to assign each axes to a variable, fig, (ax1, ax2, ax3) = plt.subplots(1, 3). However, as written, this only works in cases with either nrows=1 or ncols=1. This is based on the shape of the array returned by plt.subplots, and quickly becomes cumbersome.
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) for a 2 x 2 array.
This option is most useful for two subplots (e.g.: fig, (ax1, ax2) = plt.subplots(1, 2) or fig, (ax1, ax2) = plt.subplots(2, 1)). For more subplots, it's more efficient to flatten and iterate through the array of axes.

You could use the following:
import numpy as np
import matplotlib.pyplot as plt
fig, _ = plt.subplots(nrows=2, ncols=2)
for i, ax in enumerate(fig.axes):
ax.plot(np.sin(np.linspace(0,2*np.pi,100) + np.pi/2*i))
Or alternatively, using the second variable that plt.subplot returns:
fig, ax_mat = plt.subplots(nrows=2, ncols=2)
for i, ax in enumerate(ax_mat.flatten()):
...
ax_mat is a matrix of the axes. It's shape is nrows x ncols.

here is a simple solution
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=False)
for sp in fig.axes:
sp.plot(range(10))

Go with the following if you really want to use a loop:
def plot(data):
fig = plt.figure(figsize=(100, 100))
for idx, k in enumerate(data.keys(), 1):
x, y = data[k].keys(), data[k].values
plt.subplot(63, 10, idx)
plt.bar(x, y)
plt.show()

Another concise solution is:
// set up structure of plots
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20,10))
// for plot 1
ax1.set_title('Title A')
ax1.plot(x, y)
// for plot 2
ax2.set_title('Title B')
ax2.plot(x, y)
// for plot 3
ax3.set_title('Title C')
ax3.plot(x,y)

Related

Pandas & Matplotlib: personalize the date format in a line chart

I want to make the dates on the x- axis look more prettier, currently the dates cannot be even read. what is the best way to do it.
Below is the code and also the actual graph picture
import matplotlib.pyplot as plt
import pandas as pd
import pandas as pd
df = dataset
# gca stands for 'get current axis'
ax = plt.gca()
y1 = df['Predicted_Lower']
y2 = df['Predicted_Upper']
x = df['Date']
ax.fill_between(x,y1, y2, facecolor="#CC6666", alpha=0.7)
df.plot(kind='line',x='Date',y='Predicted_Lower',color='white',ax=ax)
df.plot(kind='line',x='Date',y='Predicted_Upper',color='white', ax=ax)
df.plot(kind='line',x='Date',y='Predicted', color='yellow', ax=ax)
df.plot(kind='line',x='Date',y='Actuals', color='green', ax=ax)
plt.xticks(rotation=45)
plt.show()
You can modify the number of labels, by settings locs and labels parameters using matplotlib.pyplot.xticks, for example get the current locs and labels and only plot one-third of them:
# ...
df.plot(kind='line',x='Date',y='Actuals', color='green', ax=ax)
locs, labels = plt.xticks()
plt.xticks(locs[::3], labels[::3], rotation=45)
plt.show()

Matplotlib how to do one or more subplots with same axes specification [duplicate]

This question already has answers here:
How to plot in multiple subplots
(12 answers)
Closed 1 year ago.
I have constructed a set of pie charts with some help Insert image into pie chart slice
My charts look wonderful, now I need to place all 6 of them in a 2x3 figure with common tick marks on the shared x and y axis.
For starting I am looking at subplots and thought I could get it to work. I downloaded some examples and started to try a few things.
f, (a) = (plt.subplots(nrows=1, ncols=1, sharex=True, sharey=True))#,
#squeeze=False, subplot_kw=None, gridspec_kw=None))
print(type(f),'\n',type(a),'\n')#,type(b))
yields:
class 'matplotlib.figure.Figure'
class 'matplotlib.axes._subplots.AxesSubplot'
while:
f, (a) = (plt.subplots(nrows=1, ncols=1, sharex=True, sharey=True, squeeze=False, subplot_kw=None, gridspec_kw=None))
print(type(f),'\n',type(a),'\n')#,type(b))
returns:
class 'matplotlib.figure.Figure'
class 'numpy.ndarray'
When I do this:
f, (a,b) = (plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True, squeeze=False, subplot_kw=None, gridspec_kw=None))
print(type(f),'\n',type(a),'\n',type(b))
I get the similar results, however if nrows=1 and ncols=2 I get an error:
f, (a,b) = (plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, squeeze=False, subplot_kw=None, gridspec_kw=None))
print(type(f),'\n',type(a),'\n',type(b))
ValueError: not enough values to unpack (expected 2, got 1)
but again this:
f, (a , b) = (
plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True))#,
#squeeze=False, subplot_kw=None, gridspec_kw=None))
print(type(f),'\n',type(a),'\n',type(b))
gives
class 'matplotlib.figure.Figure'
class 'matplotlib.axes._subplots.AxesSubplot'
class 'matplotlib.axes._subplots.AxesSubplot'
Why is it either or array or axes, and also why does a 2X1 work and a 1X2 does not?
I wish to high heaven I could better understand the documentation. Thanks.
The different return types are due to the squeeze keyword argument to plt.subplots() which is set to True by default.
Let's enhance the documentation with the respective unpackings:
squeeze : bool, optional, default: True
If True, extra dimensions are squeezed out from the returned Axes object:
if only one subplot is constructed (nrows=ncols=1), the resulting single Axes object is returned as a scalar.
fig, ax = plt.subplots()
for Nx1 or 1xN subplots, the returned object is a 1D numpy object array of Axes objects are returned as numpy 1D arrays.
fig, (ax1, ..., axN) = plt.subplots(nrows=N, ncols=1) (for Nx1)
fig, (ax1, ..., axN) = plt.subplots(nrows=1, ncols=N) (for 1xN)
for NxM, subplots with N>1 and M>1 are returned as a 2D arrays.
fig, ((ax11, .., ax1M),..,(axN1, .., axNM)) = plt.subplots(nrows=N, ncols=M)
If False, no squeezing at all is done: the returned Axes object is always a 2D array containing Axes instances, even if it ends up being 1x1.
fig, ((ax,),) = plt.subplots(nrows=1, ncols=1, squeeze=False)
fig, ((ax,), .. ,(axN,)) = plt.subplots(nrows=N, ncols=1, squeeze=False) for Nx1
fig, ((ax, .. ,axN),) = plt.subplots(nrows=1, ncols=N, squeeze=False) for 1xN
fig, ((ax11, .., ax1M),..,(axN1, .., axNM)) = plt.subplots(nrows=N, ncols=M)
Alternatively you may always use the unpacked version
fig, ax_arr = plt.subplots(nrows=N, ncols=M, squeeze=False)
and index the array to obtain the axes, ax_arr[1,2].plot(..).
So for a 2 x 3 grid it wouldn't actually matter if you set squeeze to False. The result will always be a 2D array. You may unpack it as
fig, ((ax1, ax2, ax3),(ax4, ax5, ax6)) = plt.subplots(nrows=2, ncols=3)
to have ax{i} as the matplotlib axes objects, or you may use the packed version
fig, ax_arr = plt.subplots(nrows=2, ncols=3)
ax_arr[0,0].plot(..) # plot to first top left axes
ax_arr[1,2].plot(..) # plot to last bottom right axes

How to change the location of the symbols/text within a legend box?

I have a subplot with a single legend entry. I am placing the legend at the bottom of the figure and using mode='expand'; however, the single legend entry is placed to the very left of the legend box. To my understanding, changing kwargs such as bbox_to_anchor changes the legend box parameters but not the parameters of the symbols/text within. Below is an example to reproduce my issue.
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-10, 10, 21)
y = np.exp(x)
z = x **2
fig, axes = plt.subplots(nrows=1, ncols=2)
axes[0].plot(x, y, color='r', label='exponential')
axes[1].plot(x, z, color='b')
# handles, labels = axes[0].get_legend_handles_labels()
plt.subplots_adjust(bottom=0.125)
fig.legend(mode='expand', loc='lower center')
plt.show()
plt.close(fig)
This code produces . How can I change the position of the symbol and text such that they are centered in the legend box?
PS: I am aware that exponential is a bad label for this subplot since it only describes the first subfigure. But, this is just for examples-sake so that I can apply it to my actual use-case.
The legend entries are placed using a HPacker object. This does not allow to be centered. The behaviour is rather that those HPackers are "justified" (similar to the "justify" option in common word processing software).
A workaround would be to create three (or any odd number of) legend entries, such that the desired entry is in the middle. This would be accomplished via the ncol argument and the use of "dummy" entries (which might be transparent and have no associated label).
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-10, 10, 21)
y = np.exp(x)
z = x **2
fig, axes = plt.subplots(nrows=1, ncols=2)
fig.subplots_adjust(bottom=0.125)
l1, = axes[0].plot(x, y, color='r', label='exponential')
axes[1].plot(x, z, color='b')
dummy = plt.Line2D([],[], alpha=0)
fig.legend(handles=[dummy, l1, dummy],
mode='expand', loc='lower center', ncol=3)
plt.show()

matplotlib, add legend for each line? [duplicate]

TL;DR -> How can one create a legend for a line graph in Matplotlib's PyPlot without creating any extra variables?
Please consider the graphing script below:
if __name__ == '__main__':
PyPlot.plot(total_lengths, sort_times_bubble, 'b-',
total_lengths, sort_times_ins, 'r-',
total_lengths, sort_times_merge_r, 'g+',
total_lengths, sort_times_merge_i, 'p-', )
PyPlot.title("Combined Statistics")
PyPlot.xlabel("Length of list (number)")
PyPlot.ylabel("Time taken (seconds)")
PyPlot.show()
As you can see, this is a very basic use of matplotlib's PyPlot. This ideally generates a graph like the one below:
Nothing special, I know. However, it is unclear what data is being plotted where (I'm trying to plot the data of some sorting algorithms, length against time taken, and I'd like to make sure people know which line is which). Thus, I need a legend, however, taking a look at the following example below(from the official site):
ax = subplot(1,1,1)
p1, = ax.plot([1,2,3], label="line 1")
p2, = ax.plot([3,2,1], label="line 2")
p3, = ax.plot([2,3,1], label="line 3")
handles, labels = ax.get_legend_handles_labels()
# reverse the order
ax.legend(handles[::-1], labels[::-1])
# or sort them by labels
import operator
hl = sorted(zip(handles, labels),
key=operator.itemgetter(1))
handles2, labels2 = zip(*hl)
ax.legend(handles2, labels2)
You will see that I need to create an extra variable ax. How can I add a legend to my graph without having to create this extra variable and retaining the simplicity of my current script?
Add a label= to each of your plot() calls, and then call legend(loc='upper left').
Consider this sample (tested with Python 3.8.0):
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 20, 1000)
y1 = np.sin(x)
y2 = np.cos(x)
plt.plot(x, y1, "-b", label="sine")
plt.plot(x, y2, "-r", label="cosine")
plt.legend(loc="upper left")
plt.ylim(-1.5, 2.0)
plt.show()
Slightly modified from this tutorial: http://jakevdp.github.io/mpl_tutorial/tutorial_pages/tut1.html
You can access the Axes instance (ax) with plt.gca(). In this case, you can use
plt.gca().legend()
You can do this either by using the label= keyword in each of your plt.plot() calls or by assigning your labels as a tuple or list within legend, as in this working example:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-0.75,1,100)
y0 = np.exp(2 + 3*x - 7*x**3)
y1 = 7-4*np.sin(4*x)
plt.plot(x,y0,x,y1)
plt.gca().legend(('y0','y1'))
plt.show()
However, if you need to access the Axes instance more that once, I do recommend saving it to the variable ax with
ax = plt.gca()
and then calling ax instead of plt.gca().
Here's an example to help you out ...
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(111)
ax.set_title('ADR vs Rating (CS:GO)')
ax.scatter(x=data[:,0],y=data[:,1],label='Data')
plt.plot(data[:,0], m*data[:,0] + b,color='red',label='Our Fitting
Line')
ax.set_xlabel('ADR')
ax.set_ylabel('Rating')
ax.legend(loc='best')
plt.show()
You can add a custom legend documentation
first = [1, 2, 4, 5, 4]
second = [3, 4, 2, 2, 3]
plt.plot(first, 'g--', second, 'r--')
plt.legend(['First List', 'Second List'], loc='upper left')
plt.show()
A simple plot for sine and cosine curves with a legend.
Used matplotlib.pyplot
import math
import matplotlib.pyplot as plt
x=[]
for i in range(-314,314):
x.append(i/100)
ysin=[math.sin(i) for i in x]
ycos=[math.cos(i) for i in x]
plt.plot(x,ysin,label='sin(x)') #specify label for the corresponding curve
plt.plot(x,ycos,label='cos(x)')
plt.xticks([-3.14,-1.57,0,1.57,3.14],['-$\pi$','-$\pi$/2',0,'$\pi$/2','$\pi$'])
plt.legend()
plt.show()
Add labels to each argument in your plot call corresponding to the series it is graphing, i.e. label = "series 1"
Then simply add Pyplot.legend() to the bottom of your script and the legend will display these labels.

matplotlib.pyplot: create a subplot of stored plots

python 3.6 om mac
matplotlib 2.1.0
using matplotlib.pyplot (as plt)
Let's say i have a few plt.figures() that i appended into a list called figures as objects. When in command line i do: figures[0]it produces the plot for the index 0 of the list figures.
However, how can i arrange to have all the plots in figures to be in a subplot.
# Pseudo code:
plt.figure()
for i, fig in enumerate(figures): # figures contains the plots
plt.subplot(2, 2, i+1)
fig # location i+1 of the subplot is filled with the fig plot element
So as a result, i would a 2 by 2 grid that contains each plot found in figures.
hoping this makes sense.
A figure is a figure. You cannot have a figure inside a figure. The usual approach is to create a figure, create one or several subplots, plot something in the subplots.
In case it may happen that you want to plot something in different axes or figures, it might make sense to wrap the plotting in a function which takes the axes as argument.
You could then use this function to plot to an axes of a new figure or to plot to an axes of a figure with many subplots.
import numpy as np
import matplotlib.pyplot as plt
def myplot(ax, data_x, data_y, color="C0"):
ax.plot(data_x, data_y, color=color)
ax.legend()
x = np.linspace(0,10)
y = np.cumsum(np.random.randn(len(x),4), axis=0)
#create 4 figures
for i in range(4):
fig, ax = plt.subplots()
myplot(ax, x, y[:,i], color="C{}".format(i))
# create another figure with each plot as subplot
fig, ax = plt.subplots(2,2)
for i in range(4):
myplot(ax.flatten()[i], x, y[:,i], color="C{}".format(i))
plt.show()

Resources