BERT 总是预测相同的 class(微调)
BERT always predicts same class (Fine-Tuning)
我正在金融新闻数据集上微调 BERT。
不幸的是,BERT 似乎陷入了局部最小值。满足于学习总是预测相同的 class.
- 平衡数据集不起作用
- 调整参数也不起作用
老实说,我不确定是什么导致了这个问题。使用 simpletransformers 库,我得到了很好的结果。如果有人能帮助我,我将不胜感激。非常感谢!
github 上的完整代码:
https://github.com/Bene939/BERT_News_Sentiment_Classifier
代码:
from transformers import BertForSequenceClassification, AdamW, BertTokenizer, get_linear_schedule_with_warmup, Trainer, TrainingArguments
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import pandas as pd
from pathlib import Path
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
from torch.nn import functional as F
from collections import defaultdict
import random
#defining tokenizer, model and optimizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=3)
if torch.cuda.is_available():
print("\nUsing: ", torch.cuda.get_device_name(0))
device = torch.device('cuda')
else:
print("\nUsing: CPU")
device = torch.device('cpu')
model = model.to(device)
#loading dataset
labeled_dataset = "news_headlines_sentiment.csv"
labeled_dataset_file = Path(labeled_dataset)
file_loaded = False
while not file_loaded:
if labeled_dataset_file.exists():
labeled_dataset = pd.read_csv(labeled_dataset_file)
file_loaded = True
print("Dataset Loaded")
else:
print("File not Found")
print(labeled_dataset)
#counting sentiments
negative = 0
neutral = 0
positive = 0
for idx, row in labeled_dataset.iterrows():
if row["sentiment"] == 0:
negative += 1
elif row["sentiment"] == 1:
neutral += 1
else:
positive += 1
print("Unbalanced Dataset")
print("negative: ", negative)
print("neutral: ", neutral)
print("positive: ", positive)
#balancing dataset to 1/3 per sentiment
for idx, row in labeled_dataset.iterrows():
if row["sentiment"] == 0:
if negative - neutral != 0:
index_name = labeled_dataset[labeled_dataset["news"] == row["news"]].index
labeled_dataset.drop(index_name, inplace=True)
negative -= 1
elif row["sentiment"] == 2:
if positive - neutral != 0:
index_name = labeled_dataset[labeled_dataset["news"] == row["news"]].index
labeled_dataset.drop(index_name, inplace=True)
positive -= 1
#custom dataset class
class NewsSentimentDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
#method for tokenizing dataset list
def tokenize_headlines(headlines, labels, tokenizer):
encodings = tokenizer.batch_encode_plus(
headlines,
add_special_tokens = True,
truncation = True,
padding = 'max_length',
return_attention_mask = True,
return_token_type_ids = True
)
dataset = NewsSentimentDataset(encodings, labels)
return dataset
#splitting dataset into training and validation set
#load news sentiment dataset
all_headlines = labeled_dataset['news'].tolist()
all_labels = labeled_dataset['sentiment'].tolist()
train_headlines, val_headlines, train_labels, val_labels = train_test_split(all_headlines, all_labels, test_size=.2)
val_dataset = tokenize_headlines(val_headlines, val_labels, tokenizer)
train_dataset = tokenize_headlines(train_headlines, val_labels, tokenizer)
#data loader
train_batch_size = 8
val_batch_size = 8
train_data_loader = DataLoader(train_dataset, batch_size = train_batch_size, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size = val_batch_size, sampler=SequentialSampler(val_dataset))
#optimizer and scheduler
num_epochs = 1
num_steps = len(train_data_loader) * num_epochs
optimizer = AdamW(model.parameters(), lr=5e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_steps*0.06, num_training_steps=num_steps)
#training and evaluation
seed_val = 64
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
for epoch in range(num_epochs):
print("\n###################################################")
print("Epoch: {}/{}".format(epoch+1, num_epochs))
print("###################################################\n")
#training phase
average_train_loss = 0
average_train_acc = 0
model.train()
for step, batch in enumerate(train_data_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
token_type_ids = batch['token_type_ids'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids = token_type_ids)
loss = F.cross_entropy(outputs[0], labels)
average_train_loss += loss
if step % 40 == 0:
print("Training Loss: ", loss)
logits = outputs[0].detach().cpu().numpy()
label_ids = labels.to('cpu').numpy()
average_train_acc += sklearn.metrics.accuracy_score(label_ids, np.argmax(logits, axis=1))
print("predictions: ",np.argmax(logits, axis=1))
print("labels: ",label_ids)
print("#############")
optimizer.zero_grad()
loss.backward()
#maximum gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
model.zero_grad()
average_train_loss = average_train_loss / len(train_data_loader)
average_train_acc = average_train_acc / len(train_data_loader)
print("======Average Training Loss: {:.5f}======".format(average_train_loss))
print("======Average Training Accuracy: {:.2f}%======".format(average_train_acc*100))
#validation phase
average_val_loss = 0
average_val_acc = 0
model.eval()
for step,batch in enumerate(val_data_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
token_type_ids = batch['token_type_ids'].to(device)
pred = []
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
loss = F.cross_entropy(outputs[0], labels)
average_val_loss += loss
logits = outputs[0].detach().cpu().numpy()
label_ids = labels.to('cpu').numpy()
print("predictions: ",np.argmax(logits, axis=1))
print("labels: ",label_ids)
print("#############")
average_val_acc += sklearn.metrics.accuracy_score(label_ids, np.argmax(logits, axis=1))
average_val_loss = average_val_loss / len(val_data_loader)
average_val_acc = average_val_acc / len(val_data_loader)
print("======Average Validation Loss: {:.5f}======".format(average_val_loss))
print("======Average Validation Accuracy: {:.2f}%======".format(average_val_acc*100))
###################################################
Epoch: 1/1
###################################################
Training Loss: tensor(1.1006, device='cuda:0', grad_fn=<NllLossBackward>)
predictions: [1 0 2 0 0 0 2 0]
labels: [2 0 1 1 0 1 0 1]
#############
predictions: [2 2 0 0 0 2 0 0]
labels: [1 2 1 0 2 0 1 2]
#############
predictions: [0 0 0 0 1 0 0 1]
labels: [0 1 1 0 1 1 2 0]
#############
predictions: [0 0 0 2 0 1 0 0]
labels: [0 0 0 2 0 0 2 1]
#############
predictions: [1 0 0 0 0 0 2 0]
labels: [0 2 2 1 0 0 0 0]
#############
predictions: [0 0 0 0 0 1 0 0]
labels: [1 0 2 2 2 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 2 2 0 2 0]
#############
predictions: [0 1 0 0 0 0 0 0]
labels: [2 2 0 2 0 0 0 1]
#############
predictions: [0 0 0 0 0 2 0 1]
labels: [0 1 0 2 2 0 1 2]
#############
predictions: [0 0 2 0 0 0 1 0]
labels: [0 0 0 1 2 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 1 0 1 0 1 1]
#############
predictions: [0 2 0 0 0 0 0 0]
labels: [2 2 0 1 0 1 2 1]
#############
predictions: [0 1 0 0 0 0 1 2]
labels: [2 2 1 0 2 0 0 2]
#############
predictions: [0 0 1 1 1 1 0 1]
labels: [1 2 1 1 1 1 2 2]
#############
predictions: [1 0 0 0 0 1 2 1]
labels: [1 0 1 1 0 0 0 2]
#############
predictions: [0 1 1 1 1 0 2 1]
labels: [2 2 1 2 2 1 1 2]
#############
predictions: [0 0 1 0 1 1 0 0]
labels: [1 0 0 1 0 1 0 2]
#############
predictions: [1 2 0 0 1 2 0 0]
labels: [0 2 2 1 2 0 1 0]
#############
predictions: [0 2 1 1 0 1 1 0]
labels: [2 2 0 1 1 0 1 2]
#############
predictions: [1 0 1 1 1 1 1 0]
labels: [0 2 0 1 0 1 2 2]
#############
predictions: [0 2 1 2 0 0 1 1]
labels: [2 1 1 1 1 2 2 0]
#############
predictions: [0 1 2 2 2 1 1 2]
labels: [2 2 1 1 2 1 0 1]
#############
predictions: [2 2 2 1 2 1 1 1]
labels: [0 1 1 0 0 2 2 1]
#############
predictions: [1 2 2 2 1 2 1 2]
labels: [0 0 0 0 2 0 1 2]
#############
predictions: [2 1 1 1 2 2 2 2]
labels: [1 0 2 2 1 0 0 0]
#############
predictions: [2 1 2 2 2 1 2 2]
labels: [2 1 1 1 1 1 2 2]
#############
predictions: [1 1 0 2 1 2 1 2]
labels: [2 2 0 2 0 1 2 0]
#############
predictions: [0 1 1 2 0 1 2 1]
labels: [2 2 2 1 2 2 0 1]
#############
predictions: [2 1 1 1 1 2 1 1]
labels: [0 1 1 2 1 0 0 2]
#############
predictions: [1 2 2 0 1 1 1 2]
labels: [0 1 2 1 2 1 0 1]
#############
predictions: [0 1 1 1 1 1 1 0]
labels: [0 2 0 1 1 2 2 2]
#############
predictions: [1 2 1 1 2 1 1 0]
labels: [0 2 2 2 0 0 1 0]
#############
predictions: [2 2 2 1 2 1 1 2]
labels: [2 2 1 2 1 0 0 0]
#############
predictions: [2 2 1 2 2 2 1 2]
labels: [1 1 2 2 2 0 2 1]
#############
predictions: [2 2 2 2 2 0 2 2]
labels: [2 2 1 2 0 1 1 2]
#############
predictions: [1 1 2 1 2 2 0 1]
labels: [2 1 1 1 0 0 2 2]
#############
predictions: [2 1 2 2 2 2 1 0]
labels: [0 2 0 2 0 0 0 0]
#############
predictions: [2 2 2 2 2 2 2 2]
labels: [1 1 0 2 0 1 2 1]
#############
predictions: [2 2 2 2 1 2 2 2]
labels: [1 0 0 1 1 0 0 0]
#############
predictions: [2 2 2 1 2 2 2 2]
labels: [1 0 1 1 0 2 2 0]
#############
Training Loss: tensor(1.1104, device='cuda:0', grad_fn=<NllLossBackward>)
predictions: [2 0 1 2 1 2 2 0]
labels: [2 2 0 0 1 0 0 2]
#############
predictions: [0 2 2 0 2 1 1 1]
labels: [0 0 0 1 0 0 1 0]
#############
predictions: [0 2 2 0 1 1 1 2]
labels: [2 1 1 1 2 2 1 0]
#############
predictions: [2 1 1 2 2 0 2 0]
labels: [1 2 1 2 1 0 2 1]
#############
predictions: [0 2 2 0 0 2 1 2]
labels: [0 0 2 2 0 0 2 0]
#############
predictions: [0 0 1 2 2 0 2 2]
labels: [0 0 0 0 0 0 0 0]
#############
predictions: [1 1 2 1 2 0 1 2]
labels: [0 0 2 0 0 0 1 1]
#############
predictions: [0 0 2 1 0 2 0 1]
labels: [1 1 2 1 1 0 2 0]
#############
predictions: [0 0 0 0 1 0 0 0]
labels: [2 2 1 1 2 1 1 1]
#############
predictions: [0 0 0 0 1 0 0 0]
labels: [1 1 2 2 1 1 2 0]
#############
predictions: [0 0 0 0 0 1 1 1]
labels: [2 0 1 1 0 1 2 2]
#############
predictions: [0 0 1 0 0 1 2 1]
labels: [1 2 0 2 2 0 2 1]
#############
predictions: [1 1 1 1 0 1 0 1]
labels: [2 0 1 0 1 0 1 2]
#############
predictions: [1 2 2 0 0 0 1 1]
labels: [2 0 0 2 1 2 2 2]
#############
predictions: [1 0 2 1 0 2 2 0]
labels: [0 0 2 1 2 1 1 1]
#############
predictions: [0 0 0 1 1 1 1 1]
labels: [1 2 1 0 0 0 1 0]
#############
predictions: [1 1 1 0 1 1 0 1]
labels: [0 2 1 2 1 2 2 0]
#############
predictions: [2 1 0 1 1 2 0 0]
labels: [0 1 0 0 1 2 0 2]
#############
predictions: [0 1 1 0 0 1 0 1]
labels: [1 0 0 2 2 1 1 2]
#############
predictions: [1 1 1 1 1 1 1 1]
labels: [2 0 1 0 2 0 0 2]
#############
predictions: [1 0 0 1 0 1 0 2]
labels: [1 0 0 1 1 2 2 1]
#############
predictions: [1 1 1 1 1 1 0 0]
labels: [1 1 0 2 1 0 2 0]
#############
predictions: [1 1 2 1 0 1 0 0]
labels: [0 2 1 2 1 1 0 2]
#############
predictions: [1 1 0 0 1 2 1 1]
labels: [0 2 1 0 2 2 0 1]
#############
predictions: [0 1 1 0 0 1 0 1]
labels: [0 0 1 2 2 0 1 2]
#############
predictions: [1 0 2 2 2 1 1 0]
labels: [2 2 1 0 0 1 1 2]
#############
predictions: [1 2 2 1 1 2 1 1]
labels: [1 0 0 1 0 0 0 0]
#############
predictions: [0 2 0 2 2 0 2 2]
labels: [2 0 0 0 2 1 1 2]
#############
predictions: [0 0 1 0 1 0 2 2]
labels: [0 0 1 0 1 0 2 0]
#############
predictions: [0 2 0 1 1 2 2 0]
labels: [0 2 0 2 0 2 0 0]
#############
predictions: [2 2 2 2 2 2 2 1]
labels: [2 2 1 1 0 0 2 2]
#############
predictions: [2 0 0 2 2 1 1 0]
labels: [1 0 0 1 0 2 1 2]
#############
predictions: [2 0 0 2 0 2 2 0]
labels: [2 2 2 2 0 1 1 1]
#############
predictions: [0 2 2 0 2 2 0 0]
labels: [1 0 1 2 0 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 2]
labels: [2 1 1 0 0 0 1 2]
#############
predictions: [2 0 2 0 2 1 0 2]
labels: [2 1 1 2 1 1 0 0]
#############
predictions: [1 1 2 0 2 0 2 2]
labels: [0 2 1 2 1 2 1 0]
#############
predictions: [2 0 1 1 0 2 0 0]
labels: [2 1 0 1 1 0 2 0]
#############
predictions: [2 0 0 2 0 2 1 0]
labels: [0 0 0 0 2 1 0 1]
#############
predictions: [1 2 1 0 0 2 0 2]
labels: [2 0 2 1 0 0 1 1]
#############
Training Loss: tensor(1.1162, device='cuda:0', grad_fn=<NllLossBackward>)
predictions: [2 0 0 1 1 1 0 1]
labels: [0 1 1 1 1 2 2 1]
#############
predictions: [0 2 0 1 2 0 0 1]
labels: [2 2 1 0 1 0 0 0]
#############
predictions: [0 0 1 0 0 0 0 1]
labels: [1 0 2 0 0 2 2 0]
#############
predictions: [2 1 2 2 0 1 2 0]
labels: [2 0 1 0 2 1 0 1]
#############
predictions: [1 0 0 2 0 0 1 1]
labels: [2 2 0 2 0 2 0 0]
#############
predictions: [0 0 1 0 0 0 0 0]
labels: [2 2 2 1 2 2 2 2]
#############
predictions: [0 0 1 1 0 1 1 0]
labels: [2 1 1 1 0 2 1 0]
#############
predictions: [0 0 0 1 0 0 1 0]
labels: [2 0 2 2 0 0 1 2]
#############
predictions: [1 0 1 0 0 2 0 0]
labels: [1 1 2 0 0 1 0 0]
#############
predictions: [2 1 0 0 0 1 0 0]
labels: [1 2 0 0 0 0 0 0]
#############
predictions: [0 2 0 0 0 0 0 0]
labels: [2 0 1 1 2 2 1 1]
#############
predictions: [0 1 0 0 0 1 0 2]
labels: [0 2 1 1 0 0 1 2]
#############
predictions: [0 2 1 0 0 1 1 1]
labels: [1 1 0 2 0 1 1 0]
#############
predictions: [0 1 1 0 0 0 1 0]
labels: [0 0 1 0 1 2 1 1]
#############
predictions: [0 1 1 0 1 0 0 0]
labels: [0 1 1 1 2 2 2 0]
#############
predictions: [0 0 0 0 1 1 0 0]
labels: [2 0 2 2 1 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 0 2 2 0 1 1]
#############
predictions: [0 1 0 0 0 0 0 0]
labels: [0 2 0 1 1 2 0 2]
#############
predictions: [1 1 0 1 0 1 0 2]
labels: [1 2 0 0 2 2 2 1]
#############
predictions: [1 1 0 0 0 1 2 1]
labels: [0 0 1 2 2 1 2 2]
#############
predictions: [1 1 1 0 1 1 2 0]
labels: [0 0 0 2 0 1 0 2]
#############
predictions: [0 1 0 0 1 1 2 1]
labels: [2 0 0 1 2 2 1 2]
#############
predictions: [1 0 0 0 1 0 0 1]
labels: [1 2 2 2 2 1 0 1]
#############
predictions: [2 0 0 0 0 0 0 0]
labels: [1 2 0 2 2 1 1 1]
#############
predictions: [2 0 1 1 0 0 1 0]
labels: [0 0 0 0 2 2 1 1]
#############
predictions: [2 0 0 1 0 0 1 1]
labels: [2 2 1 1 0 0 1 0]
#############
predictions: [1 1 1 1 1 2 0 0]
labels: [0 0 2 1 0 0 0 0]
#############
predictions: [1 1 2 0 1 2 0 1]
labels: [0 2 1 0 2 0 0 1]
#############
predictions: [0 0 2 1 0 2 0 1]
labels: [1 2 0 2 2 1 0 0]
#############
predictions: [0 0 2 0 2 1 1 2]
labels: [2 2 1 2 2 2 0 0]
#############
predictions: [0 1 0 0 0 0 2 1]
labels: [1 1 0 1 1 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 0 0 2 0 0 2]
#############
predictions: [2 2 2 0 1 1 1 0]
labels: [1 0 2 1 1 2 0 0]
#############
predictions: [0 0 1 0 0 0 2 0]
labels: [0 1 2 1 1 0 0 0]
#############
predictions: [0 2 0 1 0 2 0 0]
labels: [0 0 2 1 1 0 2 2]
#############
predictions: [0 0 1 2 0 2 0 1]
labels: [2 2 0 0 0 2 2 2]
#############
predictions: [1 0 0 0 2 0 0 1]
labels: [2 0 1 1 1 0 0 1]
#############
predictions: [0 1 0 0 0 0 0 2]
labels: [1 1 1 0 0 0 2 2]
#############
predictions: [0 2 0 1 0 2 0 0]
labels: [1 1 1 1 2 2 1 0]
#############
predictions: [1 2 0 0 0 0 0 0]
labels: [2 0 2 1 0 1 1 1]
#############
Training Loss: tensor(1.2082, device='cuda:0', grad_fn=<NllLossBackward>)
predictions: [0 2 0 0 0 0 2 0]
labels: [1 0 2 1 2 2 1 1]
#############
predictions: [2 0 0 0 0 0 1 0]
labels: [1 0 0 0 0 2 1 0]
#############
predictions: [0 0 0 0 2 1 1 1]
labels: [0 2 2 0 1 2 1 1]
#############
predictions: [2 1 0 1 0 0 2 0]
labels: [1 0 2 1 0 2 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 0 0 0 0 1 0]
#############
predictions: [0 2 1 0 0 0 1 1]
labels: [0 2 2 2 2 1 1 0]
#############
predictions: [0 0 0 1 1 0 0 1]
labels: [0 1 0 1 2 2 2 2]
#############
predictions: [0 0 0 1 1 1 1 2]
labels: [2 2 1 2 0 1 1 1]
#############
predictions: [0 1 2 0 0 1 0 0]
labels: [0 2 1 0 0 1 0 0]
#############
predictions: [1 1 1 1 0 0 0 0]
labels: [2 1 2 1 0 2 2 1]
#############
predictions: [0 1 2 0 0 1 1 0]
labels: [2 0 2 1 1 1 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 1 1 0 0]
#############
predictions: [0 0 0 0 0 1 2 2]
labels: [2 2 1 1 0 2 1 2]
#############
predictions: [0 1 0 0 1 1 0 1]
labels: [0 1 0 2 1 0 0 1]
#############
predictions: [0 2 2 0 0 0 0 2]
labels: [0 0 2 1 2 2 0 1]
#############
predictions: [2 0 0 2 2 0 2 0]
labels: [2 1 0 2 2 0 1 0]
#############
predictions: [0 2 2 0 2 1 1 2]
labels: [1 1 0 0 2 1 0 0]
#############
predictions: [1 1 2 2 0 0 1 2]
labels: [2 0 2 0 1 1 1 1]
#############
predictions: [0 1 1 0 0 1 1 0]
labels: [0 2 1 0 0 2 2 0]
#############
predictions: [2 1 0 0 0 0 1 1]
labels: [0 2 0 2 0 0 1 1]
#############
predictions: [1 2 0 1 2 0 0 0]
labels: [1 0 1 1 0 2 2 2]
#############
predictions: [0 0 0 0 2 2 1 2]
labels: [2 2 2 1 1 1 1 0]
#############
predictions: [1 2 0 1 0 0 2 0]
labels: [2 2 1 1 1 0 2 0]
#############
predictions: [2 0 0 0 0 2 1]
labels: [0 1 1 2 2 0 2]
#############
======Average Training Loss: 1.11279======
======Average Training Accuracy: 33.77%======
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 1 1 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 0 2 1 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 2 2 1 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 2 0 1 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 0 0 1 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 1 2 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 1 2 0 2 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 1 2 2 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 2 2 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 2 0 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 1 1 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 2 2 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 0 1 2 2 1]
#############
predictions: [0 0 0 1 0 0 0 0]
labels: [0 0 1 1 0 2 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 2 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 2 2 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 2 1 2 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 0 2 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 0 1 0 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 2 2 2 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 1 1 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 1 1 2 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 1 2 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 1 0 2 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 1 1 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 0 1 2 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 1 1 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 1 0 0 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 0 0 0 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 1 1 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 1 2 1 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 0 1 1 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 1 0 1 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 2 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 0 2 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 1 1 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 1 2 2 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 0 0 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 1 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 2 1 1 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 2 2 2 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 2 1 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 2 2 2 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 0 0 1 0 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 0 0 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 1 2 0 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 2 0 1 2 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 0 0 0 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 0 0 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 0 1 1 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 0 0 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 2 1 1 1 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 0 2 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 2 1 0 2 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 1 2 2 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 0 0 2 1 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 2 0 2 1 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 2 0 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 0 0 1 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 2 0 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 1 1 1 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 2 2 1 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 0 2 0 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 1 1 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 1 1 1 1 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 2 1 0 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 2 1 0 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 2 0 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 2 0 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 0 2 2 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 0 1 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 2 0 1 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 0 0 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 1 2 0 2 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 1 0 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 0 1 0 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 0 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 1 2 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 2 0 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 0 1 1 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 0 1 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 1 2 0 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 1 1 0 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 1 2 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 0 1 1 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 0 2 1 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 0 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 1 2 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 0 1 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 1 0 2 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 2 0 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 2 2 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 1 2 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 1 1 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 2 2 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 0 0 1 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 2 1 2 1 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 0 0 0 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 1 1 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 0 2 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 1 1 2 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 0 1 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 2 2 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 2 2 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 1 0 2 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 1 0 2 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 0 2 2 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 0 1 0 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 1 0 0 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 0 1 2 1 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 2 2 2 2 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 0 1 2 0 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 1 1 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 0 0 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 2 0 1 2 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 1 2 1 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 1 1 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 0 0 0 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 1 2 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 1 0 1 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 2 2 0 0 1]
#############
predictions: [0 0 0 0 0 0 0]
labels: [2 2 1 1 0 0 1]
#############
======Average Validation Loss: 1.09527======
======Average Validation Accuracy: 35.53%======
对于使用 BERT 的 multi-class classification/sentiment 分析,'neutral' class 必须为 2!它不能在 'negative' = 0 和 'positive' = 2
之间
我正在金融新闻数据集上微调 BERT。 不幸的是,BERT 似乎陷入了局部最小值。满足于学习总是预测相同的 class.
- 平衡数据集不起作用
- 调整参数也不起作用
老实说,我不确定是什么导致了这个问题。使用 simpletransformers 库,我得到了很好的结果。如果有人能帮助我,我将不胜感激。非常感谢!
github 上的完整代码: https://github.com/Bene939/BERT_News_Sentiment_Classifier
代码:
from transformers import BertForSequenceClassification, AdamW, BertTokenizer, get_linear_schedule_with_warmup, Trainer, TrainingArguments
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import pandas as pd
from pathlib import Path
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
from torch.nn import functional as F
from collections import defaultdict
import random
#defining tokenizer, model and optimizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=3)
if torch.cuda.is_available():
print("\nUsing: ", torch.cuda.get_device_name(0))
device = torch.device('cuda')
else:
print("\nUsing: CPU")
device = torch.device('cpu')
model = model.to(device)
#loading dataset
labeled_dataset = "news_headlines_sentiment.csv"
labeled_dataset_file = Path(labeled_dataset)
file_loaded = False
while not file_loaded:
if labeled_dataset_file.exists():
labeled_dataset = pd.read_csv(labeled_dataset_file)
file_loaded = True
print("Dataset Loaded")
else:
print("File not Found")
print(labeled_dataset)
#counting sentiments
negative = 0
neutral = 0
positive = 0
for idx, row in labeled_dataset.iterrows():
if row["sentiment"] == 0:
negative += 1
elif row["sentiment"] == 1:
neutral += 1
else:
positive += 1
print("Unbalanced Dataset")
print("negative: ", negative)
print("neutral: ", neutral)
print("positive: ", positive)
#balancing dataset to 1/3 per sentiment
for idx, row in labeled_dataset.iterrows():
if row["sentiment"] == 0:
if negative - neutral != 0:
index_name = labeled_dataset[labeled_dataset["news"] == row["news"]].index
labeled_dataset.drop(index_name, inplace=True)
negative -= 1
elif row["sentiment"] == 2:
if positive - neutral != 0:
index_name = labeled_dataset[labeled_dataset["news"] == row["news"]].index
labeled_dataset.drop(index_name, inplace=True)
positive -= 1
#custom dataset class
class NewsSentimentDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
#method for tokenizing dataset list
def tokenize_headlines(headlines, labels, tokenizer):
encodings = tokenizer.batch_encode_plus(
headlines,
add_special_tokens = True,
truncation = True,
padding = 'max_length',
return_attention_mask = True,
return_token_type_ids = True
)
dataset = NewsSentimentDataset(encodings, labels)
return dataset
#splitting dataset into training and validation set
#load news sentiment dataset
all_headlines = labeled_dataset['news'].tolist()
all_labels = labeled_dataset['sentiment'].tolist()
train_headlines, val_headlines, train_labels, val_labels = train_test_split(all_headlines, all_labels, test_size=.2)
val_dataset = tokenize_headlines(val_headlines, val_labels, tokenizer)
train_dataset = tokenize_headlines(train_headlines, val_labels, tokenizer)
#data loader
train_batch_size = 8
val_batch_size = 8
train_data_loader = DataLoader(train_dataset, batch_size = train_batch_size, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size = val_batch_size, sampler=SequentialSampler(val_dataset))
#optimizer and scheduler
num_epochs = 1
num_steps = len(train_data_loader) * num_epochs
optimizer = AdamW(model.parameters(), lr=5e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_steps*0.06, num_training_steps=num_steps)
#training and evaluation
seed_val = 64
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
for epoch in range(num_epochs):
print("\n###################################################")
print("Epoch: {}/{}".format(epoch+1, num_epochs))
print("###################################################\n")
#training phase
average_train_loss = 0
average_train_acc = 0
model.train()
for step, batch in enumerate(train_data_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
token_type_ids = batch['token_type_ids'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids = token_type_ids)
loss = F.cross_entropy(outputs[0], labels)
average_train_loss += loss
if step % 40 == 0:
print("Training Loss: ", loss)
logits = outputs[0].detach().cpu().numpy()
label_ids = labels.to('cpu').numpy()
average_train_acc += sklearn.metrics.accuracy_score(label_ids, np.argmax(logits, axis=1))
print("predictions: ",np.argmax(logits, axis=1))
print("labels: ",label_ids)
print("#############")
optimizer.zero_grad()
loss.backward()
#maximum gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
model.zero_grad()
average_train_loss = average_train_loss / len(train_data_loader)
average_train_acc = average_train_acc / len(train_data_loader)
print("======Average Training Loss: {:.5f}======".format(average_train_loss))
print("======Average Training Accuracy: {:.2f}%======".format(average_train_acc*100))
#validation phase
average_val_loss = 0
average_val_acc = 0
model.eval()
for step,batch in enumerate(val_data_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
token_type_ids = batch['token_type_ids'].to(device)
pred = []
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
loss = F.cross_entropy(outputs[0], labels)
average_val_loss += loss
logits = outputs[0].detach().cpu().numpy()
label_ids = labels.to('cpu').numpy()
print("predictions: ",np.argmax(logits, axis=1))
print("labels: ",label_ids)
print("#############")
average_val_acc += sklearn.metrics.accuracy_score(label_ids, np.argmax(logits, axis=1))
average_val_loss = average_val_loss / len(val_data_loader)
average_val_acc = average_val_acc / len(val_data_loader)
print("======Average Validation Loss: {:.5f}======".format(average_val_loss))
print("======Average Validation Accuracy: {:.2f}%======".format(average_val_acc*100))
###################################################
Epoch: 1/1
###################################################
Training Loss: tensor(1.1006, device='cuda:0', grad_fn=<NllLossBackward>)
predictions: [1 0 2 0 0 0 2 0]
labels: [2 0 1 1 0 1 0 1]
#############
predictions: [2 2 0 0 0 2 0 0]
labels: [1 2 1 0 2 0 1 2]
#############
predictions: [0 0 0 0 1 0 0 1]
labels: [0 1 1 0 1 1 2 0]
#############
predictions: [0 0 0 2 0 1 0 0]
labels: [0 0 0 2 0 0 2 1]
#############
predictions: [1 0 0 0 0 0 2 0]
labels: [0 2 2 1 0 0 0 0]
#############
predictions: [0 0 0 0 0 1 0 0]
labels: [1 0 2 2 2 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 2 2 0 2 0]
#############
predictions: [0 1 0 0 0 0 0 0]
labels: [2 2 0 2 0 0 0 1]
#############
predictions: [0 0 0 0 0 2 0 1]
labels: [0 1 0 2 2 0 1 2]
#############
predictions: [0 0 2 0 0 0 1 0]
labels: [0 0 0 1 2 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 1 0 1 0 1 1]
#############
predictions: [0 2 0 0 0 0 0 0]
labels: [2 2 0 1 0 1 2 1]
#############
predictions: [0 1 0 0 0 0 1 2]
labels: [2 2 1 0 2 0 0 2]
#############
predictions: [0 0 1 1 1 1 0 1]
labels: [1 2 1 1 1 1 2 2]
#############
predictions: [1 0 0 0 0 1 2 1]
labels: [1 0 1 1 0 0 0 2]
#############
predictions: [0 1 1 1 1 0 2 1]
labels: [2 2 1 2 2 1 1 2]
#############
predictions: [0 0 1 0 1 1 0 0]
labels: [1 0 0 1 0 1 0 2]
#############
predictions: [1 2 0 0 1 2 0 0]
labels: [0 2 2 1 2 0 1 0]
#############
predictions: [0 2 1 1 0 1 1 0]
labels: [2 2 0 1 1 0 1 2]
#############
predictions: [1 0 1 1 1 1 1 0]
labels: [0 2 0 1 0 1 2 2]
#############
predictions: [0 2 1 2 0 0 1 1]
labels: [2 1 1 1 1 2 2 0]
#############
predictions: [0 1 2 2 2 1 1 2]
labels: [2 2 1 1 2 1 0 1]
#############
predictions: [2 2 2 1 2 1 1 1]
labels: [0 1 1 0 0 2 2 1]
#############
predictions: [1 2 2 2 1 2 1 2]
labels: [0 0 0 0 2 0 1 2]
#############
predictions: [2 1 1 1 2 2 2 2]
labels: [1 0 2 2 1 0 0 0]
#############
predictions: [2 1 2 2 2 1 2 2]
labels: [2 1 1 1 1 1 2 2]
#############
predictions: [1 1 0 2 1 2 1 2]
labels: [2 2 0 2 0 1 2 0]
#############
predictions: [0 1 1 2 0 1 2 1]
labels: [2 2 2 1 2 2 0 1]
#############
predictions: [2 1 1 1 1 2 1 1]
labels: [0 1 1 2 1 0 0 2]
#############
predictions: [1 2 2 0 1 1 1 2]
labels: [0 1 2 1 2 1 0 1]
#############
predictions: [0 1 1 1 1 1 1 0]
labels: [0 2 0 1 1 2 2 2]
#############
predictions: [1 2 1 1 2 1 1 0]
labels: [0 2 2 2 0 0 1 0]
#############
predictions: [2 2 2 1 2 1 1 2]
labels: [2 2 1 2 1 0 0 0]
#############
predictions: [2 2 1 2 2 2 1 2]
labels: [1 1 2 2 2 0 2 1]
#############
predictions: [2 2 2 2 2 0 2 2]
labels: [2 2 1 2 0 1 1 2]
#############
predictions: [1 1 2 1 2 2 0 1]
labels: [2 1 1 1 0 0 2 2]
#############
predictions: [2 1 2 2 2 2 1 0]
labels: [0 2 0 2 0 0 0 0]
#############
predictions: [2 2 2 2 2 2 2 2]
labels: [1 1 0 2 0 1 2 1]
#############
predictions: [2 2 2 2 1 2 2 2]
labels: [1 0 0 1 1 0 0 0]
#############
predictions: [2 2 2 1 2 2 2 2]
labels: [1 0 1 1 0 2 2 0]
#############
Training Loss: tensor(1.1104, device='cuda:0', grad_fn=<NllLossBackward>)
predictions: [2 0 1 2 1 2 2 0]
labels: [2 2 0 0 1 0 0 2]
#############
predictions: [0 2 2 0 2 1 1 1]
labels: [0 0 0 1 0 0 1 0]
#############
predictions: [0 2 2 0 1 1 1 2]
labels: [2 1 1 1 2 2 1 0]
#############
predictions: [2 1 1 2 2 0 2 0]
labels: [1 2 1 2 1 0 2 1]
#############
predictions: [0 2 2 0 0 2 1 2]
labels: [0 0 2 2 0 0 2 0]
#############
predictions: [0 0 1 2 2 0 2 2]
labels: [0 0 0 0 0 0 0 0]
#############
predictions: [1 1 2 1 2 0 1 2]
labels: [0 0 2 0 0 0 1 1]
#############
predictions: [0 0 2 1 0 2 0 1]
labels: [1 1 2 1 1 0 2 0]
#############
predictions: [0 0 0 0 1 0 0 0]
labels: [2 2 1 1 2 1 1 1]
#############
predictions: [0 0 0 0 1 0 0 0]
labels: [1 1 2 2 1 1 2 0]
#############
predictions: [0 0 0 0 0 1 1 1]
labels: [2 0 1 1 0 1 2 2]
#############
predictions: [0 0 1 0 0 1 2 1]
labels: [1 2 0 2 2 0 2 1]
#############
predictions: [1 1 1 1 0 1 0 1]
labels: [2 0 1 0 1 0 1 2]
#############
predictions: [1 2 2 0 0 0 1 1]
labels: [2 0 0 2 1 2 2 2]
#############
predictions: [1 0 2 1 0 2 2 0]
labels: [0 0 2 1 2 1 1 1]
#############
predictions: [0 0 0 1 1 1 1 1]
labels: [1 2 1 0 0 0 1 0]
#############
predictions: [1 1 1 0 1 1 0 1]
labels: [0 2 1 2 1 2 2 0]
#############
predictions: [2 1 0 1 1 2 0 0]
labels: [0 1 0 0 1 2 0 2]
#############
predictions: [0 1 1 0 0 1 0 1]
labels: [1 0 0 2 2 1 1 2]
#############
predictions: [1 1 1 1 1 1 1 1]
labels: [2 0 1 0 2 0 0 2]
#############
predictions: [1 0 0 1 0 1 0 2]
labels: [1 0 0 1 1 2 2 1]
#############
predictions: [1 1 1 1 1 1 0 0]
labels: [1 1 0 2 1 0 2 0]
#############
predictions: [1 1 2 1 0 1 0 0]
labels: [0 2 1 2 1 1 0 2]
#############
predictions: [1 1 0 0 1 2 1 1]
labels: [0 2 1 0 2 2 0 1]
#############
predictions: [0 1 1 0 0 1 0 1]
labels: [0 0 1 2 2 0 1 2]
#############
predictions: [1 0 2 2 2 1 1 0]
labels: [2 2 1 0 0 1 1 2]
#############
predictions: [1 2 2 1 1 2 1 1]
labels: [1 0 0 1 0 0 0 0]
#############
predictions: [0 2 0 2 2 0 2 2]
labels: [2 0 0 0 2 1 1 2]
#############
predictions: [0 0 1 0 1 0 2 2]
labels: [0 0 1 0 1 0 2 0]
#############
predictions: [0 2 0 1 1 2 2 0]
labels: [0 2 0 2 0 2 0 0]
#############
predictions: [2 2 2 2 2 2 2 1]
labels: [2 2 1 1 0 0 2 2]
#############
predictions: [2 0 0 2 2 1 1 0]
labels: [1 0 0 1 0 2 1 2]
#############
predictions: [2 0 0 2 0 2 2 0]
labels: [2 2 2 2 0 1 1 1]
#############
predictions: [0 2 2 0 2 2 0 0]
labels: [1 0 1 2 0 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 2]
labels: [2 1 1 0 0 0 1 2]
#############
predictions: [2 0 2 0 2 1 0 2]
labels: [2 1 1 2 1 1 0 0]
#############
predictions: [1 1 2 0 2 0 2 2]
labels: [0 2 1 2 1 2 1 0]
#############
predictions: [2 0 1 1 0 2 0 0]
labels: [2 1 0 1 1 0 2 0]
#############
predictions: [2 0 0 2 0 2 1 0]
labels: [0 0 0 0 2 1 0 1]
#############
predictions: [1 2 1 0 0 2 0 2]
labels: [2 0 2 1 0 0 1 1]
#############
Training Loss: tensor(1.1162, device='cuda:0', grad_fn=<NllLossBackward>)
predictions: [2 0 0 1 1 1 0 1]
labels: [0 1 1 1 1 2 2 1]
#############
predictions: [0 2 0 1 2 0 0 1]
labels: [2 2 1 0 1 0 0 0]
#############
predictions: [0 0 1 0 0 0 0 1]
labels: [1 0 2 0 0 2 2 0]
#############
predictions: [2 1 2 2 0 1 2 0]
labels: [2 0 1 0 2 1 0 1]
#############
predictions: [1 0 0 2 0 0 1 1]
labels: [2 2 0 2 0 2 0 0]
#############
predictions: [0 0 1 0 0 0 0 0]
labels: [2 2 2 1 2 2 2 2]
#############
predictions: [0 0 1 1 0 1 1 0]
labels: [2 1 1 1 0 2 1 0]
#############
predictions: [0 0 0 1 0 0 1 0]
labels: [2 0 2 2 0 0 1 2]
#############
predictions: [1 0 1 0 0 2 0 0]
labels: [1 1 2 0 0 1 0 0]
#############
predictions: [2 1 0 0 0 1 0 0]
labels: [1 2 0 0 0 0 0 0]
#############
predictions: [0 2 0 0 0 0 0 0]
labels: [2 0 1 1 2 2 1 1]
#############
predictions: [0 1 0 0 0 1 0 2]
labels: [0 2 1 1 0 0 1 2]
#############
predictions: [0 2 1 0 0 1 1 1]
labels: [1 1 0 2 0 1 1 0]
#############
predictions: [0 1 1 0 0 0 1 0]
labels: [0 0 1 0 1 2 1 1]
#############
predictions: [0 1 1 0 1 0 0 0]
labels: [0 1 1 1 2 2 2 0]
#############
predictions: [0 0 0 0 1 1 0 0]
labels: [2 0 2 2 1 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 0 2 2 0 1 1]
#############
predictions: [0 1 0 0 0 0 0 0]
labels: [0 2 0 1 1 2 0 2]
#############
predictions: [1 1 0 1 0 1 0 2]
labels: [1 2 0 0 2 2 2 1]
#############
predictions: [1 1 0 0 0 1 2 1]
labels: [0 0 1 2 2 1 2 2]
#############
predictions: [1 1 1 0 1 1 2 0]
labels: [0 0 0 2 0 1 0 2]
#############
predictions: [0 1 0 0 1 1 2 1]
labels: [2 0 0 1 2 2 1 2]
#############
predictions: [1 0 0 0 1 0 0 1]
labels: [1 2 2 2 2 1 0 1]
#############
predictions: [2 0 0 0 0 0 0 0]
labels: [1 2 0 2 2 1 1 1]
#############
predictions: [2 0 1 1 0 0 1 0]
labels: [0 0 0 0 2 2 1 1]
#############
predictions: [2 0 0 1 0 0 1 1]
labels: [2 2 1 1 0 0 1 0]
#############
predictions: [1 1 1 1 1 2 0 0]
labels: [0 0 2 1 0 0 0 0]
#############
predictions: [1 1 2 0 1 2 0 1]
labels: [0 2 1 0 2 0 0 1]
#############
predictions: [0 0 2 1 0 2 0 1]
labels: [1 2 0 2 2 1 0 0]
#############
predictions: [0 0 2 0 2 1 1 2]
labels: [2 2 1 2 2 2 0 0]
#############
predictions: [0 1 0 0 0 0 2 1]
labels: [1 1 0 1 1 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 0 0 2 0 0 2]
#############
predictions: [2 2 2 0 1 1 1 0]
labels: [1 0 2 1 1 2 0 0]
#############
predictions: [0 0 1 0 0 0 2 0]
labels: [0 1 2 1 1 0 0 0]
#############
predictions: [0 2 0 1 0 2 0 0]
labels: [0 0 2 1 1 0 2 2]
#############
predictions: [0 0 1 2 0 2 0 1]
labels: [2 2 0 0 0 2 2 2]
#############
predictions: [1 0 0 0 2 0 0 1]
labels: [2 0 1 1 1 0 0 1]
#############
predictions: [0 1 0 0 0 0 0 2]
labels: [1 1 1 0 0 0 2 2]
#############
predictions: [0 2 0 1 0 2 0 0]
labels: [1 1 1 1 2 2 1 0]
#############
predictions: [1 2 0 0 0 0 0 0]
labels: [2 0 2 1 0 1 1 1]
#############
Training Loss: tensor(1.2082, device='cuda:0', grad_fn=<NllLossBackward>)
predictions: [0 2 0 0 0 0 2 0]
labels: [1 0 2 1 2 2 1 1]
#############
predictions: [2 0 0 0 0 0 1 0]
labels: [1 0 0 0 0 2 1 0]
#############
predictions: [0 0 0 0 2 1 1 1]
labels: [0 2 2 0 1 2 1 1]
#############
predictions: [2 1 0 1 0 0 2 0]
labels: [1 0 2 1 0 2 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 0 0 0 0 1 0]
#############
predictions: [0 2 1 0 0 0 1 1]
labels: [0 2 2 2 2 1 1 0]
#############
predictions: [0 0 0 1 1 0 0 1]
labels: [0 1 0 1 2 2 2 2]
#############
predictions: [0 0 0 1 1 1 1 2]
labels: [2 2 1 2 0 1 1 1]
#############
predictions: [0 1 2 0 0 1 0 0]
labels: [0 2 1 0 0 1 0 0]
#############
predictions: [1 1 1 1 0 0 0 0]
labels: [2 1 2 1 0 2 2 1]
#############
predictions: [0 1 2 0 0 1 1 0]
labels: [2 0 2 1 1 1 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 1 1 0 0]
#############
predictions: [0 0 0 0 0 1 2 2]
labels: [2 2 1 1 0 2 1 2]
#############
predictions: [0 1 0 0 1 1 0 1]
labels: [0 1 0 2 1 0 0 1]
#############
predictions: [0 2 2 0 0 0 0 2]
labels: [0 0 2 1 2 2 0 1]
#############
predictions: [2 0 0 2 2 0 2 0]
labels: [2 1 0 2 2 0 1 0]
#############
predictions: [0 2 2 0 2 1 1 2]
labels: [1 1 0 0 2 1 0 0]
#############
predictions: [1 1 2 2 0 0 1 2]
labels: [2 0 2 0 1 1 1 1]
#############
predictions: [0 1 1 0 0 1 1 0]
labels: [0 2 1 0 0 2 2 0]
#############
predictions: [2 1 0 0 0 0 1 1]
labels: [0 2 0 2 0 0 1 1]
#############
predictions: [1 2 0 1 2 0 0 0]
labels: [1 0 1 1 0 2 2 2]
#############
predictions: [0 0 0 0 2 2 1 2]
labels: [2 2 2 1 1 1 1 0]
#############
predictions: [1 2 0 1 0 0 2 0]
labels: [2 2 1 1 1 0 2 0]
#############
predictions: [2 0 0 0 0 2 1]
labels: [0 1 1 2 2 0 2]
#############
======Average Training Loss: 1.11279======
======Average Training Accuracy: 33.77%======
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 1 1 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 0 2 1 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 2 2 1 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 2 0 1 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 0 0 1 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 1 2 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 1 2 0 2 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 1 2 2 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 2 2 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 2 0 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 1 1 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 2 2 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 0 1 2 2 1]
#############
predictions: [0 0 0 1 0 0 0 0]
labels: [0 0 1 1 0 2 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 2 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 2 2 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 2 1 2 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 0 2 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 0 1 0 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 2 2 2 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 1 1 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 1 1 2 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 1 2 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 1 0 2 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 1 1 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 0 1 2 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 1 1 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 1 0 0 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 0 0 0 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 1 1 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 1 2 1 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 0 1 1 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 1 0 1 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 2 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 0 2 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 1 1 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 1 2 2 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 0 0 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 1 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 2 1 1 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 2 2 2 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 2 1 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 2 2 2 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 0 0 1 0 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 0 0 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 1 2 0 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 2 0 1 2 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 0 0 0 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 0 0 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 0 1 1 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 0 0 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 2 1 1 1 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 0 2 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 2 1 0 2 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 1 2 2 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 0 0 2 1 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 2 0 2 1 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 2 0 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 0 0 1 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 2 2 0 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 1 1 1 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 2 2 1 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 0 2 0 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 1 1 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 1 1 1 1 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 2 1 0 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 2 1 0 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 2 0 0 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 2 0 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 0 2 2 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 0 1 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 2 0 1 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 0 0 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 1 2 0 2 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 1 0 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 0 1 0 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 0 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 1 2 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 2 0 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 0 1 1 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 0 1 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 1 2 0 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 1 1 0 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 1 2 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 0 1 1 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 0 2 1 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 0 0 0 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 1 2 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 0 1 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 1 0 2 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 1 2 0 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 2 2 2 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 1 2 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 1 1 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 2 2 2 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 0 0 1 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 2 1 2 1 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 0 0 0 2 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 1 1 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 0 2 2 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 1 1 2 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 0 1 0 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 2 2 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 2 2 2 1 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 0 1 0 2 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 2 1 0 2 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 2 0 2 2 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 2 0 0 1 0 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 1 0 0 0 2 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 2 0 1 2 1 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 2 2 2 2 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 0 1 2 0 2 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 1 1 1 1 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 0 0 0 1 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 1 2 0 1 2 2 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 1 1 1 2 1 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [1 0 1 1 1 0 0 2]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 2 0 0 0 0 1 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [0 0 1 1 2 0 0 1]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 1 1 1 0 1 0 0]
#############
predictions: [0 0 0 0 0 0 0 0]
labels: [2 0 2 2 2 0 0 1]
#############
predictions: [0 0 0 0 0 0 0]
labels: [2 2 1 1 0 0 1]
#############
======Average Validation Loss: 1.09527======
======Average Validation Accuracy: 35.53%======
对于使用 BERT 的 multi-class classification/sentiment 分析,'neutral' class 必须为 2!它不能在 'negative' = 0 和 'positive' = 2
之间