#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Combined QLoRA training script for: 1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer 2) DPO (Direct Preference Optimization) with TRL DPOTrainer Usage examples (multi-GPU): torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py sft torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both You can override any path/knob via CLI flags (see --help). """ import os import time import json import inspect import argparse from dataclasses import dataclass, asdict from typing import List, Optional import torch from datasets import load_from_disk from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Trainer as HFTrainer, TrainerCallback, ) from peft import ( PeftModel, LoraConfig, get_peft_model, prepare_model_for_kbit_training, ) from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig # ========================== # Compatibility patches # ========================== def patch_trl_dpo_trainer_if_needed(): """ Fix TRL<->Transformers API mismatches. 1) get_batch_samples signature mismatch: Transformers Trainer calls get_batch_samples(epoch_iterator, num_batches) Older TRL DPOTrainer defines get_batch_samples(model, batch) => 'generator' object has no attribute 'generate' Patch: delegate to HFTrainer.get_batch_samples. 2) compute_loss signature mismatch: Newer Transformers passes num_items_in_batch kwarg. Patch wrapper to accept it if missing. """ # ---- Patch get_batch_samples ---- try: sig = inspect.signature(DPOTrainer.get_batch_samples) params = list(sig.parameters.keys()) # includes "self" # Old TRL: (self, model, batch, ...) -> params[1] == "model" if len(params) >= 3 and params[1] == "model": print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True) def _get_batch_samples(self, epoch_iterator, num_batches): return HFTrainer.get_batch_samples(self, epoch_iterator, num_batches) DPOTrainer.get_batch_samples = _get_batch_samples else: print("[INFO] DPOTrainer.get_batch_samples signature looks compatible; no patch needed", flush=True) except Exception as e: print(f"[WARN] Could not inspect/patch get_batch_samples: {e}", flush=True) # ---- Patch compute_loss ---- try: sig = inspect.signature(DPOTrainer.compute_loss) if "num_items_in_batch" not in sig.parameters: print("[PATCH] DPOTrainer.compute_loss lacks num_items_in_batch; patching wrapper", flush=True) _orig_compute_loss = DPOTrainer.compute_loss def _compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): return _orig_compute_loss(self, model, inputs, return_outputs=return_outputs) DPOTrainer.compute_loss = _compute_loss else: print("[INFO] DPOTrainer.compute_loss signature looks compatible; no patch needed", flush=True) except Exception as e: print(f"[WARN] Could not inspect/patch compute_loss: {e}", flush=True) # ========================== # Callbacks # ========================== class ProgressETA(TrainerCallback): def __init__(self): self.t0 = None self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) def on_train_begin(self, args, state, control, **kwargs): self.t0 = time.time() if self.local_rank == 0: print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True) def on_log(self, args, state, control, logs=None, **kwargs): if self.local_rank != 0: return if not logs or not state.max_steps: return step = state.global_step max_steps = state.max_steps elapsed = time.time() - (self.t0 or time.time()) sps = step / max(elapsed, 1e-9) rem = max_steps - step eta = rem / max(sps, 1e-9) msg = f"[PROGRESS] step {step}/{max_steps} | {sps:.3f} steps/s | ETA {eta/60:.1f} min" if "loss" in logs: msg += f" | loss {logs['loss']:.4f}" if "learning_rate" in logs: msg += f" | lr {logs['learning_rate']:.3e}" print(msg, flush=True) # ========================== # Config dataclasses # ========================== @dataclass class CommonCfg: base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b" target_modules: List[str] = None # set in __post_init__ # QLoRA quant config load_in_4bit: bool = True bnb_4bit_compute_dtype: str = "float16" bnb_4bit_use_double_quant: bool = True bnb_4bit_quant_type: str = "nf4" # Tokenizer use_fast_tokenizer: bool = False trust_remote_code: bool = True def __post_init__(self): if self.target_modules is None: self.target_modules = ["q_proj", "v_proj"] @dataclass class SFTCfg: data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text" out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora" max_seq_len: int = 1024 num_epochs: int = 1 per_device_bs: int = 1 grad_accum: int = 16 lr: float = 2e-4 warmup_ratio: float = 0.03 logging_steps: int = 10 save_steps: int = 200 save_total_limit: int = 2 # LoRA lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 @dataclass class DPOCfg: # SFT adapter input dir (from SFT run) sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora" dpo_data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_dpo" # must have prompt/chosen/rejected out_dir: str = "/home/hyrenko/Diploma/outputs_dpo/mistral-sk-7b-pku-saferlhf-sk-dpo-qlora" max_length: int = 1024 max_prompt_length: int = 512 num_epochs: int = 1 per_device_bs: int = 1 grad_accum: int = 16 lr: float = 1e-5 warmup_ratio: float = 0.03 logging_steps: int = 10 save_steps: int = 200 save_total_limit: int = 2 beta: float = 0.1 # NEW LoRA adapter for DPO (separate from SFT) lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 # ========================== # Helpers # ========================== def get_local_rank_and_device(): local_rank = int(os.environ.get("LOCAL_RANK", "0")) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) else: device = torch.device("cpu") print(f"[INFO] local_rank={local_rank} device={device}", flush=True) return local_rank, device def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig: dtype = getattr(torch, common.bnb_4bit_compute_dtype) return BitsAndBytesConfig( load_in_4bit=common.load_in_4bit, bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=common.bnb_4bit_use_double_quant, bnb_4bit_quant_type=common.bnb_4bit_quant_type, ) def load_tokenizer(common: CommonCfg): tokenizer = AutoTokenizer.from_pretrained( common.base_model_path, use_fast=common.use_fast_tokenizer, trust_remote_code=common.trust_remote_code, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device): model = AutoModelForCausalLM.from_pretrained( common.base_model_path, quantization_config=bnb_config, device_map={"": torch.cuda.current_device()} if device.type == "cuda" else None, torch_dtype=torch.float16, trust_remote_code=common.trust_remote_code, ) model.config.use_cache = False # Avoid DDP + checkpointing reentrant issues if hasattr(model, "gradient_checkpointing_disable"): model.gradient_checkpointing_disable() model = prepare_model_for_kbit_training(model) return model def save_manifest(out_dir: str, local_rank: int, manifest: dict): if local_rank != 0: return os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, "run_manifest.json"), "w", encoding="utf-8") as f: json.dump(manifest, f, ensure_ascii=False, indent=2) # ========================== # SFT # ========================== def run_sft(common: CommonCfg, sft: SFTCfg): local_rank, device = get_local_rank_and_device() os.makedirs(sft.out_dir, exist_ok=True) ds = load_from_disk(sft.data_path) if "text" not in ds.column_names: raise RuntimeError(f"[SFT] Dataset must have 'text' column. Found: {ds.column_names}") print(f"[INFO] [SFT] Dataset rows: {len(ds)}", flush=True) tokenizer = load_tokenizer(common) bnb_config = make_bnb_config(common) model = load_4bit_model(common, bnb_config, device) lora = LoraConfig( r=sft.lora_r, lora_alpha=sft.lora_alpha, lora_dropout=sft.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=common.target_modules, ) model = get_peft_model(model, lora) if local_rank == 0: model.print_trainable_parameters() cfg = SFTConfig( output_dir=sft.out_dir, dataset_text_field="text", max_seq_length=sft.max_seq_len, packing=True, num_train_epochs=sft.num_epochs, per_device_train_batch_size=sft.per_device_bs, gradient_accumulation_steps=sft.grad_accum, learning_rate=sft.lr, lr_scheduler_type="cosine", warmup_ratio=sft.warmup_ratio, weight_decay=0.0, logging_steps=sft.logging_steps, save_steps=sft.save_steps, save_total_limit=sft.save_total_limit, fp16=True, bf16=False, optim="paged_adamw_8bit", report_to="none", remove_unused_columns=False, ddp_find_unused_parameters=False, gradient_checkpointing=False, ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=ds, args=cfg, ) trainer.add_callback(ProgressETA()) save_manifest( sft.out_dir, local_rank, { "stage": "sft", "common": asdict(common), "sft": asdict(sft), "quant": "4bit nf4 qlora", "gradient_checkpointing": False, }, ) trainer.train() if local_rank == 0: trainer.save_model(sft.out_dir) # saves PEFT adapter tokenizer.save_pretrained(sft.out_dir) print(f"[OK] [SFT] Finished. Adapter saved to: {sft.out_dir}", flush=True) print("[OK] [SFT] Base model was NOT overwritten; only LoRA adapter weights were trained/saved.", flush=True) # ========================== # DPO # ========================== def run_dpo(common: CommonCfg, dpo: DPOCfg): # Patches first (safe no-op if not needed) patch_trl_dpo_trainer_if_needed() local_rank, device = get_local_rank_and_device() os.makedirs(dpo.out_dir, exist_ok=True) ds = load_from_disk(dpo.dpo_data_path) for c in ("prompt", "chosen", "rejected"): if c not in ds.column_names: raise RuntimeError(f"[DPO] Missing '{c}' in DPO dataset. Found: {ds.column_names}") print(f"[INFO] [DPO] Dataset rows: {len(ds)}", flush=True) tokenizer = load_tokenizer(common) bnb_config = make_bnb_config(common) model = load_4bit_model(common, bnb_config, device) # Load SFT adapter (frozen) + add NEW DPO adapter (trainable) model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False) dpo_lora = LoraConfig( r=dpo.lora_r, lora_alpha=dpo.lora_alpha, lora_dropout=dpo.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=common.target_modules, ) model = get_peft_model(model, dpo_lora) if local_rank == 0: model.print_trainable_parameters() cfg = DPOConfig( output_dir=dpo.out_dir, max_length=dpo.max_length, max_prompt_length=dpo.max_prompt_length, num_train_epochs=dpo.num_epochs, per_device_train_batch_size=dpo.per_device_bs, gradient_accumulation_steps=dpo.grad_accum, learning_rate=dpo.lr, lr_scheduler_type="cosine", warmup_ratio=dpo.warmup_ratio, logging_steps=dpo.logging_steps, save_steps=dpo.save_steps, save_total_limit=dpo.save_total_limit, fp16=True, bf16=False, optim="paged_adamw_8bit", report_to="none", beta=dpo.beta, ddp_find_unused_parameters=False, gradient_checkpointing=False, ) trainer = DPOTrainer( model=model, ref_model=None, args=cfg, train_dataset=ds, tokenizer=tokenizer, ) trainer.add_callback(ProgressETA()) save_manifest( dpo.out_dir, local_rank, { "stage": "dpo", "common": asdict(common), "dpo": asdict(dpo), "notes": "Includes TRL/Transformers compatibility patches for get_batch_samples + compute_loss", }, ) trainer.train() if local_rank == 0: trainer.save_model(dpo.out_dir) tokenizer.save_pretrained(dpo.out_dir) print(f"[OK] [DPO] Finished. Adapter saved to: {dpo.out_dir}", flush=True) print("[OK] [DPO] Base model was NOT overwritten; only DPO LoRA adapter weights were trained/saved.", flush=True) # ========================== # CLI # ========================== def build_parser(): p = argparse.ArgumentParser( description="Combined QLoRA training: SFT and DPO (TRL).", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) sub = p.add_subparsers(dest="cmd", required=False) def add_common_args(sp): sp.add_argument("--base_model_path", default=CommonCfg.base_model_path) sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"]) sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer") sp.add_argument("--no_trust_remote_code", action="store_true", help="Disable trust_remote_code") # ---- SFT ---- sft_p = sub.add_parser("sft", help="Run SFT stage only") add_common_args(sft_p) sft_p.add_argument("--sft_data_path", default=SFTCfg.data_path) sft_p.add_argument("--sft_out_dir", default=SFTCfg.out_dir) sft_p.add_argument("--sft_max_seq_len", type=int, default=SFTCfg.max_seq_len) sft_p.add_argument("--sft_num_epochs", type=int, default=SFTCfg.num_epochs) sft_p.add_argument("--sft_per_device_bs", type=int, default=SFTCfg.per_device_bs) sft_p.add_argument("--sft_grad_accum", type=int, default=SFTCfg.grad_accum) sft_p.add_argument("--sft_lr", type=float, default=SFTCfg.lr) sft_p.add_argument("--sft_warmup_ratio", type=float, default=SFTCfg.warmup_ratio) sft_p.add_argument("--sft_logging_steps", type=int, default=SFTCfg.logging_steps) sft_p.add_argument("--sft_save_steps", type=int, default=SFTCfg.save_steps) sft_p.add_argument("--sft_save_total_limit", type=int, default=SFTCfg.save_total_limit) sft_p.add_argument("--sft_lora_r", type=int, default=SFTCfg.lora_r) sft_p.add_argument("--sft_lora_alpha", type=int, default=SFTCfg.lora_alpha) sft_p.add_argument("--sft_lora_dropout", type=float, default=SFTCfg.lora_dropout) # ---- DPO ---- dpo_p = sub.add_parser("dpo", help="Run DPO stage only") add_common_args(dpo_p) dpo_p.add_argument("--sft_adapter_dir", default=DPOCfg.sft_adapter_dir) dpo_p.add_argument("--dpo_data_path", default=DPOCfg.dpo_data_path) dpo_p.add_argument("--dpo_out_dir", default=DPOCfg.out_dir) dpo_p.add_argument("--dpo_max_length", type=int, default=DPOCfg.max_length) dpo_p.add_argument("--dpo_max_prompt_length", type=int, default=DPOCfg.max_prompt_length) dpo_p.add_argument("--dpo_num_epochs", type=int, default=DPOCfg.num_epochs) dpo_p.add_argument("--dpo_per_device_bs", type=int, default=DPOCfg.per_device_bs) dpo_p.add_argument("--dpo_grad_accum", type=int, default=DPOCfg.grad_accum) dpo_p.add_argument("--dpo_lr", type=float, default=DPOCfg.lr) dpo_p.add_argument("--dpo_warmup_ratio", type=float, default=DPOCfg.warmup_ratio) dpo_p.add_argument("--dpo_logging_steps", type=int, default=DPOCfg.logging_steps) dpo_p.add_argument("--dpo_save_steps", type=int, default=DPOCfg.save_steps) dpo_p.add_argument("--dpo_save_total_limit", type=int, default=DPOCfg.save_total_limit) dpo_p.add_argument("--dpo_beta", type=float, default=DPOCfg.beta) dpo_p.add_argument("--dpo_lora_r", type=int, default=DPOCfg.lora_r) dpo_p.add_argument("--dpo_lora_alpha", type=int, default=DPOCfg.lora_alpha) dpo_p.add_argument("--dpo_lora_dropout", type=float, default=DPOCfg.lora_dropout) # ---- BOTH ---- both_p = sub.add_parser("both", help="Run SFT then DPO (uses SFT out dir as DPO sft_adapter_dir unless overridden)") add_common_args(both_p) # SFT args subset both_p.add_argument("--sft_data_path", default=SFTCfg.data_path) both_p.add_argument("--sft_out_dir", default=SFTCfg.out_dir) both_p.add_argument("--sft_max_seq_len", type=int, default=SFTCfg.max_seq_len) both_p.add_argument("--sft_num_epochs", type=int, default=SFTCfg.num_epochs) both_p.add_argument("--sft_per_device_bs", type=int, default=SFTCfg.per_device_bs) both_p.add_argument("--sft_grad_accum", type=int, default=SFTCfg.grad_accum) both_p.add_argument("--sft_lr", type=float, default=SFTCfg.lr) both_p.add_argument("--sft_warmup_ratio", type=float, default=SFTCfg.warmup_ratio) both_p.add_argument("--sft_logging_steps", type=int, default=SFTCfg.logging_steps) both_p.add_argument("--sft_save_steps", type=int, default=SFTCfg.save_steps) both_p.add_argument("--sft_save_total_limit", type=int, default=SFTCfg.save_total_limit) both_p.add_argument("--sft_lora_r", type=int, default=SFTCfg.lora_r) both_p.add_argument("--sft_lora_alpha", type=int, default=SFTCfg.lora_alpha) both_p.add_argument("--sft_lora_dropout", type=float, default=SFTCfg.lora_dropout) # DPO args subset both_p.add_argument("--dpo_data_path", default=DPOCfg.dpo_data_path) both_p.add_argument("--dpo_out_dir", default=DPOCfg.out_dir) both_p.add_argument("--dpo_max_length", type=int, default=DPOCfg.max_length) both_p.add_argument("--dpo_max_prompt_length", type=int, default=DPOCfg.max_prompt_length) both_p.add_argument("--dpo_num_epochs", type=int, default=DPOCfg.num_epochs) both_p.add_argument("--dpo_per_device_bs", type=int, default=DPOCfg.per_device_bs) both_p.add_argument("--dpo_grad_accum", type=int, default=DPOCfg.grad_accum) both_p.add_argument("--dpo_lr", type=float, default=DPOCfg.lr) both_p.add_argument("--dpo_warmup_ratio", type=float, default=DPOCfg.warmup_ratio) both_p.add_argument("--dpo_logging_steps", type=int, default=DPOCfg.logging_steps) both_p.add_argument("--dpo_save_steps", type=int, default=DPOCfg.save_steps) both_p.add_argument("--dpo_save_total_limit", type=int, default=DPOCfg.save_total_limit) both_p.add_argument("--dpo_beta", type=float, default=DPOCfg.beta) both_p.add_argument("--dpo_lora_r", type=int, default=DPOCfg.lora_r) both_p.add_argument("--dpo_lora_alpha", type=int, default=DPOCfg.lora_alpha) both_p.add_argument("--dpo_lora_dropout", type=float, default=DPOCfg.lora_dropout) both_p.add_argument("--sft_adapter_dir", default="", help="If set, use this as DPO SFT adapter dir; otherwise uses --sft_out_dir") return p def interactive_menu() -> str: """ Simple interactive selector when the script is launched without CLI arguments. Uses defaults for all paths/knobs. If you want to override anything, run with CLI flags. """ print("\n=== Training menu ===") print("1) SFT") print("2) DPO") print("3) SFT + DPO (both)") print("q) Quit\n") while True: choice = input("Select an option [1/2/3/q]: ").strip().lower() if choice in ("1", "sft"): return "sft" if choice in ("2", "dpo"): return "dpo" if choice in ("3", "both", "all"): return "both" if choice in ("q", "quit", "exit"): raise SystemExit(0) print("Invalid choice. Please enter 1, 2, 3, or q.") def main(): parser = build_parser() # If user runs the script with no arguments, we show interactive menu # and run with defaults. import sys if len(sys.argv) == 1: # IMPORTANT for torchrun/DDP: # In distributed mode, we must NOT prompt on every process. # Rank0 asks, then writes the choice to a small file; other ranks read it. world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) if world_size == 1: cmd = interactive_menu() else: choice_file = os.path.join( os.getcwd(), f".train_choice_{os.environ.get('MASTER_PORT', '0')}_{os.environ.get('RUN_ID', '0')}.json", ) if local_rank == 0: cmd = interactive_menu() try: with open(choice_file, "w", encoding="utf-8") as f: json.dump({"cmd": cmd}, f) except Exception as e: print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True) else: # Wait for rank0 to write the choice deadline = time.time() + 600 # seconds cmd = None while time.time() < deadline: if os.path.exists(choice_file): try: with open(choice_file, "r", encoding="utf-8") as f: cmd = json.load(f).get("cmd") if cmd: break except Exception: pass time.sleep(0.2) if not cmd: print("[ERROR] Timed out waiting for rank0 menu selection. " "Run with explicit subcommand: sft/dpo/both.", flush=True) raise SystemExit(2) args = parser.parse_args([cmd]) # defaults else: args = parser.parse_args() if not getattr(args, "cmd", None): parser.print_help() raise SystemExit(2) common = CommonCfg( base_model_path=args.base_model_path, target_modules=args.target_modules, use_fast_tokenizer=args.use_fast_tokenizer, trust_remote_code=not args.no_trust_remote_code, ) if args.cmd == "sft": sft = SFTCfg( data_path=args.sft_data_path, out_dir=args.sft_out_dir, max_seq_len=args.sft_max_seq_len, num_epochs=args.sft_num_epochs, per_device_bs=args.sft_per_device_bs, grad_accum=args.sft_grad_accum, lr=args.sft_lr, warmup_ratio=args.sft_warmup_ratio, logging_steps=args.sft_logging_steps, save_steps=args.sft_save_steps, save_total_limit=args.sft_save_total_limit, lora_r=args.sft_lora_r, lora_alpha=args.sft_lora_alpha, lora_dropout=args.sft_lora_dropout, ) run_sft(common, sft) elif args.cmd == "dpo": dpo = DPOCfg( sft_adapter_dir=args.sft_adapter_dir, dpo_data_path=args.dpo_data_path, out_dir=args.dpo_out_dir, max_length=args.dpo_max_length, max_prompt_length=args.dpo_max_prompt_length, num_epochs=args.dpo_num_epochs, per_device_bs=args.dpo_per_device_bs, grad_accum=args.dpo_grad_accum, lr=args.dpo_lr, warmup_ratio=args.dpo_warmup_ratio, logging_steps=args.dpo_logging_steps, save_steps=args.dpo_save_steps, save_total_limit=args.dpo_save_total_limit, beta=args.dpo_beta, lora_r=args.dpo_lora_r, lora_alpha=args.dpo_lora_alpha, lora_dropout=args.dpo_lora_dropout, ) run_dpo(common, dpo) elif args.cmd == "both": sft = SFTCfg( data_path=args.sft_data_path, out_dir=args.sft_out_dir, max_seq_len=args.sft_max_seq_len, num_epochs=args.sft_num_epochs, per_device_bs=args.sft_per_device_bs, grad_accum=args.sft_grad_accum, lr=args.sft_lr, warmup_ratio=args.sft_warmup_ratio, logging_steps=args.sft_logging_steps, save_steps=args.sft_save_steps, save_total_limit=args.sft_save_total_limit, lora_r=args.sft_lora_r, lora_alpha=args.sft_lora_alpha, lora_dropout=args.sft_lora_dropout, ) run_sft(common, sft) sft_adapter_dir = args.sft_adapter_dir.strip() or sft.out_dir dpo = DPOCfg( sft_adapter_dir=sft_adapter_dir, dpo_data_path=args.dpo_data_path, out_dir=args.dpo_out_dir, max_length=args.dpo_max_length, max_prompt_length=args.dpo_max_prompt_length, num_epochs=args.dpo_num_epochs, per_device_bs=args.dpo_per_device_bs, grad_accum=args.dpo_grad_accum, lr=args.dpo_lr, warmup_ratio=args.dpo_warmup_ratio, logging_steps=args.dpo_logging_steps, save_steps=args.dpo_save_steps, save_total_limit=args.dpo_save_total_limit, beta=args.dpo_beta, lora_r=args.dpo_lora_r, lora_alpha=args.dpo_lora_alpha, lora_dropout=args.dpo_lora_dropout, ) run_dpo(common, dpo) if __name__ == "__main__": main()