How to download matplotlib graphs generated in a Streamlit app - python-3.x

Is there any way to allow users to download a graph made using matplotlib/seaborn in a Streamlit app? I managed to get a download button working to download a dataframe as a csv, but I can't figure out how to allow users to download a graph. I have provided a code snippet below of the graph code.
Thank you for your help!
fig_RFC_scatter, ax = plt.subplots(1,1, figsize = (5,4))
ax = sns.scatterplot(data = RFC_results_df_subset, x = X_Element_RM, y = Y_Element_RM, hue = "F_sil_remaining", edgecolor = "k", legend = "full")
ax = sns.scatterplot(data = RFC_results_df_subset, x = X_Element_CM, y = Y_Element_CM, hue = "F_sil_remaining", edgecolor = "k", legend = None, marker = "s")
ax = plt.xlabel(X_Element_RM)
ax = plt.ylabel(Y_Element_RM)
ax = plt.xlim(x_min, x_max)
ax = plt.ylim(y_min, y_max)
ax = plt.xscale(x_scale)
ax = plt.yscale(y_scale)
ax = plt.axhline(y = 1, color = "grey", linewidth = 0.5, linestyle = "--")
ax = plt.axvline(x = 1, color = "grey", linewidth = 0.5, linestyle = "--")
ax = plt.legend(bbox_to_anchor = (1, 0.87), frameon = False, title = "% Sil. Remaining")
st.write(fig_RFC_scatter)

Two approaches first save the image to file and offer for download and second save image to memory (no clutter on disk) and offer for download.
First
"""
Ref: https://docs.streamlit.io/library/api-reference/widgets/st.download_button
"""
import io
import matplotlib.pyplot as plt
import seaborn as sns
import streamlit as st
X = [1, 2, 3, 4, 5, 6, 7, 8]
Y = [1500, 1550, 1600, 1640, 1680, 1700, 1760, 1800]
sns.scatterplot(x=X, y=Y)
# Save to file first or an image file has already existed.
fn = 'scatter.png'
plt.savefig(fn)
with open(fn, "rb") as img:
btn = st.download_button(
label="Download image",
data=img,
file_name=fn,
mime="image/png"
)
Second
Save to memory first.
fn = 'scatter.png'
img = io.BytesIO()
plt.savefig(img, format='png')
btn = st.download_button(
label="Download image",
data=img,
file_name=fn,
mime="image/png"
)
Downloaded file output

Related

Add 3 or more legends to a seaborn clustermap

