mlx.nn.LSTM#
- class LSTM(input_size: int, hidden_size: int, bias: bool = True)#
An LSTM recurrent layer.
The input has shape
NLDorLDwhere:Nis the optional batch dimensionLis the sequence lengthDis the input’s feature dimension
Concretely, for each element of the sequence, this layer computes:
\[\begin{split}\begin{aligned} i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\ f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\ g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\ o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\ c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\ h_{t + 1} &= o_t \text{tanh}(c_{t + 1}) \end{aligned}\end{split}\]The hidden state \(h\) and cell state \(c\) have shape
NHorH, depending on whether the input is batched or not.The layer returns two arrays, the hidden state and the cell state at each time step, both of shape
NLHorLH.- Parameters:
Methods