fedem.server package

Module contents

class fedem.server.MambaTrainer(model: PreTrainedModel | Module | None = None, args: TrainingArguments | None = None, data_collator: DataCollator | None = None, train_dataset: Dataset | None = None, eval_dataset: Dataset | Dict[str, Dataset] | None = None, tokenizer: PreTrainedTokenizerBase | None = None, model_init: Callable[[], PreTrainedModel] | None = None, compute_metrics: Callable[[EvalPrediction], Dict] | None = None, callbacks: List[TrainerCallback] | None = None, optimizers: Tuple[Optimizer, LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[Tensor, Tensor], Tensor] | None = None)[source]

Bases: Trainer

compute_loss(model, inputs, return_outputs=False)[source]

Compute the loss for the training process.

Parameters:
  • model – The PyTorch model.

  • inputs – Input data.

  • return_outputs (bool) – Whether to return the computed loss.

Returns:

Computed language modeling loss.

Return type:

torch.Tensor

class fedem.server.Seshu(adapters: dict | str, config_file: dict | str, hf_token: str | None = None, org_id: str = 'mlsquare', train_args=False)[source]

Bases: object

model_merge_eval(model_path, type_config='small', data='mlsquare/SERVER_samantar_mixed_val')[source]
pretrain(cpt_hours: int | None = None, debug: bool = False)[source]
tokenize(data)[source]
fedem.server.compute_loss(model, inputs, return_outputs=False)[source]
fedem.server.create_JSON(value)[source]
fedem.server.evaluation(data, model, tokenizer, batch_size=32, max_length=1024)[source]
fedem.server.get_checkpoint_model(model_name)[source]

Get a checkpoint model by model name from an organization.

Parameters:

model_name (str) – Name of the model.

Returns:

Model ID if found, False otherwise.

Return type:

str | False

fedem.server.get_data(data_path, fraction=0.01)[source]

Load a fraction of the dataset from the specified path and return it.

Parameters:
  • data_path (str) – Path to the dataset.

  • fraction (float, optional) – Fraction of the dataset to load. Defaults to 0.01.

Returns:

Loaded dataset.

Return type:

datasets.Dataset

fedem.server.load_data(data_path)[source]
fedem.server.load_json(json_path)[source]
fedem.server.make_config(json)[source]

Create a MambaConfig object based on the provided JSON data.

Parameters:

json (dict) – JSON data containing configuration parameters.

Returns:

Created MambaConfig object.

Return type:

MambaConfig

fedem.server.model_merge(adapters, model_path, data, tokenizer)[source]
fedem.server.print_trainable_parameters(model)[source]

Prints the number of trainable parameters in the model.