Using another answer, I'm wondering if it is possible to add 3 or more legends? Adapting the code from the author, I could add 4 row labels, but adding the legend is tricky. If I add more row_dendrogram and col_dendrogram, they simply do not show independently from the others.
import seaborn as sns
from matplotlib.pyplot import gcf
networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])
# Label 1
network_labels = networks.columns.get_level_values("network")
network_pal = sns.cubehelix_palette(network_labels.unique().size, light=.9, dark=.1, reverse=True, start=1, rot=-2)
network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
network_colors = pd.Series(network_labels, index=networks.columns).map(network_lut)
# Label 2
node_labels = networks.columns.get_level_values("node")
node_pal = sns.cubehelix_palette(node_labels.unique().size)
node_lut = dict(zip(map(str, node_labels.unique()), node_pal))
node_colors = pd.Series(node_labels, index=networks.columns).map(node_lut)
# Label 3
lab3_labels = networks.columns.get_level_values("node")
lab3_pal = sns.color_palette("hls", lab3_labels.unique().size)
lab3_lut = dict(zip(map(str, lab3_labels.unique()), lab3_pal))
lab3_colors = pd.Series(lab3_labels, index=networks.columns, name='lab3').map(lab3_lut)
# Label 4
lab4_labels = networks.columns.get_level_values("node")
lab4_pal = sns.color_palette("husl", lab4_labels.unique().size)
lab4_lut = dict(zip(map(str, lab4_labels.unique()), lab4_pal))
lab4_colors = pd.Series(lab4_labels, index=networks.columns, name='lab4').map(lab4_lut)
network_node_colors = pd.DataFrame(network_colors).join(pd.DataFrame(node_colors)).join(pd.DataFrame(lab3_colors)).join(pd.DataFrame(lab4_colors))
g = sns.clustermap(networks.corr(),
row_cluster=False, col_cluster=False,
row_colors = network_node_colors,
col_colors = network_node_colors,
linewidths=0,
xticklabels=False, yticklabels=False,
center=0, cmap="vlag")
# add legends
for label in network_labels.unique():
g.ax_col_dendrogram.bar(0, 0, color=network_lut[label], label=label, linewidth=0);
l1 = g.ax_col_dendrogram.legend(title='Network', loc="center", ncol=5, bbox_to_anchor=(0.47, 0.89), bbox_transform=gcf().transFigure)
for label in node_labels.unique():
g.ax_row_dendrogram.bar(0, 0, color=node_lut[label], label=label, linewidth=0);
l2 = g.ax_row_dendrogram.legend(title='Node', loc="center", ncol=2, bbox_to_anchor=(0.8, 0.89), bbox_transform=gcf().transFigure)
#how to add other row dendrograms here without them overlapping with the existing ones?
plt.show()
I believe the problem here is that one cannot directly access the axes of the plot. The legend is based on the bar graph, the row which you add. I have found the following workaround which, tbh, is not nice. But working. It follows the classical matplotlib problem of adding an artist to an ax, you can read more about it in the following posts:
1, 2, 3 and in the docs.
So what I do is that I save the objects of the bar plot when I create them and then later form the legend out of them. The full code is below. But maybe I would recommend contacting the author and raising a question/issue there.
import seaborn as sns
from matplotlib.pyplot import gcf
import matplotlib.pyplot as plt
# fig, axs = plt.subplots()
networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])
# Label 1
network_labels = networks.columns.get_level_values("network")
network_pal = sns.cubehelix_palette(network_labels.unique().size, light=.9, dark=.1, reverse=True, start=1, rot=-2)
network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
network_colors = pd.Series(network_labels, index=networks.columns).map(network_lut)
# Label 2
node_labels = networks.columns.get_level_values("node")
node_pal = sns.cubehelix_palette(node_labels.unique().size)
node_lut = dict(zip(map(str, node_labels.unique()), node_pal))
node_colors = pd.Series(node_labels, index=networks.columns).map(node_lut)
# Label 3
lab3_labels = networks.columns.get_level_values("node")
lab3_pal = sns.color_palette("hls", lab3_labels.unique().size)
lab3_lut = dict(zip(map(str, lab3_labels.unique()), lab3_pal))
lab3_colors = pd.Series(lab3_labels, index=networks.columns, name='lab3').map(lab3_lut)
# Label 4
lab4_labels = networks.columns.get_level_values("node")
lab4_pal = sns.color_palette("husl", lab4_labels.unique().size)
lab4_lut = dict(zip(map(str, lab4_labels.unique()), lab4_pal))
lab4_colors = pd.Series(lab4_labels, index=networks.columns, name='lab4').map(lab4_lut)
network_node_colors = pd.DataFrame(network_colors).join(pd.DataFrame(node_colors)).join(pd.DataFrame(lab3_colors)).join(pd.DataFrame(lab4_colors))
g = sns.clustermap(networks.corr(),
row_cluster=False, col_cluster=False,
row_colors = network_node_colors,
col_colors = network_node_colors,
linewidths=0,
xticklabels=False, yticklabels=False,
center=0, cmap="vlag")
# add legends
for label in network_labels.unique():
g.ax_col_dendrogram.bar(0, 0, color=network_lut[label], label=label, linewidth=0);
l1 = g.ax_col_dendrogram.legend(title='Network', loc="center", ncol=5, bbox_to_anchor=(0.35, 0.89), bbox_transform=gcf().transFigure)
for label in node_labels.unique():
g.ax_row_dendrogram.bar(0, 0, color=node_lut[label], label=label, linewidth=0);
l2 = g.ax_row_dendrogram.legend(title='Node', loc="center", ncol=2, bbox_to_anchor=(0.66, 0.89), bbox_transform=gcf().transFigure)
# create a list for the bar plot patches
xx = []
for label in lab3_labels.unique():
x = g.ax_row_dendrogram.bar(0, 0, color=lab3_lut[label], label=label, linewidth=0)
xx.append(x)
# add the legend
legend3 = plt.legend(xx, lab3_labels.unique(), loc="center", title='lab3', bbox_to_anchor=(.78, 0.89), bbox_transform=gcf().transFigure)
# create a list for the bar plot patches
yy = []
for label in lab4_labels.unique():
y = g.ax_row_dendrogram.bar(0, 0, color=lab4_lut[label], label=label, linewidth=0)
yy.append(y)
# add the second legend
legend4 = plt.legend(yy, lab4_labels.unique(), loc="center", title='lab4', ncol=2, bbox_to_anchor=(.9, 0.89), bbox_transform=gcf().transFigure)
plt.gca().add_artist(legend3)

