Skip to content

Commit

Permalink
Add example figures
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Milk committed May 7, 2024
1 parent c5b3b42 commit 70e98a9
Show file tree
Hide file tree
Showing 11 changed files with 1,257 additions and 4 deletions.
17 changes: 13 additions & 4 deletions scripts/benchmark/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,15 @@ def plot_ax(ax, data, xlabel):
palette = ["#5847AD", "#F7B449"]
kws = dict(orient="v", errorbar="sd", palette=palette, width=0.6)
sns.barplot(data=data, ax=ax, **kws)
data.melt().pipe((sns.scatterplot, "data"), y="value", x="variable", ax=ax,
zorder=100, color="grey", alpha=0.5)
data.melt().pipe(
(sns.scatterplot, "data"),
y="value",
x="variable",
ax=ax,
zorder=100,
color="grey",
alpha=0.5,
)
ax.get_xticklabels()[0].set_fontweight("bold")
ax.tick_params(axis="x", rotation=45)
ax.tick_params(bottom=False)
Expand Down Expand Up @@ -140,7 +147,9 @@ def plot_ax(ax, data, xlabel):

fk.install("Lato", verbose=False)
data = pd.DataFrame({"Marsilea": marsilea_tokens, "Matplotlib": matplotlib_tokens})
_, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(6, 5), gridspec_kw={"wspace": 0.5})
_, (ax1, ax2, ax3) = plt.subplots(
ncols=3, figsize=(6, 5), gridspec_kw={"wspace": 0.5}
)
# sns.set(style="whitegrid")
plot_ax(ax1, data, "Tokens")
# Plot the number of lines of code using seaborn
Expand All @@ -156,4 +165,4 @@ def plot_ax(ax, data, xlabel):
root = Path(__file__).parent
plt.savefig(root / "marsilea_vs_matplotlib.png", dpi=300, bbox_inches="tight")
plt.savefig(root / "marsilea_vs_matplotlib.svg", bbox_inches="tight")
plt.show()
plt.show()
164 changes: 164 additions & 0 deletions scripts/example_figures/arc-diagrams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""
Les Miserables Arc Diagram
===============================
This example shows how to create an arc diagram from a network.
"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import networkx as nx

import marsilea as ma
import marsilea.plotter as mp

# sphinx_gallery_start_ignore
import mpl_fontkit as fk

fk.install("Lato", verbose=False)
# sphinx_gallery_end_ignore

# %%
# Create Arc Diagram
# ------------------

data = pd.read_csv("data/PPI.csv")
G = nx.from_pandas_edgelist(data, "source", "target", create_using=nx.DiGraph)

protein_classification = {
"MAP Kinases": [
"MAP2K7",
"RPS6KA3",
"MAPK7",
"RPS6KB1",
"MAPKAPK5",
"RPS6KB2",
"JUN",
"RPS6KA1",
"MAPK10",
"MAPK1",
"MAPK14",
"MAPK3",
"MAP2K1",
"MAPK12",
"MAP2K4",
"MAPK9",
"MAPK8",
],
"Kinases": [
"CSNK2A1",
"GSK3B",
"CSNK1A1",
"PRKACA",
"SYK",
"JAK2",
"GSK3A",
"PRKCA",
"CDK1",
"ITK",
"ELK1",
"LYN",
"PRKCD",
"CDK4",
"PLK1",
"PAK1",
"CDK7",
"LCK",
"CDK5",
"MAPK8",
"GRK2",
"AURKB",
"PRKCQ",
"CDC25C",
"CHEK2",
"CDK7",
"TP53",
],
"Transcription Factors": ["NFKBIA", "MEF2A", "ESR1", "CREB1", "NFATC4"],
"Tyrosine Kinases": [
"HCLS1",
"FCGR2A",
"BTK",
"SYK",
"HCK",
"PTK2B",
"LCK",
"FYN",
"ZAP70",
],
"Adaptor Proteins": ["IRS1", "CBL", "SHC1", "GRB2"],
"Ubiquitin Ligases": ["MDM2", "CBL"],
"Cell Structure/Signaling": [
"STK11",
"FGFR1",
"CSK",
"ILK",
"CTTN",
"SNCA",
"KRT8",
],
"Other": ["HNRNPK"],
}


colormap = {
"MAP Kinases": "#E2DFD0",
"Kinases": "#CA8787",
"Transcription Factors": "#E65C19",
"Tyrosine Kinases": "#F8D082",
"Adaptor Proteins": "#0A6847",
"Ubiquitin Ligases": "#7ABA78",
"Cell Structure/Signaling": "#03AED2",
"Other": ".7",
}

# Reverse mapping to get labels for each protein
protein_labels = {}
for label, proteins in protein_classification.items():
for protein in proteins:
protein_labels[protein] = label

nodes = list(G.nodes)
nodes = pd.DataFrame(
{"nodes": nodes, "type": [protein_labels[n] for n in nodes]}
).sort_values("type")["nodes"]

degs = nx.degree(G)
degree_arr = np.array([[degs[n] for n in nodes]])
color_arr = np.array([[protein_labels[n] for n in nodes]])

