Chapter 44: Data Visualization
Data in a table tells you facts. Data in a chart tells you a story.
A bar chart reveals which product dominates. A scatter plot shows whether two variables are related. A histogram shows you the shape of your data in seconds. Visualization is how you turn numbers into understanding — and how you communicate that understanding to others.
This chapter covers matplotlib for full control and seaborn for beautiful statistical plots with almost no code. Both work directly with pandas DataFrames.
Setup
pip install matplotlib seaborn pandas numpy
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
How matplotlib Works
matplotlib uses a figure/axes model:
- Figure — the entire image (the canvas)
- Axes — one plot inside the figure (you can have multiple)
fig, ax = plt.subplots() # create figure and one axes
ax.plot([1, 2, 3], [4, 5, 6])
plt.show()
Always use fig, ax = plt.subplots() — it gives you explicit control. The older plt.plot() shorthand works for quick scripts but gets confusing with multiple subplots.
Line Plot
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2 * np.pi, 100)
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(x, np.sin(x), label="sin(x)", color="steelblue", linewidth=2)
ax.plot(x, np.cos(x), label="cos(x)", color="coral", linewidth=2, linestyle="--")
ax.set_title("Sine and Cosine", fontsize=14)
ax.set_xlabel("x (radians)")
ax.set_ylabel("y")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("line_plot.png", dpi=150)
plt.show()
Key options:
figsize=(width, height)— inchescolor=— named color, hex, or RGB tuplelinewidth=,linestyle=("solid", "dashed", "--", "-.", ":")label=+ax.legend()— add a legendax.grid(True)— gridlinesplt.tight_layout()— prevent clippingplt.savefig("file.png", dpi=150)— save to file
Bar Chart
categories = ["Electronics", "Clothing", "Food", "Books", "Sports"]
revenue = [45000, 32000, 28000, 15000, 19000]
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(categories, revenue, color="steelblue", edgecolor="white", width=0.6)
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(
bar.get_x() + bar.get_width() / 2,
height + 500,
f"${height:,}",
ha="center", va="bottom", fontsize=9
)
ax.set_title("Revenue by Category", fontsize=14)
ax.set_xlabel("Category")
ax.set_ylabel("Revenue ($)")
ax.set_ylim(0, max(revenue) * 1.15)
plt.tight_layout()
plt.show()
Horizontal bar chart
fig, ax = plt.subplots(figsize=(8, 5))
ax.barh(categories, revenue, color="steelblue")
ax.set_xlabel("Revenue ($)")
ax.invert_yaxis() # largest bar at the top
plt.tight_layout()
plt.show()
Histogram
np.random.seed(42)
data = np.random.normal(loc=70, scale=12, size=500)
fig, ax = plt.subplots(figsize=(8, 5))
ax.hist(data, bins=30, color="steelblue", edgecolor="white", alpha=0.8)
# Add vertical lines for mean and median
ax.axvline(data.mean(), color="red", linestyle="--", label=f"Mean {data.mean():.1f}")
ax.axvline(np.median(data), color="green", linestyle="--", label=f"Median {np.median(data):.1f}")
ax.set_title("Distribution of Exam Scores")
ax.set_xlabel("Score")
ax.set_ylabel("Count")
ax.legend()
plt.tight_layout()
plt.show()
Scatter Plot
np.random.seed(0)
hours_studied = np.random.uniform(1, 10, 80)
exam_score = 50 + hours_studied * 4 + np.random.normal(0, 5, 80)
fig, ax = plt.subplots(figsize=(7, 5))
scatter = ax.scatter(
hours_studied,
exam_score,
c=exam_score, # colour by score value
cmap="RdYlGn", # red=low, yellow=mid, green=high
alpha=0.7,
edgecolors="white",
linewidths=0.5,
s=60, # marker size
)
# Trend line
m, b = np.polyfit(hours_studied, exam_score, 1)
x_line = np.linspace(hours_studied.min(), hours_studied.max(), 100)
ax.plot(x_line, m * x_line + b, color="navy", linewidth=2, label="Trend")
plt.colorbar(scatter, ax=ax, label="Exam Score")
ax.set_title("Hours Studied vs Exam Score")
ax.set_xlabel("Hours Studied")
ax.set_ylabel("Exam Score")
ax.legend()
plt.tight_layout()
plt.show()
Pie / Donut Chart
labels = ["Electronics", "Clothing", "Food", "Books", "Sports"]
sizes = [35, 25, 20, 10, 10]
explode = [0.05, 0, 0, 0, 0] # pull out the first slice
fig, ax = plt.subplots(figsize=(7, 6))
wedges, texts, autotexts = ax.pie(
sizes,
labels=labels,
explode=explode,
autopct="%1.1f%%",
startangle=140,
colors=["#4C72B0", "#DD8452", "#55A868", "#C44E52", "#8172B2"],
)
# Donut variant — add a white circle in the centre
centre_circle = plt.Circle((0, 0), 0.65, fc="white")
ax.add_artist(centre_circle)
ax.set_title("Revenue Share by Category")
plt.tight_layout()
plt.show()
Box Plot
Box plots show distribution, median, quartiles, and outliers at a glance.
np.random.seed(7)
data = {
"Math": np.random.normal(72, 10, 100),
"Science": np.random.normal(68, 15, 100),
"English": np.random.normal(75, 8, 100),
"History": np.random.normal(65, 12, 100),
}
fig, ax = plt.subplots(figsize=(8, 5))
ax.boxplot(
data.values(),
labels=data.keys(),
patch_artist=True, # fill boxes with colour
medianprops=dict(color="black", linewidth=2),
)
ax.set_title("Score Distribution by Subject")
ax.set_ylabel("Score")
ax.grid(True, axis="y", alpha=0.3)
plt.tight_layout()
plt.show()
Multiple Subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
fig.suptitle("Sales Dashboard", fontsize=16, fontweight="bold")
months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun"]
revenue = [42000, 48000, 45000, 53000, 61000, 58000]
# Top-left: line chart
axes[0, 0].plot(months, revenue, "o-", color="steelblue", linewidth=2)
axes[0, 0].set_title("Monthly Revenue")
axes[0, 0].set_ylabel("Revenue ($)")
axes[0, 0].grid(True, alpha=0.3)
# Top-right: bar chart
axes[0, 1].bar(months, revenue, color="coral")
axes[0, 1].set_title("Revenue by Month")
axes[0, 1].set_ylabel("Revenue ($)")
# Bottom-left: histogram
np.random.seed(1)
orders = np.random.normal(250, 60, 200)
axes[1, 0].hist(orders, bins=25, color="mediumseagreen", edgecolor="white")
axes[1, 0].set_title("Order Value Distribution")
axes[1, 0].set_xlabel("Order Value ($)")
axes[1, 0].set_ylabel("Frequency")
# Bottom-right: pie chart
categories = ["Online", "In-store", "Mobile"]
sales = [55, 30, 15]
axes[1, 1].pie(sales, labels=categories, autopct="%1.0f%%", startangle=90)
axes[1, 1].set_title("Sales by Channel")
plt.tight_layout()
plt.savefig("dashboard.png", dpi=150, bbox_inches="tight")
plt.show()
Seaborn — Statistical Plots Made Easy
seaborn is a high-level wrapper around matplotlib. It produces better-looking charts with less code and works natively with pandas DataFrames.
import seaborn as sns
import pandas as pd
sns.set_theme(style="whitegrid") # clean background with gridlines
# Other styles: "darkgrid", "white", "dark", "ticks"
Distribution plots
tips = sns.load_dataset("tips") # built-in sample dataset
# Histogram + KDE
fig, ax = plt.subplots(figsize=(7, 4))
sns.histplot(data=tips, x="total_bill", kde=True, ax=ax)
ax.set_title("Distribution of Total Bill")
plt.tight_layout()
plt.show()
# KDE only
sns.kdeplot(data=tips, x="total_bill", hue="sex", fill=True)
plt.show()
Categorical plots
# Box plot
sns.boxplot(data=tips, x="day", y="total_bill", hue="sex", palette="Set2")
plt.title("Total Bill by Day and Sex")
plt.show()
# Violin plot — like boxplot but shows the full distribution
sns.violinplot(data=tips, x="day", y="total_bill", hue="sex",
split=True, palette="Set2")
plt.show()
# Bar plot with confidence intervals
sns.barplot(data=tips, x="day", y="total_bill", hue="sex",
estimator="mean", errorbar="ci")
plt.show()
# Strip plot — individual points
sns.stripplot(data=tips, x="day", y="tip", hue="sex",
dodge=True, alpha=0.5, jitter=True)
plt.show()
Scatter and relationship plots
# Scatter with regression line
sns.regplot(data=tips, x="total_bill", y="tip", scatter_kws={"alpha": 0.5})
plt.title("Tip vs Total Bill")
plt.show()
# Scatter with colour and size encoding
sns.scatterplot(data=tips, x="total_bill", y="tip",
hue="day", size="size", sizes=(20, 200))
plt.show()
# Pair plot — all variables against each other
iris = sns.load_dataset("iris")
sns.pairplot(iris, hue="species", diag_kind="kde")
plt.show()
Heatmaps
# Correlation matrix
flights = sns.load_dataset("flights")
pivot = flights.pivot_table(index="month", columns="year", values="passengers")
fig, ax = plt.subplots(figsize=(10, 6))
sns.heatmap(
pivot,
annot=True, # show numbers in cells
fmt="d", # format as integer
cmap="YlOrRd", # colour map: yellow -> orange -> red
linewidths=0.5,
ax=ax,
)
ax.set_title("Passengers per Month and Year")
plt.tight_layout()
plt.show()
# Correlation heatmap
corr = tips.select_dtypes("number").corr()
sns.heatmap(corr, annot=True, cmap="coolwarm", vmin=-1, vmax=1, center=0)
plt.title("Feature Correlations")
plt.show()
Plotting Directly from pandas
pandas has a .plot() method that wraps matplotlib:
import pandas as pd
df = pd.DataFrame({
"month": ["Jan", "Feb", "Mar", "Apr", "May"],
"online": [20000, 24000, 22000, 28000, 31000],
"instore": [15000, 16000, 14000, 17000, 18000],
})
df.set_index("month", inplace=True)
# Grouped bar chart
df.plot(kind="bar", figsize=(8, 5), color=["steelblue", "coral"])
plt.title("Sales by Channel")
plt.ylabel("Revenue ($)")
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()
# Line chart
df.plot(kind="line", figsize=(8, 5), marker="o")
plt.title("Monthly Sales Trend")
plt.show()
# Stacked area chart
df.plot(kind="area", figsize=(8, 5), alpha=0.6, stacked=True)
plt.title("Stacked Revenue by Channel")
plt.show()
Saving High-Quality Figures
plt.savefig("chart.png", dpi=300, bbox_inches="tight") # for screen/web
plt.savefig("chart.pdf", bbox_inches="tight") # for print (vector)
plt.savefig("chart.svg", bbox_inches="tight") # for web (scalable)
bbox_inches="tight" prevents labels from being cropped. dpi=300 is print-quality.
Project: Complete Sales Dashboard
"""
dashboard.py — A 2x3 sales dashboard using matplotlib + seaborn.
"""
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import pandas as pd
import numpy as np
sns.set_theme(style="whitegrid", palette="muted")
np.random.seed(42)
# ── Generate sample data ──────────────────────────────────────────────────────
months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun",
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
monthly = pd.DataFrame({
"month": months,
"revenue": np.random.normal(50000, 8000, 12).clip(30000),
"orders": np.random.randint(200, 600, 12),
"returns": np.random.randint(5, 30, 12),
})
categories = pd.DataFrame({
"category": ["Electronics", "Clothing", "Food", "Books", "Sports"],
"revenue": [45000, 32000, 28000, 15000, 19000],
})
order_values = np.random.lognormal(mean=5.2, sigma=0.5, size=500)
# ── Create dashboard ──────────────────────────────────────────────────────────
fig = plt.figure(figsize=(16, 10))
fig.suptitle("Annual Sales Dashboard", fontsize=18, fontweight="bold", y=0.98)
gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.4, wspace=0.35)
# 1 — Revenue trend (top-left, spans 2 columns)
ax1 = fig.add_subplot(gs[0, :2])
ax1.plot(months, monthly["revenue"] / 1000, "o-",
color="#2196F3", linewidth=2.5, markersize=6)
ax1.fill_between(range(12), monthly["revenue"] / 1000,
alpha=0.15, color="#2196F3")
ax1.set_title("Monthly Revenue (£000s)")
ax1.set_ylabel("Revenue (£000s)")
ax1.tick_params(axis="x", rotation=45)
ax1.grid(True, alpha=0.4)
# 2 — Category breakdown (top-right)
ax2 = fig.add_subplot(gs[0, 2])
colors = ["#2196F3", "#FF9800", "#4CAF50", "#9C27B0", "#F44336"]
ax2.barh(categories["category"], categories["revenue"] / 1000,
color=colors, edgecolor="white")
ax2.set_title("Revenue by Category (£000s)")
ax2.set_xlabel("Revenue (£000s)")
ax2.invert_yaxis()
# 3 — Orders per month (bottom-left)
ax3 = fig.add_subplot(gs[1, 0])
ax3.bar(months, monthly["orders"], color="#4CAF50", edgecolor="white")
ax3.set_title("Orders per Month")
ax3.set_ylabel("Orders")
ax3.tick_params(axis="x", rotation=90)
# 4 — Order value distribution (bottom-middle)
ax4 = fig.add_subplot(gs[1, 1])
ax4.hist(order_values, bins=35, color="#FF9800", edgecolor="white", alpha=0.85)
ax4.axvline(np.median(order_values), color="red", linestyle="--",
label=f"Median £{np.median(order_values):.0f}")
ax4.set_title("Order Value Distribution")
ax4.set_xlabel("Order Value (£)")
ax4.set_ylabel("Frequency")
ax4.legend(fontsize=8)
# 5 — Returns heatmap (bottom-right)
ax5 = fig.add_subplot(gs[1, 2])
quarters = np.array(monthly["returns"]).reshape(4, 3)
sns.heatmap(quarters,
annot=True, fmt="d", cmap="Reds",
xticklabels=["M1", "M2", "M3"],
yticklabels=["Q1", "Q2", "Q3", "Q4"],
ax=ax5, cbar=False)
ax5.set_title("Returns by Quarter/Month")
plt.savefig("sales_dashboard.png", dpi=150, bbox_inches="tight")
print("Dashboard saved to sales_dashboard.png")
plt.show()
What You Learned in This Chapter
- matplotlib uses a Figure (canvas) and Axes (individual plot) model.
fig, ax = plt.subplots()gives you both. ax.plot()-> line chart,ax.bar()-> bar chart,ax.hist()-> histogram,ax.scatter()-> scatter,ax.boxplot()-> box plot,ax.pie()-> pie chart.- Label everything:
ax.set_title(),ax.set_xlabel(),ax.set_ylabel(),ax.legend(). plt.subplots(rows, cols)creates a grid of charts. Access each withaxes[row, col].plt.savefig("file.png", dpi=300, bbox_inches="tight")saves in PNG, PDF, or SVG.- seaborn wraps matplotlib with better defaults.
sns.set_theme()styles all charts. - seaborn accepts DataFrames directly:
sns.boxplot(data=df, x="col", y="col2"). sns.histplot,sns.kdeplot,sns.boxplot,sns.violinplot,sns.barplot,sns.scatterplot,sns.regplot,sns.pairplot,sns.heatmapcover almost every use case.- pandas
.plot()wraps matplotlib for quick charts directly from a DataFrame.
What's Next?
Chapter 45 covers Machine Learning with scikit-learn — using the DataFrames and visualisations you've built to train models that predict numbers, classify data, and find patterns automatically.