Embedding category data for LSTM network

조회 수: 6 (최근 30일)
Philip Hua
Philip Hua 2022년 6월 8일
답변: Yash Sharma 2023년 9월 7일
hi - could i ask how one would embed category data for a LSTM network? Suppose I have 108 unique tokens by m columns (each column is a time step). Currently I am passing this into a LSTM network by coverting the tokens to numerical values using double(108xm) and then oneshot as follows
cX = cell(1,1);
cY = cell(1,1);
corpusTokens = split(corpus,'@');
corpusValues = categorical(corpusTokens); % sequence of tokens (by columns)
dTokens = double(corpusValues); % numeric tokens
uniqueTokens = categories(corpusValues); % unique tokens
vocabularySize = numel(uniqueTokens); % number of unique tokens
cX{1}= dummyvar(dTokens(1:end-1))'; % oneshot coding
cY{1}= categorical(corpusTokens(2:end)'); % sequence of tokens (by rows)
cX and cY are then passed into the network. This works fine but I want to embed the input to 32 dimensions. How do you pass the embedded data into the network?
Thank you in advance
Philip

답변 (1개)

Yash Sharma
Yash Sharma 2023년 9월 7일
I understand that you have categorical data that you want to input into an LSTM network with 32 dimensions. To achieve this, you can incorporate a word embedding layer before the LSTM layer to convert the words into a 32-dimensional representation. Here is an example code to demonstrate this:
% Assuming you have already prepared your data and defined cX and cY
embeddingDimension = 32; % Dimensionality of the word embedding space
% Create the word embedding layer
embeddingLayer = wordEmbeddingLayer(embeddingDimension, vocabularySize);
% Define the LSTM network architecture
layers = [
sequenceInputLayer(inputSize) % inputSize depends on your data preprocessing
embeddingLayer
lstmLayer(numHiddenUnits, 'OutputMode', 'sequence')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer
];
% Train the LSTM network
options = trainingOptions('adam', 'MaxEpochs', 10);
net = trainNetwork(cX, cY, layers, options);
Please find links to below documentation which I believe will help you for further reference.\

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by