add:tqdm
This commit is contained in:
parent
2107d9e189
commit
34f29a3e74
@ -18,16 +18,21 @@ def load_data(file_path):
|
||||
|
||||
# 在analyze_sentiment函数中添加模型路径处理
|
||||
def analyze_sentiment(texts):
|
||||
"""改进的情感分析函数"""
|
||||
try:
|
||||
# 修改为优先使用打包后的模型路径
|
||||
model_path = os.path.join(os.path.dirname(__file__), '.cache/huggingface/hub')
|
||||
# 使用 HuggingFace 的模型
|
||||
model_name = "IDEA-CCNL/Erlangshen-Roberta-330M-Sentiment"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=model_path)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=model_path)
|
||||
# 将模型移动到GPU
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
# 批量处理提升效率
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()} # 将输入数据也移动到GPU
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user