fedem.server package
Module contents
- class fedem.server.MambaTrainer(model: PreTrainedModel | Module | None = None, args: TrainingArguments | None = None, data_collator: DataCollator | None = None, train_dataset: Dataset | None = None, eval_dataset: Dataset | Dict[str, Dataset] | None = None, tokenizer: PreTrainedTokenizerBase | None = None, model_init: Callable[[], PreTrainedModel] | None = None, compute_metrics: Callable[[EvalPrediction], Dict] | None = None, callbacks: List[TrainerCallback] | None = None, optimizers: Tuple[Optimizer, LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[Tensor, Tensor], Tensor] | None = None)[source]
Bases:
Trainer
- class fedem.server.Seshu(adapters: dict | str, config_file: dict | str, hf_token: str | None = None, org_id: str = 'mlsquare', train_args=False)[source]
Bases:
object
- fedem.server.get_checkpoint_model(model_name)[source]
Get a checkpoint model by model name from an organization.
- Parameters:
model_name (str) – Name of the model.
- Returns:
Model ID if found, False otherwise.
- Return type:
str | False
- fedem.server.get_data(data_path, fraction=0.01)[source]
Load a fraction of the dataset from the specified path and return it.
- Parameters:
data_path (str) – Path to the dataset.
fraction (float, optional) – Fraction of the dataset to load. Defaults to 0.01.
- Returns:
Loaded dataset.
- Return type:
datasets.Dataset