mlx.nn.GRU

Contents

mlx.nn.GRU#

class GRU(input_size: int, hidden_size: int, bias: bool = True)#

A gated recurrent unit (GRU) RNN layer.

The input has shape NLD or LD where:

  • N is the optional batch dimension

  • L is the sequence length

  • D is the input’s feature dimension

Concretely, for each element of the sequence, this layer computes:

rt=σ(Wxrxt+Whrht+br)zt=σ(Wxzxt+Whzht+bz)nt=tanh(Wxnxt+bn+rt(Whnht+bhn))ht+1=(1zt)nt+ztht

The hidden state h has shape NH or H depending on whether the input is batched or not. Returns the hidden state at each time step of shape NLH or LH.

Parameters:
  • input_size (int) – Dimension of the input, D.

  • hidden_size (int) – Dimension of the hidden state, H.

  • bias (bool) – Whether to use biases or not. Default: True.

Methods