How to fix 'Dropdown Menu Read' Error in Plotly Dash

I have tried to re-create the following example Towards Data Science Example shown on the web
I have written the following code which I modified to this:
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import pandas as pd
import plotly.graph_objs as go
# Step 1. Launch the application
app = dash.Dash()
# Step 2. Import the dataset
filepath = 'https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv'
st = pd.read_csv(filepath)
# range slider options
st['Date'] = pd.to_datetime(st.Date)
dates = ['2015-02-17', '2015-05-17', '2015-08-17', '2015-11-17',
'2016-02-17', '2016-05-17', '2016-08-17', '2016-11-17', '2017-02-17']
features = st.columns[1:-1]
opts = [{'label' : i, 'value' : i} for i in features]
# Step 3. Create a plotly figure
trace_1 = go.Scatter(x = st.Date, y = st['AAPL.High'],
name = 'AAPL HIGH',
line = dict(width = 2,
color = 'rgb(229, 151, 50)'))
layout = go.Layout(title = 'Time Series Plot',
hovermode = 'closest')
fig = go.Figure(data = [trace_1], layout = layout)
# Step 4. Create a Dash layout
app.layout = html.Div([
# a header and a paragraph
html.Div([
html.H1("This is my first dashboard"),
html.P("Dash is so interesting!!")
],
style = {'padding' : '50px' ,
'backgroundColor' : '#3aaab2'}),
# adding a plot
dcc.Graph(id = 'plot', figure = fig),
# dropdown
html.P([
html.Label("Choose a feature"),
dcc.Dropdown(
id='opt',
options=opts,
value=features[0],
multi=True
),
# range slider
html.P([
html.Label("Time Period"),
dcc.RangeSlider(id = 'slider',
marks = {i : dates[i] for i in range(0, 9)},
min = 0,
max = 8,
value = [1, 7])
], style = {'width' : '80%',
'fontSize' : '20px',
'padding-left' : '100px',
'display': 'inline-block'})
])
])
# Step 5. Add callback functions
#app.callback(Output('plot', 'figure'),
[Input('opt', 'value'),
Input('slider', 'value')])
def update_figure(input1, input2):
# filtering the data
st2 = st[(st.Date > dates[input2[0]]) & (st.Date < dates[input2[1]])]
# updating the plot
trace_1 = go.Scatter(x = st2.Date, y = st2['AAPL.High'],
name = 'AAPL HIGH',
line = dict(width = 2,
color = 'rgb(229, 151, 50)'))
trace_2 = go.Scatter(x = st2.Date, y = st2[input1],
name = str(input1),
line = dict(width = 2,
color = 'rgb(106, 181, 135)'))
fig = go.Figure(data = [trace_1, trace_2], layout = layout)
return fig
# Step 6. Add the server clause
if __name__ == '__main__':
app.run_server(debug = True)
When I change the feature input, it does not update the plot correctly and does not show the selected features in the plot.
Either there is something wrong with the callback function or the initialization of the graph with the second trace. But I cant figure out where the issue is.
As you are only providing two scatter traces within your callback. From both, one is static for 'AAPL.High'. So you need to limit the dropdown values to Multi=False.
Valid plots are only generated for choosing options like 'AAPL.LOW' and others like dic won't display a second trace. The callback wouldn't terminate if you would keepmulti=True the callback would stil work, if always only one option is selected. The moment you select two or more options the script will fail as it would try to find faulty data for the data return block here:
trace_2 = go.Scatter(x = st2.Date, y = st2[**MULTIINPUT**],
name = str(input1),
line = dict(width = 2,
color = 'rgb(106, 181, 135)'))
Only one column id is allowed to be passed at MULTIINPUT. If you want to introduce more traces please use a for loop.
Change the code to the following:
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import pandas as pd
import plotly.graph_objs as go
# Step 1. Launch the application
app = dash.Dash()
# Step 2. Import the dataset
filepath = 'https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv'
st = pd.read_csv(filepath)
# range slider options
st['Date'] = pd.to_datetime(st.Date)
dates = ['2015-02-17', '2015-05-17', '2015-08-17', '2015-11-17',
'2016-02-17', '2016-05-17', '2016-08-17', '2016-11-17', '2017-02-17']
features = st.columns
opts = [{'label' : i, 'value' : i} for i in features]
# Step 3. Create a plotly figure
trace_1 = go.Scatter(x = st.Date, y = st['AAPL.High'],
name = 'AAPL HIGH',
line = dict(width = 2,
color = 'rgb(229, 151, 50)'))
layout = go.Layout(title = 'Time Series Plot',
hovermode = 'closest')
fig = go.Figure(data = [trace_1], layout = layout)
# Step 4. Create a Dash layout
app.layout = html.Div([
# a header and a paragraph
html.Div([
html.H1("This is a Test Dashboard"),
html.P("Dash is great!!")
],
style = {'padding' : '50px' ,
'backgroundColor' : '#3aaab2'}),
# adding a plot
dcc.Graph(id = 'plot', figure = fig),
# dropdown
html.P([
html.Label("Choose a feature"),
dcc.Dropdown(
id='opt',
options=opts,
value=features[0],
multi=False
),
# range slider
html.P([
html.Label("Time Period"),
dcc.RangeSlider(id = 'slider',
marks = {i : dates[i] for i in range(0, 9)},
min = 0,
max = 8,
value = [1, 7])
], style = {'width' : '80%',
'fontSize' : '20px',
'padding-left' : '100px',
'display': 'inline-block'})
])
])
# Step 5. Add callback functions
#app.callback(Output('plot', 'figure'),
[Input('opt', 'value'),
Input('slider', 'value')])
def update_figure(input1, input2):
# filtering the data
st2 = st#[(st.Date > dates[input2[0]]) & (st.Date < dates[input2[1]])]
# updating the plot
trace_1 = go.Scatter(x = st2.Date, y = st2['AAPL.High'],
name = 'AAPL HIGH',
line = dict(width = 2,
color = 'rgb(229, 151, 50)'))
trace_2 = go.Scatter(x = st2.Date, y = st2[input1],
name = str(input1),
line = dict(width = 2,
color = 'rgb(106, 181, 135)'))
fig = go.Figure(data = [trace_1, trace_2], layout = layout)
return fig
# Step 6. Add the server clause
if __name__ == '__main__':
app.run_server(debug = True)
I hope this cleared things up and solved your issues. :)

