29 lines
817 B
Python
29 lines
817 B
Python
import pandas as pd
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
def plot_heatmap(matrix, len, wid, title="Heatmap", save_path="heatmap.png"):
|
|
# 设置支持中文的字体
|
|
# plt.rcParams['font.family'] = 'SimHei'
|
|
|
|
# 解决负号显示问题
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
# 读取数据
|
|
df = matrix
|
|
|
|
# 绘制热力图
|
|
plt.figure(figsize=(len, wid))
|
|
sns.heatmap(
|
|
df,
|
|
annot=True, # 显示数值
|
|
fmt=".2f", # 数值格式
|
|
cmap="coolwarm", # 红-蓝渐变色
|
|
vmin=-1, vmax=1, # 颜色范围固定为-1到1
|
|
linewidths=0.5,
|
|
)
|
|
plt.title(title)
|
|
plt.xticks(rotation=45, ha="right") # 调整X轴标签角度
|
|
plt.tight_layout()
|
|
plt.savefig(save_path)
|
|
plt.show()
|
|
plt.close() |