Error when drawing a legend on a separate plot - python-3.x

I want to plot a legend on a separate frame than the original plot. I can plot the legend from plot commands. But not the legend from fill_between.
Here is a sample code
#!/usr/bin/env python
import numpy as np
import matplotlib.pyplot as plt
xx = np.linspace(0, 3.14*3, 100)
yy = np.sin (xx)
zz = yy +0.5
# I can draw a plot with legend
fig = plt.figure( )
ax = fig.add_subplot(1, 1, 1,)
line1, = ax.plot (xx, yy, label='xx')
line2, = ax.plot (xx, zz, label='yy')
fill0 = ax.fill_between (xx, yy, zz, label='filling', alpha=0.2, color='grey' )
ax.legend ( handles=[line1, line2, fill0])
plt.show()
# I can draw a empty plot with the legend of the lines
plt.legend(handles=[line1, line2])
plt.show()
# I can't draw an empty plot with the legend of the fill
# Why ?
# Can I fix this ?
plt.legend(handles=[fill0,])
plt.show()
And now the error :
Traceback (most recent call last):
File "Untitled.py", line 34, in <module>
plt.legend(handles=[fill0,])
File "/Users/marti/anaconda3/envs/PROD/lib/python3.7/site-packages/matplotlib/pyplot.py", line 2721, in legend
return gca().legend(*args, **kwargs)
File "/Users/marti/anaconda3/envs/PROD/lib/python3.7/site-packages/matplotlib/axes/_axes.py", line 417, in legend
self.legend_ = mlegend.Legend(self, handles, labels, **kwargs)
File "/Users/marti/anaconda3/envs/PROD/lib/python3.7/site-packages/matplotlib/legend.py", line 503, in __init__
self._init_legend_box(handles, labels, markerfirst)
File "/Users/marti/anaconda3/envs/PROD/lib/python3.7/site-packages/matplotlib/legend.py", line 767, in _init_legend_box
fontsize, handlebox))
File "/Users/marti/anaconda3/envs/PROD/lib/python3.7/site-packages/matplotlib/legend_handler.py", line 117, in legend_artist
fontsize, handlebox.get_transform())
File "/Users/marti/anaconda3/envs/PROD/lib/python3.7/site-packages/matplotlib /legend_handler.py", line 727, in create_artists
self.update_prop(p, orig_handle, legend)
File "/Users/marti/anaconda3/envs/PROD/lib/python3.7/site-packages/matplotlib/legend_handler.py", line 76, in update_prop
legend._set_artist_props(legend_handle)
File "/Users/marti/anaconda3/envs/PROD/lib/python3.7/site-packages/matplotlib/legend.py", line 550, in _set_artist_props
a.set_figure(self.figure)
File "/Users/marti/anaconda3/envs/PROD/lib/python3.7/site-packages/matplotlib/artist.py", line 704, in set_figure
raise RuntimeError("Can not put single artist in "
RuntimeError: Can not put single artist in more than one figure
Any hint to fix this?

The artist object involved has deep nested data that matplotlib chooses not to allow its reuse on other axes. The workaround code can be something like this:
(only relevant code towards the end is given here)
# if want to plot the filled patch, must create anew
grey_patch1 = mpatches.Patch(label='filling1', alpha=0.2, color='grey')
plt.legend(handles=[line1, line2, grey_patch1])
plt.show()
#grey_patch1 can be reused, no need to recreate like this
grey_patch2 = mpatches.Patch(label='filling2', alpha=0.2, color='red')
plt.legend(handles=[grey_patch1, grey_patch2,])
plt.show()
The above code needs import matplotlib.patches as mpatches.

Related

Why using Area under curve in manim is giving me an error?

I am trying to show area under a curve using manim
this is my code
from manimlib import *
import numpy as np
class GraphExample(Scene):
def construct(self):
ax = Axes((-3, 10), (-1, 8))
ax.add_coordinate_labels()
curve = ax.get_graph(lambda x: 2 * np.sin(x))
self.add(ax,curve)
area = ax.get_area_under_graph(graph=curve, x_range= (0,2))
self.add(curve, area)
self.wait(1)
this is giving an error message
File "c:\manim-master\manimlib\__main__.py", line 17, in main scene.run()
File "c:\manim-master\manimlib\scene\scene.py", line 75, in run self.construct()
File "test.py", line 21, in construct self.add(area)
File "c:\manim-master\manimlib\scene\scene.py", line 209, in add self.remove(*new_mobjects)
File "c:\manim-master\manimlib\scene\scene.py", line 226, in remove self.mobjects = restructure_list_to_exclude_certain_family_members(
File "c:\manim-master\manimlib\utils\family_ops.py", line 25, in restructure_list_to_exclude_certain_family_members
to_remove = extract_mobject_family_members(to_remove)
File "c:\manim-master\manimlib\utils\family_ops.py", line 5, in extract_mobject_family_members result = list(it.chain(*[
File "c:\manim-master\manimlib\utils\family_ops.py", line 6, in <listcomp>mob.get_family()
AttributeError: 'NoneType' object has no attribute 'get_family'
I don't know what I need to change, someone please help me out here
I changed the code to use Manim's get_area method.
get_area(graph, x_range=None, color=['#58C4DD', '#83C167'], opacity=0.3, bounded_graph=None, **kwargs)
Returns a Polygon representing the area under the graph passed.
from manim import *
import numpy as np
class GraphExample(Scene):
def construct(self):
ax = Axes((-3, 10), (-1, 8))
ax.add_coordinates()
curve = ax.get_graph(lambda x: 2 * np.sin(x))
self.add(ax, curve)
area = ax.get_area(graph=curve, x_range=(0,2))
self.add(area)
self.wait(1)
Output:

Secondary y-axis using seaborn

I am trying to plot a bar chart and a line chart as a single plot and inclined to use seaborn due to its nice formatting features. However, when I do df1.plot(kind='bar',...) followed by df1.plot(kind='line',..., secondary_y=True), I get similar outcome as below, i.e., no line chart but no error.
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# Sample dataframe.
df1 = pd.DataFrame({'date':pd.date_range(datetime(2020,1,1), periods=699).tolist(), 'amount':range(1,700), 'balance':np.cumsum(range(1,700))})
df1.loc[:, 'month'] = df1['date'].dt.to_period("M")
df1.loc[:, 'month_str'] = df1['date'].dt.year.astype(str) + '-' + df1['date'].dt.month.astype(str)
df1.loc[:, 'month_dt'] = pd.to_datetime(df1.month.dt.year*10000+df1.month.dt.month*100+1,format='%Y%m%d')
# Case-1: This doesn't work.
df2 = df1.groupby(['month']).agg({'amount':'sum','balance':'sum'})
sns.barplot(x='month', y='amount', data=df2.reset_index(), palette="Blues_d")
ax2 = plt.twinx()
sns.lineplot(x='month', y='balance', data=df2.reset_index(), color='red', markers=True, ax=ax2)
# Case-2: This doesn't work (as intended, if months grow they will not auto-adjust max numbers to show and not sorted).
df2 = df1.groupby(['month_str']).agg({'amount':'sum','balance':'sum'})
sns.barplot(x='month_str', y='amount', data=df2.reset_index(), palette="Blues_d")
ax2 = plt.twinx()
sns.lineplot(x='month_str', y='balance', data=df2.reset_index(), color='red', markers=True, ax=ax2)
# Case-3: This doesn't work either.
df2 = df1.groupby(['month_dt']).agg({'amount':'sum','balance':'sum'})
sns.barplot(x='month_dt', y='amount', data=df2.reset_index(), palette="Blues_d")
ax2 = plt.twinx()
sns.lineplot(x='month_dt', y='balance', data=df2.reset_index(), color='red', markers=True, ax=ax2)
Case-1:
Traceback (most recent call last):
File "C:\...\lib\site-packages\seaborn\_decorators.py", line 46, in inner_f
return f(**kwargs)
File "C:\...\lib\site-packages\seaborn\relational.py", line 703, in lineplot
p.plot(ax, kwargs)
File "C:\...\lib\site-packages\seaborn\relational.py", line 529, in plot
line, = ax.plot(x, y, **kws)
File "C:\...\lib\site-packages\matplotlib\axes\_axes.py", line 1745, in plot
self.add_line(line)
File "C:\...\lib\site-packages\matplotlib\axes\_base.py", line 1964, in add_line
self._update_line_limits(line)
File "C:\...\lib\site-packages\matplotlib\axes\_base.py", line 1986, in _update_line_limits
path = line.get_path()
File "C:\...\lib\site-packages\matplotlib\lines.py", line 1011, in get_path
self.recache()
File "C:\...\lib\site-packages\matplotlib\lines.py", line 653, in recache
x = _to_unmasked_float_array(xconv).ravel()
File "C:\...\lib\site-packages\matplotlib\cbook\__init__.py", line 1289, in _to_unmasked_float_array
return np.asarray(x, float)
File "C:\...\lib\site-packages\numpy\core\_asarray.py", line 83, in asarray
return array(a, dtype, copy=False, order=order)
TypeError: float() argument must be a string or a number, not 'Period'
Case-2:
Case-3:
You can build a dummy x axis and replace the values with the month data after constructing the plot:
fig, ax = plt.subplots()
df2 = df1.groupby(['month']).agg({'amount':'sum','balance':'sum'})
# helper axis
data = df2.reset_index()
data['xaxis'] = range(len(data))
sns.barplot(x='xaxis', y='amount', data=data, palette="Blues_d", ax=ax)
ax2 = ax.twinx()
sns.lineplot(x='xaxis', y='balance', data=data, color='red', markers=True, ax=ax2)
# replace helper axis with actual data
ax.set_xticklabels(data['month'].values, rotation = 45, ha="right")
You don't need to create a dummy x axis. Try the following code:
import seaborn as sns
import pandas as pd
import numpy as np
# Sample dataframe.
df1 = pd.DataFrame({'date':pd.date_range(pd.datetime(2020,1,1), periods=699).tolist(), 'amount':range(1,700), 'balance':np.cumsum(range(1,700))})
df1.loc[:, 'month'] = df1['date'].dt.to_period("M")
df2 = df1.groupby(['month']).agg({'amount':'sum','balance':'sum'})
g = sns.barplot(x='month', y='amount', data=df2.reset_index(), palette="Blues_d")
g.set_xticklabels(g.get_xticklabels(), rotation=90)
sns.lineplot(x=range(len(df2.reset_index())), y='balance', data=df2.reset_index(), color='red', markers=True, ax=g.twinx())

using the matplotlib .pylot for drawing histogram and the smooth curve which lies on the histogram

I have tried to draw a histogram using matplotlib and the pandas but while drawing the smooth curve it gave me an error I can you please help to resolve this and maybe give me some method to draw the smooth curve on histogram using matplotlib I am trying not to use any another library (seaborn) here is the code
mu,sigma = 100,15
plt.style.use('dark_background')
x = mu + sigma * np.random.randn(10000)
n,bins,patches = plt.hist(x,bins=50,density=1,facecolor='g',alpha = 0.5)
zee=bins[:-1]
plt.plot(np.round(zee),patches,'ro')
plt.xlabel('Smarts')
plt.ylabel('Probablity')
plt.title('Histogram of the Iq')
plt.axis([40,160,0,0.03])
plt.grid(1)
plt.show()
the error shown is
python3 -u "/home/somesh/Downloads/vscode_code/python ml course /firstml.py"
Traceback (most recent call last):
File "/home/somesh/Downloads/vscode_code/python ml course /firstml.py", line 149, in <module>
plt.plot(np.round(zee),patches,'ro')
File "/home/somesh/.local/lib/python3.8/site-packages/matplotlib/pyplot.py", line 2840, in plot
return gca().plot(
File "/home/somesh/.local/lib/python3.8/site-packages/matplotlib/axes/_axes.py", line 1745, in plot
self.add_line(line)
File "/home/somesh/.local/lib/python3.8/site-packages/matplotlib/axes/_base.py", line 1964, in add_line
self._update_line_limits(line)
File "/home/somesh/.local/lib/python3.8/site-packages/matplotlib/axes/_base.py", line 1986, in _update_line_limits
path = line.get_path()
File "/home/somesh/.local/lib/python3.8/site-packages/matplotlib/lines.py", line 1011, in get_path
self.recache()
File "/home/somesh/.local/lib/python3.8/site-packages/matplotlib/lines.py", line 658, in recache
y = _to_unmasked_float_array(yconv).ravel()
File "/home/somesh/.local/lib/python3.8/site-packages/matplotlib/cbook/__init__.py", line 1289, in _to_unmasked_float_array
return np.asarray(x, float)
File "/home/somesh/.local/lib/python3.8/site-packages/numpy/core/_asarray.py", line 85, in asarray
return array(a, dtype, copy=False, order=order)
TypeError: float() argument must be a string or a number, not 'Rectangle'
and is this possible to draw the smooth curve using only the matplotlib library
edit 1: thanks for the answer I was finally able to spot the error
In your code, zee is a matplotlibobject Rectangle object. However, the plot function need a float as input.
Since what you are plotting is a normal distribution. Also, you like the curve to be smooth. So why not generate a normal distribution and plot it into same figure. Here is a modified version of your code.
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
mu,sigma = 100,15
plt.style.use('dark_background')
x = mu + sigma * np.random.randn(10000)
n,bins,patches = plt.hist(x,bins=50,density=1,facecolor='g',alpha = 0.5)
# zee=bins[:-1]
# plt.plot(np.round(zee),patches,'ro')
x_overlay = np.linspace(mu - 3*sigma, mu + 3*sigma, 100)
plt.plot(x_overlay, stats.norm.pdf(x_overlay, mu, sigma),"ro")
plt.xlabel('Smarts')
plt.ylabel('Probablity')
plt.title('Histogram of the Iq')
plt.axis([40,160,0,0.03])
plt.grid(1)
plt.show()
Output of the plot:
n has the same size with zee, which is length(bins)-1:
mu,sigma = 100,15
plt.style.use('dark_background')
x = mu + sigma * np.random.randn(10000)
n,bins,patches = plt.hist(x,bins=50,density=1,facecolor='g',alpha = 0.5)
zee=bins[:-1]
## this
plt.plot(np.round(zee),n,'ro')
Output:

subplotting different dataframes and using a dataframe as the x value

I have dataFrame which I obtain from a CSV file that contains a column which is the Time and 18 columns that are samples taken in that time. The first thing I do is to calculate the mean for each replicate and create three different data frame. with the following code:
data = pd.read_csv('growht.csv', delimiter=',', header=0)
file:
# read the following data in with
data = pd.read_clipboard(sep=',', index=False)
Time,WT5,WT5,WT5,WT1,WT1,WT1,NF5,NF5,NF5,NF1,NF1,NF1,D5,D5,D5,D1,D1,D1
9.7e-05,0.113,0.11900000000000001,0.11699999999999999,0.081,0.086,0.076,0.102,0.111,0.111,0.086,0.087,0.084,0.1,0.105,0.106,0.085,0.087,0.086
0.041737,0.122,0.121,0.126,0.075,0.07400000000000001,0.07400000000000001,0.10400000000000001,0.105,0.10300000000000001,0.075,0.073,0.073,0.1,0.09699999999999999,0.09699999999999999,0.075,0.073,0.073
0.08340299999999999,0.161,0.163,0.174,0.076,0.075,0.075,0.126,0.129,0.13,0.076,0.07400000000000001,0.07400000000000001,0.12,0.11900000000000001,0.11900000000000001,0.076,0.07400000000000001,0.07400000000000001
0.12507200000000002,0.285,0.307,0.303,0.079,0.079,0.079,0.175,0.188,0.191,0.077,0.07400000000000001,0.075,0.165,0.17,0.172,0.079,0.077,0.077
0.166738,0.34600000000000003,0.368,0.369,0.09,0.091,0.091,0.273,0.28300000000000003,0.292,0.078,0.076,0.077,0.255,0.27,0.278,0.08800000000000001,0.085,0.085
0.208404,0.418,0.461,0.418,0.113,0.122,0.121,0.366,0.41200000000000003,0.38,0.08,0.078,0.079,0.368,0.376,0.382,0.113,0.10400000000000001,0.106
0.25007399999999996,0.48,0.513,0.508,0.18,0.2,0.196,0.418,0.42100000000000004,0.43,0.08800000000000001,0.087,0.08900000000000001,0.446,0.47700000000000004,0.475,0.17300000000000001,0.155,0.158
0.29173699999999997,0.551,0.589,0.5920000000000001,0.311,0.33399999999999996,0.336,0.46399999999999997,0.47600000000000003,0.47,0.10400000000000001,0.105,0.10800000000000001,0.5379999999999999,0.544,0.542,0.24,0.22699999999999998,0.22699999999999998
0.3334,0.612,0.603,0.617,0.436,0.48100000000000004,0.446,0.514,0.556,0.53,0.14,0.147,0.154,0.59,0.644,0.629,0.361,0.35100000000000003,0.341
0.375066,0.682,0.685,0.703,0.516,0.505,0.47600000000000003,0.5670000000000001,0.605,0.5760000000000001,0.215,0.247,0.259,0.6559999999999999,0.72,0.735,0.456,0.41200000000000003,0.409
0.416733,0.7340000000000001,0.741,0.755,0.735,0.624,0.605,0.609,0.614,0.588,0.335,0.355,0.365,0.708,0.746,0.7490000000000001,0.523,0.495,0.494
0.4584,0.763,0.799,0.8420000000000001,0.748,0.682,0.6659999999999999,0.653,0.6759999999999999,0.655,0.42200000000000004,0.442,0.45299999999999996,0.759,0.809,0.81,0.629,0.5870000000000001,0.59
0.500066,0.802,0.858,0.8740000000000001,0.831,0.767,0.757,0.6809999999999999,0.705,0.684,0.47100000000000003,0.47,0.47200000000000003,0.816,0.863,0.8690000000000001,0.645,0.632,0.645
0.541733,0.852,0.893,0.903,0.863,0.748,0.731,0.7170000000000001,0.741,0.722,0.562,0.579,0.5760000000000001,0.872,0.927,0.9279999999999999,0.7070000000000001,0.675,0.6729999999999999
0.583399,0.927,0.907,0.9840000000000001,0.889,0.773,0.742,0.74,0.763,0.741,0.614,0.66,0.64,0.914,0.975,0.975,0.7290000000000001,0.698,0.693
0.625066,0.9590000000000001,0.956,1.041,0.892,0.7829999999999999,0.746,0.762,0.78,0.767,0.647,0.711,0.693,0.95,1.02,1.016,0.76,0.745,0.742
0.666733,0.987,1.04,1.035,0.8909999999999999,0.7959999999999999,0.807,0.769,0.7959999999999999,0.7859999999999999,0.7,0.731,0.718,0.978,1.058,1.047,0.789,0.782,0.782
0.708399,1.042,1.056,1.032,0.848,0.802,0.833,0.777,0.81,0.7979999999999999,0.737,0.782,0.775,0.9790000000000001,1.083,1.075,0.807,0.818,0.8170000000000001
0.750067,1.062,1.0979999999999999,1.0590000000000002,0.8540000000000001,0.8590000000000001,0.8490000000000001,0.785,0.815,0.8079999999999999,0.7929999999999999,0.828,0.804,0.973,1.102,1.091,0.831,0.851,0.85
0.791732,1.0959999999999999,1.102,1.069,0.8590000000000001,0.941,0.889,0.7709999999999999,0.802,0.797,0.809,0.853,0.825,0.956,1.0979999999999999,1.0859999999999999,0.836,0.875,0.872
0.8334,1.125,1.133,1.1,0.8690000000000001,0.9790000000000001,0.932,0.757,0.795,0.7909999999999999,0.835,0.884,0.8440000000000001,0.945,1.103,1.085,0.843,0.8859999999999999,0.889
0.875065,1.133,1.166,1.121,0.89,0.9990000000000001,0.975,0.7440000000000001,0.7829999999999999,0.7809999999999999,0.843,0.898,0.855,0.938,1.097,1.074,0.836,0.8959999999999999,0.8959999999999999
0.916733,1.136,1.198,1.119,0.92,1.056,0.9540000000000001,0.727,0.777,0.773,0.853,0.905,0.858,0.917,1.088,1.07,0.8220000000000001,0.8959999999999999,0.898
0.9584,1.119,1.202,1.115,0.9179999999999999,1.071,1.026,0.7140000000000001,0.7609999999999999,0.76,0.851,0.907,0.8490000000000001,0.904,1.075,1.055,0.812,0.8859999999999999,0.8909999999999999
1.000065,1.167,1.199,1.099,0.9079999999999999,1.093,1.006,0.6970000000000001,0.748,0.7509999999999999,0.835,0.902,0.843,0.889,1.069,1.0490000000000002,0.8009999999999999,0.885,0.892
data.columns = data.columns.str.replace('(\.\d+)$','') #with this I remove the .N when the columns has the same name
data_mean=data.mean(axis=1, level=0)#calculates the mean of the columns with the same name in the row axis
data_std=data.std(axis=1, level=0)
data_time=data.filter(like='Time')
data_WT=data_mean.filter(like='WT')
data_NF=data_mean.filter(like='NF')
data_D=data_mean.filter(like='D')
now with the code above I create new dataframes that only contains the columns with specific titles. So I have three different dataFrames with 2 columns and 24 rows each. Which I manage to plot them in the same figure using the following code:
fig, axes = plt.subplots(nrows=1, ncols=3,squeeze=False,figsize=(10,5))
axes = axes.flatten()
data_WT.plot(ax=axes[0],yerr=data_std,fontsize=6,grid=True)
data_NF.plot(ax=axes[1],yerr=data_std,fontsize=6,grid=True)
data_D.plot(ax=axes[2],yerr=data_std,fontsize=6,grid=True)
the output look like this:
then I wanted to add scatter for each point in each graph for every point. For this I add the dataFrame data_time that contains the values of the time to use it as the x value in the scatter plot. However when I do this for the first subplot for example:
data_WT.plot.scatter(ax=axes[0],x=data_time,y=data_WT)
I have the following error:
Traceback (most recent call last):
File "c:/Users/Nico/Desktop/bioscreen.py", line 60, in <module>
data_WT.plot.scatter(ax=axes[0],x=data_time,y=data_WT)
File "C:\Users\Nico\AppData\Local\Programs\Python\Python38\lib\site-packages\pandas\plotting\_core.py", line 1499, in scatter
return self(kind="scatter", x=x, y=y, s=s, c=c, **kwargs)
File "C:\Users\Nico\AppData\Local\Programs\Python\Python38\lib\site-packages\pandas\plotting\_core.py", line 792, in __call__
return plot_backend.plot(data, x=x, y=y, kind=kind, **kwargs)
File "C:\Users\Nico\AppData\Local\Programs\Python\Python38\lib\site-packages\pandas\plotting\_matplotlib\__init__.py", line 61, in plot
plot_obj.generate()
File "C:\Users\Nico\AppData\Local\Programs\Python\Python38\lib\site-packages\pandas\plotting\_matplotlib\core.py", line 263, in generate
self._make_plot()
File "C:\Users\Nico\AppData\Local\Programs\Python\Python38\lib\site-packages\pandas\plotting\_matplotlib\core.py", line 970, in _make_plot
data[x].values,
File "C:\Users\Nico\AppData\Local\Programs\Python\Python38\lib\site-packages\pandas\core\frame.py", line 2806, in __getitem__
indexer = self.loc._get_listlike_indexer(key, axis=1, raise_missing=True)[1]
File "C:\Users\Nico\AppData\Local\Programs\Python\Python38\lib\site-packages\pandas\core\indexing.py", line 1551, in _get_listlike_indexer
self._validate_read_indexer(
File "C:\Users\Nico\AppData\Local\Programs\Python\Python38\lib\site-packages\pandas\core\indexing.py", line 1639, in _validate_read_indexer
raise KeyError(f"None of [{key}] are in the [{axis_name}]")
KeyError: "None of [Float64Index([ 9.73e-05, 0.041736991, 0.083402986,\n 0.125072396, 0.166737708, 0.20840449100000003,\n 0.250073843, 0.29173736100000003, 0.333400081,\n 0.375066481, 0.41673263899999996, 0.458399595,\n 0.500066227,
0.541732743, 0.583399375,\n 0.625065949, 0.666732685, 0.7083994790000001,\n
0.75006728, 0.79173228, 0.833399606,\n 0.875064988, 0.916732766, 0.958400093,\n 1.000065417],\n dtype='float64')] are in the [columns]"
Any suggestion in how to overcome this error, I have been reading but I can't make any answer to help me with this.
Thank you.
It's easiest to resolve the issue by setting Time as the index.
Also using seaborn.scatterplot is easier to add the scatter plot.
Seaborn is a Python data visualization library based on matplotlib. It provides a high-level interface for drawing attractive and informative statistical graphics.
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# read the file in with Time as the index
data = pd.read_csv('growth.csv', delimiter=',', header=0, index_col='Time')
# change the column names
data.columns = data.columns.str.replace('(\.\d+)$','')
# don't change any of this code, but data_time isn't needed
data_mean=data.mean(axis=1, level=0)
data_std=data.std(axis=1, level=0)
data_WT=data_mean.filter(like='WT')
data_NF=data_mean.filter(like='NF')
data_D=data_mean.filter(like='D')
# plot
fig, axes = plt.subplots(nrows=1, ncols=3, squeeze=False, figsize=(16, 8))
axes = axes.flatten()
data_WT.plot(ax=axes[0],yerr=data_std,fontsize=6,grid=True)
sns.scatterplot(data=data_WT, ax=axes[0])
data_NF.plot(ax=axes[1],yerr=data_std,fontsize=6,grid=True)
sns.scatterplot(data=data_NF, ax=axes[1])
data_D.plot(ax=axes[2],yerr=data_std,fontsize=6,grid=True)
sns.scatterplot(data=data_D, ax=axes[2])
If all you're trying to accomplish with the scatter plot is to add the markers to the plot, then you can use the marker parameter when making each plot
fig, axes = plt.subplots(nrows=1, ncols=3, squeeze=False, figsize=(16, 8))
axes = axes.flatten()
data_WT.plot(ax=axes[0],yerr=data_std,fontsize=6,grid=True, marker='o')
data_NF.plot(ax=axes[1],yerr=data_std,fontsize=6,grid=True, marker='o')
data_D.plot(ax=axes[2],yerr=data_std,fontsize=6,grid=True, marker='o')

Stop x-axis labels from shrinking the plot in Matplotlib?

I'm trying to make a bar graph with the following code:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
test = {'names':['a','b','abcdefghijklmnopqrstuvwxyz123456789012345678901234567890'], 'values':[1,2,3]}
df = pd.DataFrame(test)
plt.rcParams['figure.autolayout'] = False
ax = sns.barplot(x='names', y='values', data=df)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
plt.show()
But I get the following error because the long value in 'names' as a label on the x-axis is making the image shrink until the bottom is above the top.
Traceback (most recent call last):
File "C:/Users/Adam/.PyCharm2018.2/config/scratches/scratch.py", line 11, in <module>
plt.show()
File "C:\Anaconda3\lib\site-packages\matplotlib\pyplot.py", line 253, in show
return _show(*args, **kw)
File "C:\Program Files\JetBrains\PyCharm 2018.2.3\helpers\pycharm_matplotlib_backend\backend_interagg.py", line 25, in __call__
manager.show(**kwargs)
File "C:\Program Files\JetBrains\PyCharm 2018.2.3\helpers\pycharm_matplotlib_backend\backend_interagg.py", line 107, in show
self.canvas.show()
File "C:\Program Files\JetBrains\PyCharm 2018.2.3\helpers\pycharm_matplotlib_backend\backend_interagg.py", line 62, in show
self.figure.tight_layout()
File "C:\Anaconda3\lib\site-packages\matplotlib\figure.py", line 2276, in tight_layout
self.subplots_adjust(**kwargs)
File "C:\Anaconda3\lib\site-packages\matplotlib\figure.py", line 2088, in subplots_adjust
self.subplotpars.update(*args, **kwargs)
File "C:\Anaconda3\lib\site-packages\matplotlib\figure.py", line 245, in update
raise ValueError('bottom cannot be >= top')
ValueError: bottom cannot be >= top
Here is what it looks like if I reduce the length of that name slightly:
How can I get it to expand the figure to fit the label instead of shrinking the axes?
One workaround is to create the Axes instance yourself as axes, not as subplot. Then tight_layout() has no effect, even if it's called internally. You can then pass the Axes with the ax keyword to sns.barplot. The problem now is that if you call plt.show() the label may be cut off, but if you call savefig with bbox_inches='tight', the figure size will be extended to contain both the figure and all labels:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
test = {'names':['a','b','abcdefghijklmnopqrstuvwxyz123456789012345678901234567890'], 'values':[1,2,3]}
df = pd.DataFrame(test)
#plt.rcParams['figure.autolayout'] = False
ax = sns.barplot(x='names', y='values', data=df, ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
#plt.show()
fig.savefig('long_label.png', bbox_inches='tight')
PROCLAIMER: I don't have pycharm, so there goes the assumption in this code, that matplotlib behaves the same with and without pycharm. Anyway, for me the outcome looks like this:
If you want this in an interactive backend I didn't find any other way than manually adjust the figure size. This is what I get using the qt5agg backend:
ax = sns.barplot(x='names', y='values', data=df)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.figure.set_size_inches(5, 8) # manually adjust figure size
plt.tight_layout() # automatically adjust elements inside the figure
plt.show()
Note that pycharm's scientific mode might be doing some magic that prevents this to work so you might need to deactivate it or just run the script outside pycharm.

Resources