How can I use the plotly dropdown menu feature to update the z value in my choropleth map?

I just want to create a menu on the plot where I'm able to change the z-value in data only. I tried looking at other examples on here: https://plot.ly/python/dropdowns/#restyle-dropdown but it was hard since the examples were not exactly similar to my plot.
import plotly
import plotly.plotly as py
import plotly.graph_objs as go
import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2014_world_gdp_with_codes.csv')
data = [go.Choropleth(
locations = df['CODE'],
z = df['GDP (BILLIONS)'],
text = df['COUNTRY'],
colorscale = [
[0, "rgb(5, 10, 172)"],
[0.35, "rgb(40, 60, 190)"],
[0.5, "rgb(70, 100, 245)"],
[0.6, "rgb(90, 120, 245)"],
[0.7, "rgb(106, 137, 247)"],
[1, "rgb(220, 220, 220)"]
],
autocolorscale = False,
reversescale = True,
marker = go.choropleth.Marker(
line = go.choropleth.marker.Line(
color = 'rgb(180,180,180)',
width = 0.5
)),
colorbar = go.choropleth.ColorBar(
tickprefix = '$',
title = 'GDP<br>Billions US$'),
)]
layout = go.Layout(
title = go.layout.Title(
text = '2014 Global GDP'
),
geo = go.layout.Geo(
showframe = False,
showcoastlines = False,
projection = go.layout.geo.Projection(
type = 'equirectangular'
)
),
annotations = [go.layout.Annotation(
x = 0.55,
y = 0.1,
xref = 'paper',
yref = 'paper',
text = 'Source: <a href="https://www.cia.gov/library/publications/the-world-factbook/fields/2195.html">\
CIA World Factbook</a>',
showarrow = False
)]
)
fig = go.Figure(data = data, layout = layout)
py.iplot(fig, filename = 'd3-world-map')
It's been a while since this was asked, but I figured it was still worth answering. I can't speak to how this might have changed since it was asked in 2019, but this works today.
First, I'll provide the code I used to create the new z values and the dropdown menu, then I'll provide all of the code I used to create these graphs in one chunk (easier to cut and paste...and all that).
This is the data I used for the alternate data in the z field.
import plotly.graph_objects as go
import pandas as pd
import random
z2 = df['GDP (BILLIONS)'] * .667 + 12
random.seed(21)
random.shuffle(z2)
df['z2'] = z2 # example as another column in df
print(df.head()) # validate as expected
z3 = df['GDP (BILLIONS)'] * .2 + 1000
random.seed(231)
random.shuffle(z3) # example as a series outside of df
z4 = df['GDP (BILLIONS)']**(1/3) * df['GDP (BILLIONS)']**(1/2)
random.seed(23)
random.shuffle(z4)
z4 = z4.tolist() # example as a basic Python list
To add buttons to change z, you'll add updatemenus to your layout. Each dict() is a separate dropdown option. At a minimum, each button requires a method, a label, and args. These represent what is changing (method for data, layout, or both), what it's called in the dropdown (label), and the new information (the new z in this example).
args for changes to data (where the method is either restyle or update) can also include the trace the change applies to. So if you had a bar chart and a line graph together, you may have a button that only changes the bar graph.
Using the same structure you have:
updatemenus = [go.layout.Updatemenu(
x = 1, xanchor = 'right', y = 1.15, type = "dropdown",
pad = {'t': 5, 'r': 20, 'b': 5, 'l': 30}, # around all buttons (not indiv buttons)
buttons = list([
dict(
args = [{'z': [df['GDP (BILLIONS)']]}], # original data; nest data in []
label = 'Return to the Original z',
method = 'restyle' # restyle is for trace updates
),
dict(
args = [{'z': [df['z2']]}], # nest data in []
label = 'A different z',
method = 'restyle'
),
dict(
args = [{'z': [z3]}], # nest data in []
label = 'How about this z?',
method = 'restyle'
),
dict(
args = [{'z': [z4]}], # nest data in []
label = 'Last option for z',
method = 'restyle'
)])
)]
All code used to create this graph in one chunk (includes code shown above).
import plotly.graph_objs as go
import pandas as pd
import ssl
import random
# to collect data without an error
ssl._create_default_https_context = ssl._create_unverified_context
# data used in plot
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2014_world_gdp_with_codes.csv')
# z values used in buttons
z2 = df['GDP (BILLIONS)'] * .667 + 12
random.seed(21)
random.shuffle(z2)
df['z2'] = z2 # example as another column in the data frame
print(df.head()) # validate as expected
z3 = df['GDP (BILLIONS)'] * .2 + 1000
random.seed(231)
random.shuffle(z3) # example as a series outside of the data frame
z4 = df['GDP (BILLIONS)']**(1/3) * df['GDP (BILLIONS)']**(1/2)
random.seed(23)
random.shuffle(z4)
z4 = z4.tolist() # example as a basic Python list
data = [go.Choropleth(
locations = df['CODE'], z = df['GDP (BILLIONS)'], text = df['COUNTRY'],
colorscale = [
[0, "rgb(5, 10, 172)"],
[0.35, "rgb(40, 60, 190)"],
[0.5, "rgb(70, 100, 245)"],
[0.6, "rgb(90, 120, 245)"],
[0.7, "rgb(106, 137, 247)"],
[1, "rgb(220, 220, 220)"]],
reversescale = True,
marker = go.choropleth.Marker(
line = go.choropleth.marker.Line(
color = 'rgb(180,180,180)', width = 0.5)),
colorbar = go.choropleth.ColorBar(
tickprefix = '$',
title = 'GDP<br>Billions US$',
len = .6) # I added this for aesthetics
)]
layout = go.Layout(
title = go.layout.Title(text = '2014 Global GDP'),
geo = go.layout.Geo(
showframe = False, showcoastlines = False,
projection = go.layout.geo.Projection(
type = 'equirectangular')
),
annotations = [go.layout.Annotation(
x = 0.55, y = 0.1, xref = 'paper', yref = 'paper',
text = 'Source: <a href="https://www.cia.gov/library/publications/the-world-factbook/fields/2195.html">\
CIA World Factbook</a>',
showarrow = False
)],
updatemenus = [go.layout.Updatemenu(
x = 1, xanchor = 'right', y = 1.15, type = "dropdown",
pad = {'t': 5, 'r': 20, 'b': 5, 'l': 30},
buttons = list([
dict(
args = [{'z': [df['GDP (BILLIONS)']]}], # original data; nest data in []
label = 'Return to the Original z',
method = 'restyle' # restyle is for trace updates only
),
dict(
args = [{'z': [df['z2']]}], # nest data in []
label = 'A different z',
method = 'restyle'
),
dict(
args = [{'z': [z3]}], # nest data in []
label = 'How about this z?',
method = 'restyle'
),
dict(
args = [{'z': [z4]}], # nest data in []
label = 'Last option for z',
method = 'restyle'
)])
)]
)
fig = go.Figure(data = data, layout = layout)
fig.show()