edges = list(G.edges)
edges_colors = [colormap[protein_labels[a]] for a, _ in edges]

sources = set([a for a, _ in edges])
is_sources = np.array([["*" if n in sources else "" for n in nodes]])

wb = ma.SizedHeatmap(
degree_arr,
color_arr,
palette=colormap,
sizes=(10, 300),
frameon=False,
width=10.5,
height=0.3,
size_legend_kws={"func": lambda x: [int(i) for i in x], "title": "Count"},
color_legend_kws={"title": "Protein Type"},
)
wb.add_bottom(mp.Labels(nodes, align="bottom"))
wb.add_bottom(mp.Labels(is_sources, fontsize=16))
arc = mp.Arc(nodes, edges, colors=edges_colors, lw=1, alpha=0.5)
wb.add_top(arc, size=2)
wb.add_legends(stack_size=1)
wb.render()

# sphinx_gallery_start_ignore
if "__file__" in globals():
from pathlib import Path

save_path = Path(__file__).parent / "figures"
wb.save(save_path / "arc_diagram.svg")
else:
plt.show()
# sphinx_gallery_end_ignore
122 changes: 122 additions & 0 deletions scripts/example_figures/cooking_oils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Fat content in cooking oils
===========================
This example shows how to apply x-layout on statistical plots.
"""

import marsilea as ma
import marsilea.plotter as mp

import mpl_fontkit as fk

fk.install_fontawesome(verbose=False)
fk.install("Lato", verbose=False)

# sphinx_gallery_start_ignore
import matplotlib as mpl

mpl.rcParams["font.size"] = 12
# sphinx_gallery_end_ignore


# %%
# Load data
# ---------
oils = ma.load_data("cooking_oils")

red = "#cd442a"
yellow = "#f0bd00"
green = "#7e9437"
gray = "#eee"

mapper = {0: "\uf58a", 1: "\uf11a", 2: "\uf567"}
cmapper = {0: "#609966", 1: "#DC8449", 2: "#F16767"}
flavour = [mapper[i] for i in oils["flavour"].values]
flavour_colors = [cmapper[i] for i in oils["flavour"].values]
fat_content = oils[
["saturated", "polyunsaturated (omega 3 & 6)", "monounsaturated", "other fat"]
]

# %%
# Visualize the oil contents
# --------------------------

fat_stack_bar = mp.StackBar(
fat_content.T * 100,
colors=[red, yellow, green, gray],
width=0.8,
orient="h",
label="Fat Content (%)",
legend_kws={"ncol": 2, "fontsize": 10},
)
fmt = lambda x: f"{x:.1f}" if x > 0 else ""
trans_fat_bar = mp.Numbers(
oils["trans fat"] * 100,
fmt=fmt,
color="#3A98B9",
label="Trans Fat (%)",
)

flavour_emoji = mp.Labels(
flavour, fontfamily="Font Awesome 6 Free", text_props={"color": flavour_colors}
)

oil_names = mp.Labels(oils.index.str.capitalize())

fmt = lambda x: f"{int(x)}" if x > 0 else ""

omege_bar = ma.plotter.CenterBar(
(oils[["omega 3", "omega 6"]] * 100).astype(int),
names=["Omega 3 (%)", "Omega 6 (%)"],
colors=["#7DB9B6", "#F5E9CF"],
fmt=fmt,
show_value=True,
)
conditions_text = [
"Control",
">230 °C\nDeep-frying",
"200-229 °C\nStir-frying",
"150-199 °C\nLight saute",
"<150 °C\nDressings",
]
colors = ["#e5e7eb", "#c2410c", "#fb923c", "#fca5a5", "#fecaca"]
conditions = ma.plotter.Chunk(conditions_text, colors, rotation=0, padding=10)

cb = ma.ClusterBoard(fat_content.to_numpy(), height=10)
cb.add_layer(fat_stack_bar)
cb.add_left(trans_fat_bar, pad=0.2, name="trans fat")
cb.add_right(flavour_emoji)
cb.add_right(oil_names, pad=0.1)
cb.add_right(omege_bar, size=2, pad=0.2)

order = [
"Control",
">230 °C (Deep-frying)",
"200-229 °C (Stir-frying)",
"150-199 °C (Light saute)",
"<150 °C (Dressings)",
]
cb.hsplit(labels=oils["cooking conditions"], order=order)
cb.add_left(conditions, pad=0.1)
cb.add_dendrogram(
"left", add_meta=False, colors=colors, linewidth=1.5, size=0.5, pad=0.02
)
cb.add_title(top="Fat in Cooking Oils", fontsize=16)
cb.add_legends("bottom", pad=0.3)
cb.render()

axes = cb.get_ax("trans fat")
for ax in axes:
ax.set_xlim(4.2, 0)

# sphinx_gallery_start_ignore
if "__file__" in globals():
from pathlib import Path
import matplotlib.pyplot as plt

plt.rcParams["svg.fonttype"] = "none"
save_path = Path(__file__).parent / "imgs"
plt.savefig(save_path / "oil_well.svg", bbox_inches="tight")
# sphinx_gallery_end_ignore
Loading

0 comments on commit 70e98a9

Please sign in to comment.