|
AutoModelForCausalLM与AutoModelForSequenceClassification的区别,分别给出示例
`AutoModelForCausalLM` 和 `AutoModelForSequenceClassification` 是 Hugging Face Transformers 库中针对不同任务设计的模型类,主要区别如下:
---
### **1. AutoModelForCausalLM**
- **用途**:因果语言建模(生成式任务),如文本生成、对话生成。
- **特点**:
- 模型基于前一个词预测下一个词(自回归生成)。
- 使用单向注意力掩码(防止未来信息泄露)。
- 典型应用:GPT、LLaMA 等生成式模型。
#### 示例:生成文本
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
# 加载模型和分词器
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 输入文本
input_text = "Artificial intelligence is"
inputs = tokenizer(input_text, return_tensors="pt")
# 生成文本
outputs = model.generate(
inputs.input_ids,
max_length=50,
do_sample=True,
temperature=0.7
)
# 解码输出
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)
# 输出可能:"Artificial intelligence is a field of computer science that aims to create systems capable of performing tasks that normally require human intelligence..."
```
---
### **2. AutoModelForSequenceClassification**
- **用途**:序列分类任务(判别式任务),如情感分析、文本分类。
- **特点**:
- 模型输出分类概率(如正面/负面情感)。
- 在基础模型(如 BERT)顶部添加分类层。
- 使用双向注意力(允许全序列上下文)。
#### 示例:情感分析
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# 加载模型和分词器
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# 输入文本
input_text = "I love this movie!"
inputs = tokenizer(input_text, return_tensors="pt")
# 推理
with torch.no_grad():
outputs = model(**inputs)
# 获取分类结果
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities).item()
# 输出标签映射(假设标签为 0: NEGATIVE, 1: POSITIVE)
labels = ["NEGATIVE", "POSITIVE"]
print(f"Predicted sentiment: {labels[predicted_class]}")
# 输出可能:"Predicted sentiment: POSITIVE"
```
---
### **关键区别总结**
| 特性 | AutoModelForCausalLM | AutoModelForSequenceClassification |
|---------------------|----------------------------------|--------------------------------------|
| 任务类型 | 生成式任务(文本生成) | 判别式任务(分类) |
| 注意力机制 | 单向(仅左侧上下文) | 双向(全序列上下文) |
| 输出 | 下一个词的概率分布 | 分类标签的概率分布 |
| 典型模型 | GPT、LLaMA | BERT、RoBERTa |
| 应用场景 | 对话、续写、代码生成 | 情感分析、垃圾邮件检测 |
---
### **注意事项**
1. 模型选择需与任务匹配:生成任务用 `CausalLM`,分类任务用 `SequenceClassification`。
2. 输入格式差异:`CausalLM` 可能需要手动添加注意力掩码,而 `SequenceClassification` 通常自动处理。
3. 微调方法不同:`CausalLM` 通过预测下一个词损失优化,`SequenceClassification` 通过交叉熵损失优化分类层。
|
|