Matplotlib figure annotations outside of window

I am making a program that implements a matplotlib pie/donut chart into a tkinter window to illustrate some data, however, I have added "annotations" or labels from each wedge of the pie chart. Because of this the window that opens when I execute the code fits the chart itself, but the labels are cut off at the edges of the window. Specifically, it looks like this...
Note the top two arrows don't actually have text attached to the corresponding labels so the situation is actually worse than my screenshot depicts.
Even if I get rid of the code related to generating a tkinter GUI, and just try to execute code to generate a regular figure window the labels are initially cut-off. But, if I use the built in zoom-out functionality I can zoom out the make the labels fit.
I have tried to adjust the figsize here...
fig, ax = plt.subplots(figsize=(6, 4), subplot_kw=dict(aspect="equal"))
yet it makes no difference. Hopefully there is a solution, thanks...
Here is my full code if anyone needs...
import numpy as np
import matplotlib.pyplot as plt
player1_cards = {'Mustard', 'Plum', 'Revolver', 'Rope', 'Ballroom', 'Library'}
player2_cards = {'Scarlet', 'White', 'Candlestick'}
player3_cards = {'Green', 'Library', 'Kitchen', 'Conservatory'}
middle_cards = {'Peacock'}
unknown_cards = {'Lead Pipe', 'Wrench', 'Knife', 'Hall', 'Lounge', 'Dining Room', 'Study'}
player1_string = ', '.join(player1_cards)
player1_string = player1_string.replace(', ', '\n')
player2_string = ', '.join(player2_cards)
player2_string = player2_string.replace(', ', '\n')
player3_string = ', '.join(player3_cards)
player3_string = player3_string.replace(', ', '\n')
fig, ax = plt.subplots(figsize=(6, 4), subplot_kw=dict(aspect="equal"))
recipe = [player1_string, player2_string, player3_string, '', '']
data = [len(player1_cards), len(player2_cards), len(player3_cards), 1, 7]
cols = ['#339E5A', '#26823E', '#0C5D2E', '#98D6AE', '#5EC488']
wedges, texts = ax.pie(data, wedgeprops=dict(width=0.5), startangle=90, colors = cols)
for w in wedges:
w.set_linewidth(4)
w.set_edgecolor('white')
bbox_props = dict(boxstyle="square,pad=0.3", fc="w", ec="white", lw=0.72)
kw = dict(xycoords='data', textcoords='data', arrowprops=dict(arrowstyle="-"), bbox=bbox_props, zorder=0, va="center")
for i, p in enumerate(wedges):
ang = (p.theta2 - p.theta1)/2. + p.theta1
y = np.sin(np.deg2rad(ang))
x = np.cos(np.deg2rad(ang))
horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
connectionstyle = "angle,angleA=0,angleB={}".format(ang)
kw["arrowprops"].update({"connectionstyle": connectionstyle})
ax.annotate(recipe[i], xy=(x, y), xytext=(x + np.sign(x)*.5, y*1.5),
horizontalalignment=horizontalalignment, **kw, family = "Quicksand")
ax.set_title("Matplotlib bakery: A donut")
plt.show()
You would want to play around with the subplot parameters to make space for the text outside the axes.
fig.subplots_adjust(bottom=..., top=..., left=..., right=...)
E.g. in this case
fig.subplots_adjust(bottom=0.2, top=0.9)
seems to give a nice representation

