fedem.trainer package
Module contents
- class fedem.trainer.MambaTrainer(model: PreTrainedModel | Module | None = None, args: TrainingArguments | None = None, data_collator: DataCollator | None = None, train_dataset: Dataset | None = None, eval_dataset: Dataset | Dict[str, Dataset] | None = None, tokenizer: PreTrainedTokenizerBase | None = None, model_init: Callable[[], PreTrainedModel] | None = None, compute_metrics: Callable[[EvalPrediction], Dict] | None = None, callbacks: List[TrainerCallback] | None = None, optimizers: Tuple[Optimizer, LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[Tensor, Tensor], Tensor] | None = None)[source]
Bases:
Trainer
Trainer subclass for training Mamba models.
Inherits from transformers.Trainer.
- Parameters:
Trainer – Parent class for training transformers models.
- compute_loss(model, inputs, return_outputs=False)[source]
Computes the loss for Mamba model training.
- Parameters:
model – Mamba model.
inputs – Model inputs.
return_outputs (bool, optional) – Whether to return model outputs. Defaults to False.
- Returns:
Computed language modeling loss.
- Return type:
torch.Tensor