fedem.models package
Submodules
fedem.models.mamba module
- class fedem.models.mamba.MambaBlock(config: MambaConfig)[source]
Bases:
Module
Mamba block module as described in the Mamba paper.
- Parameters:
config (MambaConfig) – Mamba model configuration.
- forward(x)[source]
Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
- Parameters:
x – shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n…)
- Returns:
shape (b, l, d)
- Return type:
output
- Official Implementation:
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
- selective_scan(u, delta, A, B, C, D)[source]
- Does selective scan algorithm. See:
Section 2 State Space Models in the Mamba paper [1]
Algorithm 2 in Section 3.2 in the Mamba paper [1]
run_SSM(A, B, C, u) in The Annotated S4 [2]
- This is the classic discrete state space formula:
x(t + 1) = Ax(t) + Bu(t) y(t) = Cx(t) + Du(t)
except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
- Parameters:
u – shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n…)
delta – shape (b, l, d_in)
A – shape (d_in, n)
B – shape (b, l, n)
C – shape (b, l, n)
D – shape (d_in,)
- Returns:
shape (b, l, d_in)
- Return type:
output
- Official Implementation:
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 Note: I refactored some parts out of selective_scan_ref out, so the functionality doesn’t match exactly.
- ssm(x)[source]
- Runs the SSM. See:
Algorithm 2 in Section 3.2 in the Mamba paper [1]
run_SSM(A, B, C, u) in The Annotated S4 [2]
- Parameters:
x – shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n…)
- Returns:
shape (b, l, d_in)
- Return type:
output
- Official Implementation:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
- class fedem.models.mamba.MambaForCausalLM(config)[source]
Bases:
MambaPreTrainedModel
Mamba model for Causal Language Modeling.
- Parameters:
config (MambaConfig) – Mamba model configuration.
- forward(input_ids: LongTensor | None = None, labels: LongTensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None) Tuple | CausalLMOutputWithPast [source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_input_embeddings()[source]
Returns the model’s input embeddings.
- Returns:
A torch module mapping vocabulary to hidden states.
- Return type:
nn.Module
- get_output_embeddings()[source]
Returns the model’s output embeddings.
- Returns:
A torch module mapping hidden states to vocabulary.
- Return type:
nn.Module
- class fedem.models.mamba.MambaModel(config: MambaConfig)[source]
Bases:
MambaPreTrainedModel
Mamba model architecture consisting of MambaBlocks.
- Parameters:
config (MambaConfig) – Mamba model configuration.
- forward(input_ids: LongTensor | None = None, return_dict: bool | None = None) Tuple | BaseModelOutputWithPast [source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class fedem.models.mamba.MambaPreTrainedModel(config: PretrainedConfig, *inputs, **kwargs)[source]
Bases:
PreTrainedModel
Base class for pre-trained Mamba models.
- base_model_prefix = 'model'
- config_class
alias of
MambaConfig
- supports_gradient_checkpointing = True
- class fedem.models.mamba.MambaRMSNorm(d_model: int, eps: float = 1e-05)[source]
Bases:
Module
Root Mean Square Normalization module for Mamba model.
- Parameters:
d_model (int) – Model dimension.
eps (float, optional) – Epsilon value for numerical stability. Defaults to 1e-5.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.