Embedding category data for LSTM network

hi - could i ask how one would embed category data for a LSTM network? Suppose I have 108 unique tokens by m columns (each column is a time step). Currently I am passing this into a LSTM network by coverting the tokens to numerical values using double(108xm) and then oneshot as follows
cX = cell(1,1);
cY = cell(1,1);
corpusTokens = split(corpus,'@');
corpusValues = categorical(corpusTokens); % sequence of tokens (by columns)
dTokens = double(corpusValues); % numeric tokens
uniqueTokens = categories(corpusValues); % unique tokens
vocabularySize = numel(uniqueTokens); % number of unique tokens
cX{1}= dummyvar(dTokens(1:end-1))'; % oneshot coding
cY{1}= categorical(corpusTokens(2:end)'); % sequence of tokens (by rows)
cX and cY are then passed into the network. This works fine but I want to embed the input to 32 dimensions. How do you pass the embedded data into the network?
Thank you in advance
Philip

답변 (1개)

Hornett
Hornett 2024년 9월 16일

1 개 추천

I understand that you have categorical data that you want to input into an LSTM network with 32 dimensions. To achieve this, you can incorporate a word embedding layer before the LSTM layer to convert the words into a 32-dimensional representation. Here is an example code to demonstrate this:
% Assuming you have already prepared your data and defined cX and cY
embeddingDimension = 32; % Dimensionality of the word embedding space
% Create the word embedding layer
embeddingLayer = wordEmbeddingLayer(embeddingDimension, vocabularySize);
% Define the LSTM network architecture
layers = [
sequenceInputLayer(inputSize) % inputSize depends on your data preprocessing
embeddingLayer
lstmLayer(numHiddenUnits, 'OutputMode', 'sequence')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer
];
% Train the LSTM network
options = trainingOptions('adam', 'MaxEpochs', 10);
net = trainNetwork(cX, cY, layers, options);
Please find links to below documentation which I believe will help you for further reference.\

카테고리

도움말 센터File Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

제품

릴리스

R2021b

질문:

2022년 6월 8일

답변:

2024년 9월 16일

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by