Уведомления

Группа в Telegram: @pythonsu

#1 Фев. 3, 2021 19:46:43

Roodger
Зарегистрирован: 2021-02-03
Сообщения: 1
Репутация: +  0  -
Профиль   Отправить e-mail  

Нужна помощь по torchtext

Использую модуль torchtext для классификации текстов.
Классификация с помощью pytorch моделей
Не получается разобраться с механизмом векторизации текста.
В общем, если использовать sclearn для получения векторов, то все идет нормально. Скармливаем трейн и тест модели получаем нормальный скор
Хотел применить вместо sklearn torchtext , но почему то чепуха. Подозреваю, я неправильно использую словарь слов. Валидация показывает аккураси 90, а тест 1
Вот код для создания трейна и валид датасетов и итератора

 TEXT = torchtext.data.Field(sequential=True, tokenize=lambda x: x.split(' '), pad_token='<pad>', unk_token='<unk>')
LABEL = torchtext.data.LabelField(sequential=False, unk_token=None, is_target=True)
datafields = [('text', TEXT), ('label', LABEL)]
data = read_data(train, datafields, label_column='category_id')
valid, trainn = data.split([0.1, 0.9])
TEXT.build_vocab(trainn, max_size=50000)
LABEL.build_vocab(trainn)
model = CBoWTextClassifier2(TEXT, LABEL, emb_dim=128) 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
train_iterator = torchtext.data.BucketIterator(
        trainn,
        device=device,
        batch_size=256,
        sort_key=lambda x: len(x.text),
        shuffle = True,
        repeat=False,
        train=True)
    
valid_iterator = torchtext.data.Iterator(
        valid,
        device=device,
        batch_size=256,
        repeat=False,
        train=False,
        sort=False)

Вот код теста:
 datafields = [('text', TEXT)]
data = read_data(test, datafields)
test_iterator = torchtext.data.Iterator(
    data,
#     device='cpu',
    batch_size=1,
    repeat=False,
    train=False,
    sort=False)

Остальное стандартно. Модель обычная эмбеддинг и лиейные слои

Офлайн

Board footer

Модераторировать

Powered by DjangoBB

Lo-Fi Version