Why can't I plot multiple scatter subplots in one figure for a data set from a DataFrame of Pandas (Python) in the way I plot subplots of histograms? [duplicate]

This question already has an answer here:
Matplotlib: create multiple subplot in one figure
(1 answer)
Closed 5 years ago.
I am learning how to play with matplotlib recently. However, some problems come up. I read in a non-standard data file named students.data with the following command.
student_dataset = pd.read_csv("students.data", index_col=0)
Here is how students.data looks like.
Then I plot a figure with four subplots of histograms in it with the following commands.
fig = plt.figure(0) #Use it to create subplots.
fig.subplots_adjust(hspace=0.5, wspace=0.5) #Adjust height-spacing to
#de-overlap titles and ticks
ax1 = fig.add_subplot(2, 2, 1)
my_series1 = student_dataset["G1"]
my_series1.plot.hist(alpha=0.5, color = "blue", histtype = "bar", bins = 30)
ax2 = fig.add_subplot(2, 2, 2)
my_series2 = student_dataset["G2"]
my_series2.plot.hist(alpha=1, color = "green", histtype = "step", bins = 20)
ax3 = fig.add_subplot(2, 2, 3)
my_series3 = student_dataset["G3"]
my_series3.plot.hist(alpha=0.5, color = "red", histtype = "stepfilled")
ax4 = fig.add_subplot(2, 2, 4)
my_series1.plot.hist(alpha=0.5, color = "blue")
my_series2.plot.hist(alpha=0.5, color = "green")
my_series3.plot.hist(alpha=0.5, color = "red")
And the result is exactly the stuff I want. However, as I try to do so for scatter subplots, they are separated in different figures. And I cannot figure out why. Here are the commands.
fig = plt.figure(2)
ax1 = fig.add_subplot(2, 2, 1)
student_dataset.plot.scatter(x = "freetime", y = "G1")
ax2 = fig.add_subplot(2, 2, 2)
student_dataset.plot.scatter(x = "freetime", y = "G2")
ax3 = fig.add_subplot(2, 2, 3)
student_dataset.plot.scatter(x = "freetime", y = "G3")
After searching for a day, I find the solution that almost fits my target. But, still, why? Why my original method is not working?
Here are the new commands and the result.
fig, axes = plt.subplots(2, 2, figsize=(6, 6), sharex=False, sharey=False)
x = student_dataset["freetime"].values
for i in range(3):
axes[i//2, i%2].scatter(x, student_dataset.iloc[:, i + 25].values)
fig.tight_layout()
Sorry that I cannot put more images in this post to describe my question. Hope you can understand my point.
Thanks in advance.
You may choose to use option 2 of the linked question,
fig = plt.figure(2)
ax1 = fig.add_subplot(2, 2, 1)
student_dataset.plot.scatter(x = "freetime", y = "G1", ax=ax1)
ax2 = fig.add_subplot(2, 2, 2)
student_dataset.plot.scatter(x = "freetime", y = "G2", ax=ax2)
ax3 = fig.add_subplot(2, 2, 3)
student_dataset.plot.scatter(x = "freetime", y = "G3", ax=ax3)
If you don't specify ax, pandas will produce a new figure.
At the moment I don't have any good explanation for why plot.hist does not require the ax keyword; it probably has to do with it directly calling the plt.hist function instead of preprocessing the data first.

Resources