26 lines
566 B
Python
26 lines
566 B
Python
|
import numpy as np
|
||
|
|
||
|
def cal_metrics(estimated, true):
|
||
|
"""
|
||
|
这是一个用来计算拟合度指标的函数
|
||
|
"""
|
||
|
|
||
|
estimated = np.array(estimated)
|
||
|
true = np.array(true)
|
||
|
|
||
|
#RMSE
|
||
|
RMSE = np.sqrt(np.mean((estimated - true) ** 2))
|
||
|
|
||
|
#R-squared
|
||
|
SS_res = np.sum((true - estimated) ** 2)
|
||
|
SS_tot = np.sum((true - np.mean(true)) ** 2)
|
||
|
R_squared = 1 - (SS_res / SS_tot) if SS_tot != 0 else 0
|
||
|
|
||
|
#MAE
|
||
|
MAE = np.mean(np.abs(estimated - true))
|
||
|
|
||
|
return {
|
||
|
'RMSE': RMSE,
|
||
|
'R-squared': R_squared,
|
||
|
'MAE': MAE
|
||
|
}
|