29 lines
815 B
Python
29 lines
815 B
Python
|
import pandas as pd
|
||
|
import seaborn as sns
|
||
|
import matplotlib.pyplot as plt
|
||
|
def plot_heatmap(matrix, len, wid, title="Heatmap", save_path="heatmap.png"):
|
||
|
# 设置支持中文的字体
|
||
|
plt.rcParams['font.family'] = 'SimHei'
|
||
|
|
||
|
# 解决负号显示问题
|
||
|
plt.rcParams['axes.unicode_minus'] = False
|
||
|
|
||
|
# 读取数据
|
||
|
df = matrix
|
||
|
|
||
|
# 绘制热力图
|
||
|
plt.figure(figsize=(len, wid))
|
||
|
sns.heatmap(
|
||
|
df,
|
||
|
annot=True, # 显示数值
|
||
|
fmt=".2f", # 数值格式
|
||
|
cmap="coolwarm", # 红-蓝渐变色
|
||
|
vmin=-1, vmax=1, # 颜色范围固定为-1到1
|
||
|
linewidths=0.5,
|
||
|
)
|
||
|
plt.title(title)
|
||
|
plt.xticks(rotation=45, ha="right") # 调整X轴标签角度
|
||
|
plt.tight_layout()
|
||
|
plt.savefig(save_path)
|
||
|
plt.show()
|
||
|
plt.close()
|