Make exllamav3_hf and exllamav2_hf functional again

This commit is contained in:
oobabooga 2025-09-17 12:29:22 -07:00
parent 9c0a833a0a
commit 9e9ab39892
2 changed files with 9 additions and 5 deletions

View file

@ -35,7 +35,9 @@ except Exception:
class Exllamav2HF(PreTrainedModel, GenerationMixin): class Exllamav2HF(PreTrainedModel, GenerationMixin):
def __init__(self, config: ExLlamaV2Config): def __init__(self, config: ExLlamaV2Config):
super().__init__(PretrainedConfig()) hf_config = PretrainedConfig.from_pretrained(config.model_dir)
super().__init__(hf_config)
self.ex_config = config self.ex_config = config
self.loras = None self.loras = None
self.generation_config = GenerationConfig() self.generation_config = GenerationConfig()

View file

@ -27,11 +27,13 @@ except Exception:
class Exllamav3HF(PreTrainedModel, GenerationMixin): class Exllamav3HF(PreTrainedModel, GenerationMixin):
def __init__(self, model_dir): def __init__(self, model_dir):
super().__init__(PretrainedConfig()) hf_config = PretrainedConfig.from_pretrained(model_dir)
self.generation_config = GenerationConfig() super().__init__(hf_config)
config = Config.from_directory(model_dir) exl3_config = Config.from_directory(model_dir)
self.ex_model = Model.from_config(config)
self.generation_config = GenerationConfig()
self.ex_model = Model.from_config(exl3_config)
# Calculate the closest multiple of 256 at or above the chosen value # Calculate the closest multiple of 256 at or above the chosen value
max_tokens = shared.args.ctx_size max_tokens = shared.args.ctx_size