mlx.nn.GRU#
- class GRU(input_size: int, hidden_size: int, bias: bool = True)#
A gated recurrent unit (GRU) RNN layer.
The input has shape
NLD
orLD
where:N
is the optional batch dimensionL
is the sequence lengthD
is the input’s feature dimension
Concretely, for each element of the sequence, this layer computes:
The hidden state
has shapeNH
orH
depending on whether the input is batched or not. Returns the hidden state at each time step of shapeNLH
orLH
.- Parameters:
Methods