# Packages
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
# Use LaTeX font
plt.rcParams.update({'text.usetex': True})
# Figure font config
label_font = {'fontfamily': 'Arial Black', 'fontsize': 14}
title_font = {'fontfamily': 'Arial Black', 'fontsize': 16}
legend_font = {'family': 'Palatino Linotype', 'size': 12}
text_font = {'family': 'Palatino Linotype', 'size': 12}
# Generate data
x = np.linspace(0, 10, 100)
x11 = np.linspace(0, 5, 50)
x12 = np.linspace(5, 10, 50)
# Line plot
y1 = x * (x - 5) * (x - 10)
y2 = 0.25 * x * (x - 2.5) * (x - 10)
y3 = 0.25 * x11 * (x11 - 7.5) * (x11 - 10)
y4 = 0.25 * x12 * (x12 - 7.5) * (x12 - 10)
# Scatter nodes
x2 = [0, 2.5, 5, 7.5, 10]
x3 = np.linspace(0, 10, 10)
fig, ax = plt.subplots(1, figsize=(6, 3))
ax.plot(x, y1, color=cm.Set2(0), label='Model 1', linestyle='-', linewidth=2)
ax.plot(x, y2, color=cm.Set2(1), label='Model 2', linestyle='-.', linewidth=2)
ax.plot(x11, y3, color=cm.Set2(2), label='Model 3(1)', linestyle='--', linewidth=2)
ax.plot(x12, y4, color=cm.Set2(3), label='Model 3(2)', linestyle='--', linewidth=2)
ax.plot([5, 5], [-50, 50], color='gray', linestyle='--', linewidth=1)
# Fill between line and y=0
ax.fill_between(x, y1, 0, color=cm.Set2(0), alpha=0.1)
ax.fill_between(x, y2, 0, color=cm.Set2(1), alpha=0.1)
# Scatter nodes
ax.scatter(x2, [0] * len(x2), label='Node 1', color=cm.Set1(1), marker='o', s=30)
ax.scatter(x3, [0] * len(x3), label='Node 2', color='gray', marker='x', s=50)
# Legend
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='center left', ncol=1, bbox_to_anchor=(0.98, 0.5), prop=legend_font)
# Label and title
ax.set_xlabel('X Lable', fontdict=label_font)
ax.set_ylabel('Y Lable', fontdict=label_font)
ax.set_title('Single Line Plot 3', fontdict=title_font)
# Ticks fontsize and font family
ax.tick_params(axis='both', which='major', labelsize=14)
labels = ax.get_xticklabels() + ax.get_yticklabels()
[label.set_fontname('serif') for label in labels]
# Axis range
ax.set_xlim(0, 10)
ax.set_ylim(-50, 50)
# Grid
ax.grid(axis='both', color='black', alpha=0.1)
plt.tight_layout()
plt.savefig('../fig/single-linear-3.jpg', dpi=300, bbox_inches='tight')