fedem.models package

Submodules

fedem.models.mamba module

class fedem.models.mamba.MambaBlock(config: MambaConfig)[source]

Bases: Module

Mamba block module as described in the Mamba paper.

Parameters:

config (MambaConfig) – Mamba model configuration.

forward(x)[source]

Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].

Parameters:

x – shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n…)

Returns:

shape (b, l, d)

Return type:

output

Official Implementation:

class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311

selective_scan(u, delta, A, B, C, D)[source]
Does selective scan algorithm. See:
  • Section 2 State Space Models in the Mamba paper [1]

  • Algorithm 2 in Section 3.2 in the Mamba paper [1]

  • run_SSM(A, B, C, u) in The Annotated S4 [2]

This is the classic discrete state space formula:

x(t + 1) = Ax(t) + Bu(t) y(t) = Cx(t) + Du(t)

except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).

Parameters:
  • u – shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n…)

  • delta – shape (b, l, d_in)

  • A – shape (d_in, n)

  • B – shape (b, l, n)

  • C – shape (b, l, n)

  • D – shape (d_in,)

Returns:

shape (b, l, d_in)

Return type:

output

Official Implementation:

selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 Note: I refactored some parts out of selective_scan_ref out, so the functionality doesn’t match exactly.

ssm(x)[source]
Runs the SSM. See:
  • Algorithm 2 in Section 3.2 in the Mamba paper [1]

  • run_SSM(A, B, C, u) in The Annotated S4 [2]

Parameters:

x – shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n…)

Returns:

shape (b, l, d_in)

Return type:

output

Official Implementation:

mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311

class fedem.models.mamba.MambaForCausalLM(config)[source]

Bases: MambaPreTrainedModel

Mamba model for Causal Language Modeling.

Parameters:

config (MambaConfig) – Mamba model configuration.

forward(input_ids: LongTensor | None = None, labels: LongTensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None) Tuple | CausalLMOutputWithPast[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_decoder()[source]
get_input_embeddings()[source]

Returns the model’s input embeddings.

Returns:

A torch module mapping vocabulary to hidden states.

Return type:

nn.Module

get_output_embeddings()[source]

Returns the model’s output embeddings.

Returns:

A torch module mapping hidden states to vocabulary.

Return type:

nn.Module

prepare_inputs_for_generation(input_ids, **kwargs)[source]
set_decoder(decoder)[source]
set_input_embeddings(value)[source]

Set model’s input embeddings.

Parameters:

value (nn.Module) – A module mapping vocabulary to hidden states.

set_output_embeddings(new_embeddings)[source]
class fedem.models.mamba.MambaModel(config: MambaConfig)[source]

Bases: MambaPreTrainedModel

Mamba model architecture consisting of MambaBlocks.

Parameters:

config (MambaConfig) – Mamba model configuration.

forward(input_ids: LongTensor | None = None, return_dict: bool | None = None) Tuple | BaseModelOutputWithPast[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_input_embeddings()[source]

Returns the model’s input embeddings.

Returns:

A torch module mapping vocabulary to hidden states.

Return type:

nn.Module

set_input_embeddings(value)[source]

Set model’s input embeddings.

Parameters:

value (nn.Module) – A module mapping vocabulary to hidden states.

class fedem.models.mamba.MambaPreTrainedModel(config: PretrainedConfig, *inputs, **kwargs)[source]

Bases: PreTrainedModel

Base class for pre-trained Mamba models.

base_model_prefix = 'model'
config_class

alias of MambaConfig

supports_gradient_checkpointing = True
class fedem.models.mamba.MambaRMSNorm(d_model: int, eps: float = 1e-05)[source]

Bases: Module

Root Mean Square Normalization module for Mamba model.

Parameters:
  • d_model (int) – Model dimension.

  • eps (float, optional) – Epsilon value for numerical stability. Defaults to 1e-5.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents