import json
import os
import shutil
from datetime import datetime
from huggingface_hub import CommitInfo, RepoUrl
from transformers import (
AutoTokenizer,
DataCollatorForLanguageModeling,
TrainingArguments,
)
from ..models.mamba import MambaForCausalLM
from ..trainer import MambaTrainer
from ..utils import load_data, load_model_with_LoRA
from ..utils.huggingface import get_client_details, verify_user_with_org
[docs]
class Seshu:
def __init__(
self,
hf_model_path: str,
hf_tokenizer_path: str,
target_modules: list[str],
hf_adapter_path: str,
hf_data_path: str,
org_id: str = "mlsquare",
hf_token: str | None = None,
):
"""
Initialize a Seshu object.
Args:
hf_model_path (str): Path to the HF model.
hf_tokenizer_path (str): Path to the HF tokenizer.
target_modules (list[str]): List of target modules.
hf_adapter_path (str): Path to the HF adapter.
hf_data_path (str): Path to the HF data.
org_id (str, optional): Organization ID. Defaults to "mlsquare".
hf_token (str | None, optional): HF token. Defaults to None.
"""
self.hf_model_path = hf_model_path
self.hf_tokenizer_path = hf_tokenizer_path
self.target_modules = target_modules
self.hf_adapter_path = hf_adapter_path
self.hf_data_path = hf_data_path
self.hf_token = hf_token
self.api, self.client_details = get_client_details(hf_token=self.hf_token)
self.username: str = self.client_details['name']
self.fullname: str = self.client_details['fullname']
self.org_id = org_id
self.org_details = verify_user_with_org(self.client_details, self.org_id)
print(
f"{self.fullname} is part of the organization {self.org_id} as a contributor."
)
try:
# check if the status json exists in the root adapter repo
os.makedirs(self.org_id, exist_ok=True)
response = self.api.hf_hub_download(
repo_id=self.hf_adapter_path,
filename="status.json",
local_dir=self.org_id,
)
raise Exception(
"""The adapter is being used by another user.
Please use a different adapter or wait for couple of hours for it to be available."""
)
except:
print(
f"The adapter {self.hf_adapter_path} is available for use by {self.username}."
)
# create a new model repo if the model does not exist on the Hugging Face Hub
self.repo_name: str = (
f"{self.org_id}/{self.hf_adapter_path.split('/')[-1]}_{self.username}"
)
try:
response = self.api.model_info(self.repo_name)
print(f"Model repo {self.repo_name} already exists on HF.")
response = self.api.delete_repo(repo_id=self.repo_name)
print(f"Model repo {self.repo_name} deleted successfully.")
except:
print(f"Model repo {self.repo_name} does not exist on HF.")
response: RepoUrl = self.api.create_repo(self.repo_name, repo_type="model")
print(f"New model repo created at {response.url}")
# create a local path to store the model
self.local_path: str = os.path.join(os.getcwd(), self.repo_name)
if os.path.exists(self.local_path):
shutil.rmtree(self.local_path, ignore_errors=True)
print(
f"Local path already exists. Deleting the existing path. - {self.local_path}"
)
os.makedirs(self.local_path)
print(f"Local path created - {self.local_path}")
# create a status json file
status = {
"username": self.username,
"fullname": self.fullname,
"org_id": self.org_id,
"time": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S"),
"hf_model_path": self.hf_model_path,
"hf_tokenizer_path": self.hf_tokenizer_path,
"target_modules": self.target_modules,
"hf_adapter_path": self.hf_adapter_path,
"hf_data_path": self.hf_data_path,
}
os.makedirs(os.path.join(self.local_path, "local_copy"), exist_ok=True)
json_path = os.path.join(self.local_path, "local_copy", "status.json")
with open(json_path, "w") as f:
json.dump(status, f, indent=4)
print(f"Status json created at {json_path}")
self.push_to_hub()
self.tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_path)
[docs]
def tokenize(self, data_to_tokenize):
"""
Tokenize the input data.
Args:
data_to_tokenize: Data to be tokenized.
Returns:
dict: Tokenized input data.
"""
outputs = self.tokenizer(
data_to_tokenize["tgt"],
truncation=True,
max_length=1024,
return_overflowing_tokens=True,
return_length=True,
)
if "length" not in outputs and "input_ids" not in outputs:
raise ValueError(
"The tokenizer did not return the expected outputs. Please check the inputs."
)
input_batch = []
for length, input_ids in zip(outputs["length"], outputs["input_ids"]): # type: ignore
if length != 0:
input_batch.append(input_ids)
return {"input_ids": input_batch}
[docs]
def train_lora(
self,
training_args: TrainingArguments | None = None,
debug: bool = False,
):
"""
Train the model using LoRA.
Args:
training_args (TrainingArguments | None, optional): Training arguments. Defaults to None.
debug (bool, optional): Debug mode flag. Defaults to False.
"""
if training_args is None:
self.training_args = TrainingArguments(
output_dir=os.path.join(self.local_path, "outputs"),
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir=os.path.join(self.local_path, "logs"),
)
else:
self.training_args = training_args
try:
model = MambaForCausalLM.from_pretrained(self.hf_model_path)
model.enable_input_require_grads() # type: ignore
model.load_adapter(self.hf_adapter_path) # type: ignore
except Exception:
model = MambaForCausalLM.from_pretrained(self.hf_model_path)
model = load_model_with_LoRA(model, self.target_modules, self.local_path)
print("Creating new adapter as the previous one is not valid.")
self.tokenizer.pad_token = self.tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
data_to_tokenize = load_data(self.hf_data_path)
tokenized_data = data_to_tokenize.map(
self.tokenize,
batched=True,
remove_columns=data_to_tokenize["train"].column_names,
)
trainer = MambaTrainer(
model=model, # type: ignore
tokenizer=self.tokenizer,
args=self.training_args,
data_collator=data_collator,
train_dataset=tokenized_data["train"], # type: ignore
eval_dataset=tokenized_data["valid"], # type: ignore
)
if not debug:
trainer.train()
else:
print("trainer.train() will be called in non debug mode")
trainer.save_model(os.path.join(self.local_path, "local_copy"))
[docs]
def push_to_hub(self):
"""
Push files to the Hugging Face Hub.
"""
response: CommitInfo = self.api.upload_folder(
folder_path=os.path.join(self.local_path, "local_copy"),
repo_id=self.repo_name,
repo_type="model",
)
print(f"File(s) uploaded to {response.commit_url} successfully.")